treelite
c_api.cc
Go to the documentation of this file.
1 
8 #include <memory>
9 #include <unordered_map>
10 #include <algorithm>
11 #include <treelite/annotator.h>
12 #include <treelite/c_api.h>
13 #include <treelite/compiler.h>
14 #include <treelite/data.h>
15 #include <treelite/frontend.h>
16 #include <dmlc/json.h>
17 #include <dmlc/thread_local.h>
18 #include "./c_api_error.h"
19 #include "../compiler/param.h"
20 #include "../common/filesystem.h"
21 #include "../common/math.h"
22 
23 using namespace treelite;
24 
25 namespace {
26 
27 struct CompilerHandleImpl {
28  std::string name;
29  std::vector<std::pair<std::string, std::string>> cfg;
30  std::unique_ptr<Compiler> compiler;
31  explicit CompilerHandleImpl(const std::string& name)
32  : name(name), cfg(), compiler(nullptr) {}
33  ~CompilerHandleImpl() = default;
34 };
35 
37 struct TreeliteAPIThreadLocalEntry {
39  std::string ret_str;
40 };
41 
42 // define threadlocal store for returning information
43 using TreeliteAPIThreadLocalStore
44  = dmlc::ThreadLocalStore<TreeliteAPIThreadLocalEntry>;
45 
46 } // anonymous namespace
47 
48 int TreeliteDMatrixCreateFromFile(const char* path,
49  const char* format,
50  int nthread,
51  int verbose,
52  DMatrixHandle* out) {
53  API_BEGIN();
54  *out = static_cast<DMatrixHandle>(DMatrix::Create(path, format,
55  nthread, verbose));
56  API_END();
57 }
58 
59 int TreeliteDMatrixCreateFromCSR(const float* data,
60  const unsigned* col_ind,
61  const size_t* row_ptr,
62  size_t num_row,
63  size_t num_col,
64  DMatrixHandle* out) {
65  API_BEGIN();
66  DMatrix* dmat = new DMatrix();
67  dmat->Clear();
68  auto& data_ = dmat->data;
69  auto& col_ind_ = dmat->col_ind;
70  auto& row_ptr_ = dmat->row_ptr;
71  data_.reserve(row_ptr[num_row]);
72  col_ind_.reserve(row_ptr[num_row]);
73  row_ptr_.reserve(num_row + 1);
74  for (size_t i = 0; i < num_row; ++i) {
75  const size_t jbegin = row_ptr[i];
76  const size_t jend = row_ptr[i + 1];
77  for (size_t j = jbegin; j < jend; ++j) {
78  if (!common::math::CheckNAN(data[j])) { // skip NaN
79  data_.push_back(data[j]);
80  CHECK_LT(col_ind[j], std::numeric_limits<uint32_t>::max())
81  << "feature index too big to fit into uint32_t";
82  col_ind_.push_back(static_cast<uint32_t>(col_ind[j]));
83  }
84  }
85  row_ptr_.push_back(data_.size());
86  }
87  data_.shrink_to_fit();
88  col_ind_.shrink_to_fit();
89  dmat->num_row = num_row;
90  dmat->num_col = num_col;
91  dmat->nelem = data_.size(); // some nonzeros may have been deleted as NAN
92 
93  *out = static_cast<DMatrixHandle>(dmat);
94  API_END();
95 }
96 
97 int TreeliteDMatrixCreateFromMat(const float* data,
98  size_t num_row,
99  size_t num_col,
100  float missing_value,
101  DMatrixHandle* out) {
102  const bool nan_missing = common::math::CheckNAN(missing_value);
103  API_BEGIN();
104  CHECK_LT(num_col, std::numeric_limits<uint32_t>::max())
105  << "num_col argument is too big";
106  DMatrix* dmat = new DMatrix();
107  dmat->Clear();
108  auto& data_ = dmat->data;
109  auto& col_ind_ = dmat->col_ind;
110  auto& row_ptr_ = dmat->row_ptr;
111  // make an educated guess for initial sizes,
112  // so as to present initial wave of allocation
113  const size_t guess_size
114  = std::min(std::min(num_row * num_col, num_row * 1000),
115  static_cast<size_t>(64 * 1024 * 1024));
116  data_.reserve(guess_size);
117  col_ind_.reserve(guess_size);
118  row_ptr_.reserve(num_row + 1);
119  const float* row = &data[0]; // points to beginning of each row
120  for (size_t i = 0; i < num_row; ++i, row += num_col) {
121  for (size_t j = 0; j < num_col; ++j) {
122  if (common::math::CheckNAN(row[j])) {
123  CHECK(nan_missing)
124  << "The missing_value argument must be set to NaN if there is any "
125  << "NaN in the matrix.";
126  } else if (nan_missing || row[j] != missing_value) {
127  // row[j] is a valid entry
128  data_.push_back(row[j]);
129  col_ind_.push_back(static_cast<uint32_t>(j));
130  }
131  }
132  row_ptr_.push_back(data_.size());
133  }
134  data_.shrink_to_fit();
135  col_ind_.shrink_to_fit();
136  dmat->num_row = num_row;
137  dmat->num_col = num_col;
138  dmat->nelem = data_.size(); // some nonzeros may have been deleted as NaN
139 
140  *out = static_cast<DMatrixHandle>(dmat);
141  API_END();
142 }
143 
145  size_t* out_num_row,
146  size_t* out_num_col,
147  size_t* out_nelem) {
148  API_BEGIN();
149  const DMatrix* dmat = static_cast<DMatrix*>(handle);
150  *out_num_row = dmat->num_row;
151  *out_num_col = dmat->num_col;
152  *out_nelem = dmat->nelem;
153  API_END();
154 }
155 
157  const char** out_preview) {
158  API_BEGIN();
159  const DMatrix* dmat = static_cast<DMatrix*>(handle);
160  std::string& ret_str = TreeliteAPIThreadLocalStore::Get()->ret_str;
161  std::ostringstream oss;
162  const size_t iend = (dmat->nelem <= 50) ? dmat->nelem : 25;
163  for (size_t i = 0; i < iend; ++i) {
164  const size_t row_ind =
165  std::upper_bound(&dmat->row_ptr[0], &dmat->row_ptr[dmat->num_row + 1], i)
166  - &dmat->row_ptr[0] - 1;
167  oss << " (" << row_ind << ", " << dmat->col_ind[i] << ")\t"
168  << dmat->data[i] << "\n";
169  }
170  if (dmat->nelem > 50) {
171  oss << " :\t:\n";
172  for (size_t i = dmat->nelem - 25; i < dmat->nelem; ++i) {
173  const size_t row_ind =
174  std::upper_bound(&dmat->row_ptr[0], &dmat->row_ptr[dmat->num_row + 1], i)
175  - &dmat->row_ptr[0] - 1;
176  oss << " (" << row_ind << ", " << dmat->col_ind[i] << ")\t"
177  << dmat->data[i] << "\n";
178  }
179  }
180  ret_str = oss.str();
181  *out_preview = ret_str.c_str();
182  API_END();
183 }
184 
186  const float** out_data,
187  const uint32_t** out_col_ind,
188  const size_t** out_row_ptr) {
189  API_BEGIN();
190  const DMatrix* dmat_ = static_cast<DMatrix*>(handle);
191  *out_data = &dmat_->data[0];
192  *out_col_ind = &dmat_->col_ind[0];
193  *out_row_ptr = &dmat_->row_ptr[0];
194  API_END();
195 }
196 
198  API_BEGIN();
199  delete static_cast<DMatrix*>(handle);
200  API_END();
201 }
202 
204  DMatrixHandle dmat,
205  int nthread,
206  int verbose,
207  AnnotationHandle* out) {
208  API_BEGIN();
209  BranchAnnotator* annotator = new BranchAnnotator();
210  const Model* model_ = static_cast<Model*>(model);
211  const DMatrix* dmat_ = static_cast<DMatrix*>(dmat);
212  annotator->Annotate(*model_, dmat_, nthread, verbose);
213  *out = static_cast<AnnotationHandle>(annotator);
214  API_END();
215 }
216 
218  const char* path) {
219  API_BEGIN();
220  const BranchAnnotator* annotator = static_cast<BranchAnnotator*>(handle);
221  std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(path, "w"));
222  annotator->Save(fo.get());
223  API_END();
224 }
225 
227  API_BEGIN();
228  delete static_cast<BranchAnnotator*>(handle);
229  API_END();
230 }
231 
232 int TreeliteCompilerCreate(const char* name,
233  CompilerHandle* out) {
234  API_BEGIN();
235  *out = static_cast<CompilerHandle>(new CompilerHandleImpl(name));
236  API_END();
237 }
238 
240  const char* name,
241  const char* value) {
242  API_BEGIN();
243  CompilerHandleImpl* impl = static_cast<CompilerHandleImpl*>(handle);
244  auto& cfg_ = impl->cfg;
245  std::string name_(name);
246  std::string value_(value);
247  // check for duplicate parameters
248  auto it = std::find_if(cfg_.begin(), cfg_.end(),
249  [&name_](const std::pair<std::string, std::string>& x) {
250  return x.first == name_;
251  });
252  if (it == cfg_.end()) {
253  cfg_.emplace_back(name_, value_);
254  } else {
255  it->second = value;
256  }
257  API_END();
258 }
259 
261  ModelHandle model,
262  int verbose,
263  const char* dirpath) {
264  API_BEGIN();
265  if (verbose > 0) { // verbose enabled
266  int ret = TreeliteCompilerSetParam(compiler, "verbose",
267  std::to_string(verbose).c_str());
268  if (ret < 0) { // SetParam failed
269  return ret;
270  }
271  }
272  const Model* model_ = static_cast<Model*>(model);
273  CompilerHandleImpl* impl = static_cast<CompilerHandleImpl*>(compiler);
274 
275  // create directory named dirpath
276  const std::string& dirpath_(dirpath);
277  common::filesystem::CreateDirectoryIfNotExist(dirpath);
278 
280  cparam.Init(impl->cfg, dmlc::parameter::kAllMatch);
281 
282  /* compile model */
283  impl->compiler.reset(Compiler::Create(impl->name, cparam));
284  auto compiled_model = impl->compiler->Compile(*model_);
285  if (verbose > 0) {
286  LOG(INFO) << "Code generation finished. Writing code to files...";
287  }
288 
289  for (const auto& it : compiled_model.files) {
290  if (verbose > 0) {
291  LOG(INFO) << "Writing file " << it.first << "...";
292  }
293  const std::string filename_full = dirpath_ + "/" + it.first;
294  if (it.second.is_binary) {
295  common::WriteToFile(filename_full, it.second.content_binary);
296  } else {
297  common::WriteToFile(filename_full, it.second.content);
298  }
299  }
300 
301  API_END();
302 }
303 
305  API_BEGIN();
306  delete static_cast<CompilerHandleImpl*>(handle);
307  API_END();
308 }
309 
310 int TreeliteLoadLightGBMModel(const char* filename,
311  ModelHandle* out) {
312  API_BEGIN();
313  Model* model = new Model(std::move(frontend::LoadLightGBMModel(filename)));
314  *out = static_cast<ModelHandle>(model);
315  API_END();
316 }
317 
318 int TreeliteLoadXGBoostModel(const char* filename,
319  ModelHandle* out) {
320  API_BEGIN();
321  Model* model = new Model(std::move(frontend::LoadXGBoostModel(filename)));
322  *out = static_cast<ModelHandle>(model);
323  API_END();
324 }
325 
326 int TreeliteLoadXGBoostModelFromMemoryBuffer(const void* buf, size_t len,
327  ModelHandle* out) {
328  API_BEGIN();
329  Model* model = new Model(std::move(frontend::LoadXGBoostModel(buf, len)));
330  *out = static_cast<ModelHandle>(model);
331  API_END();
332 }
333 
334 int TreeliteLoadProtobufModel(const char* filename,
335  ModelHandle* out) {
336  API_BEGIN();
337  Model* model = new Model(std::move(frontend::LoadProtobufModel(filename)));
338  *out = static_cast<ModelHandle>(model);
339  API_END();
340 }
341 
342 int TreeliteExportProtobufModel(const char* filename,
343  ModelHandle model) {
344  API_BEGIN();
345  Model* model_ = static_cast<Model*>(model);
346  frontend::ExportProtobufModel(filename, *model_);
347  API_END();
348 }
349 
351  API_BEGIN();
352  delete static_cast<Model*>(handle);
353  API_END();
354 }
355 
356 int TreeliteQueryNumTree(ModelHandle handle, size_t* out) {
357  API_BEGIN();
358  const Model* model_ = static_cast<Model*>(handle);
359  *out = model_->trees.size();
360  API_END();
361 }
362 
363 int TreeliteQueryNumFeature(ModelHandle handle, size_t* out) {
364  API_BEGIN();
365  const Model* model_ = static_cast<Model*>(handle);
366  *out = static_cast<size_t>(model_->num_feature);
367  API_END();
368 }
369 
370 int TreeliteQueryNumOutputGroups(ModelHandle handle, size_t* out) {
371  API_BEGIN();
372  const Model* model_ = static_cast<Model*>(handle);
373  *out = static_cast<size_t>(model_->num_output_group);
374  API_END();
375 }
376 
377 int TreeliteSetTreeLimit(ModelHandle handle, size_t limit) {
378  API_BEGIN();
379  CHECK_GT(limit, 0) << "limit should be greater than 0!";
380  auto* model_ = static_cast<Model*>(handle);
381  CHECK_GE(model_->trees.size(), limit)
382  << "Model contains less trees(" << model_->trees.size() << ") than limit";
383  model_->trees.resize(limit);
384  API_END();
385 }
386 
388  API_BEGIN();
389  auto builder = new frontend::TreeBuilder();
390  *out = static_cast<TreeBuilderHandle>(builder);
391  API_END();
392 }
393 
395  API_BEGIN();
396  delete static_cast<frontend::TreeBuilder*>(handle);
397  API_END();
398 }
399 
401  API_BEGIN();
402  auto builder = static_cast<frontend::TreeBuilder*>(handle);
403  CHECK(builder) << "Detected dangling reference to deleted TreeBuilder object";
404  return (builder->CreateNode(node_key)) ? 0 : -1;
405  API_END();
406 }
407 
409  API_BEGIN();
410  auto builder = static_cast<frontend::TreeBuilder*>(handle);
411  CHECK(builder) << "Detected dangling reference to deleted TreeBuilder object";
412  return (builder->DeleteNode(node_key)) ? 0 : -1;
413  API_END();
414 }
415 
417  API_BEGIN();
418  auto builder = static_cast<frontend::TreeBuilder*>(handle);
419  CHECK(builder) << "Detected dangling reference to deleted TreeBuilder object";
420  return (builder->SetRootNode(node_key)) ? 0 : -1;
421  API_END();
422 }
423 
425  int node_key, unsigned feature_id,
426  const char* opname,
427  double threshold, int default_left,
428  int left_child_key,
429  int right_child_key) {
430  API_BEGIN();
431  auto builder = static_cast<frontend::TreeBuilder*>(handle);
432  CHECK(builder) << "Detected dangling reference to deleted TreeBuilder object";
433  CHECK_GT(optable.count(opname), 0)
434  << "No operator `" << opname << "\" exists";
435  return (builder->SetNumericalTestNode(node_key, feature_id,
436  optable.at(opname),
437  static_cast<tl_float>(threshold),
438  (default_left != 0),
439  left_child_key, right_child_key)) \
440  ? 0 : -1;
441  API_END();
442 }
443 
445  TreeBuilderHandle handle,
446  int node_key, unsigned feature_id,
447  const unsigned int* left_categories,
448  size_t left_categories_len,
449  int default_left,
450  int left_child_key,
451  int right_child_key) {
452  API_BEGIN();
453  auto builder = static_cast<frontend::TreeBuilder*>(handle);
454  CHECK(builder) << "Detected dangling reference to deleted TreeBuilder object";
455  std::vector<uint32_t> vec(left_categories_len);
456  for (size_t i = 0; i < left_categories_len; ++i) {
457  CHECK(left_categories[i] <= std::numeric_limits<uint32_t>::max());
458  vec[i] = static_cast<uint32_t>(left_categories[i]);
459  }
460  return (builder->SetCategoricalTestNode(node_key, feature_id, vec,
461  (default_left != 0),
462  left_child_key, right_child_key)) \
463  ? 0 : -1;
464  API_END();
465 }
466 
468  double leaf_value) {
469  API_BEGIN();
470  auto builder = static_cast<frontend::TreeBuilder*>(handle);
471  CHECK(builder) << "Detected dangling reference to deleted TreeBuilder object";
472  return (builder->SetLeafNode(node_key, static_cast<tl_float>(leaf_value))) \
473  ? 0 : -1;
474  API_END();
475 }
476 
478  int node_key,
479  const double* leaf_vector,
480  size_t leaf_vector_len) {
481  API_BEGIN();
482  auto builder = static_cast<frontend::TreeBuilder*>(handle);
483  CHECK(builder) << "Detected dangling reference to deleted TreeBuilder object";
484  std::vector<tl_float> vec(leaf_vector_len);
485  for (size_t i = 0; i < leaf_vector_len; ++i) {
486  vec[i] = static_cast<tl_float>(leaf_vector[i]);
487  }
488  return (builder->SetLeafVectorNode(node_key, vec)) ? 0 : -1;
489  API_END();
490 }
491 
492 int TreeliteCreateModelBuilder(int num_feature,
493  int num_output_group,
494  int random_forest_flag,
495  ModelBuilderHandle* out) {
496  API_BEGIN();
497  auto builder = new frontend::ModelBuilder(num_feature, num_output_group,
498  (random_forest_flag != 0));
499  *out = static_cast<ModelBuilderHandle>(builder);
500  API_END();
501 }
502 
504  const char* name,
505  const char* value) {
506  API_BEGIN();
507  auto builder = static_cast<frontend::ModelBuilder*>(handle);
508  CHECK(builder) << "Detected dangling reference to deleted ModelBuilder object";
509  builder->SetModelParam(name, value);
510  API_END();
511 }
512 
514  API_BEGIN();
515  delete static_cast<frontend::ModelBuilder*>(handle);
516  API_END();
517 }
518 
520  TreeBuilderHandle tree_builder_handle,
521  int index) {
522  API_BEGIN();
523  auto model_builder = static_cast<frontend::ModelBuilder*>(handle);
524  CHECK(model_builder) << "Detected dangling reference to deleted ModelBuilder object";
525  auto tree_builder = static_cast<frontend::TreeBuilder*>(tree_builder_handle);
526  CHECK(tree_builder) << "Detected dangling reference to deleted TreeBuilder object";
527  return model_builder->InsertTree(tree_builder, index);
528  API_END();
529 }
530 
532  TreeBuilderHandle *out) {
533  API_BEGIN();
534  auto model_builder = static_cast<frontend::ModelBuilder*>(handle);
535  CHECK(model_builder) << "Detected dangling reference to deleted ModelBuilder object";
536  auto tree_builder = &model_builder->GetTree(index);
537  CHECK(tree_builder) << "Detected dangling reference to deleted TreeBuilder object";
538  *out = static_cast<TreeBuilderHandle>(tree_builder);
539  API_END();
540 }
541 
543  API_BEGIN();
544  auto builder = static_cast<frontend::ModelBuilder*>(handle);
545  CHECK(builder) << "Detected dangling reference to deleted ModelBuilder object";
546  return (builder->DeleteTree(index)) ? 0 : -1;
547  API_END();
548 }
549 
551  ModelHandle* out) {
552  API_BEGIN();
553  auto builder = static_cast<frontend::ModelBuilder*>(handle);
554  CHECK(builder) << "Detected dangling reference to deleted ModelBuilder object";
555  Model* model = new Model();
556  const bool result = builder->CommitModel(model);
557  if (result) {
558  *out = static_cast<ModelHandle>(model);
559  return 0;
560  } else {
561  return -1;
562  }
563  API_END();
564 }
C API of treelite, used for interfacing with other languages This header is excluded from the runtime...
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
Definition: tree.h:438
int TreeliteModelBuilderSetModelParam(ModelBuilderHandle handle, const char *name, const char *value)
Set a model parameter.
Definition: c_api.cc:503
branch annotator class
Definition: annotator.h:17
int TreeliteModelBuilderGetTree(ModelBuilderHandle handle, int index, TreeBuilderHandle *out)
Get a reference to a tree in the ensemble.
Definition: c_api.cc:531
std::vector< float > data
feature values
Definition: data.h:18
Collection of front-end methods to load or construct ensemble model.
thin wrapper for tree ensemble model
Definition: tree.h:428
int TreeliteLoadXGBoostModel(const char *filename, ModelHandle *out)
load a model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tree en...
Definition: c_api.cc:318
int TreeliteFreeModel(ModelHandle handle)
delete model from memory
Definition: c_api.cc:350
int TreeliteAnnotationSave(AnnotationHandle handle, const char *path)
save branch annotation to a JSON file
Definition: c_api.cc:217
int TreeliteQueryNumTree(ModelHandle handle, size_t *out)
Query the number of trees in the model.
Definition: c_api.cc:356
int TreeliteDMatrixCreateFromCSR(const float *data, const unsigned *col_ind, const size_t *row_ptr, size_t num_row, size_t num_col, DMatrixHandle *out)
create DMatrix from a (in-memory) CSR matrix
Definition: c_api.cc:59
tree builder class
Definition: frontend.h:70
std::vector< Tree > trees
member trees
Definition: tree.h:430
int TreeliteDMatrixFree(DMatrixHandle handle)
delete DMatrix from memory
Definition: c_api.cc:197
int TreeliteModelBuilderDeleteTree(ModelBuilderHandle handle, int index)
Remove a tree from the ensemble.
Definition: c_api.cc:542
parameters for tree compiler
Definition: param.h:18
Input data structure of treelite.
void Annotate(const Model &model, const DMatrix *dmat, int nthread, int verbose)
annotate branches in a given model using frequency patterns in the training data. The annotation can ...
Definition: annotator.cc:95
int TreeliteCompilerSetParam(CompilerHandle handle, const char *name, const char *value)
set a parameter for a compiler
Definition: c_api.cc:239
int TreeliteDeleteModelBuilder(ModelBuilderHandle handle)
Delete a model builder from memory.
Definition: c_api.cc:513
int TreeliteTreeBuilderSetCategoricalTestNode(TreeBuilderHandle handle, int node_key, unsigned feature_id, const unsigned int *left_categories, size_t left_categories_len, int default_left, int left_child_key, int right_child_key)
Turn an empty node into a test node with categorical split. A list defines all categories that would ...
Definition: c_api.cc:444
void SetModelParam(const char *name, const char *value)
Set a model parameter.
Definition: builder.cc:261
int TreeliteCreateTreeBuilder(TreeBuilderHandle *out)
Create a new tree builder.
Definition: c_api.cc:387
int TreeliteAnnotationFree(AnnotationHandle handle)
delete branch annotation from memory
Definition: c_api.cc:226
std::vector< uint32_t > col_ind
feature indices
Definition: data.h:20
void Save(dmlc::Stream *fo) const
save branch annotation to a JSON file
Definition: annotator.cc:145
static DMatrix * Create(const char *filename, const char *format, int nthread, int verbose)
construct a new DMatrix from a file
Definition: data.cc:17
int TreeliteAnnotateBranch(ModelHandle model, DMatrixHandle dmat, int nthread, int verbose, AnnotationHandle *out)
annotate branches in a given model using frequency patterns in the training data. ...
Definition: c_api.cc:203
int TreeliteQueryNumOutputGroups(ModelHandle handle, size_t *out)
Query the number of output groups of the model.
Definition: c_api.cc:370
Interface of compiler that compiles a tree ensemble model.
int TreeliteDMatrixCreateFromMat(const float *data, size_t num_row, size_t num_col, float missing_value, DMatrixHandle *out)
create DMatrix from a (in-memory) dense matrix
Definition: c_api.cc:97
int TreeliteSetTreeLimit(ModelHandle handle, size_t limit)
keep first N trees of model, limit must smaller than number of trees.
Definition: c_api.cc:377
model builder class
Definition: frontend.h:160
int TreeliteModelBuilderInsertTree(ModelBuilderHandle handle, TreeBuilderHandle tree_builder_handle, int index)
Insert a tree at specified location.
Definition: c_api.cc:519
int TreeliteDMatrixGetArrays(DMatrixHandle handle, const float **out_data, const uint32_t **out_col_ind, const size_t **out_row_ptr)
extract three arrays (data, col_ind, row_ptr) that define a DMatrix.
Definition: c_api.cc:185
size_t num_row
number of rows
Definition: data.h:24
a simple data matrix in CSR (Compressed Sparse Row) storage
Definition: data.h:16
int TreeliteTreeBuilderSetRootNode(TreeBuilderHandle handle, int node_key)
Set a node as the root of a tree.
Definition: c_api.cc:416
void * TreeBuilderHandle
handle to tree builder class
Definition: c_api.h:27
int TreeliteDeleteTreeBuilder(TreeBuilderHandle handle)
Delete a tree builder from memory.
Definition: c_api.cc:394
void * AnnotationHandle
handle to branch annotation data
Definition: c_api.h:31
int TreeliteModelBuilderCommitModel(ModelBuilderHandle handle, ModelHandle *out)
finalize the model and produce the in-memory representation
Definition: c_api.cc:550
int TreeliteCompilerGenerateCode(CompilerHandle compiler, ModelHandle model, int verbose, const char *dirpath)
generate prediction code from a tree ensemble model. The code will be C99 compliant. One header file (.h) will be generated, along with one or more source files (.c).
Definition: c_api.cc:260
void Clear()
clear all data fields
Definition: data.h:33
int TreeliteCreateModelBuilder(int num_feature, int num_output_group, int random_forest_flag, ModelBuilderHandle *out)
Create a new model builder.
Definition: c_api.cc:492
const std::unordered_map< std::string, Operator > optable
conversion table from string to operator, defined in optable.cc
Definition: optable.cc:12
int TreeliteTreeBuilderSetLeafVectorNode(TreeBuilderHandle handle, int node_key, const double *leaf_vector, size_t leaf_vector_len)
Turn an empty node into a leaf vector node The leaf vector (collection of multiple leaf weights per l...
Definition: c_api.cc:477
int TreeliteLoadXGBoostModelFromMemoryBuffer(const void *buf, size_t len, ModelHandle *out)
load an XGBoost model from a memory buffer.
Definition: c_api.cc:326
int TreeliteDMatrixGetPreview(DMatrixHandle handle, const char **out_preview)
produce a human-readable preview of a DMatrix Will print first and last 25 non-zero entries...
Definition: c_api.cc:156
void * ModelHandle
handle to a decision tree ensemble model
Definition: c_api.h:25
int TreeliteLoadLightGBMModel(const char *filename, ModelHandle *out)
load a model file generated by LightGBM (Microsoft/LightGBM). The model file must contain a decision ...
Definition: c_api.cc:310
static Compiler * Create(const std::string &name, const compiler::CompilerParam &param)
create a compiler from given name
Definition: compiler.cc:15
double tl_float
float type to be used internally
Definition: base.h:17
void * DMatrixHandle
handle to a data matrix
Definition: c_api.h:23
size_t num_col
number of columns
Definition: data.h:26
int TreeliteLoadProtobufModel(const char *filename, ModelHandle *out)
load a model in Protocol Buffers format. Protocol Buffers (google/protobuf) is a language- and platfo...
Definition: c_api.cc:334
int TreeliteTreeBuilderSetNumericalTestNode(TreeBuilderHandle handle, int node_key, unsigned feature_id, const char *opname, double threshold, int default_left, int left_child_key, int right_child_key)
Turn an empty node into a test node with numerical split. The test is in the form [feature value] OP ...
Definition: c_api.cc:424
int TreeliteCompilerCreate(const char *name, CompilerHandle *out)
create a compiler with a given name
Definition: c_api.cc:232
int TreeliteExportProtobufModel(const char *filename, ModelHandle model)
export a model in Protocol Buffers format. Protocol Buffers (google/protobuf) is a language- and plat...
Definition: c_api.cc:342
Branch annotation tools.
int TreeliteTreeBuilderDeleteNode(TreeBuilderHandle handle, int node_key)
Remove a node from a tree.
Definition: c_api.cc:408
int TreeliteQueryNumFeature(ModelHandle handle, size_t *out)
Query the number of features used in the model.
Definition: c_api.cc:363
int TreeliteDMatrixGetDimension(DMatrixHandle handle, size_t *out_num_row, size_t *out_num_col, size_t *out_nelem)
get dimensions of a DMatrix
Definition: c_api.cc:144
void * ModelBuilderHandle
handle to ensemble builder class
Definition: c_api.h:29
size_t nelem
number of nonzero entries
Definition: data.h:28
std::vector< size_t > row_ptr
pointer to row headers; length of [num_row] + 1
Definition: data.h:22
void * CompilerHandle
handle to compiler class
Definition: c_api.h:33
int TreeliteTreeBuilderSetLeafNode(TreeBuilderHandle handle, int node_key, double leaf_value)
Turn an empty node into a leaf node.
Definition: c_api.cc:467
int TreeliteDMatrixCreateFromFile(const char *path, const char *format, int nthread, int verbose, DMatrixHandle *out)
create DMatrix from a file
Definition: c_api.cc:48
int TreeliteCompilerFree(CompilerHandle handle)
delete compiler from memory
Definition: c_api.cc:304
TreeBuilder & GetTree(int index)
Get a reference to a tree in the ensemble.
Definition: builder.cc:319
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:435
int TreeliteTreeBuilderCreateNode(TreeBuilderHandle handle, int node_key)
Create an empty node within a tree.
Definition: c_api.cc:400