Treelite
c_api.cc
Go to the documentation of this file.
1 
9 #include <treelite/annotator.h>
10 #include <treelite/c_api.h>
11 #include <treelite/c_api_error.h>
12 #include <treelite/compiler.h>
14 #include <treelite/data.h>
15 #include <treelite/filesystem.h>
16 #include <treelite/frontend.h>
17 #include <treelite/tree.h>
18 #include <treelite/math.h>
19 #include <treelite/gtil.h>
20 #include <dmlc/thread_local.h>
21 #include <memory>
22 #include <algorithm>
23 
24 using namespace treelite;
25 
26 namespace {
27 
28 struct CompilerHandleImpl {
29  std::string name;
30  std::vector<std::pair<std::string, std::string>> cfg;
31  std::unique_ptr<Compiler> compiler;
32  explicit CompilerHandleImpl(const std::string& name)
33  : name(name), cfg(), compiler(nullptr) {}
34  ~CompilerHandleImpl() = default;
35 };
36 
37 } // anonymous namespace
38 
40  ModelHandle model, DMatrixHandle dmat, int nthread, int verbose, AnnotationHandle* out) {
41  API_BEGIN();
42  std::unique_ptr<BranchAnnotator> annotator{new BranchAnnotator()};
43  const Model* model_ = static_cast<Model*>(model);
44  const auto* dmat_ = static_cast<const DMatrix*>(dmat);
45  CHECK(dmat_) << "Found a dangling reference to DMatrix";
46  annotator->Annotate(*model_, dmat_, nthread, verbose);
47  *out = static_cast<AnnotationHandle>(annotator.release());
48  API_END();
49 }
50 
52  const char* path) {
53  API_BEGIN();
54  const BranchAnnotator* annotator = static_cast<BranchAnnotator*>(handle);
55  std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(path, "w"));
56  annotator->Save(fo.get());
57  API_END();
58 }
59 
61  API_BEGIN();
62  delete static_cast<BranchAnnotator*>(handle);
63  API_END();
64 }
65 
66 int TreeliteCompilerCreate(const char* name,
67  CompilerHandle* out) {
68  API_BEGIN();
69  std::unique_ptr<CompilerHandleImpl> compiler{new CompilerHandleImpl(name)};
70  *out = static_cast<CompilerHandle>(compiler.release());
71  API_END();
72 }
73 
75  const char* name,
76  const char* value) {
77  API_BEGIN();
78  CompilerHandleImpl* impl = static_cast<CompilerHandleImpl*>(handle);
79  auto& cfg_ = impl->cfg;
80  std::string name_(name);
81  std::string value_(value);
82  // check for duplicate parameters
83  auto it = std::find_if(cfg_.begin(), cfg_.end(),
84  [&name_](const std::pair<std::string, std::string>& x) {
85  return x.first == name_;
86  });
87  if (it == cfg_.end()) {
88  cfg_.emplace_back(name_, value_);
89  } else {
90  it->second = value;
91  }
92  API_END();
93 }
94 
96  ModelHandle model,
97  int verbose,
98  const char* dirpath) {
99  API_BEGIN();
100  if (verbose > 0) { // verbose enabled
101  int ret = TreeliteCompilerSetParam(compiler, "verbose",
102  std::to_string(verbose).c_str());
103  if (ret < 0) { // SetParam failed
104  return ret;
105  }
106  }
107  const Model* model_ = static_cast<Model*>(model);
108  CompilerHandleImpl* impl = static_cast<CompilerHandleImpl*>(compiler);
109 
110  // create directory named dirpath
111  const std::string& dirpath_(dirpath);
112  filesystem::CreateDirectoryIfNotExist(dirpath);
113 
115  cparam.Init(impl->cfg, dmlc::parameter::kAllMatch);
116 
117  /* compile model */
118  impl->compiler.reset(Compiler::Create(impl->name, cparam));
119  auto compiled_model = impl->compiler->Compile(*model_);
120  if (verbose > 0) {
121  LOG(INFO) << "Code generation finished. Writing code to files...";
122  }
123 
124  for (const auto& it : compiled_model.files) {
125  if (verbose > 0) {
126  LOG(INFO) << "Writing file " << it.first << "...";
127  }
128  const std::string filename_full = dirpath_ + "/" + it.first;
129  if (it.second.is_binary) {
130  filesystem::WriteToFile(filename_full, it.second.content_binary);
131  } else {
132  filesystem::WriteToFile(filename_full, it.second.content);
133  }
134  }
135 
136  API_END();
137 }
138 
140  API_BEGIN();
141  delete static_cast<CompilerHandleImpl*>(handle);
142  API_END();
143 }
144 
145 int TreeliteLoadLightGBMModel(const char* filename, ModelHandle* out) {
146  API_BEGIN();
147  std::unique_ptr<Model> model = frontend::LoadLightGBMModel(filename);
148  *out = static_cast<ModelHandle>(model.release());
149  API_END();
150 }
151 
152 int TreeliteLoadXGBoostModel(const char* filename, ModelHandle* out) {
153  API_BEGIN();
154  std::unique_ptr<Model> model = frontend::LoadXGBoostModel(filename);
155  *out = static_cast<ModelHandle>(model.release());
156  API_END();
157 }
158 
159 int TreeliteLoadXGBoostJSON(const char* filename, ModelHandle* out) {
160  API_BEGIN();
161  std::unique_ptr<Model> model = frontend::LoadXGBoostJSONModel(filename);
162  *out = static_cast<ModelHandle>(model.release());
163  API_END();
164 }
165 
166 int TreeliteLoadXGBoostJSONString(const char* json_str, size_t length, ModelHandle* out) {
167  API_BEGIN();
168  std::unique_ptr<Model> model = frontend::LoadXGBoostJSONModelString(json_str, length);
169  *out = static_cast<ModelHandle>(model.release());
170  API_END();
171 }
172 
173 int TreeliteLoadXGBoostModelFromMemoryBuffer(const void* buf, size_t len, ModelHandle* out) {
174  API_BEGIN();
175  std::unique_ptr<Model> model = frontend::LoadXGBoostModel(buf, len);
176  *out = static_cast<ModelHandle>(model.release());
177  API_END();
178 }
179 
181  int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left,
182  const int64_t** children_right, const int64_t** feature, const double** threshold,
183  const double** value, const int64_t** n_node_samples, const double** impurity,
184  ModelHandle* out) {
185  API_BEGIN();
186  std::unique_ptr<Model> model = frontend::LoadSKLearnRandomForestRegressor(
187  n_estimators, n_features, node_count, children_left, children_right, feature, threshold,
188  value, n_node_samples, impurity);
189  *out = static_cast<ModelHandle>(model.release());
190  API_END();
191 }
192 
194  int n_estimators, int n_features, int n_classes, const int64_t* node_count,
195  const int64_t** children_left, const int64_t** children_right, const int64_t** feature,
196  const double** threshold, const double** value, const int64_t** n_node_samples,
197  const double** impurity, ModelHandle* out) {
198  API_BEGIN();
199  std::unique_ptr<Model> model = frontend::LoadSKLearnRandomForestClassifier(
200  n_estimators, n_features, n_classes, node_count, children_left, children_right, feature,
201  threshold, value, n_node_samples, impurity);
202  *out = static_cast<ModelHandle>(model.release());
203  API_END();
204 }
205 
207  int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left,
208  const int64_t** children_right, const int64_t** feature, const double** threshold,
209  const double** value, const int64_t** n_node_samples, const double** impurity,
210  ModelHandle* out) {
211  API_BEGIN();
212  std::unique_ptr<Model> model = frontend::LoadSKLearnGradientBoostingRegressor(
213  n_estimators, n_features, node_count, children_left, children_right, feature, threshold,
214  value, n_node_samples, impurity);
215  *out = static_cast<ModelHandle>(model.release());
216  API_END();
217 }
218 
220  int n_estimators, int n_features, int n_classes, const int64_t* node_count,
221  const int64_t** children_left, const int64_t** children_right, const int64_t** feature,
222  const double** threshold, const double** value, const int64_t** n_node_samples,
223  const double** impurity, ModelHandle* out) {
224  API_BEGIN();
225  std::unique_ptr<Model> model = frontend::LoadSKLearnGradientBoostingClassifier(
226  n_estimators, n_features, n_classes, node_count, children_left, children_right, feature,
227  threshold, value, n_node_samples, impurity);
228  *out = static_cast<ModelHandle>(model.release());
229  API_END();
230 }
231 
232 int TreeliteSerializeModel(const char* filename, ModelHandle handle) {
233  API_BEGIN();
234  FILE* fp = std::fopen(filename, "wb");
235  CHECK(fp) << "Failed to open file '" << filename << "'";
236  auto* model_ = static_cast<Model*>(handle);
237  model_->SerializeToFile(fp);
238  std::fclose(fp);
239  API_END();
240 }
241 
242 int TreeliteDeserializeModel(const char* filename, ModelHandle* out) {
243  API_BEGIN();
244  FILE* fp = std::fopen(filename, "rb");
245  CHECK(fp) << "Failed to open file '" << filename << "'";
246  std::unique_ptr<Model> model = Model::DeserializeFromFile(fp);
247  std::fclose(fp);
248  *out = static_cast<ModelHandle>(model.release());
249  API_END();
250 }
251 
253  API_BEGIN();
254  delete static_cast<Model*>(handle);
255  API_END();
256 }
257 
258 int TreeliteGTILGetPredictOutputSize(ModelHandle handle, size_t num_row, size_t* out) {
259  API_BEGIN();
260  const auto* model_ = static_cast<const Model*>(handle);
261  *out = gtil::GetPredictOutputSize(model_, num_row);
262  API_END();
263 }
264 
265 int TreeliteGTILPredict(ModelHandle handle, const float* input, size_t num_row, float* output,
266  int pred_transform, size_t* out_result_size) {
267  API_BEGIN();
268  const auto* model_ = static_cast<const Model*>(handle);
269  *out_result_size =
270  gtil::Predict(model_, input, num_row, output, (pred_transform == 1));
271  API_END();
272 }
273 
274 int TreeliteQueryNumTree(ModelHandle handle, size_t* out) {
275  API_BEGIN();
276  const auto* model_ = static_cast<const Model*>(handle);
277  *out = model_->GetNumTree();
278  API_END();
279 }
280 
281 int TreeliteQueryNumFeature(ModelHandle handle, size_t* out) {
282  API_BEGIN();
283  const auto* model_ = static_cast<const Model*>(handle);
284  *out = static_cast<size_t>(model_->num_feature);
285  API_END();
286 }
287 
288 int TreeliteQueryNumClass(ModelHandle handle, size_t* out) {
289  API_BEGIN();
290  const auto* model_ = static_cast<const Model*>(handle);
291  *out = static_cast<size_t>(model_->task_param.num_class);
292  API_END();
293 }
294 
295 int TreeliteSetTreeLimit(ModelHandle handle, size_t limit) {
296  API_BEGIN();
297  CHECK_GT(limit, 0) << "limit should be greater than 0!";
298  auto* model_ = static_cast<Model*>(handle);
299  const size_t num_tree = model_->GetNumTree();
300  CHECK_GE(num_tree, limit) << "Model contains less trees(" << num_tree << ") than limit";
301  model_->SetTreeLimit(limit);
302  API_END();
303 }
304 
305 int TreeliteTreeBuilderCreateValue(const void* init_value, const char* type, ValueHandle* out) {
306  API_BEGIN();
307  std::unique_ptr<frontend::Value> value = std::make_unique<frontend::Value>();
308  *value = frontend::Value::Create(init_value, GetTypeInfoByName(type));
309  *out = static_cast<ValueHandle>(value.release());
310  API_END();
311 }
312 
314  API_BEGIN();
315  delete static_cast<frontend::Value*>(handle);
316  API_END();
317 }
318 
319 int TreeliteCreateTreeBuilder(const char* threshold_type, const char* leaf_output_type,
320  TreeBuilderHandle* out) {
321  API_BEGIN();
322  std::unique_ptr<frontend::TreeBuilder> builder{
323  new frontend::TreeBuilder(GetTypeInfoByName(threshold_type),
324  GetTypeInfoByName(leaf_output_type))
325  };
326  *out = static_cast<TreeBuilderHandle>(builder.release());
327  API_END();
328 }
329 
331  API_BEGIN();
332  delete static_cast<frontend::TreeBuilder*>(handle);
333  API_END();
334 }
335 
337  API_BEGIN();
338  auto* builder = static_cast<frontend::TreeBuilder*>(handle);
339  CHECK(builder) << "Detected dangling reference to deleted TreeBuilder object";
340  builder->CreateNode(node_key);
341  API_END();
342 }
343 
345  API_BEGIN();
346  auto* builder = static_cast<frontend::TreeBuilder*>(handle);
347  CHECK(builder) << "Detected dangling reference to deleted TreeBuilder object";
348  builder->DeleteNode(node_key);
349  API_END();
350 }
351 
353  API_BEGIN();
354  auto* builder = static_cast<frontend::TreeBuilder*>(handle);
355  CHECK(builder) << "Detected dangling reference to deleted TreeBuilder object";
356  builder->SetRootNode(node_key);
357  API_END();
358 }
359 
361  TreeBuilderHandle handle, int node_key, unsigned feature_id, const char* opname,
362  ValueHandle threshold, int default_left, int left_child_key, int right_child_key) {
363  API_BEGIN();
364  auto* builder = static_cast<frontend::TreeBuilder*>(handle);
365  CHECK(builder) << "Detected dangling reference to deleted TreeBuilder object";
366  builder->SetNumericalTestNode(node_key, feature_id, opname,
367  *static_cast<const frontend::Value*>(threshold),
368  (default_left != 0), left_child_key, right_child_key);
369  API_END();
370 }
371 
373  TreeBuilderHandle handle, int node_key, unsigned feature_id,
374  const unsigned int* left_categories, size_t left_categories_len, int default_left,
375  int left_child_key, int right_child_key) {
376  API_BEGIN();
377  auto* builder = static_cast<frontend::TreeBuilder*>(handle);
378  CHECK(builder) << "Detected dangling reference to deleted TreeBuilder object";
379  std::vector<uint32_t> vec(left_categories_len);
380  for (size_t i = 0; i < left_categories_len; ++i) {
381  CHECK(left_categories[i] <= std::numeric_limits<uint32_t>::max());
382  vec[i] = static_cast<uint32_t>(left_categories[i]);
383  }
384  builder->SetCategoricalTestNode(node_key, feature_id, vec, (default_left != 0),
385  left_child_key, right_child_key);
386  API_END();
387 }
388 
389 int TreeliteTreeBuilderSetLeafNode(TreeBuilderHandle handle, int node_key, ValueHandle leaf_value) {
390  API_BEGIN();
391  auto* builder = static_cast<frontend::TreeBuilder*>(handle);
392  CHECK(builder) << "Detected dangling reference to deleted TreeBuilder object";
393  builder->SetLeafNode(node_key, *static_cast<const frontend::Value*>(leaf_value));
394  API_END();
395 }
396 
398  const ValueHandle* leaf_vector, size_t leaf_vector_len) {
399  API_BEGIN();
400  auto* builder = static_cast<frontend::TreeBuilder*>(handle);
401  CHECK(builder) << "Detected dangling reference to deleted TreeBuilder object";
402  std::vector<frontend::Value> vec(leaf_vector_len);
403  CHECK(leaf_vector) << "leaf_vector argument must not be null";
404  for (size_t i = 0; i < leaf_vector_len; ++i) {
405  CHECK(leaf_vector[i]) << "leaf_vector[" << i << "] contains an empty Value handle";
406  vec[i] = *static_cast<const frontend::Value*>(leaf_vector[i]);
407  }
408  builder->SetLeafVectorNode(node_key, vec);
409  API_END();
410 }
411 
413  int num_feature, int num_class, int average_tree_output, const char* threshold_type,
414  const char* leaf_output_type, ModelBuilderHandle* out) {
415  API_BEGIN();
416  std::unique_ptr<frontend::ModelBuilder> builder{new frontend::ModelBuilder(
417  num_feature, num_class, (average_tree_output != 0), GetTypeInfoByName(threshold_type),
418  GetTypeInfoByName(leaf_output_type))};
419  *out = static_cast<ModelBuilderHandle>(builder.release());
420  API_END();
421 }
422 
424  const char* value) {
425  API_BEGIN();
426  auto* builder = static_cast<frontend::ModelBuilder*>(handle);
427  CHECK(builder) << "Detected dangling reference to deleted ModelBuilder object";
428  builder->SetModelParam(name, value);
429  API_END();
430 }
431 
433  API_BEGIN();
434  delete static_cast<frontend::ModelBuilder*>(handle);
435  API_END();
436 }
437 
439  int index) {
440  API_BEGIN();
441  auto* model_builder = static_cast<frontend::ModelBuilder*>(handle);
442  CHECK(model_builder) << "Detected dangling reference to deleted ModelBuilder object";
443  auto* tree_builder = static_cast<frontend::TreeBuilder*>(tree_builder_handle);
444  CHECK(tree_builder) << "Detected dangling reference to deleted TreeBuilder object";
445  return model_builder->InsertTree(tree_builder, index);
446  API_END();
447 }
448 
450  API_BEGIN();
451  auto* model_builder = static_cast<frontend::ModelBuilder*>(handle);
452  CHECK(model_builder) << "Detected dangling reference to deleted ModelBuilder object";
453  auto* tree_builder = model_builder->GetTree(index);
454  CHECK(tree_builder) << "Detected dangling reference to deleted TreeBuilder object";
455  *out = static_cast<TreeBuilderHandle>(tree_builder);
456  API_END();
457 }
458 
460  API_BEGIN();
461  auto* builder = static_cast<frontend::ModelBuilder*>(handle);
462  CHECK(builder) << "Detected dangling reference to deleted ModelBuilder object";
463  builder->DeleteTree(index);
464  API_END();
465 }
466 
468  API_BEGIN();
469  auto* builder = static_cast<frontend::ModelBuilder*>(handle);
470  CHECK(builder) << "Detected dangling reference to deleted ModelBuilder object";
471  std::unique_ptr<Model> model = builder->CommitModel();
472  *out = static_cast<ModelHandle>(model.release());
473  API_END();
474 }
Some useful math utilities.
int TreeliteQueryNumClass(ModelHandle handle, size_t *out)
Query the number of classes of the model. (1 if the model is binary classifier or regressor) ...
Definition: c_api.cc:288
C API of Treelite, used for interfacing with other languages This header is excluded from the runtime...
Parameters for tree compiler.
int TreeliteModelBuilderSetModelParam(ModelBuilderHandle handle, const char *name, const char *value)
Set a model parameter.
Definition: c_api.cc:423
int TreeliteLoadSKLearnRandomForestRegressor(int n_estimators, int n_features, const int64_t *node_count, const int64_t **children_left, const int64_t **children_right, const int64_t **feature, const double **threshold, const double **value, const int64_t **n_node_samples, const double **impurity, ModelHandle *out)
Load a scikit-learn random forest regressor model from a collection of arrays. Refer to https://sciki...
Definition: c_api.cc:180
branch annotator class
Definition: annotator.h:17
int TreeliteModelBuilderGetTree(ModelBuilderHandle handle, int index, TreeBuilderHandle *out)
Get a reference to a tree in the ensemble.
Definition: c_api.cc:449
std::unique_ptr< Model > CommitModel()
finalize the model and produce the in-memory representation
Definition: builder.cc:418
Collection of front-end methods to load or construct ensemble model.
int TreeliteLoadXGBoostModel(const char *filename, ModelHandle *out)
load a model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tree en...
Definition: c_api.cc:152
void DeleteNode(int node_key)
Remove a node from a tree.
Definition: builder.cc:185
#define API_BEGIN()
macro to guard beginning and end section of all functions
Definition: c_api_error.h:15
int TreeliteFreeModel(ModelHandle handle)
delete model from memory
Definition: c_api.cc:252
int TreeliteAnnotationSave(AnnotationHandle handle, const char *path)
save branch annotation to a JSON file
Definition: c_api.cc:51
int TreeliteQueryNumTree(ModelHandle handle, size_t *out)
Query the number of trees in the model.
Definition: c_api.cc:274
tree builder class
Definition: frontend.h:211
int TreeliteLoadSKLearnGradientBoostingRegressor(int n_estimators, int n_features, const int64_t *node_count, const int64_t **children_left, const int64_t **children_right, const int64_t **feature, const double **threshold, const double **value, const int64_t **n_node_samples, const double **impurity, ModelHandle *out)
Load a scikit-learn gradient boosting regressor model from a collection of arrays. Refer to https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html to learn the mearning of the arrays in detail.
Definition: c_api.cc:206
int TreeliteModelBuilderDeleteTree(ModelBuilderHandle handle, int index)
Remove a tree from the ensemble.
Definition: c_api.cc:459
parameters for tree compiler
Input data structure of Treelite.
void SetNumericalTestNode(int node_key, unsigned feature_id, const char *op, Value threshold, bool default_left, int left_child_key, int right_child_key)
Turn an empty node into a numerical test node; the test is in the form [feature value] OP [threshold]...
Definition: builder.cc:218
int TreeliteCompilerSetParam(CompilerHandle handle, const char *name, const char *value)
set a parameter for a compiler
Definition: c_api.cc:74
int TreeliteDeserializeModel(const char *filename, ModelHandle *out)
Deserialize (load) a model object from disk.
Definition: c_api.cc:242
int TreeliteLoadXGBoostJSON(const char *filename, ModelHandle *out)
load a json model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tr...
Definition: c_api.cc:159
int TreeliteDeleteModelBuilder(ModelBuilderHandle handle)
Delete a model builder from memory.
Definition: c_api.cc:432
int TreeliteTreeBuilderSetCategoricalTestNode(TreeBuilderHandle handle, int node_key, unsigned feature_id, const unsigned int *left_categories, size_t left_categories_len, int default_left, int left_child_key, int right_child_key)
Turn an empty node into a test node with categorical split. A list defines all categories that would ...
Definition: c_api.cc:372
model structure for tree ensemble
void SetModelParam(const char *name, const char *value)
Set a model parameter.
Definition: builder.cc:337
void SetLeafNode(int node_key, Value leaf_value)
Turn an empty node into a leaf node.
Definition: builder.cc:296
void SetRootNode(int node_key)
Set a node as the root of a tree.
Definition: builder.cc:208
int TreeliteAnnotationFree(AnnotationHandle handle)
delete branch annotation from memory
Definition: c_api.cc:60
TreeBuilder * GetTree(int index)
Get a reference to a tree in the ensemble.
Definition: builder.cc:401
void Save(dmlc::Stream *fo) const
save branch annotation to a JSON file
Definition: annotator.cc:248
int TreeliteTreeBuilderSetNumericalTestNode(TreeBuilderHandle handle, int node_key, unsigned feature_id, const char *opname, ValueHandle threshold, int default_left, int left_child_key, int right_child_key)
Turn an empty node into a test node with numerical split. The test is in the form [feature value] OP ...
Definition: c_api.cc:360
int TreeliteLoadSKLearnRandomForestClassifier(int n_estimators, int n_features, int n_classes, const int64_t *node_count, const int64_t **children_left, const int64_t **children_right, const int64_t **feature, const double **threshold, const double **value, const int64_t **n_node_samples, const double **impurity, ModelHandle *out)
Load a scikit-learn random forest classifier model from a collection of arrays. Refer to https://scik...
Definition: c_api.cc:193
int TreeliteAnnotateBranch(ModelHandle model, DMatrixHandle dmat, int nthread, int verbose, AnnotationHandle *out)
annotate branches in a given model using frequency patterns in the training data. ...
Definition: c_api.cc:39
int TreeliteTreeBuilderCreateValue(const void *init_value, const char *type, ValueHandle *out)
Create a new Value object. Some model builder API functions accept this Value type to accommodate val...
Definition: c_api.cc:305
Interface of compiler that compiles a tree ensemble model.
int TreeliteSetTreeLimit(ModelHandle handle, size_t limit)
keep first N trees of model, limit must smaller than number of trees.
Definition: c_api.cc:295
model builder class
Definition: frontend.h:300
void * ValueHandle
handle to a polymorphic value type, used in the model builder API
Definition: c_api.h:33
int TreeliteModelBuilderInsertTree(ModelBuilderHandle handle, TreeBuilderHandle tree_builder_handle, int index)
Insert a tree at specified location.
Definition: c_api.cc:438
void * DMatrixHandle
handle to a data matrix
Definition: c_api_common.h:30
Cross-platform wrapper for common filesystem functions.
int TreeliteTreeBuilderSetRootNode(TreeBuilderHandle handle, int node_key)
Set a node as the root of a tree.
Definition: c_api.cc:352
void * TreeBuilderHandle
handle to tree builder class
Definition: c_api.h:25
Error handling for C API.
int TreeliteDeleteTreeBuilder(TreeBuilderHandle handle)
Delete a tree builder from memory.
Definition: c_api.cc:330
int TreeliteCreateTreeBuilder(const char *threshold_type, const char *leaf_output_type, TreeBuilderHandle *out)
Create a new tree builder.
Definition: c_api.cc:319
void * AnnotationHandle
handle to branch annotation data
Definition: c_api.h:29
int TreeliteModelBuilderCommitModel(ModelBuilderHandle handle, ModelHandle *out)
finalize the model and produce the in-memory representation
Definition: c_api.cc:467
General Tree Inference Library (GTIL), providing a reference implementation for predicting with decis...
int TreeliteTreeBuilderSetLeafNode(TreeBuilderHandle handle, int node_key, ValueHandle leaf_value)
Turn an empty node into a leaf node.
Definition: c_api.cc:389
int TreeliteCompilerGenerateCode(CompilerHandle compiler, ModelHandle model, int verbose, const char *dirpath)
generate prediction code from a tree ensemble model. The code will be C99 compliant. One header file (.h) will be generated, along with one or more source files (.c).
Definition: c_api.cc:95
int TreeliteCreateModelBuilder(int num_feature, int num_class, int average_tree_output, const char *threshold_type, const char *leaf_output_type, ModelBuilderHandle *out)
Create a new model builder.
Definition: c_api.cc:412
int TreeliteLoadSKLearnGradientBoostingClassifier(int n_estimators, int n_features, int n_classes, const int64_t *node_count, const int64_t **children_left, const int64_t **children_right, const int64_t **feature, const double **threshold, const double **value, const int64_t **n_node_samples, const double **impurity, ModelHandle *out)
Load a scikit-learn gradient boosting classifier model from a collection of arrays. Refer to https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html to learn the mearning of the arrays in detail.
Definition: c_api.cc:219
void DeleteTree(int index)
Remove a tree from the ensemble.
Definition: builder.cc:411
int TreeliteLoadXGBoostModelFromMemoryBuffer(const void *buf, size_t len, ModelHandle *out)
load an XGBoost model from a memory buffer.
Definition: c_api.cc:173
void * ModelHandle
handle to a decision tree ensemble model
Definition: c_api.h:23
int TreeliteLoadLightGBMModel(const char *filename, ModelHandle *out)
load a model file generated by LightGBM (Microsoft/LightGBM). The model file must contain a decision ...
Definition: c_api.cc:145
TypeInfo GetTypeInfoByName(const std::string &str)
conversion table from string to TypeInfo, defined in tables.cc
Definition: typeinfo.cc:16
static Compiler * Create(const std::string &name, const compiler::CompilerParam &param)
create a compiler from given name
Definition: compiler.cc:15
int TreeliteTreeBuilderDeleteValue(ValueHandle handle)
Delete a Value object from memory.
Definition: c_api.cc:313
int TreeliteCompilerCreate(const char *name, CompilerHandle *out)
create a compiler with a given name
Definition: c_api.cc:66
Branch annotation tools.
int TreeliteTreeBuilderDeleteNode(TreeBuilderHandle handle, int node_key)
Remove a node from a tree.
Definition: c_api.cc:344
thin wrapper for tree ensemble model
Definition: tree.h:632
int TreeliteQueryNumFeature(ModelHandle handle, size_t *out)
Query the number of features used in the model.
Definition: c_api.cc:281
void * ModelBuilderHandle
handle to ensemble builder class
Definition: c_api.h:27
int TreeliteSerializeModel(const char *filename, ModelHandle handle)
Serialize (persist) a model object to disk.
Definition: c_api.cc:232
int TreeliteTreeBuilderSetLeafVectorNode(TreeBuilderHandle handle, int node_key, const ValueHandle *leaf_vector, size_t leaf_vector_len)
Turn an empty node into a leaf vector node The leaf vector (collection of multiple leaf weights per l...
Definition: c_api.cc:397
void CreateNode(int node_key)
Create an empty node within a tree.
Definition: builder.cc:178
void * CompilerHandle
handle to compiler class
Definition: c_api.h:31
int TreeliteLoadXGBoostJSONString(const char *json_str, size_t length, ModelHandle *out)
load a model stored as JSON stringby XGBoost (dmlc/xgboost). The model json must contain a decision t...
Definition: c_api.cc:166
int TreeliteCompilerFree(CompilerHandle handle)
delete compiler from memory
Definition: c_api.cc:139
#define API_END()
every function starts with API_BEGIN(); and finishes with API_END() or API_END_HANDLE_ERROR ...
Definition: c_api_error.h:18
int TreeliteTreeBuilderCreateNode(TreeBuilderHandle handle, int node_key)
Create an empty node within a tree.
Definition: c_api.cc:336