7 DMLC_REGISTRY_FILE_TAG(quantize);
10 scan_thresholds(ASTNode* node,
11 std::vector<std::set<tl_float>>* cut_pts,
12 std::vector<bool>* is_categorical) {
13 NumericalConditionNode* num_cond;
14 CategoricalConditionNode* cat_cond;
15 if ( (num_cond = dynamic_cast<NumericalConditionNode*>(node)) ) {
16 CHECK(!num_cond->quantized) <<
"should not be already quantized";
17 const tl_float threshold = num_cond->threshold.float_val;
18 if (std::isfinite(threshold)) {
19 (*cut_pts)[num_cond->split_index].insert(threshold);
21 }
else if ( (cat_cond = dynamic_cast<CategoricalConditionNode*>(node)) ) {
22 (*is_categorical)[cat_cond->split_index] =
true;
24 for (ASTNode* child : node->children) {
25 scan_thresholds(child, cut_pts, is_categorical);
30 rewrite_thresholds(ASTNode* node,
31 const std::vector<std::vector<tl_float>>& cut_pts) {
32 NumericalConditionNode* num_cond;
33 if ( (num_cond = dynamic_cast<NumericalConditionNode*>(node)) ) {
34 CHECK(!num_cond->quantized) <<
"should not be already quantized";
35 const tl_float threshold = num_cond->threshold.float_val;
36 if (std::isfinite(threshold)) {
37 const auto& v = cut_pts[num_cond->split_index];
38 auto loc = common::binary_search(v.begin(), v.end(), threshold);
39 CHECK(loc != v.end());
40 num_cond->threshold.int_val =
static_cast<size_t>(loc - v.begin()) * 2;
41 num_cond->quantized =
true;
44 for (ASTNode* child : node->children) {
45 rewrite_thresholds(child, cut_pts);
49 void ASTBuilder::QuantizeThresholds() {
50 this->quantize_threshold_flag =
true;
51 std::vector<std::set<tl_float>> cut_pts;
52 std::vector<std::vector<tl_float>> cut_pts_vec;
53 std::vector<bool> is_categorical(this->num_feature,
false);
54 cut_pts.resize(this->num_feature);
55 cut_pts_vec.resize(this->num_feature);
56 scan_thresholds(this->main_node, &cut_pts, &is_categorical);
58 for (
int i = 0; i < this->num_feature; ++i) {
59 std::copy(cut_pts[i].begin(), cut_pts[i].end(),
60 std::back_inserter(cut_pts_vec[i]));
64 rewrite_thresholds(this->main_node, cut_pts_vec);
66 CHECK_EQ(this->main_node->children.size(), 1);
67 ASTNode* top_ac_node = this->main_node->children[0];
68 CHECK(dynamic_cast<AccumulatorContextNode*>(top_ac_node));
72 ASTNode* quantizer_node = AddNode<QuantizerNode>(this->main_node,
73 std::move(cut_pts_vec),
74 std::move(is_categorical));
75 quantizer_node->children.push_back(top_ac_node);
76 top_ac_node->parent = quantizer_node;
77 this->main_node->children[0] = quantizer_node;
float tl_float
float type to be used internally