Treelite
predict.cc
Go to the documentation of this file.
1 
9 #include <treelite/gtil.h>
10 #include <treelite/tree.h>
11 #include <treelite/data.h>
12 #include <treelite/logging.h>
13 #include <limits>
14 #include <vector>
15 #include <cmath>
16 #include <cstddef>
17 #include <cfloat>
18 #include "./pred_transform.h"
19 
20 namespace {
21 
22 using PredTransformFuncType = std::size_t (*) (const treelite::Model&, const float*, float*);
23 
24 template <typename T>
25 inline int NextNode(float fvalue, T threshold, treelite::Operator op,
26  int left_child, int right_child, int default_child) {
27  if (std::isnan(fvalue)) {
28  return default_child;
29  }
30  switch (op) {
31  case treelite::Operator::kEQ:
32  return (fvalue == threshold) ? left_child : right_child;
33  case treelite::Operator::kLT:
34  return (fvalue < threshold) ? left_child : right_child;
35  case treelite::Operator::kLE:
36  return (fvalue <= threshold) ? left_child : right_child;
37  case treelite::Operator::kGT:
38  return (fvalue > threshold) ? left_child : right_child;
39  case treelite::Operator::kGE:
40  return (fvalue >= threshold) ? left_child : right_child;
41  default:
42  TREELITE_CHECK(false) << "Unrecognized comparison operator " << static_cast<int>(op);
43  return -1;
44  }
45 }
46 
47 inline int NextNodeCategorical(float fvalue, const std::vector<uint32_t>& matching_categories,
48  bool categories_list_right_child, int left_child, int right_child,
49  int default_child) {
50  if (std::isnan(fvalue)) {
51  return default_child;
52  }
53  bool is_matching_category;
54  float max_representable_int = static_cast<float>(uint32_t(1) << FLT_MANT_DIG);
55  if (fvalue < 0 || std::fabs(fvalue) > max_representable_int) {
56  is_matching_category = false;
57  } else {
58  const auto category_value = static_cast<uint32_t>(fvalue);
59  is_matching_category = (
60  std::find(matching_categories.begin(), matching_categories.end(), category_value)
61  != matching_categories.end());
62  }
63  if (categories_list_right_child) {
64  return is_matching_category ? right_child : left_child;
65  } else {
66  return is_matching_category ? left_child : right_child;
67  }
68 }
69 
70 template <typename ThresholdType, typename LeafOutputType, typename DMatrixType,
71  typename OutputFunc>
72 inline std::size_t PredictImplInner(const treelite::ModelImpl<ThresholdType, LeafOutputType>& model,
73  const DMatrixType* input, float* output, bool pred_transform,
74  OutputFunc output_func) {
76  const size_t num_row = input->GetNumRow();
77  const size_t num_col = input->GetNumCol();
78  std::vector<ThresholdType> row(num_col);
79  const treelite::TaskParam task_param = model.task_param;
80  std::vector<float> sum(task_param.num_class);
81 
82  // TODO(phcho): Use parallelism
83  std::size_t output_offset = 0;
84  for (size_t row_id = 0; row_id < num_row; ++row_id) {
85  input->FillRow(row_id, row.data());
86  std::fill(sum.begin(), sum.end(), 0.0f);
87  const std::size_t num_tree = model.trees.size();
88  for (std::size_t tree_id = 0; tree_id < num_tree; ++tree_id) {
89  const TreeType& tree = model.trees[tree_id];
90  int node_id = 0;
91  while (!tree.IsLeaf(node_id)) {
92  treelite::SplitFeatureType split_type = tree.SplitType(node_id);
93  if (split_type == treelite::SplitFeatureType::kNumerical) {
94  node_id = NextNode(row[tree.SplitIndex(node_id)], tree.Threshold(node_id),
95  tree.ComparisonOp(node_id), tree.LeftChild(node_id),
96  tree.RightChild(node_id), tree.DefaultChild(node_id));
97  } else if (split_type == treelite::SplitFeatureType::kCategorical) {
98  node_id = NextNodeCategorical(row[tree.SplitIndex(node_id)],
99  tree.MatchingCategories(node_id),
100  tree.CategoriesListRightChild(node_id),
101  tree.LeftChild(node_id), tree.RightChild(node_id),
102  tree.DefaultChild(node_id));
103  } else {
104  TREELITE_CHECK(false) << "Unrecognized split type: " << static_cast<int>(split_type);
105  }
106  }
107  output_func(tree, tree_id, node_id, sum.data());
108  }
109  if (model.average_tree_output) {
110  float average_factor;
111  if (model.task_type == treelite::TaskType::kMultiClfGrovePerClass) {
112  TREELITE_CHECK(task_param.grove_per_class);
113  TREELITE_CHECK_EQ(task_param.leaf_vector_size, 1);
114  TREELITE_CHECK_GT(task_param.num_class, 1);
115  TREELITE_CHECK_EQ(num_tree % task_param.num_class, 0)
116  << "Expected the number of trees to be divisible by the number of classes";
117  int num_boosting_round = num_tree / static_cast<int>(task_param.num_class);
118  average_factor = static_cast<float>(num_boosting_round);
119  } else {
120  TREELITE_CHECK(model.task_type == treelite::TaskType::kBinaryClfRegr
121  || model.task_type == treelite::TaskType::kMultiClfProbDistLeaf);
122  TREELITE_CHECK(task_param.num_class == task_param.leaf_vector_size);
123  TREELITE_CHECK(!task_param.grove_per_class);
124  average_factor = static_cast<float>(num_tree);
125  }
126  for (unsigned int i = 0; i < task_param.num_class; ++i) {
127  sum[i] /= average_factor;
128  }
129  }
130  for (unsigned int i = 0; i < task_param.num_class; ++i) {
131  sum[i] += model.param.global_bias;
132  }
133  if (pred_transform) {
134  PredTransformFuncType pred_transform_func
135  = treelite::gtil::LookupPredTransform(model.param.pred_transform);
136  output_offset += pred_transform_func(model, sum.data(), &output[output_offset]);
137  } else {
138  for (unsigned int i = 0; i < task_param.num_class; ++i) {
139  output[output_offset + i] = sum[i];
140  }
141  output_offset += task_param.num_class;
142  }
143  input->ClearRow(row_id, row.data());
144  }
145  return output_offset;
146 }
147 
148 template <typename ThresholdType, typename LeafOutputType, typename DMatrixType>
149 inline std::size_t PredictImpl(const treelite::ModelImpl<ThresholdType, LeafOutputType>& model,
150  const DMatrixType* input, float* output,
151  bool pred_transform) {
153  const treelite::TaskParam task_param = model.task_param;
154  if (task_param.num_class > 1) {
155  if (task_param.leaf_vector_size > 1) {
156  // multi-class classification with random forest
157  auto output_logic = [task_param](
158  const TreeType& tree, int, int node_id, float* sum) {
159  auto leaf_vector = tree.LeafVector(node_id);
160  for (unsigned int i = 0; i < task_param.leaf_vector_size; ++i) {
161  sum[i] += leaf_vector[i];
162  }
163  };
164  return PredictImplInner(model, input, output, pred_transform, output_logic);
165  } else {
166  // multi-class classification with gradient boosted trees
167  auto output_logic = [task_param](
168  const TreeType& tree, int tree_id, int node_id, float* sum) {
169  sum[tree_id % task_param.num_class] += tree.LeafValue(node_id);
170  };
171  return PredictImplInner(model, input, output, pred_transform, output_logic);
172  }
173  } else {
174  auto output_logic = [task_param](
175  const TreeType& tree, int tree_id, int node_id, float* sum) {
176  sum[0] += tree.LeafValue(node_id);
177  };
178  return PredictImplInner(model, input, output, pred_transform, output_logic);
179  }
180 }
181 
182 } // anonymous namespace
183 
184 namespace treelite {
185 namespace gtil {
186 
187 std::size_t Predict(const Model* model, const DMatrix* input, float* output, bool pred_transform) {
188  // Check type of DMatrix
189  const auto* d1 = dynamic_cast<const DenseDMatrixImpl<float>*>(input);
190  const auto* d2 = dynamic_cast<const CSRDMatrixImpl<float>*>(input);
191  if (d1) {
192  return model->Dispatch([d1, output, pred_transform](const auto& model) {
193  return PredictImpl(model, d1, output, pred_transform);
194  });
195  } else if (d2) {
196  return model->Dispatch([d2, output, pred_transform](const auto& model) {
197  return PredictImpl(model, d2, output, pred_transform);
198  });
199  } else {
200  TREELITE_LOG(FATAL) << "DMatrix with float64 data is not supported";
201  return 0;
202  }
203 }
204 
205 std::size_t Predict(const Model* model, const float* input, std::size_t num_row, float* output,
206  bool pred_transform) {
207  std::unique_ptr<DenseDMatrixImpl<float>> dmat =
208  std::make_unique<DenseDMatrixImpl<float>>(
209  std::vector<float>(input, input + num_row * model->num_feature),
210  std::numeric_limits<float>::quiet_NaN(),
211  num_row,
212  model->num_feature);
213  return Predict(model, dmat.get(), output, pred_transform);
214 }
215 
216 std::size_t GetPredictOutputSize(const Model* model, std::size_t num_row) {
217  return model->task_param.num_class * num_row;
218 }
219 
220 std::size_t GetPredictOutputSize(const Model* model, const DMatrix* input) {
221  return GetPredictOutputSize(model, input->GetNumRow());
222 }
223 
224 } // namespace gtil
225 } // namespace treelite
ModelParam param
extra parameters
Definition: tree.h:722
SplitFeatureType
feature split type
Definition: base.h:22
bool grove_per_class
Whether we designate a subset of the trees to compute the prediction for each class.
Definition: tree.h:184
Group of parameters that are dependent on the choice of the task type.
Definition: tree.h:173
bool average_tree_output
whether to average tree outputs
Definition: tree.h:718
Input data structure of Treelite.
model structure for tree ensemble
unsigned int leaf_vector_size
Dimension of the output from each leaf node.
Definition: tree.h:199
in-memory representation of a decision tree
Definition: tree.h:214
logging facility for Treelite
unsigned int num_class
The number of classes in the target label.
Definition: tree.h:192
TaskType task_type
Task type.
Definition: tree.h:716
std::vector< Tree< ThresholdType, LeafOutputType > > trees
member trees
Definition: tree.h:746
General Tree Inference Library (GTIL), providing a reference implementation for predicting with decis...
TaskParam task_param
Group of parameters that are specific to the particular task type.
Definition: tree.h:720
thin wrapper for tree ensemble model
Definition: tree.h:667
Operator
comparison operators
Definition: base.h:26