Treelite
frontend.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_FRONTEND_H_
8 #define TREELITE_FRONTEND_H_
9 
10 #include <treelite/base.h>
11 #include <memory>
12 #include <vector>
13 #include <cstdint>
14 
15 namespace treelite {
16 
17 struct Model; // forward declaration
18 
19 namespace frontend {
20 
21 //--------------------------------------------------------------------------
22 // model loader interface: read from the disk
23 //--------------------------------------------------------------------------
30 void LoadLightGBMModel(const char *filename, Model* out);
37 void LoadXGBoostModel(const char* filename, Model* out);
44 void LoadXGBoostModel(const void* buf, size_t len, Model* out);
52 void LoadProtobufModel(const char* filename, Model* out);
60 void ExportProtobufModel(const char* filename, const Model& model);
61 
62 //--------------------------------------------------------------------------
63 // model builder interface: build trees incrementally
64 //--------------------------------------------------------------------------
65 struct TreeBuilderImpl; // forward declaration
66 struct ModelBuilderImpl; // ditto
67 class ModelBuilder; // ditto
68 
70 class TreeBuilder {
71  public:
72  TreeBuilder(); // constructor
73  ~TreeBuilder(); // destructor
74  // this class is only move-constructible and move-assignable
75  TreeBuilder(const TreeBuilder&) = delete;
76  TreeBuilder(TreeBuilder&&) = default;
77  TreeBuilder& operator=(const TreeBuilder&) = delete;
78  TreeBuilder& operator=(TreeBuilder&&) = default;
83  void CreateNode(int node_key);
88  void DeleteNode(int node_key);
93  void SetRootNode(int node_key);
107  void SetNumericalTestNode(int node_key, unsigned feature_id,
108  const char* op, tl_float threshold, bool default_left,
109  int left_child_key, int right_child_key);
110  void SetNumericalTestNode(int node_key, unsigned feature_id,
111  Operator op, tl_float threshold, bool default_left,
112  int left_child_key, int right_child_key);
126  void SetCategoricalTestNode(int node_key,
127  unsigned feature_id,
128  const std::vector<uint32_t>& left_categories,
129  bool default_left, int left_child_key,
130  int right_child_key);
137  void SetLeafNode(int node_key, tl_float leaf_value);
146  void SetLeafVectorNode(int node_key,
147  const std::vector<tl_float>& leaf_vector);
148 
149  private:
150  std::unique_ptr<TreeBuilderImpl> pimpl; // Pimpl pattern
151  void* ensemble_id; // id of ensemble (nullptr if not part of any)
152  friend class ModelBuilder;
153 };
154 
157  public:
169  ModelBuilder(int num_feature, int num_output_group, bool random_forest_flag);
170  ~ModelBuilder(); // destructor
176  void SetModelParam(const char* name, const char* value);
189  int InsertTree(TreeBuilder* tree_builder, int index = -1);
195  TreeBuilder* GetTree(int index);
196  const TreeBuilder* GetTree(int index) const;
201  void DeleteTree(int index);
206  void CommitModel(Model* out_model);
207 
208  private:
209  std::unique_ptr<ModelBuilderImpl> pimpl; // Pimpl pattern
210 };
211 
212 } // namespace frontend
213 } // namespace treelite
214 #endif // TREELITE_FRONTEND_H_
void ExportProtobufModel(const char *filename, const Model &model)
export a model in Protocol Buffers format. Protocol Buffers (google/protobuf) is a language- and plat...
Definition: protobuf.cc:344
thin wrapper for tree ensemble model
Definition: tree.h:409
float tl_float
float type to be used internally
Definition: base.h:18
void DeleteNode(int node_key)
Remove a node from a tree.
Definition: builder.cc:105
tree builder class
Definition: frontend.h:70
void SetRootNode(int node_key)
Set a node as the root of a tree.
Definition: builder.cc:128
model builder class
Definition: frontend.h:156
void SetNumericalTestNode(int node_key, unsigned feature_id, const char *op, tl_float threshold, bool default_left, int left_child_key, int right_child_key)
Turn an empty node into a numerical test node; the test is in the form [feature value] OP [threshold]...
Definition: builder.cc:138
void LoadProtobufModel(const char *filename, Model *out)
load a model in Protocol Buffers format. Protocol Buffers (google/protobuf) is a language- and platfo...
Definition: protobuf.cc:340
void SetCategoricalTestNode(int node_key, unsigned feature_id, const std::vector< uint32_t > &left_categories, bool default_left, int left_child_key, int right_child_key)
Turn an empty node into a categorical test node. A list defines all categories that would be classifi...
Definition: builder.cc:185
void LoadXGBoostModel(const char *filename, Model *out)
load a model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tree en...
Definition: xgboost.cc:28
void LoadLightGBMModel(const char *filename, Model *out)
load a model file generated by LightGBM (Microsoft/LightGBM). The model file must contain a decision ...
Definition: lightgbm.cc:26
defines configuration macros of Treelite
void SetLeafNode(int node_key, tl_float leaf_value)
Turn an empty node into a leaf node.
Definition: builder.cc:219
void CreateNode(int node_key)
Create an empty node within a tree.
Definition: builder.cc:98
void SetLeafVectorNode(int node_key, const std::vector< tl_float > &leaf_vector)
Turn an empty node into a leaf vector node The leaf vector (collection of multiple leaf weights per l...
Definition: builder.cc:230
Operator
comparison operators
Definition: base.h:24