7 #include <dmlc/registry.h> 14 DMLC_REGISTRY_FILE_TAG(quantize);
17 scan_thresholds(ASTNode* node,
18 std::vector<std::set<tl_float>>* cut_pts) {
19 NumericalConditionNode* num_cond;
20 CategoricalConditionNode* cat_cond;
21 if ( (num_cond = dynamic_cast<NumericalConditionNode*>(node)) ) {
22 CHECK(!num_cond->quantized) <<
"should not be already quantized";
23 const tl_float threshold = num_cond->threshold.float_val;
24 if (std::isfinite(threshold)) {
25 (*cut_pts)[num_cond->split_index].insert(threshold);
28 for (ASTNode* child : node->children) {
29 scan_thresholds(child, cut_pts);
34 rewrite_thresholds(ASTNode* node,
35 const std::vector<std::vector<tl_float>>& cut_pts) {
36 NumericalConditionNode* num_cond;
37 if ( (num_cond = dynamic_cast<NumericalConditionNode*>(node)) ) {
38 CHECK(!num_cond->quantized) <<
"should not be already quantized";
39 const tl_float threshold = num_cond->threshold.float_val;
40 if (std::isfinite(threshold)) {
41 const auto& v = cut_pts[num_cond->split_index];
42 auto loc = math::binary_search(v.begin(), v.end(), threshold);
43 CHECK(loc != v.end());
44 num_cond->threshold.int_val =
static_cast<int>(loc - v.begin()) * 2;
45 num_cond->quantized =
true;
48 for (ASTNode* child : node->children) {
49 rewrite_thresholds(child, cut_pts);
53 void ASTBuilder::QuantizeThresholds() {
54 this->quantize_threshold_flag =
true;
55 std::vector<std::set<tl_float>> cut_pts;
56 std::vector<std::vector<tl_float>> cut_pts_vec;
57 cut_pts.resize(this->num_feature);
58 cut_pts_vec.resize(this->num_feature);
59 scan_thresholds(this->main_node, &cut_pts);
61 for (
int i = 0; i < this->num_feature; ++i) {
62 std::copy(cut_pts[i].begin(), cut_pts[i].end(),
63 std::back_inserter(cut_pts_vec[i]));
67 rewrite_thresholds(this->main_node, cut_pts_vec);
69 CHECK_EQ(this->main_node->children.size(), 1);
70 ASTNode* top_ac_node = this->main_node->children[0];
71 CHECK(dynamic_cast<AccumulatorContextNode*>(top_ac_node));
75 ASTNode* quantizer_node = AddNode<QuantizerNode>(this->main_node,
76 std::move(cut_pts_vec));
77 quantizer_node->children.push_back(top_ac_node);
78 top_ac_node->parent = quantizer_node;
79 this->main_node->children[0] = quantizer_node;
Some useful math utilities.
float tl_float
float type to be used internally