Treelite
frontend.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_FRONTEND_H_
8 #define TREELITE_FRONTEND_H_
9 
10 #include <dmlc/logging.h>
11 #include <treelite/base.h>
12 #include <string>
13 #include <memory>
14 #include <vector>
15 #include <cstdint>
16 
17 namespace treelite {
18 
19 class Model; // forward declaration
20 
21 namespace frontend {
22 
23 //--------------------------------------------------------------------------
24 // model loader interface: read from the disk
25 //--------------------------------------------------------------------------
32 std::unique_ptr<treelite::Model> LoadLightGBMModel(const char *filename);
39 std::unique_ptr<treelite::Model> LoadXGBoostModel(const char* filename);
46 std::unique_ptr<treelite::Model> LoadXGBoostModel(const void* buf, size_t len);
53 std::unique_ptr<treelite::Model> LoadXGBoostJSONModel(const char* filename);
60 std::unique_ptr<treelite::Model> LoadXGBoostJSONModelString(const char* json_str, size_t length);
85 std::unique_ptr<treelite::Model> LoadSKLearnRandomForestRegressor(
86  int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left,
87  const int64_t** children_right, const int64_t** feature, const double** threshold,
88  const double** value, const int64_t** n_node_samples, const double** impurity);
114 std::unique_ptr<treelite::Model> LoadSKLearnRandomForestClassifier(
115  int n_estimators, int n_features, int n_classes, const int64_t* node_count,
116  const int64_t** children_left, const int64_t** children_right, const int64_t** feature,
117  const double** threshold, const double** value, const int64_t** n_node_samples,
118  const double** impurity);
142 std::unique_ptr<treelite::Model> LoadSKLearnGradientBoostingRegressor(
143  int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left,
144  const int64_t** children_right, const int64_t** feature, const double** threshold,
145  const double** value, const int64_t** n_node_samples, const double** impurity);
170 std::unique_ptr<treelite::Model> LoadSKLearnGradientBoostingClassifier(
171  int n_estimators, int n_features, int n_classes, const int64_t* node_count,
172  const int64_t** children_left, const int64_t** children_right, const int64_t** feature,
173  const double** threshold, const double** value, const int64_t** n_node_samples,
174  const double** impurity);
175 
176 //--------------------------------------------------------------------------
177 // model builder interface: build trees incrementally
178 //--------------------------------------------------------------------------
179 
180 /* forward declarations */
181 struct TreeBuilderImpl;
182 struct ModelBuilderImpl;
183 class ModelBuilder;
184 
185 class Value {
186  private:
187  std::shared_ptr<void> handle_;
188  TypeInfo type_;
189  public:
190  Value();
191  ~Value() = default;
192  Value(const Value&) = default;
193  Value(Value&&) noexcept = default;
194  Value& operator=(const Value&) = default;
195  Value& operator=(Value&&) noexcept = default;
196  template <typename T>
197  static Value Create(T init_value);
198  static Value Create(const void* init_value, TypeInfo type);
199  template <typename T>
200  T& Get();
201  template <typename T>
202  const T& Get() const;
203  template <typename Func>
204  inline auto Dispatch(Func func);
205  template <typename Func>
206  inline auto Dispatch(Func func) const;
207  TypeInfo GetValueType() const;
208 };
209 
211 class TreeBuilder {
212  public:
220  TreeBuilder(TypeInfo threshold_type, TypeInfo leaf_output_type); // constructor
221  ~TreeBuilder(); // destructor
222  // this class is only move-constructible and move-assignable
223  TreeBuilder(const TreeBuilder&) = delete;
224  TreeBuilder(TreeBuilder&&) = default;
225  TreeBuilder& operator=(const TreeBuilder&) = delete;
226  TreeBuilder& operator=(TreeBuilder&&) = default;
231  void CreateNode(int node_key);
236  void DeleteNode(int node_key);
241  void SetRootNode(int node_key);
255  void SetNumericalTestNode(int node_key, unsigned feature_id, const char* op, Value threshold,
256  bool default_left, int left_child_key, int right_child_key);
257  void SetNumericalTestNode(int node_key, unsigned feature_id, Operator op, Value threshold,
258  bool default_left, int left_child_key, int right_child_key);
272  void SetCategoricalTestNode(int node_key, unsigned feature_id,
273  const std::vector<uint32_t>& left_categories, bool default_left,
274  int left_child_key, int right_child_key);
281  void SetLeafNode(int node_key, Value leaf_value);
290  void SetLeafVectorNode(int node_key, const std::vector<Value>& leaf_vector);
291 
292  private:
293  std::unique_ptr<TreeBuilderImpl> pimpl_; // Pimpl pattern
294  ModelBuilder* ensemble_id_; // id of ensemble (nullptr if not part of any)
295  friend class ModelBuilder;
296  friend struct ModelBuilderImpl;
297 };
298 
301  public:
317  ModelBuilder(int num_feature, int num_class, bool average_tree_output,
318  TypeInfo threshold_type, TypeInfo leaf_output_type);
319  ~ModelBuilder(); // destructor
325  void SetModelParam(const char* name, const char* value);
338  int InsertTree(TreeBuilder* tree_builder, int index = -1);
344  TreeBuilder* GetTree(int index);
345  const TreeBuilder* GetTree(int index) const;
350  void DeleteTree(int index);
355  std::unique_ptr<Model> CommitModel();
356 
357  private:
358  std::unique_ptr<ModelBuilderImpl> pimpl_; // Pimpl pattern
359 };
360 
361 } // namespace frontend
362 } // namespace treelite
363 
364 #include "frontend_impl.h"
365 
366 #endif // TREELITE_FRONTEND_H_
std::unique_ptr< treelite::Model > LoadSKLearnRandomForestRegressor(int n_estimators, int n_features, const int64_t *node_count, const int64_t **children_left, const int64_t **children_right, const int64_t **feature, const double **threshold, const double **value, const int64_t **n_node_samples, const double **impurity)
Load a scikit-learn random forest regressor model from a collection of arrays. Refer to https://sciki...
Definition: sklearn.cc:73
tree builder class
Definition: frontend.h:211
Implementation for frontend.h.
std::unique_ptr< treelite::Model > LoadLightGBMModel(const char *filename)
load a model file generated by LightGBM (Microsoft/LightGBM). The model file must contain a decision ...
Definition: lightgbm.cc:26
std::unique_ptr< treelite::Model > LoadSKLearnRandomForestClassifier(int n_estimators, int n_features, int n_classes, const int64_t *node_count, const int64_t **children_left, const int64_t **children_right, const int64_t **feature, const double **threshold, const double **value, const int64_t **n_node_samples, const double **impurity)
Load a scikit-learn random forest classifier model from a collection of arrays. Refer to https://scik...
Definition: sklearn.cc:161
std::unique_ptr< treelite::Model > LoadSKLearnGradientBoostingClassifier(int n_estimators, int n_features, int n_classes, const int64_t *node_count, const int64_t **children_left, const int64_t **children_right, const int64_t **feature, const double **threshold, const double **value, const int64_t **n_node_samples, const double **impurity)
Load a scikit-learn gradient boosting classifier model from a collection of arrays. Refer to https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html to learn the mearning of the arrays in detail.
Definition: sklearn.cc:254
model builder class
Definition: frontend.h:300
TypeInfo
Types used by thresholds and leaf outputs.
Definition: typeinfo.h:22
std::unique_ptr< treelite::Model > LoadSKLearnGradientBoostingRegressor(int n_estimators, int n_features, const int64_t *node_count, const int64_t **children_left, const int64_t **children_right, const int64_t **feature, const double **threshold, const double **value, const int64_t **n_node_samples, const double **impurity)
Load a scikit-learn gradient boosting regressor model from a collection of arrays. Refer to https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html to learn the mearning of the arrays in detail.
Definition: sklearn.cc:177
std::unique_ptr< treelite::Model > LoadXGBoostModel(const char *filename)
load a model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tree en...
Definition: xgboost.cc:29
defines configuration macros of Treelite
std::unique_ptr< treelite::Model > LoadXGBoostJSONModelString(const char *json_str, size_t length)
load an XGBoost model from a JSON string
Definition: xgboost_json.cc:85
std::unique_ptr< treelite::Model > LoadXGBoostJSONModel(const char *filename)
load a model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tree en...
Definition: xgboost_json.cc:45
Operator
comparison operators
Definition: base.h:26