treelite
annotator.cc
Go to the documentation of this file.
1 
8 #include <treelite/annotator.h>
9 #include <treelite/common.h>
10 #include <treelite/omp.h>
11 #include <cstdint>
12 #include <limits>
13 
14 namespace {
15 
16 union Entry {
17  int missing;
18  float fvalue;
19 };
20 
21 void Traverse_(const treelite::Tree& tree, const Entry* data,
22  int nid, size_t* out_counts) {
23  const treelite::Tree::Node& node = tree[nid];
24 
25  ++out_counts[nid];
26  if (!node.is_leaf()) {
27  const unsigned split_index = node.split_index();
28 
29  if (data[split_index].missing == -1) {
30  Traverse_(tree, data, node.cdefault(), out_counts);
31  } else {
32  bool result = true;
33  if (node.split_type() == treelite::SplitFeatureType::kNumerical) {
34  const treelite::tl_float threshold = node.threshold();
35  const treelite::Operator op = node.comparison_op();
36  const treelite::tl_float fvalue
37  = static_cast<treelite::tl_float>(data[split_index].fvalue);
38  result = treelite::common::CompareWithOp(fvalue, op, threshold);
39  } else {
40  const auto fvalue = data[split_index].fvalue;
41  const uint32_t fvalue2 = static_cast<uint32_t>(fvalue);
42  const auto left_categories = node.left_categories();
43  result = (std::binary_search(left_categories.begin(),
44  left_categories.end(), fvalue));
45  }
46  if (result) { // left child
47  Traverse_(tree, data, node.cleft(), out_counts);
48  } else { // right child
49  Traverse_(tree, data, node.cright(), out_counts);
50  }
51  }
52  }
53 }
54 
55 void Traverse(const treelite::Tree& tree, const Entry* data,
56  size_t* out_counts) {
57  Traverse_(tree, data, 0, out_counts);
58 }
59 
60 inline void ComputeBranchLoop(const treelite::Model& model,
61  const treelite::DMatrix* dmat,
62  size_t rbegin, size_t rend, int nthread,
63  const size_t* count_row_ptr,
64  size_t* counts_tloc, Entry* inst) {
65  const size_t ntree = model.trees.size();
66  CHECK_LE(rbegin, rend);
67  CHECK_LT(static_cast<int64_t>(rend), std::numeric_limits<int64_t>::max());
68  const int64_t rbegin_i = static_cast<int64_t>(rbegin);
69  const int64_t rend_i = static_cast<int64_t>(rend);
70  #pragma omp parallel for schedule(static) num_threads(nthread)
71  for (int64_t rid = rbegin_i; rid < rend_i; ++rid) {
72  const int tid = omp_get_thread_num();
73  const size_t off = dmat->num_col * tid;
74  const size_t off2 = count_row_ptr[ntree] * tid;
75  const size_t ibegin = dmat->row_ptr[rid];
76  const size_t iend = dmat->row_ptr[rid + 1];
77  for (size_t i = ibegin; i < iend; ++i) {
78  inst[off + dmat->col_ind[i]].fvalue = dmat->data[i];
79  }
80  for (size_t tree_id = 0; tree_id < ntree; ++tree_id) {
81  Traverse(model.trees[tree_id], &inst[off],
82  &counts_tloc[off2 + count_row_ptr[tree_id]]);
83  }
84  for (size_t i = ibegin; i < iend; ++i) {
85  inst[off + dmat->col_ind[i]].missing = -1;
86  }
87  }
88 }
89 
90 } // anonymous namespace
91 
92 namespace treelite {
93 
94 void
95 BranchAnnotator::Annotate(const Model& model, const DMatrix* dmat,
96  int nthread, int verbose) {
97  std::vector<size_t> counts;
98  std::vector<size_t> counts_tloc;
99  std::vector<size_t> count_row_ptr;
100  count_row_ptr = {0};
101  const size_t ntree = model.trees.size();
102  const int max_thread = omp_get_max_threads();
103  nthread = (nthread == 0) ? max_thread : std::min(nthread, max_thread);
104  for (const Tree& tree : model.trees) {
105  count_row_ptr.push_back(count_row_ptr.back() + tree.num_nodes);
106  }
107  counts.resize(count_row_ptr[ntree], 0);
108  counts_tloc.resize(count_row_ptr[ntree] * nthread, 0);
109 
110  std::vector<Entry> inst(nthread * dmat->num_col, {-1});
111  const size_t pstep = (dmat->num_row + 19) / 20;
112  // interval to display progress
113  for (size_t rbegin = 0; rbegin < dmat->num_row; rbegin += pstep) {
114  const size_t rend = std::min(rbegin + pstep, dmat->num_row);
115  ComputeBranchLoop(model, dmat, rbegin, rend, nthread,
116  &count_row_ptr[0], &counts_tloc[0], &inst[0]);
117  if (verbose > 0) {
118  LOG(INFO) << rend << " of " << dmat->num_row << " rows processed";
119  }
120  }
121 
122  // perform reduction on counts
123  for (int tid = 0; tid < nthread; ++tid) {
124  const size_t off = count_row_ptr[ntree] * tid;
125  for (size_t i = 0; i < count_row_ptr[ntree]; ++i) {
126  counts[i] += counts_tloc[off + i];
127  }
128  }
129 
130  // change layout of counts
131  for (size_t i = 0; i < ntree; ++i) {
132  this->counts.emplace_back(&counts[count_row_ptr[i]],
133  &counts[count_row_ptr[i + 1]]);
134  }
135 }
136 
137 void
138 BranchAnnotator::Load(dmlc::Stream* fi) {
139  dmlc::istream is(fi);
140  auto reader = common::make_unique<dmlc::JSONReader>(&is);
141  reader->Read(&counts);
142 }
143 
144 void
145 BranchAnnotator::Save(dmlc::Stream* fo) const {
146  dmlc::ostream os(fo);
147  auto writer = common::make_unique<dmlc::JSONWriter>(&os);
148  writer->Write(counts);
149 }
150 
151 } // namespace treelite
std::vector< float > data
feature values
Definition: data.h:18
thin wrapper for tree ensemble model
Definition: tree.h:415
tree node
Definition: tree.h:25
std::vector< Tree > trees
member trees
Definition: tree.h:417
unsigned split_index() const
feature index of split condition
Definition: tree.h:41
Operator comparison_op() const
get comparison operator
Definition: tree.h:86
in-memory representation of a decision tree
Definition: tree.h:22
const std::vector< uint32_t > & left_categories() const
get categories for left child node
Definition: tree.h:90
std::vector< uint32_t > col_ind
feature indices
Definition: data.h:20
tl_float threshold() const
Definition: tree.h:70
int cright() const
index of right child
Definition: tree.h:33
size_t num_row
number of rows
Definition: data.h:24
a simple data matrix in CSR (Compressed Sparse Row) storage
Definition: data.h:16
int cdefault() const
index of default child when feature is missing
Definition: tree.h:37
int num_nodes
number of nodes
Definition: tree.h:281
double tl_float
float type to be used internally
Definition: base.h:17
size_t num_col
number of columns
Definition: data.h:26
compatiblity wrapper for systems that don&#39;t support OpenMP
Branch annotation tools.
int cleft() const
index of left child
Definition: tree.h:29
SplitFeatureType split_type() const
get feature split type
Definition: tree.h:94
std::vector< size_t > row_ptr
pointer to row headers; length of [num_row] + 1
Definition: data.h:22
bool is_leaf() const
whether current node is leaf node
Definition: tree.h:49
Operator
comparison operators
Definition: base.h:23