17 #include <dmlc/thread_local.h> 26 struct CompilerHandleImpl {
28 std::vector<std::pair<std::string, std::string>> cfg;
29 std::unique_ptr<Compiler> compiler;
30 explicit CompilerHandleImpl(
const std::string& name)
31 : name(name), cfg(), compiler(nullptr) {}
32 ~CompilerHandleImpl() =
default;
36 struct TreeliteAPIThreadLocalEntry {
42 using TreeliteAPIThreadLocalStore
43 = dmlc::ThreadLocalStore<TreeliteAPIThreadLocalEntry>;
59 const unsigned* col_ind,
60 const size_t* row_ptr,
65 std::unique_ptr<DMatrix> dmat{
new DMatrix()};
67 auto& data_ = dmat->data;
68 auto& col_ind_ = dmat->col_ind;
69 auto& row_ptr_ = dmat->row_ptr;
70 data_.reserve(row_ptr[num_row]);
71 col_ind_.reserve(row_ptr[num_row]);
72 row_ptr_.reserve(num_row + 1);
73 for (
size_t i = 0; i < num_row; ++i) {
74 const size_t jbegin = row_ptr[i];
75 const size_t jend = row_ptr[i + 1];
76 for (
size_t j = jbegin; j < jend; ++j) {
77 if (!math::CheckNAN(data[j])) {
78 data_.push_back(data[j]);
79 CHECK_LT(col_ind[j], std::numeric_limits<uint32_t>::max())
80 <<
"feature index too big to fit into uint32_t";
81 col_ind_.push_back(static_cast<uint32_t>(col_ind[j]));
84 row_ptr_.push_back(data_.size());
86 data_.shrink_to_fit();
87 col_ind_.shrink_to_fit();
88 dmat->num_row = num_row;
89 dmat->num_col = num_col;
90 dmat->nelem = data_.size();
101 const bool nan_missing = math::CheckNAN(missing_value);
103 CHECK_LT(num_col, std::numeric_limits<uint32_t>::max())
104 <<
"num_col argument is too big";
105 std::unique_ptr<DMatrix> dmat{
new DMatrix()};
107 auto& data_ = dmat->data;
108 auto& col_ind_ = dmat->col_ind;
109 auto& row_ptr_ = dmat->row_ptr;
112 const size_t guess_size
113 = std::min(std::min(num_row * num_col, num_row * 1000),
114 static_cast<size_t>(64 * 1024 * 1024));
115 data_.reserve(guess_size);
116 col_ind_.reserve(guess_size);
117 row_ptr_.reserve(num_row + 1);
118 const float* row = &data[0];
119 for (
size_t i = 0; i < num_row; ++i, row += num_col) {
120 for (
size_t j = 0; j < num_col; ++j) {
121 if (math::CheckNAN(row[j])) {
123 <<
"The missing_value argument must be set to NaN if there is any " 124 <<
"NaN in the matrix.";
125 }
else if (nan_missing || row[j] != missing_value) {
127 data_.push_back(row[j]);
128 col_ind_.push_back(static_cast<uint32_t>(j));
131 row_ptr_.push_back(data_.size());
133 data_.shrink_to_fit();
134 col_ind_.shrink_to_fit();
135 dmat->num_row = num_row;
136 dmat->num_col = num_col;
137 dmat->nelem = data_.size();
151 *out_nelem = dmat->
nelem;
156 const char** out_preview) {
159 std::string& ret_str = TreeliteAPIThreadLocalStore::Get()->ret_str;
160 std::ostringstream oss;
161 const size_t iend = (dmat->
nelem <= 50) ? dmat->
nelem : 25;
162 for (
size_t i = 0; i < iend; ++i) {
163 const size_t row_ind =
166 oss <<
" (" << row_ind <<
", " << dmat->
col_ind[i] <<
")\t" 167 << dmat->
data[i] <<
"\n";
169 if (dmat->
nelem > 50) {
171 for (
size_t i = dmat->
nelem - 25; i < dmat->nelem; ++i) {
172 const size_t row_ind =
175 oss <<
" (" << row_ind <<
", " << dmat->
col_ind[i] <<
")\t" 176 << dmat->
data[i] <<
"\n";
180 *out_preview = ret_str.c_str();
185 const float** out_data,
186 const uint32_t** out_col_ind,
187 const size_t** out_row_ptr) {
190 *out_data = &dmat_->
data[0];
191 *out_col_ind = &dmat_->
col_ind[0];
192 *out_row_ptr = &dmat_->
row_ptr[0];
198 delete static_cast<DMatrix*
>(handle);
209 const Model* model_ =
static_cast<Model*
>(model);
211 annotator->Annotate(*model_, dmat_, nthread, verbose);
220 std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(path,
"w"));
221 annotator->
Save(fo.get());
234 std::unique_ptr<CompilerHandleImpl> compiler{
new CompilerHandleImpl(name)};
243 CompilerHandleImpl* impl =
static_cast<CompilerHandleImpl*
>(handle);
244 auto& cfg_ = impl->cfg;
245 std::string name_(name);
246 std::string value_(value);
248 auto it = std::find_if(cfg_.begin(), cfg_.end(),
249 [&name_](
const std::pair<std::string, std::string>& x) {
250 return x.first == name_;
252 if (it == cfg_.end()) {
253 cfg_.emplace_back(name_, value_);
263 const char* dirpath) {
267 std::to_string(verbose).c_str());
272 const Model* model_ =
static_cast<Model*
>(model);
273 CompilerHandleImpl* impl =
static_cast<CompilerHandleImpl*
>(compiler);
276 const std::string& dirpath_(dirpath);
277 filesystem::CreateDirectoryIfNotExist(dirpath);
280 cparam.Init(impl->cfg, dmlc::parameter::kAllMatch);
284 auto compiled_model = impl->compiler->Compile(*model_);
286 LOG(INFO) <<
"Code generation finished. Writing code to files...";
289 for (
const auto& it : compiled_model.files) {
291 LOG(INFO) <<
"Writing file " << it.first <<
"...";
293 const std::string filename_full = dirpath_ +
"/" + it.first;
294 if (it.second.is_binary) {
295 filesystem::WriteToFile(filename_full, it.second.content_binary);
297 filesystem::WriteToFile(filename_full, it.second.content);
306 delete static_cast<CompilerHandleImpl*
>(handle);
313 std::unique_ptr<Model> model{
new Model()};
314 frontend::LoadLightGBMModel(filename, model.get());
322 std::unique_ptr<Model> model{
new Model()};
323 frontend::LoadXGBoostModel(filename, model.get());
331 std::unique_ptr<Model> model{
new Model()};
332 frontend::LoadXGBoostModel(buf, len, model.get());
340 std::unique_ptr<Model> model{
new Model()};
341 frontend::LoadProtobufModel(filename, model.get());
349 auto model_ =
static_cast<Model*
>(model);
350 frontend::ExportProtobufModel(filename, *model_);
356 delete static_cast<Model*
>(handle);
362 auto model_ =
static_cast<const Model*
>(handle);
363 *out = model_->
trees.size();
369 auto model_ =
static_cast<const Model*
>(handle);
370 *out =
static_cast<size_t>(model_->num_feature);
376 auto model_ =
static_cast<const Model*
>(handle);
377 *out =
static_cast<size_t>(model_->num_output_group);
383 CHECK_GT(limit, 0) <<
"limit should be greater than 0!";
384 auto model_ =
static_cast<Model*
>(handle);
385 CHECK_GE(model_->trees.size(), limit)
386 <<
"Model contains less trees(" << model_->
trees.size() <<
") than limit";
387 model_->trees.resize(limit);
407 CHECK(builder) <<
"Detected dangling reference to deleted TreeBuilder object";
415 CHECK(builder) <<
"Detected dangling reference to deleted TreeBuilder object";
423 CHECK(builder) <<
"Detected dangling reference to deleted TreeBuilder object";
429 int node_key,
unsigned feature_id,
431 float threshold,
int default_left,
433 int right_child_key) {
436 CHECK(builder) <<
"Detected dangling reference to deleted TreeBuilder object";
438 (default_left != 0), left_child_key, right_child_key);
444 int node_key,
unsigned feature_id,
445 const unsigned int* left_categories,
446 size_t left_categories_len,
449 int right_child_key) {
452 CHECK(builder) <<
"Detected dangling reference to deleted TreeBuilder object";
453 std::vector<uint32_t> vec(left_categories_len);
454 for (
size_t i = 0; i < left_categories_len; ++i) {
455 CHECK(left_categories[i] <= std::numeric_limits<uint32_t>::max());
456 vec[i] =
static_cast<uint32_t
>(left_categories[i]);
458 builder->SetCategoricalTestNode(node_key, feature_id, vec, (default_left != 0),
459 left_child_key, right_child_key);
466 CHECK(builder) <<
"Detected dangling reference to deleted TreeBuilder object";
467 builder->
SetLeafNode(node_key, static_cast<tl_float>(leaf_value));
473 const float* leaf_vector,
474 size_t leaf_vector_len) {
477 CHECK(builder) <<
"Detected dangling reference to deleted TreeBuilder object";
478 std::vector<tl_float> vec(leaf_vector_len);
479 for (
size_t i = 0; i < leaf_vector_len; ++i) {
480 vec[i] =
static_cast<tl_float>(leaf_vector[i]);
482 builder->SetLeafVectorNode(node_key, vec);
487 int num_output_group,
488 int random_forest_flag,
494 (random_forest_flag != 0))};
504 CHECK(builder) <<
"Detected dangling reference to deleted ModelBuilder object";
520 CHECK(model_builder) <<
"Detected dangling reference to deleted ModelBuilder object";
522 CHECK(tree_builder) <<
"Detected dangling reference to deleted TreeBuilder object";
523 return model_builder->InsertTree(tree_builder, index);
531 CHECK(model_builder) <<
"Detected dangling reference to deleted ModelBuilder object";
532 auto tree_builder = model_builder->
GetTree(index);
533 CHECK(tree_builder) <<
"Detected dangling reference to deleted TreeBuilder object";
541 CHECK(builder) <<
"Detected dangling reference to deleted ModelBuilder object";
550 CHECK(builder) <<
"Detected dangling reference to deleted ModelBuilder object";
551 std::unique_ptr<Model> model{
new Model()};
552 builder->CommitModel(model.get());
Some useful math utilities.
C API of Treelite, used for interfacing with other languages This header is excluded from the runtime...
int TreeliteTreeBuilderSetNumericalTestNode(TreeBuilderHandle handle, int node_key, unsigned feature_id, const char *opname, float threshold, int default_left, int left_child_key, int right_child_key)
Turn an empty node into a test node with numerical split. The test is in the form [feature value] OP ...
Parameters for tree compiler.
int TreeliteModelBuilderSetModelParam(ModelBuilderHandle handle, const char *name, const char *value)
Set a model parameter.
int TreeliteModelBuilderGetTree(ModelBuilderHandle handle, int index, TreeBuilderHandle *out)
Get a reference to a tree in the ensemble.
std::vector< float > data
feature values
Collection of front-end methods to load or construct ensemble model.
thin wrapper for tree ensemble model
float tl_float
float type to be used internally
int TreeliteLoadXGBoostModel(const char *filename, ModelHandle *out)
load a model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tree en...
void DeleteNode(int node_key)
Remove a node from a tree.
#define API_BEGIN()
macro to guard beginning and end section of all functions
int TreeliteFreeModel(ModelHandle handle)
delete model from memory
int TreeliteAnnotationSave(AnnotationHandle handle, const char *path)
save branch annotation to a JSON file
int TreeliteQueryNumTree(ModelHandle handle, size_t *out)
Query the number of trees in the model.
int TreeliteDMatrixCreateFromCSR(const float *data, const unsigned *col_ind, const size_t *row_ptr, size_t num_row, size_t num_col, DMatrixHandle *out)
create DMatrix from a (in-memory) CSR matrix
std::vector< Tree > trees
member trees
int TreeliteDMatrixFree(DMatrixHandle handle)
delete DMatrix from memory
int TreeliteModelBuilderDeleteTree(ModelBuilderHandle handle, int index)
Remove a tree from the ensemble.
parameters for tree compiler
Input data structure of Treelite.
int TreeliteCompilerSetParam(CompilerHandle handle, const char *name, const char *value)
set a parameter for a compiler
int TreeliteDeleteModelBuilder(ModelBuilderHandle handle)
Delete a model builder from memory.
int TreeliteTreeBuilderSetCategoricalTestNode(TreeBuilderHandle handle, int node_key, unsigned feature_id, const unsigned int *left_categories, size_t left_categories_len, int default_left, int left_child_key, int right_child_key)
Turn an empty node into a test node with categorical split. A list defines all categories that would ...
void SetModelParam(const char *name, const char *value)
Set a model parameter.
int TreeliteCreateTreeBuilder(TreeBuilderHandle *out)
Create a new tree builder.
void SetRootNode(int node_key)
Set a node as the root of a tree.
int TreeliteAnnotationFree(AnnotationHandle handle)
delete branch annotation from memory
std::vector< uint32_t > col_ind
feature indices
TreeBuilder * GetTree(int index)
Get a reference to a tree in the ensemble.
void Save(dmlc::Stream *fo) const
save branch annotation to a JSON file
static DMatrix * Create(const char *filename, const char *format, int nthread, int verbose)
construct a new DMatrix from a file
int TreeliteTreeBuilderSetLeafVectorNode(TreeBuilderHandle handle, int node_key, const float *leaf_vector, size_t leaf_vector_len)
Turn an empty node into a leaf vector node The leaf vector (collection of multiple leaf weights per l...
int TreeliteAnnotateBranch(ModelHandle model, DMatrixHandle dmat, int nthread, int verbose, AnnotationHandle *out)
annotate branches in a given model using frequency patterns in the training data. ...
int TreeliteQueryNumOutputGroups(ModelHandle handle, size_t *out)
Query the number of output groups of the model.
Interface of compiler that compiles a tree ensemble model.
int TreeliteDMatrixCreateFromMat(const float *data, size_t num_row, size_t num_col, float missing_value, DMatrixHandle *out)
create DMatrix from a (in-memory) dense matrix
int TreeliteSetTreeLimit(ModelHandle handle, size_t limit)
keep first N trees of model, limit must smaller than number of trees.
int TreeliteModelBuilderInsertTree(ModelBuilderHandle handle, TreeBuilderHandle tree_builder_handle, int index)
Insert a tree at specified location.
int TreeliteDMatrixGetArrays(DMatrixHandle handle, const float **out_data, const uint32_t **out_col_ind, const size_t **out_row_ptr)
extract three arrays (data, col_ind, row_ptr) that define a DMatrix.
Cross-platform wrapper for common filesystem functions.
size_t num_row
number of rows
a simple data matrix in CSR (Compressed Sparse Row) storage
void SetNumericalTestNode(int node_key, unsigned feature_id, const char *op, tl_float threshold, bool default_left, int left_child_key, int right_child_key)
Turn an empty node into a numerical test node; the test is in the form [feature value] OP [threshold]...
int TreeliteTreeBuilderSetRootNode(TreeBuilderHandle handle, int node_key)
Set a node as the root of a tree.
void * TreeBuilderHandle
handle to tree builder class
Error handling for C API.
int TreeliteDeleteTreeBuilder(TreeBuilderHandle handle)
Delete a tree builder from memory.
void * AnnotationHandle
handle to branch annotation data
int TreeliteModelBuilderCommitModel(ModelBuilderHandle handle, ModelHandle *out)
finalize the model and produce the in-memory representation
int TreeliteCompilerGenerateCode(CompilerHandle compiler, ModelHandle model, int verbose, const char *dirpath)
generate prediction code from a tree ensemble model. The code will be C99 compliant. One header file (.h) will be generated, along with one or more source files (.c).
void Clear()
clear all data fields
int TreeliteCreateModelBuilder(int num_feature, int num_output_group, int random_forest_flag, ModelBuilderHandle *out)
Create a new model builder.
void DeleteTree(int index)
Remove a tree from the ensemble.
int TreeliteLoadXGBoostModelFromMemoryBuffer(const void *buf, size_t len, ModelHandle *out)
load an XGBoost model from a memory buffer.
int TreeliteDMatrixGetPreview(DMatrixHandle handle, const char **out_preview)
produce a human-readable preview of a DMatrix Will print first and last 25 non-zero entries...
void * ModelHandle
handle to a decision tree ensemble model
int TreeliteLoadLightGBMModel(const char *filename, ModelHandle *out)
load a model file generated by LightGBM (Microsoft/LightGBM). The model file must contain a decision ...
static Compiler * Create(const std::string &name, const compiler::CompilerParam ¶m)
create a compiler from given name
void * DMatrixHandle
handle to a data matrix
size_t num_col
number of columns
int TreeliteLoadProtobufModel(const char *filename, ModelHandle *out)
load a model in Protocol Buffers format. Protocol Buffers (google/protobuf) is a language- and platfo...
int TreeliteCompilerCreate(const char *name, CompilerHandle *out)
create a compiler with a given name
int TreeliteExportProtobufModel(const char *filename, ModelHandle model)
export a model in Protocol Buffers format. Protocol Buffers (google/protobuf) is a language- and plat...
int TreeliteTreeBuilderSetLeafNode(TreeBuilderHandle handle, int node_key, float leaf_value)
Turn an empty node into a leaf node.
int TreeliteTreeBuilderDeleteNode(TreeBuilderHandle handle, int node_key)
Remove a node from a tree.
void SetLeafNode(int node_key, tl_float leaf_value)
Turn an empty node into a leaf node.
int TreeliteQueryNumFeature(ModelHandle handle, size_t *out)
Query the number of features used in the model.
int TreeliteDMatrixGetDimension(DMatrixHandle handle, size_t *out_num_row, size_t *out_num_col, size_t *out_nelem)
get dimensions of a DMatrix
void * ModelBuilderHandle
handle to ensemble builder class
void CreateNode(int node_key)
Create an empty node within a tree.
size_t nelem
number of nonzero entries
std::vector< size_t > row_ptr
pointer to row headers; length of [num_row] + 1
void * CompilerHandle
handle to compiler class
int TreeliteDMatrixCreateFromFile(const char *path, const char *format, int nthread, int verbose, DMatrixHandle *out)
create DMatrix from a file
int TreeliteCompilerFree(CompilerHandle handle)
delete compiler from memory
#define API_END()
every function starts with API_BEGIN(); and finishes with API_END() or API_END_HANDLE_ERROR ...
int TreeliteTreeBuilderCreateNode(TreeBuilderHandle handle, int node_key)
Create an empty node within a tree.