Treelite
frontend.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_FRONTEND_H_
8 #define TREELITE_FRONTEND_H_
9 
10 #include <treelite/base.h>
11 #include <string>
12 #include <memory>
13 #include <vector>
14 #include <cstdint>
15 
16 namespace treelite {
17 
18 class Model; // forward declaration
19 
20 namespace frontend {
21 
22 //--------------------------------------------------------------------------
23 // model loader interface: read from the disk
24 //--------------------------------------------------------------------------
31 std::unique_ptr<treelite::Model> LoadLightGBMModel(const char *filename);
38 std::unique_ptr<treelite::Model> LoadLightGBMModelFromString(const char* model_str);
45 std::unique_ptr<treelite::Model> LoadXGBoostModel(const char* filename);
52 std::unique_ptr<treelite::Model> LoadXGBoostModel(const void* buf, size_t len);
59 std::unique_ptr<treelite::Model> LoadXGBoostJSONModel(const char* filename);
66 std::unique_ptr<treelite::Model> LoadXGBoostJSONModelString(const char* json_str, size_t length);
93 std::unique_ptr<treelite::Model> LoadSKLearnRandomForestRegressor(
94  int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left,
95  const int64_t** children_right, const int64_t** feature, const double** threshold,
96  const double** value, const int64_t** n_node_samples, const double** weighted_n_node_samples,
97  const double** impurity);
123 std::unique_ptr<treelite::Model> LoadSKLearnIsolationForest(
124  int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left,
125  const int64_t** children_right, const int64_t** feature, const double** threshold,
126  const double** value, const int64_t** n_node_samples, const double** weighted_n_node_samples,
127  const double** impurity, const double ratio_c);
155 std::unique_ptr<treelite::Model> LoadSKLearnRandomForestClassifier(
156  int n_estimators, int n_features, int n_classes, const int64_t* node_count,
157  const int64_t** children_left, const int64_t** children_right, const int64_t** feature,
158  const double** threshold, const double** value, const int64_t** n_node_samples,
159  const double** weighted_n_node_samples, const double** impurity);
185 std::unique_ptr<treelite::Model> LoadSKLearnGradientBoostingRegressor(
186  int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left,
187  const int64_t** children_right, const int64_t** feature, const double** threshold,
188  const double** value, const int64_t** n_node_samples, const double** weighted_n_node_samples,
189  const double** impurity);
216 std::unique_ptr<treelite::Model> LoadSKLearnGradientBoostingClassifier(
217  int n_estimators, int n_features, int n_classes, const int64_t* node_count,
218  const int64_t** children_left, const int64_t** children_right, const int64_t** feature,
219  const double** threshold, const double** value, const int64_t** n_node_samples,
220  const double** weighted_n_node_samples, const double** impurity);
221 
222 //--------------------------------------------------------------------------
223 // model builder interface: build trees incrementally
224 //--------------------------------------------------------------------------
225 
226 /* forward declarations */
227 struct TreeBuilderImpl;
228 struct ModelBuilderImpl;
229 class ModelBuilder;
230 
231 class Value {
232  private:
233  std::shared_ptr<void> handle_;
234  TypeInfo type_;
235  public:
236  Value();
237  ~Value() = default;
238  Value(const Value&) = default;
239  Value(Value&&) noexcept = default;
240  Value& operator=(const Value&) = default;
241  Value& operator=(Value&&) noexcept = default;
242  template <typename T>
243  static Value Create(T init_value);
244  static Value Create(const void* init_value, TypeInfo type);
245  template <typename T>
246  T& Get();
247  template <typename T>
248  const T& Get() const;
249  template <typename Func>
250  inline auto Dispatch(Func func);
251  template <typename Func>
252  inline auto Dispatch(Func func) const;
253  TypeInfo GetValueType() const;
254 };
255 
257 class TreeBuilder {
258  public:
266  TreeBuilder(TypeInfo threshold_type, TypeInfo leaf_output_type); // constructor
267  ~TreeBuilder(); // destructor
268  // this class is only move-constructible and move-assignable
269  TreeBuilder(const TreeBuilder&) = delete;
270  TreeBuilder(TreeBuilder&&) = default;
271  TreeBuilder& operator=(const TreeBuilder&) = delete;
272  TreeBuilder& operator=(TreeBuilder&&) = default;
277  void CreateNode(int node_key);
282  void DeleteNode(int node_key);
287  void SetRootNode(int node_key);
301  void SetNumericalTestNode(int node_key, unsigned feature_id, const char* op, Value threshold,
302  bool default_left, int left_child_key, int right_child_key);
303  void SetNumericalTestNode(int node_key, unsigned feature_id, Operator op, Value threshold,
304  bool default_left, int left_child_key, int right_child_key);
318  void SetCategoricalTestNode(int node_key, unsigned feature_id,
319  const std::vector<uint32_t>& left_categories, bool default_left,
320  int left_child_key, int right_child_key);
327  void SetLeafNode(int node_key, Value leaf_value);
336  void SetLeafVectorNode(int node_key, const std::vector<Value>& leaf_vector);
337 
338  private:
339  std::unique_ptr<TreeBuilderImpl> pimpl_; // Pimpl pattern
340  ModelBuilder* ensemble_id_; // id of ensemble (nullptr if not part of any)
341  friend class ModelBuilder;
342  friend struct ModelBuilderImpl;
343 };
344 
347  public:
363  ModelBuilder(int num_feature, int num_class, bool average_tree_output,
364  TypeInfo threshold_type, TypeInfo leaf_output_type);
365  ~ModelBuilder(); // destructor
371  void SetModelParam(const char* name, const char* value);
384  int InsertTree(TreeBuilder* tree_builder, int index = -1);
390  TreeBuilder* GetTree(int index);
391  const TreeBuilder* GetTree(int index) const;
396  void DeleteTree(int index);
401  std::unique_ptr<Model> CommitModel();
402 
403  private:
404  std::unique_ptr<ModelBuilderImpl> pimpl_; // Pimpl pattern
405 };
406 
407 } // namespace frontend
408 } // namespace treelite
409 
410 #include "frontend_impl.h"
411 
412 #endif // TREELITE_FRONTEND_H_
tree builder class
Definition: frontend.h:257
Implementation for frontend.h.
std::unique_ptr< treelite::Model > LoadLightGBMModel(const char *filename)
load a model file generated by LightGBM (Microsoft/LightGBM). The model file must contain a decision ...
Definition: lightgbm.cc:28
std::unique_ptr< treelite::Model > LoadLightGBMModelFromString(const char *model_str)
Load a LightGBM model from a string. The string should be created with the model_to_string() method i...
Definition: lightgbm.cc:33
model builder class
Definition: frontend.h:346
TypeInfo
Types used by thresholds and leaf outputs.
Definition: typeinfo.h:22
std::unique_ptr< treelite::Model > LoadXGBoostModel(const char *filename)
load a model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tree en...
Definition: xgboost.cc:29
std::unique_ptr< treelite::Model > LoadSKLearnRandomForestClassifier(int n_estimators, int n_features, int n_classes, const int64_t *node_count, const int64_t **children_left, const int64_t **children_right, const int64_t **feature, const double **threshold, const double **value, const int64_t **n_node_samples, const double **weighted_n_node_samples, const double **impurity)
Load a scikit-learn random forest classifier model from a collection of arrays. Refer to https://scik...
Definition: sklearn.cc:195
defines configuration macros of Treelite
std::unique_ptr< treelite::Model > LoadXGBoostJSONModelString(const char *json_str, size_t length)
load an XGBoost model from a JSON string
Definition: xgboost_json.cc:82
std::unique_ptr< treelite::Model > LoadSKLearnRandomForestRegressor(int n_estimators, int n_features, const int64_t *node_count, const int64_t **children_left, const int64_t **children_right, const int64_t **feature, const double **threshold, const double **value, const int64_t **n_node_samples, const double **weighted_n_node_samples, const double **impurity)
Load a scikit-learn random forest regressor model from a collection of arrays. Refer to https://sciki...
Definition: sklearn.cc:77
std::unique_ptr< treelite::Model > LoadSKLearnGradientBoostingRegressor(int n_estimators, int n_features, const int64_t *node_count, const int64_t **children_left, const int64_t **children_right, const int64_t **feature, const double **threshold, const double **value, const int64_t **n_node_samples, const double **weighted_n_node_samples, const double **impurity)
Load a scikit-learn gradient boosting regressor model from a collection of arrays. Refer to https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html to learn the mearning of the arrays in detail.
Definition: sklearn.cc:212
std::unique_ptr< treelite::Model > LoadXGBoostJSONModel(const char *filename)
load a model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tree en...
Definition: xgboost_json.cc:42
std::unique_ptr< treelite::Model > LoadSKLearnIsolationForest(int n_estimators, int n_features, const int64_t *node_count, const int64_t **children_left, const int64_t **children_right, const int64_t **feature, const double **threshold, const double **value, const int64_t **n_node_samples, const double **weighted_n_node_samples, const double **impurity, const double ratio_c)
Load a scikit-learn isolation forest model from a collection of arrays. Refer to https://scikit-learn...
Definition: sklearn.cc:103
std::unique_ptr< treelite::Model > LoadSKLearnGradientBoostingClassifier(int n_estimators, int n_features, int n_classes, const int64_t *node_count, const int64_t **children_left, const int64_t **children_right, const int64_t **feature, const double **threshold, const double **value, const int64_t **n_node_samples, const double **weighted_n_node_samples, const double **impurity)
Load a scikit-learn gradient boosting classifier model from a collection of arrays. Refer to https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html to learn the mearning of the arrays in detail.
Definition: sklearn.cc:290
Operator
comparison operators
Definition: base.h:26