Treelite
c_api.cc
Go to the documentation of this file.
1 
8 #include <treelite/annotator.h>
9 #include <treelite/c_api.h>
10 #include <treelite/c_api_error.h>
11 #include <treelite/compiler.h>
13 #include <treelite/data.h>
14 #include <treelite/filesystem.h>
15 #include <treelite/frontend.h>
16 #include <treelite/tree.h>
17 #include <treelite/math.h>
18 #include <treelite/gtil.h>
19 #include <treelite/logging.h>
20 #include <memory>
21 #include <algorithm>
22 #include <fstream>
23 #include <string>
24 #include <cstdio>
25 
26 using namespace treelite;
27 
28 namespace {
29 
31 struct TreeliteAPIThreadLocalEntry {
33  std::string ret_str;
34 };
35 
36 // define threadlocal store for returning information
37 using TreeliteAPIThreadLocalStore = ThreadLocalStore<TreeliteAPIThreadLocalEntry>;
38 
39 } // anonymous namespace
40 
42  ModelHandle model, DMatrixHandle dmat, int nthread, int verbose, AnnotationHandle* out) {
43  API_BEGIN();
44  std::unique_ptr<BranchAnnotator> annotator{new BranchAnnotator()};
45  const Model* model_ = static_cast<Model*>(model);
46  const auto* dmat_ = static_cast<const DMatrix*>(dmat);
47  TREELITE_CHECK(dmat_) << "Found a dangling reference to DMatrix";
48  annotator->Annotate(*model_, dmat_, nthread, verbose);
49  *out = static_cast<AnnotationHandle>(annotator.release());
50  API_END();
51 }
52 
54  const char* path) {
55  API_BEGIN();
56  const BranchAnnotator* annotator = static_cast<BranchAnnotator*>(handle);
57  std::ofstream fo(path);
58  annotator->Save(fo);
59  API_END();
60 }
61 
63  API_BEGIN();
64  delete static_cast<BranchAnnotator*>(handle);
65  API_END();
66 }
67 
68 int TreeliteCompilerCreateV2(const char* name, const char* params_json_str, CompilerHandle* out) {
69  API_BEGIN();
70  std::unique_ptr<Compiler> compiler{Compiler::Create(name, params_json_str)};
71  *out = static_cast<CompilerHandle>(compiler.release());
72  API_END();
73 }
74 
76  ModelHandle model,
77  const char* dirpath) {
78  API_BEGIN();
79  const Model* model_ = static_cast<Model*>(model);
80  Compiler* compiler_ = static_cast<Compiler*>(compiler);
81  TREELITE_CHECK(model_);
82  TREELITE_CHECK(compiler_);
83  compiler::CompilerParam param = compiler_->QueryParam();
84 
85  // create directory named dirpath
86  const std::string& dirpath_(dirpath);
87  filesystem::CreateDirectoryIfNotExist(dirpath);
88 
89  /* compile model */
90  auto compiled_model = compiler_->Compile(*model_);
91  if (param.verbose > 0) {
92  TREELITE_LOG(INFO) << "Code generation finished. Writing code to files...";
93  }
94 
95  for (const auto& it : compiled_model.files) {
96  if (param.verbose > 0) {
97  TREELITE_LOG(INFO) << "Writing file " << it.first << "...";
98  }
99  const std::string filename_full = dirpath_ + "/" + it.first;
100  if (it.second.is_binary) {
101  filesystem::WriteToFile(filename_full, it.second.content_binary);
102  } else {
103  filesystem::WriteToFile(filename_full, it.second.content);
104  }
105  }
106 
107  API_END();
108 }
109 
111  API_BEGIN();
112  delete static_cast<Compiler*>(handle);
113  API_END();
114 }
115 
116 int TreeliteLoadLightGBMModel(const char* filename, ModelHandle* out) {
117  API_BEGIN();
118  std::unique_ptr<Model> model = frontend::LoadLightGBMModel(filename);
119  *out = static_cast<ModelHandle>(model.release());
120  API_END();
121 }
122 
123 int TreeliteLoadXGBoostModel(const char* filename, ModelHandle* out) {
124  API_BEGIN();
125  std::unique_ptr<Model> model = frontend::LoadXGBoostModel(filename);
126  *out = static_cast<ModelHandle>(model.release());
127  API_END();
128 }
129 
130 int TreeliteLoadXGBoostJSON(const char* filename, ModelHandle* out) {
131  API_BEGIN();
132  std::unique_ptr<Model> model = frontend::LoadXGBoostJSONModel(filename);
133  *out = static_cast<ModelHandle>(model.release());
134  API_END();
135 }
136 
137 int TreeliteLoadXGBoostJSONString(const char* json_str, size_t length, ModelHandle* out) {
138  API_BEGIN();
139  std::unique_ptr<Model> model = frontend::LoadXGBoostJSONModelString(json_str, length);
140  *out = static_cast<ModelHandle>(model.release());
141  API_END();
142 }
143 
144 int TreeliteLoadXGBoostModelFromMemoryBuffer(const void* buf, size_t len, ModelHandle* out) {
145  API_BEGIN();
146  std::unique_ptr<Model> model = frontend::LoadXGBoostModel(buf, len);
147  *out = static_cast<ModelHandle>(model.release());
148  API_END();
149 }
150 
151 int TreeliteLoadLightGBMModelFromString(const char* model_str, ModelHandle* out) {
152  API_BEGIN();
153  std::unique_ptr<Model> model = frontend::LoadLightGBMModelFromString(model_str);
154  *out = static_cast<ModelHandle>(model.release());
155  API_END();
156 }
157 
159  int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left,
160  const int64_t** children_right, const int64_t** feature, const double** threshold,
161  const double** value, const int64_t** n_node_samples, const double** weighted_n_node_samples,
162  const double** impurity, ModelHandle* out) {
163  API_BEGIN();
164  std::unique_ptr<Model> model = frontend::LoadSKLearnRandomForestRegressor(
165  n_estimators, n_features, node_count, children_left, children_right, feature, threshold,
166  value, n_node_samples, weighted_n_node_samples, impurity);
167  *out = static_cast<ModelHandle>(model.release());
168  API_END();
169 }
170 
172  int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left,
173  const int64_t** children_right, const int64_t** feature, const double** threshold,
174  const double** value, const int64_t** n_node_samples, const double** weighted_n_node_samples,
175  const double** impurity, const double ratio_c, ModelHandle* out) {
176  API_BEGIN();
177  std::unique_ptr<Model> model = frontend::LoadSKLearnIsolationForest(
178  n_estimators, n_features, node_count, children_left, children_right, feature, threshold,
179  value, n_node_samples, weighted_n_node_samples, impurity, ratio_c);
180  *out = static_cast<ModelHandle>(model.release());
181  API_END();
182 }
183 
185  int n_estimators, int n_features, int n_classes, const int64_t* node_count,
186  const int64_t** children_left, const int64_t** children_right, const int64_t** feature,
187  const double** threshold, const double** value, const int64_t** n_node_samples,
188  const double** weighted_n_node_samples, const double** impurity, ModelHandle* out) {
189  API_BEGIN();
190  std::unique_ptr<Model> model = frontend::LoadSKLearnRandomForestClassifier(
191  n_estimators, n_features, n_classes, node_count, children_left, children_right, feature,
192  threshold, value, n_node_samples, weighted_n_node_samples, impurity);
193  *out = static_cast<ModelHandle>(model.release());
194  API_END();
195 }
196 
198  int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left,
199  const int64_t** children_right, const int64_t** feature, const double** threshold,
200  const double** value, const int64_t** n_node_samples, const double** weighted_n_node_samples,
201  const double** impurity, ModelHandle* out) {
202  API_BEGIN();
203  std::unique_ptr<Model> model = frontend::LoadSKLearnGradientBoostingRegressor(
204  n_estimators, n_features, node_count, children_left, children_right, feature, threshold,
205  value, n_node_samples, weighted_n_node_samples, impurity);
206  *out = static_cast<ModelHandle>(model.release());
207  API_END();
208 }
209 
211  int n_estimators, int n_features, int n_classes, const int64_t* node_count,
212  const int64_t** children_left, const int64_t** children_right, const int64_t** feature,
213  const double** threshold, const double** value, const int64_t** n_node_samples,
214  const double** weighted_n_node_samples, const double** impurity, ModelHandle* out) {
215  API_BEGIN();
216  std::unique_ptr<Model> model = frontend::LoadSKLearnGradientBoostingClassifier(
217  n_estimators, n_features, n_classes, node_count, children_left, children_right, feature,
218  threshold, value, n_node_samples, weighted_n_node_samples, impurity);
219  *out = static_cast<ModelHandle>(model.release());
220  API_END();
221 }
222 
223 int TreeliteSerializeModel(const char* filename, ModelHandle handle) {
224  API_BEGIN();
225  FILE* fp = std::fopen(filename, "wb");
226  TREELITE_CHECK(fp) << "Failed to open file '" << filename << "'";
227  auto* model_ = static_cast<Model*>(handle);
228  model_->SerializeToFile(fp);
229  std::fclose(fp);
230  API_END();
231 }
232 
233 int TreeliteDeserializeModel(const char* filename, ModelHandle* out) {
234  API_BEGIN();
235  FILE* fp = std::fopen(filename, "rb");
236  TREELITE_CHECK(fp) << "Failed to open file '" << filename << "'";
237  std::unique_ptr<Model> model = Model::DeserializeFromFile(fp);
238  std::fclose(fp);
239  *out = static_cast<ModelHandle>(model.release());
240  API_END();
241 }
242 
243 int TreeliteConcatenateModelObjects(const ModelHandle* objs, size_t len,
244  ModelHandle* out) {
245  API_BEGIN();
246  std::vector<const Model*> model_objs(len, nullptr);
247  std::transform(objs, objs + len, model_objs.begin(),
248  [](const ModelHandle e) { return static_cast<const Model*>(e); });
249  auto concatenated_model = ConcatenateModelObjects(model_objs);
250  *out = static_cast<ModelHandle>(concatenated_model.release());
251  API_END();
252 }
253 
254 int TreeliteDumpAsJSON(ModelHandle handle, int pretty_print, const char** out_json_str) {
255  API_BEGIN();
256  auto* model_ = static_cast<Model*>(handle);
257  std::string& ret_str = TreeliteAPIThreadLocalStore::Get()->ret_str;
258  ret_str = model_->DumpAsJSON(pretty_print != 0);
259  *out_json_str = ret_str.c_str();
260  API_END();
261 }
262 
264  API_BEGIN();
265  delete static_cast<Model*>(handle);
266  API_END();
267 }
268 
269 int TreeliteGTILGetPredictOutputSize(ModelHandle model, size_t num_row, size_t* out) {
270  API_BEGIN();
271  const auto* model_ = static_cast<const Model*>(model);
272  *out = gtil::GetPredictOutputSize(model_, num_row);
273  API_END();
274 }
275 
276 int TreeliteGTILPredict(ModelHandle model, const float* input, size_t num_row, float* output,
277  int nthread, int pred_transform, size_t* out_result_size) {
278  API_BEGIN();
279  const auto* model_ = static_cast<const Model*>(model);
280  *out_result_size =
281  gtil::Predict(model_, input, num_row, output, nthread, (pred_transform == 1));
282  API_END();
283 }
284 
285 int TreeliteQueryNumTree(ModelHandle handle, size_t* out) {
286  API_BEGIN();
287  const auto* model_ = static_cast<const Model*>(handle);
288  *out = model_->GetNumTree();
289  API_END();
290 }
291 
292 int TreeliteQueryNumFeature(ModelHandle handle, size_t* out) {
293  API_BEGIN();
294  const auto* model_ = static_cast<const Model*>(handle);
295  *out = static_cast<size_t>(model_->num_feature);
296  API_END();
297 }
298 
299 int TreeliteQueryNumClass(ModelHandle handle, size_t* out) {
300  API_BEGIN();
301  const auto* model_ = static_cast<const Model*>(handle);
302  *out = static_cast<size_t>(model_->task_param.num_class);
303  API_END();
304 }
305 
306 int TreeliteSetTreeLimit(ModelHandle handle, size_t limit) {
307  API_BEGIN();
308  TREELITE_CHECK_GT(limit, 0) << "limit should be greater than 0!";
309  auto* model_ = static_cast<Model*>(handle);
310  const size_t num_tree = model_->GetNumTree();
311  TREELITE_CHECK_GE(num_tree, limit) << "Model contains fewer trees(" << num_tree << ") than limit";
312  model_->SetTreeLimit(limit);
313  API_END();
314 }
315 
316 int TreeliteTreeBuilderCreateValue(const void* init_value, const char* type, ValueHandle* out) {
317  API_BEGIN();
318  std::unique_ptr<frontend::Value> value = std::make_unique<frontend::Value>();
319  *value = frontend::Value::Create(init_value, GetTypeInfoByName(type));
320  *out = static_cast<ValueHandle>(value.release());
321  API_END();
322 }
323 
325  API_BEGIN();
326  delete static_cast<frontend::Value*>(handle);
327  API_END();
328 }
329 
330 int TreeliteCreateTreeBuilder(const char* threshold_type, const char* leaf_output_type,
331  TreeBuilderHandle* out) {
332  API_BEGIN();
333  std::unique_ptr<frontend::TreeBuilder> builder{
334  new frontend::TreeBuilder(GetTypeInfoByName(threshold_type),
335  GetTypeInfoByName(leaf_output_type))
336  };
337  *out = static_cast<TreeBuilderHandle>(builder.release());
338  API_END();
339 }
340 
342  API_BEGIN();
343  delete static_cast<frontend::TreeBuilder*>(handle);
344  API_END();
345 }
346 
348  API_BEGIN();
349  auto* builder = static_cast<frontend::TreeBuilder*>(handle);
350  TREELITE_CHECK(builder) << "Detected dangling reference to deleted TreeBuilder object";
351  builder->CreateNode(node_key);
352  API_END();
353 }
354 
356  API_BEGIN();
357  auto* builder = static_cast<frontend::TreeBuilder*>(handle);
358  TREELITE_CHECK(builder) << "Detected dangling reference to deleted TreeBuilder object";
359  builder->DeleteNode(node_key);
360  API_END();
361 }
362 
364  API_BEGIN();
365  auto* builder = static_cast<frontend::TreeBuilder*>(handle);
366  TREELITE_CHECK(builder) << "Detected dangling reference to deleted TreeBuilder object";
367  builder->SetRootNode(node_key);
368  API_END();
369 }
370 
372  TreeBuilderHandle handle, int node_key, unsigned feature_id, const char* opname,
373  ValueHandle threshold, int default_left, int left_child_key, int right_child_key) {
374  API_BEGIN();
375  auto* builder = static_cast<frontend::TreeBuilder*>(handle);
376  TREELITE_CHECK(builder) << "Detected dangling reference to deleted TreeBuilder object";
377  builder->SetNumericalTestNode(node_key, feature_id, opname,
378  *static_cast<const frontend::Value*>(threshold),
379  (default_left != 0), left_child_key, right_child_key);
380  API_END();
381 }
382 
384  TreeBuilderHandle handle, int node_key, unsigned feature_id,
385  const unsigned int* left_categories, size_t left_categories_len, int default_left,
386  int left_child_key, int right_child_key) {
387  API_BEGIN();
388  auto* builder = static_cast<frontend::TreeBuilder*>(handle);
389  TREELITE_CHECK(builder) << "Detected dangling reference to deleted TreeBuilder object";
390  std::vector<uint32_t> vec(left_categories_len);
391  for (size_t i = 0; i < left_categories_len; ++i) {
392  TREELITE_CHECK(left_categories[i] <= std::numeric_limits<uint32_t>::max());
393  vec[i] = static_cast<uint32_t>(left_categories[i]);
394  }
395  builder->SetCategoricalTestNode(node_key, feature_id, vec, (default_left != 0),
396  left_child_key, right_child_key);
397  API_END();
398 }
399 
400 int TreeliteTreeBuilderSetLeafNode(TreeBuilderHandle handle, int node_key, ValueHandle leaf_value) {
401  API_BEGIN();
402  auto* builder = static_cast<frontend::TreeBuilder*>(handle);
403  TREELITE_CHECK(builder) << "Detected dangling reference to deleted TreeBuilder object";
404  builder->SetLeafNode(node_key, *static_cast<const frontend::Value*>(leaf_value));
405  API_END();
406 }
407 
409  const ValueHandle* leaf_vector, size_t leaf_vector_len) {
410  API_BEGIN();
411  auto* builder = static_cast<frontend::TreeBuilder*>(handle);
412  TREELITE_CHECK(builder) << "Detected dangling reference to deleted TreeBuilder object";
413  std::vector<frontend::Value> vec(leaf_vector_len);
414  TREELITE_CHECK(leaf_vector) << "leaf_vector argument must not be null";
415  for (size_t i = 0; i < leaf_vector_len; ++i) {
416  TREELITE_CHECK(leaf_vector[i]) << "leaf_vector[" << i << "] contains an empty Value handle";
417  vec[i] = *static_cast<const frontend::Value*>(leaf_vector[i]);
418  }
419  builder->SetLeafVectorNode(node_key, vec);
420  API_END();
421 }
422 
424  int num_feature, int num_class, int average_tree_output, const char* threshold_type,
425  const char* leaf_output_type, ModelBuilderHandle* out) {
426  API_BEGIN();
427  std::unique_ptr<frontend::ModelBuilder> builder{new frontend::ModelBuilder(
428  num_feature, num_class, (average_tree_output != 0), GetTypeInfoByName(threshold_type),
429  GetTypeInfoByName(leaf_output_type))};
430  *out = static_cast<ModelBuilderHandle>(builder.release());
431  API_END();
432 }
433 
435  const char* value) {
436  API_BEGIN();
437  auto* builder = static_cast<frontend::ModelBuilder*>(handle);
438  TREELITE_CHECK(builder) << "Detected dangling reference to deleted ModelBuilder object";
439  builder->SetModelParam(name, value);
440  API_END();
441 }
442 
444  API_BEGIN();
445  delete static_cast<frontend::ModelBuilder*>(handle);
446  API_END();
447 }
448 
450  int index) {
451  API_BEGIN();
452  auto* model_builder = static_cast<frontend::ModelBuilder*>(handle);
453  TREELITE_CHECK(model_builder) << "Detected dangling reference to deleted ModelBuilder object";
454  auto* tree_builder = static_cast<frontend::TreeBuilder*>(tree_builder_handle);
455  TREELITE_CHECK(tree_builder) << "Detected dangling reference to deleted TreeBuilder object";
456  return model_builder->InsertTree(tree_builder, index);
457  API_END();
458 }
459 
461  API_BEGIN();
462  auto* model_builder = static_cast<frontend::ModelBuilder*>(handle);
463  TREELITE_CHECK(model_builder) << "Detected dangling reference to deleted ModelBuilder object";
464  auto* tree_builder = model_builder->GetTree(index);
465  TREELITE_CHECK(tree_builder) << "Detected dangling reference to deleted TreeBuilder object";
466  *out = static_cast<TreeBuilderHandle>(tree_builder);
467  API_END();
468 }
469 
471  API_BEGIN();
472  auto* builder = static_cast<frontend::ModelBuilder*>(handle);
473  TREELITE_CHECK(builder) << "Detected dangling reference to deleted ModelBuilder object";
474  builder->DeleteTree(index);
475  API_END();
476 }
477 
479  API_BEGIN();
480  auto* builder = static_cast<frontend::ModelBuilder*>(handle);
481  TREELITE_CHECK(builder) << "Detected dangling reference to deleted ModelBuilder object";
482  std::unique_ptr<Model> model = builder->CommitModel();
483  *out = static_cast<ModelHandle>(model.release());
484  API_END();
485 }
Some useful math utilities.
int TreeliteQueryNumClass(ModelHandle handle, size_t *out)
Query the number of classes of the model. (1 if the model is binary classifier or regressor) ...
Definition: c_api.cc:299
C API of Treelite, used for interfacing with other languages This header is excluded from the runtime...
Parameters for tree compiler.
int TreeliteModelBuilderSetModelParam(ModelBuilderHandle handle, const char *name, const char *value)
Set a model parameter.
Definition: c_api.cc:434
branch annotator class
Definition: annotator.h:21
int TreeliteModelBuilderGetTree(ModelBuilderHandle handle, int index, TreeBuilderHandle *out)
Get a reference to a tree in the ensemble.
Definition: c_api.cc:460
void Save(std::ostream &fo) const
save branch annotation to a JSON file
Definition: annotator.cc:264
Collection of front-end methods to load or construct ensemble model.
int TreeliteLoadXGBoostModel(const char *filename, ModelHandle *out)
load a model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tree en...
Definition: c_api.cc:123
int TreeliteDumpAsJSON(ModelHandle handle, int pretty_print, const char **out_json_str)
Dump a model object as a JSON string.
Definition: c_api.cc:254
#define API_BEGIN()
macro to guard beginning and end section of all functions
Definition: c_api_error.h:14
int TreeliteFreeModel(ModelHandle handle)
delete model from memory
Definition: c_api.cc:263
int TreeliteAnnotationSave(AnnotationHandle handle, const char *path)
save branch annotation to a JSON file
Definition: c_api.cc:53
int TreeliteQueryNumTree(ModelHandle handle, size_t *out)
Query the number of trees in the model.
Definition: c_api.cc:285
tree builder class
Definition: frontend.h:257
int TreeliteModelBuilderDeleteTree(ModelBuilderHandle handle, int index)
Remove a tree from the ensemble.
Definition: c_api.cc:470
int TreeliteConcatenateModelObjects(const ModelHandle *objs, size_t len, ModelHandle *out)
Concatenate multiple model objects into a single model object by copying all member trees into the de...
Definition: c_api.cc:243
parameters for tree compiler
Input data structure of Treelite.
int TreeliteDeserializeModel(const char *filename, ModelHandle *out)
Deserialize (load) a model object from disk.
Definition: c_api.cc:233
int TreeliteLoadXGBoostJSON(const char *filename, ModelHandle *out)
load a json model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tr...
Definition: c_api.cc:130
int TreeliteDeleteModelBuilder(ModelBuilderHandle handle)
Delete a model builder from memory.
Definition: c_api.cc:443
int TreeliteTreeBuilderSetCategoricalTestNode(TreeBuilderHandle handle, int node_key, unsigned feature_id, const unsigned int *left_categories, size_t left_categories_len, int default_left, int left_child_key, int right_child_key)
Turn an empty node into a test node with categorical split. A list defines all categories that would ...
Definition: c_api.cc:383
model structure for tree ensemble
logging facility for Treelite
int TreeliteAnnotationFree(AnnotationHandle handle)
delete branch annotation from memory
Definition: c_api.cc:62
int TreeliteLoadSKLearnIsolationForest(int n_estimators, int n_features, const int64_t *node_count, const int64_t **children_left, const int64_t **children_right, const int64_t **feature, const double **threshold, const double **value, const int64_t **n_node_samples, const double **weighted_n_node_samples, const double **impurity, const double ratio_c, ModelHandle *out)
Load a scikit-learn isolation forest model from a collection of arrays. Refer to https://scikit-learn...
Definition: c_api.cc:171
int TreeliteTreeBuilderSetNumericalTestNode(TreeBuilderHandle handle, int node_key, unsigned feature_id, const char *opname, ValueHandle threshold, int default_left, int left_child_key, int right_child_key)
Turn an empty node into a test node with numerical split. The test is in the form [feature value] OP ...
Definition: c_api.cc:371
int TreeliteLoadSKLearnRandomForestClassifier(int n_estimators, int n_features, int n_classes, const int64_t *node_count, const int64_t **children_left, const int64_t **children_right, const int64_t **feature, const double **threshold, const double **value, const int64_t **n_node_samples, const double **weighted_n_node_samples, const double **impurity, ModelHandle *out)
Load a scikit-learn random forest classifier model from a collection of arrays. Refer to https://scik...
Definition: c_api.cc:184
interface of compiler
Definition: compiler.h:53
int TreeliteAnnotateBranch(ModelHandle model, DMatrixHandle dmat, int nthread, int verbose, AnnotationHandle *out)
annotate branches in a given model using frequency patterns in the training data. ...
Definition: c_api.cc:41
int TreeliteTreeBuilderCreateValue(const void *init_value, const char *type, ValueHandle *out)
Create a new Value object. Some model builder API functions accept this Value type to accommodate val...
Definition: c_api.cc:316
Interface of compiler that compiles a tree ensemble model.
int TreeliteSetTreeLimit(ModelHandle handle, size_t limit)
keep first N trees of model, limit must smaller than number of trees.
Definition: c_api.cc:306
model builder class
Definition: frontend.h:346
void * ValueHandle
handle to a polymorphic value type, used in the model builder API
Definition: c_api.h:33
int TreeliteModelBuilderInsertTree(ModelBuilderHandle handle, TreeBuilderHandle tree_builder_handle, int index)
Insert a tree at specified location.
Definition: c_api.cc:449
void * DMatrixHandle
handle to a data matrix
Definition: c_api_common.h:30
Cross-platform wrapper for common filesystem functions.
int TreeliteLoadLightGBMModelFromString(const char *model_str, ModelHandle *out)
Load a LightGBM model from a string. The string should be created with the model_to_string() method i...
Definition: c_api.cc:151
int TreeliteTreeBuilderSetRootNode(TreeBuilderHandle handle, int node_key)
Set a node as the root of a tree.
Definition: c_api.cc:363
int TreeliteLoadSKLearnGradientBoostingClassifier(int n_estimators, int n_features, int n_classes, const int64_t *node_count, const int64_t **children_left, const int64_t **children_right, const int64_t **feature, const double **threshold, const double **value, const int64_t **n_node_samples, const double **weighted_n_node_samples, const double **impurity, ModelHandle *out)
Load a scikit-learn gradient boosting classifier model from a collection of arrays. Refer to https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html to learn the mearning of the arrays in detail.
Definition: c_api.cc:210
int TreeliteCompilerCreateV2(const char *name, const char *params_json_str, CompilerHandle *out)
Create a compiler with a given name.
Definition: c_api.cc:68
void * TreeBuilderHandle
handle to tree builder class
Definition: c_api.h:25
Error handling for C API.
int TreeliteDeleteTreeBuilder(TreeBuilderHandle handle)
Delete a tree builder from memory.
Definition: c_api.cc:341
int TreeliteCreateTreeBuilder(const char *threshold_type, const char *leaf_output_type, TreeBuilderHandle *out)
Create a new tree builder.
Definition: c_api.cc:330
void * AnnotationHandle
handle to branch annotation data
Definition: c_api.h:29
int TreeliteModelBuilderCommitModel(ModelBuilderHandle handle, ModelHandle *out)
finalize the model and produce the in-memory representation
Definition: c_api.cc:478
General Tree Inference Library (GTIL), providing a reference implementation for predicting with decis...
int TreeliteTreeBuilderSetLeafNode(TreeBuilderHandle handle, int node_key, ValueHandle leaf_value)
Turn an empty node into a leaf node.
Definition: c_api.cc:400
int TreeliteCompilerGenerateCodeV2(CompilerHandle compiler, ModelHandle model, const char *dirpath)
Generate prediction code from a tree ensemble model. The code will be C99 compliant. One header file (.h) will be generated, along with one or more source files (.c).
Definition: c_api.cc:75
int TreeliteCreateModelBuilder(int num_feature, int num_class, int average_tree_output, const char *threshold_type, const char *leaf_output_type, ModelBuilderHandle *out)
Create a new model builder.
Definition: c_api.cc:423
virtual compiler::CompiledModel Compile(const Model &model)=0
convert tree ensemble model
virtual compiler::CompilerParam QueryParam() const =0
Query the parameters used to intiailize the compiler.
int TreeliteLoadXGBoostModelFromMemoryBuffer(const void *buf, size_t len, ModelHandle *out)
load an XGBoost model from a memory buffer.
Definition: c_api.cc:144
int TreeliteLoadSKLearnGradientBoostingRegressor(int n_estimators, int n_features, const int64_t *node_count, const int64_t **children_left, const int64_t **children_right, const int64_t **feature, const double **threshold, const double **value, const int64_t **n_node_samples, const double **weighted_n_node_samples, const double **impurity, ModelHandle *out)
Load a scikit-learn gradient boosting regressor model from a collection of arrays. Refer to https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html to learn the mearning of the arrays in detail.
Definition: c_api.cc:197
void * ModelHandle
handle to a decision tree ensemble model
Definition: c_api.h:23
int TreeliteLoadLightGBMModel(const char *filename, ModelHandle *out)
load a model file generated by LightGBM (Microsoft/LightGBM). The model file must contain a decision ...
Definition: c_api.cc:116
TypeInfo GetTypeInfoByName(const std::string &str)
conversion table from string to TypeInfo, defined in tables.cc
Definition: typeinfo.cc:15
int TreeliteTreeBuilderDeleteValue(ValueHandle handle)
Delete a Value object from memory.
Definition: c_api.cc:324
Branch annotation tools.
int TreeliteTreeBuilderDeleteNode(TreeBuilderHandle handle, int node_key)
Remove a node from a tree.
Definition: c_api.cc:355
thin wrapper for tree ensemble model
Definition: tree.h:734
int TreeliteLoadSKLearnRandomForestRegressor(int n_estimators, int n_features, const int64_t *node_count, const int64_t **children_left, const int64_t **children_right, const int64_t **feature, const double **threshold, const double **value, const int64_t **n_node_samples, const double **weighted_n_node_samples, const double **impurity, ModelHandle *out)
Load a scikit-learn random forest regressor model from a collection of arrays. Refer to https://sciki...
Definition: c_api.cc:158
int TreeliteQueryNumFeature(ModelHandle handle, size_t *out)
Query the number of features used in the model.
Definition: c_api.cc:292
int TreeliteGTILPredict(ModelHandle model, const float *input, size_t num_row, float *output, int nthread, int pred_transform, size_t *out_result_size)
Predict with a 2D dense array.
Definition: c_api.cc:276
void * ModelBuilderHandle
handle to ensemble builder class
Definition: c_api.h:27
int TreeliteSerializeModel(const char *filename, ModelHandle handle)
Serialize (persist) a model object to disk.
Definition: c_api.cc:223
int TreeliteTreeBuilderSetLeafVectorNode(TreeBuilderHandle handle, int node_key, const ValueHandle *leaf_vector, size_t leaf_vector_len)
Turn an empty node into a leaf vector node The leaf vector (collection of multiple leaf weights per l...
Definition: c_api.cc:408
A thread-local storage.
Definition: thread_local.h:17
int verbose
if >0, produce extra messages
void * CompilerHandle
handle to compiler class
Definition: c_api.h:31
static Compiler * Create(const std::string &name, const char *param_json_str)
create a compiler from given name
Definition: compiler.cc:16
int TreeliteLoadXGBoostJSONString(const char *json_str, size_t length, ModelHandle *out)
load a model stored as JSON stringby XGBoost (dmlc/xgboost). The model json must contain a decision t...
Definition: c_api.cc:137
int TreeliteGTILGetPredictOutputSize(ModelHandle model, size_t num_row, size_t *out)
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
Definition: c_api.cc:269
std::unique_ptr< Model > ConcatenateModelObjects(const std::vector< const Model *> &objs)
Concatenate multiple model objects into a single model object by copying all member trees into the de...
Definition: model_concat.cc:16
int TreeliteCompilerFree(CompilerHandle handle)
delete compiler from memory
Definition: c_api.cc:110
#define API_END()
every function starts with API_BEGIN(); and finishes with API_END() or API_END_HANDLE_ERROR ...
Definition: c_api_error.h:17
int TreeliteTreeBuilderCreateNode(TreeBuilderHandle handle, int node_key)
Create an empty node within a tree.
Definition: c_api.cc:347