7 #include <dmlc/registry.h> 14 DMLC_REGISTRY_FILE_TAG(load_data_counts);
16 static void load_data_counts(ASTNode* node,
const std::vector<std::vector<size_t>>& counts) {
17 if (node->tree_id >= 0 && node->node_id >= 0) {
18 node->data_count = counts[node->tree_id][node->node_id];
20 for (ASTNode* child : node->children) {
21 load_data_counts(child, counts);
25 template <
typename ThresholdType,
typename LeafOutputType>
27 ASTBuilder<ThresholdType, LeafOutputType>::LoadDataCounts(
28 const std::vector<std::vector<size_t>>& counts) {
29 load_data_counts(this->main_node, counts);
32 template void ASTBuilder<float, uint32_t>::LoadDataCounts(
const std::vector<std::vector<size_t>>&);
33 template void ASTBuilder<float, float>::LoadDataCounts(
const std::vector<std::vector<size_t>>&);
34 template void ASTBuilder<double, uint32_t>::LoadDataCounts(
const std::vector<std::vector<size_t>>&);
35 template void ASTBuilder<double, double>::LoadDataCounts(
const std::vector<std::vector<size_t>>&);