Treelite
quantize.cc
Go to the documentation of this file.
1 
6 #include <treelite/math.h>
7 #include <treelite/logging.h>
8 #include <set>
9 #include <cmath>
10 #include "./builder.h"
11 
12 namespace treelite {
13 namespace compiler {
14 
15 template <typename ThresholdType>
16 static void
17 scan_thresholds(ASTNode* node, std::vector<std::set<ThresholdType>>* cut_pts) {
18  NumericalConditionNode<ThresholdType>* num_cond;
19  if ( (num_cond = dynamic_cast<NumericalConditionNode<ThresholdType>*>(node)) ) {
20  TREELITE_CHECK(!num_cond->quantized) << "should not be already quantized";
21  const ThresholdType threshold = num_cond->threshold.float_val;
22  if (std::isfinite(threshold)) {
23  (*cut_pts)[num_cond->split_index].insert(threshold);
24  }
25  }
26  for (ASTNode* child : node->children) {
27  scan_thresholds(child, cut_pts);
28  }
29 }
30 
31 template <typename ThresholdType>
32 static void
33 rewrite_thresholds(ASTNode* node, const std::vector<std::vector<ThresholdType>>& cut_pts) {
34  NumericalConditionNode<ThresholdType>* num_cond;
35  if ( (num_cond = dynamic_cast<NumericalConditionNode<ThresholdType>*>(node)) ) {
36  TREELITE_CHECK(!num_cond->quantized) << "should not be already quantized";
37  const ThresholdType threshold = num_cond->threshold.float_val;
38  if (std::isfinite(threshold)) {
39  const auto& v = cut_pts[num_cond->split_index];
40  {
41  auto loc = math::binary_search(v.begin(), v.end(), threshold);
42  TREELITE_CHECK(loc != v.end());
43  num_cond->threshold.int_val = static_cast<int>(loc - v.begin()) * 2;
44  }
45  {
46  ThresholdType zero = static_cast<ThresholdType>(0);
47  auto loc = std::lower_bound(v.begin(), v.end(), zero);
48  num_cond->zero_quantized = static_cast<int>(loc - v.begin()) * 2;
49  if (loc != v.end() && zero != *loc) {
50  --num_cond->zero_quantized;
51  }
52  }
53  num_cond->quantized = true;
54  } // splits with infinite thresholds will not be quantized
55  }
56  for (ASTNode* child : node->children) {
57  rewrite_thresholds(child, cut_pts);
58  }
59 }
60 
61 template <typename ThresholdType, typename LeafOutputType>
62 void
63 ASTBuilder<ThresholdType, LeafOutputType>::QuantizeThresholds() {
64  this->quantize_threshold_flag = true;
65  std::vector<std::set<ThresholdType>> cut_pts;
66  std::vector<std::vector<ThresholdType>> cut_pts_vec;
67  cut_pts.resize(this->num_feature);
68  cut_pts_vec.resize(this->num_feature);
69  scan_thresholds(this->main_node, &cut_pts);
70  // convert cut_pts into std::vector
71  for (int i = 0; i < this->num_feature; ++i) {
72  std::copy(cut_pts[i].begin(), cut_pts[i].end(), std::back_inserter(cut_pts_vec[i]));
73  }
74 
75  /* revise all numerical splits by quantizing thresholds */
76  rewrite_thresholds(this->main_node, cut_pts_vec);
77 
78  TREELITE_CHECK_EQ(this->main_node->children.size(), 1);
79  ASTNode* top_ac_node = this->main_node->children[0];
80  TREELITE_CHECK(dynamic_cast<AccumulatorContextNode*>(top_ac_node));
81  /* dynamic_cast<> is used here to check node types. This is to ensure
82  that we don't accidentally call QuantizeThresholds() twice. */
83 
84  ASTNode* quantizer_node
85  = AddNode<QuantizerNode<ThresholdType>>(this->main_node, std::move(cut_pts_vec));
86  quantizer_node->children.push_back(top_ac_node);
87  top_ac_node->parent = quantizer_node;
88  this->main_node->children[0] = quantizer_node;
89 }
90 
91 template void ASTBuilder<float, uint32_t>::QuantizeThresholds();
92 template void ASTBuilder<float, float>::QuantizeThresholds();
93 template void ASTBuilder<double, uint32_t>::QuantizeThresholds();
94 template void ASTBuilder<double, double>::QuantizeThresholds();
95 
96 } // namespace compiler
97 } // namespace treelite
Some useful math utilities.
logging facility for Treelite
AST Builder class.