14 #include <dmlc/config.h> 15 #include <dmlc/data.h> 24 #include "./compiler/ast/builder.h" 47 inline const char* FileFormatString(
int format) {
49 case kLibSVM:
return "libsvm";
50 case kCSV:
return "csv";
51 case kLibFM:
return "libfm";
56 struct CLIParam :
public dmlc::Parameter<CLIParam> {
90 std::vector<std::pair<std::string, std::string> >
cfg;
94 DMLC_DECLARE_FIELD(task).set_default(kCodegen)
95 .add_enum(
"codegen", kCodegen)
96 .add_enum(
"annotate", kAnnotate)
97 .add_enum(
"predict", kPredict)
98 .add_enum(
"dump_ast", kDumpAST)
99 .describe(
"Task to be performed by the CLI program.");
100 DMLC_DECLARE_FIELD(verbose).set_default(0)
101 .describe(
"Produce extra messages if >0");
102 DMLC_DECLARE_FIELD(compiler).set_default(
"ast_native")
103 .describe(
"kind of compiler to use");
104 DMLC_DECLARE_FIELD(format)
105 .add_enum(
"xgboost", kXGBModel)
106 .add_enum(
"lightgbm", kLGBModel)
107 .add_enum(
"protobuf", kProtobuf)
108 .describe(
"Model format");
109 DMLC_DECLARE_FIELD(model_in).set_default(
"NULL")
110 .describe(
"Input model path");
111 DMLC_DECLARE_FIELD(name_codegen_dir).set_default(
"codegen")
112 .describe(
"directory name for generated code files");
113 DMLC_DECLARE_FIELD(name_annotate).set_default(
"annotate.json")
114 .describe(
"Name of generated annotation file");
115 DMLC_DECLARE_FIELD(name_pred).set_default(
"pred.txt")
116 .describe(
"Name of text file to save prediction");
117 DMLC_DECLARE_FIELD(train_data_path).set_default(
"NULL")
118 .describe(
"Training data path; used for annotation");
119 DMLC_DECLARE_FIELD(test_data_path).set_default(
"NULL")
120 .describe(
"Test data path; used prediction");
121 DMLC_DECLARE_FIELD(codelib_path).set_default(
"NULL")
122 .describe(
"Path to compiled dynamic shared library (.so/.dll/.dylib); " 123 "used for prediction");
124 DMLC_DECLARE_FIELD(train_format).set_default(kLibSVM)
125 .add_enum(
"libsvm", kLibSVM)
126 .add_enum(
"csv", kCSV)
127 .add_enum(
"libfm", kLibFM)
128 .describe(
"training set data format");
129 DMLC_DECLARE_FIELD(test_format).set_default(kLibSVM)
130 .add_enum(
"libsvm", kLibSVM)
131 .add_enum(
"csv", kCSV)
132 .add_enum(
"libfm", kLibFM)
133 .describe(
"test set data format");
134 DMLC_DECLARE_FIELD(nthread).set_default(0).describe(
135 "Number of threads to use.");
136 DMLC_DECLARE_FIELD(pred_margin).set_default(0).describe(
137 "if >0, predict margin instead of transformed probability");
140 DMLC_DECLARE_ALIAS(train_data_path, data);
141 DMLC_DECLARE_ALIAS(test_data_path, test:data);
142 DMLC_DECLARE_ALIAS(train_format, data_format);
145 inline void Configure(
const std::vector<std::pair<std::string, std::string> >& cfg) {
147 this->InitAllowUnknown(cfg);
154 CHECK(param.
model_in !=
"NULL") <<
"model_in parameter must be provided";
157 return frontend::LoadXGBoostModel(param.
model_in.c_str());
159 return frontend::LoadLightGBMModel(param.
model_in.c_str());
161 return frontend::LoadProtobufModel(param.
model_in.c_str());
163 LOG(FATAL) <<
"Unknown model format";
168 void CLICodegen(
const CLIParam& param) {
170 cparam.InitAllowUnknown(param.
cfg);
172 Model model = ParseModel(param);
173 LOG(INFO) <<
"model size = " << model.
trees.size();
176 common::filesystem::CreateDirectoryIfNotExist(param.
name_codegen_dir.c_str());
178 auto compiled_model = compiler->Compile(model);
180 LOG(INFO) <<
"Code generation finished. Writing code to files...";
183 if (!compiled_model.file_prefix.empty()) {
184 const std::vector<std::string> tokens
185 = common::Split(compiled_model.file_prefix,
'/');
187 for (
size_t i = 0; i < tokens.size(); ++i) {
188 common::filesystem::CreateDirectoryIfNotExist(accum.c_str());
189 if (i < tokens.size() - 1) {
191 accum += tokens[i + 1];
196 for (
const auto& it : compiled_model.files) {
197 LOG(INFO) <<
"Writing file " << it.first <<
"...";
199 common::WriteToFile(filename_full, it.second);
203 void CLIAnnotate(
const CLIParam& param) {
204 Model model = ParseModel(param);
205 LOG(INFO) <<
"model size = " << model.
trees.size();
208 <<
"Need to specify train_data_path paramter for annotation task";
211 param.nthread, param.
verbose));
215 std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(
217 annotator.
Save(fo.get());
220 void CLIPredict(
const CLIParam& param) {
223 <<
"Need to specify codelib_path paramter for prediction task";
225 <<
"Need to specify test_data_path paramter for prediction task";
228 param.nthread, param.
verbose));
229 std::unique_ptr<CSRBatch> batch = common::make_unique<CSRBatch>();
230 batch->data = &dmat->data[0];
231 batch->col_ind = &dmat->col_ind[0];
232 batch->row_ptr = &dmat->row_ptr[0];
233 batch->num_row = dmat->num_row;
234 batch->num_col = dmat->num_col;
238 std::vector<float> result(result_size);
239 result_size = predictor.
PredictBatch(batch.get(), param.nthread,
243 std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(
245 dmlc::ostream os(fo.get());
246 for (
size_t i = 0; i < result_size; ++i) {
247 os << result[i] << std::endl;
250 os.set_stream(
nullptr);
253 void CLIDumpAST(
const CLIParam& param) {
255 cparam.InitAllowUnknown(param.
cfg);
257 Model model = ParseModel(param);
258 LOG(INFO) <<
"model size = " << model.
trees.size();
260 builder.Build(model);
263 std::unique_ptr<dmlc::Stream> fi(
264 dmlc::Stream::Create(cparam.
annotate_in.c_str(),
"r"));
265 annotator.
Load(fi.get());
266 const auto annotation = annotator.
Get();
267 builder.AnnotateBranches(annotation);
268 LOG(INFO) <<
"Using branch annotation file `" 273 builder.QuantizeThresholds();
278 int CLIRunTask(
int argc,
char* argv[]) {
280 printf(
"Usage: <config>\n");
284 std::vector<std::pair<std::string, std::string> >
cfg;
287 std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(argv[1],
"r"));
288 dmlc::istream cfgfile(fi.get());
289 dmlc::Config itr(cfgfile);
290 for (
const auto& entry : itr) {
291 cfg.push_back(std::make_pair(entry.first, entry.second));
295 for (
int i = 2; i < argc; ++i) {
296 char name[256], val[256];
297 if (sscanf(argv[i],
"%[^=]=%s", name, val) == 2) {
298 cfg.push_back(std::make_pair(std::string(name), std::string(val)));
303 param.Configure(cfg);
305 switch (param.
task) {
306 case kCodegen: CLICodegen(param);
break;
307 case kAnnotate: CLIAnnotate(param);
break;
308 case kPredict: CLIPredict(param);
break;
309 case kDumpAST: CLIDumpAST(param);
break;
317 int main(
int argc,
char* argv[]) {
319 = treelite::LogCallbackRegistryStore::Get();
320 registry->Register([] (
const char* msg) {
321 std::cerr << msg << std::endl;
323 return treelite::CLIRunTask(argc, argv);
Load prediction function exported as a shared library.
int pred_margin
whether to predict margin instead of transformed probability
Collection of front-end methods to load or construct ensemble model.
thin wrapper for tree ensemble model
size_t PredictBatch(const CSRBatch *batch, int nthread, int verbose, bool pred_margin, float *out_result) const
make predictions on a batch of data rows
std::vector< std::pair< std::string, std::string > > cfg
all the configurations
std::vector< Tree > trees
member trees
size_t QueryResultSize(const CSRBatch *batch) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
parameters for tree compiler
void Annotate(const Model &model, const DMatrix *dmat, int nthread, int verbose)
annotate branches in a given model using frequency patterns in the training data. The annotation can ...
logging facility for treelite
std::string test_data_path
the path of test set: used for prediction
std::string model_in
model file
static DMatrix * Create(const char *filename, const char *format, int nthread, int verbose)
construct a new DMatrix from a file
Parameters for tree compiler.
Interface of compiler that compiles a tree ensemble model.
std::string name_codegen_dir
directory name for generated code files
std::vector< std::vector< size_t > > Get() const
fetch branch annotation. Usage example:
int verbose
whether verbose
Cross-platform wrapper for common filesystem functions.
std::string train_data_path
the path of training set: used for annotation
void Load(const char *name)
load the prediction function from dynamic shared library.
void Load(dmlc::Stream *fi)
load branch annotation from a JSON file
void Save(dmlc::Stream *fo) const
save branch annotation to a JSON file
std::string annotate_in
name of model annotation file. Use the class treelite.Annotator to generate this file.
std::string codelib_path
the path of compiled dynamic shared library: used for prediction
int parallel_comp
option to enable parallel compilation; if set to nonzero, the trees will be evely distributed into [p...
static Compiler * Create(const std::string &name, const compiler::CompilerParam ¶m)
create a compiler from given name
compatiblity wrapper for systems that don't support OpenMP
int quantize
whether to quantize threshold points (0: no, >0: yes)
std::string name_annotate
name of generated annotation file
predictor class: wrapper for optimized prediction code
std::string name_pred
name of text file to save prediction
std::string compiler
which compiler to use
int test_format
test set file format
int train_format
training set file format