treelite
annotate.cc
1 #include "./builder.h"
2 #include <cmath>
3 
4 namespace treelite {
5 namespace compiler {
6 
7 DMLC_REGISTRY_FILE_TAG(annotate);
8 
9 static void annotate(ASTNode* node,
10  const std::vector<std::vector<size_t>>& counts) {
11  const int tree_id = node->tree_id;
12  ConditionNode* cond_node;
13  if ( (cond_node = dynamic_cast<ConditionNode*>(node)) ) {
14  const std::vector<ASTNode*>& children = cond_node->children;
15  CHECK_EQ(children.size(), 2);
16  CHECK_EQ(tree_id, children[0]->tree_id);
17  CHECK_EQ(tree_id, children[1]->tree_id);
18  const size_t left_count = counts[tree_id][children[0]->node_id];
19  const size_t right_count = counts[tree_id][children[1]->node_id];
20  cond_node->branch_hint = (left_count > right_count) ? BranchHint::kLikely
21  : BranchHint::kUnlikely;
22  }
23 
24  for (ASTNode* child : node->children) {
25  annotate(child, counts);
26  }
27 }
28 
29 void
30 ASTBuilder::AnnotateBranches(const std::vector<std::vector<size_t>>& counts) {
31  annotate(this->main_node, counts);
32 }
33 
34 } // namespace compiler
35 } // namespace treelite