6 #include <dmlc/registry.h> 12 DMLC_REGISTRY_FILE_TAG(build);
14 void ASTBuilder::BuildAST(
const Model& model) {
15 this->output_vector_flag
16 = (model.num_output_group > 1 && model.random_forest_flag);
17 this->num_feature = model.num_feature;
18 this->num_output_group = model.num_output_group;
19 this->random_forest_flag = model.random_forest_flag;
21 this->main_node = AddNode<MainNode>(
nullptr, model.param.global_bias,
22 model.random_forest_flag,
25 ASTNode* ac = AddNode<AccumulatorContextNode>(this->main_node);
26 this->main_node->children.push_back(ac);
27 for (
int tree_id = 0; tree_id < model.trees.size(); ++tree_id) {
28 ASTNode* tree_head = BuildASTFromTree(model.trees[tree_id], tree_id, ac);
29 ac->children.push_back(tree_head);
31 this->model_param = model.param.__DICT__();
34 ASTNode* ASTBuilder::BuildASTFromTree(
const Tree& tree,
int tree_id,
36 return BuildASTFromTree(tree, tree_id, 0, parent);
39 ASTNode* ASTBuilder::BuildASTFromTree(
const Tree& tree,
int tree_id,
int nid,
41 ASTNode* ast_node =
nullptr;
42 if (tree.IsLeaf(nid)) {
43 if (this->output_vector_flag) {
44 ast_node = AddNode<OutputNode>(parent, tree.LeafVector(nid));
46 ast_node = AddNode<OutputNode>(parent, tree.LeafValue(nid));
49 if (tree.SplitType(nid) == SplitFeatureType::kNumerical) {
50 ast_node = AddNode<NumericalConditionNode>(parent,
52 tree.DefaultLeft(nid),
54 tree.ComparisonOp(nid),
55 ThresholdVariant(static_cast<tl_float>(
56 tree.Threshold(nid))));
58 ast_node = AddNode<CategoricalConditionNode>(parent,
60 tree.DefaultLeft(nid),
61 tree.LeftCategories(nid),
62 tree.MissingCategoryToZero(nid));
64 if (tree.HasGain(nid)) {
65 dynamic_cast<ConditionNode*
>(ast_node)->gain = tree.Gain(nid);
67 ast_node->children.push_back(BuildASTFromTree(tree, tree_id, tree.LeftChild(nid), ast_node));
68 ast_node->children.push_back(BuildASTFromTree(tree, tree_id, tree.RightChild(nid), ast_node));
70 ast_node->node_id = nid;
71 ast_node->tree_id = tree_id;
72 if (tree.HasDataCount(nid)) {
73 ast_node->data_count = tree.DataCount(nid);
75 if (tree.HasSumHess(nid)) {
76 ast_node->sum_hess = tree.SumHess(nid);