treelite
builder.h
Go to the documentation of this file.
1 
6 #ifndef TREELITE_COMPILER_AST_BUILDER_H_
7 #define TREELITE_COMPILER_AST_BUILDER_H_
8 
9 #include <treelite/common.h>
10 #include <treelite/tree.h>
11 #include <map>
12 #include <string>
13 #include <vector>
14 #include <ostream>
15 #include <utility>
16 #include <memory>
17 #include "./ast.h"
18 
19 namespace treelite {
20 namespace compiler {
21 
22 // forward declaration
23 class ASTBuilder;
24 struct CodeFoldingContext;
25 bool fold_code(ASTNode*, CodeFoldingContext*, ASTBuilder*);
26 bool breakup(ASTNode*, int, int*, ASTBuilder*);
27 
28 class ASTBuilder {
29  public:
30  ASTBuilder() : output_vector_flag(false), main_node(nullptr),
31  quantize_threshold_flag(false) {}
32 
33  /* \brief initially build AST from model */
34  void BuildAST(const Model& model);
35  /* \brief generate is_categorical[] array, which tells whether each feature
36  is categorical or numerical */
37  std::vector<bool> GenerateIsCategoricalArray();
38  /*
39  * \brief fold rarely visited subtrees into tight loops (don't produce
40  * if/else blocks). Rarity of each node is determined by its
41  * data count and/or hessian sum: any node is "rare" if its data count
42  * or hessian sum is lower than the proscribed threshold.
43  * \param magnitude_req all nodes whose data counts are lower than that of
44  * the root node of the decision tree by [magnitude_req]
45  * will be folded. To diable folding, set to +inf. If
46  * hessian sums are available instead of data counts,
47  * hessian sums will be used as a proxy of data counts
48  * \param create_new_translation_unit if true, place folded loops in
49  * separate translation units
50  * \param whether at least one subtree was folded
51  */
52  bool FoldCode(double magnitude_req, bool create_new_translation_unit = false);
53  /*
54  * \brief split prediction function into multiple translation units
55  * \param parallel_comp number of translation units
56  */
57  void Split(int parallel_comp);
58  /* \brief replace split thresholds with integers */
59  void QuantizeThresholds();
60  /* \brief Load data counts from annotation file */
61  void LoadDataCounts(const std::vector<std::vector<size_t>>& counts);
62 
63  inline const ASTNode* GetRootNode() {
64  return main_node;
65  }
66 
67  private:
68  friend bool treelite::compiler::fold_code(ASTNode*, CodeFoldingContext*,
69  ASTBuilder*);
70 
71  template <typename NodeType, typename ...Args>
72  NodeType* AddNode(ASTNode* parent, Args&& ...args) {
73  std::unique_ptr<NodeType> node
74  = common::make_unique<NodeType>(std::forward<Args>(args)...);
75  NodeType* ref = node.get();
76  ref->parent = parent;
77  nodes.push_back(std::move(node));
78  return ref;
79  }
80  ASTNode* BuildASTFromTree(const Tree& tree, int tree_id, ASTNode* parent);
81  ASTNode* BuildASTFromTree(const Tree& tree, int tree_id, int nid,
82  ASTNode* parent);
83 
84  // keep tract of all nodes built so far, to prevent memory leak
85  std::vector<std::unique_ptr<ASTNode>> nodes;
86  bool output_vector_flag;
87  bool quantize_threshold_flag;
88  int num_feature;
89  int num_output_group;
90  bool random_forest_flag;
91  ASTNode* main_node;
92  std::vector<bool> is_categorical;
93  std::map<std::string, std::string> model_param;
94 };
95 
96 } // namespace compiler
97 } // namespace treelite
98 
99 #endif // TREELITE_COMPILER_AST_BUILDER_H_
thin wrapper for tree ensemble model
Definition: tree.h:428
model structure for tree
Definition for AST classes.
in-memory representation of a decision tree
Definition: tree.h:23