14 #include <dmlc/json.h> 15 #include <dmlc/thread_local.h> 17 #include <unordered_map> 20 #include "../compiler/param.h" 21 #include "../common/filesystem.h" 22 #include "../common/math.h" 28 struct CompilerHandleImpl {
30 std::vector<std::pair<std::string, std::string>> cfg;
31 std::unique_ptr<Compiler> compiler;
32 CompilerHandleImpl(
const std::string& name)
33 : name(name), cfg(), compiler(nullptr) {}
34 ~CompilerHandleImpl() =
default;
38 struct TreeliteAPIThreadLocalEntry {
44 using TreeliteAPIThreadLocalStore
45 = dmlc::ThreadLocalStore<TreeliteAPIThreadLocalEntry>;
61 const unsigned* col_ind,
62 const size_t* row_ptr,
69 auto& data_ = dmat->
data;
72 data_.reserve(row_ptr[num_row]);
73 col_ind_.reserve(row_ptr[num_row]);
74 row_ptr_.reserve(num_row + 1);
75 for (
size_t i = 0; i < num_row; ++i) {
76 const size_t jbegin = row_ptr[i];
77 const size_t jend = row_ptr[i + 1];
78 for (
size_t j = jbegin; j < jend; ++j) {
79 if (!common::math::CheckNAN(data[j])) {
80 data_.push_back(data[j]);
81 CHECK_LT(col_ind[j], std::numeric_limits<uint32_t>::max())
82 <<
"feature index too big to fit into uint32_t";
83 col_ind_.push_back(static_cast<uint32_t>(col_ind[j]));
86 row_ptr_.push_back(data_.size());
88 data_.shrink_to_fit();
89 col_ind_.shrink_to_fit();
92 dmat->
nelem = data_.size();
103 const bool nan_missing = common::math::CheckNAN(missing_value);
105 CHECK_LT(num_col, std::numeric_limits<uint32_t>::max())
106 <<
"num_col argument is too big";
109 auto& data_ = dmat->
data;
110 auto& col_ind_ = dmat->
col_ind;
111 auto& row_ptr_ = dmat->
row_ptr;
114 const size_t guess_size
115 = std::min(std::min(num_row * num_col, num_row * 1000),
116 static_cast<size_t>(64 * 1024 * 1024));
117 data_.reserve(guess_size);
118 col_ind_.reserve(guess_size);
119 row_ptr_.reserve(num_row + 1);
120 const float* row = &data[0];
121 for (
size_t i = 0; i < num_row; ++i, row += num_col) {
122 for (
size_t j = 0; j < num_col; ++j) {
123 if (common::math::CheckNAN(row[j])) {
125 <<
"The missing_value argument must be set to NaN if there is any " 126 <<
"NaN in the matrix.";
127 }
else if (nan_missing || row[j] != missing_value) {
129 data_.push_back(row[j]);
130 col_ind_.push_back(static_cast<uint32_t>(j));
133 row_ptr_.push_back(data_.size());
135 data_.shrink_to_fit();
136 col_ind_.shrink_to_fit();
139 dmat->
nelem = data_.size();
153 *out_nelem = dmat->
nelem;
158 const char** out_preview) {
161 std::string& ret_str = TreeliteAPIThreadLocalStore::Get()->ret_str;
162 std::ostringstream oss;
163 const size_t iend = (dmat->
nelem <= 50) ? dmat->
nelem : 25;
164 for (
size_t i = 0; i < iend; ++i) {
165 const size_t row_ind =
168 oss <<
" (" << row_ind <<
", " << dmat->
col_ind[i] <<
")\t" 169 << dmat->
data[i] <<
"\n";
171 if (dmat->
nelem > 50) {
173 for (
size_t i = dmat->
nelem - 25; i < dmat->nelem; ++i) {
174 const size_t row_ind =
177 oss <<
" (" << row_ind <<
", " << dmat->
col_ind[i] <<
")\t" 178 << dmat->
data[i] <<
"\n";
182 *out_preview = ret_str.c_str();
187 const float** out_data,
188 const uint32_t** out_col_ind,
189 const size_t** out_row_ptr) {
192 *out_data = &dmat_->
data[0];
193 *out_col_ind = &dmat_->
col_ind[0];
194 *out_row_ptr = &dmat_->
row_ptr[0];
200 delete static_cast<DMatrix*
>(handle);
211 const Model* model_ =
static_cast<Model*
>(model);
213 annotator->
Annotate(*model_, dmat_, nthread, verbose);
222 std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(path,
"r"));
223 annotator->
Load(fi.get());
232 std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(path,
"w"));
233 annotator->
Save(fo.get());
254 CompilerHandleImpl* impl =
static_cast<CompilerHandleImpl*
>(handle);
255 auto& cfg_ = impl->cfg;
256 std::string name_(name);
257 std::string value_(value);
259 auto it = std::find_if(cfg_.begin(), cfg_.end(),
260 [&name_](
const std::pair<std::string, std::string>& x) {
261 return x.first == name_;
263 if (it == cfg_.end()) {
264 cfg_.emplace_back(name_, value_);
274 const char* dirpath) {
278 std::to_string(verbose).c_str());
283 const Model* model_ =
static_cast<Model*
>(model);
284 CompilerHandleImpl* impl =
static_cast<CompilerHandleImpl*
>(compiler);
287 const std::string& dirpath_(dirpath);
288 common::filesystem::CreateDirectoryIfNotExist(dirpath);
289 const std::string basename = common::filesystem::GetBasename(dirpath);
292 cparam.Init(impl->cfg, dmlc::parameter::kAllMatch);
296 auto semantic_model = impl->compiler->Compile(*model_);
298 LOG(INFO) <<
"Code generation finished. Writing code to files...";
302 const std::string header_filename = dirpath_ +
"/" + basename +
".h";
304 LOG(INFO) <<
"Writing " << header_filename <<
" ...";
307 std::vector<std::string> lines;
308 common::TransformPushBack(&lines, semantic_model.common_header->Compile(),
309 [] (std::string line) {
312 lines.emplace_back();
313 std::ostringstream oss;
315 std::copy(semantic_model.function_registry.begin(),
316 semantic_model.function_registry.end(),
317 std::ostream_iterator<FunctionEntry>(oss));
318 lines.push_back(oss.str());
319 common::WriteToFile(header_filename, lines);
322 std::vector<std::unordered_map<std::string, std::string>> source_list;
323 std::vector<std::string> object_list;
324 if (semantic_model.units.size() == 1) {
325 const std::string filename = basename +
".c";
326 const std::string filename_full = dirpath_ +
"/" + filename;
327 const std::string objname = basename +
".o";
329 LOG(INFO) <<
"Writing " << filename_full <<
" ...";
331 auto lines = semantic_model.units[0].Compile(header_filename);
332 source_list.push_back({ {
"name", basename},
333 {
"length", std::to_string(lines.size())} });
334 object_list.push_back(objname);
335 common::WriteToFile(filename_full, lines);
337 for (
size_t i = 0; i < semantic_model.units.size(); ++i) {
338 const std::string filename = basename + std::to_string(i) +
".c";
339 const std::string filename_full = dirpath_ +
"/" + filename;
340 const std::string objname = basename + std::to_string(i) +
".o";
342 LOG(INFO) <<
"Writing " << filename_full <<
" ...";
344 auto lines = semantic_model.units[i].Compile(header_filename);
345 source_list.push_back({ {
"name", basename + std::to_string(i)},
346 {
"length", std::to_string(lines.size())} });
347 object_list.push_back(objname);
348 common::WriteToFile(filename_full, lines);
353 const std::string recipe_name = dirpath_ +
"/recipe.json";
355 LOG(INFO) <<
"Writing " << recipe_name <<
" ...";
357 std::unique_ptr<dmlc::Stream> fo(
358 dmlc::Stream::Create(recipe_name.c_str(),
"w"));
359 dmlc::ostream os(fo.get());
360 auto writer = common::make_unique<dmlc::JSONWriter>(&os);
361 writer->BeginObject();
362 writer->WriteObjectKeyValue(
"target", basename);
363 writer->WriteObjectKeyValue(
"sources", source_list);
366 os.set_stream(
nullptr);
371 std::string library_name = basename +
".so";
372 std::ostringstream oss;
373 oss <<
"all: " << library_name << std::endl << std::endl
374 << library_name <<
": ";
375 for (
const auto& e : object_list) {
379 <<
"\tgcc -shared -O3 -o $@ $? -fPIC -std=c99 -flto" 380 << std::endl << std::endl;
381 for (
size_t i = 0; i < object_list.size(); ++i) {
382 oss << object_list[i] <<
": " 383 << source_list[i][
"name"] <<
".c" << std::endl
384 <<
"\tgcc -c -O3 -o $@ $? -fPIC -std=c99 -flto" << std::endl;
387 <<
"clean:" << std::endl
388 <<
"\trm -fv " << library_name <<
" ";
389 for (
const auto& e : object_list) {
392 common::WriteToFile(dirpath_ +
"/Makefile", {oss.str()});
394 LOG(INFO) <<
"Writing " << dirpath_ +
"/Makefile ...";
403 delete static_cast<CompilerHandleImpl*
>(handle);
410 Model* model =
new Model(std::move(frontend::LoadLightGBMModel(filename)));
418 Model* model =
new Model(std::move(frontend::LoadXGBoostModel(filename)));
426 Model* model =
new Model(std::move(frontend::LoadXGBoostModel(buf, len)));
434 Model* model =
new Model(std::move(frontend::LoadProtobufModel(filename)));
441 delete static_cast<Model*
>(handle);
461 return (builder->CreateNode(node_key)) ? 0 : -1;
468 return (builder->DeleteNode(node_key)) ? 0 : -1;
475 return (builder->SetRootNode(node_key)) ? 0 : -1;
480 int node_key,
unsigned feature_id,
482 float threshold,
int default_left,
484 int right_child_key) {
487 CHECK_GT(
optable.count(opname), 0)
488 <<
"No operator `" << opname <<
"\" exists";
489 return (builder->SetNumericalTestNode(node_key, feature_id,
493 left_child_key, right_child_key)) \
500 int node_key,
unsigned feature_id,
501 const unsigned char* left_categories,
502 size_t left_categories_len,
505 int right_child_key) {
508 std::vector<uint8_t> vec(left_categories_len);
509 for (
size_t i = 0; i < left_categories_len; ++i) {
510 CHECK(left_categories[i] <= std::numeric_limits<uint8_t>::max());
511 vec[i] =
static_cast<uint8_t
>(left_categories[i]);
513 return (builder->SetCategoricalTestNode(node_key, feature_id, vec,
515 left_child_key, right_child_key)) \
524 return (builder->SetLeafNode(node_key, static_cast<tl_float>(leaf_value))) \
531 const float* leaf_vector,
532 size_t leaf_vector_len) {
535 std::vector<tl_float> vec(leaf_vector_len);
536 for (
size_t i = 0; i < leaf_vector_len; ++i) {
537 vec[i] =
static_cast<tl_float>(leaf_vector[i]);
539 return (builder->SetLeafVectorNode(node_key, vec)) ? 0 : -1;
544 int num_output_group,
545 int random_forest_flag,
549 (random_forest_flag != 0));
575 return model_builder->InsertTree(tree_builder, index);
583 auto tree_builder = &model_builder->
GetTree(index);
591 return (builder->DeleteTree(index)) ? 0 : -1;
600 const bool result = builder->CommitModel(model);
C API of treelite, used for interfacing with other languages This header is excluded from the runtime...
int TreeliteTreeBuilderSetNumericalTestNode(TreeBuilderHandle handle, int node_key, unsigned feature_id, const char *opname, float threshold, int default_left, int left_child_key, int right_child_key)
Turn an empty node into a test node with numerical split. The test is in the form [feature value] OP ...
int TreeliteModelBuilderSetModelParam(ModelBuilderHandle handle, const char *name, const char *value)
Set a model parameter.
int TreeliteModelBuilderGetTree(ModelBuilderHandle handle, int index, TreeBuilderHandle *out)
Get a reference to a tree in the ensemble.
std::vector< float > data
feature values
Collection of front-end methods to load or construct ensemble model.
thin wrapper for tree ensemble model
float tl_float
float type to be used internally
int TreeliteLoadXGBoostModel(const char *filename, ModelHandle *out)
load a model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tree en...
#define API_BEGIN()
macro to guard beginning and end section of all functions
int TreeliteFreeModel(ModelHandle handle)
delete model from memory
int TreeliteAnnotationSave(AnnotationHandle handle, const char *path)
save branch annotation to a JSON file
int TreeliteDMatrixCreateFromCSR(const float *data, const unsigned *col_ind, const size_t *row_ptr, size_t num_row, size_t num_col, DMatrixHandle *out)
create DMatrix from a (in-memory) CSR matrix
int TreeliteDMatrixFree(DMatrixHandle handle)
delete DMatrix from memory
int TreeliteModelBuilderDeleteTree(ModelBuilderHandle handle, int index)
Remove a tree from the ensemble.
parameters for tree compiler
Input data structure of treelite.
void Annotate(const Model &model, const DMatrix *dmat, int nthread, int verbose)
annotate branches in a given model using frequency patterns in the training data. The annotation can ...
int TreeliteCompilerSetParam(CompilerHandle handle, const char *name, const char *value)
set a parameter for a compiler
int TreeliteDeleteModelBuilder(ModelBuilderHandle handle)
Delete a model builder from memory.
void SetModelParam(const char *name, const char *value)
Set a model parameter.
int TreeliteCreateTreeBuilder(TreeBuilderHandle *out)
Create a new tree builder.
int TreeliteAnnotationFree(AnnotationHandle handle)
delete branch annotation from memory
std::vector< uint32_t > col_ind
feature indices
static DMatrix * Create(const char *filename, const char *format, int nthread, int verbose)
construct a new DMatrix from a file
int TreeliteTreeBuilderSetLeafVectorNode(TreeBuilderHandle handle, int node_key, const float *leaf_vector, size_t leaf_vector_len)
Turn an empty node into a leaf vector node The leaf vector (collection of multiple leaf weights per l...
int TreeliteAnnotateBranch(ModelHandle model, DMatrixHandle dmat, int nthread, int verbose, AnnotationHandle *out)
annotate branches in a given model using frequency patterns in the training data. ...
Interface of compiler that translates a tree ensemble model into a semantic model.
int TreeliteDMatrixCreateFromMat(const float *data, size_t num_row, size_t num_col, float missing_value, DMatrixHandle *out)
create DMatrix from a (in-memory) dense matrix
int TreeliteModelBuilderInsertTree(ModelBuilderHandle handle, TreeBuilderHandle tree_builder_handle, int index)
Insert a tree at specified location.
int TreeliteDMatrixGetArrays(DMatrixHandle handle, const float **out_data, const uint32_t **out_col_ind, const size_t **out_row_ptr)
extract three arrays (data, col_ind, row_ptr) that define a DMatrix.
size_t num_row
number of rows
a simple data matrix in CSR (Compressed Sparse Row) storage
void Load(dmlc::Stream *fi)
load branch annotation from a JSON file
void Save(dmlc::Stream *fo) const
save branch annotation to a JSON file
int TreeliteTreeBuilderSetRootNode(TreeBuilderHandle handle, int node_key)
Set a node as the root of a tree.
void * TreeBuilderHandle
handle to tree builder class
Error handling for C API.
int TreeliteDeleteTreeBuilder(TreeBuilderHandle handle)
Delete a tree builder from memory.
void * AnnotationHandle
handle to branch annotation data
int TreeliteModelBuilderCommitModel(ModelBuilderHandle handle, ModelHandle *out)
finalize the model and produce the in-memory representation
int TreeliteCompilerGenerateCode(CompilerHandle compiler, ModelHandle model, int verbose, const char *dirpath)
generate prediction code from a tree ensemble model. The code will be C99 compliant. One header file (.h) will be generated, along with one or more source files (.c).
void Clear()
clear all data fields
int TreeliteCreateModelBuilder(int num_feature, int num_output_group, int random_forest_flag, ModelBuilderHandle *out)
Create a new model builder.
int TreeliteAnnotationLoad(const char *path, AnnotationHandle *out)
load branch annotation from a JSON file
const std::unordered_map< std::string, Operator > optable
conversion table from string to operator, defined in optable.cc
int TreeliteLoadXGBoostModelFromMemoryBuffer(const void *buf, size_t len, ModelHandle *out)
load an XGBoost model from a memory buffer.
int TreeliteDMatrixGetPreview(DMatrixHandle handle, const char **out_preview)
produce a human-readable preview of a DMatrix Will print first and last 25 non-zero entries...
void * ModelHandle
handle to a decision tree ensemble model
int TreeliteLoadLightGBMModel(const char *filename, ModelHandle *out)
load a model file generated by LightGBM (Microsoft/LightGBM). The model file must contain a decision ...
static Compiler * Create(const std::string &name, const compiler::CompilerParam ¶m)
create a compiler from given name
void * DMatrixHandle
handle to a data matrix
size_t num_col
number of columns
int TreeliteLoadProtobufModel(const char *filename, ModelHandle *out)
load a model in Protocol Buffers format. Protocol Buffers (google/protobuf) is a language- and platfo...
int TreeliteCompilerCreate(const char *name, CompilerHandle *out)
create a compiler with a given name
int TreeliteTreeBuilderSetLeafNode(TreeBuilderHandle handle, int node_key, float leaf_value)
Turn an empty node into a leaf node.
int TreeliteTreeBuilderDeleteNode(TreeBuilderHandle handle, int node_key)
Remove a node from a tree.
Building blocks for semantic model of tree prediction code.
int TreeliteDMatrixGetDimension(DMatrixHandle handle, size_t *out_num_row, size_t *out_num_col, size_t *out_nelem)
get dimensions of a DMatrix
void * ModelBuilderHandle
handle to ensemble builder class
int TreeliteTreeBuilderSetCategoricalTestNode(TreeBuilderHandle handle, int node_key, unsigned feature_id, const unsigned char *left_categories, size_t left_categories_len, int default_left, int left_child_key, int right_child_key)
Turn an empty node into a test node with categorical split. A list defines all categories that would ...
size_t nelem
number of nonzero entries
std::vector< size_t > row_ptr
pointer to row headers; length of [num_row] + 1
void * CompilerHandle
handle to compiler class
int TreeliteDMatrixCreateFromFile(const char *path, const char *format, int nthread, int verbose, DMatrixHandle *out)
create DMatrix from a file
int TreeliteCompilerFree(CompilerHandle handle)
delete compiler from memory
TreeBuilder & GetTree(int index)
Get a reference to a tree in the ensemble.
#define API_END()
every function starts with API_BEGIN(); and finishes with API_END() or API_END_HANDLE_ERROR ...
int TreeliteTreeBuilderCreateNode(TreeBuilderHandle handle, int node_key)
Create an empty node within a tree.