Treelite
is_categorical_array.cc
Go to the documentation of this file.
1 
7 #include "./builder.h"
8 
9 namespace treelite {
10 namespace compiler {
11 
12 static void
13 scan_thresholds(ASTNode* node, std::vector<bool>* is_categorical) {
14  CategoricalConditionNode* cat_cond
15  = dynamic_cast<CategoricalConditionNode*>(node);
16  if (cat_cond) {
17  (*is_categorical)[cat_cond->split_index] = true;
18  }
19  for (ASTNode* child : node->children) {
20  scan_thresholds(child, is_categorical);
21  }
22 }
23 
24 template <typename ThresholdType, typename LeafOutputType>
25 std::vector<bool>
26 ASTBuilder<ThresholdType, LeafOutputType>::GenerateIsCategoricalArray() {
27  this->is_categorical = std::vector<bool>(this->num_feature, false);
28  scan_thresholds(this->main_node, &this->is_categorical);
29  return this->is_categorical;
30 }
31 
32 template std::vector<bool> ASTBuilder<float, uint32_t>::GenerateIsCategoricalArray();
33 template std::vector<bool> ASTBuilder<float, float>::GenerateIsCategoricalArray();
34 template std::vector<bool> ASTBuilder<double, uint32_t>::GenerateIsCategoricalArray();
35 template std::vector<bool> ASTBuilder<double, double>::GenerateIsCategoricalArray();
36 
37 } // namespace compiler
38 } // namespace treelite
AST Builder class.