treelite
cli_main.cc
Go to the documentation of this file.
1 
8 #include <treelite/frontend.h>
9 #include <treelite/annotator.h>
10 #include <treelite/compiler.h>
11 #include <treelite/predictor.h>
12 #include <treelite/logging.h>
13 #include <treelite/omp.h>
14 #include <dmlc/config.h>
15 #include <dmlc/data.h>
16 #include <fstream>
17 #include <memory>
18 #include <vector>
19 #include <queue>
20 #include <iterator>
21 #include <string>
22 #include "./compiler/param.h"
23 #include "./common/filesystem.h"
24 #include "./compiler/ast/builder.h"
25 
26 namespace treelite {
27 
28 enum CLITask {
29  kCodegen = 0,
30  kAnnotate = 1,
31  kPredict = 2,
32  kDumpAST = 3
33 };
34 
35 enum InputFormat {
36  kLibSVM = 0,
37  kCSV = 1,
38  kLibFM = 2
39 };
40 
41 enum ModelFormat {
42  kXGBModel = 0,
43  kLGBModel = 1,
44  kProtobuf = 2
45 };
46 
47 inline const char* FileFormatString(int format) {
48  switch (format) {
49  case kLibSVM: return "libsvm";
50  case kCSV: return "csv";
51  case kLibFM: return "libfm";
52  }
53  return "";
54 }
55 
56 struct CLIParam : public dmlc::Parameter<CLIParam> {
58  int task;
60  int verbose;
62  std::string compiler;
64  int format;
66  std::string model_in;
68  std::string name_codegen_dir;
70  std::string name_annotate;
72  std::string name_pred;
74  std::string train_data_path;
76  std::string test_data_path;
78  std::string codelib_path;
83  // number of threads to use if OpenMP is enabled
84  // if equals 0, use system default
85  int nthread;
90  std::vector<std::pair<std::string, std::string> > cfg;
91 
92  // declare parameters
93  DMLC_DECLARE_PARAMETER(CLIParam) {
94  DMLC_DECLARE_FIELD(task).set_default(kCodegen)
95  .add_enum("codegen", kCodegen)
96  .add_enum("annotate", kAnnotate)
97  .add_enum("predict", kPredict)
98  .add_enum("dump_ast", kDumpAST)
99  .describe("Task to be performed by the CLI program.");
100  DMLC_DECLARE_FIELD(verbose).set_default(0)
101  .describe("Produce extra messages if >0");
102  DMLC_DECLARE_FIELD(compiler).set_default("ast_native")
103  .describe("kind of compiler to use");
104  DMLC_DECLARE_FIELD(format)
105  .add_enum("xgboost", kXGBModel)
106  .add_enum("lightgbm", kLGBModel)
107  .add_enum("protobuf", kProtobuf)
108  .describe("Model format");
109  DMLC_DECLARE_FIELD(model_in).set_default("NULL")
110  .describe("Input model path");
111  DMLC_DECLARE_FIELD(name_codegen_dir).set_default("codegen")
112  .describe("directory name for generated code files");
113  DMLC_DECLARE_FIELD(name_annotate).set_default("annotate.json")
114  .describe("Name of generated annotation file");
115  DMLC_DECLARE_FIELD(name_pred).set_default("pred.txt")
116  .describe("Name of text file to save prediction");
117  DMLC_DECLARE_FIELD(train_data_path).set_default("NULL")
118  .describe("Training data path; used for annotation");
119  DMLC_DECLARE_FIELD(test_data_path).set_default("NULL")
120  .describe("Test data path; used prediction");
121  DMLC_DECLARE_FIELD(codelib_path).set_default("NULL")
122  .describe("Path to compiled dynamic shared library (.so/.dll/.dylib); "
123  "used for prediction");
124  DMLC_DECLARE_FIELD(train_format).set_default(kLibSVM)
125  .add_enum("libsvm", kLibSVM)
126  .add_enum("csv", kCSV)
127  .add_enum("libfm", kLibFM)
128  .describe("training set data format");
129  DMLC_DECLARE_FIELD(test_format).set_default(kLibSVM)
130  .add_enum("libsvm", kLibSVM)
131  .add_enum("csv", kCSV)
132  .add_enum("libfm", kLibFM)
133  .describe("test set data format");
134  DMLC_DECLARE_FIELD(nthread).set_default(0).describe(
135  "Number of threads to use.");
136  DMLC_DECLARE_FIELD(pred_margin).set_default(0).describe(
137  "if >0, predict margin instead of transformed probability");
138 
139  // alias
140  DMLC_DECLARE_ALIAS(train_data_path, data);
141  DMLC_DECLARE_ALIAS(test_data_path, test:data);
142  DMLC_DECLARE_ALIAS(train_format, data_format);
143  }
144  // customized configure function of CLIParam
145  inline void Configure(const std::vector<std::pair<std::string, std::string> >& cfg) {
146  this->cfg = cfg;
147  this->InitAllowUnknown(cfg);
148  }
149 };
150 
151 DMLC_REGISTER_PARAMETER(CLIParam);
152 
153 Model ParseModel(const CLIParam& param) {
154  CHECK(param.model_in != "NULL") << "model_in parameter must be provided";
155  switch (param.format) {
156  case kXGBModel:
157  return frontend::LoadXGBoostModel(param.model_in.c_str());
158  case kLGBModel:
159  return frontend::LoadLightGBMModel(param.model_in.c_str());
160  case kProtobuf:
161  return frontend::LoadProtobufModel(param.model_in.c_str());
162  default:
163  LOG(FATAL) << "Unknown model format";
164  return {}; // avoid compiler warning
165  }
166 }
167 
168 void CLICodegen(const CLIParam& param) {
170  cparam.InitAllowUnknown(param.cfg);
171 
172  Model model = ParseModel(param);
173  LOG(INFO) << "model size = " << model.trees.size();
174 
175  // create directory named name_codegen_dir
176  common::filesystem::CreateDirectoryIfNotExist(param.name_codegen_dir.c_str());
177  std::unique_ptr<Compiler> compiler(Compiler::Create(param.compiler, cparam));
178  auto compiled_model = compiler->Compile(model);
179  if (param.verbose > 0) {
180  LOG(INFO) << "Code generation finished. Writing code to files...";
181  }
182 
183  if (!compiled_model.file_prefix.empty()) {
184  const std::vector<std::string> tokens
185  = common::Split(compiled_model.file_prefix, '/');
186  std::string accum = param.name_codegen_dir + "/" + tokens[0];
187  for (size_t i = 0; i < tokens.size(); ++i) {
188  common::filesystem::CreateDirectoryIfNotExist(accum.c_str());
189  if (i < tokens.size() - 1) {
190  accum += "/";
191  accum += tokens[i + 1];
192  }
193  }
194  }
195 
196  for (const auto& it : compiled_model.files) {
197  LOG(INFO) << "Writing file " << it.first << "...";
198  const std::string filename_full = param.name_codegen_dir + "/" + it.first;
199  common::WriteToFile(filename_full, it.second);
200  }
201 }
202 
203 void CLIAnnotate(const CLIParam& param) {
204  Model model = ParseModel(param);
205  LOG(INFO) << "model size = " << model.trees.size();
206 
207  CHECK_NE(param.train_data_path, "NULL")
208  << "Need to specify train_data_path paramter for annotation task";
209  std::unique_ptr<DMatrix> dmat(DMatrix::Create(param.train_data_path.c_str(),
210  FileFormatString(param.train_format),
211  param.nthread, param.verbose));
212  BranchAnnotator annotator;
213  annotator.Annotate(model, dmat.get(), param.nthread, param.verbose);
214  // write to json file
215  std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(
216  param.name_annotate.c_str(), "w"));
217  annotator.Save(fo.get());
218 }
219 
220 void CLIPredict(const CLIParam& param) {
221 
222  CHECK_NE(param.codelib_path, "NULL")
223  << "Need to specify codelib_path paramter for prediction task";
224  CHECK_NE(param.test_data_path, "NULL")
225  << "Need to specify test_data_path paramter for prediction task";
226  std::unique_ptr<DMatrix> dmat(DMatrix::Create(param.test_data_path.c_str(),
227  FileFormatString(param.test_format),
228  param.nthread, param.verbose));
229  std::unique_ptr<CSRBatch> batch = common::make_unique<CSRBatch>();
230  batch->data = &dmat->data[0];
231  batch->col_ind = &dmat->col_ind[0];
232  batch->row_ptr = &dmat->row_ptr[0];
233  batch->num_row = dmat->num_row;
234  batch->num_col = dmat->num_col;
235  Predictor predictor;
236  predictor.Load(param.codelib_path.c_str());
237  size_t result_size = predictor.QueryResultSize(batch.get());
238  std::vector<float> result(result_size);
239  result_size = predictor.PredictBatch(batch.get(), param.verbose,
240  param.pred_margin, &result[0]);
241  // write to text file
242  std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(
243  param.name_pred.c_str(), "w"));
244  dmlc::ostream os(fo.get());
245  for (size_t i = 0; i < result_size; ++i) {
246  os << result[i] << std::endl;
247  }
248  // force flush before fo destruct.
249  os.set_stream(nullptr);
250 }
251 
252 void CLIDumpAST(const CLIParam& param) {
254  cparam.InitAllowUnknown(param.cfg);
255 
256  Model model = ParseModel(param);
257  LOG(INFO) << "model size = " << model.trees.size();
258  compiler::ASTBuilder builder;
259  builder.Build(model);
260  if (cparam.annotate_in != "NULL") {
261  BranchAnnotator annotator;
262  std::unique_ptr<dmlc::Stream> fi(
263  dmlc::Stream::Create(cparam.annotate_in.c_str(), "r"));
264  annotator.Load(fi.get());
265  const auto annotation = annotator.Get();
266  builder.AnnotateBranches(annotation);
267  LOG(INFO) << "Using branch annotation file `"
268  << cparam.annotate_in << "'";
269  }
270  builder.Split(cparam.parallel_comp);
271  if (cparam.quantize > 0) {
272  builder.QuantizeThresholds();
273  }
274  builder.Dump();
275 }
276 
277 int CLIRunTask(int argc, char* argv[]) {
278  if (argc < 2) {
279  printf("Usage: <config>\n");
280  return 0;
281  }
282 
283  std::vector<std::pair<std::string, std::string> > cfg;
284 
285  {
286  std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(argv[1], "r"));
287  dmlc::istream cfgfile(fi.get());
288  dmlc::Config itr(cfgfile);
289  for (const auto& entry : itr) {
290  cfg.push_back(std::make_pair(entry.first, entry.second));
291  }
292  }
293 
294  for (int i = 2; i < argc; ++i) {
295  char name[256], val[256];
296  if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) {
297  cfg.push_back(std::make_pair(std::string(name), std::string(val)));
298  }
299  }
300 
301  CLIParam param;
302  param.Configure(cfg);
303 
304  switch (param.task) {
305  case kCodegen: CLICodegen(param); break;
306  case kAnnotate: CLIAnnotate(param); break;
307  case kPredict: CLIPredict(param); break;
308  case kDumpAST: CLIDumpAST(param); break;
309  }
310 
311  return 0;
312 }
313 
314 } // namespace treelite
315 
316 int main(int argc, char* argv[]) {
318  = treelite::LogCallbackRegistryStore::Get();
319  registry->Register([] (const char* msg) {
320  std::cerr << msg << std::endl;
321  });
322  return treelite::CLIRunTask(argc, argv);
323 }
Load prediction function exported as a shared library.
branch annotator class
Definition: annotator.h:16
int pred_margin
whether to predict margin instead of transformed probability
Definition: cli_main.cc:88
Collection of front-end methods to load or construct ensemble model.
thin wrapper for tree ensemble model
Definition: tree.h:351
std::vector< std::pair< std::string, std::string > > cfg
all the configurations
Definition: cli_main.cc:90
std::vector< Tree > trees
member trees
Definition: tree.h:353
size_t QueryResultSize(const CSRBatch *batch) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
Definition: predictor.h:96
parameters for tree compiler
Definition: param.h:16
void Annotate(const Model &model, const DMatrix *dmat, int nthread, int verbose)
annotate branches in a given model using frequency patterns in the training data. The annotation can ...
Definition: annotator.cc:95
logging facility for treelite
std::string test_data_path
the path of test set: used for prediction
Definition: cli_main.cc:76
std::string model_in
model file
Definition: cli_main.cc:66
static DMatrix * Create(const char *filename, const char *format, int nthread, int verbose)
construct a new DMatrix from a file
Definition: data.cc:17
Parameters for tree compiler.
Interface of compiler that compiles a tree ensemble model.
std::string name_codegen_dir
directory name for generated code files
Definition: cli_main.cc:68
std::vector< std::vector< size_t > > Get() const
fetch branch annotation. Usage example:
Definition: annotator.h:51
int verbose
whether verbose
Definition: cli_main.cc:60
Cross-platform wrapper for common filesystem functions.
std::string train_data_path
the path of training set: used for annotation
Definition: cli_main.cc:74
void Load(const char *name)
load the prediction function from dynamic shared library.
Definition: predictor.cc:210
void Load(dmlc::Stream *fi)
load branch annotation from a JSON file
Definition: annotator.cc:138
void Save(dmlc::Stream *fo) const
save branch annotation to a JSON file
Definition: annotator.cc:145
size_t PredictBatch(const CSRBatch *batch, int verbose, bool pred_margin, float *out_result)
Make predictions on a batch of data rows (synchronously). This function internally divides the worklo...
Definition: predictor.cc:355
std::string annotate_in
name of model annotation file. Use the class treelite.Annotator to generate this file.
Definition: param.h:24
std::string codelib_path
the path of compiled dynamic shared library: used for prediction
Definition: cli_main.cc:78
int task
the task name
Definition: cli_main.cc:58
int parallel_comp
option to enable parallel compilation; if set to nonzero, the trees will be evely distributed into [p...
Definition: param.h:32
static Compiler * Create(const std::string &name, const compiler::CompilerParam &param)
create a compiler from given name
Definition: compiler.cc:15
int format
model format
Definition: cli_main.cc:64
compatiblity wrapper for systems that don&#39;t support OpenMP
Branch annotation tools.
int quantize
whether to quantize threshold points (0: no, >0: yes)
Definition: param.h:26
std::string name_annotate
name of generated annotation file
Definition: cli_main.cc:70
predictor class: wrapper for optimized prediction code
Definition: predictor.h:42
std::string name_pred
name of text file to save prediction
Definition: cli_main.cc:72
std::string compiler
which compiler to use
Definition: cli_main.cc:62
int test_format
test set file format
Definition: cli_main.cc:82
int train_format
training set file format
Definition: cli_main.cc:80