Treelite
fold_code.cc
1 
7 #include <treelite/logging.h>
8 #include <limits>
9 #include <cmath>
10 #include "./builder.h"
11 
12 namespace treelite {
13 namespace compiler {
14 
16  double magnitude_req;
17  double log_root_data_count;
18  double log_root_sum_hess;
19  bool create_new_translation_unit;
20  int num_tu;
21 };
22 
23 template <typename ThresholdType, typename LeafOutputType>
24 bool fold_code(ASTNode* node, CodeFoldingContext* context,
26  if (node->node_id == 0) {
27  if (node->data_count) {
28  context->log_root_data_count = std::log(*node->data_count);
29  } else {
30  context->log_root_data_count = std::numeric_limits<double>::quiet_NaN();
31  }
32  if (node->sum_hess) {
33  context->log_root_sum_hess = std::log(*node->sum_hess);
34  } else {
35  context->log_root_sum_hess = std::numeric_limits<double>::quiet_NaN();
36  }
37  }
38 
39  if ( (node->data_count && !std::isnan(context->log_root_data_count)
40  && context->log_root_data_count - std::log(*node->data_count)
41  >= context->magnitude_req)
42  || (node->sum_hess && !std::isnan(context->log_root_sum_hess)
43  && context->log_root_sum_hess - std::log(*node->sum_hess)
44  >= context->magnitude_req) ) {
45  // fold the subtree whose root is [node]
46  ASTNode* parent_node = node->parent;
47  ASTNode* folder_node = nullptr;
48  ASTNode* tu_node = nullptr;
49  if (context->create_new_translation_unit) {
50  tu_node = builder->template AddNode<TranslationUnitNode>(parent_node, context->num_tu++);
51  ASTNode* ac = builder->template AddNode<AccumulatorContextNode>(tu_node);
52  folder_node = builder->template AddNode<CodeFolderNode>(ac);
53  tu_node->children.push_back(ac);
54  ac->children.push_back(folder_node);
55  } else {
56  folder_node = builder->template AddNode<CodeFolderNode>(parent_node);
57  }
58  size_t node_loc = -1; // is current node 1st child or 2nd child or so forth
59  for (size_t i = 0; i < parent_node->children.size(); ++i) {
60  if (parent_node->children[i] == node) {
61  node_loc = i;
62  break;
63  }
64  }
65  TREELITE_CHECK_NE(node_loc, -1); // parent should have a link to current node
66  parent_node->children[node_loc]
67  = context->create_new_translation_unit ? tu_node : folder_node;
68  folder_node->children.push_back(node);
69  node->parent = folder_node;
70  return true;
71  } else {
72  bool folded_at_least_once = false;
73  for (ASTNode* child : node->children) {
74  folded_at_least_once |= fold_code(child, context, builder);
75  }
76  return folded_at_least_once;
77  }
78 }
79 
80 int count_tu_nodes(ASTNode* node);
81 
82 template <typename ThresholdType, typename LeafOutputType>
83 bool
85  double magnitude_req, bool create_new_translation_unit) {
86  CodeFoldingContext context{magnitude_req,
87  std::numeric_limits<double>::quiet_NaN(),
88  std::numeric_limits<double>::quiet_NaN(),
89  create_new_translation_unit,
90  count_tu_nodes(this->main_node)};
91  return fold_code(this->main_node, &context, this);
92 }
93 
94 template bool ASTBuilder<float, uint32_t>::FoldCode(double, bool);
95 template bool ASTBuilder<float, float>::FoldCode(double, bool);
96 template bool ASTBuilder<double, uint32_t>::FoldCode(double, bool);
97 template bool ASTBuilder<double, double>::FoldCode(double, bool);
98 
99 } // namespace compiler
100 } // namespace treelite
logging facility for Treelite
AST Builder class.