6 DMLC_REGISTRY_FILE_TAG(build);
8 void ASTBuilder::Build(
const Model& model) {
9 this->output_vector_flag
10 = (model.num_output_group > 1 && model.random_forest_flag);
11 this->num_feature = model.num_feature;
13 this->main_node = AddNode<MainNode>(
nullptr, model.param.global_bias,
14 model.random_forest_flag,
17 ASTNode* ac = AddNode<AccumulatorContextNode>(this->main_node);
18 this->main_node->children.push_back(ac);
19 for (
int tree_id = 0; tree_id < model.trees.size(); ++tree_id) {
20 ASTNode* tree_head = WalkTree(model.trees[tree_id], ac);
22 std::function<void(ASTNode*)> func;
23 func = [tree_id, &func](ASTNode* node) ->
void {
24 node->tree_id = tree_id;
25 for (ASTNode* child : node->children) {
30 ac->children.push_back(tree_head);
34 ASTNode* ASTBuilder::WalkTree(
const Tree& tree, ASTNode* parent) {
35 return WalkTree(tree, 0, parent);
38 ASTNode* ASTBuilder::WalkTree(
const Tree& tree,
int nid, ASTNode* parent) {
39 const Tree::Node& node = tree[nid];
40 ASTNode* ast_node =
nullptr;
42 if (this->output_vector_flag) {
43 ast_node = AddNode<OutputNode>(parent, node.leaf_vector());
45 ast_node = AddNode<OutputNode>(parent, node.leaf_value());
48 if (node.split_type() == SplitFeatureType::kNumerical) {
49 ast_node = AddNode<NumericalConditionNode>(parent,
54 ThresholdVariant(static_cast<tl_float>(node.threshold())));
56 ast_node = AddNode<CategoricalConditionNode>(parent,
59 node.left_categories());
61 ast_node->children.push_back(WalkTree(tree, node.cleft(), ast_node));
62 ast_node->children.push_back(WalkTree(tree, node.cright(), ast_node));
64 ast_node->node_id = nid;