8 #ifndef TREELITE_MODEL_BUILDER_H_
9 #define TREELITE_MODEL_BUILDER_H_
29 namespace model_builder {
73 virtual void NumericalTest(std::int32_t split_index,
double threshold,
bool default_left,
74 Operator cmp,
int left_child_key,
int right_child_key)
88 std::vector<std::uint32_t>
const& category_list,
bool category_list_right_child,
89 int left_child_key,
int right_child_key)
101 virtual void LeafVector(std::vector<float>
const& leaf_vector) = 0;
106 virtual void LeafVector(std::vector<double>
const& leaf_vector) = 0;
112 virtual void Gain(
double gain) = 0;
137 std::optional<std::string>
const& attributes)
170 std::vector<std::int32_t>
const&
class_id);
183 std::map<std::string, PostProcessorConfigParam>
config{};
196 std::string
const&
name, std::map<std::string, PostProcessorConfigParam>
const&
config);
239 std::optional<std::string>
const& attributes = std::nullopt);
Model builder interface.
Definition: model_builder.h:40
virtual void Gain(double gain)=0
Specify the gain (loss reduction) that's resulted from the current split.
virtual void NumericalTest(std::int32_t split_index, double threshold, bool default_left, Operator cmp, int left_child_key, int right_child_key)=0
Declare the current node as a numerical test node, where the test is of form [feature value] [cmp] [t...
virtual void CategoricalTest(std::int32_t split_index, bool default_left, std::vector< std::uint32_t > const &category_list, bool category_list_right_child, int left_child_key, int right_child_key)=0
Declare the current node as a categorical test node, where the test is of form [feature value] \in [c...
virtual std::unique_ptr< Model > CommitModel()=0
Conclude model building and obtain the final model object.
virtual void StartNode(int node_key)=0
Start a new node.
virtual void InitializeMetadata(Metadata const &metadata, TreeAnnotation const &tree_annotation, PostProcessorFunc const &postprocessor, std::vector< double > const &base_scores, std::optional< std::string > const &attributes)=0
Specify a metadata for this model, if no metadata was previously specified.
virtual void StartTree()=0
Start a new tree.
virtual void LeafVector(std::vector< float > const &leaf_vector)=0
Declare the current node as a leaf node with a vector output.
virtual void LeafScalar(double leaf_value)=0
Declare the current node as a leaf node with a scalar output.
virtual void LeafVector(std::vector< double > const &leaf_vector)=0
Declare the current node as a leaf node with a vector output.
virtual void SumHess(double sum_hess)=0
Specify the weighted sample count or the sum of Hessians for the data points that are mapped to the c...
virtual void EndTree()=0
End the current tree.
virtual void DataCount(std::uint64_t data_count)=0
Specify the number of data points (samples) that are mapped to the current node.
virtual void EndNode()=0
End the current node.
virtual ~ModelBuilder()=default
std::variant< std::int64_t, double, std::string > PostProcessorConfigParam
Parameter type used to configure postprocessor functions.
Definition: model_builder.h:176
std::unique_ptr< ModelBuilder > GetModelBuilder(TypeInfo threshold_type, TypeInfo leaf_output_type, Metadata const &metadata, TreeAnnotation const &tree_annotation, PostProcessorFunc const &postprocessor, std::vector< double > const &base_scores, std::optional< std::string > const &attributes=std::nullopt)
Initialize a model builder object with a given set of metadata.
Definition: contiguous_array.h:14
TypeInfo
Types used by thresholds and leaf outputs.
Definition: typeinfo.h:21
TaskType
Enum type representing the task type.
Definition: task_type.h:19
Operator
Type of comparison operators used in numerical test nodes.
Definition: operator.h:17
Define enum type Operator.
Specification for postprocessor of prediction outputs.
Definition: model_builder.h:181
std::string name
Definition: model_builder.h:182
std::map< std::string, PostProcessorConfigParam > config
Definition: model_builder.h:183
PostProcessorFunc(std::string const &name, std::map< std::string, PostProcessorConfigParam > const &config)
Constructor for PostProcessorFunc object.
PostProcessorFunc(std::string const &name)
Constructor for PostProcessorFunc object, with no configuration parameters.
Annotation for individual trees. Use this object to look up which target and class each tree is assoc...
Definition: model_builder.h:159
TreeAnnotation(std::int32_t num_tree, std::vector< std::int32_t > const &target_id, std::vector< std::int32_t > const &class_id)
Constructor for TreeAnnotation object.
std::int32_t num_tree
Definition: model_builder.h:160
std::vector< std::int32_t > class_id
Definition: model_builder.h:162
std::vector< std::int32_t > target_id
Definition: model_builder.h:161
Define enum type TaskType.
Define enum type NodeType.
Defines enum type TypeInfo.