Treelite
fold_code.cc
1 
7 #include <dmlc/registry.h>
8 #include <cmath>
9 #include <limits>
10 #include "./builder.h"
11 
12 namespace treelite {
13 namespace compiler {
14 
15 DMLC_REGISTRY_FILE_TAG(fold_code);
16 
18  double magnitude_req;
19  double log_root_data_count;
20  double log_root_sum_hess;
21  bool create_new_translation_unit;
22  int num_tu;
23 };
24 
25 bool fold_code(ASTNode* node, CodeFoldingContext* context,
26  ASTBuilder* builder) {
27  if (node->node_id == 0) {
28  if (node->data_count) {
29  context->log_root_data_count = std::log(node->data_count.value());
30  } else {
31  context->log_root_data_count = std::numeric_limits<double>::quiet_NaN();
32  }
33  if (node->sum_hess) {
34  context->log_root_sum_hess = std::log(node->sum_hess.value());
35  } else {
36  context->log_root_sum_hess = std::numeric_limits<double>::quiet_NaN();
37  }
38  }
39 
40  if ( (node->data_count && !std::isnan(context->log_root_data_count)
41  && context->log_root_data_count - std::log(node->data_count.value())
42  >= context->magnitude_req)
43  || (node->sum_hess && !std::isnan(context->log_root_sum_hess)
44  && context->log_root_sum_hess - std::log(node->sum_hess.value())
45  >= context->magnitude_req) ) {
46  // fold the subtree whose root is [node]
47  ASTNode* parent_node = node->parent;
48  ASTNode* folder_node = nullptr;
49  ASTNode* tu_node = nullptr;
50  if (context->create_new_translation_unit) {
51  tu_node
52  = builder->AddNode<TranslationUnitNode>(parent_node, context->num_tu++);
53  ASTNode* ac = builder->AddNode<AccumulatorContextNode>(tu_node);
54  folder_node = builder->AddNode<CodeFolderNode>(ac);
55  tu_node->children.push_back(ac);
56  ac->children.push_back(folder_node);
57  } else {
58  folder_node = builder->AddNode<CodeFolderNode>(parent_node);
59  }
60  size_t node_loc = -1; // is current node 1st child or 2nd child or so forth
61  for (size_t i = 0; i < parent_node->children.size(); ++i) {
62  if (parent_node->children[i] == node) {
63  node_loc = i;
64  break;
65  }
66  }
67  CHECK_NE(node_loc, -1); // parent should have a link to current node
68  parent_node->children[node_loc]
69  = context->create_new_translation_unit ? tu_node : folder_node;
70  folder_node->children.push_back(node);
71  node->parent = folder_node;
72  return true;
73  } else {
74  bool folded_at_least_once = false;
75  for (ASTNode* child : node->children) {
76  folded_at_least_once |= fold_code(child, context, builder);
77  }
78  return folded_at_least_once;
79  }
80 }
81 
82 int count_tu_nodes(ASTNode* node);
83 
84 bool ASTBuilder::FoldCode(double magnitude_req,
85  bool create_new_translation_unit) {
86  CodeFoldingContext context{magnitude_req,
87  std::numeric_limits<double>::quiet_NaN(),
88  std::numeric_limits<double>::quiet_NaN(),
89  create_new_translation_unit,
90  count_tu_nodes(this->main_node)};
91  return fold_code(this->main_node, &context, this);
92 }
93 
94 } // namespace compiler
95 } // namespace treelite
AST Builder class.