7 #include <dmlc/registry.h> 13 DMLC_REGISTRY_FILE_TAG(is_categorical_array);
16 scan_thresholds(ASTNode* node, std::vector<bool>* is_categorical) {
17 CategoricalConditionNode* cat_cond
18 =
dynamic_cast<CategoricalConditionNode*
>(node);
20 (*is_categorical)[cat_cond->split_index] =
true;
22 for (ASTNode* child : node->children) {
23 scan_thresholds(child, is_categorical);
27 template <
typename ThresholdType,
typename LeafOutputType>
29 ASTBuilder<ThresholdType, LeafOutputType>::GenerateIsCategoricalArray() {
30 this->is_categorical = std::vector<bool>(this->num_feature,
false);
31 scan_thresholds(this->main_node, &this->is_categorical);
32 return this->is_categorical;
35 template std::vector<bool> ASTBuilder<float, uint32_t>::GenerateIsCategoricalArray();
36 template std::vector<bool> ASTBuilder<float, float>::GenerateIsCategoricalArray();
37 template std::vector<bool> ASTBuilder<double, uint32_t>::GenerateIsCategoricalArray();
38 template std::vector<bool> ASTBuilder<double, double>::GenerateIsCategoricalArray();