treelite
frontend.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_FRONTEND_H_
8 #define TREELITE_FRONTEND_H_
9 
10 #include <treelite/base.h>
11 #include <memory>
12 #include <vector>
13 #include <cstdint>
14 
15 namespace treelite {
16 
17 struct Model; // forward declaration
18 
19 namespace frontend {
20 
21 //--------------------------------------------------------------------------
22 // model loader interface: read from the disk
23 //--------------------------------------------------------------------------
30 Model LoadLightGBMModel(const char* filename);
37 Model LoadXGBoostModel(const char* filename);
44 Model LoadXGBoostModel(const void* buf, size_t len);
52 Model LoadProtobufModel(const char* filename);
60 void ExportXGBoostModel(const char* filename, const Model& model,
61  const char* name_obj);
62 
63 //--------------------------------------------------------------------------
64 // model builder interface: build trees incrementally
65 //--------------------------------------------------------------------------
66 struct TreeBuilderImpl; // forward declaration
67 struct ModelBuilderImpl; // ditto
68 class ModelBuilder; // ditto
69 
71 class TreeBuilder {
72  public:
73  TreeBuilder(); // constructor
74  ~TreeBuilder(); // destructor
75  // this class is only move-constructible and move-assignable
76  TreeBuilder(const TreeBuilder&) = delete;
77  TreeBuilder(TreeBuilder&&) = default;
78  TreeBuilder& operator=(const TreeBuilder&) = delete;
79  TreeBuilder& operator=(TreeBuilder&&) = default;
85  bool CreateNode(int node_key);
91  bool DeleteNode(int node_key);
97  bool SetRootNode(int node_key);
112  bool SetNumericalTestNode(int node_key, unsigned feature_id,
113  Operator op, tl_float threshold, bool default_left,
114  int left_child_key, int right_child_key);
129  bool SetCategoricalTestNode(int node_key,
130  unsigned feature_id,
131  const std::vector<uint32_t>& left_categories,
132  bool default_left, int left_child_key,
133  int right_child_key);
141  bool SetLeafNode(int node_key, tl_float leaf_value);
151  bool SetLeafVectorNode(int node_key,
152  const std::vector<tl_float>& leaf_vector);
153 
154  private:
155  std::unique_ptr<TreeBuilderImpl> pimpl; // Pimpl pattern
156  void* ensemble_id; // id of ensemble (nullptr if not part of any)
157  friend class ModelBuilder;
158 };
159 
162  public:
174  ModelBuilder(int num_feature, int num_output_group, bool random_forest_flag);
175  ~ModelBuilder(); // destructor
181  void SetModelParam(const char* name, const char* value);
194  int InsertTree(TreeBuilder* tree_builder, int index = -1);
200  TreeBuilder& GetTree(int index);
201  const TreeBuilder& GetTree(int index) const;
207  bool DeleteTree(int index);
214  bool CommitModel(Model* out_model);
215 
216  private:
217  std::unique_ptr<ModelBuilderImpl> pimpl; // Pimpl pattern
218 };
219 
220 } // namespace frontend
221 } // namespace treelite
222 #endif // TREELITE_FRONTEND_H_
thin wrapper for tree ensemble model
Definition: tree.h:351
float tl_float
float type to be used internally
Definition: base.h:17
tree builder class
Definition: frontend.h:71
bool SetRootNode(int node_key)
Set a node as the root of a tree.
Definition: builder.cc:135
bool CreateNode(int node_key)
Create an empty node within a tree.
Definition: builder.cc:106
Model LoadLightGBMModel(const char *filename)
load a model file generated by LightGBM (Microsoft/LightGBM). The model file must contain a decision ...
Definition: lightgbm.cc:24
Model LoadProtobufModel(const char *filename)
load a model in Protocol Buffers format. Protocol Buffers (google/protobuf) is a language- and platfo...
Definition: protobuf.cc:210
model builder class
Definition: frontend.h:161
Model LoadXGBoostModel(const char *filename)
load a model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tree en...
Definition: xgboost.cc:28
bool SetLeafNode(int node_key, tl_float leaf_value)
Turn an empty node into a leaf node.
Definition: builder.cc:227
defines configuration macros of treelite
bool SetLeafVectorNode(int node_key, const std::vector< tl_float > &leaf_vector)
Turn an empty node into a leaf vector node The leaf vector (collection of multiple leaf weights per l...
Definition: builder.cc:241
bool SetCategoricalTestNode(int node_key, unsigned feature_id, const std::vector< uint32_t > &left_categories, bool default_left, int left_child_key, int right_child_key)
Turn an empty node into a categorical test node. A list defines all categories that would be classifi...
Definition: builder.cc:189
bool DeleteNode(int node_key)
Remove a node from a tree.
Definition: builder.cc:115
void ExportXGBoostModel(const char *filename, const Model &model, const char *name_obj)
export a model in XGBoost format. The exported model can be read by XGBoost (dmlc/xgboost).
Definition: xgboost.cc:33
bool SetNumericalTestNode(int node_key, unsigned feature_id, Operator op, tl_float threshold, bool default_left, int left_child_key, int right_child_key)
Turn an empty node into a numerical test node; the test is in the form [feature value] OP [threshold]...
Definition: builder.cc:150
Operator
comparison operators
Definition: base.h:23