Treelite
c_api.cc
Go to the documentation of this file.
1 
9 #include <treelite/annotator.h>
10 #include <treelite/c_api.h>
11 #include <treelite/c_api_error.h>
12 #include <treelite/compiler.h>
14 #include <treelite/data.h>
15 #include <treelite/filesystem.h>
16 #include <treelite/frontend.h>
17 #include <treelite/tree.h>
18 #include <treelite/math.h>
19 #include <dmlc/thread_local.h>
20 #include <memory>
21 #include <algorithm>
22 
23 using namespace treelite;
24 
25 namespace {
26 
27 struct CompilerHandleImpl {
28  std::string name;
29  std::vector<std::pair<std::string, std::string>> cfg;
30  std::unique_ptr<Compiler> compiler;
31  explicit CompilerHandleImpl(const std::string& name)
32  : name(name), cfg(), compiler(nullptr) {}
33  ~CompilerHandleImpl() = default;
34 };
35 
36 } // anonymous namespace
37 
39  ModelHandle model, DMatrixHandle dmat, int nthread, int verbose, AnnotationHandle* out) {
40  API_BEGIN();
41  std::unique_ptr<BranchAnnotator> annotator{new BranchAnnotator()};
42  const Model* model_ = static_cast<Model*>(model);
43  const auto* dmat_ = static_cast<const DMatrix*>(dmat);
44  CHECK(dmat_) << "Found a dangling reference to DMatrix";
45  annotator->Annotate(*model_, dmat_, nthread, verbose);
46  *out = static_cast<AnnotationHandle>(annotator.release());
47  API_END();
48 }
49 
51  const char* path) {
52  API_BEGIN();
53  const BranchAnnotator* annotator = static_cast<BranchAnnotator*>(handle);
54  std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(path, "w"));
55  annotator->Save(fo.get());
56  API_END();
57 }
58 
60  API_BEGIN();
61  delete static_cast<BranchAnnotator*>(handle);
62  API_END();
63 }
64 
65 int TreeliteCompilerCreate(const char* name,
66  CompilerHandle* out) {
67  API_BEGIN();
68  std::unique_ptr<CompilerHandleImpl> compiler{new CompilerHandleImpl(name)};
69  *out = static_cast<CompilerHandle>(compiler.release());
70  API_END();
71 }
72 
74  const char* name,
75  const char* value) {
76  API_BEGIN();
77  CompilerHandleImpl* impl = static_cast<CompilerHandleImpl*>(handle);
78  auto& cfg_ = impl->cfg;
79  std::string name_(name);
80  std::string value_(value);
81  // check for duplicate parameters
82  auto it = std::find_if(cfg_.begin(), cfg_.end(),
83  [&name_](const std::pair<std::string, std::string>& x) {
84  return x.first == name_;
85  });
86  if (it == cfg_.end()) {
87  cfg_.emplace_back(name_, value_);
88  } else {
89  it->second = value;
90  }
91  API_END();
92 }
93 
95  ModelHandle model,
96  int verbose,
97  const char* dirpath) {
98  API_BEGIN();
99  if (verbose > 0) { // verbose enabled
100  int ret = TreeliteCompilerSetParam(compiler, "verbose",
101  std::to_string(verbose).c_str());
102  if (ret < 0) { // SetParam failed
103  return ret;
104  }
105  }
106  const Model* model_ = static_cast<Model*>(model);
107  CompilerHandleImpl* impl = static_cast<CompilerHandleImpl*>(compiler);
108 
109  // create directory named dirpath
110  const std::string& dirpath_(dirpath);
111  filesystem::CreateDirectoryIfNotExist(dirpath);
112 
114  cparam.Init(impl->cfg, dmlc::parameter::kAllMatch);
115 
116  /* compile model */
117  impl->compiler.reset(Compiler::Create(impl->name, cparam));
118  auto compiled_model = impl->compiler->Compile(*model_);
119  if (verbose > 0) {
120  LOG(INFO) << "Code generation finished. Writing code to files...";
121  }
122 
123  for (const auto& it : compiled_model.files) {
124  if (verbose > 0) {
125  LOG(INFO) << "Writing file " << it.first << "...";
126  }
127  const std::string filename_full = dirpath_ + "/" + it.first;
128  if (it.second.is_binary) {
129  filesystem::WriteToFile(filename_full, it.second.content_binary);
130  } else {
131  filesystem::WriteToFile(filename_full, it.second.content);
132  }
133  }
134 
135  API_END();
136 }
137 
139  API_BEGIN();
140  delete static_cast<CompilerHandleImpl*>(handle);
141  API_END();
142 }
143 
144 int TreeliteLoadLightGBMModel(const char* filename, ModelHandle* out) {
145  API_BEGIN();
146  std::unique_ptr<Model> model = frontend::LoadLightGBMModel(filename);
147  *out = static_cast<ModelHandle>(model.release());
148  API_END();
149 }
150 
151 int TreeliteLoadXGBoostModel(const char* filename, ModelHandle* out) {
152  API_BEGIN();
153  std::unique_ptr<Model> model = frontend::LoadXGBoostModel(filename);
154  *out = static_cast<ModelHandle>(model.release());
155  API_END();
156 }
157 
158 int TreeliteLoadXGBoostJSON(const char* filename, ModelHandle* out) {
159  API_BEGIN();
160  std::unique_ptr<Model> model = frontend::LoadXGBoostJSONModel(filename);
161  *out = static_cast<ModelHandle>(model.release());
162  API_END();
163 }
164 
165 int TreeliteLoadXGBoostJSONString(const char* json_str, size_t length, ModelHandle* out) {
166  API_BEGIN();
167  std::unique_ptr<Model> model = frontend::LoadXGBoostJSONModelString(json_str, length);
168  *out = static_cast<ModelHandle>(model.release());
169  API_END();
170 }
171 
172 int TreeliteLoadXGBoostModelFromMemoryBuffer(const void* buf, size_t len, ModelHandle* out) {
173  API_BEGIN();
174  std::unique_ptr<Model> model = frontend::LoadXGBoostModel(buf, len);
175  *out = static_cast<ModelHandle>(model.release());
176  API_END();
177 }
178 
180  API_BEGIN();
181  delete static_cast<Model*>(handle);
182  API_END();
183 }
184 
185 int TreeliteQueryNumTree(ModelHandle handle, size_t* out) {
186  API_BEGIN();
187  const auto* model_ = static_cast<const Model*>(handle);
188  *out = model_->GetNumTree();
189  API_END();
190 }
191 
192 int TreeliteQueryNumFeature(ModelHandle handle, size_t* out) {
193  API_BEGIN();
194  const auto* model_ = static_cast<const Model*>(handle);
195  *out = static_cast<size_t>(model_->num_feature);
196  API_END();
197 }
198 
199 int TreeliteQueryNumClass(ModelHandle handle, size_t* out) {
200  API_BEGIN();
201  const auto* model_ = static_cast<const Model*>(handle);
202  *out = static_cast<size_t>(model_->task_param.num_class);
203  API_END();
204 }
205 
206 int TreeliteSetTreeLimit(ModelHandle handle, size_t limit) {
207  API_BEGIN();
208  CHECK_GT(limit, 0) << "limit should be greater than 0!";
209  auto* model_ = static_cast<Model*>(handle);
210  const size_t num_tree = model_->GetNumTree();
211  CHECK_GE(num_tree, limit) << "Model contains less trees(" << num_tree << ") than limit";
212  model_->SetTreeLimit(limit);
213  API_END();
214 }
215 
216 int TreeliteTreeBuilderCreateValue(const void* init_value, const char* type, ValueHandle* out) {
217  API_BEGIN();
218  std::unique_ptr<frontend::Value> value = std::make_unique<frontend::Value>();
219  *value = frontend::Value::Create(init_value, GetTypeInfoByName(type));
220  *out = static_cast<ValueHandle>(value.release());
221  API_END();
222 }
223 
225  API_BEGIN();
226  delete static_cast<frontend::Value*>(handle);
227  API_END();
228 }
229 
230 int TreeliteCreateTreeBuilder(const char* threshold_type, const char* leaf_output_type,
231  TreeBuilderHandle* out) {
232  API_BEGIN();
233  std::unique_ptr<frontend::TreeBuilder> builder{
234  new frontend::TreeBuilder(GetTypeInfoByName(threshold_type),
235  GetTypeInfoByName(leaf_output_type))
236  };
237  *out = static_cast<TreeBuilderHandle>(builder.release());
238  API_END();
239 }
240 
242  API_BEGIN();
243  delete static_cast<frontend::TreeBuilder*>(handle);
244  API_END();
245 }
246 
248  API_BEGIN();
249  auto* builder = static_cast<frontend::TreeBuilder*>(handle);
250  CHECK(builder) << "Detected dangling reference to deleted TreeBuilder object";
251  builder->CreateNode(node_key);
252  API_END();
253 }
254 
256  API_BEGIN();
257  auto* builder = static_cast<frontend::TreeBuilder*>(handle);
258  CHECK(builder) << "Detected dangling reference to deleted TreeBuilder object";
259  builder->DeleteNode(node_key);
260  API_END();
261 }
262 
264  API_BEGIN();
265  auto* builder = static_cast<frontend::TreeBuilder*>(handle);
266  CHECK(builder) << "Detected dangling reference to deleted TreeBuilder object";
267  builder->SetRootNode(node_key);
268  API_END();
269 }
270 
272  TreeBuilderHandle handle, int node_key, unsigned feature_id, const char* opname,
273  ValueHandle threshold, int default_left, int left_child_key, int right_child_key) {
274  API_BEGIN();
275  auto* builder = static_cast<frontend::TreeBuilder*>(handle);
276  CHECK(builder) << "Detected dangling reference to deleted TreeBuilder object";
277  builder->SetNumericalTestNode(node_key, feature_id, opname,
278  *static_cast<const frontend::Value*>(threshold),
279  (default_left != 0), left_child_key, right_child_key);
280  API_END();
281 }
282 
284  TreeBuilderHandle handle, int node_key, unsigned feature_id,
285  const unsigned int* left_categories, size_t left_categories_len, int default_left,
286  int left_child_key, int right_child_key) {
287  API_BEGIN();
288  auto* builder = static_cast<frontend::TreeBuilder*>(handle);
289  CHECK(builder) << "Detected dangling reference to deleted TreeBuilder object";
290  std::vector<uint32_t> vec(left_categories_len);
291  for (size_t i = 0; i < left_categories_len; ++i) {
292  CHECK(left_categories[i] <= std::numeric_limits<uint32_t>::max());
293  vec[i] = static_cast<uint32_t>(left_categories[i]);
294  }
295  builder->SetCategoricalTestNode(node_key, feature_id, vec, (default_left != 0),
296  left_child_key, right_child_key);
297  API_END();
298 }
299 
300 int TreeliteTreeBuilderSetLeafNode(TreeBuilderHandle handle, int node_key, ValueHandle leaf_value) {
301  API_BEGIN();
302  auto* builder = static_cast<frontend::TreeBuilder*>(handle);
303  CHECK(builder) << "Detected dangling reference to deleted TreeBuilder object";
304  builder->SetLeafNode(node_key, *static_cast<const frontend::Value*>(leaf_value));
305  API_END();
306 }
307 
309  const ValueHandle* leaf_vector, size_t leaf_vector_len) {
310  API_BEGIN();
311  auto* builder = static_cast<frontend::TreeBuilder*>(handle);
312  CHECK(builder) << "Detected dangling reference to deleted TreeBuilder object";
313  std::vector<frontend::Value> vec(leaf_vector_len);
314  CHECK(leaf_vector) << "leaf_vector argument must not be null";
315  for (size_t i = 0; i < leaf_vector_len; ++i) {
316  CHECK(leaf_vector[i]) << "leaf_vector[" << i << "] contains an empty Value handle";
317  vec[i] = *static_cast<const frontend::Value*>(leaf_vector[i]);
318  }
319  builder->SetLeafVectorNode(node_key, vec);
320  API_END();
321 }
322 
324  int num_feature, int num_class, int average_tree_output, const char* threshold_type,
325  const char* leaf_output_type, ModelBuilderHandle* out) {
326  API_BEGIN();
327  std::unique_ptr<frontend::ModelBuilder> builder{new frontend::ModelBuilder(
328  num_feature, num_class, (average_tree_output != 0), GetTypeInfoByName(threshold_type),
329  GetTypeInfoByName(leaf_output_type))};
330  *out = static_cast<ModelBuilderHandle>(builder.release());
331  API_END();
332 }
333 
335  const char* value) {
336  API_BEGIN();
337  auto* builder = static_cast<frontend::ModelBuilder*>(handle);
338  CHECK(builder) << "Detected dangling reference to deleted ModelBuilder object";
339  builder->SetModelParam(name, value);
340  API_END();
341 }
342 
344  API_BEGIN();
345  delete static_cast<frontend::ModelBuilder*>(handle);
346  API_END();
347 }
348 
350  int index) {
351  API_BEGIN();
352  auto* model_builder = static_cast<frontend::ModelBuilder*>(handle);
353  CHECK(model_builder) << "Detected dangling reference to deleted ModelBuilder object";
354  auto* tree_builder = static_cast<frontend::TreeBuilder*>(tree_builder_handle);
355  CHECK(tree_builder) << "Detected dangling reference to deleted TreeBuilder object";
356  return model_builder->InsertTree(tree_builder, index);
357  API_END();
358 }
359 
361  API_BEGIN();
362  auto* model_builder = static_cast<frontend::ModelBuilder*>(handle);
363  CHECK(model_builder) << "Detected dangling reference to deleted ModelBuilder object";
364  auto* tree_builder = model_builder->GetTree(index);
365  CHECK(tree_builder) << "Detected dangling reference to deleted TreeBuilder object";
366  *out = static_cast<TreeBuilderHandle>(tree_builder);
367  API_END();
368 }
369 
371  API_BEGIN();
372  auto* builder = static_cast<frontend::ModelBuilder*>(handle);
373  CHECK(builder) << "Detected dangling reference to deleted ModelBuilder object";
374  builder->DeleteTree(index);
375  API_END();
376 }
377 
379  API_BEGIN();
380  auto* builder = static_cast<frontend::ModelBuilder*>(handle);
381  CHECK(builder) << "Detected dangling reference to deleted ModelBuilder object";
382  std::unique_ptr<Model> model = builder->CommitModel();
383  *out = static_cast<ModelHandle>(model.release());
384  API_END();
385 }
Some useful math utilities.
int TreeliteQueryNumClass(ModelHandle handle, size_t *out)
Query the number of classes of the model. (1 if the model is binary classifier or regressor) ...
Definition: c_api.cc:199
C API of Treelite, used for interfacing with other languages This header is excluded from the runtime...
Parameters for tree compiler.
int TreeliteModelBuilderSetModelParam(ModelBuilderHandle handle, const char *name, const char *value)
Set a model parameter.
Definition: c_api.cc:334
branch annotator class
Definition: annotator.h:17
int TreeliteModelBuilderGetTree(ModelBuilderHandle handle, int index, TreeBuilderHandle *out)
Get a reference to a tree in the ensemble.
Definition: c_api.cc:360
std::unique_ptr< Model > CommitModel()
finalize the model and produce the in-memory representation
Definition: builder.cc:418
Collection of front-end methods to load or construct ensemble model.
int TreeliteLoadXGBoostModel(const char *filename, ModelHandle *out)
load a model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tree en...
Definition: c_api.cc:151
void DeleteNode(int node_key)
Remove a node from a tree.
Definition: builder.cc:185
#define API_BEGIN()
macro to guard beginning and end section of all functions
Definition: c_api_error.h:15
int TreeliteFreeModel(ModelHandle handle)
delete model from memory
Definition: c_api.cc:179
int TreeliteAnnotationSave(AnnotationHandle handle, const char *path)
save branch annotation to a JSON file
Definition: c_api.cc:50
int TreeliteQueryNumTree(ModelHandle handle, size_t *out)
Query the number of trees in the model.
Definition: c_api.cc:185
tree builder class
Definition: frontend.h:96
int TreeliteModelBuilderDeleteTree(ModelBuilderHandle handle, int index)
Remove a tree from the ensemble.
Definition: c_api.cc:370
parameters for tree compiler
Input data structure of Treelite.
void SetNumericalTestNode(int node_key, unsigned feature_id, const char *op, Value threshold, bool default_left, int left_child_key, int right_child_key)
Turn an empty node into a numerical test node; the test is in the form [feature value] OP [threshold]...
Definition: builder.cc:218
int TreeliteCompilerSetParam(CompilerHandle handle, const char *name, const char *value)
set a parameter for a compiler
Definition: c_api.cc:73
int TreeliteLoadXGBoostJSON(const char *filename, ModelHandle *out)
load a json model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tr...
Definition: c_api.cc:158
int TreeliteDeleteModelBuilder(ModelBuilderHandle handle)
Delete a model builder from memory.
Definition: c_api.cc:343
int TreeliteTreeBuilderSetCategoricalTestNode(TreeBuilderHandle handle, int node_key, unsigned feature_id, const unsigned int *left_categories, size_t left_categories_len, int default_left, int left_child_key, int right_child_key)
Turn an empty node into a test node with categorical split. A list defines all categories that would ...
Definition: c_api.cc:283
model structure for tree ensemble
void SetModelParam(const char *name, const char *value)
Set a model parameter.
Definition: builder.cc:337
void SetLeafNode(int node_key, Value leaf_value)
Turn an empty node into a leaf node.
Definition: builder.cc:296
void SetRootNode(int node_key)
Set a node as the root of a tree.
Definition: builder.cc:208
int TreeliteAnnotationFree(AnnotationHandle handle)
delete branch annotation from memory
Definition: c_api.cc:59
TreeBuilder * GetTree(int index)
Get a reference to a tree in the ensemble.
Definition: builder.cc:401
void Save(dmlc::Stream *fo) const
save branch annotation to a JSON file
Definition: annotator.cc:248
int TreeliteTreeBuilderSetNumericalTestNode(TreeBuilderHandle handle, int node_key, unsigned feature_id, const char *opname, ValueHandle threshold, int default_left, int left_child_key, int right_child_key)
Turn an empty node into a test node with numerical split. The test is in the form [feature value] OP ...
Definition: c_api.cc:271
int TreeliteAnnotateBranch(ModelHandle model, DMatrixHandle dmat, int nthread, int verbose, AnnotationHandle *out)
annotate branches in a given model using frequency patterns in the training data. ...
Definition: c_api.cc:38
int TreeliteTreeBuilderCreateValue(const void *init_value, const char *type, ValueHandle *out)
Create a new Value object. Some model builder API functions accept this Value type to accommodate val...
Definition: c_api.cc:216
Interface of compiler that compiles a tree ensemble model.
int TreeliteSetTreeLimit(ModelHandle handle, size_t limit)
keep first N trees of model, limit must smaller than number of trees.
Definition: c_api.cc:206
model builder class
Definition: frontend.h:185
void * ValueHandle
handle to a polymorphic value type, used in the model builder API
Definition: c_api.h:33
int TreeliteModelBuilderInsertTree(ModelBuilderHandle handle, TreeBuilderHandle tree_builder_handle, int index)
Insert a tree at specified location.
Definition: c_api.cc:349
void * DMatrixHandle
handle to a data matrix
Definition: c_api_common.h:30
Cross-platform wrapper for common filesystem functions.
int TreeliteTreeBuilderSetRootNode(TreeBuilderHandle handle, int node_key)
Set a node as the root of a tree.
Definition: c_api.cc:263
void * TreeBuilderHandle
handle to tree builder class
Definition: c_api.h:25
Error handling for C API.
int TreeliteDeleteTreeBuilder(TreeBuilderHandle handle)
Delete a tree builder from memory.
Definition: c_api.cc:241
int TreeliteCreateTreeBuilder(const char *threshold_type, const char *leaf_output_type, TreeBuilderHandle *out)
Create a new tree builder.
Definition: c_api.cc:230
void * AnnotationHandle
handle to branch annotation data
Definition: c_api.h:29
int TreeliteModelBuilderCommitModel(ModelBuilderHandle handle, ModelHandle *out)
finalize the model and produce the in-memory representation
Definition: c_api.cc:378
int TreeliteTreeBuilderSetLeafNode(TreeBuilderHandle handle, int node_key, ValueHandle leaf_value)
Turn an empty node into a leaf node.
Definition: c_api.cc:300
int TreeliteCompilerGenerateCode(CompilerHandle compiler, ModelHandle model, int verbose, const char *dirpath)
generate prediction code from a tree ensemble model. The code will be C99 compliant. One header file (.h) will be generated, along with one or more source files (.c).
Definition: c_api.cc:94
int TreeliteCreateModelBuilder(int num_feature, int num_class, int average_tree_output, const char *threshold_type, const char *leaf_output_type, ModelBuilderHandle *out)
Create a new model builder.
Definition: c_api.cc:323
void DeleteTree(int index)
Remove a tree from the ensemble.
Definition: builder.cc:411
int TreeliteLoadXGBoostModelFromMemoryBuffer(const void *buf, size_t len, ModelHandle *out)
load an XGBoost model from a memory buffer.
Definition: c_api.cc:172
void * ModelHandle
handle to a decision tree ensemble model
Definition: c_api.h:23
int TreeliteLoadLightGBMModel(const char *filename, ModelHandle *out)
load a model file generated by LightGBM (Microsoft/LightGBM). The model file must contain a decision ...
Definition: c_api.cc:144
TypeInfo GetTypeInfoByName(const std::string &str)
conversion table from string to TypeInfo, defined in tables.cc
Definition: typeinfo.cc:16
static Compiler * Create(const std::string &name, const compiler::CompilerParam &param)
create a compiler from given name
Definition: compiler.cc:15
int TreeliteTreeBuilderDeleteValue(ValueHandle handle)
Delete a Value object from memory.
Definition: c_api.cc:224
int TreeliteCompilerCreate(const char *name, CompilerHandle *out)
create a compiler with a given name
Definition: c_api.cc:65
Branch annotation tools.
int TreeliteTreeBuilderDeleteNode(TreeBuilderHandle handle, int node_key)
Remove a node from a tree.
Definition: c_api.cc:255
thin wrapper for tree ensemble model
Definition: tree.h:615
int TreeliteQueryNumFeature(ModelHandle handle, size_t *out)
Query the number of features used in the model.
Definition: c_api.cc:192
void * ModelBuilderHandle
handle to ensemble builder class
Definition: c_api.h:27
int TreeliteTreeBuilderSetLeafVectorNode(TreeBuilderHandle handle, int node_key, const ValueHandle *leaf_vector, size_t leaf_vector_len)
Turn an empty node into a leaf vector node The leaf vector (collection of multiple leaf weights per l...
Definition: c_api.cc:308
void CreateNode(int node_key)
Create an empty node within a tree.
Definition: builder.cc:178
void * CompilerHandle
handle to compiler class
Definition: c_api.h:31
int TreeliteLoadXGBoostJSONString(const char *json_str, size_t length, ModelHandle *out)
load a model stored as JSON stringby XGBoost (dmlc/xgboost). The model json must contain a decision t...
Definition: c_api.cc:165
int TreeliteCompilerFree(CompilerHandle handle)
delete compiler from memory
Definition: c_api.cc:138
#define API_END()
every function starts with API_BEGIN(); and finishes with API_END() or API_END_HANDLE_ERROR ...
Definition: c_api_error.h:18
int TreeliteTreeBuilderCreateNode(TreeBuilderHandle handle, int node_key)
Create an empty node within a tree.
Definition: c_api.cc:247