Treelite
quantize.cc
Go to the documentation of this file.
1 
6 #include <treelite/math.h>
7 #include <dmlc/registry.h>
8 #include <cmath>
9 #include "./builder.h"
10 
11 namespace treelite {
12 namespace compiler {
13 
14 DMLC_REGISTRY_FILE_TAG(quantize);
15 
16 static void
17 scan_thresholds(ASTNode* node,
18  std::vector<std::set<tl_float>>* cut_pts) {
19  NumericalConditionNode* num_cond;
20  CategoricalConditionNode* cat_cond;
21  if ( (num_cond = dynamic_cast<NumericalConditionNode*>(node)) ) {
22  CHECK(!num_cond->quantized) << "should not be already quantized";
23  const tl_float threshold = num_cond->threshold.float_val;
24  if (std::isfinite(threshold)) {
25  (*cut_pts)[num_cond->split_index].insert(threshold);
26  }
27  }
28  for (ASTNode* child : node->children) {
29  scan_thresholds(child, cut_pts);
30  }
31 }
32 
33 static void
34 rewrite_thresholds(ASTNode* node,
35  const std::vector<std::vector<tl_float>>& cut_pts) {
36  NumericalConditionNode* num_cond;
37  if ( (num_cond = dynamic_cast<NumericalConditionNode*>(node)) ) {
38  CHECK(!num_cond->quantized) << "should not be already quantized";
39  const tl_float threshold = num_cond->threshold.float_val;
40  if (std::isfinite(threshold)) {
41  const auto& v = cut_pts[num_cond->split_index];
42  auto loc = math::binary_search(v.begin(), v.end(), threshold);
43  CHECK(loc != v.end());
44  num_cond->threshold.int_val = static_cast<int>(loc - v.begin()) * 2;
45  num_cond->quantized = true;
46  } // splits with infinite thresholds will not be quantized
47  }
48  for (ASTNode* child : node->children) {
49  rewrite_thresholds(child, cut_pts);
50  }
51 }
52 
53 void ASTBuilder::QuantizeThresholds() {
54  this->quantize_threshold_flag = true;
55  std::vector<std::set<tl_float>> cut_pts;
56  std::vector<std::vector<tl_float>> cut_pts_vec;
57  cut_pts.resize(this->num_feature);
58  cut_pts_vec.resize(this->num_feature);
59  scan_thresholds(this->main_node, &cut_pts);
60  // convert cut_pts into std::vector
61  for (int i = 0; i < this->num_feature; ++i) {
62  std::copy(cut_pts[i].begin(), cut_pts[i].end(),
63  std::back_inserter(cut_pts_vec[i]));
64  }
65 
66  /* revise all numerical splits by quantizing thresholds */
67  rewrite_thresholds(this->main_node, cut_pts_vec);
68 
69  CHECK_EQ(this->main_node->children.size(), 1);
70  ASTNode* top_ac_node = this->main_node->children[0];
71  CHECK(dynamic_cast<AccumulatorContextNode*>(top_ac_node));
72  /* dynamic_cast<> is used here to check node types. This is to ensure
73  that we don't accidentally call QuantizeThresholds() twice. */
74 
75  ASTNode* quantizer_node = AddNode<QuantizerNode>(this->main_node,
76  std::move(cut_pts_vec));
77  quantizer_node->children.push_back(top_ac_node);
78  top_ac_node->parent = quantizer_node;
79  this->main_node->children[0] = quantizer_node;
80 }
81 
82 } // namespace compiler
83 } // namespace treelite
Some useful math utilities.
float tl_float
float type to be used internally
Definition: base.h:18
AST Builder class.