Treelite
frontend.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_FRONTEND_H_
8 #define TREELITE_FRONTEND_H_
9 
10 #include <treelite/base.h>
11 #include <string>
12 #include <memory>
13 #include <vector>
14 #include <cstdint>
15 
16 namespace treelite {
17 
18 class Model; // forward declaration
19 
20 namespace frontend {
21 
22 //--------------------------------------------------------------------------
23 // model loader interface: read from the disk
24 //--------------------------------------------------------------------------
31 std::unique_ptr<treelite::Model> LoadLightGBMModel(const char *filename);
38 std::unique_ptr<treelite::Model> LoadXGBoostModel(const char* filename);
45 std::unique_ptr<treelite::Model> LoadXGBoostModel(const void* buf, size_t len);
52 std::unique_ptr<treelite::Model> LoadXGBoostJSONModel(const char* filename);
59 std::unique_ptr<treelite::Model> LoadXGBoostJSONModelString(const char* json_str, size_t length);
60 
61 //--------------------------------------------------------------------------
62 // model builder interface: build trees incrementally
63 //--------------------------------------------------------------------------
64 
65 /* forward declarations */
66 struct TreeBuilderImpl;
67 struct ModelBuilderImpl;
68 class ModelBuilder;
69 
70 class Value {
71  private:
72  std::shared_ptr<void> handle_;
73  TypeInfo type_;
74  public:
75  Value();
76  ~Value() = default;
77  Value(const Value&) = default;
78  Value(Value&&) noexcept = default;
79  Value& operator=(const Value&) = default;
80  Value& operator=(Value&&) noexcept = default;
81  template <typename T>
82  static Value Create(T init_value);
83  static Value Create(const void* init_value, TypeInfo type);
84  template <typename T>
85  T& Get();
86  template <typename T>
87  const T& Get() const;
88  template <typename Func>
89  inline auto Dispatch(Func func);
90  template <typename Func>
91  inline auto Dispatch(Func func) const;
92  TypeInfo GetValueType() const;
93 };
94 
96 class TreeBuilder {
97  public:
105  TreeBuilder(TypeInfo threshold_type, TypeInfo leaf_output_type); // constructor
106  ~TreeBuilder(); // destructor
107  // this class is only move-constructible and move-assignable
108  TreeBuilder(const TreeBuilder&) = delete;
109  TreeBuilder(TreeBuilder&&) = default;
110  TreeBuilder& operator=(const TreeBuilder&) = delete;
111  TreeBuilder& operator=(TreeBuilder&&) = default;
116  void CreateNode(int node_key);
121  void DeleteNode(int node_key);
126  void SetRootNode(int node_key);
140  void SetNumericalTestNode(int node_key, unsigned feature_id, const char* op, Value threshold,
141  bool default_left, int left_child_key, int right_child_key);
142  void SetNumericalTestNode(int node_key, unsigned feature_id, Operator op, Value threshold,
143  bool default_left, int left_child_key, int right_child_key);
157  void SetCategoricalTestNode(int node_key, unsigned feature_id,
158  const std::vector<uint32_t>& left_categories, bool default_left,
159  int left_child_key, int right_child_key);
166  void SetLeafNode(int node_key, Value leaf_value);
175  void SetLeafVectorNode(int node_key, const std::vector<Value>& leaf_vector);
176 
177  private:
178  std::unique_ptr<TreeBuilderImpl> pimpl_; // Pimpl pattern
179  ModelBuilder* ensemble_id_; // id of ensemble (nullptr if not part of any)
180  friend class ModelBuilder;
181  friend struct ModelBuilderImpl;
182 };
183 
186  public:
202  ModelBuilder(int num_feature, int num_class, bool average_tree_output,
203  TypeInfo threshold_type, TypeInfo leaf_output_type);
204  ~ModelBuilder(); // destructor
210  void SetModelParam(const char* name, const char* value);
223  int InsertTree(TreeBuilder* tree_builder, int index = -1);
229  TreeBuilder* GetTree(int index);
230  const TreeBuilder* GetTree(int index) const;
235  void DeleteTree(int index);
240  std::unique_ptr<Model> CommitModel();
241 
242  private:
243  std::unique_ptr<ModelBuilderImpl> pimpl_; // Pimpl pattern
244 };
245 
246 } // namespace frontend
247 } // namespace treelite
248 
249 #include "frontend_impl.h"
250 
251 #endif // TREELITE_FRONTEND_H_
tree builder class
Definition: frontend.h:96
Implementation for frontend.h.
std::unique_ptr< treelite::Model > LoadLightGBMModel(const char *filename)
load a model file generated by LightGBM (Microsoft/LightGBM). The model file must contain a decision ...
Definition: lightgbm.cc:26
model builder class
Definition: frontend.h:185
TypeInfo
Types used by thresholds and leaf outputs.
Definition: typeinfo.h:22
std::unique_ptr< treelite::Model > LoadXGBoostModel(const char *filename)
load a model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tree en...
Definition: xgboost.cc:29
defines configuration macros of Treelite
std::unique_ptr< treelite::Model > LoadXGBoostJSONModelString(const char *json_str, size_t length)
load an XGBoost model from a JSON string
Definition: xgboost_json.cc:85
std::unique_ptr< treelite::Model > LoadXGBoostJSONModel(const char *filename)
load a model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tree en...
Definition: xgboost_json.cc:45
Operator
comparison operators
Definition: base.h:26