18 #include "./pred_transform.h" 22 using PredTransformFuncType = std::size_t (*) (
const treelite::Model&,
const float*,
float*);
26 int left_child,
int right_child,
int default_child) {
27 if (std::isnan(fvalue)) {
31 case treelite::Operator::kEQ:
32 return (fvalue == threshold) ? left_child : right_child;
33 case treelite::Operator::kLT:
34 return (fvalue < threshold) ? left_child : right_child;
35 case treelite::Operator::kLE:
36 return (fvalue <= threshold) ? left_child : right_child;
37 case treelite::Operator::kGT:
38 return (fvalue > threshold) ? left_child : right_child;
39 case treelite::Operator::kGE:
40 return (fvalue >= threshold) ? left_child : right_child;
42 TREELITE_CHECK(
false) <<
"Unrecognized comparison operator " <<
static_cast<int>(op);
47 inline int NextNodeCategorical(
float fvalue,
const std::vector<uint32_t>& matching_categories,
48 bool categories_list_right_child,
int left_child,
int right_child,
50 if (std::isnan(fvalue)) {
53 bool is_matching_category;
54 float max_representable_int =
static_cast<float>(uint32_t(1) << FLT_MANT_DIG);
55 if (fvalue < 0 || std::fabs(fvalue) > max_representable_int) {
56 is_matching_category =
false;
58 const auto category_value =
static_cast<uint32_t
>(fvalue);
59 is_matching_category = (
60 std::find(matching_categories.begin(), matching_categories.end(), category_value)
61 != matching_categories.end());
63 if (categories_list_right_child) {
64 return is_matching_category ? right_child : left_child;
66 return is_matching_category ? left_child : right_child;
70 template <
typename ThresholdType,
typename LeafOutputType,
typename DMatrixType,
73 const DMatrixType* input,
float* output,
bool pred_transform,
74 OutputFunc output_func) {
76 const size_t num_row = input->GetNumRow();
77 const size_t num_col = input->GetNumCol();
78 std::vector<ThresholdType> row(num_col);
80 std::vector<float> sum(task_param.
num_class);
83 std::size_t output_offset = 0;
84 for (
size_t row_id = 0; row_id < num_row; ++row_id) {
85 input->FillRow(row_id, row.data());
86 std::fill(sum.begin(), sum.end(), 0.0f);
87 const std::size_t num_tree = model.
trees.size();
88 for (std::size_t tree_id = 0; tree_id < num_tree; ++tree_id) {
89 const TreeType& tree = model.
trees[tree_id];
91 while (!tree.IsLeaf(node_id)) {
93 if (split_type == treelite::SplitFeatureType::kNumerical) {
94 node_id = NextNode(row[tree.SplitIndex(node_id)], tree.Threshold(node_id),
95 tree.ComparisonOp(node_id), tree.LeftChild(node_id),
96 tree.RightChild(node_id), tree.DefaultChild(node_id));
97 }
else if (split_type == treelite::SplitFeatureType::kCategorical) {
98 node_id = NextNodeCategorical(row[tree.SplitIndex(node_id)],
99 tree.MatchingCategories(node_id),
100 tree.CategoriesListRightChild(node_id),
101 tree.LeftChild(node_id), tree.RightChild(node_id),
102 tree.DefaultChild(node_id));
104 TREELITE_CHECK(
false) <<
"Unrecognized split type: " <<
static_cast<int>(split_type);
107 output_func(tree, tree_id, node_id, sum.data());
110 float average_factor;
111 if (model.
task_type == treelite::TaskType::kMultiClfGrovePerClass) {
114 TREELITE_CHECK_GT(task_param.
num_class, 1);
115 TREELITE_CHECK_EQ(num_tree % task_param.
num_class, 0)
116 <<
"Expected the number of trees to be divisible by the number of classes";
117 int num_boosting_round = num_tree /
static_cast<int>(task_param.
num_class);
118 average_factor =
static_cast<float>(num_boosting_round);
120 TREELITE_CHECK(model.
task_type == treelite::TaskType::kBinaryClfRegr
121 || model.
task_type == treelite::TaskType::kMultiClfProbDistLeaf);
124 average_factor =
static_cast<float>(num_tree);
126 for (
unsigned int i = 0; i < task_param.
num_class; ++i) {
127 sum[i] /= average_factor;
130 for (
unsigned int i = 0; i < task_param.
num_class; ++i) {
131 sum[i] += model.
param.global_bias;
133 if (pred_transform) {
134 PredTransformFuncType pred_transform_func
135 = treelite::gtil::LookupPredTransform(model.
param.pred_transform);
136 output_offset += pred_transform_func(model, sum.data(), &output[output_offset]);
138 for (
unsigned int i = 0; i < task_param.
num_class; ++i) {
139 output[output_offset + i] = sum[i];
143 input->ClearRow(row_id, row.data());
145 return output_offset;
148 template <
typename ThresholdType,
typename LeafOutputType,
typename DMatrixType>
150 const DMatrixType* input,
float* output,
151 bool pred_transform) {
157 auto output_logic = [task_param](
158 const TreeType& tree, int,
int node_id,
float* sum) {
159 auto leaf_vector = tree.LeafVector(node_id);
161 sum[i] += leaf_vector[i];
164 return PredictImplInner(model, input, output, pred_transform, output_logic);
167 auto output_logic = [task_param](
168 const TreeType& tree,
int tree_id,
int node_id,
float* sum) {
169 sum[tree_id % task_param.
num_class] += tree.LeafValue(node_id);
171 return PredictImplInner(model, input, output, pred_transform, output_logic);
174 auto output_logic = [task_param](
175 const TreeType& tree,
int tree_id,
int node_id,
float* sum) {
176 sum[0] += tree.LeafValue(node_id);
178 return PredictImplInner(model, input, output, pred_transform, output_logic);
187 std::size_t Predict(
const Model* model,
const DMatrix* input,
float* output,
bool pred_transform) {
189 const auto* d1 =
dynamic_cast<const DenseDMatrixImpl<float>*
>(input);
190 const auto* d2 =
dynamic_cast<const CSRDMatrixImpl<float>*
>(input);
192 return model->Dispatch([d1, output, pred_transform](
const auto& model) {
193 return PredictImpl(model, d1, output, pred_transform);
196 return model->Dispatch([d2, output, pred_transform](
const auto& model) {
197 return PredictImpl(model, d2, output, pred_transform);
200 TREELITE_LOG(FATAL) <<
"DMatrix with float64 data is not supported";
205 std::size_t Predict(
const Model* model,
const float* input, std::size_t num_row,
float* output,
206 bool pred_transform) {
207 std::unique_ptr<DenseDMatrixImpl<float>> dmat =
208 std::make_unique<DenseDMatrixImpl<float>>(
209 std::vector<float>(input, input + num_row * model->num_feature),
210 std::numeric_limits<float>::quiet_NaN(),
213 return Predict(model, dmat.get(), output, pred_transform);
216 std::size_t GetPredictOutputSize(
const Model* model, std::size_t num_row) {
217 return model->task_param.num_class * num_row;
220 std::size_t GetPredictOutputSize(
const Model* model,
const DMatrix* input) {
221 return GetPredictOutputSize(model, input->GetNumRow());
ModelParam param
extra parameters
SplitFeatureType
feature split type
bool grove_per_class
Whether we designate a subset of the trees to compute the prediction for each class.
Group of parameters that are dependent on the choice of the task type.
bool average_tree_output
whether to average tree outputs
Input data structure of Treelite.
model structure for tree ensemble
unsigned int leaf_vector_size
Dimension of the output from each leaf node.
in-memory representation of a decision tree
logging facility for Treelite
unsigned int num_class
The number of classes in the target label.
TaskType task_type
Task type.
std::vector< Tree< ThresholdType, LeafOutputType > > trees
member trees
General Tree Inference Library (GTIL), providing a reference implementation for predicting with decis...
TaskParam task_param
Group of parameters that are specific to the particular task type.
thin wrapper for tree ensemble model
Operator
comparison operators