11 #include <rapidjson/istreamwrapper.h> 12 #include <rapidjson/ostreamwrapper.h> 13 #include <rapidjson/writer.h> 14 #include <rapidjson/document.h> 24 template <
typename ElementType>
30 template <
typename ElementType,
typename ThresholdType,
typename LeafOutputType>
32 const Entry<ElementType>* data,
int nid, uint64_t* out_counts) {
35 const unsigned split_index = tree.
SplitIndex(nid);
37 if (data[split_index].missing == -1) {
38 Traverse_(tree, data, tree.
DefaultChild(nid), out_counts);
41 if (tree.
SplitType(nid) == treelite::SplitFeatureType::kNumerical) {
42 const ThresholdType threshold = tree.
Threshold(nid);
44 const auto fvalue =
static_cast<ElementType
>(data[split_index].fvalue);
47 const auto fvalue = data[split_index].fvalue;
49 result = (std::binary_search(matching_categories.begin(),
50 matching_categories.end(),
51 static_cast<uint32_t
>(fvalue)));
57 Traverse_(tree, data, tree.
LeftChild(nid), out_counts);
59 Traverse_(tree, data, tree.
RightChild(nid), out_counts);
65 template <
typename ElementType,
typename ThresholdType,
typename LeafOutputType>
67 const Entry<ElementType>* data, uint64_t* out_counts) {
68 Traverse_(tree, data, 0, out_counts);
71 template <
typename ElementType,
typename ThresholdType,
typename LeafOutputType>
72 inline void ComputeBranchLoopImpl(
75 const ThreadConfig& thread_config,
const size_t* count_row_ptr, uint64_t* counts_tloc) {
76 std::vector<Entry<ElementType>> inst(thread_config.nthread * dmat->
num_col, {-1});
77 size_t ntree = model.
trees.size();
78 TREELITE_CHECK_LE(rbegin, rend);
82 auto sched = treelite::threading_utils::ParallelSchedule::Static();
83 treelite::threading_utils::ParallelFor(rbegin, rend, thread_config, sched,
84 [&](std::size_t rid,
int thread_id) {
85 const ElementType* row = &dmat->
data[rid * num_col];
86 const size_t off = dmat->
num_col * thread_id;
87 const size_t off2 = count_row_ptr[ntree] * thread_id;
88 for (
size_t j = 0; j < num_col; ++j) {
89 if (treelite::math::CheckNAN(row[j])) {
90 TREELITE_CHECK(nan_missing)
91 <<
"The missing_value argument must be set to NaN if there is any NaN in the matrix.";
92 }
else if (nan_missing || row[j] != missing_value) {
93 inst[off + j].fvalue = row[j];
96 for (
size_t tree_id = 0; tree_id < ntree; ++tree_id) {
97 Traverse(model.
trees[tree_id], &inst[off], &counts_tloc[off2 + count_row_ptr[tree_id]]);
99 for (
size_t j = 0; j < num_col; ++j) {
100 inst[off + j].missing = -1;
105 template <
typename ElementType,
typename ThresholdType,
typename LeafOutputType>
106 inline void ComputeBranchLoopImpl(
109 const ThreadConfig& thread_config,
const size_t* count_row_ptr, uint64_t* counts_tloc) {
110 std::vector<Entry<ElementType>> inst(thread_config.nthread * dmat->
num_col, {-1});
111 size_t ntree = model.
trees.size();
112 TREELITE_CHECK_LE(rbegin, rend);
113 auto sched = treelite::threading_utils::ParallelSchedule::Static();
114 treelite::threading_utils::ParallelFor(rbegin, rend, thread_config, sched,
115 [&](std::size_t rid,
int thread_id) {
116 const size_t off = dmat->
num_col * thread_id;
117 const size_t off2 = count_row_ptr[ntree] * thread_id;
118 const size_t ibegin = dmat->
row_ptr[rid];
119 const size_t iend = dmat->
row_ptr[rid + 1];
120 for (
size_t i = ibegin; i < iend; ++i) {
121 inst[off + dmat->
col_ind[i]].fvalue = dmat->
data[i];
123 for (
size_t tree_id = 0; tree_id < ntree; ++tree_id) {
124 Traverse(model.
trees[tree_id], &inst[off], &counts_tloc[off2 + count_row_ptr[tree_id]]);
126 for (
size_t i = ibegin; i < iend; ++i) {
127 inst[off + dmat->
col_ind[i]].missing = -1;
132 template <
typename ElementType>
133 class ComputeBranchLoopDispatcherWithDenseDMatrix {
135 template <
typename ThresholdType,
typename LeafOutputType>
136 inline static void Dispatch(
138 const treelite::DMatrix* dmat,
size_t rbegin,
size_t rend,
const ThreadConfig& thread_config,
139 const size_t* count_row_ptr, uint64_t* counts_tloc) {
141 TREELITE_CHECK(dmat_) <<
"Dangling data matrix reference detected";
142 ComputeBranchLoopImpl(model, dmat_, rbegin, rend, thread_config, count_row_ptr, counts_tloc);
146 template <
typename ElementType>
147 class ComputeBranchLoopDispatcherWithCSRDMatrix {
149 template <
typename ThresholdType,
typename LeafOutputType>
150 inline static void Dispatch(
152 const treelite::DMatrix* dmat,
size_t rbegin,
size_t rend,
const ThreadConfig& thread_config,
153 const size_t* count_row_ptr, uint64_t* counts_tloc) {
155 TREELITE_CHECK(dmat_) <<
"Dangling data matrix reference detected";
156 ComputeBranchLoopImpl(model, dmat_, rbegin, rend, thread_config, count_row_ptr, counts_tloc);
160 template <
typename ThresholdType,
typename LeafOutputType>
163 size_t rend,
const ThreadConfig& thread_config,
164 const size_t* count_row_ptr, uint64_t* counts_tloc) {
165 switch (dmat->GetType()) {
166 case treelite::DMatrixType::kDense: {
167 treelite::DispatchWithTypeInfo<ComputeBranchLoopDispatcherWithDenseDMatrix>(
168 dmat->GetElementType(), model, dmat, rbegin, rend, thread_config, count_row_ptr,
172 case treelite::DMatrixType::kSparseCSR: {
173 treelite::DispatchWithTypeInfo<ComputeBranchLoopDispatcherWithCSRDMatrix>(
174 dmat->GetElementType(), model, dmat, rbegin, rend, thread_config, count_row_ptr,
180 <<
"Annotator does not support DMatrix of type " <<
static_cast<int>(dmat->GetType());
189 template <
typename ThresholdType,
typename LeafOutputType>
194 std::vector<std::vector<uint64_t>>* out_counts) {
195 std::vector<uint64_t> new_counts;
196 std::vector<uint64_t> counts_tloc;
197 std::vector<size_t> count_row_ptr;
200 const size_t ntree = model.
trees.size();
201 ThreadConfig thread_config = threading_utils::ConfigureThreadConfig(nthread);
203 count_row_ptr.push_back(count_row_ptr.back() + tree.
num_nodes);
205 new_counts.resize(count_row_ptr[ntree], 0);
206 counts_tloc.resize(count_row_ptr[ntree] * thread_config.nthread, 0);
208 const size_t num_row = dmat->GetNumRow();
209 const size_t pstep = (num_row + 19) / 20;
211 for (
size_t rbegin = 0; rbegin < num_row; rbegin += pstep) {
212 const size_t rend = std::min(rbegin + pstep, num_row);
213 ComputeBranchLoop(model, dmat, rbegin, rend, thread_config, &count_row_ptr[0],
216 TREELITE_LOG(INFO) << rend <<
" of " << num_row <<
" rows processed";
221 for (std::uint32_t tid = 0; tid < thread_config.nthread; ++tid) {
222 const size_t off = count_row_ptr[ntree] * tid;
223 for (
size_t i = 0; i < count_row_ptr[ntree]; ++i) {
224 new_counts[i] += counts_tloc[off + i];
229 std::vector<std::vector<uint64_t>>& counts = *out_counts;
230 for (
size_t i = 0; i < ntree; ++i) {
231 counts.emplace_back(&new_counts[count_row_ptr[i]], &new_counts[count_row_ptr[i + 1]]);
236 BranchAnnotator::Annotate(
const Model& model,
const DMatrix* dmat,
int nthread,
int verbose) {
237 TypeInfo threshold_type = model.GetThresholdType();
238 model.Dispatch([
this, dmat, nthread, verbose, threshold_type](
auto& handle) {
239 AnnotateImpl(handle, dmat, nthread, verbose, &this->counts_);
244 BranchAnnotator::Load(std::istream& fi) {
245 rapidjson::IStreamWrapper is(fi);
247 rapidjson::Document doc;
250 std::string err_msg =
"JSON file must contain a list of lists of integers";
251 TREELITE_CHECK(doc.IsArray()) << err_msg;
253 for (
const auto& node_cnt : doc.GetArray()) {
254 TREELITE_CHECK(node_cnt.IsArray()) << err_msg;
255 counts_.emplace_back();
256 for (
const auto& e : node_cnt.GetArray()) {
257 counts_.back().push_back(e.GetUint64());
263 BranchAnnotator::Save(std::ostream& fo)
const {
264 rapidjson::OStreamWrapper os(fo);
265 rapidjson::Writer<rapidjson::OStreamWrapper> writer(os);
268 for (
const auto& node_cnt : counts_) {
270 for (
auto e : node_cnt) {
Some useful math utilities.
Operator ComparisonOp(int nid) const
get comparison operator
Represent thread configuration, to be used with parallel loops.
size_t num_col
number of columns (i.e. # of features used)
bool CheckNAN(T value)
check for NaN (Not a Number)
ElementType missing_value
value representing the missing value (usually NaN)
int DefaultChild(int nid) const
index of the node's "default" child, used when feature is missing
bool CompareWithOp(ElementType lhs, Operator op, ThresholdType rhs)
perform comparison between two float's using a comparsion operator The comparison will be in the form...
std::vector< size_t > row_ptr
pointer to row headers; length is [num_row] + 1.
std::vector< uint32_t > col_ind
feature indices. col_ind[i] indicates the feature index associated with data[i].
std::vector< ElementType > data
feature values
in-memory representation of a decision tree
logging facility for Treelite
TypeInfo
Types used by thresholds and leaf outputs.
std::vector< ElementType > data
feature values
std::vector< Tree< ThresholdType, LeafOutputType > > trees
member trees
Implemenation of parallel for loop.
std::uint32_t SplitIndex(int nid) const
feature index of the node's split condition
bool CategoriesListRightChild(int nid) const
test whether the list given by MatchingCategories(nid) is associated with the right child node or the...
int num_nodes
number of nodes
SplitFeatureType SplitType(int nid) const
get feature split type
int LeftChild(int nid) const
Getters.
int RightChild(int nid) const
index of the node's right child
thin wrapper for tree ensemble model
bool IsLeaf(int nid) const
whether the node is leaf node
ThresholdType Threshold(int nid) const
get threshold of the node
std::vector< std::uint32_t > MatchingCategories(int nid) const
Get list of all categories belonging to the left/right child node. See the categories_list_right_chil...
size_t num_col
number of columns (i.e. # of features used)
Operator
comparison operators