Treelite
is_categorical_array.cc
Go to the documentation of this file.
1 
7 #include <dmlc/registry.h>
8 #include "./builder.h"
9 
10 namespace treelite {
11 namespace compiler {
12 
13 DMLC_REGISTRY_FILE_TAG(is_categorical_array);
14 
15 static void
16 scan_thresholds(ASTNode* node, std::vector<bool>* is_categorical) {
17  CategoricalConditionNode* cat_cond
18  = dynamic_cast<CategoricalConditionNode*>(node);
19  if (cat_cond) {
20  (*is_categorical)[cat_cond->split_index] = true;
21  }
22  for (ASTNode* child : node->children) {
23  scan_thresholds(child, is_categorical);
24  }
25 }
26 
27 template <typename ThresholdType, typename LeafOutputType>
28 std::vector<bool>
29 ASTBuilder<ThresholdType, LeafOutputType>::GenerateIsCategoricalArray() {
30  this->is_categorical = std::vector<bool>(this->num_feature, false);
31  scan_thresholds(this->main_node, &this->is_categorical);
32  return this->is_categorical;
33 }
34 
35 template std::vector<bool> ASTBuilder<float, uint32_t>::GenerateIsCategoricalArray();
36 template std::vector<bool> ASTBuilder<float, float>::GenerateIsCategoricalArray();
37 template std::vector<bool> ASTBuilder<double, uint32_t>::GenerateIsCategoricalArray();
38 template std::vector<bool> ASTBuilder<double, double>::GenerateIsCategoricalArray();
39 
40 } // namespace compiler
41 } // namespace treelite
AST Builder class.