treelite
quantize.cc
Go to the documentation of this file.
1 
6 #include <cmath>
7 #include "./builder.h"
8 
9 namespace treelite {
10 namespace compiler {
11 
12 DMLC_REGISTRY_FILE_TAG(quantize);
13 
14 static void
15 scan_thresholds(ASTNode* node,
16  std::vector<std::set<tl_float>>* cut_pts) {
17  NumericalConditionNode* num_cond;
18  CategoricalConditionNode* cat_cond;
19  if ( (num_cond = dynamic_cast<NumericalConditionNode*>(node)) ) {
20  CHECK(!num_cond->quantized) << "should not be already quantized";
21  const tl_float threshold = num_cond->threshold.float_val;
22  if (std::isfinite(threshold)) {
23  (*cut_pts)[num_cond->split_index].insert(threshold);
24  }
25  }
26  for (ASTNode* child : node->children) {
27  scan_thresholds(child, cut_pts);
28  }
29 }
30 
31 static void
32 rewrite_thresholds(ASTNode* node,
33  const std::vector<std::vector<tl_float>>& cut_pts) {
34  NumericalConditionNode* num_cond;
35  if ( (num_cond = dynamic_cast<NumericalConditionNode*>(node)) ) {
36  CHECK(!num_cond->quantized) << "should not be already quantized";
37  const tl_float threshold = num_cond->threshold.float_val;
38  if (std::isfinite(threshold)) {
39  const auto& v = cut_pts[num_cond->split_index];
40  auto loc = common::binary_search(v.begin(), v.end(), threshold);
41  CHECK(loc != v.end());
42  num_cond->threshold.int_val = static_cast<size_t>(loc - v.begin()) * 2;
43  num_cond->quantized = true;
44  } // splits with infinite thresholds will not be quantized
45  }
46  for (ASTNode* child : node->children) {
47  rewrite_thresholds(child, cut_pts);
48  }
49 }
50 
51 void ASTBuilder::QuantizeThresholds() {
52  this->quantize_threshold_flag = true;
53  std::vector<std::set<tl_float>> cut_pts;
54  std::vector<std::vector<tl_float>> cut_pts_vec;
55  cut_pts.resize(this->num_feature);
56  cut_pts_vec.resize(this->num_feature);
57  scan_thresholds(this->main_node, &cut_pts);
58  // convert cut_pts into std::vector
59  for (int i = 0; i < this->num_feature; ++i) {
60  std::copy(cut_pts[i].begin(), cut_pts[i].end(),
61  std::back_inserter(cut_pts_vec[i]));
62  }
63 
64  /* revise all numerical splits by quantizing thresholds */
65  rewrite_thresholds(this->main_node, cut_pts_vec);
66 
67  CHECK_EQ(this->main_node->children.size(), 1);
68  ASTNode* top_ac_node = this->main_node->children[0];
69  CHECK(dynamic_cast<AccumulatorContextNode*>(top_ac_node));
70  /* dynamic_cast<> is used here to check node types. This is to ensure
71  that we don't accidentally call QuantizeThresholds() twice. */
72 
73  ASTNode* quantizer_node = AddNode<QuantizerNode>(this->main_node,
74  std::move(cut_pts_vec));
75  quantizer_node->children.push_back(top_ac_node);
76  top_ac_node->parent = quantizer_node;
77  this->main_node->children[0] = quantizer_node;
78 }
79 
80 } // namespace compiler
81 } // namespace treelite
AST Builder class.
double tl_float
float type to be used internally
Definition: base.h:17