6 #include <dmlc/registry.h> 12 DMLC_REGISTRY_FILE_TAG(build);
14 template <
typename ThresholdType,
typename LeafOutputType>
16 ASTBuilder<ThresholdType, LeafOutputType>::BuildAST(
17 const ModelImpl<ThresholdType, LeafOutputType>& model) {
18 this->output_vector_flag = (model.task_param.leaf_vector_size > 1);
19 this->num_feature = model.num_feature;
20 this->average_output_flag = model.average_tree_output;
22 this->main_node = AddNode<MainNode>(
nullptr, model.param.global_bias,
23 model.average_tree_output,
24 static_cast<int>(model.trees.size()),
26 ASTNode* ac = AddNode<AccumulatorContextNode>(this->main_node);
27 this->main_node->children.push_back(ac);
28 for (
size_t tree_id = 0; tree_id < model.trees.size(); ++tree_id) {
29 ASTNode* tree_head = BuildASTFromTree(model.trees[tree_id], static_cast<int>(tree_id), 0, ac);
30 ac->children.push_back(tree_head);
32 this->model_param = model.param.__DICT__();
35 template <
typename ThresholdType,
typename LeafOutputType>
37 ASTBuilder<ThresholdType, LeafOutputType>::BuildASTFromTree(
38 const Tree<ThresholdType, LeafOutputType>& tree,
int tree_id,
int nid, ASTNode* parent) {
39 ASTNode* ast_node =
nullptr;
40 if (tree.IsLeaf(nid)) {
41 if (this->output_vector_flag) {
42 ast_node = AddNode<OutputNode<LeafOutputType>>(parent, tree.LeafVector(nid));
44 ast_node = AddNode<OutputNode<LeafOutputType>>(parent, tree.LeafValue(nid));
47 if (tree.SplitType(nid) == SplitFeatureType::kNumerical) {
48 ast_node = AddNode<NumericalConditionNode<ThresholdType>>(
51 tree.DefaultLeft(nid),
53 tree.ComparisonOp(nid),
54 ThresholdVariant<ThresholdType>(tree.Threshold(nid)));
56 ast_node = AddNode<CategoricalConditionNode>(
59 tree.DefaultLeft(nid),
60 tree.MatchingCategories(nid),
61 tree.CategoriesListRightChild(nid));
63 if (tree.HasGain(nid)) {
64 dynamic_cast<ConditionNode*
>(ast_node)->gain = tree.Gain(nid);
66 ast_node->children.push_back(BuildASTFromTree(tree, tree_id, tree.LeftChild(nid), ast_node));
67 ast_node->children.push_back(BuildASTFromTree(tree, tree_id, tree.RightChild(nid), ast_node));
69 ast_node->node_id = nid;
70 ast_node->tree_id = tree_id;
71 if (tree.HasDataCount(nid)) {
72 ast_node->data_count = tree.DataCount(nid);
74 if (tree.HasSumHess(nid)) {
75 ast_node->sum_hess = tree.SumHess(nid);
82 template void ASTBuilder<float, uint32_t>::BuildAST(
const ModelImpl<float, uint32_t>&);
83 template void ASTBuilder<float, float>::BuildAST(
const ModelImpl<float, float>&);
84 template void ASTBuilder<double, uint32_t>::BuildAST(
const ModelImpl<double, uint32_t>&);
85 template void ASTBuilder<double, double>::BuildAST(
const ModelImpl<double, double>&);
86 template ASTNode* ASTBuilder<float, uint32_t>::BuildASTFromTree(
87 const Tree<float, uint32_t>&,
int,
int, ASTNode*);
88 template ASTNode* ASTBuilder<float, float>::BuildASTFromTree(
89 const Tree<float, float>&,
int,
int, ASTNode*);
90 template ASTNode* ASTBuilder<double, uint32_t>::BuildASTFromTree(
91 const Tree<double, uint32_t>&,
int,
int, ASTNode*);
92 template ASTNode* ASTBuilder<double, double>::BuildASTFromTree(
93 const Tree<double, double>&,
int,
int, ASTNode*);