treelite
annotator.cc
Go to the documentation of this file.
1 
8 #include <treelite/annotator.h>
9 #include <treelite/semantic.h>
10 #include <treelite/omp.h>
11 #include <cstdint>
12 #include <limits>
13 
14 namespace {
15 
16 union Entry {
17  int missing;
18  float fvalue;
19 };
20 
21 void Traverse_(const treelite::Tree& tree, const Entry* data,
22  int nid, size_t* out_counts) {
23  const treelite::Tree::Node& node = tree[nid];
24 
25  ++out_counts[nid];
26  if (!node.is_leaf()) {
27  const unsigned split_index = node.split_index();
28 
29  if (data[split_index].missing == -1) {
30  Traverse_(tree, data, node.cdefault(), out_counts);
31  } else {
32  bool result = true;
33  if (node.split_type() == treelite::SplitFeatureType::kNumerical) {
34  const treelite::tl_float threshold = node.threshold();
35  const treelite::Operator op = node.comparison_op();
36  const treelite::tl_float fvalue
37  = static_cast<treelite::tl_float>(data[split_index].fvalue);
38  result = treelite::semantic::CompareWithOp(fvalue, op, threshold);
39  } else {
40  const auto fvalue = data[split_index].fvalue;
41  CHECK_LT(fvalue, 64) << "Cannot have more than 64 categories";
42  const uint8_t fvalue2 = static_cast<uint8_t>(fvalue);
43  const auto left_categories = node.left_categories();
44  result = (std::binary_search(left_categories.begin(),
45  left_categories.end(), fvalue));
46  }
47  if (result) { // left child
48  Traverse_(tree, data, node.cleft(), out_counts);
49  } else { // right child
50  Traverse_(tree, data, node.cright(), out_counts);
51  }
52  }
53  }
54 }
55 
56 void Traverse(const treelite::Tree& tree, const Entry* data,
57  size_t* out_counts) {
58  Traverse_(tree, data, 0, out_counts);
59 }
60 
61 inline void ComputeBranchLoop(const treelite::Model& model,
62  const treelite::DMatrix* dmat,
63  size_t rbegin, size_t rend, int nthread,
64  const size_t* count_row_ptr,
65  size_t* counts_tloc, Entry* inst) {
66  const size_t ntree = model.trees.size();
67  CHECK_LE(rbegin, rend);
68  CHECK_LT(static_cast<int64_t>(rend), std::numeric_limits<int64_t>::max());
69  const int64_t rbegin_i = static_cast<int64_t>(rbegin);
70  const int64_t rend_i = static_cast<int64_t>(rend);
71  #pragma omp parallel for schedule(static) num_threads(nthread)
72  for (int64_t rid = rbegin_i; rid < rend_i; ++rid) {
73  const int tid = omp_get_thread_num();
74  const size_t off = dmat->num_col * tid;
75  const size_t off2 = count_row_ptr[ntree] * tid;
76  const size_t ibegin = dmat->row_ptr[rid];
77  const size_t iend = dmat->row_ptr[rid + 1];
78  for (size_t i = ibegin; i < iend; ++i) {
79  inst[off + dmat->col_ind[i]].fvalue = dmat->data[i];
80  }
81  for (size_t tree_id = 0; tree_id < ntree; ++tree_id) {
82  Traverse(model.trees[tree_id], &inst[off],
83  &counts_tloc[off2 + count_row_ptr[tree_id]]);
84  }
85  for (size_t i = ibegin; i < iend; ++i) {
86  inst[off + dmat->col_ind[i]].missing = -1;
87  }
88  }
89 }
90 
91 } // namespace anonymous
92 
93 namespace treelite {
94 
95 void
96 BranchAnnotator::Annotate(const Model& model, const DMatrix* dmat,
97  int nthread, int verbose) {
98  std::vector<size_t> counts;
99  std::vector<size_t> counts_tloc;
100  std::vector<size_t> count_row_ptr;
101  count_row_ptr = {0};
102  const size_t ntree = model.trees.size();
103  const int max_thread = omp_get_max_threads();
104  nthread = (nthread == 0) ? max_thread : std::min(nthread, max_thread);
105  for (const Tree& tree : model.trees) {
106  count_row_ptr.push_back(count_row_ptr.back() + tree.num_nodes);
107  }
108  counts.resize(count_row_ptr[ntree], 0);
109  counts_tloc.resize(count_row_ptr[ntree] * nthread, 0);
110 
111  std::vector<Entry> inst(nthread * dmat->num_col, {-1});
112  const size_t pstep = (dmat->num_row + 19) / 20;
113  // interval to display progress
114  for (size_t rbegin = 0; rbegin < dmat->num_row; rbegin += pstep) {
115  const size_t rend = std::min(rbegin + pstep, dmat->num_row);
116  ComputeBranchLoop(model, dmat, rbegin, rend, nthread,
117  &count_row_ptr[0], &counts_tloc[0], &inst[0]);
118  if (verbose > 0) {
119  LOG(INFO) << rend << " of " << dmat->num_row << " rows processed";
120  }
121  }
122 
123  // perform reduction on counts
124  for (int tid = 0; tid < nthread; ++tid) {
125  const size_t off = count_row_ptr[ntree] * tid;
126  for (size_t i = 0; i < count_row_ptr[ntree]; ++i) {
127  counts[i] += counts_tloc[off + i];
128  }
129  }
130 
131  // change layout of counts
132  for (size_t i = 0; i < ntree; ++i) {
133  this->counts.emplace_back(&counts[count_row_ptr[i]],
134  &counts[count_row_ptr[i + 1]]);
135  }
136 }
137 
138 void
139 BranchAnnotator::Load(dmlc::Stream* fi) {
140  dmlc::istream is(fi);
141  auto reader = common::make_unique<dmlc::JSONReader>(&is);
142  reader->Read(&counts);
143 }
144 
145 void
146 BranchAnnotator::Save(dmlc::Stream* fo) const {
147  dmlc::ostream os(fo);
148  auto writer = common::make_unique<dmlc::JSONWriter>(&os);
149  writer->Write(counts);
150 }
151 
152 } // namespace treelite
std::vector< float > data
feature values
Definition: data.h:17
thin wrapper for tree ensemble model
Definition: tree.h:350
float tl_float
float type to be used internally
Definition: base.h:17
tree node
Definition: tree.h:22
std::vector< Tree > trees
member trees
Definition: tree.h:352
unsigned split_index() const
feature index of split condition
Definition: tree.h:38
Operator comparison_op() const
get comparison operator
Definition: tree.h:83
in-memory representation of a decision tree
Definition: tree.h:19
std::vector< uint32_t > col_ind
feature indices
Definition: data.h:19
bool CompareWithOp(tl_float lhs, Operator op, tl_float rhs)
perform comparison between two float&#39;s using a comparsion operator The comparison will be in the form...
Definition: semantic.h:47
tl_float threshold() const
Definition: tree.h:67
int cright() const
index of right child
Definition: tree.h:30
size_t num_row
number of rows
Definition: data.h:23
a simple data matrix in CSR (Compressed Sparse Row) storage
Definition: data.h:15
int cdefault() const
index of default child when feature is missing
Definition: tree.h:34
const std::vector< uint8_t > & left_categories() const
get categories for left child node
Definition: tree.h:87
int num_nodes
number of nodes
Definition: tree.h:216
size_t num_col
number of columns
Definition: data.h:25
compatiblity wrapper for systems that don&#39;t support OpenMP
Branch annotation tools.
int cleft() const
index of left child
Definition: tree.h:26
Building blocks for semantic model of tree prediction code.
SplitFeatureType split_type() const
get feature split type
Definition: tree.h:91
std::vector< size_t > row_ptr
pointer to row headers; length of [num_row] + 1
Definition: data.h:21
bool is_leaf() const
whether current node is leaf node
Definition: tree.h:46
Operator
comparison operators
Definition: base.h:23