Treelite
src
compiler
native
main_template.h
Go to the documentation of this file.
1
8
#ifndef TREELITE_COMPILER_NATIVE_MAIN_TEMPLATE_H_
9
#define TREELITE_COMPILER_NATIVE_MAIN_TEMPLATE_H_
10
11
namespace
treelite
{
12
namespace
compiler {
13
namespace
native {
14
15
const
char
*
const
query_functions_definition_template =
16
R
"TREELITETEMPLATE(
17
size_t get_num_class(void) {{
18
return {num_class};
19
}}
20
21
size_t get_num_feature(void) {{
22
return {num_feature};
23
}}
24
25
const char* get_pred_transform(void) {{
26
return "{pred_transform}";
27
}}
28
29
float get_sigmoid_alpha(void) {{
30
return {sigmoid_alpha};
31
}}
32
33
float get_ratio_c(void) {{
34
return {ratio_c};
35
}}
36
37
float get_global_bias(void) {{
38
return {global_bias};
39
}}
40
41
const char* get_threshold_type(void) {{
42
return "{threshold_type_str}";
43
}}
44
45
const char* get_leaf_output_type(void) {{
46
return "{leaf_output_type_str}";
47
}}
48
)TREELITETEMPLATE";
49
50
const
char
*
const
main_start_template =
51
R
"TREELITETEMPLATE(
52
#include "header.h"
53
54
{array_is_categorical};
55
56
{query_functions_definition}
57
58
{pred_transform_function}
59
{predict_function_signature} {{
60
)TREELITETEMPLATE";
61
62
const
char
*
const
main_end_multiclass_template =
63
R
"TREELITETEMPLATE(
64
for (int i = 0; i < {num_class}; ++i) {{
65
result[i] = sum[i]{optional_average_field} + ({leaf_output_type})({global_bias});
66
}}
67
if (!pred_margin) {{
68
return pred_transform(result);
69
}} else {{
70
return {num_class};
71
}}
72
}}
73
)TREELITETEMPLATE";
// only for multiclass classification
74
75
const
char
*
const
main_end_template =
76
R
"TREELITETEMPLATE(
77
sum = sum{optional_average_field} + ({leaf_output_type})({global_bias});
78
if (!pred_margin) {{
79
return pred_transform(sum);
80
}} else {{
81
return sum;
82
}}
83
}}
84
)TREELITETEMPLATE";
85
86
}
// namespace native
87
}
// namespace compiler
88
}
// namespace treelite
89
#endif // TREELITE_COMPILER_NATIVE_MAIN_TEMPLATE_H_
treelite
Definition:
annotator.h:18
Generated by
1.8.13