treelite
builder.h
Go to the documentation of this file.
1 
6 #ifndef TREELITE_COMPILER_AST_BUILDER_H_
7 #define TREELITE_COMPILER_AST_BUILDER_H_
8 
9 #include <treelite/common.h>
10 #include <treelite/tree.h>
11 #include <map>
12 #include <string>
13 #include <vector>
14 #include <ostream>
15 #include "./ast.h"
16 
17 namespace treelite {
18 namespace compiler {
19 
20 // forward declaration
21 class ASTBuilder;
22 struct CodeFoldingContext;
23 bool fold_code(ASTNode*, CodeFoldingContext*, ASTBuilder*);
24 bool breakup(ASTNode*, int, int*, ASTBuilder*);
25 
26 class ASTBuilder {
27  public:
28  ASTBuilder() : output_vector_flag(false), main_node(nullptr),
29  quantize_threshold_flag(false) {}
30 
31  /* \brief initially build AST from model */
32  void BuildAST(const Model& model);
33  /* \brief generate is_categorical[] array, which tells whether each feature
34  is categorical or numerical */
35  std::vector<bool> GenerateIsCategoricalArray();
36  /*
37  * \brief fold rarely visited subtrees into tight loops (don't produce
38  * if/else blocks). Rarity of each node is determined by its
39  * data count and/or hessian sum: any node is "rare" if its data count
40  * or hessian sum is lower than the proscribed threshold.
41  * \param magnitude_req all nodes whose data counts are lower than that of
42  * the root node of the decision tree by [magnitude_req]
43  * will be folded. To diable folding, set to +inf. If
44  * hessian sums are available instead of data counts,
45  * hessian sums will be used as a proxy of data counts
46  * \param create_new_translation_unit if true, place folded loops in
47  * separate translation units
48  * \param whether at least one subtree was folded
49  */
50  bool FoldCode(double magnitude_req, bool create_new_translation_unit = false);
51  /*
52  * \brief split prediction function into multiple translation units
53  * \param parallel_comp number of translation units
54  */
55  void Split(int parallel_comp);
56  /* \brief replace split thresholds with integers */
57  void QuantizeThresholds();
58  /* \brief Load data counts from annotation file */
59  void LoadDataCounts(const std::vector<std::vector<size_t>>& counts);
60 
61  inline const ASTNode* GetRootNode() {
62  return main_node;
63  }
64 
65  private:
66  friend bool treelite::compiler::fold_code(ASTNode*, CodeFoldingContext*,
67  ASTBuilder*);
68 
69  template <typename NodeType, typename ...Args>
70  NodeType* AddNode(ASTNode* parent, Args&& ...args) {
71  std::unique_ptr<NodeType> node
72  = common::make_unique<NodeType>(std::forward<Args>(args)...);
73  NodeType* ref = node.get();
74  ref->parent = parent;
75  nodes.push_back(std::move(node));
76  return ref;
77  }
78  ASTNode* BuildASTFromTree(const Tree& tree, int tree_id, ASTNode* parent);
79  ASTNode* BuildASTFromTree(const Tree& tree, int tree_id, int nid,
80  ASTNode* parent);
81 
82  // keep tract of all nodes built so far, to prevent memory leak
83  std::vector<std::unique_ptr<ASTNode>> nodes;
84  bool output_vector_flag;
85  bool quantize_threshold_flag;
86  int num_feature;
87  int num_output_group;
88  bool random_forest_flag;
89  ASTNode* main_node;
90  std::vector<bool> is_categorical;
91  std::map<std::string, std::string> model_param;
92 };
93 
94 } // namespace compiler
95 } // namespace treelite
96 
97 #endif // TREELITE_COMPILER_AST_BUILDER_H_
thin wrapper for tree ensemble model
Definition: tree.h:415
model structure for tree
Definition for AST classes.
in-memory representation of a decision tree
Definition: tree.h:22