14 #include <dmlc/json.h> 15 #include <dmlc/thread_local.h> 17 #include <unordered_map> 20 #include "../compiler/param.h" 21 #include "../common/filesystem.h" 22 #include "../common/math.h" 28 struct CompilerHandleImpl {
30 std::vector<std::pair<std::string, std::string>> cfg;
31 std::unique_ptr<Compiler> compiler;
32 CompilerHandleImpl(
const std::string& name)
33 : name(name), cfg(), compiler(nullptr) {}
34 ~CompilerHandleImpl() =
default;
38 struct TreeliteAPIThreadLocalEntry {
44 using TreeliteAPIThreadLocalStore
45 = dmlc::ThreadLocalStore<TreeliteAPIThreadLocalEntry>;
61 const unsigned* col_ind,
62 const size_t* row_ptr,
69 auto& data_ = dmat->
data;
72 data_.reserve(row_ptr[num_row]);
73 col_ind_.reserve(row_ptr[num_row]);
74 row_ptr_.reserve(num_row + 1);
75 for (
size_t i = 0; i < num_row; ++i) {
76 const size_t jbegin = row_ptr[i];
77 const size_t jend = row_ptr[i + 1];
78 for (
size_t j = jbegin; j < jend; ++j) {
79 if (!common::math::CheckNAN(data[j])) {
80 data_.push_back(data[j]);
81 CHECK_LT(col_ind[j], std::numeric_limits<uint32_t>::max())
82 <<
"feature index too big to fit into uint32_t";
83 col_ind_.push_back(static_cast<uint32_t>(col_ind[j]));
86 row_ptr_.push_back(data_.size());
88 data_.shrink_to_fit();
89 col_ind_.shrink_to_fit();
92 dmat->
nelem = data_.size();
103 const bool nan_missing = common::math::CheckNAN(missing_value);
105 CHECK_LT(num_col, std::numeric_limits<uint32_t>::max())
106 <<
"num_col argument is too big";
109 auto& data_ = dmat->
data;
110 auto& col_ind_ = dmat->
col_ind;
111 auto& row_ptr_ = dmat->
row_ptr;
114 const size_t guess_size
115 = std::min(std::min(num_row * num_col, num_row * 1000),
116 static_cast<size_t>(64 * 1024 * 1024));
117 data_.reserve(guess_size);
118 col_ind_.reserve(guess_size);
119 row_ptr_.reserve(num_row + 1);
120 const float* row = &data[0];
121 for (
size_t i = 0; i < num_row; ++i, row += num_col) {
122 for (
size_t j = 0; j < num_col; ++j) {
123 if (common::math::CheckNAN(row[j])) {
125 <<
"The missing_value argument must be set to NaN if there is any " 126 <<
"NaN in the matrix.";
127 }
else if (nan_missing || row[j] != missing_value) {
129 data_.push_back(row[j]);
130 col_ind_.push_back(static_cast<uint32_t>(j));
133 row_ptr_.push_back(data_.size());
135 data_.shrink_to_fit();
136 col_ind_.shrink_to_fit();
139 dmat->
nelem = data_.size();
153 *out_nelem = dmat->
nelem;
158 const char** out_preview) {
161 std::string& ret_str = TreeliteAPIThreadLocalStore::Get()->ret_str;
162 std::ostringstream oss;
163 const size_t iend = (dmat->
nelem <= 50) ? dmat->
nelem : 25;
164 for (
size_t i = 0; i < iend; ++i) {
165 const size_t row_ind =
168 oss <<
" (" << row_ind <<
", " << dmat->
col_ind[i] <<
")\t" 169 << dmat->
data[i] <<
"\n";
171 if (dmat->
nelem > 50) {
173 for (
size_t i = dmat->
nelem - 25; i < dmat->nelem; ++i) {
174 const size_t row_ind =
177 oss <<
" (" << row_ind <<
", " << dmat->
col_ind[i] <<
")\t" 178 << dmat->
data[i] <<
"\n";
182 *out_preview = ret_str.c_str();
187 const float** out_data,
188 const uint32_t** out_col_ind,
189 const size_t** out_row_ptr) {
192 *out_data = &dmat_->
data[0];
193 *out_col_ind = &dmat_->
col_ind[0];
194 *out_row_ptr = &dmat_->
row_ptr[0];
200 delete static_cast<DMatrix*
>(handle);
211 const Model* model_ =
static_cast<Model*
>(model);
213 annotator->
Annotate(*model_, dmat_, nthread, verbose);
222 std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(path,
"r"));
223 annotator->
Load(fi.get());
232 std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(path,
"w"));
233 annotator->
Save(fo.get());
254 CompilerHandleImpl* impl =
static_cast<CompilerHandleImpl*
>(handle);
255 auto& cfg_ = impl->cfg;
256 std::string name_(name);
257 std::string value_(value);
259 auto it = std::find_if(cfg_.begin(), cfg_.end(),
260 [&name_](
const std::pair<std::string, std::string>& x) {
261 return x.first == name_;
263 if (it == cfg_.end()) {
264 cfg_.emplace_back(name_, value_);
274 const char* dirpath) {
278 std::to_string(verbose).c_str());
283 const Model* model_ =
static_cast<Model*
>(model);
284 CompilerHandleImpl* impl =
static_cast<CompilerHandleImpl*
>(compiler);
287 const std::string& dirpath_(dirpath);
288 common::filesystem::CreateDirectoryIfNotExist(dirpath);
289 const std::string basename = common::filesystem::GetBasename(dirpath);
292 cparam.Init(impl->cfg, dmlc::parameter::kAllMatch);
296 auto semantic_model = impl->compiler->Compile(*model_);
298 LOG(INFO) <<
"Code generation finished. Writing code to files...";
302 const std::string header_filename = dirpath_ +
"/" + basename +
".h";
304 LOG(INFO) <<
"Writing " << header_filename <<
" ...";
307 std::vector<std::string> lines;
308 common::TransformPushBack(&lines, semantic_model.common_header->Compile(),
309 [] (std::string line) {
312 lines.emplace_back();
313 std::ostringstream oss;
315 std::copy(semantic_model.function_registry.begin(),
316 semantic_model.function_registry.end(),
317 std::ostream_iterator<FunctionEntry>(oss));
318 lines.push_back(oss.str());
319 common::WriteToFile(header_filename, lines);
322 std::vector<std::unordered_map<std::string, std::string>> source_list;
323 if (semantic_model.units.size() == 1) {
324 const std::string filename = basename +
".c";
325 const std::string filename_full = dirpath_ +
"/" + filename;
326 const std::string objname = basename +
".o";
328 LOG(INFO) <<
"Writing " << filename_full <<
" ...";
330 auto lines = semantic_model.units[0].Compile(header_filename);
331 source_list.push_back({ {
"name", basename},
332 {
"length", std::to_string(lines.size())} });
333 common::WriteToFile(filename_full, lines);
335 for (
size_t i = 0; i < semantic_model.units.size(); ++i) {
336 const std::string filename = basename + std::to_string(i) +
".c";
337 const std::string filename_full = dirpath_ +
"/" + filename;
338 const std::string objname = basename + std::to_string(i) +
".o";
340 LOG(INFO) <<
"Writing " << filename_full <<
" ...";
342 auto lines = semantic_model.units[i].Compile(header_filename);
343 source_list.push_back({ {
"name", basename + std::to_string(i)},
344 {
"length", std::to_string(lines.size())} });
345 common::WriteToFile(filename_full, lines);
350 const std::string recipe_name = dirpath_ +
"/recipe.json";
352 LOG(INFO) <<
"Writing " << recipe_name <<
" ...";
354 std::unique_ptr<dmlc::Stream> fo(
355 dmlc::Stream::Create(recipe_name.c_str(),
"w"));
356 dmlc::ostream os(fo.get());
357 auto writer = common::make_unique<dmlc::JSONWriter>(&os);
358 writer->BeginObject();
359 writer->WriteObjectKeyValue(
"target", basename);
360 writer->WriteObjectKeyValue(
"sources", source_list);
363 os.set_stream(
nullptr);
370 delete static_cast<CompilerHandleImpl*
>(handle);
377 Model* model =
new Model(std::move(frontend::LoadLightGBMModel(filename)));
385 Model* model =
new Model(std::move(frontend::LoadXGBoostModel(filename)));
393 Model* model =
new Model(std::move(frontend::LoadXGBoostModel(buf, len)));
401 Model* model =
new Model(std::move(frontend::LoadProtobufModel(filename)));
408 delete static_cast<Model*
>(handle);
428 return (builder->CreateNode(node_key)) ? 0 : -1;
435 return (builder->DeleteNode(node_key)) ? 0 : -1;
442 return (builder->SetRootNode(node_key)) ? 0 : -1;
447 int node_key,
unsigned feature_id,
449 float threshold,
int default_left,
451 int right_child_key) {
454 CHECK_GT(
optable.count(opname), 0)
455 <<
"No operator `" << opname <<
"\" exists";
456 return (builder->SetNumericalTestNode(node_key, feature_id,
460 left_child_key, right_child_key)) \
467 int node_key,
unsigned feature_id,
468 const unsigned int* left_categories,
469 size_t left_categories_len,
472 int right_child_key) {
475 std::vector<uint32_t> vec(left_categories_len);
476 for (
size_t i = 0; i < left_categories_len; ++i) {
477 CHECK(left_categories[i] <= std::numeric_limits<uint32_t>::max());
478 vec[i] =
static_cast<uint32_t
>(left_categories[i]);
480 return (builder->SetCategoricalTestNode(node_key, feature_id, vec,
482 left_child_key, right_child_key)) \
491 return (builder->SetLeafNode(node_key, static_cast<tl_float>(leaf_value))) \
498 const float* leaf_vector,
499 size_t leaf_vector_len) {
502 std::vector<tl_float> vec(leaf_vector_len);
503 for (
size_t i = 0; i < leaf_vector_len; ++i) {
504 vec[i] =
static_cast<tl_float>(leaf_vector[i]);
506 return (builder->SetLeafVectorNode(node_key, vec)) ? 0 : -1;
511 int num_output_group,
512 int random_forest_flag,
516 (random_forest_flag != 0));
542 return model_builder->InsertTree(tree_builder, index);
550 auto tree_builder = &model_builder->
GetTree(index);
558 return (builder->DeleteTree(index)) ? 0 : -1;
567 const bool result = builder->CommitModel(model);
C API of treelite, used for interfacing with other languages This header is excluded from the runtime...
int TreeliteTreeBuilderSetNumericalTestNode(TreeBuilderHandle handle, int node_key, unsigned feature_id, const char *opname, float threshold, int default_left, int left_child_key, int right_child_key)
Turn an empty node into a test node with numerical split. The test is in the form [feature value] OP ...
int TreeliteModelBuilderSetModelParam(ModelBuilderHandle handle, const char *name, const char *value)
Set a model parameter.
int TreeliteModelBuilderGetTree(ModelBuilderHandle handle, int index, TreeBuilderHandle *out)
Get a reference to a tree in the ensemble.
std::vector< float > data
feature values
Collection of front-end methods to load or construct ensemble model.
thin wrapper for tree ensemble model
float tl_float
float type to be used internally
int TreeliteLoadXGBoostModel(const char *filename, ModelHandle *out)
load a model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tree en...
#define API_BEGIN()
macro to guard beginning and end section of all functions
int TreeliteFreeModel(ModelHandle handle)
delete model from memory
int TreeliteAnnotationSave(AnnotationHandle handle, const char *path)
save branch annotation to a JSON file
int TreeliteDMatrixCreateFromCSR(const float *data, const unsigned *col_ind, const size_t *row_ptr, size_t num_row, size_t num_col, DMatrixHandle *out)
create DMatrix from a (in-memory) CSR matrix
int TreeliteDMatrixFree(DMatrixHandle handle)
delete DMatrix from memory
int TreeliteModelBuilderDeleteTree(ModelBuilderHandle handle, int index)
Remove a tree from the ensemble.
parameters for tree compiler
Input data structure of treelite.
void Annotate(const Model &model, const DMatrix *dmat, int nthread, int verbose)
annotate branches in a given model using frequency patterns in the training data. The annotation can ...
int TreeliteCompilerSetParam(CompilerHandle handle, const char *name, const char *value)
set a parameter for a compiler
int TreeliteDeleteModelBuilder(ModelBuilderHandle handle)
Delete a model builder from memory.
int TreeliteTreeBuilderSetCategoricalTestNode(TreeBuilderHandle handle, int node_key, unsigned feature_id, const unsigned int *left_categories, size_t left_categories_len, int default_left, int left_child_key, int right_child_key)
Turn an empty node into a test node with categorical split. A list defines all categories that would ...
void SetModelParam(const char *name, const char *value)
Set a model parameter.
int TreeliteCreateTreeBuilder(TreeBuilderHandle *out)
Create a new tree builder.
int TreeliteAnnotationFree(AnnotationHandle handle)
delete branch annotation from memory
std::vector< uint32_t > col_ind
feature indices
static DMatrix * Create(const char *filename, const char *format, int nthread, int verbose)
construct a new DMatrix from a file
int TreeliteTreeBuilderSetLeafVectorNode(TreeBuilderHandle handle, int node_key, const float *leaf_vector, size_t leaf_vector_len)
Turn an empty node into a leaf vector node The leaf vector (collection of multiple leaf weights per l...
int TreeliteAnnotateBranch(ModelHandle model, DMatrixHandle dmat, int nthread, int verbose, AnnotationHandle *out)
annotate branches in a given model using frequency patterns in the training data. ...
Interface of compiler that translates a tree ensemble model into a semantic model.
int TreeliteDMatrixCreateFromMat(const float *data, size_t num_row, size_t num_col, float missing_value, DMatrixHandle *out)
create DMatrix from a (in-memory) dense matrix
int TreeliteModelBuilderInsertTree(ModelBuilderHandle handle, TreeBuilderHandle tree_builder_handle, int index)
Insert a tree at specified location.
int TreeliteDMatrixGetArrays(DMatrixHandle handle, const float **out_data, const uint32_t **out_col_ind, const size_t **out_row_ptr)
extract three arrays (data, col_ind, row_ptr) that define a DMatrix.
size_t num_row
number of rows
a simple data matrix in CSR (Compressed Sparse Row) storage
void Load(dmlc::Stream *fi)
load branch annotation from a JSON file
void Save(dmlc::Stream *fo) const
save branch annotation to a JSON file
int TreeliteTreeBuilderSetRootNode(TreeBuilderHandle handle, int node_key)
Set a node as the root of a tree.
void * TreeBuilderHandle
handle to tree builder class
Error handling for C API.
int TreeliteDeleteTreeBuilder(TreeBuilderHandle handle)
Delete a tree builder from memory.
void * AnnotationHandle
handle to branch annotation data
int TreeliteModelBuilderCommitModel(ModelBuilderHandle handle, ModelHandle *out)
finalize the model and produce the in-memory representation
int TreeliteCompilerGenerateCode(CompilerHandle compiler, ModelHandle model, int verbose, const char *dirpath)
generate prediction code from a tree ensemble model. The code will be C99 compliant. One header file (.h) will be generated, along with one or more source files (.c).
void Clear()
clear all data fields
int TreeliteCreateModelBuilder(int num_feature, int num_output_group, int random_forest_flag, ModelBuilderHandle *out)
Create a new model builder.
int TreeliteAnnotationLoad(const char *path, AnnotationHandle *out)
load branch annotation from a JSON file
const std::unordered_map< std::string, Operator > optable
conversion table from string to operator, defined in optable.cc
int TreeliteLoadXGBoostModelFromMemoryBuffer(const void *buf, size_t len, ModelHandle *out)
load an XGBoost model from a memory buffer.
int TreeliteDMatrixGetPreview(DMatrixHandle handle, const char **out_preview)
produce a human-readable preview of a DMatrix Will print first and last 25 non-zero entries...
void * ModelHandle
handle to a decision tree ensemble model
int TreeliteLoadLightGBMModel(const char *filename, ModelHandle *out)
load a model file generated by LightGBM (Microsoft/LightGBM). The model file must contain a decision ...
static Compiler * Create(const std::string &name, const compiler::CompilerParam ¶m)
create a compiler from given name
void * DMatrixHandle
handle to a data matrix
size_t num_col
number of columns
int TreeliteLoadProtobufModel(const char *filename, ModelHandle *out)
load a model in Protocol Buffers format. Protocol Buffers (google/protobuf) is a language- and platfo...
int TreeliteCompilerCreate(const char *name, CompilerHandle *out)
create a compiler with a given name
int TreeliteTreeBuilderSetLeafNode(TreeBuilderHandle handle, int node_key, float leaf_value)
Turn an empty node into a leaf node.
int TreeliteTreeBuilderDeleteNode(TreeBuilderHandle handle, int node_key)
Remove a node from a tree.
Building blocks for semantic model of tree prediction code.
int TreeliteDMatrixGetDimension(DMatrixHandle handle, size_t *out_num_row, size_t *out_num_col, size_t *out_nelem)
get dimensions of a DMatrix
void * ModelBuilderHandle
handle to ensemble builder class
size_t nelem
number of nonzero entries
std::vector< size_t > row_ptr
pointer to row headers; length of [num_row] + 1
void * CompilerHandle
handle to compiler class
int TreeliteDMatrixCreateFromFile(const char *path, const char *format, int nthread, int verbose, DMatrixHandle *out)
create DMatrix from a file
int TreeliteCompilerFree(CompilerHandle handle)
delete compiler from memory
TreeBuilder & GetTree(int index)
Get a reference to a tree in the ensemble.
#define API_END()
every function starts with API_BEGIN(); and finishes with API_END() or API_END_HANDLE_ERROR ...
int TreeliteTreeBuilderCreateNode(TreeBuilderHandle handle, int node_key)
Create an empty node within a tree.