7 #include <dmlc/registry.h> 15 DMLC_REGISTRY_FILE_TAG(fold_code);
19 double log_root_data_count;
20 double log_root_sum_hess;
21 bool create_new_translation_unit;
25 template <
typename ThresholdType,
typename LeafOutputType>
28 if (node->node_id == 0) {
29 if (node->data_count) {
30 context->log_root_data_count = std::log(node->data_count.value());
32 context->log_root_data_count = std::numeric_limits<double>::quiet_NaN();
35 context->log_root_sum_hess = std::log(node->sum_hess.value());
37 context->log_root_sum_hess = std::numeric_limits<double>::quiet_NaN();
41 if ( (node->data_count && !std::isnan(context->log_root_data_count)
42 && context->log_root_data_count - std::log(node->data_count.value())
43 >= context->magnitude_req)
44 || (node->sum_hess && !std::isnan(context->log_root_sum_hess)
45 && context->log_root_sum_hess - std::log(node->sum_hess.value())
46 >= context->magnitude_req) ) {
48 ASTNode* parent_node = node->parent;
51 if (context->create_new_translation_unit) {
52 tu_node = builder->template AddNode<TranslationUnitNode>(parent_node, context->num_tu++);
53 ASTNode* ac = builder->template AddNode<AccumulatorContextNode>(tu_node);
54 folder_node = builder->template AddNode<CodeFolderNode>(ac);
55 tu_node->children.push_back(ac);
56 ac->children.push_back(folder_node);
58 folder_node = builder->template AddNode<CodeFolderNode>(parent_node);
61 for (
size_t i = 0; i < parent_node->children.size(); ++i) {
62 if (parent_node->children[i] == node) {
67 CHECK_NE(node_loc, -1);
68 parent_node->children[node_loc]
69 = context->create_new_translation_unit ? tu_node : folder_node;
70 folder_node->children.push_back(node);
71 node->parent = folder_node;
74 bool folded_at_least_once =
false;
75 for (
ASTNode* child : node->children) {
76 folded_at_least_once |= fold_code(child, context, builder);
78 return folded_at_least_once;
82 int count_tu_nodes(
ASTNode* node);
84 template <
typename ThresholdType,
typename LeafOutputType>
87 double magnitude_req,
bool create_new_translation_unit) {
89 std::numeric_limits<double>::quiet_NaN(),
90 std::numeric_limits<double>::quiet_NaN(),
91 create_new_translation_unit,
92 count_tu_nodes(this->main_node)};
93 return fold_code(this->main_node, &context,
this);