Treelite
load_data_counts.cc
Go to the documentation of this file.
1 
7 #include <dmlc/registry.h>
8 #include <cmath>
9 #include "./builder.h"
10 
11 namespace treelite {
12 namespace compiler {
13 
14 DMLC_REGISTRY_FILE_TAG(load_data_counts);
15 
16 static void load_data_counts(ASTNode* node, const std::vector<std::vector<size_t>>& counts) {
17  if (node->tree_id >= 0 && node->node_id >= 0) {
18  node->data_count = counts[node->tree_id][node->node_id];
19  }
20  for (ASTNode* child : node->children) {
21  load_data_counts(child, counts);
22  }
23 }
24 
25 template <typename ThresholdType, typename LeafOutputType>
26 void
27 ASTBuilder<ThresholdType, LeafOutputType>::LoadDataCounts(
28  const std::vector<std::vector<size_t>>& counts) {
29  load_data_counts(this->main_node, counts);
30 }
31 
32 template void ASTBuilder<float, uint32_t>::LoadDataCounts(const std::vector<std::vector<size_t>>&);
33 template void ASTBuilder<float, float>::LoadDataCounts(const std::vector<std::vector<size_t>>&);
34 template void ASTBuilder<double, uint32_t>::LoadDataCounts(const std::vector<std::vector<size_t>>&);
35 template void ASTBuilder<double, double>::LoadDataCounts(const std::vector<std::vector<size_t>>&);
36 
37 } // namespace compiler
38 } // namespace treelite
AST Builder class.