Treelite
c_api.cc
Go to the documentation of this file.
1 
9 #include <treelite/annotator.h>
10 #include <treelite/c_api.h>
11 #include <treelite/compiler.h>
13 #include <treelite/data.h>
14 #include <treelite/filesystem.h>
15 #include <treelite/frontend.h>
16 #include <treelite/math.h>
17 #include <dmlc/thread_local.h>
18 #include <memory>
19 #include <algorithm>
20 #include "./c_api_error.h"
21 
22 using namespace treelite;
23 
24 namespace {
25 
26 struct CompilerHandleImpl {
27  std::string name;
28  std::vector<std::pair<std::string, std::string>> cfg;
29  std::unique_ptr<Compiler> compiler;
30  explicit CompilerHandleImpl(const std::string& name)
31  : name(name), cfg(), compiler(nullptr) {}
32  ~CompilerHandleImpl() = default;
33 };
34 
36 struct TreeliteAPIThreadLocalEntry {
38  std::string ret_str;
39 };
40 
41 // define threadlocal store for returning information
42 using TreeliteAPIThreadLocalStore
43  = dmlc::ThreadLocalStore<TreeliteAPIThreadLocalEntry>;
44 
45 } // anonymous namespace
46 
47 int TreeliteDMatrixCreateFromFile(const char* path,
48  const char* format,
49  int nthread,
50  int verbose,
51  DMatrixHandle* out) {
52  API_BEGIN();
53  *out = static_cast<DMatrixHandle>(DMatrix::Create(path, format,
54  nthread, verbose));
55  API_END();
56 }
57 
58 int TreeliteDMatrixCreateFromCSR(const float* data,
59  const unsigned* col_ind,
60  const size_t* row_ptr,
61  size_t num_row,
62  size_t num_col,
63  DMatrixHandle* out) {
64  API_BEGIN();
65  std::unique_ptr<DMatrix> dmat{new DMatrix()};
66  dmat->Clear();
67  auto& data_ = dmat->data;
68  auto& col_ind_ = dmat->col_ind;
69  auto& row_ptr_ = dmat->row_ptr;
70  data_.reserve(row_ptr[num_row]);
71  col_ind_.reserve(row_ptr[num_row]);
72  row_ptr_.reserve(num_row + 1);
73  for (size_t i = 0; i < num_row; ++i) {
74  const size_t jbegin = row_ptr[i];
75  const size_t jend = row_ptr[i + 1];
76  for (size_t j = jbegin; j < jend; ++j) {
77  if (!math::CheckNAN(data[j])) { // skip NaN
78  data_.push_back(data[j]);
79  CHECK_LT(col_ind[j], std::numeric_limits<uint32_t>::max())
80  << "feature index too big to fit into uint32_t";
81  col_ind_.push_back(static_cast<uint32_t>(col_ind[j]));
82  }
83  }
84  row_ptr_.push_back(data_.size());
85  }
86  data_.shrink_to_fit();
87  col_ind_.shrink_to_fit();
88  dmat->num_row = num_row;
89  dmat->num_col = num_col;
90  dmat->nelem = data_.size(); // some nonzeros may have been deleted as NAN
91 
92  *out = static_cast<DMatrixHandle>(dmat.release());
93  API_END();
94 }
95 
96 int TreeliteDMatrixCreateFromMat(const float* data,
97  size_t num_row,
98  size_t num_col,
99  float missing_value,
100  DMatrixHandle* out) {
101  const bool nan_missing = math::CheckNAN(missing_value);
102  API_BEGIN();
103  CHECK_LT(num_col, std::numeric_limits<uint32_t>::max())
104  << "num_col argument is too big";
105  std::unique_ptr<DMatrix> dmat{new DMatrix()};
106  dmat->Clear();
107  auto& data_ = dmat->data;
108  auto& col_ind_ = dmat->col_ind;
109  auto& row_ptr_ = dmat->row_ptr;
110  // make an educated guess for initial sizes,
111  // so as to present initial wave of allocation
112  const size_t guess_size
113  = std::min(std::min(num_row * num_col, num_row * 1000),
114  static_cast<size_t>(64 * 1024 * 1024));
115  data_.reserve(guess_size);
116  col_ind_.reserve(guess_size);
117  row_ptr_.reserve(num_row + 1);
118  const float* row = &data[0]; // points to beginning of each row
119  for (size_t i = 0; i < num_row; ++i, row += num_col) {
120  for (size_t j = 0; j < num_col; ++j) {
121  if (math::CheckNAN(row[j])) {
122  CHECK(nan_missing)
123  << "The missing_value argument must be set to NaN if there is any "
124  << "NaN in the matrix.";
125  } else if (nan_missing || row[j] != missing_value) {
126  // row[j] is a valid entry
127  data_.push_back(row[j]);
128  col_ind_.push_back(static_cast<uint32_t>(j));
129  }
130  }
131  row_ptr_.push_back(data_.size());
132  }
133  data_.shrink_to_fit();
134  col_ind_.shrink_to_fit();
135  dmat->num_row = num_row;
136  dmat->num_col = num_col;
137  dmat->nelem = data_.size(); // some nonzeros may have been deleted as NaN
138 
139  *out = static_cast<DMatrixHandle>(dmat.release());
140  API_END();
141 }
142 
144  size_t* out_num_row,
145  size_t* out_num_col,
146  size_t* out_nelem) {
147  API_BEGIN();
148  const DMatrix* dmat = static_cast<DMatrix*>(handle);
149  *out_num_row = dmat->num_row;
150  *out_num_col = dmat->num_col;
151  *out_nelem = dmat->nelem;
152  API_END();
153 }
154 
156  const char** out_preview) {
157  API_BEGIN();
158  const DMatrix* dmat = static_cast<DMatrix*>(handle);
159  std::string& ret_str = TreeliteAPIThreadLocalStore::Get()->ret_str;
160  std::ostringstream oss;
161  const size_t iend = (dmat->nelem <= 50) ? dmat->nelem : 25;
162  for (size_t i = 0; i < iend; ++i) {
163  const size_t row_ind =
164  std::upper_bound(&dmat->row_ptr[0], &dmat->row_ptr[dmat->num_row + 1], i)
165  - &dmat->row_ptr[0] - 1;
166  oss << " (" << row_ind << ", " << dmat->col_ind[i] << ")\t"
167  << dmat->data[i] << "\n";
168  }
169  if (dmat->nelem > 50) {
170  oss << " :\t:\n";
171  for (size_t i = dmat->nelem - 25; i < dmat->nelem; ++i) {
172  const size_t row_ind =
173  std::upper_bound(&dmat->row_ptr[0], &dmat->row_ptr[dmat->num_row + 1], i)
174  - &dmat->row_ptr[0] - 1;
175  oss << " (" << row_ind << ", " << dmat->col_ind[i] << ")\t"
176  << dmat->data[i] << "\n";
177  }
178  }
179  ret_str = oss.str();
180  *out_preview = ret_str.c_str();
181  API_END();
182 }
183 
185  const float** out_data,
186  const uint32_t** out_col_ind,
187  const size_t** out_row_ptr) {
188  API_BEGIN();
189  const DMatrix* dmat_ = static_cast<DMatrix*>(handle);
190  *out_data = &dmat_->data[0];
191  *out_col_ind = &dmat_->col_ind[0];
192  *out_row_ptr = &dmat_->row_ptr[0];
193  API_END();
194 }
195 
197  API_BEGIN();
198  delete static_cast<DMatrix*>(handle);
199  API_END();
200 }
201 
203  DMatrixHandle dmat,
204  int nthread,
205  int verbose,
206  AnnotationHandle* out) {
207  API_BEGIN();
208  std::unique_ptr<BranchAnnotator> annotator{new BranchAnnotator()};
209  const Model* model_ = static_cast<Model*>(model);
210  const DMatrix* dmat_ = static_cast<DMatrix*>(dmat);
211  annotator->Annotate(*model_, dmat_, nthread, verbose);
212  *out = static_cast<AnnotationHandle>(annotator.release());
213  API_END();
214 }
215 
217  const char* path) {
218  API_BEGIN();
219  const BranchAnnotator* annotator = static_cast<BranchAnnotator*>(handle);
220  std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(path, "w"));
221  annotator->Save(fo.get());
222  API_END();
223 }
224 
226  API_BEGIN();
227  delete static_cast<BranchAnnotator*>(handle);
228  API_END();
229 }
230 
231 int TreeliteCompilerCreate(const char* name,
232  CompilerHandle* out) {
233  API_BEGIN();
234  std::unique_ptr<CompilerHandleImpl> compiler{new CompilerHandleImpl(name)};
235  *out = static_cast<CompilerHandle>(compiler.release());
236  API_END();
237 }
238 
240  const char* name,
241  const char* value) {
242  API_BEGIN();
243  CompilerHandleImpl* impl = static_cast<CompilerHandleImpl*>(handle);
244  auto& cfg_ = impl->cfg;
245  std::string name_(name);
246  std::string value_(value);
247  // check for duplicate parameters
248  auto it = std::find_if(cfg_.begin(), cfg_.end(),
249  [&name_](const std::pair<std::string, std::string>& x) {
250  return x.first == name_;
251  });
252  if (it == cfg_.end()) {
253  cfg_.emplace_back(name_, value_);
254  } else {
255  it->second = value;
256  }
257  API_END();
258 }
259 
261  ModelHandle model,
262  int verbose,
263  const char* dirpath) {
264  API_BEGIN();
265  if (verbose > 0) { // verbose enabled
266  int ret = TreeliteCompilerSetParam(compiler, "verbose",
267  std::to_string(verbose).c_str());
268  if (ret < 0) { // SetParam failed
269  return ret;
270  }
271  }
272  const Model* model_ = static_cast<Model*>(model);
273  CompilerHandleImpl* impl = static_cast<CompilerHandleImpl*>(compiler);
274 
275  // create directory named dirpath
276  const std::string& dirpath_(dirpath);
277  filesystem::CreateDirectoryIfNotExist(dirpath);
278 
280  cparam.Init(impl->cfg, dmlc::parameter::kAllMatch);
281 
282  /* compile model */
283  impl->compiler.reset(Compiler::Create(impl->name, cparam));
284  auto compiled_model = impl->compiler->Compile(*model_);
285  if (verbose > 0) {
286  LOG(INFO) << "Code generation finished. Writing code to files...";
287  }
288 
289  for (const auto& it : compiled_model.files) {
290  if (verbose > 0) {
291  LOG(INFO) << "Writing file " << it.first << "...";
292  }
293  const std::string filename_full = dirpath_ + "/" + it.first;
294  if (it.second.is_binary) {
295  filesystem::WriteToFile(filename_full, it.second.content_binary);
296  } else {
297  filesystem::WriteToFile(filename_full, it.second.content);
298  }
299  }
300 
301  API_END();
302 }
303 
305  API_BEGIN();
306  delete static_cast<CompilerHandleImpl*>(handle);
307  API_END();
308 }
309 
310 int TreeliteLoadLightGBMModel(const char* filename,
311  ModelHandle* out) {
312  API_BEGIN();
313  std::unique_ptr<Model> model{new Model()};
314  frontend::LoadLightGBMModel(filename, model.get());
315  *out = static_cast<ModelHandle>(model.release());
316  API_END();
317 }
318 
319 int TreeliteLoadXGBoostModel(const char* filename,
320  ModelHandle* out) {
321  API_BEGIN();
322  std::unique_ptr<Model> model{new Model()};
323  frontend::LoadXGBoostModel(filename, model.get());
324  *out = static_cast<ModelHandle>(model.release());
325  API_END();
326 }
327 
328 int TreeliteLoadXGBoostModelFromMemoryBuffer(const void* buf, size_t len,
329  ModelHandle* out) {
330  API_BEGIN();
331  std::unique_ptr<Model> model{new Model()};
332  frontend::LoadXGBoostModel(buf, len, model.get());
333  *out = static_cast<ModelHandle>(model.release());
334  API_END();
335 }
336 
337 int TreeliteLoadProtobufModel(const char* filename,
338  ModelHandle* out) {
339  API_BEGIN();
340  std::unique_ptr<Model> model{new Model()};
341  frontend::LoadProtobufModel(filename, model.get());
342  *out = static_cast<ModelHandle>(model.release());
343  API_END();
344 }
345 
346 int TreeliteExportProtobufModel(const char* filename,
347  ModelHandle model) {
348  API_BEGIN();
349  auto model_ = static_cast<Model*>(model);
350  frontend::ExportProtobufModel(filename, *model_);
351  API_END();
352 }
353 
355  API_BEGIN();
356  delete static_cast<Model*>(handle);
357  API_END();
358 }
359 
360 int TreeliteQueryNumTree(ModelHandle handle, size_t* out) {
361  API_BEGIN();
362  auto model_ = static_cast<const Model*>(handle);
363  *out = model_->trees.size();
364  API_END();
365 }
366 
367 int TreeliteQueryNumFeature(ModelHandle handle, size_t* out) {
368  API_BEGIN();
369  auto model_ = static_cast<const Model*>(handle);
370  *out = static_cast<size_t>(model_->num_feature);
371  API_END();
372 }
373 
374 int TreeliteQueryNumOutputGroups(ModelHandle handle, size_t* out) {
375  API_BEGIN();
376  auto model_ = static_cast<const Model*>(handle);
377  *out = static_cast<size_t>(model_->num_output_group);
378  API_END();
379 }
380 
381 int TreeliteSetTreeLimit(ModelHandle handle, size_t limit) {
382  API_BEGIN();
383  CHECK_GT(limit, 0) << "limit should be greater than 0!";
384  auto model_ = static_cast<Model*>(handle);
385  CHECK_GE(model_->trees.size(), limit)
386  << "Model contains less trees(" << model_->trees.size() << ") than limit";
387  model_->trees.resize(limit);
388  API_END();
389 }
390 
392  API_BEGIN();
393  std::unique_ptr<frontend::TreeBuilder> builder{new frontend::TreeBuilder()};
394  *out = static_cast<TreeBuilderHandle>(builder.release());
395  API_END();
396 }
397 
399  API_BEGIN();
400  delete static_cast<frontend::TreeBuilder*>(handle);
401  API_END();
402 }
403 
405  API_BEGIN();
406  auto builder = static_cast<frontend::TreeBuilder*>(handle);
407  CHECK(builder) << "Detected dangling reference to deleted TreeBuilder object";
408  builder->CreateNode(node_key);
409  API_END();
410 }
411 
413  API_BEGIN();
414  auto builder = static_cast<frontend::TreeBuilder*>(handle);
415  CHECK(builder) << "Detected dangling reference to deleted TreeBuilder object";
416  builder->DeleteNode(node_key);
417  API_END();
418 }
419 
421  API_BEGIN();
422  auto builder = static_cast<frontend::TreeBuilder*>(handle);
423  CHECK(builder) << "Detected dangling reference to deleted TreeBuilder object";
424  builder->SetRootNode(node_key);
425  API_END();
426 }
427 
429  int node_key, unsigned feature_id,
430  const char* opname,
431  float threshold, int default_left,
432  int left_child_key,
433  int right_child_key) {
434  API_BEGIN();
435  auto builder = static_cast<frontend::TreeBuilder*>(handle);
436  CHECK(builder) << "Detected dangling reference to deleted TreeBuilder object";
437  builder->SetNumericalTestNode(node_key, feature_id, opname, static_cast<tl_float>(threshold),
438  (default_left != 0), left_child_key, right_child_key);
439  API_END();
440 }
441 
443  TreeBuilderHandle handle,
444  int node_key, unsigned feature_id,
445  const unsigned int* left_categories,
446  size_t left_categories_len,
447  int default_left,
448  int left_child_key,
449  int right_child_key) {
450  API_BEGIN();
451  auto builder = static_cast<frontend::TreeBuilder*>(handle);
452  CHECK(builder) << "Detected dangling reference to deleted TreeBuilder object";
453  std::vector<uint32_t> vec(left_categories_len);
454  for (size_t i = 0; i < left_categories_len; ++i) {
455  CHECK(left_categories[i] <= std::numeric_limits<uint32_t>::max());
456  vec[i] = static_cast<uint32_t>(left_categories[i]);
457  }
458  builder->SetCategoricalTestNode(node_key, feature_id, vec, (default_left != 0),
459  left_child_key, right_child_key);
460  API_END();
461 }
462 
463 int TreeliteTreeBuilderSetLeafNode(TreeBuilderHandle handle, int node_key, float leaf_value) {
464  API_BEGIN();
465  auto builder = static_cast<frontend::TreeBuilder*>(handle);
466  CHECK(builder) << "Detected dangling reference to deleted TreeBuilder object";
467  builder->SetLeafNode(node_key, static_cast<tl_float>(leaf_value));
468  API_END();
469 }
470 
472  int node_key,
473  const float* leaf_vector,
474  size_t leaf_vector_len) {
475  API_BEGIN();
476  auto builder = static_cast<frontend::TreeBuilder*>(handle);
477  CHECK(builder) << "Detected dangling reference to deleted TreeBuilder object";
478  std::vector<tl_float> vec(leaf_vector_len);
479  for (size_t i = 0; i < leaf_vector_len; ++i) {
480  vec[i] = static_cast<tl_float>(leaf_vector[i]);
481  }
482  builder->SetLeafVectorNode(node_key, vec);
483  API_END();
484 }
485 
486 int TreeliteCreateModelBuilder(int num_feature,
487  int num_output_group,
488  int random_forest_flag,
489  ModelBuilderHandle* out) {
490  API_BEGIN();
491  std::unique_ptr<frontend::ModelBuilder> builder{new frontend::ModelBuilder(
492  num_feature,
493  num_output_group,
494  (random_forest_flag != 0))};
495  *out = static_cast<ModelBuilderHandle>(builder.release());
496  API_END();
497 }
498 
500  const char* name,
501  const char* value) {
502  API_BEGIN();
503  auto builder = static_cast<frontend::ModelBuilder*>(handle);
504  CHECK(builder) << "Detected dangling reference to deleted ModelBuilder object";
505  builder->SetModelParam(name, value);
506  API_END();
507 }
508 
510  API_BEGIN();
511  delete static_cast<frontend::ModelBuilder*>(handle);
512  API_END();
513 }
514 
516  TreeBuilderHandle tree_builder_handle,
517  int index) {
518  API_BEGIN();
519  auto model_builder = static_cast<frontend::ModelBuilder*>(handle);
520  CHECK(model_builder) << "Detected dangling reference to deleted ModelBuilder object";
521  auto tree_builder = static_cast<frontend::TreeBuilder*>(tree_builder_handle);
522  CHECK(tree_builder) << "Detected dangling reference to deleted TreeBuilder object";
523  return model_builder->InsertTree(tree_builder, index);
524  API_END();
525 }
526 
528  TreeBuilderHandle *out) {
529  API_BEGIN();
530  auto model_builder = static_cast<frontend::ModelBuilder*>(handle);
531  CHECK(model_builder) << "Detected dangling reference to deleted ModelBuilder object";
532  auto tree_builder = model_builder->GetTree(index);
533  CHECK(tree_builder) << "Detected dangling reference to deleted TreeBuilder object";
534  *out = static_cast<TreeBuilderHandle>(tree_builder);
535  API_END();
536 }
537 
539  API_BEGIN();
540  auto builder = static_cast<frontend::ModelBuilder*>(handle);
541  CHECK(builder) << "Detected dangling reference to deleted ModelBuilder object";
542  builder->DeleteTree(index);
543  API_END();
544 }
545 
547  ModelHandle* out) {
548  API_BEGIN();
549  auto builder = static_cast<frontend::ModelBuilder*>(handle);
550  CHECK(builder) << "Detected dangling reference to deleted ModelBuilder object";
551  std::unique_ptr<Model> model{new Model()};
552  builder->CommitModel(model.get());
553  *out = static_cast<ModelHandle>(model.release());
554  API_END();
555 }
Some useful math utilities.
C API of Treelite, used for interfacing with other languages This header is excluded from the runtime...
int TreeliteTreeBuilderSetNumericalTestNode(TreeBuilderHandle handle, int node_key, unsigned feature_id, const char *opname, float threshold, int default_left, int left_child_key, int right_child_key)
Turn an empty node into a test node with numerical split. The test is in the form [feature value] OP ...
Definition: c_api.cc:428
Parameters for tree compiler.
int TreeliteModelBuilderSetModelParam(ModelBuilderHandle handle, const char *name, const char *value)
Set a model parameter.
Definition: c_api.cc:499
branch annotator class
Definition: annotator.h:17
int TreeliteModelBuilderGetTree(ModelBuilderHandle handle, int index, TreeBuilderHandle *out)
Get a reference to a tree in the ensemble.
Definition: c_api.cc:527
std::vector< float > data
feature values
Definition: data.h:18
Collection of front-end methods to load or construct ensemble model.
thin wrapper for tree ensemble model
Definition: tree.h:409
float tl_float
float type to be used internally
Definition: base.h:18
int TreeliteLoadXGBoostModel(const char *filename, ModelHandle *out)
load a model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tree en...
Definition: c_api.cc:319
void DeleteNode(int node_key)
Remove a node from a tree.
Definition: builder.cc:105
#define API_BEGIN()
macro to guard beginning and end section of all functions
Definition: c_api_error.h:15
int TreeliteFreeModel(ModelHandle handle)
delete model from memory
Definition: c_api.cc:354
int TreeliteAnnotationSave(AnnotationHandle handle, const char *path)
save branch annotation to a JSON file
Definition: c_api.cc:216
int TreeliteQueryNumTree(ModelHandle handle, size_t *out)
Query the number of trees in the model.
Definition: c_api.cc:360
int TreeliteDMatrixCreateFromCSR(const float *data, const unsigned *col_ind, const size_t *row_ptr, size_t num_row, size_t num_col, DMatrixHandle *out)
create DMatrix from a (in-memory) CSR matrix
Definition: c_api.cc:58
tree builder class
Definition: frontend.h:70
std::vector< Tree > trees
member trees
Definition: tree.h:411
int TreeliteDMatrixFree(DMatrixHandle handle)
delete DMatrix from memory
Definition: c_api.cc:196
int TreeliteModelBuilderDeleteTree(ModelBuilderHandle handle, int index)
Remove a tree from the ensemble.
Definition: c_api.cc:538
parameters for tree compiler
Input data structure of Treelite.
int TreeliteCompilerSetParam(CompilerHandle handle, const char *name, const char *value)
set a parameter for a compiler
Definition: c_api.cc:239
int TreeliteDeleteModelBuilder(ModelBuilderHandle handle)
Delete a model builder from memory.
Definition: c_api.cc:509
int TreeliteTreeBuilderSetCategoricalTestNode(TreeBuilderHandle handle, int node_key, unsigned feature_id, const unsigned int *left_categories, size_t left_categories_len, int default_left, int left_child_key, int right_child_key)
Turn an empty node into a test node with categorical split. A list defines all categories that would ...
Definition: c_api.cc:442
void SetModelParam(const char *name, const char *value)
Set a model parameter.
Definition: builder.cc:249
int TreeliteCreateTreeBuilder(TreeBuilderHandle *out)
Create a new tree builder.
Definition: c_api.cc:391
void SetRootNode(int node_key)
Set a node as the root of a tree.
Definition: builder.cc:128
int TreeliteAnnotationFree(AnnotationHandle handle)
delete branch annotation from memory
Definition: c_api.cc:225
std::vector< uint32_t > col_ind
feature indices
Definition: data.h:20
TreeBuilder * GetTree(int index)
Get a reference to a tree in the ensemble.
Definition: builder.cc:297
void Save(dmlc::Stream *fo) const
save branch annotation to a JSON file
Definition: annotator.cc:142
static DMatrix * Create(const char *filename, const char *format, int nthread, int verbose)
construct a new DMatrix from a file
Definition: data.cc:17
int TreeliteTreeBuilderSetLeafVectorNode(TreeBuilderHandle handle, int node_key, const float *leaf_vector, size_t leaf_vector_len)
Turn an empty node into a leaf vector node The leaf vector (collection of multiple leaf weights per l...
Definition: c_api.cc:471
int TreeliteAnnotateBranch(ModelHandle model, DMatrixHandle dmat, int nthread, int verbose, AnnotationHandle *out)
annotate branches in a given model using frequency patterns in the training data. ...
Definition: c_api.cc:202
int TreeliteQueryNumOutputGroups(ModelHandle handle, size_t *out)
Query the number of output groups of the model.
Definition: c_api.cc:374
Interface of compiler that compiles a tree ensemble model.
int TreeliteDMatrixCreateFromMat(const float *data, size_t num_row, size_t num_col, float missing_value, DMatrixHandle *out)
create DMatrix from a (in-memory) dense matrix
Definition: c_api.cc:96
int TreeliteSetTreeLimit(ModelHandle handle, size_t limit)
keep first N trees of model, limit must smaller than number of trees.
Definition: c_api.cc:381
model builder class
Definition: frontend.h:156
int TreeliteModelBuilderInsertTree(ModelBuilderHandle handle, TreeBuilderHandle tree_builder_handle, int index)
Insert a tree at specified location.
Definition: c_api.cc:515
int TreeliteDMatrixGetArrays(DMatrixHandle handle, const float **out_data, const uint32_t **out_col_ind, const size_t **out_row_ptr)
extract three arrays (data, col_ind, row_ptr) that define a DMatrix.
Definition: c_api.cc:184
Cross-platform wrapper for common filesystem functions.
size_t num_row
number of rows
Definition: data.h:24
a simple data matrix in CSR (Compressed Sparse Row) storage
Definition: data.h:16
void SetNumericalTestNode(int node_key, unsigned feature_id, const char *op, tl_float threshold, bool default_left, int left_child_key, int right_child_key)
Turn an empty node into a numerical test node; the test is in the form [feature value] OP [threshold]...
Definition: builder.cc:138
int TreeliteTreeBuilderSetRootNode(TreeBuilderHandle handle, int node_key)
Set a node as the root of a tree.
Definition: c_api.cc:420
void * TreeBuilderHandle
handle to tree builder class
Definition: c_api.h:27
Error handling for C API.
int TreeliteDeleteTreeBuilder(TreeBuilderHandle handle)
Delete a tree builder from memory.
Definition: c_api.cc:398
void * AnnotationHandle
handle to branch annotation data
Definition: c_api.h:31
int TreeliteModelBuilderCommitModel(ModelBuilderHandle handle, ModelHandle *out)
finalize the model and produce the in-memory representation
Definition: c_api.cc:546
int TreeliteCompilerGenerateCode(CompilerHandle compiler, ModelHandle model, int verbose, const char *dirpath)
generate prediction code from a tree ensemble model. The code will be C99 compliant. One header file (.h) will be generated, along with one or more source files (.c).
Definition: c_api.cc:260
void Clear()
clear all data fields
Definition: data.h:33
int TreeliteCreateModelBuilder(int num_feature, int num_output_group, int random_forest_flag, ModelBuilderHandle *out)
Create a new model builder.
Definition: c_api.cc:486
void DeleteTree(int index)
Remove a tree from the ensemble.
Definition: builder.cc:307
int TreeliteLoadXGBoostModelFromMemoryBuffer(const void *buf, size_t len, ModelHandle *out)
load an XGBoost model from a memory buffer.
Definition: c_api.cc:328
int TreeliteDMatrixGetPreview(DMatrixHandle handle, const char **out_preview)
produce a human-readable preview of a DMatrix Will print first and last 25 non-zero entries...
Definition: c_api.cc:155
void * ModelHandle
handle to a decision tree ensemble model
Definition: c_api.h:25
int TreeliteLoadLightGBMModel(const char *filename, ModelHandle *out)
load a model file generated by LightGBM (Microsoft/LightGBM). The model file must contain a decision ...
Definition: c_api.cc:310
static Compiler * Create(const std::string &name, const compiler::CompilerParam &param)
create a compiler from given name
Definition: compiler.cc:15
void * DMatrixHandle
handle to a data matrix
Definition: c_api.h:23
size_t num_col
number of columns
Definition: data.h:26
int TreeliteLoadProtobufModel(const char *filename, ModelHandle *out)
load a model in Protocol Buffers format. Protocol Buffers (google/protobuf) is a language- and platfo...
Definition: c_api.cc:337
int TreeliteCompilerCreate(const char *name, CompilerHandle *out)
create a compiler with a given name
Definition: c_api.cc:231
int TreeliteExportProtobufModel(const char *filename, ModelHandle model)
export a model in Protocol Buffers format. Protocol Buffers (google/protobuf) is a language- and plat...
Definition: c_api.cc:346
int TreeliteTreeBuilderSetLeafNode(TreeBuilderHandle handle, int node_key, float leaf_value)
Turn an empty node into a leaf node.
Definition: c_api.cc:463
Branch annotation tools.
int TreeliteTreeBuilderDeleteNode(TreeBuilderHandle handle, int node_key)
Remove a node from a tree.
Definition: c_api.cc:412
void SetLeafNode(int node_key, tl_float leaf_value)
Turn an empty node into a leaf node.
Definition: builder.cc:219
int TreeliteQueryNumFeature(ModelHandle handle, size_t *out)
Query the number of features used in the model.
Definition: c_api.cc:367
int TreeliteDMatrixGetDimension(DMatrixHandle handle, size_t *out_num_row, size_t *out_num_col, size_t *out_nelem)
get dimensions of a DMatrix
Definition: c_api.cc:143
void * ModelBuilderHandle
handle to ensemble builder class
Definition: c_api.h:29
void CreateNode(int node_key)
Create an empty node within a tree.
Definition: builder.cc:98
size_t nelem
number of nonzero entries
Definition: data.h:28
std::vector< size_t > row_ptr
pointer to row headers; length of [num_row] + 1
Definition: data.h:22
void * CompilerHandle
handle to compiler class
Definition: c_api.h:33
int TreeliteDMatrixCreateFromFile(const char *path, const char *format, int nthread, int verbose, DMatrixHandle *out)
create DMatrix from a file
Definition: c_api.cc:47
int TreeliteCompilerFree(CompilerHandle handle)
delete compiler from memory
Definition: c_api.cc:304
#define API_END()
every function starts with API_BEGIN(); and finishes with API_END() or API_END_HANDLE_ERROR ...
Definition: c_api_error.h:18
int TreeliteTreeBuilderCreateNode(TreeBuilderHandle handle, int node_key)
Create an empty node within a tree.
Definition: c_api.cc:404