11 #include <rapidjson/istreamwrapper.h> 12 #include <rapidjson/ostreamwrapper.h> 13 #include <rapidjson/writer.h> 14 #include <rapidjson/document.h> 22 template <
typename ElementType>
28 template <
typename ElementType,
typename ThresholdType,
typename LeafOutputType>
30 const Entry<ElementType>* data,
int nid, uint64_t* out_counts) {
33 const unsigned split_index = tree.
SplitIndex(nid);
35 if (data[split_index].missing == -1) {
36 Traverse_(tree, data, tree.
DefaultChild(nid), out_counts);
39 if (tree.
SplitType(nid) == treelite::SplitFeatureType::kNumerical) {
40 const ThresholdType threshold = tree.
Threshold(nid);
42 const auto fvalue =
static_cast<ElementType
>(data[split_index].fvalue);
45 const auto fvalue = data[split_index].fvalue;
47 result = (std::binary_search(matching_categories.begin(),
48 matching_categories.end(),
49 static_cast<uint32_t
>(fvalue)));
55 Traverse_(tree, data, tree.
LeftChild(nid), out_counts);
57 Traverse_(tree, data, tree.
RightChild(nid), out_counts);
63 template <
typename ElementType,
typename ThresholdType,
typename LeafOutputType>
65 const Entry<ElementType>* data, uint64_t* out_counts) {
66 Traverse_(tree, data, 0, out_counts);
69 template <
typename ElementType,
typename ThresholdType,
typename LeafOutputType>
70 inline void ComputeBranchLoopImpl(
73 const size_t* count_row_ptr, uint64_t* counts_tloc) {
74 std::vector<Entry<ElementType>> inst(nthread * dmat->
num_col, {-1});
75 const size_t ntree = model.
trees.size();
76 TREELITE_CHECK_LE(rbegin, rend);
77 const size_t num_col = dmat->
num_col;
80 treelite::threading_utils::ParallelFor(rbegin, rend, nthread,
81 [&](std::size_t rid, std::size_t thread_id) {
82 const ElementType* row = &dmat->
data[rid * num_col];
83 const size_t off = dmat->
num_col * thread_id;
84 const size_t off2 = count_row_ptr[ntree] * thread_id;
85 for (
size_t j = 0; j < num_col; ++j) {
86 if (treelite::math::CheckNAN(row[j])) {
87 TREELITE_CHECK(nan_missing)
88 <<
"The missing_value argument must be set to NaN if there is any NaN in the matrix.";
89 }
else if (nan_missing || row[j] != missing_value) {
90 inst[off + j].fvalue = row[j];
93 for (
size_t tree_id = 0; tree_id < ntree; ++tree_id) {
94 Traverse(model.
trees[tree_id], &inst[off], &counts_tloc[off2 + count_row_ptr[tree_id]]);
96 for (
size_t j = 0; j < num_col; ++j) {
97 inst[off + j].missing = -1;
102 template <
typename ElementType,
typename ThresholdType,
typename LeafOutputType>
103 inline void ComputeBranchLoopImpl(
106 const size_t* count_row_ptr, uint64_t* counts_tloc) {
107 std::vector<Entry<ElementType>> inst(nthread * dmat->
num_col, {-1});
108 const size_t ntree = model.
trees.size();
109 TREELITE_CHECK_LE(rbegin, rend);
110 treelite::threading_utils::ParallelFor(rbegin, rend, nthread,
111 [&](std::size_t rid, std::size_t thread_id) {
112 const size_t off = dmat->
num_col * thread_id;
113 const size_t off2 = count_row_ptr[ntree] * thread_id;
114 const size_t ibegin = dmat->
row_ptr[rid];
115 const size_t iend = dmat->
row_ptr[rid + 1];
116 for (
size_t i = ibegin; i < iend; ++i) {
117 inst[off + dmat->
col_ind[i]].fvalue = dmat->
data[i];
119 for (
size_t tree_id = 0; tree_id < ntree; ++tree_id) {
120 Traverse(model.
trees[tree_id], &inst[off], &counts_tloc[off2 + count_row_ptr[tree_id]]);
122 for (
size_t i = ibegin; i < iend; ++i) {
123 inst[off + dmat->
col_ind[i]].missing = -1;
128 template <
typename ElementType>
129 class ComputeBranchLoopDispatcherWithDenseDMatrix {
131 template <
typename ThresholdType,
typename LeafOutputType>
132 inline static void Dispatch(
135 const size_t* count_row_ptr, uint64_t* counts_tloc) {
137 TREELITE_CHECK(dmat_) <<
"Dangling data matrix reference detected";
138 ComputeBranchLoopImpl(model, dmat_, rbegin, rend, nthread, count_row_ptr, counts_tloc);
142 template <
typename ElementType>
143 class ComputeBranchLoopDispatcherWithCSRDMatrix {
145 template <
typename ThresholdType,
typename LeafOutputType>
146 inline static void Dispatch(
149 const size_t* count_row_ptr, uint64_t* counts_tloc) {
151 TREELITE_CHECK(dmat_) <<
"Dangling data matrix reference detected";
152 ComputeBranchLoopImpl(model, dmat_, rbegin, rend, nthread, count_row_ptr, counts_tloc);
156 template <
typename ThresholdType,
typename LeafOutputType>
159 size_t rend,
int nthread,
const size_t* count_row_ptr,
160 uint64_t* counts_tloc) {
161 switch (dmat->GetType()) {
162 case treelite::DMatrixType::kDense: {
163 treelite::DispatchWithTypeInfo<ComputeBranchLoopDispatcherWithDenseDMatrix>(
164 dmat->GetElementType(), model, dmat, rbegin, rend, nthread, count_row_ptr, counts_tloc);
167 case treelite::DMatrixType::kSparseCSR: {
168 treelite::DispatchWithTypeInfo<ComputeBranchLoopDispatcherWithCSRDMatrix>(
169 dmat->GetElementType(), model, dmat, rbegin, rend, nthread, count_row_ptr, counts_tloc);
174 <<
"Annotator does not support DMatrix of type " <<
static_cast<int>(dmat->GetType());
183 template <
typename ThresholdType,
typename LeafOutputType>
188 std::vector<std::vector<uint64_t>>* out_counts) {
189 std::vector<uint64_t> new_counts;
190 std::vector<uint64_t> counts_tloc;
191 std::vector<size_t> count_row_ptr;
194 const size_t ntree = model.
trees.size();
195 const int max_thread =
static_cast<int>(std::thread::hardware_concurrency());
196 nthread = (nthread == 0) ? max_thread : std::min(nthread, max_thread);
198 count_row_ptr.push_back(count_row_ptr.back() + tree.
num_nodes);
200 new_counts.resize(count_row_ptr[ntree], 0);
201 counts_tloc.resize(count_row_ptr[ntree] * nthread, 0);
203 const size_t num_row = dmat->GetNumRow();
204 const size_t pstep = (num_row + 19) / 20;
206 for (
size_t rbegin = 0; rbegin < num_row; rbegin += pstep) {
207 const size_t rend = std::min(rbegin + pstep, num_row);
208 ComputeBranchLoop(model, dmat, rbegin, rend, nthread, &count_row_ptr[0], &counts_tloc[0]);
210 TREELITE_LOG(INFO) << rend <<
" of " << num_row <<
" rows processed";
215 for (
int tid = 0; tid < nthread; ++tid) {
216 const size_t off = count_row_ptr[ntree] * tid;
217 for (
size_t i = 0; i < count_row_ptr[ntree]; ++i) {
218 new_counts[i] += counts_tloc[off + i];
223 std::vector<std::vector<uint64_t>>& counts = *out_counts;
224 for (
size_t i = 0; i < ntree; ++i) {
225 counts.emplace_back(&new_counts[count_row_ptr[i]], &new_counts[count_row_ptr[i + 1]]);
230 BranchAnnotator::Annotate(
const Model& model,
const DMatrix* dmat,
int nthread,
int verbose) {
231 TypeInfo threshold_type = model.GetThresholdType();
232 model.Dispatch([
this, dmat, nthread, verbose, threshold_type](
auto& handle) {
233 AnnotateImpl(handle, dmat, nthread, verbose, &this->counts_);
238 BranchAnnotator::Load(std::istream& fi) {
239 rapidjson::IStreamWrapper is(fi);
241 rapidjson::Document doc;
244 std::string err_msg =
"JSON file must contain a list of lists of integers";
245 TREELITE_CHECK(doc.IsArray()) << err_msg;
247 for (
const auto& node_cnt : doc.GetArray()) {
248 TREELITE_CHECK(node_cnt.IsArray()) << err_msg;
249 counts_.emplace_back();
250 for (
const auto& e : node_cnt.GetArray()) {
251 counts_.back().push_back(e.GetUint64());
257 BranchAnnotator::Save(std::ostream& fo)
const {
258 rapidjson::OStreamWrapper os(fo);
259 rapidjson::Writer<rapidjson::OStreamWrapper> writer(os);
262 for (
const auto& node_cnt : counts_) {
264 for (
auto e : node_cnt) {
Some useful math utilities.
Operator ComparisonOp(int nid) const
get comparison operator
size_t num_col
number of columns (i.e. # of features used)
bool CheckNAN(T value)
check for NaN (Not a Number)
ElementType missing_value
value representing the missing value (usually NaN)
int DefaultChild(int nid) const
index of the node's "default" child, used when feature is missing
bool CompareWithOp(ElementType lhs, Operator op, ThresholdType rhs)
perform comparison between two float's using a comparsion operator The comparison will be in the form...
std::vector< size_t > row_ptr
pointer to row headers; length is [num_row] + 1.
std::vector< uint32_t > col_ind
feature indices. col_ind[i] indicates the feature index associated with data[i].
std::vector< ElementType > data
feature values
in-memory representation of a decision tree
logging facility for Treelite
uint32_t SplitIndex(int nid) const
feature index of the node's split condition
TypeInfo
Types used by thresholds and leaf outputs.
std::vector< ElementType > data
feature values
std::vector< Tree< ThresholdType, LeafOutputType > > trees
member trees
Implemenation of parallel for loop.
bool CategoriesListRightChild(int nid) const
test whether the list given by MatchingCategories(nid) is associated with the right child node or the...
int num_nodes
number of nodes
SplitFeatureType SplitType(int nid) const
get feature split type
int LeftChild(int nid) const
Getters.
std::vector< uint32_t > MatchingCategories(int nid) const
Get list of all categories belonging to the left/right child node. See the categories_list_right_chil...
int RightChild(int nid) const
index of the node's right child
thin wrapper for tree ensemble model
bool IsLeaf(int nid) const
whether the node is leaf node
ThresholdType Threshold(int nid) const
get threshold of the node
size_t num_col
number of columns (i.e. # of features used)
Operator
comparison operators