Treelite
annotator.cc
Go to the documentation of this file.
1 
8 #include <treelite/annotator.h>
9 #include <treelite/math.h>
10 #include <treelite/omp.h>
11 #include <dmlc/json.h>
12 #include <limits>
13 #include <cstdint>
14 
15 namespace {
16 
17 template <typename ElementType>
18 union Entry {
19  int missing;
20  ElementType fvalue;
21 };
22 
23 template <typename ElementType, typename ThresholdType, typename LeafOutputType>
24 void Traverse_(const treelite::Tree<ThresholdType, LeafOutputType>& tree,
25  const Entry<ElementType>* data, int nid, size_t* out_counts) {
26  ++out_counts[nid];
27  if (!tree.IsLeaf(nid)) {
28  const unsigned split_index = tree.SplitIndex(nid);
29 
30  if (data[split_index].missing == -1) {
31  Traverse_(tree, data, tree.DefaultChild(nid), out_counts);
32  } else {
33  bool result = true;
34  if (tree.SplitType(nid) == treelite::SplitFeatureType::kNumerical) {
35  const ThresholdType threshold = tree.Threshold(nid);
36  const treelite::Operator op = tree.ComparisonOp(nid);
37  const auto fvalue = static_cast<ElementType>(data[split_index].fvalue);
38  result = treelite::CompareWithOp(fvalue, op, threshold);
39  } else {
40  const auto fvalue = data[split_index].fvalue;
41  const auto matching_categories = tree.MatchingCategories(nid);
42  result = (std::binary_search(matching_categories.begin(),
43  matching_categories.end(),
44  static_cast<uint32_t>(fvalue)));
45  if (tree.CategoriesListRightChild(nid)) {
46  result = !result;
47  }
48  }
49  if (result) { // left child
50  Traverse_(tree, data, tree.LeftChild(nid), out_counts);
51  } else { // right child
52  Traverse_(tree, data, tree.RightChild(nid), out_counts);
53  }
54  }
55  }
56 }
57 
58 template <typename ElementType, typename ThresholdType, typename LeafOutputType>
59 void Traverse(const treelite::Tree<ThresholdType, LeafOutputType>& tree,
60  const Entry<ElementType>* data, size_t* out_counts) {
61  Traverse_(tree, data, 0, out_counts);
62 }
63 
64 template <typename ElementType, typename ThresholdType, typename LeafOutputType>
65 inline void ComputeBranchLoopImpl(
67  const treelite::DenseDMatrixImpl<ElementType>* dmat, size_t rbegin, size_t rend, int nthread,
68  const size_t* count_row_ptr, size_t* counts_tloc) {
69  std::vector<Entry<ElementType>> inst(nthread * dmat->num_col, {-1});
70  const size_t ntree = model.trees.size();
71  CHECK_LE(rbegin, rend);
72  CHECK_LT(static_cast<int64_t>(rend), std::numeric_limits<int64_t>::max());
73  const size_t num_col = dmat->num_col;
74  const ElementType missing_value = dmat->missing_value;
75  const bool nan_missing = treelite::math::CheckNAN(missing_value);
76  const auto rbegin_i = static_cast<int64_t>(rbegin);
77  const auto rend_i = static_cast<int64_t>(rend);
78  #pragma omp parallel for schedule(static) num_threads(nthread)
79  for (int64_t rid = rbegin_i; rid < rend_i; ++rid) {
80  const int tid = omp_get_thread_num();
81  const ElementType* row = &dmat->data[rid * num_col];
82  const size_t off = dmat->num_col * tid;
83  const size_t off2 = count_row_ptr[ntree] * tid;
84  for (size_t j = 0; j < num_col; ++j) {
85  if (treelite::math::CheckNAN(row[j])) {
86  CHECK(nan_missing)
87  << "The missing_value argument must be set to NaN if there is any NaN in the matrix.";
88  } else if (nan_missing || row[j] != missing_value) {
89  inst[off + j].fvalue = row[j];
90  }
91  }
92  for (size_t tree_id = 0; tree_id < ntree; ++tree_id) {
93  Traverse(model.trees[tree_id], &inst[off], &counts_tloc[off2 + count_row_ptr[tree_id]]);
94  }
95  for (size_t j = 0; j < num_col; ++j) {
96  inst[off + j].missing = -1;
97  }
98  }
99 }
100 
101 template <typename ElementType, typename ThresholdType, typename LeafOutputType>
102 inline void ComputeBranchLoopImpl(
104  const treelite::CSRDMatrixImpl<ElementType>* dmat, size_t rbegin, size_t rend, int nthread,
105  const size_t* count_row_ptr, size_t* counts_tloc) {
106  std::vector<Entry<ElementType>> inst(nthread * dmat->num_col, {-1});
107  const size_t ntree = model.trees.size();
108  CHECK_LE(rbegin, rend);
109  CHECK_LT(static_cast<int64_t>(rend), std::numeric_limits<int64_t>::max());
110  const auto rbegin_i = static_cast<int64_t>(rbegin);
111  const auto rend_i = static_cast<int64_t>(rend);
112  #pragma omp parallel for schedule(static) num_threads(nthread)
113  for (int64_t rid = rbegin_i; rid < rend_i; ++rid) {
114  const int tid = omp_get_thread_num();
115  const size_t off = dmat->num_col * tid;
116  const size_t off2 = count_row_ptr[ntree] * tid;
117  const size_t ibegin = dmat->row_ptr[rid];
118  const size_t iend = dmat->row_ptr[rid + 1];
119  for (size_t i = ibegin; i < iend; ++i) {
120  inst[off + dmat->col_ind[i]].fvalue = dmat->data[i];
121  }
122  for (size_t tree_id = 0; tree_id < ntree; ++tree_id) {
123  Traverse(model.trees[tree_id], &inst[off], &counts_tloc[off2 + count_row_ptr[tree_id]]);
124  }
125  for (size_t i = ibegin; i < iend; ++i) {
126  inst[off + dmat->col_ind[i]].missing = -1;
127  }
128  }
129 }
130 
131 template <typename ElementType>
132 class ComputeBranchLoopDispatcherWithDenseDMatrix {
133  public:
134  template <typename ThresholdType, typename LeafOutputType>
135  inline static void Dispatch(
137  const treelite::DMatrix* dmat, size_t rbegin, size_t rend, int nthread,
138  const size_t* count_row_ptr, size_t* counts_tloc) {
139  const auto* dmat_ = static_cast<const treelite::DenseDMatrixImpl<ElementType>*>(dmat);
140  CHECK(dmat_) << "Dangling data matrix reference detected";
141  ComputeBranchLoopImpl(model, dmat_, rbegin, rend, nthread, count_row_ptr, counts_tloc);
142  }
143 };
144 
145 template <typename ElementType>
146 class ComputeBranchLoopDispatcherWithCSRDMatrix {
147  public:
148  template <typename ThresholdType, typename LeafOutputType>
149  inline static void Dispatch(
151  const treelite::DMatrix* dmat, size_t rbegin, size_t rend, int nthread,
152  const size_t* count_row_ptr, size_t* counts_tloc) {
153  const auto* dmat_ = static_cast<const treelite::CSRDMatrixImpl<ElementType>*>(dmat);
154  CHECK(dmat_) << "Dangling data matrix reference detected";
155  ComputeBranchLoopImpl(model, dmat_, rbegin, rend, nthread, count_row_ptr, counts_tloc);
156  }
157 };
158 
159 template <typename ThresholdType, typename LeafOutputType>
160 inline void ComputeBranchLoop(const treelite::ModelImpl<ThresholdType, LeafOutputType>& model,
161  const treelite::DMatrix* dmat, size_t rbegin,
162  size_t rend, int nthread, const size_t* count_row_ptr,
163  size_t* counts_tloc) {
164  switch (dmat->GetType()) {
165  case treelite::DMatrixType::kDense: {
166  treelite::DispatchWithTypeInfo<ComputeBranchLoopDispatcherWithDenseDMatrix>(
167  dmat->GetElementType(), model, dmat, rbegin, rend, nthread, count_row_ptr, counts_tloc);
168  break;
169  }
170  case treelite::DMatrixType::kSparseCSR: {
171  treelite::DispatchWithTypeInfo<ComputeBranchLoopDispatcherWithCSRDMatrix>(
172  dmat->GetElementType(), model, dmat, rbegin, rend, nthread, count_row_ptr, counts_tloc);
173  break;
174  }
175  default:
176  LOG(FATAL)
177  << "Annotator does not support DMatrix of type " << static_cast<int>(dmat->GetType());
178  break;
179  }
180 }
181 
182 } // anonymous namespace
183 
184 namespace treelite {
185 
186 template <typename ThresholdType, typename LeafOutputType>
187 inline void
188 AnnotateImpl(
190  const treelite::DMatrix* dmat, int nthread, int verbose,
191  std::vector<std::vector<size_t>>* out_counts) {
192  std::vector<size_t> new_counts;
193  std::vector<size_t> counts_tloc;
194  std::vector<size_t> count_row_ptr;
195 
196  count_row_ptr = {0};
197  const size_t ntree = model.trees.size();
198  const int max_thread = omp_get_max_threads();
199  nthread = (nthread == 0) ? max_thread : std::min(nthread, max_thread);
200  for (const treelite::Tree<ThresholdType, LeafOutputType>& tree : model.trees) {
201  count_row_ptr.push_back(count_row_ptr.back() + tree.num_nodes);
202  }
203  new_counts.resize(count_row_ptr[ntree], 0);
204  counts_tloc.resize(count_row_ptr[ntree] * nthread, 0);
205 
206  const size_t num_row = dmat->GetNumRow();
207  const size_t pstep = (num_row + 19) / 20;
208  // interval to display progress
209  for (size_t rbegin = 0; rbegin < num_row; rbegin += pstep) {
210  const size_t rend = std::min(rbegin + pstep, num_row);
211  ComputeBranchLoop(model, dmat, rbegin, rend, nthread, &count_row_ptr[0], &counts_tloc[0]);
212  if (verbose > 0) {
213  LOG(INFO) << rend << " of " << num_row << " rows processed";
214  }
215  }
216 
217  // perform reduction on counts
218  for (int tid = 0; tid < nthread; ++tid) {
219  const size_t off = count_row_ptr[ntree] * tid;
220  for (size_t i = 0; i < count_row_ptr[ntree]; ++i) {
221  new_counts[i] += counts_tloc[off + i];
222  }
223  }
224 
225  // change layout of counts
226  std::vector<std::vector<size_t>>& counts = *out_counts;
227  for (size_t i = 0; i < ntree; ++i) {
228  counts.emplace_back(&new_counts[count_row_ptr[i]], &new_counts[count_row_ptr[i + 1]]);
229  }
230 }
231 
232 void
233 BranchAnnotator::Annotate(const Model& model, const DMatrix* dmat, int nthread, int verbose) {
234  TypeInfo threshold_type = model.GetThresholdType();
235  model.Dispatch([this, dmat, nthread, verbose, threshold_type](auto& handle) {
236  AnnotateImpl(handle, dmat, nthread, verbose, &this->counts);
237  });
238 }
239 
240 void
241 BranchAnnotator::Load(dmlc::Stream* fi) {
242  dmlc::istream is(fi);
243  std::unique_ptr<dmlc::JSONReader> reader(new dmlc::JSONReader(&is));
244  reader->Read(&counts);
245 }
246 
247 void
248 BranchAnnotator::Save(dmlc::Stream* fo) const {
249  dmlc::ostream os(fo);
250  std::unique_ptr<dmlc::JSONWriter> writer(new dmlc::JSONWriter(&os));
251  writer->Write(counts);
252 }
253 
254 } // namespace treelite
Some useful math utilities.
Operator ComparisonOp(int nid) const
get comparison operator
Definition: tree.h:388
size_t num_col
number of columns (i.e. # of features used)
Definition: data.h:118
bool CheckNAN(T value)
check for NaN (Not a Number)
Definition: math.h:43
ElementType missing_value
value representing the missing value (usually NaN)
Definition: data.h:59
int DefaultChild(int nid) const
index of the node&#39;s "default" child, used when feature is missing
Definition: tree.h:323
bool CompareWithOp(ElementType lhs, Operator op, ThresholdType rhs)
perform comparison between two float&#39;s using a comparsion operator The comparison will be in the form...
Definition: base.h:63
std::vector< size_t > row_ptr
pointer to row headers; length is [num_row] + 1.
Definition: data.h:114
std::vector< uint32_t > col_ind
feature indices. col_ind[i] indicates the feature index associated with data[i].
Definition: data.h:112
std::vector< ElementType > data
feature values
Definition: data.h:57
in-memory representation of a decision tree
Definition: tree.h:191
uint32_t SplitIndex(int nid) const
feature index of the node&#39;s split condition
Definition: tree.h:330
TypeInfo
Types used by thresholds and leaf outputs.
Definition: typeinfo.h:22
std::vector< ElementType > data
feature values
Definition: data.h:110
std::vector< Tree< ThresholdType, LeafOutputType > > trees
member trees
Definition: tree.h:673
bool CategoriesListRightChild(int nid) const
test whether the list given by MatchingCategories(nid) is associated with the right child node or the...
Definition: tree.h:475
int num_nodes
number of nodes
Definition: tree.h:289
SplitFeatureType SplitType(int nid) const
get feature split type
Definition: tree.h:424
int LeftChild(int nid) const
Getters.
Definition: tree.h:309
std::vector< uint32_t > MatchingCategories(int nid) const
Get list of all categories belonging to the left/right child node. See the categories_list_right_chil...
Definition: tree.h:399
compatiblity wrapper for systems that don&#39;t support OpenMP
int RightChild(int nid) const
index of the node&#39;s right child
Definition: tree.h:316
Branch annotation tools.
thin wrapper for tree ensemble model
Definition: tree.h:615
bool IsLeaf(int nid) const
whether the node is leaf node
Definition: tree.h:344
ThresholdType Threshold(int nid) const
get threshold of the node
Definition: tree.h:381
size_t num_col
number of columns (i.e. # of features used)
Definition: data.h:63
Operator
comparison operators
Definition: base.h:26