18 #include <type_traits> 30 inline double GetTime(
void) {
31 return std::chrono::duration<double>(
32 std::chrono::high_resolution_clock::now().time_since_epoch()).count();
44 size_t query_result_size;
50 template <
typename ElementType,
typename ThresholdType,
typename LeafOutputType,
typename PredFunc>
52 size_t rbegin,
size_t rend, LeafOutputType* out_pred, PredFunc func) {
53 TREELITE_CHECK_LE(dmat->
num_col, static_cast<size_t>(num_feature));
54 std::vector<treelite::predictor::Entry<ThresholdType>> inst(
55 std::max(dmat->
num_col, static_cast<size_t>(num_feature)), {-1});
56 TREELITE_CHECK(rbegin < rend && rend <= dmat->num_row);
57 const ElementType* data = dmat->
data.data();
58 const uint32_t* col_ind = dmat->
col_ind.data();
59 const size_t* row_ptr = dmat->
row_ptr.data();
60 size_t total_output_size = 0;
61 for (
size_t rid = rbegin; rid < rend; ++rid) {
62 const size_t ibegin = row_ptr[rid];
63 const size_t iend = row_ptr[rid + 1];
64 for (
size_t i = ibegin; i < iend; ++i) {
65 inst[col_ind[i]].fvalue =
static_cast<ThresholdType
>(data[i]);
67 total_output_size += func(rid, &inst[0], out_pred);
68 for (
size_t i = ibegin; i < iend; ++i) {
69 inst[col_ind[i]].missing = -1;
72 return total_output_size;
75 template <
typename ElementType,
typename ThresholdType,
typename LeafOutputType,
typename PredFunc>
77 size_t rbegin,
size_t rend, LeafOutputType* out_pred, PredFunc func) {
79 TREELITE_CHECK_LE(dmat->
num_col, static_cast<size_t>(num_feature));
80 std::vector<treelite::predictor::Entry<ThresholdType>> inst(
81 std::max(dmat->
num_col, static_cast<size_t>(num_feature)), {-1});
82 TREELITE_CHECK(rbegin < rend && rend <= dmat->num_row);
83 const size_t num_col = dmat->
num_col;
85 const ElementType* data = dmat->
data.data();
86 const ElementType* row =
nullptr;
87 size_t total_output_size = 0;
88 for (
size_t rid = rbegin; rid < rend; ++rid) {
89 row = &data[rid * num_col];
90 for (
size_t j = 0; j < num_col; ++j) {
91 if (treelite::math::CheckNAN(row[j])) {
92 TREELITE_CHECK(nan_missing)
93 <<
"The missing_value argument must be set to NaN if there is any NaN in the matrix.";
94 }
else if (nan_missing || row[j] != missing_value) {
95 inst[j].fvalue =
static_cast<ThresholdType
>(row[j]);
98 total_output_size += func(rid, &inst[0], out_pred);
99 for (
size_t j = 0; j < num_col; ++j) {
100 inst[j].missing = -1;
103 return total_output_size;
106 template <
typename ElementType>
107 class PredLoopDispatcherWithDenseDMatrix {
109 template <
typename ThresholdType,
typename LeafOutputType,
typename PredFunc>
110 inline static size_t Dispatch(
111 const treelite::DMatrix* dmat, ThresholdType,
int num_feature,
size_t rbegin,
size_t rend,
112 LeafOutputType* out_pred, PredFunc func) {
114 return PredLoop<ElementType, ThresholdType, LeafOutputType, PredFunc>(
115 dmat_, num_feature, rbegin, rend, out_pred, func);
119 template <
typename ElementType>
120 class PredLoopDispatcherWithCSRDMatrix {
122 template <
typename ThresholdType,
typename LeafOutputType,
typename PredFunc>
123 inline static size_t Dispatch(
124 const treelite::DMatrix* dmat, ThresholdType,
int num_feature,
size_t rbegin,
size_t rend,
125 LeafOutputType* out_pred, PredFunc func) {
127 return PredLoop<ElementType, ThresholdType, LeafOutputType, PredFunc>(
128 dmat_, num_feature, rbegin, rend, out_pred, func);
132 template <
typename ThresholdType,
typename LeafOutputType,
typename PredFunc>
133 inline size_t PredLoop(
const treelite::DMatrix* dmat, ThresholdType test_val,
int num_feature,
134 size_t rbegin,
size_t rend, LeafOutputType* out_pred, PredFunc func) {
135 treelite::DMatrixType dmat_type = dmat->GetType();
137 case treelite::DMatrixType::kDense: {
138 return treelite::DispatchWithTypeInfo<PredLoopDispatcherWithDenseDMatrix>(
139 dmat->GetElementType(), dmat, test_val, num_feature, rbegin, rend, out_pred, func);
141 case treelite::DMatrixType::kSparseCSR: {
142 return treelite::DispatchWithTypeInfo<PredLoopDispatcherWithCSRDMatrix>(
143 dmat->GetElementType(), dmat, test_val, num_feature, rbegin, rend, out_pred, func);
146 TREELITE_LOG(FATAL) <<
"Unrecognized data matrix type: " <<
static_cast<int>(dmat_type);
154 namespace predictor {
156 SharedLibrary::SharedLibrary() : handle_(nullptr), libpath_() {}
158 SharedLibrary::~SharedLibrary() {
161 FreeLibrary(static_cast<HMODULE>(handle_));
163 dlclose(static_cast<void*>(handle_));
169 SharedLibrary::Load(
const char* libpath) {
171 HMODULE handle = LoadLibraryA(libpath);
173 void* handle = dlopen(libpath, RTLD_LAZY | RTLD_LOCAL);
175 TREELITE_CHECK(handle) <<
"Failed to load dynamic shared library `" << libpath <<
"'";
176 handle_ =
static_cast<LibraryHandle
>(handle);
177 libpath_ = std::string(libpath);
180 SharedLibrary::FunctionHandle
181 SharedLibrary::LoadFunction(
const char* name)
const {
183 FARPROC func_handle = GetProcAddress(static_cast<HMODULE>(handle_), name);
185 void* func_handle = dlsym(static_cast<void*>(handle_), name);
187 TREELITE_CHECK(func_handle)
188 <<
"Dynamic shared library `" << libpath_ <<
"' does not contain a function " << name <<
"().";
189 return static_cast<SharedLibrary::FunctionHandle
>(func_handle);
192 template <
typename HandleType>
194 SharedLibrary::LoadFunctionWithSignature(
const char* name)
const {
195 auto func_handle =
reinterpret_cast<HandleType
>(LoadFunction(name));
196 TREELITE_CHECK(func_handle) <<
"Dynamic shared library `" << libpath_
197 <<
"' does not contain a function " << name
198 <<
"() with the requested signature";
202 template <
typename ThresholdType,
typename LeafOutputType>
205 inline static std::unique_ptr<PredFunction> Dispatch(
206 const SharedLibrary& library,
int num_feature,
int num_class) {
207 return std::make_unique<PredFunctionImpl<ThresholdType, LeafOutputType>>(
208 library, num_feature, num_class);
212 std::unique_ptr<PredFunction>
213 PredFunction::Create(
215 int num_feature,
int num_class) {
216 return DispatchWithModelTypes<PredFunctionInitDispatcher>(
217 threshold_type, leaf_output_type, library, num_feature, num_class);
220 template <
typename ThresholdType,
typename LeafOutputType>
222 const SharedLibrary& library,
int num_feature,
int num_class) {
223 TREELITE_CHECK_GT(num_class, 0) <<
"num_class cannot be zero";
225 handle_ = library.LoadFunction(
"predict_multiclass");
227 handle_ = library.LoadFunction(
"predict");
229 num_feature_ = num_feature;
230 num_class_ = num_class;
233 template <
typename ThresholdType,
typename LeafOutputType>
236 return TypeToInfo<ThresholdType>();
239 template <
typename ThresholdType,
typename LeafOutputType>
242 return TypeToInfo<LeafOutputType>();
245 template <
typename ThresholdType,
typename LeafOutputType>
248 const DMatrix* dmat,
size_t rbegin,
size_t rend,
bool pred_margin,
257 TREELITE_CHECK(rbegin < rend && rend <= dmat->GetNumRow());
258 if (num_class_ > 1) {
260 auto pred_func =
reinterpret_cast<PredFunc
>(handle_);
261 TREELITE_CHECK(pred_func) <<
"The predict_multiclass() function has incorrect signature.";
262 auto pred_func_wrapper
263 = [pred_func, num_class = num_class_, pred_margin]
265 return pred_func(inst, static_cast<int>(pred_margin),
266 &out_pred[rid * num_class]);
268 result_size = PredLoop(dmat, static_cast<ThresholdType>(0), num_feature_, rbegin, rend,
269 static_cast<LeafOutputType*>(out_pred), pred_func_wrapper);
272 auto pred_func =
reinterpret_cast<PredFunc
>(handle_);
273 TREELITE_CHECK(pred_func) <<
"The predict() function has incorrect signature.";
274 auto pred_func_wrapper
275 = [pred_func, pred_margin]
277 out_pred[rid] = pred_func(inst, static_cast<int>(pred_margin));
280 result_size = PredLoop(dmat, static_cast<ThresholdType>(0), num_feature_, rbegin, rend,
281 static_cast<LeafOutputType*>(out_pred), pred_func_wrapper);
286 Predictor::Predictor(
int num_worker_thread)
287 : pred_func_(
nullptr),
288 thread_pool_handle_(
nullptr),
291 sigmoid_alpha_(std::numeric_limits<float>::quiet_NaN()),
292 global_bias_(std::numeric_limits<float>::quiet_NaN()),
293 num_worker_thread_(num_worker_thread),
294 threshold_type_(TypeInfo::kInvalid),
295 leaf_output_type_(TypeInfo::kInvalid) {}
296 Predictor::~Predictor() {
297 if (thread_pool_handle_) {
303 Predictor::Load(
const char* libpath) {
306 using UnsignedQueryFunc = size_t (*)();
307 using StringQueryFunc =
const char* (*)();
308 using FloatQueryFunc = float (*)();
311 auto* num_class_query_func
312 = lib_.LoadFunctionWithSignature<UnsignedQueryFunc>(
"get_num_class");
313 num_class_ = num_class_query_func();
316 auto* num_feature_query_func
317 = lib_.LoadFunctionWithSignature<UnsignedQueryFunc>(
"get_num_feature");
318 num_feature_ = num_feature_query_func();
319 TREELITE_CHECK_GT(num_feature_, 0) <<
"num_feature cannot be zero";
322 auto* pred_transform_query_func
323 = lib_.LoadFunctionWithSignature<StringQueryFunc>(
"get_pred_transform");
324 pred_transform_ = pred_transform_query_func();
327 auto* sigmoid_alpha_query_func
328 = lib_.LoadFunctionWithSignature<FloatQueryFunc>(
"get_sigmoid_alpha");
329 sigmoid_alpha_ = sigmoid_alpha_query_func();
332 auto* global_bias_query_func = lib_.LoadFunctionWithSignature<FloatQueryFunc>(
"get_global_bias");
333 global_bias_ = global_bias_query_func();
336 auto* threshold_type_query_func
337 = lib_.LoadFunctionWithSignature<StringQueryFunc>(
"get_threshold_type");
339 auto* leaf_output_type_query_func
340 = lib_.LoadFunctionWithSignature<StringQueryFunc>(
"get_leaf_output_type");
344 TREELITE_CHECK_GT(num_class_, 0) <<
"num_class cannot be zero";
345 pred_func_ = PredFunction::Create(
346 threshold_type_, leaf_output_type_, lib_,
347 static_cast<int>(num_feature_), static_cast<int>(num_class_));
349 if (num_worker_thread_ == -1) {
350 num_worker_thread_ =
static_cast<int>(std::thread::hardware_concurrency());
353 new PredThreadPool(num_worker_thread_ - 1,
this,
357 predictor->exception_catcher_.
Run([&]() {
359 while (incoming_queue->Pop(&input)) {
360 const size_t rbegin = input.rbegin;
361 const size_t rend = input.rend;
362 size_t query_result_size
363 = predictor->pred_func_->PredictBatch(
364 input.dmat, rbegin, rend, input.pred_margin, input.out_pred);
365 outgoing_queue->Push(OutputToken{query_result_size});
373 delete static_cast<PredThreadPool*
>(thread_pool_handle_);
377 std::vector<size_t> SplitBatch(
const DMatrix* dmat,
size_t split_factor) {
378 const size_t num_row = dmat->GetNumRow();
379 TREELITE_CHECK_LE(split_factor, num_row);
380 const size_t portion = num_row / split_factor;
381 const size_t remainder = num_row % split_factor;
382 std::vector<size_t> workload(split_factor, portion);
383 std::vector<size_t> row_ptr(split_factor + 1, 0);
384 for (
size_t i = 0; i < remainder; ++i) {
388 for (
size_t i = 0; i < split_factor; ++i) {
389 accum += workload[i];
390 row_ptr[i + 1] = accum;
395 template <
typename LeafOutputType>
398 inline static void Dispatch(
399 size_t num_row,
size_t query_size_per_instance,
size_t num_class,
403 template <
typename LeafOutputType>
409 template <
typename LeafOutputType>
416 Predictor::PredictBatch(
418 const double tstart = GetTime();
420 const size_t num_row = dmat->GetNumRow();
421 auto* pool =
static_cast<PredThreadPool*
>(thread_pool_handle_);
422 InputToken request{dmat, pred_margin, pred_func_.get(), 0, num_row, out_result};
423 OutputToken response;
424 TREELITE_CHECK_GT(num_row, 0);
425 const int nthread = std::min(num_worker_thread_, static_cast<int>(num_row));
426 const std::vector<size_t> row_ptr = SplitBatch(dmat, nthread);
427 for (
int tid = 0; tid < nthread - 1; ++tid) {
428 request.rbegin = row_ptr[tid];
429 request.rend = row_ptr[tid + 1];
430 pool->SubmitTask(tid, request);
432 size_t total_size = 0;
435 const size_t rbegin = row_ptr[nthread - 1];
436 const size_t rend = row_ptr[nthread];
437 const size_t query_result_size
438 = pred_func_->PredictBatch(dmat, rbegin, rend, pred_margin, out_result);
439 total_size += query_result_size;
441 for (
int tid = 0; tid < nthread - 1; ++tid) {
442 if (pool->WaitForTask(tid, &response)) {
443 total_size += response.query_result_size;
447 if (total_size < QueryResultSize(dmat, 0, num_row)) {
448 TREELITE_CHECK_GT(num_class_, 1);
449 TREELITE_CHECK_EQ(total_size % num_row, 0);
450 const size_t query_size_per_instance = total_size / num_row;
451 TREELITE_CHECK_GT(query_size_per_instance, 0);
452 TREELITE_CHECK_LT(query_size_per_instance, num_class_);
453 DispatchWithTypeInfo<ShrinkResultToFit>(
454 leaf_output_type_, num_row, query_size_per_instance, num_class_, out_result);
456 const double tend = GetTime();
458 TREELITE_LOG(INFO) <<
"Treelite: Finished prediction in " << tend - tstart <<
" sec";
464 Predictor::CreateOutputVector(
const DMatrix* dmat)
const {
465 const size_t output_vector_size = this->QueryResultSize(dmat);
466 return DispatchWithTypeInfo<AllocateOutputVector>(leaf_output_type_, output_vector_size);
471 DispatchWithTypeInfo<DeallocateOutputVector>(leaf_output_type_, output_vector);
474 template <
typename LeafOutputType>
477 size_t num_row,
size_t query_size_per_instance,
size_t num_class,
479 auto* out_result_ =
static_cast<LeafOutputType*
>(out_result);
480 for (
size_t rid = 0; rid < num_row; ++rid) {
481 for (
size_t k = 0; k < query_size_per_instance; ++k) {
482 out_result_[rid * query_size_per_instance + k] = out_result_[rid * num_class + k];
487 template <
typename LeafOutputType>
493 template <
typename LeafOutputType>
496 delete[] (
static_cast<LeafOutputType*
>(output_vector));
Load prediction function exported as a shared library.
Some useful math utilities.
size_t num_col
number of columns (i.e. # of features used)
Input data structure of Treelite.
bool CheckNAN(T value)
check for NaN (Not a Number)
ElementType missing_value
value representing the missing value (usually NaN)
void Run(Function f, Parameters... params)
Parallel OMP blocks should be placed within Run to save exception.
void * PredictorOutputHandle
handle to output from predictor
std::vector< size_t > row_ptr
pointer to row headers; length is [num_row] + 1.
std::vector< uint32_t > col_ind
feature indices. col_ind[i] indicates the feature index associated with data[i].
std::vector< ElementType > data
feature values
TypeInfo
Types used by thresholds and leaf outputs.
std::vector< ElementType > data
feature values
Defines TypeInfo class and utilities.
TypeInfo GetTypeInfoByName(const std::string &str)
conversion table from string to TypeInfo, defined in tables.cc
a simple thread pool implementation
data layout. The value -1 signifies the missing value. When the "missing" field is set to -1...
predictor class: wrapper for optimized prediction code
void * ThreadPoolHandle
opaque handle types
size_t num_col
number of columns (i.e. # of features used)