15 template <
typename ThresholdType>
17 scan_thresholds(ASTNode* node, std::vector<std::set<ThresholdType>>* cut_pts) {
18 NumericalConditionNode<ThresholdType>* num_cond;
19 if ( (num_cond =
dynamic_cast<NumericalConditionNode<ThresholdType>*
>(node)) ) {
20 TREELITE_CHECK(!num_cond->quantized) <<
"should not be already quantized";
21 const ThresholdType threshold = num_cond->threshold.float_val;
22 if (std::isfinite(threshold)) {
23 (*cut_pts)[num_cond->split_index].insert(threshold);
26 for (ASTNode* child : node->children) {
27 scan_thresholds(child, cut_pts);
31 template <
typename ThresholdType>
33 rewrite_thresholds(ASTNode* node,
const std::vector<std::vector<ThresholdType>>& cut_pts) {
34 NumericalConditionNode<ThresholdType>* num_cond;
35 if ( (num_cond =
dynamic_cast<NumericalConditionNode<ThresholdType>*
>(node)) ) {
36 TREELITE_CHECK(!num_cond->quantized) <<
"should not be already quantized";
37 const ThresholdType threshold = num_cond->threshold.float_val;
38 if (std::isfinite(threshold)) {
39 const auto& v = cut_pts[num_cond->split_index];
41 auto loc = math::binary_search(v.begin(), v.end(), threshold);
42 TREELITE_CHECK(loc != v.end());
43 num_cond->threshold.int_val =
static_cast<int>(loc - v.begin()) * 2;
46 ThresholdType zero =
static_cast<ThresholdType
>(0);
47 auto loc = std::lower_bound(v.begin(), v.end(), zero);
48 num_cond->zero_quantized =
static_cast<int>(loc - v.begin()) * 2;
49 if (loc != v.end() && zero != *loc) {
50 --num_cond->zero_quantized;
53 num_cond->quantized =
true;
56 for (ASTNode* child : node->children) {
57 rewrite_thresholds(child, cut_pts);
61 template <
typename ThresholdType,
typename LeafOutputType>
63 ASTBuilder<ThresholdType, LeafOutputType>::QuantizeThresholds() {
64 this->quantize_threshold_flag =
true;
65 std::vector<std::set<ThresholdType>> cut_pts;
66 std::vector<std::vector<ThresholdType>> cut_pts_vec;
67 cut_pts.resize(this->num_feature);
68 cut_pts_vec.resize(this->num_feature);
69 scan_thresholds(this->main_node, &cut_pts);
71 for (
int i = 0; i < this->num_feature; ++i) {
72 std::copy(cut_pts[i].begin(), cut_pts[i].end(), std::back_inserter(cut_pts_vec[i]));
76 rewrite_thresholds(this->main_node, cut_pts_vec);
78 TREELITE_CHECK_EQ(this->main_node->children.size(), 1);
79 ASTNode* top_ac_node = this->main_node->children[0];
80 TREELITE_CHECK(dynamic_cast<AccumulatorContextNode*>(top_ac_node));
84 ASTNode* quantizer_node
85 = AddNode<QuantizerNode<ThresholdType>>(this->main_node, std::move(cut_pts_vec));
86 quantizer_node->children.push_back(top_ac_node);
87 top_ac_node->parent = quantizer_node;
88 this->main_node->children[0] = quantizer_node;
91 template void ASTBuilder<float, uint32_t>::QuantizeThresholds();
92 template void ASTBuilder<float, float>::QuantizeThresholds();
93 template void ASTBuilder<double, uint32_t>::QuantizeThresholds();
94 template void ASTBuilder<double, double>::QuantizeThresholds();
Some useful math utilities.
logging facility for Treelite