11 #include <dmlc/json.h> 17 template <
typename ElementType>
23 template <
typename ElementType,
typename ThresholdType,
typename LeafOutputType>
25 const Entry<ElementType>* data,
int nid,
size_t* out_counts) {
28 const unsigned split_index = tree.
SplitIndex(nid);
30 if (data[split_index].missing == -1) {
31 Traverse_(tree, data, tree.
DefaultChild(nid), out_counts);
34 if (tree.
SplitType(nid) == treelite::SplitFeatureType::kNumerical) {
35 const ThresholdType threshold = tree.
Threshold(nid);
37 const auto fvalue =
static_cast<ElementType
>(data[split_index].fvalue);
40 const auto fvalue = data[split_index].fvalue;
42 result = (std::binary_search(matching_categories.begin(),
43 matching_categories.end(),
44 static_cast<uint32_t
>(fvalue)));
50 Traverse_(tree, data, tree.
LeftChild(nid), out_counts);
52 Traverse_(tree, data, tree.
RightChild(nid), out_counts);
58 template <
typename ElementType,
typename ThresholdType,
typename LeafOutputType>
60 const Entry<ElementType>* data,
size_t* out_counts) {
61 Traverse_(tree, data, 0, out_counts);
64 template <
typename ElementType,
typename ThresholdType,
typename LeafOutputType>
65 inline void ComputeBranchLoopImpl(
68 const size_t* count_row_ptr,
size_t* counts_tloc) {
69 std::vector<Entry<ElementType>> inst(nthread * dmat->
num_col, {-1});
70 const size_t ntree = model.
trees.size();
71 CHECK_LE(rbegin, rend);
72 CHECK_LT(static_cast<int64_t>(rend), std::numeric_limits<int64_t>::max());
73 const size_t num_col = dmat->
num_col;
76 const auto rbegin_i =
static_cast<int64_t
>(rbegin);
77 const auto rend_i =
static_cast<int64_t
>(rend);
78 #pragma omp parallel for schedule(static) num_threads(nthread) 79 for (int64_t rid = rbegin_i; rid < rend_i; ++rid) {
80 const int tid = omp_get_thread_num();
81 const ElementType* row = &dmat->
data[rid * num_col];
82 const size_t off = dmat->
num_col * tid;
83 const size_t off2 = count_row_ptr[ntree] * tid;
84 for (
size_t j = 0; j < num_col; ++j) {
85 if (treelite::math::CheckNAN(row[j])) {
87 <<
"The missing_value argument must be set to NaN if there is any NaN in the matrix.";
88 }
else if (nan_missing || row[j] != missing_value) {
89 inst[off + j].fvalue = row[j];
92 for (
size_t tree_id = 0; tree_id < ntree; ++tree_id) {
93 Traverse(model.
trees[tree_id], &inst[off], &counts_tloc[off2 + count_row_ptr[tree_id]]);
95 for (
size_t j = 0; j < num_col; ++j) {
96 inst[off + j].missing = -1;
101 template <
typename ElementType,
typename ThresholdType,
typename LeafOutputType>
102 inline void ComputeBranchLoopImpl(
105 const size_t* count_row_ptr,
size_t* counts_tloc) {
106 std::vector<Entry<ElementType>> inst(nthread * dmat->
num_col, {-1});
107 const size_t ntree = model.
trees.size();
108 CHECK_LE(rbegin, rend);
109 CHECK_LT(static_cast<int64_t>(rend), std::numeric_limits<int64_t>::max());
110 const auto rbegin_i =
static_cast<int64_t
>(rbegin);
111 const auto rend_i =
static_cast<int64_t
>(rend);
112 #pragma omp parallel for schedule(static) num_threads(nthread) 113 for (int64_t rid = rbegin_i; rid < rend_i; ++rid) {
114 const int tid = omp_get_thread_num();
115 const size_t off = dmat->
num_col * tid;
116 const size_t off2 = count_row_ptr[ntree] * tid;
117 const size_t ibegin = dmat->
row_ptr[rid];
118 const size_t iend = dmat->
row_ptr[rid + 1];
119 for (
size_t i = ibegin; i < iend; ++i) {
120 inst[off + dmat->
col_ind[i]].fvalue = dmat->
data[i];
122 for (
size_t tree_id = 0; tree_id < ntree; ++tree_id) {
123 Traverse(model.
trees[tree_id], &inst[off], &counts_tloc[off2 + count_row_ptr[tree_id]]);
125 for (
size_t i = ibegin; i < iend; ++i) {
126 inst[off + dmat->
col_ind[i]].missing = -1;
131 template <
typename ElementType>
132 class ComputeBranchLoopDispatcherWithDenseDMatrix {
134 template <
typename ThresholdType,
typename LeafOutputType>
135 inline static void Dispatch(
138 const size_t* count_row_ptr,
size_t* counts_tloc) {
140 CHECK(dmat_) <<
"Dangling data matrix reference detected";
141 ComputeBranchLoopImpl(model, dmat_, rbegin, rend, nthread, count_row_ptr, counts_tloc);
145 template <
typename ElementType>
146 class ComputeBranchLoopDispatcherWithCSRDMatrix {
148 template <
typename ThresholdType,
typename LeafOutputType>
149 inline static void Dispatch(
152 const size_t* count_row_ptr,
size_t* counts_tloc) {
154 CHECK(dmat_) <<
"Dangling data matrix reference detected";
155 ComputeBranchLoopImpl(model, dmat_, rbegin, rend, nthread, count_row_ptr, counts_tloc);
159 template <
typename ThresholdType,
typename LeafOutputType>
162 size_t rend,
int nthread,
const size_t* count_row_ptr,
163 size_t* counts_tloc) {
164 switch (dmat->GetType()) {
165 case treelite::DMatrixType::kDense: {
166 treelite::DispatchWithTypeInfo<ComputeBranchLoopDispatcherWithDenseDMatrix>(
167 dmat->GetElementType(), model, dmat, rbegin, rend, nthread, count_row_ptr, counts_tloc);
170 case treelite::DMatrixType::kSparseCSR: {
171 treelite::DispatchWithTypeInfo<ComputeBranchLoopDispatcherWithCSRDMatrix>(
172 dmat->GetElementType(), model, dmat, rbegin, rend, nthread, count_row_ptr, counts_tloc);
177 <<
"Annotator does not support DMatrix of type " <<
static_cast<int>(dmat->GetType());
186 template <
typename ThresholdType,
typename LeafOutputType>
191 std::vector<std::vector<size_t>>* out_counts) {
192 std::vector<size_t> new_counts;
193 std::vector<size_t> counts_tloc;
194 std::vector<size_t> count_row_ptr;
197 const size_t ntree = model.
trees.size();
198 const int max_thread = omp_get_max_threads();
199 nthread = (nthread == 0) ? max_thread : std::min(nthread, max_thread);
201 count_row_ptr.push_back(count_row_ptr.back() + tree.
num_nodes);
203 new_counts.resize(count_row_ptr[ntree], 0);
204 counts_tloc.resize(count_row_ptr[ntree] * nthread, 0);
206 const size_t num_row = dmat->GetNumRow();
207 const size_t pstep = (num_row + 19) / 20;
209 for (
size_t rbegin = 0; rbegin < num_row; rbegin += pstep) {
210 const size_t rend = std::min(rbegin + pstep, num_row);
211 ComputeBranchLoop(model, dmat, rbegin, rend, nthread, &count_row_ptr[0], &counts_tloc[0]);
213 LOG(INFO) << rend <<
" of " << num_row <<
" rows processed";
218 for (
int tid = 0; tid < nthread; ++tid) {
219 const size_t off = count_row_ptr[ntree] * tid;
220 for (
size_t i = 0; i < count_row_ptr[ntree]; ++i) {
221 new_counts[i] += counts_tloc[off + i];
226 std::vector<std::vector<size_t>>& counts = *out_counts;
227 for (
size_t i = 0; i < ntree; ++i) {
228 counts.emplace_back(&new_counts[count_row_ptr[i]], &new_counts[count_row_ptr[i + 1]]);
233 BranchAnnotator::Annotate(
const Model& model,
const DMatrix* dmat,
int nthread,
int verbose) {
234 TypeInfo threshold_type = model.GetThresholdType();
235 model.Dispatch([
this, dmat, nthread, verbose, threshold_type](
auto& handle) {
236 AnnotateImpl(handle, dmat, nthread, verbose, &this->counts);
241 BranchAnnotator::Load(dmlc::Stream* fi) {
242 dmlc::istream is(fi);
243 std::unique_ptr<dmlc::JSONReader> reader(
new dmlc::JSONReader(&is));
244 reader->Read(&counts);
248 BranchAnnotator::Save(dmlc::Stream* fo)
const {
249 dmlc::ostream os(fo);
250 std::unique_ptr<dmlc::JSONWriter> writer(
new dmlc::JSONWriter(&os));
251 writer->Write(counts);
Some useful math utilities.
Operator ComparisonOp(int nid) const
get comparison operator
size_t num_col
number of columns (i.e. # of features used)
bool CheckNAN(T value)
check for NaN (Not a Number)
ElementType missing_value
value representing the missing value (usually NaN)
int DefaultChild(int nid) const
index of the node's "default" child, used when feature is missing
bool CompareWithOp(ElementType lhs, Operator op, ThresholdType rhs)
perform comparison between two float's using a comparsion operator The comparison will be in the form...
std::vector< size_t > row_ptr
pointer to row headers; length is [num_row] + 1.
std::vector< uint32_t > col_ind
feature indices. col_ind[i] indicates the feature index associated with data[i].
std::vector< ElementType > data
feature values
in-memory representation of a decision tree
uint32_t SplitIndex(int nid) const
feature index of the node's split condition
TypeInfo
Types used by thresholds and leaf outputs.
std::vector< ElementType > data
feature values
std::vector< Tree< ThresholdType, LeafOutputType > > trees
member trees
bool CategoriesListRightChild(int nid) const
test whether the list given by MatchingCategories(nid) is associated with the right child node or the...
int num_nodes
number of nodes
SplitFeatureType SplitType(int nid) const
get feature split type
int LeftChild(int nid) const
Getters.
std::vector< uint32_t > MatchingCategories(int nid) const
Get list of all categories belonging to the left/right child node. See the categories_list_right_chil...
compatiblity wrapper for systems that don't support OpenMP
int RightChild(int nid) const
index of the node's right child
thin wrapper for tree ensemble model
bool IsLeaf(int nid) const
whether the node is leaf node
ThresholdType Threshold(int nid) const
get threshold of the node
size_t num_col
number of columns (i.e. # of features used)
Operator
comparison operators