7 DMLC_REGISTRY_FILE_TAG(annotate);
9 static void annotate(ASTNode* node,
10 const std::vector<std::vector<size_t>>& counts) {
11 const int tree_id = node->tree_id;
12 ConditionNode* cond_node;
13 if ( (cond_node = dynamic_cast<ConditionNode*>(node)) ) {
14 const std::vector<ASTNode*>& children = cond_node->children;
15 CHECK_EQ(children.size(), 2);
16 CHECK_EQ(tree_id, children[0]->tree_id);
17 CHECK_EQ(tree_id, children[1]->tree_id);
18 const size_t left_count = counts[tree_id][children[0]->node_id];
19 const size_t right_count = counts[tree_id][children[1]->node_id];
20 cond_node->branch_hint = (left_count > right_count) ? BranchHint::kLikely
21 : BranchHint::kUnlikely;
24 for (ASTNode* child : node->children) {
25 annotate(child, counts);
30 ASTBuilder::AnnotateBranches(
const std::vector<std::vector<size_t>>& counts) {
31 annotate(this->main_node, counts);