Treelite
builder.h
Go to the documentation of this file.
1 
6 #ifndef TREELITE_COMPILER_AST_BUILDER_H_
7 #define TREELITE_COMPILER_AST_BUILDER_H_
8 
9 #include <treelite/tree.h>
10 #include <map>
11 #include <string>
12 #include <vector>
13 #include <ostream>
14 #include <utility>
15 #include <memory>
16 #include <cstdint>
17 #include "./ast.h"
18 
19 namespace treelite {
20 namespace compiler {
21 
22 // forward declarations
23 template <typename ThresholdType, typename LeafOutputType>
24 class ASTBuilder;
25 struct CodeFoldingContext;
26 template <typename ThresholdType, typename LeafOutputType>
28 
29 template <typename ThresholdType, typename LeafOutputType>
30 class ASTBuilder {
31  public:
32  ASTBuilder() : output_vector_flag(false), quantize_threshold_flag(false), main_node(nullptr) {}
33 
34  /* \brief initially build AST from model */
35  void BuildAST(const ModelImpl<ThresholdType, LeafOutputType>& model);
36  /* \brief generate is_categorical[] array, which tells whether each feature
37  is categorical or numerical */
38  std::vector<bool> GenerateIsCategoricalArray();
39  /*
40  * \brief fold rarely visited subtrees into tight loops (don't produce
41  * if/else blocks). Rarity of each node is determined by its
42  * data count and/or hessian sum: any node is "rare" if its data count
43  * or hessian sum is lower than the proscribed threshold.
44  * \param magnitude_req all nodes whose data counts are lower than that of
45  * the root node of the decision tree by [magnitude_req]
46  * will be folded. To diable folding, set to +inf. If
47  * hessian sums are available instead of data counts,
48  * hessian sums will be used as a proxy of data counts
49  * \param create_new_translation_unit if true, place folded loops in
50  * separate translation units
51  * \param whether at least one subtree was folded
52  */
53  bool FoldCode(double magnitude_req, bool create_new_translation_unit = false);
54  /*
55  * \brief split prediction function into multiple translation units
56  * \param parallel_comp number of translation units
57  */
58  void Split(int parallel_comp);
59  /* \brief replace split thresholds with integers */
60  void QuantizeThresholds();
61  /* \brief Load data counts from annotation file */
62  void LoadDataCounts(const std::vector<std::vector<uint64_t>>& counts);
63  /*
64  * \brief Get a text representation of AST
65  */
66  std::string GetDump() const;
67 
68  inline const ASTNode* GetRootNode() {
69  return main_node;
70  }
71 
72  private:
73  friend bool treelite::compiler::fold_code<>(ASTNode*, CodeFoldingContext*,
75 
76  template <typename NodeType, typename ...Args>
77  NodeType* AddNode(ASTNode* parent, Args&& ...args) {
78  std::unique_ptr<NodeType> node(new NodeType(std::forward<Args>(args)...));
79  NodeType* ref = node.get();
80  ref->parent = parent;
81  nodes.push_back(std::move(node));
82  return ref;
83  }
84 
85  ASTNode* BuildASTFromTree(const Tree<ThresholdType, LeafOutputType>& tree, int tree_id, int nid,
86  ASTNode* parent);
87 
88  // keep tract of all nodes built so far, to prevent memory leak
89  std::vector<std::unique_ptr<ASTNode>> nodes;
90  bool output_vector_flag;
91  bool quantize_threshold_flag;
92  int num_feature;
93  bool average_output_flag;
94  ASTNode* main_node;
95  std::vector<bool> is_categorical;
96  std::map<std::string, std::string> model_param;
97 };
98 
99 } // namespace compiler
100 } // namespace treelite
101 
102 #endif // TREELITE_COMPILER_AST_BUILDER_H_
model structure for tree ensemble
Definition for AST classes.
in-memory representation of a decision tree
Definition: tree.h:214