9 #include <unordered_map> 16 #include <dmlc/json.h> 17 #include <dmlc/thread_local.h> 18 #include "./c_api_error.h" 19 #include "../compiler/param.h" 20 #include "../common/filesystem.h" 21 #include "../common/math.h" 27 struct CompilerHandleImpl {
29 std::vector<std::pair<std::string, std::string>> cfg;
30 std::unique_ptr<Compiler> compiler;
31 explicit CompilerHandleImpl(
const std::string& name)
32 : name(name), cfg(), compiler(nullptr) {}
33 ~CompilerHandleImpl() =
default;
37 struct TreeliteAPIThreadLocalEntry {
43 using TreeliteAPIThreadLocalStore
44 = dmlc::ThreadLocalStore<TreeliteAPIThreadLocalEntry>;
60 const unsigned* col_ind,
61 const size_t* row_ptr,
68 auto& data_ = dmat->
data;
71 data_.reserve(row_ptr[num_row]);
72 col_ind_.reserve(row_ptr[num_row]);
73 row_ptr_.reserve(num_row + 1);
74 for (
size_t i = 0; i < num_row; ++i) {
75 const size_t jbegin = row_ptr[i];
76 const size_t jend = row_ptr[i + 1];
77 for (
size_t j = jbegin; j < jend; ++j) {
78 if (!common::math::CheckNAN(data[j])) {
79 data_.push_back(data[j]);
80 CHECK_LT(col_ind[j], std::numeric_limits<uint32_t>::max())
81 <<
"feature index too big to fit into uint32_t";
82 col_ind_.push_back(static_cast<uint32_t>(col_ind[j]));
85 row_ptr_.push_back(data_.size());
87 data_.shrink_to_fit();
88 col_ind_.shrink_to_fit();
91 dmat->
nelem = data_.size();
102 const bool nan_missing = common::math::CheckNAN(missing_value);
104 CHECK_LT(num_col, std::numeric_limits<uint32_t>::max())
105 <<
"num_col argument is too big";
108 auto& data_ = dmat->
data;
109 auto& col_ind_ = dmat->
col_ind;
110 auto& row_ptr_ = dmat->
row_ptr;
113 const size_t guess_size
114 = std::min(std::min(num_row * num_col, num_row * 1000),
115 static_cast<size_t>(64 * 1024 * 1024));
116 data_.reserve(guess_size);
117 col_ind_.reserve(guess_size);
118 row_ptr_.reserve(num_row + 1);
119 const float* row = &data[0];
120 for (
size_t i = 0; i < num_row; ++i, row += num_col) {
121 for (
size_t j = 0; j < num_col; ++j) {
122 if (common::math::CheckNAN(row[j])) {
124 <<
"The missing_value argument must be set to NaN if there is any " 125 <<
"NaN in the matrix.";
126 }
else if (nan_missing || row[j] != missing_value) {
128 data_.push_back(row[j]);
129 col_ind_.push_back(static_cast<uint32_t>(j));
132 row_ptr_.push_back(data_.size());
134 data_.shrink_to_fit();
135 col_ind_.shrink_to_fit();
138 dmat->
nelem = data_.size();
152 *out_nelem = dmat->
nelem;
157 const char** out_preview) {
160 std::string& ret_str = TreeliteAPIThreadLocalStore::Get()->ret_str;
161 std::ostringstream oss;
162 const size_t iend = (dmat->
nelem <= 50) ? dmat->
nelem : 25;
163 for (
size_t i = 0; i < iend; ++i) {
164 const size_t row_ind =
167 oss <<
" (" << row_ind <<
", " << dmat->
col_ind[i] <<
")\t" 168 << dmat->
data[i] <<
"\n";
170 if (dmat->
nelem > 50) {
172 for (
size_t i = dmat->
nelem - 25; i < dmat->nelem; ++i) {
173 const size_t row_ind =
176 oss <<
" (" << row_ind <<
", " << dmat->
col_ind[i] <<
")\t" 177 << dmat->
data[i] <<
"\n";
181 *out_preview = ret_str.c_str();
186 const float** out_data,
187 const uint32_t** out_col_ind,
188 const size_t** out_row_ptr) {
191 *out_data = &dmat_->
data[0];
192 *out_col_ind = &dmat_->
col_ind[0];
193 *out_row_ptr = &dmat_->
row_ptr[0];
199 delete static_cast<DMatrix*
>(handle);
210 const Model* model_ =
static_cast<Model*
>(model);
212 annotator->
Annotate(*model_, dmat_, nthread, verbose);
221 std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(path,
"w"));
222 annotator->
Save(fo.get());
243 CompilerHandleImpl* impl =
static_cast<CompilerHandleImpl*
>(handle);
244 auto& cfg_ = impl->cfg;
245 std::string name_(name);
246 std::string value_(value);
248 auto it = std::find_if(cfg_.begin(), cfg_.end(),
249 [&name_](
const std::pair<std::string, std::string>& x) {
250 return x.first == name_;
252 if (it == cfg_.end()) {
253 cfg_.emplace_back(name_, value_);
263 const char* dirpath) {
267 std::to_string(verbose).c_str());
272 const Model* model_ =
static_cast<Model*
>(model);
273 CompilerHandleImpl* impl =
static_cast<CompilerHandleImpl*
>(compiler);
276 const std::string& dirpath_(dirpath);
277 common::filesystem::CreateDirectoryIfNotExist(dirpath);
280 cparam.Init(impl->cfg, dmlc::parameter::kAllMatch);
284 auto compiled_model = impl->compiler->Compile(*model_);
286 LOG(INFO) <<
"Code generation finished. Writing code to files...";
289 for (
const auto& it : compiled_model.files) {
291 LOG(INFO) <<
"Writing file " << it.first <<
"...";
293 const std::string filename_full = dirpath_ +
"/" + it.first;
294 if (it.second.is_binary) {
295 common::WriteToFile(filename_full, it.second.content_binary);
297 common::WriteToFile(filename_full, it.second.content);
306 delete static_cast<CompilerHandleImpl*
>(handle);
313 Model* model =
new Model(std::move(frontend::LoadLightGBMModel(filename)));
321 Model* model =
new Model(std::move(frontend::LoadXGBoostModel(filename)));
329 Model* model =
new Model(std::move(frontend::LoadXGBoostModel(buf, len)));
337 Model* model =
new Model(std::move(frontend::LoadProtobufModel(filename)));
346 frontend::ExportProtobufModel(filename, *model_);
352 delete static_cast<Model*
>(handle);
358 const Model* model_ =
static_cast<Model*
>(handle);
359 *out = model_->
trees.size();
365 const Model* model_ =
static_cast<Model*
>(handle);
372 const Model* model_ =
static_cast<Model*
>(handle);
379 CHECK_GT(limit, 0) <<
"limit should be greater than 0!";
380 auto* model_ =
static_cast<Model*
>(handle);
381 CHECK_GE(model_->trees.size(), limit)
382 <<
"Model contains less trees(" << model_->
trees.size() <<
") than limit";
383 model_->trees.resize(limit);
403 CHECK(builder) <<
"Detected dangling reference to deleted TreeBuilder object";
404 return (builder->CreateNode(node_key)) ? 0 : -1;
411 CHECK(builder) <<
"Detected dangling reference to deleted TreeBuilder object";
412 return (builder->DeleteNode(node_key)) ? 0 : -1;
419 CHECK(builder) <<
"Detected dangling reference to deleted TreeBuilder object";
420 return (builder->SetRootNode(node_key)) ? 0 : -1;
425 int node_key,
unsigned feature_id,
427 double threshold,
int default_left,
429 int right_child_key) {
432 CHECK(builder) <<
"Detected dangling reference to deleted TreeBuilder object";
433 CHECK_GT(
optable.count(opname), 0)
434 <<
"No operator `" << opname <<
"\" exists";
435 return (builder->SetNumericalTestNode(node_key, feature_id,
439 left_child_key, right_child_key)) \
446 int node_key,
unsigned feature_id,
447 const unsigned int* left_categories,
448 size_t left_categories_len,
451 int right_child_key) {
454 CHECK(builder) <<
"Detected dangling reference to deleted TreeBuilder object";
455 std::vector<uint32_t> vec(left_categories_len);
456 for (
size_t i = 0; i < left_categories_len; ++i) {
457 CHECK(left_categories[i] <= std::numeric_limits<uint32_t>::max());
458 vec[i] =
static_cast<uint32_t
>(left_categories[i]);
460 return (builder->SetCategoricalTestNode(node_key, feature_id, vec,
462 left_child_key, right_child_key)) \
471 CHECK(builder) <<
"Detected dangling reference to deleted TreeBuilder object";
472 return (builder->SetLeafNode(node_key, static_cast<tl_float>(leaf_value))) \
479 const double* leaf_vector,
480 size_t leaf_vector_len) {
483 CHECK(builder) <<
"Detected dangling reference to deleted TreeBuilder object";
484 std::vector<tl_float> vec(leaf_vector_len);
485 for (
size_t i = 0; i < leaf_vector_len; ++i) {
486 vec[i] =
static_cast<tl_float>(leaf_vector[i]);
488 return (builder->SetLeafVectorNode(node_key, vec)) ? 0 : -1;
493 int num_output_group,
494 int random_forest_flag,
498 (random_forest_flag != 0));
508 CHECK(builder) <<
"Detected dangling reference to deleted ModelBuilder object";
524 CHECK(model_builder) <<
"Detected dangling reference to deleted ModelBuilder object";
526 CHECK(tree_builder) <<
"Detected dangling reference to deleted TreeBuilder object";
527 return model_builder->InsertTree(tree_builder, index);
535 CHECK(model_builder) <<
"Detected dangling reference to deleted ModelBuilder object";
536 auto tree_builder = &model_builder->
GetTree(index);
537 CHECK(tree_builder) <<
"Detected dangling reference to deleted TreeBuilder object";
545 CHECK(builder) <<
"Detected dangling reference to deleted ModelBuilder object";
546 return (builder->DeleteTree(index)) ? 0 : -1;
554 CHECK(builder) <<
"Detected dangling reference to deleted ModelBuilder object";
556 const bool result = builder->CommitModel(model);
C API of treelite, used for interfacing with other languages This header is excluded from the runtime...
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
int TreeliteModelBuilderSetModelParam(ModelBuilderHandle handle, const char *name, const char *value)
Set a model parameter.
int TreeliteModelBuilderGetTree(ModelBuilderHandle handle, int index, TreeBuilderHandle *out)
Get a reference to a tree in the ensemble.
std::vector< float > data
feature values
Collection of front-end methods to load or construct ensemble model.
thin wrapper for tree ensemble model
int TreeliteLoadXGBoostModel(const char *filename, ModelHandle *out)
load a model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tree en...
int TreeliteFreeModel(ModelHandle handle)
delete model from memory
int TreeliteAnnotationSave(AnnotationHandle handle, const char *path)
save branch annotation to a JSON file
int TreeliteQueryNumTree(ModelHandle handle, size_t *out)
Query the number of trees in the model.
int TreeliteDMatrixCreateFromCSR(const float *data, const unsigned *col_ind, const size_t *row_ptr, size_t num_row, size_t num_col, DMatrixHandle *out)
create DMatrix from a (in-memory) CSR matrix
std::vector< Tree > trees
member trees
int TreeliteDMatrixFree(DMatrixHandle handle)
delete DMatrix from memory
int TreeliteModelBuilderDeleteTree(ModelBuilderHandle handle, int index)
Remove a tree from the ensemble.
parameters for tree compiler
Input data structure of treelite.
void Annotate(const Model &model, const DMatrix *dmat, int nthread, int verbose)
annotate branches in a given model using frequency patterns in the training data. The annotation can ...
int TreeliteCompilerSetParam(CompilerHandle handle, const char *name, const char *value)
set a parameter for a compiler
int TreeliteDeleteModelBuilder(ModelBuilderHandle handle)
Delete a model builder from memory.
int TreeliteTreeBuilderSetCategoricalTestNode(TreeBuilderHandle handle, int node_key, unsigned feature_id, const unsigned int *left_categories, size_t left_categories_len, int default_left, int left_child_key, int right_child_key)
Turn an empty node into a test node with categorical split. A list defines all categories that would ...
void SetModelParam(const char *name, const char *value)
Set a model parameter.
int TreeliteCreateTreeBuilder(TreeBuilderHandle *out)
Create a new tree builder.
int TreeliteAnnotationFree(AnnotationHandle handle)
delete branch annotation from memory
std::vector< uint32_t > col_ind
feature indices
void Save(dmlc::Stream *fo) const
save branch annotation to a JSON file
static DMatrix * Create(const char *filename, const char *format, int nthread, int verbose)
construct a new DMatrix from a file
int TreeliteAnnotateBranch(ModelHandle model, DMatrixHandle dmat, int nthread, int verbose, AnnotationHandle *out)
annotate branches in a given model using frequency patterns in the training data. ...
int TreeliteQueryNumOutputGroups(ModelHandle handle, size_t *out)
Query the number of output groups of the model.
Interface of compiler that compiles a tree ensemble model.
int TreeliteDMatrixCreateFromMat(const float *data, size_t num_row, size_t num_col, float missing_value, DMatrixHandle *out)
create DMatrix from a (in-memory) dense matrix
int TreeliteSetTreeLimit(ModelHandle handle, size_t limit)
keep first N trees of model, limit must smaller than number of trees.
int TreeliteModelBuilderInsertTree(ModelBuilderHandle handle, TreeBuilderHandle tree_builder_handle, int index)
Insert a tree at specified location.
int TreeliteDMatrixGetArrays(DMatrixHandle handle, const float **out_data, const uint32_t **out_col_ind, const size_t **out_row_ptr)
extract three arrays (data, col_ind, row_ptr) that define a DMatrix.
size_t num_row
number of rows
a simple data matrix in CSR (Compressed Sparse Row) storage
int TreeliteTreeBuilderSetRootNode(TreeBuilderHandle handle, int node_key)
Set a node as the root of a tree.
void * TreeBuilderHandle
handle to tree builder class
int TreeliteDeleteTreeBuilder(TreeBuilderHandle handle)
Delete a tree builder from memory.
void * AnnotationHandle
handle to branch annotation data
int TreeliteModelBuilderCommitModel(ModelBuilderHandle handle, ModelHandle *out)
finalize the model and produce the in-memory representation
int TreeliteCompilerGenerateCode(CompilerHandle compiler, ModelHandle model, int verbose, const char *dirpath)
generate prediction code from a tree ensemble model. The code will be C99 compliant. One header file (.h) will be generated, along with one or more source files (.c).
void Clear()
clear all data fields
int TreeliteCreateModelBuilder(int num_feature, int num_output_group, int random_forest_flag, ModelBuilderHandle *out)
Create a new model builder.
const std::unordered_map< std::string, Operator > optable
conversion table from string to operator, defined in optable.cc
int TreeliteTreeBuilderSetLeafVectorNode(TreeBuilderHandle handle, int node_key, const double *leaf_vector, size_t leaf_vector_len)
Turn an empty node into a leaf vector node The leaf vector (collection of multiple leaf weights per l...
int TreeliteLoadXGBoostModelFromMemoryBuffer(const void *buf, size_t len, ModelHandle *out)
load an XGBoost model from a memory buffer.
int TreeliteDMatrixGetPreview(DMatrixHandle handle, const char **out_preview)
produce a human-readable preview of a DMatrix Will print first and last 25 non-zero entries...
void * ModelHandle
handle to a decision tree ensemble model
int TreeliteLoadLightGBMModel(const char *filename, ModelHandle *out)
load a model file generated by LightGBM (Microsoft/LightGBM). The model file must contain a decision ...
static Compiler * Create(const std::string &name, const compiler::CompilerParam ¶m)
create a compiler from given name
double tl_float
float type to be used internally
void * DMatrixHandle
handle to a data matrix
size_t num_col
number of columns
int TreeliteLoadProtobufModel(const char *filename, ModelHandle *out)
load a model in Protocol Buffers format. Protocol Buffers (google/protobuf) is a language- and platfo...
int TreeliteTreeBuilderSetNumericalTestNode(TreeBuilderHandle handle, int node_key, unsigned feature_id, const char *opname, double threshold, int default_left, int left_child_key, int right_child_key)
Turn an empty node into a test node with numerical split. The test is in the form [feature value] OP ...
int TreeliteCompilerCreate(const char *name, CompilerHandle *out)
create a compiler with a given name
int TreeliteExportProtobufModel(const char *filename, ModelHandle model)
export a model in Protocol Buffers format. Protocol Buffers (google/protobuf) is a language- and plat...
int TreeliteTreeBuilderDeleteNode(TreeBuilderHandle handle, int node_key)
Remove a node from a tree.
int TreeliteQueryNumFeature(ModelHandle handle, size_t *out)
Query the number of features used in the model.
int TreeliteDMatrixGetDimension(DMatrixHandle handle, size_t *out_num_row, size_t *out_num_col, size_t *out_nelem)
get dimensions of a DMatrix
void * ModelBuilderHandle
handle to ensemble builder class
size_t nelem
number of nonzero entries
std::vector< size_t > row_ptr
pointer to row headers; length of [num_row] + 1
void * CompilerHandle
handle to compiler class
int TreeliteTreeBuilderSetLeafNode(TreeBuilderHandle handle, int node_key, double leaf_value)
Turn an empty node into a leaf node.
int TreeliteDMatrixCreateFromFile(const char *path, const char *format, int nthread, int verbose, DMatrixHandle *out)
create DMatrix from a file
int TreeliteCompilerFree(CompilerHandle handle)
delete compiler from memory
TreeBuilder & GetTree(int index)
Get a reference to a tree in the ensemble.
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
int TreeliteTreeBuilderCreateNode(TreeBuilderHandle handle, int node_key)
Create an empty node within a tree.