treelite
cli_main.cc
Go to the documentation of this file.
1 
8 #include <treelite/frontend.h>
9 #include <treelite/annotator.h>
10 #include <treelite/compiler.h>
11 #include <treelite/predictor.h>
12 #include <treelite/logging.h>
13 #include <treelite/omp.h>
14 #include <dmlc/config.h>
15 #include <dmlc/data.h>
16 #include <fstream>
17 #include <memory>
18 #include <vector>
19 #include <queue>
20 #include <iterator>
21 #include <string>
22 #include "./compiler/param.h"
23 #include "./common/filesystem.h"
24 #include "./compiler/ast/builder.h"
25 
26 namespace treelite {
27 
28 enum CLITask {
29  kCodegen = 0,
30  kAnnotate = 1,
31  kPredict = 2,
32  kDumpAST = 3
33 };
34 
35 enum InputFormat {
36  kLibSVM = 0,
37  kCSV = 1,
38  kLibFM = 2
39 };
40 
41 enum ModelFormat {
42  kXGBModel = 0,
43  kLGBModel = 1,
44  kProtobuf = 2
45 };
46 
47 inline const char* FileFormatString(int format) {
48  switch (format) {
49  case kLibSVM: return "libsvm";
50  case kCSV: return "csv";
51  case kLibFM: return "libfm";
52  }
53  return "";
54 }
55 
56 struct CLIParam : public dmlc::Parameter<CLIParam> {
58  int task;
60  int verbose;
62  std::string compiler;
64  int format;
66  std::string model_in;
68  std::string name_codegen_dir;
70  std::string name_annotate;
72  std::string name_pred;
74  std::string train_data_path;
76  std::string test_data_path;
78  std::string codelib_path;
83  // number of threads to use if OpenMP is enabled
84  // if equals 0, use system default
85  int nthread;
90  std::vector<std::pair<std::string, std::string> > cfg;
91 
92  // declare parameters
93  DMLC_DECLARE_PARAMETER(CLIParam) {
94  DMLC_DECLARE_FIELD(task).set_default(kCodegen)
95  .add_enum("codegen", kCodegen)
96  .add_enum("annotate", kAnnotate)
97  .add_enum("predict", kPredict)
98  .add_enum("dump_ast", kDumpAST)
99  .describe("Task to be performed by the CLI program.");
100  DMLC_DECLARE_FIELD(verbose).set_default(0)
101  .describe("Produce extra messages if >0");
102  DMLC_DECLARE_FIELD(compiler).set_default("ast_native")
103  .describe("kind of compiler to use");
104  DMLC_DECLARE_FIELD(format)
105  .add_enum("xgboost", kXGBModel)
106  .add_enum("lightgbm", kLGBModel)
107  .add_enum("protobuf", kProtobuf)
108  .describe("Model format");
109  DMLC_DECLARE_FIELD(model_in).set_default("NULL")
110  .describe("Input model path");
111  DMLC_DECLARE_FIELD(name_codegen_dir).set_default("codegen")
112  .describe("directory name for generated code files");
113  DMLC_DECLARE_FIELD(name_annotate).set_default("annotate.json")
114  .describe("Name of generated annotation file");
115  DMLC_DECLARE_FIELD(name_pred).set_default("pred.txt")
116  .describe("Name of text file to save prediction");
117  DMLC_DECLARE_FIELD(train_data_path).set_default("NULL")
118  .describe("Training data path; used for annotation");
119  DMLC_DECLARE_FIELD(test_data_path).set_default("NULL")
120  .describe("Test data path; used prediction");
121  DMLC_DECLARE_FIELD(codelib_path).set_default("NULL")
122  .describe("Path to compiled dynamic shared library (.so/.dll/.dylib); "
123  "used for prediction");
124  DMLC_DECLARE_FIELD(train_format).set_default(kLibSVM)
125  .add_enum("libsvm", kLibSVM)
126  .add_enum("csv", kCSV)
127  .add_enum("libfm", kLibFM)
128  .describe("training set data format");
129  DMLC_DECLARE_FIELD(test_format).set_default(kLibSVM)
130  .add_enum("libsvm", kLibSVM)
131  .add_enum("csv", kCSV)
132  .add_enum("libfm", kLibFM)
133  .describe("test set data format");
134  DMLC_DECLARE_FIELD(nthread).set_default(0).describe(
135  "Number of threads to use.");
136  DMLC_DECLARE_FIELD(pred_margin).set_default(0).describe(
137  "if >0, predict margin instead of transformed probability");
138 
139  // alias
140  DMLC_DECLARE_ALIAS(train_data_path, data);
141  DMLC_DECLARE_ALIAS(test_data_path, test:data);
142  DMLC_DECLARE_ALIAS(train_format, data_format);
143  }
144  // customized configure function of CLIParam
145  inline void Configure(const std::vector<std::pair<std::string, std::string> >& cfg) {
146  this->cfg = cfg;
147  this->InitAllowUnknown(cfg);
148  }
149 };
150 
151 DMLC_REGISTER_PARAMETER(CLIParam);
152 
153 Model ParseModel(const CLIParam& param) {
154  CHECK(param.model_in != "NULL") << "model_in parameter must be provided";
155  switch (param.format) {
156  case kXGBModel:
157  return frontend::LoadXGBoostModel(param.model_in.c_str());
158  case kLGBModel:
159  return frontend::LoadLightGBMModel(param.model_in.c_str());
160  case kProtobuf:
161  return frontend::LoadProtobufModel(param.model_in.c_str());
162  default:
163  LOG(FATAL) << "Unknown model format";
164  return {}; // avoid compiler warning
165  }
166 }
167 
168 void CLICodegen(const CLIParam& param) {
170  cparam.InitAllowUnknown(param.cfg);
171 
172  Model model = ParseModel(param);
173  LOG(INFO) << "model size = " << model.trees.size();
174 
175  // create directory named name_codegen_dir
176  common::filesystem::CreateDirectoryIfNotExist(param.name_codegen_dir.c_str());
177  std::unique_ptr<Compiler> compiler(Compiler::Create(param.compiler, cparam));
178  auto compiled_model = compiler->Compile(model);
179  if (param.verbose > 0) {
180  LOG(INFO) << "Code generation finished. Writing code to files...";
181  }
182 
183  if (!compiled_model.file_prefix.empty()) {
184  const std::vector<std::string> tokens
185  = common::Split(compiled_model.file_prefix, '/');
186  std::string accum = param.name_codegen_dir + "/" + tokens[0];
187  for (size_t i = 0; i < tokens.size(); ++i) {
188  common::filesystem::CreateDirectoryIfNotExist(accum.c_str());
189  if (i < tokens.size() - 1) {
190  accum += "/";
191  accum += tokens[i + 1];
192  }
193  }
194  }
195 
196  for (const auto& it : compiled_model.files) {
197  LOG(INFO) << "Writing file " << it.first << "...";
198  const std::string filename_full = param.name_codegen_dir + "/" + it.first;
199  common::WriteToFile(filename_full, it.second);
200  }
201 }
202 
203 void CLIAnnotate(const CLIParam& param) {
204  Model model = ParseModel(param);
205  LOG(INFO) << "model size = " << model.trees.size();
206 
207  CHECK_NE(param.train_data_path, "NULL")
208  << "Need to specify train_data_path paramter for annotation task";
209  std::unique_ptr<DMatrix> dmat(DMatrix::Create(param.train_data_path.c_str(),
210  FileFormatString(param.train_format),
211  param.nthread, param.verbose));
212  BranchAnnotator annotator;
213  annotator.Annotate(model, dmat.get(), param.nthread, param.verbose);
214  // write to json file
215  std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(
216  param.name_annotate.c_str(), "w"));
217  annotator.Save(fo.get());
218 }
219 
220 void CLIPredict(const CLIParam& param) {
221 
222  CHECK_NE(param.codelib_path, "NULL")
223  << "Need to specify codelib_path paramter for prediction task";
224  CHECK_NE(param.test_data_path, "NULL")
225  << "Need to specify test_data_path paramter for prediction task";
226  std::unique_ptr<DMatrix> dmat(DMatrix::Create(param.test_data_path.c_str(),
227  FileFormatString(param.test_format),
228  param.nthread, param.verbose));
229  std::unique_ptr<CSRBatch> batch = common::make_unique<CSRBatch>();
230  batch->data = &dmat->data[0];
231  batch->col_ind = &dmat->col_ind[0];
232  batch->row_ptr = &dmat->row_ptr[0];
233  batch->num_row = dmat->num_row;
234  batch->num_col = dmat->num_col;
235  Predictor predictor;
236  predictor.Load(param.codelib_path.c_str());
237  size_t result_size = predictor.QueryResultSize(batch.get());
238  std::vector<float> result(result_size);
239  result_size = predictor.PredictBatch(batch.get(), param.nthread,
240  param.verbose, param.pred_margin,
241  &result[0]);
242  // write to text file
243  std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(
244  param.name_pred.c_str(), "w"));
245  dmlc::ostream os(fo.get());
246  for (size_t i = 0; i < result_size; ++i) {
247  os << result[i] << std::endl;
248  }
249  // force flush before fo destruct.
250  os.set_stream(nullptr);
251 }
252 
253 void CLIDumpAST(const CLIParam& param) {
255  cparam.InitAllowUnknown(param.cfg);
256 
257  Model model = ParseModel(param);
258  LOG(INFO) << "model size = " << model.trees.size();
259  compiler::ASTBuilder builder;
260  builder.Build(model);
261  if (cparam.annotate_in != "NULL") {
262  BranchAnnotator annotator;
263  std::unique_ptr<dmlc::Stream> fi(
264  dmlc::Stream::Create(cparam.annotate_in.c_str(), "r"));
265  annotator.Load(fi.get());
266  const auto annotation = annotator.Get();
267  builder.AnnotateBranches(annotation);
268  LOG(INFO) << "Using branch annotation file `"
269  << cparam.annotate_in << "'";
270  }
271  builder.Split(cparam.parallel_comp);
272  if (cparam.quantize > 0) {
273  builder.QuantizeThresholds();
274  }
275  builder.Dump();
276 }
277 
278 int CLIRunTask(int argc, char* argv[]) {
279  if (argc < 2) {
280  printf("Usage: <config>\n");
281  return 0;
282  }
283 
284  std::vector<std::pair<std::string, std::string> > cfg;
285 
286  {
287  std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(argv[1], "r"));
288  dmlc::istream cfgfile(fi.get());
289  dmlc::Config itr(cfgfile);
290  for (const auto& entry : itr) {
291  cfg.push_back(std::make_pair(entry.first, entry.second));
292  }
293  }
294 
295  for (int i = 2; i < argc; ++i) {
296  char name[256], val[256];
297  if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) {
298  cfg.push_back(std::make_pair(std::string(name), std::string(val)));
299  }
300  }
301 
302  CLIParam param;
303  param.Configure(cfg);
304 
305  switch (param.task) {
306  case kCodegen: CLICodegen(param); break;
307  case kAnnotate: CLIAnnotate(param); break;
308  case kPredict: CLIPredict(param); break;
309  case kDumpAST: CLIDumpAST(param); break;
310  }
311 
312  return 0;
313 }
314 
315 } // namespace treelite
316 
317 int main(int argc, char* argv[]) {
319  = treelite::LogCallbackRegistryStore::Get();
320  registry->Register([] (const char* msg) {
321  std::cerr << msg << std::endl;
322  });
323  return treelite::CLIRunTask(argc, argv);
324 }
Load prediction function exported as a shared library.
branch annotator class
Definition: annotator.h:16
int pred_margin
whether to predict margin instead of transformed probability
Definition: cli_main.cc:88
Collection of front-end methods to load or construct ensemble model.
thin wrapper for tree ensemble model
Definition: tree.h:351
size_t PredictBatch(const CSRBatch *batch, int nthread, int verbose, bool pred_margin, float *out_result) const
make predictions on a batch of data rows
Definition: predictor.cc:245
std::vector< std::pair< std::string, std::string > > cfg
all the configurations
Definition: cli_main.cc:90
std::vector< Tree > trees
member trees
Definition: tree.h:353
size_t QueryResultSize(const CSRBatch *batch) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
Definition: predictor.h:94
parameters for tree compiler
Definition: param.h:16
void Annotate(const Model &model, const DMatrix *dmat, int nthread, int verbose)
annotate branches in a given model using frequency patterns in the training data. The annotation can ...
Definition: annotator.cc:114
logging facility for treelite
std::string test_data_path
the path of test set: used for prediction
Definition: cli_main.cc:76
std::string model_in
model file
Definition: cli_main.cc:66
static DMatrix * Create(const char *filename, const char *format, int nthread, int verbose)
construct a new DMatrix from a file
Definition: data.cc:17
Parameters for tree compiler.
Interface of compiler that compiles a tree ensemble model.
std::string name_codegen_dir
directory name for generated code files
Definition: cli_main.cc:68
std::vector< std::vector< size_t > > Get() const
fetch branch annotation. Usage example:
Definition: annotator.h:51
int verbose
whether verbose
Definition: cli_main.cc:60
Cross-platform wrapper for common filesystem functions.
std::string train_data_path
the path of training set: used for annotation
Definition: cli_main.cc:74
void Load(const char *name)
load the prediction function from dynamic shared library.
Definition: predictor.cc:204
void Load(dmlc::Stream *fi)
load branch annotation from a JSON file
Definition: annotator.cc:157
void Save(dmlc::Stream *fo) const
save branch annotation to a JSON file
Definition: annotator.cc:164
std::string annotate_in
name of model annotation file. Use the class treelite.Annotator to generate this file.
Definition: param.h:24
std::string codelib_path
the path of compiled dynamic shared library: used for prediction
Definition: cli_main.cc:78
int task
the task name
Definition: cli_main.cc:58
int parallel_comp
option to enable parallel compilation; if set to nonzero, the trees will be evely distributed into [p...
Definition: param.h:32
static Compiler * Create(const std::string &name, const compiler::CompilerParam &param)
create a compiler from given name
Definition: compiler.cc:15
int format
model format
Definition: cli_main.cc:64
compatiblity wrapper for systems that don&#39;t support OpenMP
Branch annotation tools.
int quantize
whether to quantize threshold points (0: no, >0: yes)
Definition: param.h:26
std::string name_annotate
name of generated annotation file
Definition: cli_main.cc:70
predictor class: wrapper for optimized prediction code
Definition: predictor.h:42
std::string name_pred
name of text file to save prediction
Definition: cli_main.cc:72
std::string compiler
which compiler to use
Definition: cli_main.cc:62
int test_format
test set file format
Definition: cli_main.cc:82
int train_format
training set file format
Definition: cli_main.cc:80