Treelite
frontend.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_FRONTEND_H_
8 #define TREELITE_FRONTEND_H_
9 
10 #include <treelite/base.h>
11 #include <string>
12 #include <memory>
13 #include <vector>
14 #include <cstdint>
15 
16 namespace treelite {
17 
18 class Model; // forward declaration
19 
20 namespace frontend {
21 
22 //--------------------------------------------------------------------------
23 // model loader interface: read from the disk
24 //--------------------------------------------------------------------------
31 std::unique_ptr<treelite::Model> LoadLightGBMModel(const char *filename);
38 std::unique_ptr<treelite::Model> LoadXGBoostModel(const char* filename);
45 std::unique_ptr<treelite::Model> LoadXGBoostModel(const void* buf, size_t len);
52 std::unique_ptr<treelite::Model> LoadXGBoostJSONModel(const char* filename);
59 std::unique_ptr<treelite::Model> LoadXGBoostJSONModelString(const char* json_str, size_t length);
84 std::unique_ptr<treelite::Model> LoadSKLearnRandomForestRegressor(
85  int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left,
86  const int64_t** children_right, const int64_t** feature, const double** threshold,
87  const double** value, const int64_t** n_node_samples, const double** impurity);
113 std::unique_ptr<treelite::Model> LoadSKLearnRandomForestClassifier(
114  int n_estimators, int n_features, int n_classes, const int64_t* node_count,
115  const int64_t** children_left, const int64_t** children_right, const int64_t** feature,
116  const double** threshold, const double** value, const int64_t** n_node_samples,
117  const double** impurity);
141 std::unique_ptr<treelite::Model> LoadSKLearnGradientBoostingRegressor(
142  int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left,
143  const int64_t** children_right, const int64_t** feature, const double** threshold,
144  const double** value, const int64_t** n_node_samples, const double** impurity);
169 std::unique_ptr<treelite::Model> LoadSKLearnGradientBoostingClassifier(
170  int n_estimators, int n_features, int n_classes, const int64_t* node_count,
171  const int64_t** children_left, const int64_t** children_right, const int64_t** feature,
172  const double** threshold, const double** value, const int64_t** n_node_samples,
173  const double** impurity);
174 
175 //--------------------------------------------------------------------------
176 // model builder interface: build trees incrementally
177 //--------------------------------------------------------------------------
178 
179 /* forward declarations */
180 struct TreeBuilderImpl;
181 struct ModelBuilderImpl;
182 class ModelBuilder;
183 
184 class Value {
185  private:
186  std::shared_ptr<void> handle_;
187  TypeInfo type_;
188  public:
189  Value();
190  ~Value() = default;
191  Value(const Value&) = default;
192  Value(Value&&) noexcept = default;
193  Value& operator=(const Value&) = default;
194  Value& operator=(Value&&) noexcept = default;
195  template <typename T>
196  static Value Create(T init_value);
197  static Value Create(const void* init_value, TypeInfo type);
198  template <typename T>
199  T& Get();
200  template <typename T>
201  const T& Get() const;
202  template <typename Func>
203  inline auto Dispatch(Func func);
204  template <typename Func>
205  inline auto Dispatch(Func func) const;
206  TypeInfo GetValueType() const;
207 };
208 
210 class TreeBuilder {
211  public:
219  TreeBuilder(TypeInfo threshold_type, TypeInfo leaf_output_type); // constructor
220  ~TreeBuilder(); // destructor
221  // this class is only move-constructible and move-assignable
222  TreeBuilder(const TreeBuilder&) = delete;
223  TreeBuilder(TreeBuilder&&) = default;
224  TreeBuilder& operator=(const TreeBuilder&) = delete;
225  TreeBuilder& operator=(TreeBuilder&&) = default;
230  void CreateNode(int node_key);
235  void DeleteNode(int node_key);
240  void SetRootNode(int node_key);
254  void SetNumericalTestNode(int node_key, unsigned feature_id, const char* op, Value threshold,
255  bool default_left, int left_child_key, int right_child_key);
256  void SetNumericalTestNode(int node_key, unsigned feature_id, Operator op, Value threshold,
257  bool default_left, int left_child_key, int right_child_key);
271  void SetCategoricalTestNode(int node_key, unsigned feature_id,
272  const std::vector<uint32_t>& left_categories, bool default_left,
273  int left_child_key, int right_child_key);
280  void SetLeafNode(int node_key, Value leaf_value);
289  void SetLeafVectorNode(int node_key, const std::vector<Value>& leaf_vector);
290 
291  private:
292  std::unique_ptr<TreeBuilderImpl> pimpl_; // Pimpl pattern
293  ModelBuilder* ensemble_id_; // id of ensemble (nullptr if not part of any)
294  friend class ModelBuilder;
295  friend struct ModelBuilderImpl;
296 };
297 
300  public:
316  ModelBuilder(int num_feature, int num_class, bool average_tree_output,
317  TypeInfo threshold_type, TypeInfo leaf_output_type);
318  ~ModelBuilder(); // destructor
324  void SetModelParam(const char* name, const char* value);
337  int InsertTree(TreeBuilder* tree_builder, int index = -1);
343  TreeBuilder* GetTree(int index);
344  const TreeBuilder* GetTree(int index) const;
349  void DeleteTree(int index);
354  std::unique_ptr<Model> CommitModel();
355 
356  private:
357  std::unique_ptr<ModelBuilderImpl> pimpl_; // Pimpl pattern
358 };
359 
360 } // namespace frontend
361 } // namespace treelite
362 
363 #include "frontend_impl.h"
364 
365 #endif // TREELITE_FRONTEND_H_
std::unique_ptr< treelite::Model > LoadSKLearnRandomForestRegressor(int n_estimators, int n_features, const int64_t *node_count, const int64_t **children_left, const int64_t **children_right, const int64_t **feature, const double **threshold, const double **value, const int64_t **n_node_samples, const double **impurity)
Load a scikit-learn random forest regressor model from a collection of arrays. Refer to https://sciki...
Definition: sklearn.cc:74
tree builder class
Definition: frontend.h:210
Implementation for frontend.h.
std::unique_ptr< treelite::Model > LoadLightGBMModel(const char *filename)
load a model file generated by LightGBM (Microsoft/LightGBM). The model file must contain a decision ...
Definition: lightgbm.cc:27
std::unique_ptr< treelite::Model > LoadSKLearnRandomForestClassifier(int n_estimators, int n_features, int n_classes, const int64_t *node_count, const int64_t **children_left, const int64_t **children_right, const int64_t **feature, const double **threshold, const double **value, const int64_t **n_node_samples, const double **impurity)
Load a scikit-learn random forest classifier model from a collection of arrays. Refer to https://scik...
Definition: sklearn.cc:162
std::unique_ptr< treelite::Model > LoadSKLearnGradientBoostingClassifier(int n_estimators, int n_features, int n_classes, const int64_t *node_count, const int64_t **children_left, const int64_t **children_right, const int64_t **feature, const double **threshold, const double **value, const int64_t **n_node_samples, const double **impurity)
Load a scikit-learn gradient boosting classifier model from a collection of arrays. Refer to https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html to learn the mearning of the arrays in detail.
Definition: sklearn.cc:255
model builder class
Definition: frontend.h:299
TypeInfo
Types used by thresholds and leaf outputs.
Definition: typeinfo.h:22
std::unique_ptr< treelite::Model > LoadSKLearnGradientBoostingRegressor(int n_estimators, int n_features, const int64_t *node_count, const int64_t **children_left, const int64_t **children_right, const int64_t **feature, const double **threshold, const double **value, const int64_t **n_node_samples, const double **impurity)
Load a scikit-learn gradient boosting regressor model from a collection of arrays. Refer to https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html to learn the mearning of the arrays in detail.
Definition: sklearn.cc:178
std::unique_ptr< treelite::Model > LoadXGBoostModel(const char *filename)
load a model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tree en...
Definition: xgboost.cc:29
defines configuration macros of Treelite
std::unique_ptr< treelite::Model > LoadXGBoostJSONModelString(const char *json_str, size_t length)
load an XGBoost model from a JSON string
Definition: xgboost_json.cc:82
std::unique_ptr< treelite::Model > LoadXGBoostJSONModel(const char *filename)
load a model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tree en...
Definition: xgboost_json.cc:42
Operator
comparison operators
Definition: base.h:26