treelite
quantize.cc
1 #include "./builder.h"
2 #include <cmath>
3 
4 namespace treelite {
5 namespace compiler {
6 
7 DMLC_REGISTRY_FILE_TAG(quantize);
8 
9 static void
10 scan_thresholds(ASTNode* node,
11  std::vector<std::set<tl_float>>* cut_pts,
12  std::vector<bool>* is_categorical) {
13  NumericalConditionNode* num_cond;
14  CategoricalConditionNode* cat_cond;
15  if ( (num_cond = dynamic_cast<NumericalConditionNode*>(node)) ) {
16  CHECK(!num_cond->quantized) << "should not be already quantized";
17  const tl_float threshold = num_cond->threshold.float_val;
18  if (std::isfinite(threshold)) {
19  (*cut_pts)[num_cond->split_index].insert(threshold);
20  }
21  } else if ( (cat_cond = dynamic_cast<CategoricalConditionNode*>(node)) ) {
22  (*is_categorical)[cat_cond->split_index] = true;
23  }
24  for (ASTNode* child : node->children) {
25  scan_thresholds(child, cut_pts, is_categorical);
26  }
27 }
28 
29 static void
30 rewrite_thresholds(ASTNode* node,
31  const std::vector<std::vector<tl_float>>& cut_pts) {
32  NumericalConditionNode* num_cond;
33  if ( (num_cond = dynamic_cast<NumericalConditionNode*>(node)) ) {
34  CHECK(!num_cond->quantized) << "should not be already quantized";
35  const tl_float threshold = num_cond->threshold.float_val;
36  if (std::isfinite(threshold)) {
37  const auto& v = cut_pts[num_cond->split_index];
38  auto loc = common::binary_search(v.begin(), v.end(), threshold);
39  CHECK(loc != v.end());
40  num_cond->threshold.int_val = static_cast<size_t>(loc - v.begin()) * 2;
41  num_cond->quantized = true;
42  } // splits with infinite thresholds will not be quantized
43  }
44  for (ASTNode* child : node->children) {
45  rewrite_thresholds(child, cut_pts);
46  }
47 }
48 
49 void ASTBuilder::QuantizeThresholds() {
50  this->quantize_threshold_flag = true;
51  std::vector<std::set<tl_float>> cut_pts;
52  std::vector<std::vector<tl_float>> cut_pts_vec;
53  std::vector<bool> is_categorical(this->num_feature, false);
54  cut_pts.resize(this->num_feature);
55  cut_pts_vec.resize(this->num_feature);
56  scan_thresholds(this->main_node, &cut_pts, &is_categorical);
57  // convert cut_pts into std::vector
58  for (int i = 0; i < this->num_feature; ++i) {
59  std::copy(cut_pts[i].begin(), cut_pts[i].end(),
60  std::back_inserter(cut_pts_vec[i]));
61  }
62 
63  /* revise all numerical splits by quantizing thresholds */
64  rewrite_thresholds(this->main_node, cut_pts_vec);
65 
66  CHECK_EQ(this->main_node->children.size(), 1);
67  ASTNode* top_ac_node = this->main_node->children[0];
68  CHECK(dynamic_cast<AccumulatorContextNode*>(top_ac_node));
69  /* dynamic_cast<> is used here to check node types. This is to ensure
70  that we don't accidentally call QuantizeThresholds() twice. */
71 
72  ASTNode* quantizer_node = AddNode<QuantizerNode>(this->main_node,
73  std::move(cut_pts_vec),
74  std::move(is_categorical));
75  quantizer_node->children.push_back(top_ac_node);
76  top_ac_node->parent = quantizer_node;
77  this->main_node->children[0] = quantizer_node;
78 }
79 
80 } // namespace compiler
81 } // namespace treelite
float tl_float
float type to be used internally
Definition: base.h:17