treelite
split.cc
Go to the documentation of this file.
1 
6 #include "./builder.h"
7 
8 namespace treelite {
9 namespace compiler {
10 
11 DMLC_REGISTRY_FILE_TAG(split);
12 
13 int count_tu_nodes(ASTNode* node) {
14  int accum = (dynamic_cast<TranslationUnitNode*>(node)) ? 1 : 0;
15  for (ASTNode* child : node->children) {
16  accum += count_tu_nodes(child);
17  }
18  return accum;
19 }
20 
21 void ASTBuilder::Split(int parallel_comp) {
22  if (parallel_comp <= 0) {
23  LOG(INFO) << "Parallel compilation disabled; all member trees will be "
24  << "dumped to a single source file. This may increase "
25  << "compilation time and memory usage.";
26  return;
27  }
28  LOG(INFO) << "Parallel compilation enabled; member trees will be "
29  << "divided into " << parallel_comp << " translation units.";
30  CHECK_EQ(this->main_node->children.size(), 1);
31  ASTNode* top_ac_node = this->main_node->children[0];
32  CHECK(dynamic_cast<AccumulatorContextNode*>(top_ac_node));
33 
34  /* tree_head[i] stores reference to head of tree i */
35  std::vector<ASTNode*> tree_head;
36  for (ASTNode* node : top_ac_node->children) {
37  CHECK(dynamic_cast<ConditionNode*>(node) || dynamic_cast<OutputNode*>(node));
38  tree_head.push_back(node);
39  }
40  /* dynamic_cast<> is used here to check node types. This is to ensure
41  that we don't accidentally call Split() twice. */
42 
43  const int ntree = static_cast<int>(tree_head.size());
44  const int nunit = parallel_comp;
45  const int unit_size = (ntree + nunit - 1) / nunit;
46  std::vector<ASTNode*> tu_list; // list of translation units
47  const int current_num_tu = count_tu_nodes(this->main_node);
48  for (int unit_id = 0; unit_id < nunit; ++unit_id) {
49  const int tree_begin = unit_id * unit_size;
50  const int tree_end = std::min((unit_id + 1) * unit_size, ntree);
51  if (tree_begin < tree_end) {
52  TranslationUnitNode* tu
53  = AddNode<TranslationUnitNode>(top_ac_node, current_num_tu + unit_id);
54  tu_list.push_back(tu);
55  AccumulatorContextNode* ac = AddNode<AccumulatorContextNode>(tu);
56  tu->children.push_back(ac);
57  for (int tree_id = tree_begin; tree_id < tree_end; ++tree_id) {
58  ASTNode* tree_head_node = tree_head[tree_id];
59  tree_head_node->parent = ac;
60  ac->children.push_back(tree_head_node);
61  }
62  }
63  }
64  top_ac_node->children = tu_list;
65 }
66 
67 } // namespace compiler
68 } // namespace treelite
AST Builder class.