Treelite
fold_code.cc
1 
7 #include <dmlc/registry.h>
8 #include <cmath>
9 #include <limits>
10 #include "./builder.h"
11 
12 namespace treelite {
13 namespace compiler {
14 
15 DMLC_REGISTRY_FILE_TAG(fold_code);
16 
18  double magnitude_req;
19  double log_root_data_count;
20  double log_root_sum_hess;
21  bool create_new_translation_unit;
22  int num_tu;
23 };
24 
25 template <typename ThresholdType, typename LeafOutputType>
26 bool fold_code(ASTNode* node, CodeFoldingContext* context,
28  if (node->node_id == 0) {
29  if (node->data_count) {
30  context->log_root_data_count = std::log(node->data_count.value());
31  } else {
32  context->log_root_data_count = std::numeric_limits<double>::quiet_NaN();
33  }
34  if (node->sum_hess) {
35  context->log_root_sum_hess = std::log(node->sum_hess.value());
36  } else {
37  context->log_root_sum_hess = std::numeric_limits<double>::quiet_NaN();
38  }
39  }
40 
41  if ( (node->data_count && !std::isnan(context->log_root_data_count)
42  && context->log_root_data_count - std::log(node->data_count.value())
43  >= context->magnitude_req)
44  || (node->sum_hess && !std::isnan(context->log_root_sum_hess)
45  && context->log_root_sum_hess - std::log(node->sum_hess.value())
46  >= context->magnitude_req) ) {
47  // fold the subtree whose root is [node]
48  ASTNode* parent_node = node->parent;
49  ASTNode* folder_node = nullptr;
50  ASTNode* tu_node = nullptr;
51  if (context->create_new_translation_unit) {
52  tu_node = builder->template AddNode<TranslationUnitNode>(parent_node, context->num_tu++);
53  ASTNode* ac = builder->template AddNode<AccumulatorContextNode>(tu_node);
54  folder_node = builder->template AddNode<CodeFolderNode>(ac);
55  tu_node->children.push_back(ac);
56  ac->children.push_back(folder_node);
57  } else {
58  folder_node = builder->template AddNode<CodeFolderNode>(parent_node);
59  }
60  size_t node_loc = -1; // is current node 1st child or 2nd child or so forth
61  for (size_t i = 0; i < parent_node->children.size(); ++i) {
62  if (parent_node->children[i] == node) {
63  node_loc = i;
64  break;
65  }
66  }
67  CHECK_NE(node_loc, -1); // parent should have a link to current node
68  parent_node->children[node_loc]
69  = context->create_new_translation_unit ? tu_node : folder_node;
70  folder_node->children.push_back(node);
71  node->parent = folder_node;
72  return true;
73  } else {
74  bool folded_at_least_once = false;
75  for (ASTNode* child : node->children) {
76  folded_at_least_once |= fold_code(child, context, builder);
77  }
78  return folded_at_least_once;
79  }
80 }
81 
82 int count_tu_nodes(ASTNode* node);
83 
84 template <typename ThresholdType, typename LeafOutputType>
85 bool
87  double magnitude_req, bool create_new_translation_unit) {
88  CodeFoldingContext context{magnitude_req,
89  std::numeric_limits<double>::quiet_NaN(),
90  std::numeric_limits<double>::quiet_NaN(),
91  create_new_translation_unit,
92  count_tu_nodes(this->main_node)};
93  return fold_code(this->main_node, &context, this);
94 }
95 
96 template bool ASTBuilder<float, uint32_t>::FoldCode(double, bool);
97 template bool ASTBuilder<float, float>::FoldCode(double, bool);
98 template bool ASTBuilder<double, uint32_t>::FoldCode(double, bool);
99 template bool ASTBuilder<double, double>::FoldCode(double, bool);
100 
101 } // namespace compiler
102 } // namespace treelite
AST Builder class.