31 case treelite::Operator::kEQ:
return lhs == rhs;
32 case treelite::Operator::kLT:
return lhs < rhs;
33 case treelite::Operator::kLE:
return lhs <= rhs;
34 case treelite::Operator::kGT:
return lhs > rhs;
35 case treelite::Operator::kGE:
return lhs >= rhs;
36 default: LOG(FATAL) <<
"operator undefined";
41 int nid,
size_t* out_counts) {
48 if (data[split_index].missing == -1) {
49 Traverse_(tree, data, node.
cdefault(), out_counts);
52 if (node.
split_type() == treelite::SplitFeatureType::kNumerical) {
57 result = CompareWithOp(fvalue, op, threshold);
59 const auto fvalue = data[split_index].fvalue;
60 const uint32_t fvalue2 =
static_cast<uint32_t
>(fvalue);
62 result = (std::binary_search(left_categories.begin(),
63 left_categories.end(), fvalue));
66 Traverse_(tree, data, node.
cleft(), out_counts);
68 Traverse_(tree, data, node.
cright(), out_counts);
76 Traverse_(tree, data, 0, out_counts);
81 size_t rbegin,
size_t rend,
int nthread,
82 const size_t* count_row_ptr,
83 size_t* counts_tloc, Entry* inst) {
84 const size_t ntree = model.
trees.size();
85 CHECK_LE(rbegin, rend);
86 CHECK_LT(static_cast<int64_t>(rend), std::numeric_limits<int64_t>::max());
87 const int64_t rbegin_i =
static_cast<int64_t
>(rbegin);
88 const int64_t rend_i =
static_cast<int64_t
>(rend);
89 #pragma omp parallel for schedule(static) num_threads(nthread) 90 for (int64_t rid = rbegin_i; rid < rend_i; ++rid) {
91 const int tid = omp_get_thread_num();
92 const size_t off = dmat->
num_col * tid;
93 const size_t off2 = count_row_ptr[ntree] * tid;
94 const size_t ibegin = dmat->
row_ptr[rid];
95 const size_t iend = dmat->
row_ptr[rid + 1];
96 for (
size_t i = ibegin; i < iend; ++i) {
97 inst[off + dmat->
col_ind[i]].fvalue = dmat->
data[i];
99 for (
size_t tree_id = 0; tree_id < ntree; ++tree_id) {
100 Traverse(model.
trees[tree_id], &inst[off],
101 &counts_tloc[off2 + count_row_ptr[tree_id]]);
103 for (
size_t i = ibegin; i < iend; ++i) {
104 inst[off + dmat->
col_ind[i]].missing = -1;
115 int nthread,
int verbose) {
116 std::vector<size_t> counts;
117 std::vector<size_t> counts_tloc;
118 std::vector<size_t> count_row_ptr;
120 const size_t ntree = model.
trees.size();
121 const int max_thread = omp_get_max_threads();
122 nthread = (nthread == 0) ? max_thread : std::min(nthread, max_thread);
124 count_row_ptr.push_back(count_row_ptr.back() + tree.
num_nodes);
126 counts.resize(count_row_ptr[ntree], 0);
127 counts_tloc.resize(count_row_ptr[ntree] * nthread, 0);
129 std::vector<Entry> inst(nthread * dmat->
num_col, {-1});
130 const size_t pstep = (dmat->
num_row + 19) / 20;
132 for (
size_t rbegin = 0; rbegin < dmat->
num_row; rbegin += pstep) {
133 const size_t rend = std::min(rbegin + pstep, dmat->
num_row);
134 ComputeBranchLoop(model, dmat, rbegin, rend, nthread,
135 &count_row_ptr[0], &counts_tloc[0], &inst[0]);
137 LOG(INFO) << rend <<
" of " << dmat->
num_row <<
" rows processed";
142 for (
int tid = 0; tid < nthread; ++tid) {
143 const size_t off = count_row_ptr[ntree] * tid;
144 for (
size_t i = 0; i < count_row_ptr[ntree]; ++i) {
145 counts[i] += counts_tloc[off + i];
150 for (
size_t i = 0; i < ntree; ++i) {
151 this->counts.emplace_back(&counts[count_row_ptr[i]],
152 &counts[count_row_ptr[i + 1]]);
157 BranchAnnotator::Load(dmlc::Stream* fi) {
158 dmlc::istream is(fi);
159 auto reader = common::make_unique<dmlc::JSONReader>(&is);
160 reader->Read(&counts);
164 BranchAnnotator::Save(dmlc::Stream* fo)
const {
165 dmlc::ostream os(fo);
166 auto writer = common::make_unique<dmlc::JSONWriter>(&os);
167 writer->Write(counts);
std::vector< float > data
feature values
thin wrapper for tree ensemble model
float tl_float
float type to be used internally
std::vector< Tree > trees
member trees
unsigned split_index() const
feature index of split condition
Operator comparison_op() const
get comparison operator
in-memory representation of a decision tree
const std::vector< uint32_t > & left_categories() const
get categories for left child node
std::vector< uint32_t > col_ind
feature indices
tl_float threshold() const
int cright() const
index of right child
size_t num_row
number of rows
a simple data matrix in CSR (Compressed Sparse Row) storage
int cdefault() const
index of default child when feature is missing
int num_nodes
number of nodes
size_t num_col
number of columns
compatiblity wrapper for systems that don't support OpenMP
int cleft() const
index of left child
SplitFeatureType split_type() const
get feature split type
std::vector< size_t > row_ptr
pointer to row headers; length of [num_row] + 1
bool is_leaf() const
whether current node is leaf node
Operator
comparison operators