treelite
model_builder.h
Go to the documentation of this file.
1 
8 #ifndef TREELITE_MODEL_BUILDER_H_
9 #define TREELITE_MODEL_BUILDER_H_
10 
11 #include <treelite/enum/operator.h>
14 #include <treelite/enum/typeinfo.h>
15 
16 #include <array>
17 #include <cstdint>
18 #include <map>
19 #include <memory>
20 #include <optional>
21 #include <string>
22 #include <variant>
23 #include <vector>
24 
25 namespace treelite {
26 
27 class Model;
28 
29 namespace model_builder {
30 
31 class Metadata;
32 class TreeAnnotation;
33 class PostProcessorFunc;
34 
40 class ModelBuilder {
41  public:
45  virtual void StartTree() = 0;
49  virtual void EndTree() = 0;
50 
55  virtual void StartNode(int node_key) = 0;
59  virtual void EndNode() = 0;
60 
73  virtual void NumericalTest(std::int32_t split_index, double threshold, bool default_left,
74  Operator cmp, int left_child_key, int right_child_key)
75  = 0;
87  virtual void CategoricalTest(std::int32_t split_index, bool default_left,
88  std::vector<std::uint32_t> const& category_list, bool category_list_right_child,
89  int left_child_key, int right_child_key)
90  = 0;
91 
96  virtual void LeafScalar(double leaf_value) = 0;
101  virtual void LeafVector(std::vector<float> const& leaf_vector) = 0;
106  virtual void LeafVector(std::vector<double> const& leaf_vector) = 0;
107 
112  virtual void Gain(double gain) = 0;
117  virtual void DataCount(std::uint64_t data_count) = 0;
123  virtual void SumHess(double sum_hess) = 0;
124 
135  virtual void InitializeMetadata(Metadata const& metadata, TreeAnnotation const& tree_annotation,
136  PostProcessorFunc const& postprocessor, std::vector<double> const& base_scores,
137  std::optional<std::string> const& attributes)
138  = 0;
143  virtual std::unique_ptr<Model> CommitModel() = 0;
144 
145  virtual ~ModelBuilder() = default;
146 };
147 
160  std::int32_t num_tree{0};
161  std::vector<std::int32_t> target_id{};
162  std::vector<std::int32_t> class_id{};
169  TreeAnnotation(std::int32_t num_tree, std::vector<std::int32_t> const& target_id,
170  std::vector<std::int32_t> const& class_id);
171 };
172 
176 using PostProcessorConfigParam = std::variant<std::int64_t, double, std::string>;
177 
182  std::string name{};
183  std::map<std::string, PostProcessorConfigParam> config{};
188  explicit PostProcessorFunc(std::string const& name);
196  std::string const& name, std::map<std::string, PostProcessorConfigParam> const& config);
197 };
198 
202 struct Metadata {
203  std::int32_t num_feature{0};
205  bool average_tree_output{false};
206  std::int32_t num_target{1};
207  std::vector<std::int32_t> num_class{1};
208  std::array<std::int32_t, 2> leaf_vector_shape{1, 1};
219  std::int32_t num_target, std::vector<std::int32_t> const& num_class,
220  std::array<std::int32_t, 2> const& leaf_vector_shape);
221 };
222 
236 std::unique_ptr<ModelBuilder> GetModelBuilder(TypeInfo threshold_type, TypeInfo leaf_output_type,
237  Metadata const& metadata, TreeAnnotation const& tree_annotation,
238  PostProcessorFunc const& postprocessor, std::vector<double> const& base_scores,
239  std::optional<std::string> const& attributes = std::nullopt);
240 
248 std::unique_ptr<ModelBuilder> GetModelBuilder(TypeInfo threshold_type, TypeInfo leaf_output_type);
249 
256 std::unique_ptr<ModelBuilder> GetModelBuilder(std::string const& json_str);
257 // Initialize metadata from a JSON string
258 
259 } // namespace model_builder
260 } // namespace treelite
261 
262 #endif // TREELITE_MODEL_BUILDER_H_
Model builder interface.
Definition: model_builder.h:40
virtual void Gain(double gain)=0
Specify the gain (loss reduction) that's resulted from the current split.
virtual void NumericalTest(std::int32_t split_index, double threshold, bool default_left, Operator cmp, int left_child_key, int right_child_key)=0
Declare the current node as a numerical test node, where the test is of form [feature value] [cmp] [t...
virtual void CategoricalTest(std::int32_t split_index, bool default_left, std::vector< std::uint32_t > const &category_list, bool category_list_right_child, int left_child_key, int right_child_key)=0
Declare the current node as a categorical test node, where the test is of form [feature value] \in [c...
virtual std::unique_ptr< Model > CommitModel()=0
Conclude model building and obtain the final model object.
virtual void StartNode(int node_key)=0
Start a new node.
virtual void InitializeMetadata(Metadata const &metadata, TreeAnnotation const &tree_annotation, PostProcessorFunc const &postprocessor, std::vector< double > const &base_scores, std::optional< std::string > const &attributes)=0
Specify a metadata for this model, if no metadata was previously specified.
virtual void StartTree()=0
Start a new tree.
virtual void LeafVector(std::vector< float > const &leaf_vector)=0
Declare the current node as a leaf node with a vector output.
virtual void LeafScalar(double leaf_value)=0
Declare the current node as a leaf node with a scalar output.
virtual void LeafVector(std::vector< double > const &leaf_vector)=0
Declare the current node as a leaf node with a vector output.
virtual void SumHess(double sum_hess)=0
Specify the weighted sample count or the sum of Hessians for the data points that are mapped to the c...
virtual void EndTree()=0
End the current tree.
virtual void DataCount(std::uint64_t data_count)=0
Specify the number of data points (samples) that are mapped to the current node.
virtual void EndNode()=0
End the current node.
std::variant< std::int64_t, double, std::string > PostProcessorConfigParam
Parameter type used to configure postprocessor functions.
Definition: model_builder.h:176
std::unique_ptr< ModelBuilder > GetModelBuilder(TypeInfo threshold_type, TypeInfo leaf_output_type, Metadata const &metadata, TreeAnnotation const &tree_annotation, PostProcessorFunc const &postprocessor, std::vector< double > const &base_scores, std::optional< std::string > const &attributes=std::nullopt)
Initialize a model builder object with a given set of metadata.
Definition: contiguous_array.h:14
TypeInfo
Types used by thresholds and leaf outputs.
Definition: typeinfo.h:21
TaskType
Enum type representing the task type.
Definition: task_type.h:19
Operator
Type of comparison operators used in numerical test nodes.
Definition: operator.h:17
Define enum type Operator.
Metadata object, consisting of metadata information about the model at large.
Definition: model_builder.h:202
std::array< std::int32_t, 2 > leaf_vector_shape
Definition: model_builder.h:208
TaskType task_type
Definition: model_builder.h:204
std::vector< std::int32_t > num_class
Definition: model_builder.h:207
std::int32_t num_feature
Definition: model_builder.h:203
std::int32_t num_target
Definition: model_builder.h:206
bool average_tree_output
Definition: model_builder.h:205
Metadata(std::int32_t num_feature, TaskType task_type, bool average_tree_output, std::int32_t num_target, std::vector< std::int32_t > const &num_class, std::array< std::int32_t, 2 > const &leaf_vector_shape)
Constructor for Metadata object.
Specification for postprocessor of prediction outputs.
Definition: model_builder.h:181
std::string name
Definition: model_builder.h:182
std::map< std::string, PostProcessorConfigParam > config
Definition: model_builder.h:183
PostProcessorFunc(std::string const &name, std::map< std::string, PostProcessorConfigParam > const &config)
Constructor for PostProcessorFunc object.
PostProcessorFunc(std::string const &name)
Constructor for PostProcessorFunc object, with no configuration parameters.
Annotation for individual trees. Use this object to look up which target and class each tree is assoc...
Definition: model_builder.h:159
TreeAnnotation(std::int32_t num_tree, std::vector< std::int32_t > const &target_id, std::vector< std::int32_t > const &class_id)
Constructor for TreeAnnotation object.
std::int32_t num_tree
Definition: model_builder.h:160
std::vector< std::int32_t > class_id
Definition: model_builder.h:162
std::vector< std::int32_t > target_id
Definition: model_builder.h:161
Define enum type TaskType.
Define enum type NodeType.
Defines enum type TypeInfo.