12 #include <dmlc/logging.h> 14 #include <dmlc/timer.h> 21 #include <type_traits> 41 size_t query_result_size;
47 template <
typename ElementType,
typename ThresholdType,
typename LeafOutputType,
typename PredFunc>
49 size_t rbegin,
size_t rend, LeafOutputType* out_pred, PredFunc func) {
50 CHECK_LE(dmat->
num_col, static_cast<size_t>(num_feature));
51 std::vector<treelite::predictor::Entry<ThresholdType>> inst(
52 std::max(dmat->
num_col, static_cast<size_t>(num_feature)), {-1});
53 CHECK(rbegin < rend && rend <= dmat->num_row);
54 const ElementType* data = dmat->
data.data();
55 const uint32_t* col_ind = dmat->
col_ind.data();
56 const size_t* row_ptr = dmat->
row_ptr.data();
57 size_t total_output_size = 0;
58 for (
size_t rid = rbegin; rid < rend; ++rid) {
59 const size_t ibegin = row_ptr[rid];
60 const size_t iend = row_ptr[rid + 1];
61 for (
size_t i = ibegin; i < iend; ++i) {
62 inst[col_ind[i]].fvalue =
static_cast<ThresholdType
>(data[i]);
64 total_output_size += func(rid, &inst[0], out_pred);
65 for (
size_t i = ibegin; i < iend; ++i) {
66 inst[col_ind[i]].missing = -1;
69 return total_output_size;
72 template <
typename ElementType,
typename ThresholdType,
typename LeafOutputType,
typename PredFunc>
74 size_t rbegin,
size_t rend, LeafOutputType* out_pred, PredFunc func) {
76 CHECK_LE(dmat->
num_col, static_cast<size_t>(num_feature));
77 std::vector<treelite::predictor::Entry<ThresholdType>> inst(
78 std::max(dmat->
num_col, static_cast<size_t>(num_feature)), {-1});
79 CHECK(rbegin < rend && rend <= dmat->num_row);
80 const size_t num_col = dmat->
num_col;
82 const ElementType* data = dmat->
data.data();
83 const ElementType* row =
nullptr;
84 size_t total_output_size = 0;
85 for (
size_t rid = rbegin; rid < rend; ++rid) {
86 row = &data[rid * num_col];
87 for (
size_t j = 0; j < num_col; ++j) {
88 if (treelite::math::CheckNAN(row[j])) {
90 <<
"The missing_value argument must be set to NaN if there is any NaN in the matrix.";
91 }
else if (nan_missing || row[j] != missing_value) {
92 inst[j].fvalue =
static_cast<ThresholdType
>(row[j]);
95 total_output_size += func(rid, &inst[0], out_pred);
96 for (
size_t j = 0; j < num_col; ++j) {
100 return total_output_size;
103 template <
typename ElementType>
104 class PredLoopDispatcherWithDenseDMatrix {
106 template <
typename ThresholdType,
typename LeafOutputType,
typename PredFunc>
107 inline static size_t Dispatch(
108 const treelite::DMatrix* dmat, ThresholdType,
int num_feature,
size_t rbegin,
size_t rend,
109 LeafOutputType* out_pred, PredFunc func) {
111 return PredLoop<ElementType, ThresholdType, LeafOutputType, PredFunc>(
112 dmat_, num_feature, rbegin, rend, out_pred, func);
116 template <
typename ElementType>
117 class PredLoopDispatcherWithCSRDMatrix {
119 template <
typename ThresholdType,
typename LeafOutputType,
typename PredFunc>
120 inline static size_t Dispatch(
121 const treelite::DMatrix* dmat, ThresholdType,
int num_feature,
size_t rbegin,
size_t rend,
122 LeafOutputType* out_pred, PredFunc func) {
124 return PredLoop<ElementType, ThresholdType, LeafOutputType, PredFunc>(
125 dmat_, num_feature, rbegin, rend, out_pred, func);
129 template <
typename ThresholdType,
typename LeafOutputType,
typename PredFunc>
130 inline size_t PredLoop(
const treelite::DMatrix* dmat, ThresholdType test_val,
int num_feature,
131 size_t rbegin,
size_t rend, LeafOutputType* out_pred, PredFunc func) {
132 treelite::DMatrixType dmat_type = dmat->GetType();
134 case treelite::DMatrixType::kDense: {
135 return treelite::DispatchWithTypeInfo<PredLoopDispatcherWithDenseDMatrix>(
136 dmat->GetElementType(), dmat, test_val, num_feature, rbegin, rend, out_pred, func);
138 case treelite::DMatrixType::kSparseCSR: {
139 return treelite::DispatchWithTypeInfo<PredLoopDispatcherWithCSRDMatrix>(
140 dmat->GetElementType(), dmat, test_val, num_feature, rbegin, rend, out_pred, func);
143 LOG(FATAL) <<
"Unrecognized data matrix type: " <<
static_cast<int>(dmat_type);
151 namespace predictor {
153 SharedLibrary::SharedLibrary() : handle_(nullptr), libpath_() {}
155 SharedLibrary::~SharedLibrary() {
158 FreeLibrary(static_cast<HMODULE>(handle_));
160 dlclose(static_cast<void*>(handle_));
166 SharedLibrary::Load(
const char* libpath) {
168 HMODULE handle = LoadLibraryA(libpath);
170 void* handle = dlopen(libpath, RTLD_LAZY | RTLD_LOCAL);
172 CHECK(handle) <<
"Failed to load dynamic shared library `" << libpath <<
"'";
173 handle_ =
static_cast<LibraryHandle
>(handle);
174 libpath_ = std::string(libpath);
177 SharedLibrary::FunctionHandle
178 SharedLibrary::LoadFunction(
const char* name)
const {
180 FARPROC func_handle = GetProcAddress(static_cast<HMODULE>(handle_), name);
182 void* func_handle = dlsym(static_cast<void*>(handle_), name);
185 <<
"Dynamic shared library `" << libpath_ <<
"' does not contain a function " << name <<
"().";
186 return static_cast<SharedLibrary::FunctionHandle
>(func_handle);
189 template <
typename HandleType>
191 SharedLibrary::LoadFunctionWithSignature(
const char* name)
const {
192 auto func_handle =
reinterpret_cast<HandleType
>(LoadFunction(name));
193 CHECK(func_handle) <<
"Dynamic shared library `" << libpath_ <<
"' does not contain a function " 194 << name <<
"() with the requested signature";
198 template <
typename ThresholdType,
typename LeafOutputType>
201 inline static std::unique_ptr<PredFunction> Dispatch(
202 const SharedLibrary& library,
int num_feature,
int num_class) {
203 return std::make_unique<PredFunctionImpl<ThresholdType, LeafOutputType>>(
204 library, num_feature, num_class);
208 std::unique_ptr<PredFunction>
209 PredFunction::Create(
211 int num_feature,
int num_class) {
212 return DispatchWithModelTypes<PredFunctionInitDispatcher>(
213 threshold_type, leaf_output_type, library, num_feature, num_class);
216 template <
typename ThresholdType,
typename LeafOutputType>
218 const SharedLibrary& library,
int num_feature,
int num_class) {
219 CHECK_GT(num_class, 0) <<
"num_class cannot be zero";
221 handle_ = library.LoadFunction(
"predict_multiclass");
223 handle_ = library.LoadFunction(
"predict");
225 num_feature_ = num_feature;
226 num_class_ = num_class;
229 template <
typename ThresholdType,
typename LeafOutputType>
232 return TypeToInfo<ThresholdType>();
235 template <
typename ThresholdType,
typename LeafOutputType>
238 return TypeToInfo<LeafOutputType>();
241 template <
typename ThresholdType,
typename LeafOutputType>
244 const DMatrix* dmat,
size_t rbegin,
size_t rend,
bool pred_margin,
253 CHECK(rbegin < rend && rend <= dmat->GetNumRow());
254 if (num_class_ > 1) {
256 auto pred_func =
reinterpret_cast<PredFunc
>(handle_);
257 CHECK(pred_func) <<
"The predict_multiclass() function has incorrect signature.";
258 auto pred_func_wrapper
259 = [pred_func, num_class = num_class_, pred_margin]
261 return pred_func(inst, static_cast<int>(pred_margin),
262 &out_pred[rid * num_class]);
264 result_size = PredLoop(dmat, static_cast<ThresholdType>(0), num_feature_, rbegin, rend,
265 static_cast<LeafOutputType*>(out_pred), pred_func_wrapper);
268 auto pred_func =
reinterpret_cast<PredFunc
>(handle_);
269 CHECK(pred_func) <<
"The predict() function has incorrect signature.";
270 auto pred_func_wrapper
271 = [pred_func, pred_margin]
273 out_pred[rid] = pred_func(inst, static_cast<int>(pred_margin));
276 result_size = PredLoop(dmat, static_cast<ThresholdType>(0), num_feature_, rbegin, rend,
277 static_cast<LeafOutputType*>(out_pred), pred_func_wrapper);
282 Predictor::Predictor(
int num_worker_thread)
283 : pred_func_(
nullptr),
284 thread_pool_handle_(
nullptr),
287 sigmoid_alpha_(std::numeric_limits<float>::quiet_NaN()),
288 global_bias_(std::numeric_limits<float>::quiet_NaN()),
289 num_worker_thread_(num_worker_thread),
290 threshold_type_(TypeInfo::kInvalid),
291 leaf_output_type_(TypeInfo::kInvalid) {}
292 Predictor::~Predictor() {
293 if (thread_pool_handle_) {
299 Predictor::Load(
const char* libpath) {
302 using UnsignedQueryFunc = size_t (*)();
303 using StringQueryFunc =
const char* (*)();
304 using FloatQueryFunc = float (*)();
307 auto* num_class_query_func
308 = lib_.LoadFunctionWithSignature<UnsignedQueryFunc>(
"get_num_class");
309 num_class_ = num_class_query_func();
312 auto* num_feature_query_func
313 = lib_.LoadFunctionWithSignature<UnsignedQueryFunc>(
"get_num_feature");
314 num_feature_ = num_feature_query_func();
315 CHECK_GT(num_feature_, 0) <<
"num_feature cannot be zero";
318 auto* pred_transform_query_func
319 = lib_.LoadFunctionWithSignature<StringQueryFunc>(
"get_pred_transform");
320 pred_transform_ = pred_transform_query_func();
323 auto* sigmoid_alpha_query_func
324 = lib_.LoadFunctionWithSignature<FloatQueryFunc>(
"get_sigmoid_alpha");
325 sigmoid_alpha_ = sigmoid_alpha_query_func();
328 auto* global_bias_query_func = lib_.LoadFunctionWithSignature<FloatQueryFunc>(
"get_global_bias");
329 global_bias_ = global_bias_query_func();
332 auto* threshold_type_query_func
333 = lib_.LoadFunctionWithSignature<StringQueryFunc>(
"get_threshold_type");
335 auto* leaf_output_type_query_func
336 = lib_.LoadFunctionWithSignature<StringQueryFunc>(
"get_leaf_output_type");
340 CHECK_GT(num_class_, 0) <<
"num_class cannot be zero";
341 pred_func_ = PredFunction::Create(
342 threshold_type_, leaf_output_type_, lib_,
343 static_cast<int>(num_feature_), static_cast<int>(num_class_));
345 if (num_worker_thread_ == -1) {
346 num_worker_thread_ =
static_cast<int>(std::thread::hardware_concurrency());
349 new PredThreadPool(num_worker_thread_ - 1,
this,
353 predictor->exception_catcher_.Run([&]() {
355 while (incoming_queue->Pop(&input)) {
356 const size_t rbegin = input.rbegin;
357 const size_t rend = input.rend;
358 size_t query_result_size
359 = predictor->pred_func_->PredictBatch(
360 input.dmat, rbegin, rend, input.pred_margin, input.out_pred);
361 outgoing_queue->Push(OutputToken{query_result_size});
369 delete static_cast<PredThreadPool*
>(thread_pool_handle_);
373 std::vector<size_t> SplitBatch(
const DMatrix* dmat,
size_t split_factor) {
374 const size_t num_row = dmat->GetNumRow();
375 CHECK_LE(split_factor, num_row);
376 const size_t portion = num_row / split_factor;
377 const size_t remainder = num_row % split_factor;
378 std::vector<size_t> workload(split_factor, portion);
379 std::vector<size_t> row_ptr(split_factor + 1, 0);
380 for (
size_t i = 0; i < remainder; ++i) {
384 for (
size_t i = 0; i < split_factor; ++i) {
385 accum += workload[i];
386 row_ptr[i + 1] = accum;
391 template <
typename LeafOutputType>
394 inline static void Dispatch(
395 size_t num_row,
size_t query_size_per_instance,
size_t num_class,
399 template <
typename LeafOutputType>
405 template <
typename LeafOutputType>
412 Predictor::PredictBatch(
414 const double tstart = dmlc::GetTime();
416 const size_t num_row = dmat->GetNumRow();
417 auto* pool =
static_cast<PredThreadPool*
>(thread_pool_handle_);
418 InputToken request{dmat, pred_margin, pred_func_.get(), 0, num_row, out_result};
419 OutputToken response;
420 CHECK_GT(num_row, 0);
421 const int nthread = std::min(num_worker_thread_, static_cast<int>(num_row));
422 const std::vector<size_t> row_ptr = SplitBatch(dmat, nthread);
423 for (
int tid = 0; tid < nthread - 1; ++tid) {
424 request.rbegin = row_ptr[tid];
425 request.rend = row_ptr[tid + 1];
426 pool->SubmitTask(tid, request);
428 size_t total_size = 0;
431 const size_t rbegin = row_ptr[nthread - 1];
432 const size_t rend = row_ptr[nthread];
433 const size_t query_result_size
434 = pred_func_->PredictBatch(dmat, rbegin, rend, pred_margin, out_result);
435 total_size += query_result_size;
437 for (
int tid = 0; tid < nthread - 1; ++tid) {
438 if (pool->WaitForTask(tid, &response)) {
439 total_size += response.query_result_size;
443 if (total_size < QueryResultSize(dmat, 0, num_row)) {
444 CHECK_GT(num_class_, 1);
445 CHECK_EQ(total_size % num_row, 0);
446 const size_t query_size_per_instance = total_size / num_row;
447 CHECK_GT(query_size_per_instance, 0);
448 CHECK_LT(query_size_per_instance, num_class_);
449 DispatchWithTypeInfo<ShrinkResultToFit>(
450 leaf_output_type_, num_row, query_size_per_instance, num_class_, out_result);
452 const double tend = dmlc::GetTime();
454 LOG(INFO) <<
"Treelite: Finished prediction in " << tend - tstart <<
" sec";
460 Predictor::CreateOutputVector(
const DMatrix* dmat)
const {
461 const size_t output_vector_size = this->QueryResultSize(dmat);
462 return DispatchWithTypeInfo<AllocateOutputVector>(leaf_output_type_, output_vector_size);
467 DispatchWithTypeInfo<DeallocateOutputVector>(leaf_output_type_, output_vector);
470 template <
typename LeafOutputType>
473 size_t num_row,
size_t query_size_per_instance,
size_t num_class,
475 auto* out_result_ =
static_cast<LeafOutputType*
>(out_result);
476 for (
size_t rid = 0; rid < num_row; ++rid) {
477 for (
size_t k = 0; k < query_size_per_instance; ++k) {
478 out_result_[rid * query_size_per_instance + k] = out_result_[rid * num_class + k];
483 template <
typename LeafOutputType>
489 template <
typename LeafOutputType>
492 delete[] (
static_cast<LeafOutputType*
>(output_vector));
Load prediction function exported as a shared library.
Some useful math utilities.
size_t num_col
number of columns (i.e. # of features used)
Input data structure of Treelite.
bool CheckNAN(T value)
check for NaN (Not a Number)
ElementType missing_value
value representing the missing value (usually NaN)
void * PredictorOutputHandle
handle to output from predictor
std::vector< size_t > row_ptr
pointer to row headers; length is [num_row] + 1.
std::vector< uint32_t > col_ind
feature indices. col_ind[i] indicates the feature index associated with data[i].
std::vector< ElementType > data
feature values
TypeInfo
Types used by thresholds and leaf outputs.
std::vector< ElementType > data
feature values
Defines TypeInfo class and utilities.
TypeInfo GetTypeInfoByName(const std::string &str)
conversion table from string to TypeInfo, defined in tables.cc
a simple thread pool implementation
data layout. The value -1 signifies the missing value. When the "missing" field is set to -1...
predictor class: wrapper for optimized prediction code
void * ThreadPoolHandle
opaque handle types
size_t num_col
number of columns (i.e. # of features used)