Treelite
annotator.cc
Go to the documentation of this file.
1 
8 #include <treelite/logging.h>
9 #include <treelite/annotator.h>
10 #include <treelite/math.h>
11 #include <rapidjson/istreamwrapper.h>
12 #include <rapidjson/ostreamwrapper.h>
13 #include <rapidjson/writer.h>
14 #include <rapidjson/document.h>
15 #include <limits>
16 #include <thread>
17 #include <cstdint>
19 
20 namespace {
21 
23 
24 template <typename ElementType>
25 union Entry {
26  int missing;
27  ElementType fvalue;
28 };
29 
30 template <typename ElementType, typename ThresholdType, typename LeafOutputType>
31 void Traverse_(const treelite::Tree<ThresholdType, LeafOutputType>& tree,
32  const Entry<ElementType>* data, int nid, uint64_t* out_counts) {
33  ++out_counts[nid];
34  if (!tree.IsLeaf(nid)) {
35  const unsigned split_index = tree.SplitIndex(nid);
36 
37  if (data[split_index].missing == -1) {
38  Traverse_(tree, data, tree.DefaultChild(nid), out_counts);
39  } else {
40  bool result = true;
41  if (tree.SplitType(nid) == treelite::SplitFeatureType::kNumerical) {
42  const ThresholdType threshold = tree.Threshold(nid);
43  const treelite::Operator op = tree.ComparisonOp(nid);
44  const auto fvalue = static_cast<ElementType>(data[split_index].fvalue);
45  result = treelite::CompareWithOp(fvalue, op, threshold);
46  } else {
47  const auto fvalue = data[split_index].fvalue;
48  const auto matching_categories = tree.MatchingCategories(nid);
49  result = (std::binary_search(matching_categories.begin(),
50  matching_categories.end(),
51  static_cast<uint32_t>(fvalue)));
52  if (tree.CategoriesListRightChild(nid)) {
53  result = !result;
54  }
55  }
56  if (result) { // left child
57  Traverse_(tree, data, tree.LeftChild(nid), out_counts);
58  } else { // right child
59  Traverse_(tree, data, tree.RightChild(nid), out_counts);
60  }
61  }
62  }
63 }
64 
65 template <typename ElementType, typename ThresholdType, typename LeafOutputType>
66 void Traverse(const treelite::Tree<ThresholdType, LeafOutputType>& tree,
67  const Entry<ElementType>* data, uint64_t* out_counts) {
68  Traverse_(tree, data, 0, out_counts);
69 }
70 
71 template <typename ElementType, typename ThresholdType, typename LeafOutputType>
72 inline void ComputeBranchLoopImpl(
74  const treelite::DenseDMatrixImpl<ElementType>* dmat, size_t rbegin, size_t rend,
75  const ThreadConfig& thread_config, const size_t* count_row_ptr, uint64_t* counts_tloc) {
76  std::vector<Entry<ElementType>> inst(thread_config.nthread * dmat->num_col, {-1});
77  size_t ntree = model.trees.size();
78  TREELITE_CHECK_LE(rbegin, rend);
79  size_t num_col = dmat->num_col;
80  ElementType missing_value = dmat->missing_value;
81  bool nan_missing = treelite::math::CheckNAN(missing_value);
82  auto sched = treelite::threading_utils::ParallelSchedule::Static();
83  treelite::threading_utils::ParallelFor(rbegin, rend, thread_config, sched,
84  [&](std::size_t rid, int thread_id) {
85  const ElementType* row = &dmat->data[rid * num_col];
86  const size_t off = dmat->num_col * thread_id;
87  const size_t off2 = count_row_ptr[ntree] * thread_id;
88  for (size_t j = 0; j < num_col; ++j) {
89  if (treelite::math::CheckNAN(row[j])) {
90  TREELITE_CHECK(nan_missing)
91  << "The missing_value argument must be set to NaN if there is any NaN in the matrix.";
92  } else if (nan_missing || row[j] != missing_value) {
93  inst[off + j].fvalue = row[j];
94  }
95  }
96  for (size_t tree_id = 0; tree_id < ntree; ++tree_id) {
97  Traverse(model.trees[tree_id], &inst[off], &counts_tloc[off2 + count_row_ptr[tree_id]]);
98  }
99  for (size_t j = 0; j < num_col; ++j) {
100  inst[off + j].missing = -1;
101  }
102  });
103 }
104 
105 template <typename ElementType, typename ThresholdType, typename LeafOutputType>
106 inline void ComputeBranchLoopImpl(
108  const treelite::CSRDMatrixImpl<ElementType>* dmat, size_t rbegin, size_t rend,
109  const ThreadConfig& thread_config, const size_t* count_row_ptr, uint64_t* counts_tloc) {
110  std::vector<Entry<ElementType>> inst(thread_config.nthread * dmat->num_col, {-1});
111  size_t ntree = model.trees.size();
112  TREELITE_CHECK_LE(rbegin, rend);
113  auto sched = treelite::threading_utils::ParallelSchedule::Static();
114  treelite::threading_utils::ParallelFor(rbegin, rend, thread_config, sched,
115  [&](std::size_t rid, int thread_id) {
116  const size_t off = dmat->num_col * thread_id;
117  const size_t off2 = count_row_ptr[ntree] * thread_id;
118  const size_t ibegin = dmat->row_ptr[rid];
119  const size_t iend = dmat->row_ptr[rid + 1];
120  for (size_t i = ibegin; i < iend; ++i) {
121  inst[off + dmat->col_ind[i]].fvalue = dmat->data[i];
122  }
123  for (size_t tree_id = 0; tree_id < ntree; ++tree_id) {
124  Traverse(model.trees[tree_id], &inst[off], &counts_tloc[off2 + count_row_ptr[tree_id]]);
125  }
126  for (size_t i = ibegin; i < iend; ++i) {
127  inst[off + dmat->col_ind[i]].missing = -1;
128  }
129  });
130 }
131 
132 template <typename ElementType>
133 class ComputeBranchLoopDispatcherWithDenseDMatrix {
134  public:
135  template <typename ThresholdType, typename LeafOutputType>
136  inline static void Dispatch(
138  const treelite::DMatrix* dmat, size_t rbegin, size_t rend, const ThreadConfig& thread_config,
139  const size_t* count_row_ptr, uint64_t* counts_tloc) {
140  const auto* dmat_ = static_cast<const treelite::DenseDMatrixImpl<ElementType>*>(dmat);
141  TREELITE_CHECK(dmat_) << "Dangling data matrix reference detected";
142  ComputeBranchLoopImpl(model, dmat_, rbegin, rend, thread_config, count_row_ptr, counts_tloc);
143  }
144 };
145 
146 template <typename ElementType>
147 class ComputeBranchLoopDispatcherWithCSRDMatrix {
148  public:
149  template <typename ThresholdType, typename LeafOutputType>
150  inline static void Dispatch(
152  const treelite::DMatrix* dmat, size_t rbegin, size_t rend, const ThreadConfig& thread_config,
153  const size_t* count_row_ptr, uint64_t* counts_tloc) {
154  const auto* dmat_ = static_cast<const treelite::CSRDMatrixImpl<ElementType>*>(dmat);
155  TREELITE_CHECK(dmat_) << "Dangling data matrix reference detected";
156  ComputeBranchLoopImpl(model, dmat_, rbegin, rend, thread_config, count_row_ptr, counts_tloc);
157  }
158 };
159 
160 template <typename ThresholdType, typename LeafOutputType>
161 inline void ComputeBranchLoop(const treelite::ModelImpl<ThresholdType, LeafOutputType>& model,
162  const treelite::DMatrix* dmat, size_t rbegin,
163  size_t rend, const ThreadConfig& thread_config,
164  const size_t* count_row_ptr, uint64_t* counts_tloc) {
165  switch (dmat->GetType()) {
166  case treelite::DMatrixType::kDense: {
167  treelite::DispatchWithTypeInfo<ComputeBranchLoopDispatcherWithDenseDMatrix>(
168  dmat->GetElementType(), model, dmat, rbegin, rend, thread_config, count_row_ptr,
169  counts_tloc);
170  break;
171  }
172  case treelite::DMatrixType::kSparseCSR: {
173  treelite::DispatchWithTypeInfo<ComputeBranchLoopDispatcherWithCSRDMatrix>(
174  dmat->GetElementType(), model, dmat, rbegin, rend, thread_config, count_row_ptr,
175  counts_tloc);
176  break;
177  }
178  default:
179  TREELITE_LOG(FATAL)
180  << "Annotator does not support DMatrix of type " << static_cast<int>(dmat->GetType());
181  break;
182  }
183 }
184 
185 } // anonymous namespace
186 
187 namespace treelite {
188 
189 template <typename ThresholdType, typename LeafOutputType>
190 inline void
191 AnnotateImpl(
193  const treelite::DMatrix* dmat, int nthread, int verbose,
194  std::vector<std::vector<uint64_t>>* out_counts) {
195  std::vector<uint64_t> new_counts;
196  std::vector<uint64_t> counts_tloc;
197  std::vector<size_t> count_row_ptr;
198 
199  count_row_ptr = {0};
200  const size_t ntree = model.trees.size();
201  ThreadConfig thread_config = threading_utils::ConfigureThreadConfig(nthread);
202  for (const treelite::Tree<ThresholdType, LeafOutputType>& tree : model.trees) {
203  count_row_ptr.push_back(count_row_ptr.back() + tree.num_nodes);
204  }
205  new_counts.resize(count_row_ptr[ntree], 0);
206  counts_tloc.resize(count_row_ptr[ntree] * thread_config.nthread, 0);
207 
208  const size_t num_row = dmat->GetNumRow();
209  const size_t pstep = (num_row + 19) / 20;
210  // interval to display progress
211  for (size_t rbegin = 0; rbegin < num_row; rbegin += pstep) {
212  const size_t rend = std::min(rbegin + pstep, num_row);
213  ComputeBranchLoop(model, dmat, rbegin, rend, thread_config, &count_row_ptr[0],
214  &counts_tloc[0]);
215  if (verbose > 0) {
216  TREELITE_LOG(INFO) << rend << " of " << num_row << " rows processed";
217  }
218  }
219 
220  // perform reduction on counts
221  for (std::uint32_t tid = 0; tid < thread_config.nthread; ++tid) {
222  const size_t off = count_row_ptr[ntree] * tid;
223  for (size_t i = 0; i < count_row_ptr[ntree]; ++i) {
224  new_counts[i] += counts_tloc[off + i];
225  }
226  }
227 
228  // change layout of counts
229  std::vector<std::vector<uint64_t>>& counts = *out_counts;
230  for (size_t i = 0; i < ntree; ++i) {
231  counts.emplace_back(new_counts.begin() + count_row_ptr[i],
232  new_counts.begin() + count_row_ptr[i + 1]);
233  }
234 }
235 
236 void
237 BranchAnnotator::Annotate(const Model& model, const DMatrix* dmat, int nthread, int verbose) {
238  TypeInfo threshold_type = model.GetThresholdType();
239  model.Dispatch([this, dmat, nthread, verbose, threshold_type](auto& handle) {
240  AnnotateImpl(handle, dmat, nthread, verbose, &this->counts_);
241  });
242 }
243 
244 void
245 BranchAnnotator::Load(std::istream& fi) {
246  rapidjson::IStreamWrapper is(fi);
247 
248  rapidjson::Document doc;
249  doc.ParseStream(is);
250 
251  std::string err_msg = "JSON file must contain a list of lists of integers";
252  TREELITE_CHECK(doc.IsArray()) << err_msg;
253  counts_.clear();
254  for (const auto& node_cnt : doc.GetArray()) {
255  TREELITE_CHECK(node_cnt.IsArray()) << err_msg;
256  counts_.emplace_back();
257  for (const auto& e : node_cnt.GetArray()) {
258  counts_.back().push_back(e.GetUint64());
259  }
260  }
261 }
262 
263 void
264 BranchAnnotator::Save(std::ostream& fo) const {
265  rapidjson::OStreamWrapper os(fo);
266  rapidjson::Writer<rapidjson::OStreamWrapper> writer(os);
267 
268  writer.StartArray();
269  for (const auto& node_cnt : counts_) {
270  writer.StartArray();
271  for (auto e : node_cnt) {
272  writer.Uint64(e);
273  }
274  writer.EndArray();
275  }
276  writer.EndArray();
277 }
278 
279 } // namespace treelite
Some useful math utilities.
Operator ComparisonOp(int nid) const
get comparison operator
Definition: tree.h:502
Represent thread configuration, to be used with parallel loops.
Definition: parallel_for.h:36
size_t num_col
number of columns (i.e. # of features used)
Definition: data.h:120
bool CheckNAN(T value)
check for NaN (Not a Number)
Definition: math.h:43
ElementType missing_value
value representing the missing value (usually NaN)
Definition: data.h:58
int DefaultChild(int nid) const
index of the node&#39;s "default" child, used when feature is missing
Definition: tree.h:437
bool CompareWithOp(ElementType lhs, Operator op, ThresholdType rhs)
perform comparison between two float&#39;s using a comparsion operator The comparison will be in the form...
Definition: base.h:78
std::vector< size_t > row_ptr
pointer to row headers; length is [num_row] + 1.
Definition: data.h:116
std::vector< uint32_t > col_ind
feature indices. col_ind[i] indicates the feature index associated with data[i].
Definition: data.h:114
std::vector< ElementType > data
feature values
Definition: data.h:56
in-memory representation of a decision tree
Definition: tree.h:222
logging facility for Treelite
TypeInfo
Types used by thresholds and leaf outputs.
Definition: typeinfo.h:23
std::vector< ElementType > data
feature values
Definition: data.h:112
std::vector< Tree< ThresholdType, LeafOutputType > > trees
member trees
Definition: tree.h:838
Implemenation of parallel for loop.
std::uint32_t SplitIndex(int nid) const
feature index of the node&#39;s split condition
Definition: tree.h:444
bool CategoriesListRightChild(int nid) const
test whether the list given by MatchingCategories(nid) is associated with the right child node or the...
Definition: tree.h:581
int num_nodes
number of nodes
Definition: tree.h:409
SplitFeatureType SplitType(int nid) const
get feature split type
Definition: tree.h:530
int LeftChild(int nid) const
Getters.
Definition: tree.h:423
int RightChild(int nid) const
index of the node&#39;s right child
Definition: tree.h:430
Branch annotation tools.
thin wrapper for tree ensemble model
Definition: tree.h:734
bool IsLeaf(int nid) const
whether the node is leaf node
Definition: tree.h:458
ThresholdType Threshold(int nid) const
get threshold of the node
Definition: tree.h:495
std::vector< std::uint32_t > MatchingCategories(int nid) const
Get list of all categories belonging to the left/right child node. See the categories_list_right_chil...
Definition: tree.h:513
size_t num_col
number of columns (i.e. # of features used)
Definition: data.h:62
Operator
comparison operators
Definition: base.h:27