Treelite
annotator.cc
Go to the documentation of this file.
1 
8 #include <treelite/annotator.h>
9 #include <treelite/omp.h>
10 #include <dmlc/json.h>
11 #include <limits>
12 #include <cstdint>
13 
14 namespace {
15 
16 union Entry {
17  int missing;
18  float fvalue;
19 };
20 
21 void Traverse_(const treelite::Tree& tree, const Entry* data,
22  int nid, size_t* out_counts) {
23  ++out_counts[nid];
24  if (!tree.IsLeaf(nid)) {
25  const unsigned split_index = tree.SplitIndex(nid);
26 
27  if (data[split_index].missing == -1) {
28  Traverse_(tree, data, tree.DefaultChild(nid), out_counts);
29  } else {
30  bool result = true;
31  if (tree.SplitType(nid) == treelite::SplitFeatureType::kNumerical) {
32  const treelite::tl_float threshold = tree.Threshold(nid);
33  const treelite::Operator op = tree.ComparisonOp(nid);
34  const auto fvalue = static_cast<treelite::tl_float>(data[split_index].fvalue);
35  result = treelite::CompareWithOp(fvalue, op, threshold);
36  } else {
37  const auto fvalue = data[split_index].fvalue;
38  const auto left_categories = tree.LeftCategories(nid);
39  result = (std::binary_search(left_categories.begin(),
40  left_categories.end(),
41  static_cast<uint32_t>(fvalue)));
42  }
43  if (result) { // left child
44  Traverse_(tree, data, tree.LeftChild(nid), out_counts);
45  } else { // right child
46  Traverse_(tree, data, tree.RightChild(nid), out_counts);
47  }
48  }
49  }
50 }
51 
52 void Traverse(const treelite::Tree& tree, const Entry* data,
53  size_t* out_counts) {
54  Traverse_(tree, data, 0, out_counts);
55 }
56 
57 inline void ComputeBranchLoop(const treelite::Model& model,
58  const treelite::DMatrix* dmat,
59  size_t rbegin, size_t rend, int nthread,
60  const size_t* count_row_ptr,
61  size_t* counts_tloc, Entry* inst) {
62  const size_t ntree = model.trees.size();
63  CHECK_LE(rbegin, rend);
64  CHECK_LT(static_cast<int64_t>(rend), std::numeric_limits<int64_t>::max());
65  const auto rbegin_i = static_cast<int64_t>(rbegin);
66  const auto rend_i = static_cast<int64_t>(rend);
67  #pragma omp parallel for schedule(static) num_threads(nthread)
68  for (int64_t rid = rbegin_i; rid < rend_i; ++rid) {
69  const int tid = omp_get_thread_num();
70  const size_t off = dmat->num_col * tid;
71  const size_t off2 = count_row_ptr[ntree] * tid;
72  const size_t ibegin = dmat->row_ptr[rid];
73  const size_t iend = dmat->row_ptr[rid + 1];
74  for (size_t i = ibegin; i < iend; ++i) {
75  inst[off + dmat->col_ind[i]].fvalue = dmat->data[i];
76  }
77  for (size_t tree_id = 0; tree_id < ntree; ++tree_id) {
78  Traverse(model.trees[tree_id], &inst[off],
79  &counts_tloc[off2 + count_row_ptr[tree_id]]);
80  }
81  for (size_t i = ibegin; i < iend; ++i) {
82  inst[off + dmat->col_ind[i]].missing = -1;
83  }
84  }
85 }
86 
87 } // anonymous namespace
88 
89 namespace treelite {
90 
91 void
92 BranchAnnotator::Annotate(const Model& model, const DMatrix* dmat,
93  int nthread, int verbose) {
94  std::vector<size_t> new_counts;
95  std::vector<size_t> counts_tloc;
96  std::vector<size_t> count_row_ptr;
97  count_row_ptr = {0};
98  const size_t ntree = model.trees.size();
99  const int max_thread = omp_get_max_threads();
100  nthread = (nthread == 0) ? max_thread : std::min(nthread, max_thread);
101  for (const Tree& tree : model.trees) {
102  count_row_ptr.push_back(count_row_ptr.back() + tree.num_nodes);
103  }
104  new_counts.resize(count_row_ptr[ntree], 0);
105  counts_tloc.resize(count_row_ptr[ntree] * nthread, 0);
106 
107  std::vector<Entry> inst(nthread * dmat->num_col, {-1});
108  const size_t pstep = (dmat->num_row + 19) / 20;
109  // interval to display progress
110  for (size_t rbegin = 0; rbegin < dmat->num_row; rbegin += pstep) {
111  const size_t rend = std::min(rbegin + pstep, dmat->num_row);
112  ComputeBranchLoop(model, dmat, rbegin, rend, nthread,
113  &count_row_ptr[0], &counts_tloc[0], &inst[0]);
114  if (verbose > 0) {
115  LOG(INFO) << rend << " of " << dmat->num_row << " rows processed";
116  }
117  }
118 
119  // perform reduction on counts
120  for (int tid = 0; tid < nthread; ++tid) {
121  const size_t off = count_row_ptr[ntree] * tid;
122  for (size_t i = 0; i < count_row_ptr[ntree]; ++i) {
123  new_counts[i] += counts_tloc[off + i];
124  }
125  }
126 
127  // change layout of counts
128  for (size_t i = 0; i < ntree; ++i) {
129  this->counts.emplace_back(&new_counts[count_row_ptr[i]],
130  &new_counts[count_row_ptr[i + 1]]);
131  }
132 }
133 
134 void
135 BranchAnnotator::Load(dmlc::Stream* fi) {
136  dmlc::istream is(fi);
137  std::unique_ptr<dmlc::JSONReader> reader(new dmlc::JSONReader(&is));
138  reader->Read(&counts);
139 }
140 
141 void
142 BranchAnnotator::Save(dmlc::Stream* fo) const {
143  dmlc::ostream os(fo);
144  std::unique_ptr<dmlc::JSONWriter> writer(new dmlc::JSONWriter(&os));
145  writer->Write(counts);
146 }
147 
148 } // namespace treelite
Operator ComparisonOp(int nid) const
get comparison operator
Definition: tree_impl.h:581
bool CompareWithOp(tl_float lhs, Operator op, tl_float rhs)
perform comparison between two float&#39;s using a comparsion operator The comparison will be in the form...
Definition: base.h:59
std::vector< float > data
feature values
Definition: data.h:18
thin wrapper for tree ensemble model
Definition: tree.h:409
float tl_float
float type to be used internally
Definition: base.h:18
std::vector< Tree > trees
member trees
Definition: tree.h:411
tl_float Threshold(int nid) const
get threshold of the node
Definition: tree_impl.h:576
int DefaultChild(int nid) const
index of the node&#39;s "default" child, used when feature is missing
Definition: tree_impl.h:534
std::vector< uint32_t > LeftCategories(int nid) const
Get list of all categories belonging to the left child node. Categories not in this list will belong ...
Definition: tree_impl.h:586
in-memory representation of a decision tree
Definition: tree.h:80
std::vector< uint32_t > col_ind
feature indices
Definition: data.h:20
uint32_t SplitIndex(int nid) const
feature index of the node&#39;s split condition
Definition: tree_impl.h:539
size_t num_row
number of rows
Definition: data.h:24
a simple data matrix in CSR (Compressed Sparse Row) storage
Definition: data.h:16
int num_nodes
number of nodes
Definition: tree.h:167
SplitFeatureType SplitType(int nid) const
get feature split type
Definition: tree_impl.h:595
int LeftChild(int nid) const
Getters.
Definition: tree_impl.h:524
size_t num_col
number of columns
Definition: data.h:26
compatiblity wrapper for systems that don&#39;t support OpenMP
int RightChild(int nid) const
index of the node&#39;s right child
Definition: tree_impl.h:529
Branch annotation tools.
bool IsLeaf(int nid) const
whether the node is leaf node
Definition: tree_impl.h:549
std::vector< size_t > row_ptr
pointer to row headers; length of [num_row] + 1
Definition: data.h:22
Operator
comparison operators
Definition: base.h:24