treelite
annotator.cc
Go to the documentation of this file.
1 
8 #include <treelite/annotator.h>
9 #include <treelite/omp.h>
10 #include <cstdint>
11 #include <limits>
12 
13 namespace {
14 
15 union Entry {
16  int missing;
17  float fvalue;
18 };
19 
28 inline bool CompareWithOp(treelite::tl_float lhs, treelite::Operator op,
29  treelite::tl_float rhs) {
30  switch(op) {
31  case treelite::Operator::kEQ: return lhs == rhs;
32  case treelite::Operator::kLT: return lhs < rhs;
33  case treelite::Operator::kLE: return lhs <= rhs;
34  case treelite::Operator::kGT: return lhs > rhs;
35  case treelite::Operator::kGE: return lhs >= rhs;
36  default: LOG(FATAL) << "operator undefined";
37  }
38 }
39 
40 void Traverse_(const treelite::Tree& tree, const Entry* data,
41  int nid, size_t* out_counts) {
42  const treelite::Tree::Node& node = tree[nid];
43 
44  ++out_counts[nid];
45  if (!node.is_leaf()) {
46  const unsigned split_index = node.split_index();
47 
48  if (data[split_index].missing == -1) {
49  Traverse_(tree, data, node.cdefault(), out_counts);
50  } else {
51  bool result = true;
52  if (node.split_type() == treelite::SplitFeatureType::kNumerical) {
53  const treelite::tl_float threshold = node.threshold();
54  const treelite::Operator op = node.comparison_op();
55  const treelite::tl_float fvalue
56  = static_cast<treelite::tl_float>(data[split_index].fvalue);
57  result = CompareWithOp(fvalue, op, threshold);
58  } else {
59  const auto fvalue = data[split_index].fvalue;
60  const uint32_t fvalue2 = static_cast<uint32_t>(fvalue);
61  const auto left_categories = node.left_categories();
62  result = (std::binary_search(left_categories.begin(),
63  left_categories.end(), fvalue));
64  }
65  if (result) { // left child
66  Traverse_(tree, data, node.cleft(), out_counts);
67  } else { // right child
68  Traverse_(tree, data, node.cright(), out_counts);
69  }
70  }
71  }
72 }
73 
74 void Traverse(const treelite::Tree& tree, const Entry* data,
75  size_t* out_counts) {
76  Traverse_(tree, data, 0, out_counts);
77 }
78 
79 inline void ComputeBranchLoop(const treelite::Model& model,
80  const treelite::DMatrix* dmat,
81  size_t rbegin, size_t rend, int nthread,
82  const size_t* count_row_ptr,
83  size_t* counts_tloc, Entry* inst) {
84  const size_t ntree = model.trees.size();
85  CHECK_LE(rbegin, rend);
86  CHECK_LT(static_cast<int64_t>(rend), std::numeric_limits<int64_t>::max());
87  const int64_t rbegin_i = static_cast<int64_t>(rbegin);
88  const int64_t rend_i = static_cast<int64_t>(rend);
89  #pragma omp parallel for schedule(static) num_threads(nthread)
90  for (int64_t rid = rbegin_i; rid < rend_i; ++rid) {
91  const int tid = omp_get_thread_num();
92  const size_t off = dmat->num_col * tid;
93  const size_t off2 = count_row_ptr[ntree] * tid;
94  const size_t ibegin = dmat->row_ptr[rid];
95  const size_t iend = dmat->row_ptr[rid + 1];
96  for (size_t i = ibegin; i < iend; ++i) {
97  inst[off + dmat->col_ind[i]].fvalue = dmat->data[i];
98  }
99  for (size_t tree_id = 0; tree_id < ntree; ++tree_id) {
100  Traverse(model.trees[tree_id], &inst[off],
101  &counts_tloc[off2 + count_row_ptr[tree_id]]);
102  }
103  for (size_t i = ibegin; i < iend; ++i) {
104  inst[off + dmat->col_ind[i]].missing = -1;
105  }
106  }
107 }
108 
109 } // namespace anonymous
110 
111 namespace treelite {
112 
113 void
114 BranchAnnotator::Annotate(const Model& model, const DMatrix* dmat,
115  int nthread, int verbose) {
116  std::vector<size_t> counts;
117  std::vector<size_t> counts_tloc;
118  std::vector<size_t> count_row_ptr;
119  count_row_ptr = {0};
120  const size_t ntree = model.trees.size();
121  const int max_thread = omp_get_max_threads();
122  nthread = (nthread == 0) ? max_thread : std::min(nthread, max_thread);
123  for (const Tree& tree : model.trees) {
124  count_row_ptr.push_back(count_row_ptr.back() + tree.num_nodes);
125  }
126  counts.resize(count_row_ptr[ntree], 0);
127  counts_tloc.resize(count_row_ptr[ntree] * nthread, 0);
128 
129  std::vector<Entry> inst(nthread * dmat->num_col, {-1});
130  const size_t pstep = (dmat->num_row + 19) / 20;
131  // interval to display progress
132  for (size_t rbegin = 0; rbegin < dmat->num_row; rbegin += pstep) {
133  const size_t rend = std::min(rbegin + pstep, dmat->num_row);
134  ComputeBranchLoop(model, dmat, rbegin, rend, nthread,
135  &count_row_ptr[0], &counts_tloc[0], &inst[0]);
136  if (verbose > 0) {
137  LOG(INFO) << rend << " of " << dmat->num_row << " rows processed";
138  }
139  }
140 
141  // perform reduction on counts
142  for (int tid = 0; tid < nthread; ++tid) {
143  const size_t off = count_row_ptr[ntree] * tid;
144  for (size_t i = 0; i < count_row_ptr[ntree]; ++i) {
145  counts[i] += counts_tloc[off + i];
146  }
147  }
148 
149  // change layout of counts
150  for (size_t i = 0; i < ntree; ++i) {
151  this->counts.emplace_back(&counts[count_row_ptr[i]],
152  &counts[count_row_ptr[i + 1]]);
153  }
154 }
155 
156 void
157 BranchAnnotator::Load(dmlc::Stream* fi) {
158  dmlc::istream is(fi);
159  auto reader = common::make_unique<dmlc::JSONReader>(&is);
160  reader->Read(&counts);
161 }
162 
163 void
164 BranchAnnotator::Save(dmlc::Stream* fo) const {
165  dmlc::ostream os(fo);
166  auto writer = common::make_unique<dmlc::JSONWriter>(&os);
167  writer->Write(counts);
168 }
169 
170 } // namespace treelite
std::vector< float > data
feature values
Definition: data.h:17
thin wrapper for tree ensemble model
Definition: tree.h:351
float tl_float
float type to be used internally
Definition: base.h:17
tree node
Definition: tree.h:22
std::vector< Tree > trees
member trees
Definition: tree.h:353
unsigned split_index() const
feature index of split condition
Definition: tree.h:38
Operator comparison_op() const
get comparison operator
Definition: tree.h:83
in-memory representation of a decision tree
Definition: tree.h:19
const std::vector< uint32_t > & left_categories() const
get categories for left child node
Definition: tree.h:87
std::vector< uint32_t > col_ind
feature indices
Definition: data.h:19
tl_float threshold() const
Definition: tree.h:67
int cright() const
index of right child
Definition: tree.h:30
size_t num_row
number of rows
Definition: data.h:23
a simple data matrix in CSR (Compressed Sparse Row) storage
Definition: data.h:15
int cdefault() const
index of default child when feature is missing
Definition: tree.h:34
int num_nodes
number of nodes
Definition: tree.h:217
size_t num_col
number of columns
Definition: data.h:25
compatiblity wrapper for systems that don&#39;t support OpenMP
Branch annotation tools.
int cleft() const
index of left child
Definition: tree.h:26
SplitFeatureType split_type() const
get feature split type
Definition: tree.h:91
std::vector< size_t > row_ptr
pointer to row headers; length of [num_row] + 1
Definition: data.h:21
bool is_leaf() const
whether current node is leaf node
Definition: tree.h:46
Operator
comparison operators
Definition: base.h:23