13 #include <dmlc/json.h> 14 #include <dmlc/thread_local.h> 16 #include <unordered_map> 19 #include "../compiler/param.h" 20 #include "../common/filesystem.h" 21 #include "../common/math.h" 27 struct CompilerHandleImpl {
29 std::vector<std::pair<std::string, std::string>> cfg;
30 std::unique_ptr<Compiler> compiler;
31 CompilerHandleImpl(
const std::string& name)
32 : name(name), cfg(), compiler(nullptr) {}
33 ~CompilerHandleImpl() =
default;
37 struct TreeliteAPIThreadLocalEntry {
43 using TreeliteAPIThreadLocalStore
44 = dmlc::ThreadLocalStore<TreeliteAPIThreadLocalEntry>;
60 const unsigned* col_ind,
61 const size_t* row_ptr,
68 auto& data_ = dmat->
data;
71 data_.reserve(row_ptr[num_row]);
72 col_ind_.reserve(row_ptr[num_row]);
73 row_ptr_.reserve(num_row + 1);
74 for (
size_t i = 0; i < num_row; ++i) {
75 const size_t jbegin = row_ptr[i];
76 const size_t jend = row_ptr[i + 1];
77 for (
size_t j = jbegin; j < jend; ++j) {
78 if (!common::math::CheckNAN(data[j])) {
79 data_.push_back(data[j]);
80 CHECK_LT(col_ind[j], std::numeric_limits<uint32_t>::max())
81 <<
"feature index too big to fit into uint32_t";
82 col_ind_.push_back(static_cast<uint32_t>(col_ind[j]));
85 row_ptr_.push_back(data_.size());
87 data_.shrink_to_fit();
88 col_ind_.shrink_to_fit();
91 dmat->
nelem = data_.size();
102 const bool nan_missing = common::math::CheckNAN(missing_value);
104 CHECK_LT(num_col, std::numeric_limits<uint32_t>::max())
105 <<
"num_col argument is too big";
108 auto& data_ = dmat->
data;
109 auto& col_ind_ = dmat->
col_ind;
110 auto& row_ptr_ = dmat->
row_ptr;
113 const size_t guess_size
114 = std::min(std::min(num_row * num_col, num_row * 1000),
115 static_cast<size_t>(64 * 1024 * 1024));
116 data_.reserve(guess_size);
117 col_ind_.reserve(guess_size);
118 row_ptr_.reserve(num_row + 1);
119 const float* row = &data[0];
120 for (
size_t i = 0; i < num_row; ++i, row += num_col) {
121 for (
size_t j = 0; j < num_col; ++j) {
122 if (common::math::CheckNAN(row[j])) {
124 <<
"The missing_value argument must be set to NaN if there is any " 125 <<
"NaN in the matrix.";
126 }
else if (nan_missing || row[j] != missing_value) {
128 data_.push_back(row[j]);
129 col_ind_.push_back(static_cast<uint32_t>(j));
132 row_ptr_.push_back(data_.size());
134 data_.shrink_to_fit();
135 col_ind_.shrink_to_fit();
138 dmat->
nelem = data_.size();
152 *out_nelem = dmat->
nelem;
157 const char** out_preview) {
160 std::string& ret_str = TreeliteAPIThreadLocalStore::Get()->ret_str;
161 std::ostringstream oss;
162 const size_t iend = (dmat->
nelem <= 50) ? dmat->
nelem : 25;
163 for (
size_t i = 0; i < iend; ++i) {
164 const size_t row_ind =
167 oss <<
" (" << row_ind <<
", " << dmat->
col_ind[i] <<
")\t" 168 << dmat->
data[i] <<
"\n";
170 if (dmat->
nelem > 50) {
172 for (
size_t i = dmat->
nelem - 25; i < dmat->nelem; ++i) {
173 const size_t row_ind =
176 oss <<
" (" << row_ind <<
", " << dmat->
col_ind[i] <<
")\t" 177 << dmat->
data[i] <<
"\n";
181 *out_preview = ret_str.c_str();
186 const float** out_data,
187 const uint32_t** out_col_ind,
188 const size_t** out_row_ptr) {
191 *out_data = &dmat_->
data[0];
192 *out_col_ind = &dmat_->
col_ind[0];
193 *out_row_ptr = &dmat_->
row_ptr[0];
199 delete static_cast<DMatrix*
>(handle);
210 const Model* model_ =
static_cast<Model*
>(model);
212 annotator->
Annotate(*model_, dmat_, nthread, verbose);
221 std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(path,
"r"));
222 annotator->
Load(fi.get());
231 std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(path,
"w"));
232 annotator->
Save(fo.get());
253 CompilerHandleImpl* impl =
static_cast<CompilerHandleImpl*
>(handle);
254 auto& cfg_ = impl->cfg;
255 std::string name_(name);
256 std::string value_(value);
258 auto it = std::find_if(cfg_.begin(), cfg_.end(),
259 [&name_](
const std::pair<std::string, std::string>& x) {
260 return x.first == name_;
262 if (it == cfg_.end()) {
263 cfg_.emplace_back(name_, value_);
273 const char* dirpath) {
277 std::to_string(verbose).c_str());
282 const Model* model_ =
static_cast<Model*
>(model);
283 CompilerHandleImpl* impl =
static_cast<CompilerHandleImpl*
>(compiler);
286 const std::string& dirpath_(dirpath);
287 common::filesystem::CreateDirectoryIfNotExist(dirpath);
290 cparam.Init(impl->cfg, dmlc::parameter::kAllMatch);
295 auto compiled_model = impl->compiler->Compile(*model_);
297 LOG(INFO) <<
"Code generation finished. Writing code to files...";
300 if (!compiled_model.file_prefix.empty()) {
301 const std::vector<std::string> tokens
302 = common::Split(compiled_model.file_prefix,
'/');
303 std::string accum = dirpath_ +
"/" + tokens[0];
304 for (
size_t i = 0; i < tokens.size(); ++i) {
305 common::filesystem::CreateDirectoryIfNotExist(accum.c_str());
306 if (i < tokens.size() - 1) {
308 accum += tokens[i + 1];
313 for (
const auto& it : compiled_model.files) {
314 LOG(INFO) <<
"Writing file " << it.first <<
"...";
315 const std::string filename_full = dirpath_ +
"/" + it.first;
316 common::WriteToFile(filename_full, it.second);
324 delete static_cast<CompilerHandleImpl*
>(handle);
331 Model* model =
new Model(std::move(frontend::LoadLightGBMModel(filename)));
339 Model* model =
new Model(std::move(frontend::LoadXGBoostModel(filename)));
347 Model* model =
new Model(std::move(frontend::LoadXGBoostModel(buf, len)));
355 Model* model =
new Model(std::move(frontend::LoadProtobufModel(filename)));
362 const char* name_obj) {
365 frontend::ExportXGBoostModel(filename, *model_, name_obj);
371 delete static_cast<Model*
>(handle);
391 return (builder->CreateNode(node_key)) ? 0 : -1;
398 return (builder->DeleteNode(node_key)) ? 0 : -1;
405 return (builder->SetRootNode(node_key)) ? 0 : -1;
410 int node_key,
unsigned feature_id,
412 float threshold,
int default_left,
414 int right_child_key) {
417 CHECK_GT(
optable.count(opname), 0)
418 <<
"No operator `" << opname <<
"\" exists";
419 return (builder->SetNumericalTestNode(node_key, feature_id,
423 left_child_key, right_child_key)) \
430 int node_key,
unsigned feature_id,
431 const unsigned int* left_categories,
432 size_t left_categories_len,
435 int right_child_key) {
438 std::vector<uint32_t> vec(left_categories_len);
439 for (
size_t i = 0; i < left_categories_len; ++i) {
440 CHECK(left_categories[i] <= std::numeric_limits<uint32_t>::max());
441 vec[i] =
static_cast<uint32_t
>(left_categories[i]);
443 return (builder->SetCategoricalTestNode(node_key, feature_id, vec,
445 left_child_key, right_child_key)) \
454 return (builder->SetLeafNode(node_key, static_cast<tl_float>(leaf_value))) \
461 const float* leaf_vector,
462 size_t leaf_vector_len) {
465 std::vector<tl_float> vec(leaf_vector_len);
466 for (
size_t i = 0; i < leaf_vector_len; ++i) {
467 vec[i] =
static_cast<tl_float>(leaf_vector[i]);
469 return (builder->SetLeafVectorNode(node_key, vec)) ? 0 : -1;
474 int num_output_group,
475 int random_forest_flag,
479 (random_forest_flag != 0));
505 return model_builder->InsertTree(tree_builder, index);
513 auto tree_builder = &model_builder->
GetTree(index);
521 return (builder->DeleteTree(index)) ? 0 : -1;
530 const bool result = builder->CommitModel(model);
C API of treelite, used for interfacing with other languages This header is excluded from the runtime...
int TreeliteTreeBuilderSetNumericalTestNode(TreeBuilderHandle handle, int node_key, unsigned feature_id, const char *opname, float threshold, int default_left, int left_child_key, int right_child_key)
Turn an empty node into a test node with numerical split. The test is in the form [feature value] OP ...
int TreeliteModelBuilderSetModelParam(ModelBuilderHandle handle, const char *name, const char *value)
Set a model parameter.
int TreeliteModelBuilderGetTree(ModelBuilderHandle handle, int index, TreeBuilderHandle *out)
Get a reference to a tree in the ensemble.
std::vector< float > data
feature values
Collection of front-end methods to load or construct ensemble model.
thin wrapper for tree ensemble model
float tl_float
float type to be used internally
int TreeliteLoadXGBoostModel(const char *filename, ModelHandle *out)
load a model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tree en...
#define API_BEGIN()
macro to guard beginning and end section of all functions
int TreeliteFreeModel(ModelHandle handle)
delete model from memory
int TreeliteAnnotationSave(AnnotationHandle handle, const char *path)
save branch annotation to a JSON file
int TreeliteDMatrixCreateFromCSR(const float *data, const unsigned *col_ind, const size_t *row_ptr, size_t num_row, size_t num_col, DMatrixHandle *out)
create DMatrix from a (in-memory) CSR matrix
int TreeliteDMatrixFree(DMatrixHandle handle)
delete DMatrix from memory
int TreeliteModelBuilderDeleteTree(ModelBuilderHandle handle, int index)
Remove a tree from the ensemble.
parameters for tree compiler
Input data structure of treelite.
void Annotate(const Model &model, const DMatrix *dmat, int nthread, int verbose)
annotate branches in a given model using frequency patterns in the training data. The annotation can ...
int TreeliteCompilerSetParam(CompilerHandle handle, const char *name, const char *value)
set a parameter for a compiler
int TreeliteDeleteModelBuilder(ModelBuilderHandle handle)
Delete a model builder from memory.
int TreeliteTreeBuilderSetCategoricalTestNode(TreeBuilderHandle handle, int node_key, unsigned feature_id, const unsigned int *left_categories, size_t left_categories_len, int default_left, int left_child_key, int right_child_key)
Turn an empty node into a test node with categorical split. A list defines all categories that would ...
void SetModelParam(const char *name, const char *value)
Set a model parameter.
int TreeliteCreateTreeBuilder(TreeBuilderHandle *out)
Create a new tree builder.
int TreeliteAnnotationFree(AnnotationHandle handle)
delete branch annotation from memory
std::vector< uint32_t > col_ind
feature indices
static DMatrix * Create(const char *filename, const char *format, int nthread, int verbose)
construct a new DMatrix from a file
int TreeliteTreeBuilderSetLeafVectorNode(TreeBuilderHandle handle, int node_key, const float *leaf_vector, size_t leaf_vector_len)
Turn an empty node into a leaf vector node The leaf vector (collection of multiple leaf weights per l...
int TreeliteAnnotateBranch(ModelHandle model, DMatrixHandle dmat, int nthread, int verbose, AnnotationHandle *out)
annotate branches in a given model using frequency patterns in the training data. ...
Interface of compiler that compiles a tree ensemble model.
int TreeliteDMatrixCreateFromMat(const float *data, size_t num_row, size_t num_col, float missing_value, DMatrixHandle *out)
create DMatrix from a (in-memory) dense matrix
int TreeliteModelBuilderInsertTree(ModelBuilderHandle handle, TreeBuilderHandle tree_builder_handle, int index)
Insert a tree at specified location.
int TreeliteDMatrixGetArrays(DMatrixHandle handle, const float **out_data, const uint32_t **out_col_ind, const size_t **out_row_ptr)
extract three arrays (data, col_ind, row_ptr) that define a DMatrix.
size_t num_row
number of rows
a simple data matrix in CSR (Compressed Sparse Row) storage
void Load(dmlc::Stream *fi)
load branch annotation from a JSON file
void Save(dmlc::Stream *fo) const
save branch annotation to a JSON file
int TreeliteTreeBuilderSetRootNode(TreeBuilderHandle handle, int node_key)
Set a node as the root of a tree.
void * TreeBuilderHandle
handle to tree builder class
Error handling for C API.
int TreeliteDeleteTreeBuilder(TreeBuilderHandle handle)
Delete a tree builder from memory.
void * AnnotationHandle
handle to branch annotation data
int TreeliteModelBuilderCommitModel(ModelBuilderHandle handle, ModelHandle *out)
finalize the model and produce the in-memory representation
int TreeliteCompilerGenerateCode(CompilerHandle compiler, ModelHandle model, int verbose, const char *dirpath)
generate prediction code from a tree ensemble model. The code will be C99 compliant. One header file (.h) will be generated, along with one or more source files (.c).
void Clear()
clear all data fields
int TreeliteCreateModelBuilder(int num_feature, int num_output_group, int random_forest_flag, ModelBuilderHandle *out)
Create a new model builder.
int TreeliteAnnotationLoad(const char *path, AnnotationHandle *out)
load branch annotation from a JSON file
const std::unordered_map< std::string, Operator > optable
conversion table from string to operator, defined in optable.cc
int TreeliteLoadXGBoostModelFromMemoryBuffer(const void *buf, size_t len, ModelHandle *out)
load an XGBoost model from a memory buffer.
int TreeliteDMatrixGetPreview(DMatrixHandle handle, const char **out_preview)
produce a human-readable preview of a DMatrix Will print first and last 25 non-zero entries...
void * ModelHandle
handle to a decision tree ensemble model
int TreeliteLoadLightGBMModel(const char *filename, ModelHandle *out)
load a model file generated by LightGBM (Microsoft/LightGBM). The model file must contain a decision ...
static Compiler * Create(const std::string &name, const compiler::CompilerParam ¶m)
create a compiler from given name
void * DMatrixHandle
handle to a data matrix
size_t num_col
number of columns
int TreeliteLoadProtobufModel(const char *filename, ModelHandle *out)
load a model in Protocol Buffers format. Protocol Buffers (google/protobuf) is a language- and platfo...
int TreeliteCompilerCreate(const char *name, CompilerHandle *out)
create a compiler with a given name
int TreeliteTreeBuilderSetLeafNode(TreeBuilderHandle handle, int node_key, float leaf_value)
Turn an empty node into a leaf node.
int TreeliteTreeBuilderDeleteNode(TreeBuilderHandle handle, int node_key)
Remove a node from a tree.
int TreeliteDMatrixGetDimension(DMatrixHandle handle, size_t *out_num_row, size_t *out_num_col, size_t *out_nelem)
get dimensions of a DMatrix
void * ModelBuilderHandle
handle to ensemble builder class
size_t nelem
number of nonzero entries
std::vector< size_t > row_ptr
pointer to row headers; length of [num_row] + 1
void * CompilerHandle
handle to compiler class
int TreeliteExportXGBoostModel(const char *filename, ModelHandle model, const char *name_obj)
(EXPERIMENTAL FEATURE) export a model in XGBoost format. The exported model can be read by XGBoost (d...
int TreeliteDMatrixCreateFromFile(const char *path, const char *format, int nthread, int verbose, DMatrixHandle *out)
create DMatrix from a file
int TreeliteCompilerFree(CompilerHandle handle)
delete compiler from memory
TreeBuilder & GetTree(int index)
Get a reference to a tree in the ensemble.
#define API_END()
every function starts with API_BEGIN(); and finishes with API_END() or API_END_HANDLE_ERROR ...
int TreeliteTreeBuilderCreateNode(TreeBuilderHandle handle, int node_key)
Create an empty node within a tree.