10 #include <dmlc/json.h> 22 int nid,
size_t* out_counts) {
25 const unsigned split_index = tree.
SplitIndex(nid);
27 if (data[split_index].missing == -1) {
28 Traverse_(tree, data, tree.
DefaultChild(nid), out_counts);
31 if (tree.
SplitType(nid) == treelite::SplitFeatureType::kNumerical) {
37 const auto fvalue = data[split_index].fvalue;
39 result = (std::binary_search(left_categories.begin(),
40 left_categories.end(),
41 static_cast<uint32_t
>(fvalue)));
44 Traverse_(tree, data, tree.
LeftChild(nid), out_counts);
46 Traverse_(tree, data, tree.
RightChild(nid), out_counts);
54 Traverse_(tree, data, 0, out_counts);
59 size_t rbegin,
size_t rend,
int nthread,
60 const size_t* count_row_ptr,
61 size_t* counts_tloc, Entry* inst) {
62 const size_t ntree = model.
trees.size();
63 CHECK_LE(rbegin, rend);
64 CHECK_LT(static_cast<int64_t>(rend), std::numeric_limits<int64_t>::max());
65 const auto rbegin_i =
static_cast<int64_t
>(rbegin);
66 const auto rend_i =
static_cast<int64_t
>(rend);
67 #pragma omp parallel for schedule(static) num_threads(nthread) 68 for (int64_t rid = rbegin_i; rid < rend_i; ++rid) {
69 const int tid = omp_get_thread_num();
70 const size_t off = dmat->
num_col * tid;
71 const size_t off2 = count_row_ptr[ntree] * tid;
72 const size_t ibegin = dmat->
row_ptr[rid];
73 const size_t iend = dmat->
row_ptr[rid + 1];
74 for (
size_t i = ibegin; i < iend; ++i) {
75 inst[off + dmat->
col_ind[i]].fvalue = dmat->
data[i];
77 for (
size_t tree_id = 0; tree_id < ntree; ++tree_id) {
78 Traverse(model.
trees[tree_id], &inst[off],
79 &counts_tloc[off2 + count_row_ptr[tree_id]]);
81 for (
size_t i = ibegin; i < iend; ++i) {
82 inst[off + dmat->
col_ind[i]].missing = -1;
93 int nthread,
int verbose) {
94 std::vector<size_t> new_counts;
95 std::vector<size_t> counts_tloc;
96 std::vector<size_t> count_row_ptr;
98 const size_t ntree = model.
trees.size();
99 const int max_thread = omp_get_max_threads();
100 nthread = (nthread == 0) ? max_thread : std::min(nthread, max_thread);
102 count_row_ptr.push_back(count_row_ptr.back() + tree.
num_nodes);
104 new_counts.resize(count_row_ptr[ntree], 0);
105 counts_tloc.resize(count_row_ptr[ntree] * nthread, 0);
107 std::vector<Entry> inst(nthread * dmat->
num_col, {-1});
108 const size_t pstep = (dmat->
num_row + 19) / 20;
110 for (
size_t rbegin = 0; rbegin < dmat->
num_row; rbegin += pstep) {
111 const size_t rend = std::min(rbegin + pstep, dmat->
num_row);
112 ComputeBranchLoop(model, dmat, rbegin, rend, nthread,
113 &count_row_ptr[0], &counts_tloc[0], &inst[0]);
115 LOG(INFO) << rend <<
" of " << dmat->
num_row <<
" rows processed";
120 for (
int tid = 0; tid < nthread; ++tid) {
121 const size_t off = count_row_ptr[ntree] * tid;
122 for (
size_t i = 0; i < count_row_ptr[ntree]; ++i) {
123 new_counts[i] += counts_tloc[off + i];
128 for (
size_t i = 0; i < ntree; ++i) {
129 this->counts.emplace_back(&new_counts[count_row_ptr[i]],
130 &new_counts[count_row_ptr[i + 1]]);
135 BranchAnnotator::Load(dmlc::Stream* fi) {
136 dmlc::istream is(fi);
137 std::unique_ptr<dmlc::JSONReader> reader(
new dmlc::JSONReader(&is));
138 reader->Read(&counts);
142 BranchAnnotator::Save(dmlc::Stream* fo)
const {
143 dmlc::ostream os(fo);
144 std::unique_ptr<dmlc::JSONWriter> writer(
new dmlc::JSONWriter(&os));
145 writer->Write(counts);
Operator ComparisonOp(int nid) const
get comparison operator
bool CompareWithOp(tl_float lhs, Operator op, tl_float rhs)
perform comparison between two float's using a comparsion operator The comparison will be in the form...
std::vector< float > data
feature values
thin wrapper for tree ensemble model
float tl_float
float type to be used internally
std::vector< Tree > trees
member trees
tl_float Threshold(int nid) const
get threshold of the node
int DefaultChild(int nid) const
index of the node's "default" child, used when feature is missing
std::vector< uint32_t > LeftCategories(int nid) const
Get list of all categories belonging to the left child node. Categories not in this list will belong ...
in-memory representation of a decision tree
std::vector< uint32_t > col_ind
feature indices
uint32_t SplitIndex(int nid) const
feature index of the node's split condition
size_t num_row
number of rows
a simple data matrix in CSR (Compressed Sparse Row) storage
int num_nodes
number of nodes
SplitFeatureType SplitType(int nid) const
get feature split type
int LeftChild(int nid) const
Getters.
size_t num_col
number of columns
compatiblity wrapper for systems that don't support OpenMP
int RightChild(int nid) const
index of the node's right child
bool IsLeaf(int nid) const
whether the node is leaf node
std::vector< size_t > row_ptr
pointer to row headers; length of [num_row] + 1
Operator
comparison operators