treelite
cli_main.cc
Go to the documentation of this file.
1 
8 #include <treelite/frontend.h>
9 #include <treelite/annotator.h>
10 #include <treelite/compiler.h>
11 #include <treelite/semantic.h>
12 #include <treelite/predictor.h>
13 #include <treelite/logging.h>
14 #include <treelite/omp.h>
15 #include <dmlc/config.h>
16 #include <dmlc/data.h>
17 #include <fstream>
18 #include <memory>
19 #include <vector>
20 #include <queue>
21 #include <iterator>
22 #include <string>
23 #include "./compiler/param.h"
24 #include "./common/filesystem.h"
25 
26 namespace treelite {
27 
28 enum CLITask {
29  kCodegen = 0,
30  kAnnotate = 1,
31  kPredict = 2
32 };
33 
34 enum InputFormat {
35  kLibSVM = 0,
36  kCSV = 1,
37  kLibFM = 2
38 };
39 
40 enum ModelFormat {
41  kXGBModel = 0,
42  kLGBModel = 1,
43  kProtobuf = 2
44 };
45 
46 inline const char* FileFormatString(int format) {
47  switch (format) {
48  case kLibSVM: return "libsvm";
49  case kCSV: return "csv";
50  case kLibFM: return "libfm";
51  }
52  return "";
53 }
54 
55 struct CLIParam : public dmlc::Parameter<CLIParam> {
57  int task;
59  int verbose;
61  int format;
63  std::string model_in;
65  std::string name_codegen_dir;
67  std::string name_annotate;
69  std::string name_pred;
71  std::string train_data_path;
73  std::string test_data_path;
75  std::string codelib_path;
80  // number of threads to use if OpenMP is enabled
81  // if equals 0, use system default
82  int nthread;
87  std::vector<std::pair<std::string, std::string> > cfg;
88 
89  // declare parameters
90  DMLC_DECLARE_PARAMETER(CLIParam) {
91  DMLC_DECLARE_FIELD(task).set_default(kCodegen)
92  .add_enum("codegen", kCodegen)
93  .add_enum("annotate", kAnnotate)
94  .add_enum("predict", kPredict)
95  .describe("Task to be performed by the CLI program.");
96  DMLC_DECLARE_FIELD(verbose).set_default(0)
97  .describe("Produce extra messages if >0");
98  DMLC_DECLARE_FIELD(format)
99  .add_enum("xgboost", kXGBModel)
100  .add_enum("lightgbm", kLGBModel)
101  .add_enum("protobuf", kProtobuf)
102  .describe("Model format");
103  DMLC_DECLARE_FIELD(model_in).set_default("NULL")
104  .describe("Input model path");
105  DMLC_DECLARE_FIELD(name_codegen_dir).set_default("codegen")
106  .describe("directory name for generated code files");
107  DMLC_DECLARE_FIELD(name_annotate).set_default("annotate.json")
108  .describe("Name of generated annotation file");
109  DMLC_DECLARE_FIELD(name_pred).set_default("pred.txt")
110  .describe("Name of text file to save prediction");
111  DMLC_DECLARE_FIELD(train_data_path).set_default("NULL")
112  .describe("Training data path; used for annotation");
113  DMLC_DECLARE_FIELD(test_data_path).set_default("NULL")
114  .describe("Test data path; used prediction");
115  DMLC_DECLARE_FIELD(codelib_path).set_default("NULL")
116  .describe("Path to compiled dynamic shared library (.so/.dll/.dylib); "
117  "used for prediction");
118  DMLC_DECLARE_FIELD(train_format).set_default(kLibSVM)
119  .add_enum("libsvm", kLibSVM)
120  .add_enum("csv", kCSV)
121  .add_enum("libfm", kLibFM)
122  .describe("training set data format");
123  DMLC_DECLARE_FIELD(test_format).set_default(kLibSVM)
124  .add_enum("libsvm", kLibSVM)
125  .add_enum("csv", kCSV)
126  .add_enum("libfm", kLibFM)
127  .describe("test set data format");
128  DMLC_DECLARE_FIELD(nthread).set_default(0).describe(
129  "Number of threads to use.");
130  DMLC_DECLARE_FIELD(pred_margin).set_default(0).describe(
131  "if >0, predict margin instead of transformed probability");
132 
133  // alias
134  DMLC_DECLARE_ALIAS(train_data_path, data);
135  DMLC_DECLARE_ALIAS(test_data_path, test:data);
136  DMLC_DECLARE_ALIAS(train_format, data_format);
137  }
138  // customized configure function of CLIParam
139  inline void Configure(const std::vector<std::pair<std::string, std::string> >& cfg) {
140  this->cfg = cfg;
141  this->InitAllowUnknown(cfg);
142  }
143 };
144 
145 DMLC_REGISTER_PARAMETER(CLIParam);
146 
147 Model ParseModel(const CLIParam& param) {
148  CHECK(param.model_in != "NULL") << "model_in parameter must be provided";
149  switch (param.format) {
150  case kXGBModel:
151  return frontend::LoadXGBoostModel(param.model_in.c_str());
152  case kLGBModel:
153  return frontend::LoadLightGBMModel(param.model_in.c_str());
154  case kProtobuf:
155  return frontend::LoadProtobufModel(param.model_in.c_str());
156  default:
157  LOG(FATAL) << "Unknown model format";
158  return {}; // avoid compiler warning
159  }
160 }
161 
162 void CLICodegen(const CLIParam& param) {
164  cparam.InitAllowUnknown(param.cfg);
165 
166  Model model = ParseModel(param);
167  LOG(INFO) << "model size = " << model.trees.size();
168 
169  // create directory named name_codegen_dir
170  common::filesystem::CreateDirectoryIfNotExist(param.name_codegen_dir.c_str());
171  const std::string basename
172  = common::filesystem::GetBasename(param.name_codegen_dir);
173 
174  std::unique_ptr<Compiler> compiler(Compiler::Create("recursive", cparam));
175  auto semantic_model = compiler->Compile(model);
176  /* write header */
177  const std::string header_filename
178  = param.name_codegen_dir + "/" + basename + ".h";
179  {
180  std::vector<std::string> lines;
181  common::TransformPushBack(&lines, semantic_model.common_header->Compile(),
182  [] (std::string line) {
183  return line;
184  });
185  lines.emplace_back();
186  std::ostringstream oss;
187  using FunctionEntry = semantic::SemanticModel::FunctionEntry;
188  std::copy(semantic_model.function_registry.begin(),
189  semantic_model.function_registry.end(),
190  std::ostream_iterator<FunctionEntry>(oss));
191  lines.push_back(oss.str());
192  common::WriteToFile(header_filename, lines);
193  }
194  /* write source file(s) */
195  if (semantic_model.units.size() == 1) { // single file (translation unit)
196  const std::string filename = basename + ".c";
197  const std::string filename_full = param.name_codegen_dir + "/" + filename;
198  const std::string objname = basename + ".o";
199  auto lines = semantic_model.units[0].Compile(header_filename);
200  common::WriteToFile(filename_full, lines);
201  } else { // multiple files (translation units)
202  for (size_t i = 0; i < semantic_model.units.size(); ++i) {
203  const std::string filename = basename + std::to_string(i) + ".c";
204  const std::string filename_full = param.name_codegen_dir + "/" + filename;
205  const std::string objname = basename + std::to_string(i) + ".o";
206  auto lines = semantic_model.units[i].Compile(header_filename);
207  common::WriteToFile(filename_full, lines);
208  }
209  }
210 }
211 
212 void CLIAnnotate(const CLIParam& param) {
213  Model model = ParseModel(param);
214  LOG(INFO) << "model size = " << model.trees.size();
215 
216  CHECK_NE(param.train_data_path, "NULL")
217  << "Need to specify train_data_path paramter for annotation task";
218  std::unique_ptr<DMatrix> dmat(DMatrix::Create(param.train_data_path.c_str(),
219  FileFormatString(param.train_format),
220  param.nthread, param.verbose));
221  BranchAnnotator annotator;
222  annotator.Annotate(model, dmat.get(), param.nthread, param.verbose);
223  // write to json file
224  std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(
225  param.name_annotate.c_str(), "w"));
226  annotator.Save(fo.get());
227 }
228 
229 void CLIPredict(const CLIParam& param) {
230 
231  CHECK_NE(param.codelib_path, "NULL")
232  << "Need to specify codelib_path paramter for prediction task";
233  CHECK_NE(param.test_data_path, "NULL")
234  << "Need to specify test_data_path paramter for prediction task";
235  std::unique_ptr<DMatrix> dmat(DMatrix::Create(param.test_data_path.c_str(),
236  FileFormatString(param.test_format),
237  param.nthread, param.verbose));
238  std::unique_ptr<CSRBatch> batch = common::make_unique<CSRBatch>();
239  batch->data = &dmat->data[0];
240  batch->col_ind = &dmat->col_ind[0];
241  batch->row_ptr = &dmat->row_ptr[0];
242  batch->num_row = dmat->num_row;
243  batch->num_col = dmat->num_col;
244  Predictor predictor;
245  predictor.Load(param.codelib_path.c_str());
246  size_t result_size = predictor.QueryResultSize(batch.get());
247  std::vector<float> result(result_size);
248  result_size = predictor.PredictBatch(batch.get(), param.nthread,
249  param.verbose, param.pred_margin,
250  &result[0]);
251  // write to text file
252  std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(
253  param.name_pred.c_str(), "w"));
254  dmlc::ostream os(fo.get());
255  for (size_t i = 0; i < result_size; ++i) {
256  os << result[i] << std::endl;
257  }
258  // force flush before fo destruct.
259  os.set_stream(nullptr);
260 }
261 
262 int CLIRunTask(int argc, char* argv[]) {
263  if (argc < 2) {
264  printf("Usage: <config>\n");
265  return 0;
266  }
267 
268  std::vector<std::pair<std::string, std::string> > cfg;
269 
270  {
271  std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(argv[1], "r"));
272  dmlc::istream cfgfile(fi.get());
273  dmlc::Config itr(cfgfile);
274  for (const auto& entry : itr) {
275  cfg.push_back(std::make_pair(entry.first, entry.second));
276  }
277  }
278 
279  for (int i = 2; i < argc; ++i) {
280  char name[256], val[256];
281  if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) {
282  cfg.push_back(std::make_pair(std::string(name), std::string(val)));
283  }
284  }
285 
286  CLIParam param;
287  param.Configure(cfg);
288 
289  switch (param.task) {
290  case kCodegen: CLICodegen(param); break;
291  case kAnnotate: CLIAnnotate(param); break;
292  case kPredict: CLIPredict(param); break;
293  }
294 
295  return 0;
296 }
297 
298 } // namespace treelite
299 
300 int main(int argc, char* argv[]) {
302  = treelite::LogCallbackRegistryStore::Get();
303  registry->Register([] (const char* msg) {
304  std::cerr << msg << std::endl;
305  });
306  return treelite::CLIRunTask(argc, argv);
307 }
Load prediction function exported as a shared library.
branch annotator class
Definition: annotator.h:16
int pred_margin
whether to predict margin instead of transformed probability
Definition: cli_main.cc:85
Collection of front-end methods to load or construct ensemble model.
thin wrapper for tree ensemble model
Definition: tree.h:351
size_t PredictBatch(const CSRBatch *batch, int nthread, int verbose, bool pred_margin, float *out_result) const
make predictions on a batch of data rows
Definition: predictor.cc:239
std::vector< std::pair< std::string, std::string > > cfg
all the configurations
Definition: cli_main.cc:87
std::vector< Tree > trees
member trees
Definition: tree.h:353
size_t QueryResultSize(const CSRBatch *batch) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
Definition: predictor.h:95
parameters for tree compiler
Definition: param.h:16
void Annotate(const Model &model, const DMatrix *dmat, int nthread, int verbose)
annotate branches in a given model using frequency patterns in the training data. The annotation can ...
Definition: annotator.cc:95
logging facility for treelite
std::string test_data_path
the path of test set: used for prediction
Definition: cli_main.cc:73
std::string model_in
model file
Definition: cli_main.cc:63
static DMatrix * Create(const char *filename, const char *format, int nthread, int verbose)
construct a new DMatrix from a file
Definition: data.cc:17
Parameters for tree compiler.
Interface of compiler that translates a tree ensemble model into a semantic model.
std::string name_codegen_dir
directory name for generated code files
Definition: cli_main.cc:65
int verbose
whether verbose
Definition: cli_main.cc:59
Cross-platform wrapper for common filesystem functions.
std::string train_data_path
the path of training set: used for annotation
Definition: cli_main.cc:71
void Load(const char *name)
load the prediction function from dynamic shared library.
Definition: predictor.cc:186
void Save(dmlc::Stream *fo) const
save branch annotation to a JSON file
Definition: annotator.cc:145
std::string codelib_path
the path of compiled dynamic shared library: used for prediction
Definition: cli_main.cc:75
int task
the task name
Definition: cli_main.cc:57
static Compiler * Create(const std::string &name, const compiler::CompilerParam &param)
create a compiler from given name
Definition: compiler.cc:15
int format
model format
Definition: cli_main.cc:61
compatiblity wrapper for systems that don&#39;t support OpenMP
Branch annotation tools.
Definition: semantic.h:91
Building blocks for semantic model of tree prediction code.
std::string name_annotate
name of generated annotation file
Definition: cli_main.cc:67
predictor class: wrapper for optimized prediction code
Definition: predictor.h:42
std::string name_pred
name of text file to save prediction
Definition: cli_main.cc:69
int test_format
test set file format
Definition: cli_main.cc:79
int train_format
training set file format
Definition: cli_main.cc:77