Treelite
split.cc
Go to the documentation of this file.
1 
6 #include <treelite/logging.h>
7 #include "./builder.h"
8 
9 namespace treelite {
10 namespace compiler {
11 
12 int count_tu_nodes(ASTNode* node) {
13  int accum = (dynamic_cast<TranslationUnitNode*>(node)) ? 1 : 0;
14  for (ASTNode* child : node->children) {
15  accum += count_tu_nodes(child);
16  }
17  return accum;
18 }
19 
20 template <typename ThresholdType, typename LeafOutputType>
21 void
22 ASTBuilder<ThresholdType, LeafOutputType>::Split(int parallel_comp) {
23  if (parallel_comp <= 0) {
24  TREELITE_LOG(INFO) << "Parallel compilation disabled; all member trees will be "
25  << "dumped to a single source file. This may increase "
26  << "compilation time and memory usage.";
27  return;
28  }
29  TREELITE_LOG(INFO) << "Parallel compilation enabled; member trees will be "
30  << "divided into " << parallel_comp << " translation units.";
31  TREELITE_CHECK_EQ(this->main_node->children.size(), 1);
32  ASTNode* top_ac_node = this->main_node->children[0];
33  TREELITE_CHECK(dynamic_cast<AccumulatorContextNode*>(top_ac_node));
34 
35  /* tree_head[i] stores reference to head of tree i */
36  std::vector<ASTNode*> tree_head;
37  for (ASTNode* node : top_ac_node->children) {
38  TREELITE_CHECK(dynamic_cast<ConditionNode*>(node)
39  || dynamic_cast<OutputNode<LeafOutputType>*>(node)
40  || dynamic_cast<CodeFolderNode*>(node));
41  tree_head.push_back(node);
42  }
43  /* dynamic_cast<> is used here to check node types. This is to ensure
44  that we don't accidentally call Split() twice. */
45 
46  const int ntree = static_cast<int>(tree_head.size());
47  const int nunit = parallel_comp;
48  const int unit_size = (ntree + nunit - 1) / nunit;
49  std::vector<ASTNode*> tu_list; // list of translation units
50  const int current_num_tu = count_tu_nodes(this->main_node);
51  for (int unit_id = 0; unit_id < nunit; ++unit_id) {
52  const int tree_begin = unit_id * unit_size;
53  const int tree_end = std::min((unit_id + 1) * unit_size, ntree);
54  if (tree_begin < tree_end) {
55  TranslationUnitNode* tu
56  = AddNode<TranslationUnitNode>(top_ac_node, current_num_tu + unit_id);
57  tu_list.push_back(tu);
58  AccumulatorContextNode* ac = AddNode<AccumulatorContextNode>(tu);
59  tu->children.push_back(ac);
60  for (int tree_id = tree_begin; tree_id < tree_end; ++tree_id) {
61  ASTNode* tree_head_node = tree_head[tree_id];
62  tree_head_node->parent = ac;
63  ac->children.push_back(tree_head_node);
64  }
65  }
66  }
67  top_ac_node->children = tu_list;
68 }
69 
70 template void ASTBuilder<float, uint32_t>::Split(int);
71 template void ASTBuilder<float, float>::Split(int);
72 template void ASTBuilder<double, uint32_t>::Split(int);
73 template void ASTBuilder<double, double>::Split(int);
74 
75 } // namespace compiler
76 } // namespace treelite
logging facility for Treelite
AST Builder class.