11 DMLC_REGISTRY_FILE_TAG(build);
13 void ASTBuilder::BuildAST(
const Model& model) {
14 this->output_vector_flag
15 = (model.num_output_group > 1 && model.random_forest_flag);
16 this->num_feature = model.num_feature;
17 this->num_output_group = model.num_output_group;
18 this->random_forest_flag = model.random_forest_flag;
20 this->main_node = AddNode<MainNode>(
nullptr, model.param.global_bias,
21 model.random_forest_flag,
24 ASTNode* ac = AddNode<AccumulatorContextNode>(this->main_node);
25 this->main_node->children.push_back(ac);
26 for (
int tree_id = 0; tree_id < model.trees.size(); ++tree_id) {
27 ASTNode* tree_head = BuildASTFromTree(model.trees[tree_id], tree_id, ac);
28 ac->children.push_back(tree_head);
30 this->model_param = model.param.__DICT__();
33 ASTNode* ASTBuilder::BuildASTFromTree(
const Tree& tree,
int tree_id,
35 return BuildASTFromTree(tree, tree_id, 0, parent);
38 ASTNode* ASTBuilder::BuildASTFromTree(
const Tree& tree,
int tree_id,
int nid,
40 const Tree::Node& node = tree[nid];
41 ASTNode* ast_node =
nullptr;
43 if (this->output_vector_flag) {
44 ast_node = AddNode<OutputNode>(parent, node.leaf_vector());
46 ast_node = AddNode<OutputNode>(parent, node.leaf_value());
49 if (node.split_type() == SplitFeatureType::kNumerical) {
50 ast_node = AddNode<NumericalConditionNode>(parent,
55 ThresholdVariant(static_cast<tl_float>(node.threshold())));
57 ast_node = AddNode<CategoricalConditionNode>(parent,
60 node.left_categories(),
61 node.missing_category_to_zero());
63 if (node.has_gain()) {
64 dynamic_cast<ConditionNode*
>(ast_node)->gain = node.gain();
66 ast_node->children.push_back(BuildASTFromTree(tree, tree_id,
67 node.cleft(), ast_node));
68 ast_node->children.push_back(BuildASTFromTree(tree, tree_id,
69 node.cright(), ast_node));
71 ast_node->node_id = nid;
72 ast_node->tree_id = tree_id;
73 if (node.has_data_count()) {
74 ast_node->data_count = node.data_count();
76 if (node.has_sum_hess()) {
77 ast_node->sum_hess = node.sum_hess();