13 static void load_data_counts(ASTNode* node,
const std::vector<std::vector<uint64_t>>& counts) {
14 if (node->tree_id >= 0 && node->node_id >= 0) {
15 node->data_count = counts[node->tree_id][node->node_id];
17 for (ASTNode* child : node->children) {
18 load_data_counts(child, counts);
22 template <
typename ThresholdType,
typename LeafOutputType>
24 ASTBuilder<ThresholdType, LeafOutputType>::LoadDataCounts(
25 const std::vector<std::vector<uint64_t>>& counts) {
26 load_data_counts(this->main_node, counts);
29 template void ASTBuilder<float, uint32_t>::LoadDataCounts(
30 const std::vector<std::vector<uint64_t>>&);
31 template void ASTBuilder<float, float>::LoadDataCounts(
const std::vector<std::vector<uint64_t>>&);
32 template void ASTBuilder<double, uint32_t>::LoadDataCounts(
33 const std::vector<std::vector<uint64_t>>&);
34 template void ASTBuilder<double, double>::LoadDataCounts(
const std::vector<std::vector<uint64_t>>&);