Treelite
main_template.h
Go to the documentation of this file.
1 
8 #ifndef TREELITE_COMPILER_NATIVE_MAIN_TEMPLATE_H_
9 #define TREELITE_COMPILER_NATIVE_MAIN_TEMPLATE_H_
10 
11 namespace treelite {
12 namespace compiler {
13 namespace native {
14 
15 const char* const query_functions_definition_template =
16 R"TREELITETEMPLATE(
17 size_t get_num_class(void) {{
18  return {num_class};
19 }}
20 
21 size_t get_num_feature(void) {{
22  return {num_feature};
23 }}
24 
25 const char* get_pred_transform(void) {{
26  return "{pred_transform}";
27 }}
28 
29 float get_sigmoid_alpha(void) {{
30  return {sigmoid_alpha};
31 }}
32 
33 float get_ratio_c(void) {{
34  return {ratio_c};
35 }}
36 
37 float get_global_bias(void) {{
38  return {global_bias};
39 }}
40 
41 const char* get_threshold_type(void) {{
42  return "{threshold_type_str}";
43 }}
44 
45 const char* get_leaf_output_type(void) {{
46  return "{leaf_output_type_str}";
47 }}
48 )TREELITETEMPLATE";
49 
50 const char* const main_start_template =
51 R"TREELITETEMPLATE(
52 #include "header.h"
53 
54 {array_is_categorical};
55 
56 {query_functions_definition}
57 
58 {pred_transform_function}
59 {predict_function_signature} {{
60 )TREELITETEMPLATE";
61 
62 const char* const main_end_multiclass_template =
63 R"TREELITETEMPLATE(
64  for (int i = 0; i < {num_class}; ++i) {{
65  result[i] = sum[i]{optional_average_field} + ({leaf_output_type})({global_bias});
66  }}
67  if (!pred_margin) {{
68  return pred_transform(result);
69  }} else {{
70  return {num_class};
71  }}
72 }}
73 )TREELITETEMPLATE"; // only for multiclass classification
74 
75 const char* const main_end_template =
76 R"TREELITETEMPLATE(
77  sum = sum{optional_average_field} + ({leaf_output_type})({global_bias});
78  if (!pred_margin) {{
79  return pred_transform(sum);
80  }} else {{
81  return sum;
82  }}
83 }}
84 )TREELITETEMPLATE";
85 
86 } // namespace native
87 } // namespace compiler
88 } // namespace treelite
89 #endif // TREELITE_COMPILER_NATIVE_MAIN_TEMPLATE_H_