Treelite
main_template.h
Go to the documentation of this file.
1 
8 #ifndef TREELITE_COMPILER_NATIVE_MAIN_TEMPLATE_H_
9 #define TREELITE_COMPILER_NATIVE_MAIN_TEMPLATE_H_
10 
11 namespace treelite {
12 namespace compiler {
13 namespace native {
14 
15 const char* const query_functions_definition_template =
16 R"TREELITETEMPLATE(
17 size_t get_num_class(void) {{
18  return {num_class};
19 }}
20 
21 size_t get_num_feature(void) {{
22  return {num_feature};
23 }}
24 
25 const char* get_pred_transform(void) {{
26  return "{pred_transform}";
27 }}
28 
29 float get_sigmoid_alpha(void) {{
30  return {sigmoid_alpha};
31 }}
32 
33 float get_global_bias(void) {{
34  return {global_bias};
35 }}
36 
37 const char* get_threshold_type(void) {{
38  return "{threshold_type_str}";
39 }}
40 
41 const char* get_leaf_output_type(void) {{
42  return "{leaf_output_type_str}";
43 }}
44 )TREELITETEMPLATE";
45 
46 const char* const main_start_template =
47 R"TREELITETEMPLATE(
48 #include "header.h"
49 
50 {array_is_categorical};
51 
52 {query_functions_definition}
53 
54 {pred_transform_function}
55 {predict_function_signature} {{
56 )TREELITETEMPLATE";
57 
58 const char* const main_end_multiclass_template =
59 R"TREELITETEMPLATE(
60  for (int i = 0; i < {num_class}; ++i) {{
61  result[i] = sum[i]{optional_average_field} + ({leaf_output_type})({global_bias});
62  }}
63  if (!pred_margin) {{
64  return pred_transform(result);
65  }} else {{
66  return {num_class};
67  }}
68 }}
69 )TREELITETEMPLATE"; // only for multiclass classification
70 
71 const char* const main_end_template =
72 R"TREELITETEMPLATE(
73  sum = sum{optional_average_field} + ({leaf_output_type})({global_bias});
74  if (!pred_margin) {{
75  return pred_transform(sum);
76  }} else {{
77  return sum;
78  }}
79 }}
80 )TREELITETEMPLATE";
81 
82 } // namespace native
83 } // namespace compiler
84 } // namespace treelite
85 #endif // TREELITE_COMPILER_NATIVE_MAIN_TEMPLATE_H_