treelite
fold_code.cc
1 
7 #include <cmath>
8 #include <limits>
9 #include "./builder.h"
10 
11 namespace treelite {
12 namespace compiler {
13 
14 DMLC_REGISTRY_FILE_TAG(fold_code);
15 
17  double magnitude_req;
18  double log_root_data_count;
19  double log_root_sum_hess;
20  bool create_new_translation_unit;
21  int num_tu;
22 };
23 
24 bool fold_code(ASTNode* node, CodeFoldingContext* context,
25  ASTBuilder* builder) {
26  if (node->node_id == 0) {
27  if (node->data_count) {
28  context->log_root_data_count = std::log(node->data_count.value());
29  } else {
30  context->log_root_data_count = std::numeric_limits<double>::quiet_NaN();
31  }
32  if (node->sum_hess) {
33  context->log_root_sum_hess = std::log(node->sum_hess.value());
34  } else {
35  context->log_root_sum_hess = std::numeric_limits<double>::quiet_NaN();
36  }
37  }
38 
39  if ( (node->data_count && !std::isnan(context->log_root_data_count)
40  && context->log_root_data_count - std::log(node->data_count.value())
41  >= context->magnitude_req)
42  || (node->sum_hess && !std::isnan(context->log_root_sum_hess)
43  && context->log_root_sum_hess - std::log(node->sum_hess.value())
44  >= context->magnitude_req) ) {
45  // fold the subtree whose root is [node]
46  ASTNode* parent_node = node->parent;
47  ASTNode* folder_node = nullptr;
48  ASTNode* tu_node = nullptr;
49  if (context->create_new_translation_unit) {
50  tu_node
51  = builder->AddNode<TranslationUnitNode>(parent_node, context->num_tu++);
52  ASTNode* ac = builder->AddNode<AccumulatorContextNode>(tu_node);
53  folder_node = builder->AddNode<CodeFolderNode>(ac);
54  tu_node->children.push_back(ac);
55  ac->children.push_back(folder_node);
56  } else {
57  folder_node = builder->AddNode<CodeFolderNode>(parent_node);
58  }
59  size_t node_loc = -1; // is current node 1st child or 2nd child or so forth
60  for (size_t i = 0; i < parent_node->children.size(); ++i) {
61  if (parent_node->children[i] == node) {
62  node_loc = i;
63  break;
64  }
65  }
66  CHECK_NE(node_loc, -1); // parent should have a link to current node
67  parent_node->children[node_loc]
68  = context->create_new_translation_unit ? tu_node : folder_node;
69  folder_node->children.push_back(node);
70  node->parent = folder_node;
71  return true;
72  } else {
73  bool folded_at_least_once = false;
74  for (ASTNode* child : node->children) {
75  folded_at_least_once |= fold_code(child, context, builder);
76  }
77  return folded_at_least_once;
78  }
79 }
80 
81 int count_tu_nodes(ASTNode* node);
82 
83 bool ASTBuilder::FoldCode(double magnitude_req,
84  bool create_new_translation_unit) {
85  CodeFoldingContext context{magnitude_req,
86  std::numeric_limits<double>::quiet_NaN(),
87  std::numeric_limits<double>::quiet_NaN(),
88  create_new_translation_unit,
89  count_tu_nodes(this->main_node)};
90  return fold_code(this->main_node, &context, this);
91 }
92 
93 } // namespace compiler
94 } // namespace treelite
AST Builder class.