Treelite
src
compiler
native
main_template.h
Go to the documentation of this file.
1
8
#ifndef TREELITE_COMPILER_NATIVE_MAIN_TEMPLATE_H_
9
#define TREELITE_COMPILER_NATIVE_MAIN_TEMPLATE_H_
10
11
namespace
treelite
{
12
namespace
compiler {
13
namespace
native {
14
15
const
char
*
const
query_functions_definition_template =
16
R
"TREELITETEMPLATE(
17
size_t get_num_class(void) {{
18
return {num_class};
19
}}
20
21
size_t get_num_feature(void) {{
22
return {num_feature};
23
}}
24
25
const char* get_pred_transform(void) {{
26
return "{pred_transform}";
27
}}
28
29
float get_sigmoid_alpha(void) {{
30
return {sigmoid_alpha};
31
}}
32
33
float get_global_bias(void) {{
34
return {global_bias};
35
}}
36
37
const char* get_threshold_type(void) {{
38
return "{threshold_type_str}";
39
}}
40
41
const char* get_leaf_output_type(void) {{
42
return "{leaf_output_type_str}";
43
}}
44
)TREELITETEMPLATE";
45
46
const
char
*
const
main_start_template =
47
R
"TREELITETEMPLATE(
48
#include "header.h"
49
50
{array_is_categorical};
51
52
{query_functions_definition}
53
54
{pred_transform_function}
55
{predict_function_signature} {{
56
)TREELITETEMPLATE";
57
58
const
char
*
const
main_end_multiclass_template =
59
R
"TREELITETEMPLATE(
60
for (int i = 0; i < {num_class}; ++i) {{
61
result[i] = sum[i]{optional_average_field} + ({leaf_output_type})({global_bias});
62
}}
63
if (!pred_margin) {{
64
return pred_transform(result);
65
}} else {{
66
return {num_class};
67
}}
68
}}
69
)TREELITETEMPLATE";
// only for multiclass classification
70
71
const
char
*
const
main_end_template =
72
R
"TREELITETEMPLATE(
73
sum = sum{optional_average_field} + ({leaf_output_type})({global_bias});
74
if (!pred_margin) {{
75
return pred_transform(sum);
76
}} else {{
77
return sum;
78
}}
79
}}
80
)TREELITETEMPLATE";
81
82
}
// namespace native
83
}
// namespace compiler
84
}
// namespace treelite
85
#endif // TREELITE_COMPILER_NATIVE_MAIN_TEMPLATE_H_
treelite
Definition:
annotator.h:18
Generated by
1.8.13