Treelite
builder.h
Go to the documentation of this file.
1 
6 #ifndef TREELITE_COMPILER_AST_BUILDER_H_
7 #define TREELITE_COMPILER_AST_BUILDER_H_
8 
9 #include <treelite/tree.h>
10 #include <map>
11 #include <string>
12 #include <vector>
13 #include <ostream>
14 #include <utility>
15 #include <memory>
16 #include "./ast.h"
17 
18 namespace treelite {
19 namespace compiler {
20 
21 // forward declaration
22 class ASTBuilder;
23 struct CodeFoldingContext;
24 bool fold_code(ASTNode*, CodeFoldingContext*, ASTBuilder*);
25 bool breakup(ASTNode*, int, int*, ASTBuilder*);
26 
27 class ASTBuilder {
28  public:
29  ASTBuilder() : output_vector_flag(false), main_node(nullptr),
30  quantize_threshold_flag(false) {}
31 
32  /* \brief initially build AST from model */
33  void BuildAST(const Model& model);
34  /* \brief generate is_categorical[] array, which tells whether each feature
35  is categorical or numerical */
36  std::vector<bool> GenerateIsCategoricalArray();
37  /*
38  * \brief fold rarely visited subtrees into tight loops (don't produce
39  * if/else blocks). Rarity of each node is determined by its
40  * data count and/or hessian sum: any node is "rare" if its data count
41  * or hessian sum is lower than the proscribed threshold.
42  * \param magnitude_req all nodes whose data counts are lower than that of
43  * the root node of the decision tree by [magnitude_req]
44  * will be folded. To diable folding, set to +inf. If
45  * hessian sums are available instead of data counts,
46  * hessian sums will be used as a proxy of data counts
47  * \param create_new_translation_unit if true, place folded loops in
48  * separate translation units
49  * \param whether at least one subtree was folded
50  */
51  bool FoldCode(double magnitude_req, bool create_new_translation_unit = false);
52  /*
53  * \brief split prediction function into multiple translation units
54  * \param parallel_comp number of translation units
55  */
56  void Split(int parallel_comp);
57  /* \brief replace split thresholds with integers */
58  void QuantizeThresholds();
59  /* \brief Load data counts from annotation file */
60  void LoadDataCounts(const std::vector<std::vector<size_t>>& counts);
61  /*
62  * \brief Get a text representation of AST
63  */
64  std::string GetDump() const;
65 
66  inline const ASTNode* GetRootNode() {
67  return main_node;
68  }
69 
70  private:
71  friend bool treelite::compiler::fold_code(ASTNode*, CodeFoldingContext*,
72  ASTBuilder*);
73 
74  template <typename NodeType, typename ...Args>
75  NodeType* AddNode(ASTNode* parent, Args&& ...args) {
76  std::unique_ptr<NodeType> node(new NodeType(std::forward<Args>(args)...));
77  NodeType* ref = node.get();
78  ref->parent = parent;
79  nodes.push_back(std::move(node));
80  return ref;
81  }
82  ASTNode* BuildASTFromTree(const Tree& tree, int tree_id, ASTNode* parent);
83  ASTNode* BuildASTFromTree(const Tree& tree, int tree_id, int nid,
84  ASTNode* parent);
85 
86  // keep tract of all nodes built so far, to prevent memory leak
87  std::vector<std::unique_ptr<ASTNode>> nodes;
88  bool output_vector_flag;
89  bool quantize_threshold_flag;
90  int num_feature;
91  int num_output_group;
92  bool random_forest_flag;
93  ASTNode* main_node;
94  std::vector<bool> is_categorical;
95  std::map<std::string, std::string> model_param;
96 };
97 
98 } // namespace compiler
99 } // namespace treelite
100 
101 #endif // TREELITE_COMPILER_AST_BUILDER_H_
thin wrapper for tree ensemble model
Definition: tree.h:409
model structure for tree ensemble
Definition for AST classes.
in-memory representation of a decision tree
Definition: tree.h:80