11 template <
typename ThresholdType,
typename LeafOutputType>
13 ASTBuilder<ThresholdType, LeafOutputType>::BuildAST(
14 const ModelImpl<ThresholdType, LeafOutputType>& model) {
15 this->output_vector_flag = (model.task_param.leaf_vector_size > 1);
16 this->num_feature = model.num_feature;
17 this->average_output_flag = model.average_tree_output;
19 this->main_node = AddNode<MainNode>(
nullptr, model.param.global_bias,
20 model.average_tree_output,
21 static_cast<int>(model.trees.size()),
23 ASTNode* ac = AddNode<AccumulatorContextNode>(this->main_node);
24 this->main_node->children.push_back(ac);
25 for (
size_t tree_id = 0; tree_id < model.trees.size(); ++tree_id) {
26 ASTNode* tree_head = BuildASTFromTree(model.trees[tree_id], static_cast<int>(tree_id), 0, ac);
27 ac->children.push_back(tree_head);
29 this->model_param = model.param.__DICT__();
32 template <
typename ThresholdType,
typename LeafOutputType>
34 ASTBuilder<ThresholdType, LeafOutputType>::BuildASTFromTree(
35 const Tree<ThresholdType, LeafOutputType>& tree,
int tree_id,
int nid, ASTNode* parent) {
36 ASTNode* ast_node =
nullptr;
37 if (tree.IsLeaf(nid)) {
38 if (this->output_vector_flag) {
39 ast_node = AddNode<OutputNode<LeafOutputType>>(parent, tree.LeafVector(nid));
41 ast_node = AddNode<OutputNode<LeafOutputType>>(parent, tree.LeafValue(nid));
44 if (tree.SplitType(nid) == SplitFeatureType::kNumerical) {
45 ast_node = AddNode<NumericalConditionNode<ThresholdType>>(
48 tree.DefaultLeft(nid),
50 tree.ComparisonOp(nid),
51 ThresholdVariant<ThresholdType>(tree.Threshold(nid)));
53 ast_node = AddNode<CategoricalConditionNode>(
56 tree.DefaultLeft(nid),
57 tree.MatchingCategories(nid),
58 tree.CategoriesListRightChild(nid));
60 if (tree.HasGain(nid)) {
61 dynamic_cast<ConditionNode*
>(ast_node)->gain = tree.Gain(nid);
63 ast_node->children.push_back(BuildASTFromTree(tree, tree_id, tree.LeftChild(nid), ast_node));
64 ast_node->children.push_back(BuildASTFromTree(tree, tree_id, tree.RightChild(nid), ast_node));
66 ast_node->node_id = nid;
67 ast_node->tree_id = tree_id;
68 if (tree.HasDataCount(nid)) {
69 ast_node->data_count = tree.DataCount(nid);
71 if (tree.HasSumHess(nid)) {
72 ast_node->sum_hess = tree.SumHess(nid);
79 template void ASTBuilder<float, uint32_t>::BuildAST(
const ModelImpl<float, uint32_t>&);
80 template void ASTBuilder<float, float>::BuildAST(
const ModelImpl<float, float>&);
81 template void ASTBuilder<double, uint32_t>::BuildAST(
const ModelImpl<double, uint32_t>&);
82 template void ASTBuilder<double, double>::BuildAST(
const ModelImpl<double, double>&);
83 template ASTNode* ASTBuilder<float, uint32_t>::BuildASTFromTree(
84 const Tree<float, uint32_t>&,
int,
int, ASTNode*);
85 template ASTNode* ASTBuilder<float, float>::BuildASTFromTree(
86 const Tree<float, float>&,
int,
int, ASTNode*);
87 template ASTNode* ASTBuilder<double, uint32_t>::BuildASTFromTree(
88 const Tree<double, uint32_t>&,
int,
int, ASTNode*);
89 template ASTNode* ASTBuilder<double, double>::BuildASTFromTree(
90 const Tree<double, double>&,
int,
int, ASTNode*);