Treelite
builder.h
Go to the documentation of this file.
1 
6 #ifndef TREELITE_COMPILER_AST_BUILDER_H_
7 #define TREELITE_COMPILER_AST_BUILDER_H_
8 
9 #include <treelite/tree.h>
10 #include <map>
11 #include <string>
12 #include <vector>
13 #include <ostream>
14 #include <utility>
15 #include <memory>
16 #include "./ast.h"
17 
18 namespace treelite {
19 namespace compiler {
20 
21 // forward declarations
22 template <typename ThresholdType, typename LeafOutputType>
23 class ASTBuilder;
24 struct CodeFoldingContext;
25 template <typename ThresholdType, typename LeafOutputType>
27 
28 template <typename ThresholdType, typename LeafOutputType>
29 class ASTBuilder {
30  public:
31  ASTBuilder() : output_vector_flag(false), quantize_threshold_flag(false), main_node(nullptr) {}
32 
33  /* \brief initially build AST from model */
34  void BuildAST(const ModelImpl<ThresholdType, LeafOutputType>& model);
35  /* \brief generate is_categorical[] array, which tells whether each feature
36  is categorical or numerical */
37  std::vector<bool> GenerateIsCategoricalArray();
38  /*
39  * \brief fold rarely visited subtrees into tight loops (don't produce
40  * if/else blocks). Rarity of each node is determined by its
41  * data count and/or hessian sum: any node is "rare" if its data count
42  * or hessian sum is lower than the proscribed threshold.
43  * \param magnitude_req all nodes whose data counts are lower than that of
44  * the root node of the decision tree by [magnitude_req]
45  * will be folded. To diable folding, set to +inf. If
46  * hessian sums are available instead of data counts,
47  * hessian sums will be used as a proxy of data counts
48  * \param create_new_translation_unit if true, place folded loops in
49  * separate translation units
50  * \param whether at least one subtree was folded
51  */
52  bool FoldCode(double magnitude_req, bool create_new_translation_unit = false);
53  /*
54  * \brief split prediction function into multiple translation units
55  * \param parallel_comp number of translation units
56  */
57  void Split(int parallel_comp);
58  /* \brief replace split thresholds with integers */
59  void QuantizeThresholds();
60  /* \brief Load data counts from annotation file */
61  void LoadDataCounts(const std::vector<std::vector<size_t>>& counts);
62  /*
63  * \brief Get a text representation of AST
64  */
65  std::string GetDump() const;
66 
67  inline const ASTNode* GetRootNode() {
68  return main_node;
69  }
70 
71  private:
72  friend bool treelite::compiler::fold_code<>(ASTNode*, CodeFoldingContext*,
74 
75  template <typename NodeType, typename ...Args>
76  NodeType* AddNode(ASTNode* parent, Args&& ...args) {
77  std::unique_ptr<NodeType> node(new NodeType(std::forward<Args>(args)...));
78  NodeType* ref = node.get();
79  ref->parent = parent;
80  nodes.push_back(std::move(node));
81  return ref;
82  }
83 
84  ASTNode* BuildASTFromTree(const Tree<ThresholdType, LeafOutputType>& tree, int tree_id, int nid,
85  ASTNode* parent);
86 
87  // keep tract of all nodes built so far, to prevent memory leak
88  std::vector<std::unique_ptr<ASTNode>> nodes;
89  bool output_vector_flag;
90  bool quantize_threshold_flag;
91  int num_feature;
92  bool average_output_flag;
93  ASTNode* main_node;
94  std::vector<bool> is_categorical;
95  std::map<std::string, std::string> model_param;
96 };
97 
98 } // namespace compiler
99 } // namespace treelite
100 
101 #endif // TREELITE_COMPILER_AST_BUILDER_H_
model structure for tree ensemble
Definition for AST classes.
in-memory representation of a decision tree
Definition: tree.h:191