20 #include <dmlc/thread_local.h> 28 struct CompilerHandleImpl {
30 std::vector<std::pair<std::string, std::string>> cfg;
31 std::unique_ptr<Compiler> compiler;
32 explicit CompilerHandleImpl(
const std::string& name)
33 : name(name), cfg(), compiler(nullptr) {}
34 ~CompilerHandleImpl() =
default;
43 const Model* model_ =
static_cast<Model*
>(model);
44 const auto* dmat_ =
static_cast<const DMatrix*
>(dmat);
45 CHECK(dmat_) <<
"Found a dangling reference to DMatrix";
46 annotator->Annotate(*model_, dmat_, nthread, verbose);
55 std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(path,
"w"));
56 annotator->
Save(fo.get());
69 std::unique_ptr<CompilerHandleImpl> compiler{
new CompilerHandleImpl(name)};
78 CompilerHandleImpl* impl =
static_cast<CompilerHandleImpl*
>(handle);
79 auto& cfg_ = impl->cfg;
80 std::string name_(name);
81 std::string value_(value);
83 auto it = std::find_if(cfg_.begin(), cfg_.end(),
84 [&name_](
const std::pair<std::string, std::string>& x) {
85 return x.first == name_;
87 if (it == cfg_.end()) {
88 cfg_.emplace_back(name_, value_);
98 const char* dirpath) {
102 std::to_string(verbose).c_str());
107 const Model* model_ =
static_cast<Model*
>(model);
108 CompilerHandleImpl* impl =
static_cast<CompilerHandleImpl*
>(compiler);
111 const std::string& dirpath_(dirpath);
112 filesystem::CreateDirectoryIfNotExist(dirpath);
115 cparam.Init(impl->cfg, dmlc::parameter::kAllMatch);
119 auto compiled_model = impl->compiler->Compile(*model_);
121 LOG(INFO) <<
"Code generation finished. Writing code to files...";
124 for (
const auto& it : compiled_model.files) {
126 LOG(INFO) <<
"Writing file " << it.first <<
"...";
128 const std::string filename_full = dirpath_ +
"/" + it.first;
129 if (it.second.is_binary) {
130 filesystem::WriteToFile(filename_full, it.second.content_binary);
132 filesystem::WriteToFile(filename_full, it.second.content);
141 delete static_cast<CompilerHandleImpl*
>(handle);
147 std::unique_ptr<Model> model = frontend::LoadLightGBMModel(filename);
154 std::unique_ptr<Model> model = frontend::LoadXGBoostModel(filename);
161 std::unique_ptr<Model> model = frontend::LoadXGBoostJSONModel(filename);
168 std::unique_ptr<Model> model = frontend::LoadXGBoostJSONModelString(json_str, length);
175 std::unique_ptr<Model> model = frontend::LoadXGBoostModel(buf, len);
181 int n_estimators,
int n_features,
const int64_t* node_count,
const int64_t** children_left,
182 const int64_t** children_right,
const int64_t** feature,
const double** threshold,
183 const double** value,
const int64_t** n_node_samples,
const double** impurity,
186 std::unique_ptr<Model> model = frontend::LoadSKLearnRandomForestRegressor(
187 n_estimators, n_features, node_count, children_left, children_right, feature, threshold,
188 value, n_node_samples, impurity);
194 int n_estimators,
int n_features,
int n_classes,
const int64_t* node_count,
195 const int64_t** children_left,
const int64_t** children_right,
const int64_t** feature,
196 const double** threshold,
const double** value,
const int64_t** n_node_samples,
199 std::unique_ptr<Model> model = frontend::LoadSKLearnRandomForestClassifier(
200 n_estimators, n_features, n_classes, node_count, children_left, children_right, feature,
201 threshold, value, n_node_samples, impurity);
207 int n_estimators,
int n_features,
const int64_t* node_count,
const int64_t** children_left,
208 const int64_t** children_right,
const int64_t** feature,
const double** threshold,
209 const double** value,
const int64_t** n_node_samples,
const double** impurity,
212 std::unique_ptr<Model> model = frontend::LoadSKLearnGradientBoostingRegressor(
213 n_estimators, n_features, node_count, children_left, children_right, feature, threshold,
214 value, n_node_samples, impurity);
220 int n_estimators,
int n_features,
int n_classes,
const int64_t* node_count,
221 const int64_t** children_left,
const int64_t** children_right,
const int64_t** feature,
222 const double** threshold,
const double** value,
const int64_t** n_node_samples,
225 std::unique_ptr<Model> model = frontend::LoadSKLearnGradientBoostingClassifier(
226 n_estimators, n_features, n_classes, node_count, children_left, children_right, feature,
227 threshold, value, n_node_samples, impurity);
234 FILE* fp = std::fopen(filename,
"wb");
235 CHECK(fp) <<
"Failed to open file '" << filename <<
"'";
236 auto* model_ =
static_cast<Model*
>(handle);
237 model_->SerializeToFile(fp);
244 FILE* fp = std::fopen(filename,
"rb");
245 CHECK(fp) <<
"Failed to open file '" << filename <<
"'";
246 std::unique_ptr<Model> model = Model::DeserializeFromFile(fp);
254 delete static_cast<Model*
>(handle);
258 int TreeliteGTILGetPredictOutputSize(
ModelHandle handle,
size_t num_row,
size_t* out) {
260 const auto* model_ =
static_cast<const Model*
>(handle);
261 *out = gtil::GetPredictOutputSize(model_, num_row);
265 int TreeliteGTILPredict(
ModelHandle handle,
const float* input,
size_t num_row,
float* output,
266 int pred_transform,
size_t* out_result_size) {
268 const auto* model_ =
static_cast<const Model*
>(handle);
270 gtil::Predict(model_, input, num_row, output, (pred_transform == 1));
276 const auto* model_ =
static_cast<const Model*
>(handle);
277 *out = model_->GetNumTree();
283 const auto* model_ =
static_cast<const Model*
>(handle);
284 *out =
static_cast<size_t>(model_->num_feature);
290 const auto* model_ =
static_cast<const Model*
>(handle);
291 *out =
static_cast<size_t>(model_->task_param.num_class);
297 CHECK_GT(limit, 0) <<
"limit should be greater than 0!";
298 auto* model_ =
static_cast<Model*
>(handle);
299 const size_t num_tree = model_->GetNumTree();
300 CHECK_GE(num_tree, limit) <<
"Model contains less trees(" << num_tree <<
") than limit";
301 model_->SetTreeLimit(limit);
307 std::unique_ptr<frontend::Value> value = std::make_unique<frontend::Value>();
322 std::unique_ptr<frontend::TreeBuilder> builder{
339 CHECK(builder) <<
"Detected dangling reference to deleted TreeBuilder object";
347 CHECK(builder) <<
"Detected dangling reference to deleted TreeBuilder object";
355 CHECK(builder) <<
"Detected dangling reference to deleted TreeBuilder object";
362 ValueHandle threshold,
int default_left,
int left_child_key,
int right_child_key) {
365 CHECK(builder) <<
"Detected dangling reference to deleted TreeBuilder object";
367 *static_cast<const frontend::Value*>(threshold),
368 (default_left != 0), left_child_key, right_child_key);
374 const unsigned int* left_categories,
size_t left_categories_len,
int default_left,
375 int left_child_key,
int right_child_key) {
378 CHECK(builder) <<
"Detected dangling reference to deleted TreeBuilder object";
379 std::vector<uint32_t> vec(left_categories_len);
380 for (
size_t i = 0; i < left_categories_len; ++i) {
381 CHECK(left_categories[i] <= std::numeric_limits<uint32_t>::max());
382 vec[i] =
static_cast<uint32_t
>(left_categories[i]);
384 builder->SetCategoricalTestNode(node_key, feature_id, vec, (default_left != 0),
385 left_child_key, right_child_key);
392 CHECK(builder) <<
"Detected dangling reference to deleted TreeBuilder object";
393 builder->
SetLeafNode(node_key, *static_cast<const frontend::Value*>(leaf_value));
398 const ValueHandle* leaf_vector,
size_t leaf_vector_len) {
401 CHECK(builder) <<
"Detected dangling reference to deleted TreeBuilder object";
402 std::vector<frontend::Value> vec(leaf_vector_len);
403 CHECK(leaf_vector) <<
"leaf_vector argument must not be null";
404 for (
size_t i = 0; i < leaf_vector_len; ++i) {
405 CHECK(leaf_vector[i]) <<
"leaf_vector[" << i <<
"] contains an empty Value handle";
408 builder->SetLeafVectorNode(node_key, vec);
413 int num_feature,
int num_class,
int average_tree_output,
const char* threshold_type,
417 num_feature, num_class, (average_tree_output != 0),
GetTypeInfoByName(threshold_type),
427 CHECK(builder) <<
"Detected dangling reference to deleted ModelBuilder object";
442 CHECK(model_builder) <<
"Detected dangling reference to deleted ModelBuilder object";
444 CHECK(tree_builder) <<
"Detected dangling reference to deleted TreeBuilder object";
445 return model_builder->InsertTree(tree_builder, index);
452 CHECK(model_builder) <<
"Detected dangling reference to deleted ModelBuilder object";
453 auto* tree_builder = model_builder->
GetTree(index);
454 CHECK(tree_builder) <<
"Detected dangling reference to deleted TreeBuilder object";
462 CHECK(builder) <<
"Detected dangling reference to deleted ModelBuilder object";
470 CHECK(builder) <<
"Detected dangling reference to deleted ModelBuilder object";
471 std::unique_ptr<Model> model = builder->
CommitModel();
Some useful math utilities.
int TreeliteQueryNumClass(ModelHandle handle, size_t *out)
Query the number of classes of the model. (1 if the model is binary classifier or regressor) ...
C API of Treelite, used for interfacing with other languages This header is excluded from the runtime...
Parameters for tree compiler.
int TreeliteModelBuilderSetModelParam(ModelBuilderHandle handle, const char *name, const char *value)
Set a model parameter.
int TreeliteLoadSKLearnRandomForestRegressor(int n_estimators, int n_features, const int64_t *node_count, const int64_t **children_left, const int64_t **children_right, const int64_t **feature, const double **threshold, const double **value, const int64_t **n_node_samples, const double **impurity, ModelHandle *out)
Load a scikit-learn random forest regressor model from a collection of arrays. Refer to https://sciki...
int TreeliteModelBuilderGetTree(ModelBuilderHandle handle, int index, TreeBuilderHandle *out)
Get a reference to a tree in the ensemble.
std::unique_ptr< Model > CommitModel()
finalize the model and produce the in-memory representation
Collection of front-end methods to load or construct ensemble model.
int TreeliteLoadXGBoostModel(const char *filename, ModelHandle *out)
load a model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tree en...
void DeleteNode(int node_key)
Remove a node from a tree.
#define API_BEGIN()
macro to guard beginning and end section of all functions
int TreeliteFreeModel(ModelHandle handle)
delete model from memory
int TreeliteAnnotationSave(AnnotationHandle handle, const char *path)
save branch annotation to a JSON file
int TreeliteQueryNumTree(ModelHandle handle, size_t *out)
Query the number of trees in the model.
int TreeliteLoadSKLearnGradientBoostingRegressor(int n_estimators, int n_features, const int64_t *node_count, const int64_t **children_left, const int64_t **children_right, const int64_t **feature, const double **threshold, const double **value, const int64_t **n_node_samples, const double **impurity, ModelHandle *out)
Load a scikit-learn gradient boosting regressor model from a collection of arrays. Refer to https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html to learn the mearning of the arrays in detail.
int TreeliteModelBuilderDeleteTree(ModelBuilderHandle handle, int index)
Remove a tree from the ensemble.
parameters for tree compiler
Input data structure of Treelite.
void SetNumericalTestNode(int node_key, unsigned feature_id, const char *op, Value threshold, bool default_left, int left_child_key, int right_child_key)
Turn an empty node into a numerical test node; the test is in the form [feature value] OP [threshold]...
int TreeliteCompilerSetParam(CompilerHandle handle, const char *name, const char *value)
set a parameter for a compiler
int TreeliteDeserializeModel(const char *filename, ModelHandle *out)
Deserialize (load) a model object from disk.
int TreeliteLoadXGBoostJSON(const char *filename, ModelHandle *out)
load a json model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tr...
int TreeliteDeleteModelBuilder(ModelBuilderHandle handle)
Delete a model builder from memory.
int TreeliteTreeBuilderSetCategoricalTestNode(TreeBuilderHandle handle, int node_key, unsigned feature_id, const unsigned int *left_categories, size_t left_categories_len, int default_left, int left_child_key, int right_child_key)
Turn an empty node into a test node with categorical split. A list defines all categories that would ...
model structure for tree ensemble
void SetModelParam(const char *name, const char *value)
Set a model parameter.
void SetLeafNode(int node_key, Value leaf_value)
Turn an empty node into a leaf node.
void SetRootNode(int node_key)
Set a node as the root of a tree.
int TreeliteAnnotationFree(AnnotationHandle handle)
delete branch annotation from memory
TreeBuilder * GetTree(int index)
Get a reference to a tree in the ensemble.
void Save(dmlc::Stream *fo) const
save branch annotation to a JSON file
int TreeliteTreeBuilderSetNumericalTestNode(TreeBuilderHandle handle, int node_key, unsigned feature_id, const char *opname, ValueHandle threshold, int default_left, int left_child_key, int right_child_key)
Turn an empty node into a test node with numerical split. The test is in the form [feature value] OP ...
int TreeliteLoadSKLearnRandomForestClassifier(int n_estimators, int n_features, int n_classes, const int64_t *node_count, const int64_t **children_left, const int64_t **children_right, const int64_t **feature, const double **threshold, const double **value, const int64_t **n_node_samples, const double **impurity, ModelHandle *out)
Load a scikit-learn random forest classifier model from a collection of arrays. Refer to https://scik...
int TreeliteAnnotateBranch(ModelHandle model, DMatrixHandle dmat, int nthread, int verbose, AnnotationHandle *out)
annotate branches in a given model using frequency patterns in the training data. ...
int TreeliteTreeBuilderCreateValue(const void *init_value, const char *type, ValueHandle *out)
Create a new Value object. Some model builder API functions accept this Value type to accommodate val...
Interface of compiler that compiles a tree ensemble model.
int TreeliteSetTreeLimit(ModelHandle handle, size_t limit)
keep first N trees of model, limit must smaller than number of trees.
void * ValueHandle
handle to a polymorphic value type, used in the model builder API
int TreeliteModelBuilderInsertTree(ModelBuilderHandle handle, TreeBuilderHandle tree_builder_handle, int index)
Insert a tree at specified location.
void * DMatrixHandle
handle to a data matrix
Cross-platform wrapper for common filesystem functions.
int TreeliteTreeBuilderSetRootNode(TreeBuilderHandle handle, int node_key)
Set a node as the root of a tree.
void * TreeBuilderHandle
handle to tree builder class
Error handling for C API.
int TreeliteDeleteTreeBuilder(TreeBuilderHandle handle)
Delete a tree builder from memory.
int TreeliteCreateTreeBuilder(const char *threshold_type, const char *leaf_output_type, TreeBuilderHandle *out)
Create a new tree builder.
void * AnnotationHandle
handle to branch annotation data
int TreeliteModelBuilderCommitModel(ModelBuilderHandle handle, ModelHandle *out)
finalize the model and produce the in-memory representation
General Tree Inference Library (GTIL), providing a reference implementation for predicting with decis...
int TreeliteTreeBuilderSetLeafNode(TreeBuilderHandle handle, int node_key, ValueHandle leaf_value)
Turn an empty node into a leaf node.
int TreeliteCompilerGenerateCode(CompilerHandle compiler, ModelHandle model, int verbose, const char *dirpath)
generate prediction code from a tree ensemble model. The code will be C99 compliant. One header file (.h) will be generated, along with one or more source files (.c).
int TreeliteCreateModelBuilder(int num_feature, int num_class, int average_tree_output, const char *threshold_type, const char *leaf_output_type, ModelBuilderHandle *out)
Create a new model builder.
int TreeliteLoadSKLearnGradientBoostingClassifier(int n_estimators, int n_features, int n_classes, const int64_t *node_count, const int64_t **children_left, const int64_t **children_right, const int64_t **feature, const double **threshold, const double **value, const int64_t **n_node_samples, const double **impurity, ModelHandle *out)
Load a scikit-learn gradient boosting classifier model from a collection of arrays. Refer to https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html to learn the mearning of the arrays in detail.
void DeleteTree(int index)
Remove a tree from the ensemble.
int TreeliteLoadXGBoostModelFromMemoryBuffer(const void *buf, size_t len, ModelHandle *out)
load an XGBoost model from a memory buffer.
void * ModelHandle
handle to a decision tree ensemble model
int TreeliteLoadLightGBMModel(const char *filename, ModelHandle *out)
load a model file generated by LightGBM (Microsoft/LightGBM). The model file must contain a decision ...
TypeInfo GetTypeInfoByName(const std::string &str)
conversion table from string to TypeInfo, defined in tables.cc
static Compiler * Create(const std::string &name, const compiler::CompilerParam ¶m)
create a compiler from given name
int TreeliteTreeBuilderDeleteValue(ValueHandle handle)
Delete a Value object from memory.
int TreeliteCompilerCreate(const char *name, CompilerHandle *out)
create a compiler with a given name
int TreeliteTreeBuilderDeleteNode(TreeBuilderHandle handle, int node_key)
Remove a node from a tree.
thin wrapper for tree ensemble model
int TreeliteQueryNumFeature(ModelHandle handle, size_t *out)
Query the number of features used in the model.
void * ModelBuilderHandle
handle to ensemble builder class
int TreeliteSerializeModel(const char *filename, ModelHandle handle)
Serialize (persist) a model object to disk.
int TreeliteTreeBuilderSetLeafVectorNode(TreeBuilderHandle handle, int node_key, const ValueHandle *leaf_vector, size_t leaf_vector_len)
Turn an empty node into a leaf vector node The leaf vector (collection of multiple leaf weights per l...
void CreateNode(int node_key)
Create an empty node within a tree.
void * CompilerHandle
handle to compiler class
int TreeliteLoadXGBoostJSONString(const char *json_str, size_t length, ModelHandle *out)
load a model stored as JSON stringby XGBoost (dmlc/xgboost). The model json must contain a decision t...
int TreeliteCompilerFree(CompilerHandle handle)
delete compiler from memory
#define API_END()
every function starts with API_BEGIN(); and finishes with API_END() or API_END_HANDLE_ERROR ...
int TreeliteTreeBuilderCreateNode(TreeBuilderHandle handle, int node_key)
Create an empty node within a tree.