13 scan_thresholds(ASTNode* node, std::vector<bool>* is_categorical) {
14 CategoricalConditionNode* cat_cond
15 =
dynamic_cast<CategoricalConditionNode*
>(node);
17 (*is_categorical)[cat_cond->split_index] =
true;
19 for (ASTNode* child : node->children) {
20 scan_thresholds(child, is_categorical);
24 template <
typename ThresholdType,
typename LeafOutputType>
26 ASTBuilder<ThresholdType, LeafOutputType>::GenerateIsCategoricalArray() {
27 this->is_categorical = std::vector<bool>(this->num_feature,
false);
28 scan_thresholds(this->main_node, &this->is_categorical);
29 return this->is_categorical;
32 template std::vector<bool> ASTBuilder<float, uint32_t>::GenerateIsCategoricalArray();
33 template std::vector<bool> ASTBuilder<float, float>::GenerateIsCategoricalArray();
34 template std::vector<bool> ASTBuilder<double, uint32_t>::GenerateIsCategoricalArray();
35 template std::vector<bool> ASTBuilder<double, double>::GenerateIsCategoricalArray();