Treelite
load_data_counts.cc
Go to the documentation of this file.
1 
7 #include <cstdint>
8 #include "./builder.h"
9 
10 namespace treelite {
11 namespace compiler {
12 
13 static void load_data_counts(ASTNode* node, const std::vector<std::vector<uint64_t>>& counts) {
14  if (node->tree_id >= 0 && node->node_id >= 0) {
15  node->data_count = counts[node->tree_id][node->node_id];
16  }
17  for (ASTNode* child : node->children) {
18  load_data_counts(child, counts);
19  }
20 }
21 
22 template <typename ThresholdType, typename LeafOutputType>
23 void
24 ASTBuilder<ThresholdType, LeafOutputType>::LoadDataCounts(
25  const std::vector<std::vector<uint64_t>>& counts) {
26  load_data_counts(this->main_node, counts);
27 }
28 
29 template void ASTBuilder<float, uint32_t>::LoadDataCounts(
30  const std::vector<std::vector<uint64_t>>&);
31 template void ASTBuilder<float, float>::LoadDataCounts(const std::vector<std::vector<uint64_t>>&);
32 template void ASTBuilder<double, uint32_t>::LoadDataCounts(
33  const std::vector<std::vector<uint64_t>>&);
34 template void ASTBuilder<double, double>::LoadDataCounts(const std::vector<std::vector<uint64_t>>&);
35 
36 } // namespace compiler
37 } // namespace treelite
AST Builder class.