treelite
semantic.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_SEMANTIC_H_
8 #define TREELITE_SEMANTIC_H_
9 
10 #include <treelite/tree.h>
11 #include <algorithm>
12 
13 namespace treelite {
14 namespace semantic {
15 
17 enum class BranchHint : uint8_t {
18  kNone = 0,
19  kLikely = 1,
20  kUnlikely = 2
21 };
22 
28 inline std::string OpName(Operator op) {
29  switch(op) {
30  case Operator::kEQ: return "==";
31  case Operator::kLT: return "<";
32  case Operator::kLE: return "<=";
33  case Operator::kGT: return ">";
34  case Operator::kGE: return ">=";
35  default: return "";
36  }
37 }
38 
47 inline bool CompareWithOp(tl_float lhs, Operator op, tl_float rhs) {
48  switch(op) {
49  case Operator::kEQ: return lhs == rhs;
50  case Operator::kLT: return lhs < rhs;
51  case Operator::kLE: return lhs <= rhs;
52  case Operator::kGT: return lhs > rhs;
53  case Operator::kGE: return lhs >= rhs;
54  default: LOG(FATAL) << "operator undefined";
55  }
56 }
57 
58 using common::Cloneable;
60 
65 class CodeBlock : public Cloneable {
66  public:
67  virtual ~CodeBlock() = default;
68  virtual std::vector<std::string> Compile() const = 0;
69 };
70 
73  public:
74  explicit TranslationUnit(const CodeBlock& preamble, const CodeBlock& body)
75  : preamble(preamble), body(body) {}
76  explicit TranslationUnit(CodeBlock&& preamble, CodeBlock&& body)
77  : preamble(std::move(preamble)), body(std::move(body)) {}
78  explicit TranslationUnit(const TranslationUnit& other) = delete;
79  explicit TranslationUnit(TranslationUnit&& other) = default;
80  std::vector<std::string> Compile(const std::string& header_filename) const;
81  private:
82  DeepCopyUniquePtr<CodeBlock> preamble;
83  DeepCopyUniquePtr<CodeBlock> body;
84 };
85 
90 struct SemanticModel {
91  struct FunctionEntry {
92  std::string prototype;
93  bool dll_export;
94  FunctionEntry(const std::string& prototype, bool dll_export)
95  : prototype(prototype), dll_export(dll_export) {}
96  };
97  std::unique_ptr<CodeBlock> common_header;
98  std::vector<FunctionEntry> function_registry; // list of function prototypes
99  std::vector<TranslationUnit> units;
100 };
101 
102 inline std::ostream &operator<<(std::ostream &os,
103  const SemanticModel::FunctionEntry &entry) {
104 #ifdef _WIN32
105  const std::string declspec("__declspec(dllexport) ");
106 #else
107  const std::string declspec("");
108 #endif
109  if (entry.dll_export) {
110  os << declspec << entry.prototype << ";\n";
111  } else {
112  os << entry.prototype << ";\n";
113  }
114  return os;
115 }
116 
118 class PlainBlock : public CodeBlock {
119  public:
120  explicit PlainBlock()
121  : inner_text({}) {}
122  explicit PlainBlock(const std::string& inner_text)
123  : inner_text({inner_text}) {}
124  explicit PlainBlock(const std::vector<std::string>& inner_text)
125  : inner_text(inner_text) {}
126  explicit PlainBlock(std::vector<std::string>&& inner_text)
127  : inner_text(std::move(inner_text)) {}
129  std::vector<std::string> Compile() const override;
130  private:
131  std::vector<std::string> inner_text;
132 };
133 
138 class FunctionBlock : public CodeBlock {
139  private:
141 
142  public:
143  explicit FunctionBlock(const std::string& prototype,
144  const CodeBlock& body,
145  std::vector<FunctionEntry>* p_function_registry,
146  bool dll_export = false)
147  : prototype(prototype), body(body), dll_export(dll_export) {
148  if (p_function_registry != nullptr) {
149  p_function_registry->emplace_back(this->prototype, this->dll_export);
150  }
151  }
152  explicit FunctionBlock(std::string&& prototype,
153  CodeBlock&& body,
154  std::vector<FunctionEntry>* p_function_registry,
155  bool dll_export = false)
156  : prototype(std::move(prototype)), body(std::move(body)),
157  dll_export(dll_export) {
158  if (p_function_registry != nullptr) {
159  p_function_registry->emplace_back(this->prototype, this->dll_export);
160  }
161  }
163  std::vector<std::string> Compile() const override;
164  private:
165  std::string prototype;
166  bool dll_export;
167  DeepCopyUniquePtr<CodeBlock> body;
168 };
169 
171 class SequenceBlock : public CodeBlock {
172  public:
173  explicit SequenceBlock() = default;
175  std::vector<std::string> Compile() const override;
176  void Reserve(size_t size);
177  void PushBack(const CodeBlock& block);
178  void PushBack(CodeBlock&& block);
179  private:
180  std::vector<DeepCopyUniquePtr<CodeBlock>> sequence;
181 };
182 
184 class Condition : public Cloneable {
185  public:
186  virtual ~Condition() = default;
187  virtual std::string Compile() const = 0;
188 };
189 
194 class IfElseBlock : public CodeBlock {
195  public:
196  explicit IfElseBlock(const Condition& condition,
197  const CodeBlock& if_block,
198  const CodeBlock& else_block,
199  BranchHint hint = BranchHint::kNone)
200  : condition(condition), if_block(if_block), else_block(else_block),
201  branch_hint(hint) {}
202  explicit IfElseBlock(Condition&& condition,
203  CodeBlock&& if_block,
204  CodeBlock&& else_block,
205  BranchHint hint = BranchHint::kNone)
206  : condition(std::move(condition)),
207  if_block(std::move(if_block)), else_block(std::move(else_block)),
208  branch_hint(hint) {}
210  std::vector<std::string> Compile() const override;
211  private:
212  DeepCopyUniquePtr<Condition> condition;
213  DeepCopyUniquePtr<CodeBlock> if_block;
214  DeepCopyUniquePtr<CodeBlock> else_block;
215  BranchHint branch_hint;
216 };
217 
218 } // namespace semantic
219 } // namespace treelite
220 #endif // TREELITE_SEMANTIC_H_
plain code block containing one or more lines of code
Definition: semantic.h:118
float tl_float
float type to be used internally
Definition: base.h:17
fundamental block in semantic model. All code blocks should inherit from this class.
Definition: semantic.h:65
model structure for tree
std::string OpName(Operator op)
get string representation of comparsion operator
Definition: semantic.h:28
bool CompareWithOp(tl_float lhs, Operator op, tl_float rhs)
perform comparison between two float&#39;s using a comparsion operator The comparison will be in the form...
Definition: semantic.h:47
BranchHint
enum class to store branch annotation
Definition: semantic.h:17
abstract interface for classes that can be cloned
Definition: common.h:24
a wrapper around std::unique_ptr that supports deep copying and moving.
Definition: common.h:47
function block with a prototype and code body. Its prototype can optionally be registered with a func...
Definition: semantic.h:138
if-else statement with condition may store a branch hint (>50% or <50% likely)
Definition: semantic.h:194
#define CLONEABLE_BOILERPLATE(className)
macro to define boilerplate code for Cloneable classes
Definition: common.h:32
a conditional expression
Definition: semantic.h:184
sequence of one or more code blocks
Definition: semantic.h:171
Definition: semantic.h:91
translation unit is abstraction of a source file
Definition: semantic.h:72
semantic model consists of a header, function registry, and a list of translation units ...
Definition: semantic.h:90
Operator
comparison operators
Definition: base.h:23