Treelite
annotator.cc
Go to the documentation of this file.
1 
8 #include <treelite/logging.h>
9 #include <treelite/annotator.h>
10 #include <treelite/math.h>
11 #include <rapidjson/istreamwrapper.h>
12 #include <rapidjson/ostreamwrapper.h>
13 #include <rapidjson/writer.h>
14 #include <rapidjson/document.h>
15 #include <limits>
16 #include <thread>
17 #include <cstdint>
19 
20 namespace {
21 
22 template <typename ElementType>
23 union Entry {
24  int missing;
25  ElementType fvalue;
26 };
27 
28 template <typename ElementType, typename ThresholdType, typename LeafOutputType>
29 void Traverse_(const treelite::Tree<ThresholdType, LeafOutputType>& tree,
30  const Entry<ElementType>* data, int nid, uint64_t* out_counts) {
31  ++out_counts[nid];
32  if (!tree.IsLeaf(nid)) {
33  const unsigned split_index = tree.SplitIndex(nid);
34 
35  if (data[split_index].missing == -1) {
36  Traverse_(tree, data, tree.DefaultChild(nid), out_counts);
37  } else {
38  bool result = true;
39  if (tree.SplitType(nid) == treelite::SplitFeatureType::kNumerical) {
40  const ThresholdType threshold = tree.Threshold(nid);
41  const treelite::Operator op = tree.ComparisonOp(nid);
42  const auto fvalue = static_cast<ElementType>(data[split_index].fvalue);
43  result = treelite::CompareWithOp(fvalue, op, threshold);
44  } else {
45  const auto fvalue = data[split_index].fvalue;
46  const auto matching_categories = tree.MatchingCategories(nid);
47  result = (std::binary_search(matching_categories.begin(),
48  matching_categories.end(),
49  static_cast<uint32_t>(fvalue)));
50  if (tree.CategoriesListRightChild(nid)) {
51  result = !result;
52  }
53  }
54  if (result) { // left child
55  Traverse_(tree, data, tree.LeftChild(nid), out_counts);
56  } else { // right child
57  Traverse_(tree, data, tree.RightChild(nid), out_counts);
58  }
59  }
60  }
61 }
62 
63 template <typename ElementType, typename ThresholdType, typename LeafOutputType>
64 void Traverse(const treelite::Tree<ThresholdType, LeafOutputType>& tree,
65  const Entry<ElementType>* data, uint64_t* out_counts) {
66  Traverse_(tree, data, 0, out_counts);
67 }
68 
69 template <typename ElementType, typename ThresholdType, typename LeafOutputType>
70 inline void ComputeBranchLoopImpl(
72  const treelite::DenseDMatrixImpl<ElementType>* dmat, size_t rbegin, size_t rend, int nthread,
73  const size_t* count_row_ptr, uint64_t* counts_tloc) {
74  std::vector<Entry<ElementType>> inst(nthread * dmat->num_col, {-1});
75  size_t ntree = model.trees.size();
76  TREELITE_CHECK_LE(rbegin, rend);
77  size_t num_col = dmat->num_col;
78  ElementType missing_value = dmat->missing_value;
79  bool nan_missing = treelite::math::CheckNAN(missing_value);
80  treelite::threading_utils::ParallelFor(rbegin, rend, nthread,
81  [&](std::size_t rid, std::size_t thread_id) {
82  const ElementType* row = &dmat->data[rid * num_col];
83  const size_t off = dmat->num_col * thread_id;
84  const size_t off2 = count_row_ptr[ntree] * thread_id;
85  for (size_t j = 0; j < num_col; ++j) {
86  if (treelite::math::CheckNAN(row[j])) {
87  TREELITE_CHECK(nan_missing)
88  << "The missing_value argument must be set to NaN if there is any NaN in the matrix.";
89  } else if (nan_missing || row[j] != missing_value) {
90  inst[off + j].fvalue = row[j];
91  }
92  }
93  for (size_t tree_id = 0; tree_id < ntree; ++tree_id) {
94  Traverse(model.trees[tree_id], &inst[off], &counts_tloc[off2 + count_row_ptr[tree_id]]);
95  }
96  for (size_t j = 0; j < num_col; ++j) {
97  inst[off + j].missing = -1;
98  }
99  });
100 }
101 
102 template <typename ElementType, typename ThresholdType, typename LeafOutputType>
103 inline void ComputeBranchLoopImpl(
105  const treelite::CSRDMatrixImpl<ElementType>* dmat, size_t rbegin, size_t rend, int nthread,
106  const size_t* count_row_ptr, uint64_t* counts_tloc) {
107  std::vector<Entry<ElementType>> inst(nthread * dmat->num_col, {-1});
108  size_t ntree = model.trees.size();
109  TREELITE_CHECK_LE(rbegin, rend);
110  treelite::threading_utils::ParallelFor(rbegin, rend, nthread,
111  [&](std::size_t rid, std::size_t thread_id) {
112  const size_t off = dmat->num_col * thread_id;
113  const size_t off2 = count_row_ptr[ntree] * thread_id;
114  const size_t ibegin = dmat->row_ptr[rid];
115  const size_t iend = dmat->row_ptr[rid + 1];
116  for (size_t i = ibegin; i < iend; ++i) {
117  inst[off + dmat->col_ind[i]].fvalue = dmat->data[i];
118  }
119  for (size_t tree_id = 0; tree_id < ntree; ++tree_id) {
120  Traverse(model.trees[tree_id], &inst[off], &counts_tloc[off2 + count_row_ptr[tree_id]]);
121  }
122  for (size_t i = ibegin; i < iend; ++i) {
123  inst[off + dmat->col_ind[i]].missing = -1;
124  }
125  });
126 }
127 
128 template <typename ElementType>
129 class ComputeBranchLoopDispatcherWithDenseDMatrix {
130  public:
131  template <typename ThresholdType, typename LeafOutputType>
132  inline static void Dispatch(
134  const treelite::DMatrix* dmat, size_t rbegin, size_t rend, int nthread,
135  const size_t* count_row_ptr, uint64_t* counts_tloc) {
136  const auto* dmat_ = static_cast<const treelite::DenseDMatrixImpl<ElementType>*>(dmat);
137  TREELITE_CHECK(dmat_) << "Dangling data matrix reference detected";
138  ComputeBranchLoopImpl(model, dmat_, rbegin, rend, nthread, count_row_ptr, counts_tloc);
139  }
140 };
141 
142 template <typename ElementType>
143 class ComputeBranchLoopDispatcherWithCSRDMatrix {
144  public:
145  template <typename ThresholdType, typename LeafOutputType>
146  inline static void Dispatch(
148  const treelite::DMatrix* dmat, size_t rbegin, size_t rend, int nthread,
149  const size_t* count_row_ptr, uint64_t* counts_tloc) {
150  const auto* dmat_ = static_cast<const treelite::CSRDMatrixImpl<ElementType>*>(dmat);
151  TREELITE_CHECK(dmat_) << "Dangling data matrix reference detected";
152  ComputeBranchLoopImpl(model, dmat_, rbegin, rend, nthread, count_row_ptr, counts_tloc);
153  }
154 };
155 
156 template <typename ThresholdType, typename LeafOutputType>
157 inline void ComputeBranchLoop(const treelite::ModelImpl<ThresholdType, LeafOutputType>& model,
158  const treelite::DMatrix* dmat, size_t rbegin,
159  size_t rend, int nthread, const size_t* count_row_ptr,
160  uint64_t* counts_tloc) {
161  switch (dmat->GetType()) {
162  case treelite::DMatrixType::kDense: {
163  treelite::DispatchWithTypeInfo<ComputeBranchLoopDispatcherWithDenseDMatrix>(
164  dmat->GetElementType(), model, dmat, rbegin, rend, nthread, count_row_ptr, counts_tloc);
165  break;
166  }
167  case treelite::DMatrixType::kSparseCSR: {
168  treelite::DispatchWithTypeInfo<ComputeBranchLoopDispatcherWithCSRDMatrix>(
169  dmat->GetElementType(), model, dmat, rbegin, rend, nthread, count_row_ptr, counts_tloc);
170  break;
171  }
172  default:
173  TREELITE_LOG(FATAL)
174  << "Annotator does not support DMatrix of type " << static_cast<int>(dmat->GetType());
175  break;
176  }
177 }
178 
179 } // anonymous namespace
180 
181 namespace treelite {
182 
183 template <typename ThresholdType, typename LeafOutputType>
184 inline void
185 AnnotateImpl(
187  const treelite::DMatrix* dmat, int nthread, int verbose,
188  std::vector<std::vector<uint64_t>>* out_counts) {
189  std::vector<uint64_t> new_counts;
190  std::vector<uint64_t> counts_tloc;
191  std::vector<size_t> count_row_ptr;
192 
193  count_row_ptr = {0};
194  const size_t ntree = model.trees.size();
195  const int max_thread = static_cast<int>(std::thread::hardware_concurrency());
196  nthread = (nthread == 0) ? max_thread : std::min(nthread, max_thread);
197  for (const treelite::Tree<ThresholdType, LeafOutputType>& tree : model.trees) {
198  count_row_ptr.push_back(count_row_ptr.back() + tree.num_nodes);
199  }
200  new_counts.resize(count_row_ptr[ntree], 0);
201  counts_tloc.resize(count_row_ptr[ntree] * nthread, 0);
202 
203  const size_t num_row = dmat->GetNumRow();
204  const size_t pstep = (num_row + 19) / 20;
205  // interval to display progress
206  for (size_t rbegin = 0; rbegin < num_row; rbegin += pstep) {
207  const size_t rend = std::min(rbegin + pstep, num_row);
208  ComputeBranchLoop(model, dmat, rbegin, rend, nthread, &count_row_ptr[0], &counts_tloc[0]);
209  if (verbose > 0) {
210  TREELITE_LOG(INFO) << rend << " of " << num_row << " rows processed";
211  }
212  }
213 
214  // perform reduction on counts
215  for (int tid = 0; tid < nthread; ++tid) {
216  const size_t off = count_row_ptr[ntree] * tid;
217  for (size_t i = 0; i < count_row_ptr[ntree]; ++i) {
218  new_counts[i] += counts_tloc[off + i];
219  }
220  }
221 
222  // change layout of counts
223  std::vector<std::vector<uint64_t>>& counts = *out_counts;
224  for (size_t i = 0; i < ntree; ++i) {
225  counts.emplace_back(&new_counts[count_row_ptr[i]], &new_counts[count_row_ptr[i + 1]]);
226  }
227 }
228 
229 void
230 BranchAnnotator::Annotate(const Model& model, const DMatrix* dmat, int nthread, int verbose) {
231  TypeInfo threshold_type = model.GetThresholdType();
232  model.Dispatch([this, dmat, nthread, verbose, threshold_type](auto& handle) {
233  AnnotateImpl(handle, dmat, nthread, verbose, &this->counts_);
234  });
235 }
236 
237 void
238 BranchAnnotator::Load(std::istream& fi) {
239  rapidjson::IStreamWrapper is(fi);
240 
241  rapidjson::Document doc;
242  doc.ParseStream(is);
243 
244  std::string err_msg = "JSON file must contain a list of lists of integers";
245  TREELITE_CHECK(doc.IsArray()) << err_msg;
246  counts_.clear();
247  for (const auto& node_cnt : doc.GetArray()) {
248  TREELITE_CHECK(node_cnt.IsArray()) << err_msg;
249  counts_.emplace_back();
250  for (const auto& e : node_cnt.GetArray()) {
251  counts_.back().push_back(e.GetUint64());
252  }
253  }
254 }
255 
256 void
257 BranchAnnotator::Save(std::ostream& fo) const {
258  rapidjson::OStreamWrapper os(fo);
259  rapidjson::Writer<rapidjson::OStreamWrapper> writer(os);
260 
261  writer.StartArray();
262  for (const auto& node_cnt : counts_) {
263  writer.StartArray();
264  for (auto e : node_cnt) {
265  writer.Uint64(e);
266  }
267  writer.EndArray();
268  }
269  writer.EndArray();
270 }
271 
272 } // namespace treelite
Some useful math utilities.
Operator ComparisonOp(int nid) const
get comparison operator
Definition: tree.h:431
size_t num_col
number of columns (i.e. # of features used)
Definition: data.h:120
bool CheckNAN(T value)
check for NaN (Not a Number)
Definition: math.h:43
ElementType missing_value
value representing the missing value (usually NaN)
Definition: data.h:58
int DefaultChild(int nid) const
index of the node&#39;s "default" child, used when feature is missing
Definition: tree.h:366
bool CompareWithOp(ElementType lhs, Operator op, ThresholdType rhs)
perform comparison between two float&#39;s using a comparsion operator The comparison will be in the form...
Definition: base.h:77
std::vector< size_t > row_ptr
pointer to row headers; length is [num_row] + 1.
Definition: data.h:116
std::vector< uint32_t > col_ind
feature indices. col_ind[i] indicates the feature index associated with data[i].
Definition: data.h:114
std::vector< ElementType > data
feature values
Definition: data.h:56
in-memory representation of a decision tree
Definition: tree.h:214
logging facility for Treelite
uint32_t SplitIndex(int nid) const
feature index of the node&#39;s split condition
Definition: tree.h:373
TypeInfo
Types used by thresholds and leaf outputs.
Definition: typeinfo.h:22
std::vector< ElementType > data
feature values
Definition: data.h:112
std::vector< Tree< ThresholdType, LeafOutputType > > trees
member trees
Definition: tree.h:746
Implemenation of parallel for loop.
bool CategoriesListRightChild(int nid) const
test whether the list given by MatchingCategories(nid) is associated with the right child node or the...
Definition: tree.h:521
int num_nodes
number of nodes
Definition: tree.h:332
SplitFeatureType SplitType(int nid) const
get feature split type
Definition: tree.h:470
int LeftChild(int nid) const
Getters.
Definition: tree.h:352
std::vector< uint32_t > MatchingCategories(int nid) const
Get list of all categories belonging to the left/right child node. See the categories_list_right_chil...
Definition: tree.h:442
int RightChild(int nid) const
index of the node&#39;s right child
Definition: tree.h:359
Branch annotation tools.
thin wrapper for tree ensemble model
Definition: tree.h:667
bool IsLeaf(int nid) const
whether the node is leaf node
Definition: tree.h:387
ThresholdType Threshold(int nid) const
get threshold of the node
Definition: tree.h:424
size_t num_col
number of columns (i.e. # of features used)
Definition: data.h:62
Operator
comparison operators
Definition: base.h:26