18 #include <type_traits> 30 inline double GetTime(
void) {
31 return std::chrono::duration<double>(
32 std::chrono::high_resolution_clock::now().time_since_epoch()).count();
44 size_t query_result_size;
50 template <
typename ElementType,
typename ThresholdType,
typename LeafOutputType,
typename PredFunc>
52 size_t rbegin,
size_t rend, LeafOutputType* out_pred, PredFunc func) {
53 TREELITE_CHECK_LE(dmat->
num_col, static_cast<size_t>(num_feature));
54 std::vector<treelite::predictor::Entry<ThresholdType>> inst(
55 std::max(dmat->
num_col, static_cast<size_t>(num_feature)), {-1});
56 TREELITE_CHECK(rbegin < rend && rend <= dmat->num_row);
57 const ElementType* data = dmat->
data.data();
58 const uint32_t* col_ind = dmat->
col_ind.data();
59 const size_t* row_ptr = dmat->
row_ptr.data();
60 size_t total_output_size = 0;
61 for (
size_t rid = rbegin; rid < rend; ++rid) {
62 const size_t ibegin = row_ptr[rid];
63 const size_t iend = row_ptr[rid + 1];
64 for (
size_t i = ibegin; i < iend; ++i) {
65 inst[col_ind[i]].fvalue =
static_cast<ThresholdType
>(data[i]);
67 total_output_size += func(rid, &inst[0], out_pred);
68 for (
size_t i = ibegin; i < iend; ++i) {
69 inst[col_ind[i]].missing = -1;
72 return total_output_size;
75 template <
typename ElementType,
typename ThresholdType,
typename LeafOutputType,
typename PredFunc>
77 size_t rbegin,
size_t rend, LeafOutputType* out_pred, PredFunc func) {
79 TREELITE_CHECK_LE(dmat->
num_col, static_cast<size_t>(num_feature));
80 std::vector<treelite::predictor::Entry<ThresholdType>> inst(
81 std::max(dmat->
num_col, static_cast<size_t>(num_feature)), {-1});
82 TREELITE_CHECK(rbegin < rend && rend <= dmat->num_row);
83 const size_t num_col = dmat->
num_col;
85 const ElementType* data = dmat->
data.data();
86 const ElementType* row =
nullptr;
87 size_t total_output_size = 0;
88 for (
size_t rid = rbegin; rid < rend; ++rid) {
89 row = &data[rid * num_col];
90 for (
size_t j = 0; j < num_col; ++j) {
91 if (treelite::math::CheckNAN(row[j])) {
92 TREELITE_CHECK(nan_missing)
93 <<
"The missing_value argument must be set to NaN if there is any NaN in the matrix.";
94 }
else if (nan_missing || row[j] != missing_value) {
95 inst[j].fvalue =
static_cast<ThresholdType
>(row[j]);
98 total_output_size += func(rid, &inst[0], out_pred);
99 for (
size_t j = 0; j < num_col; ++j) {
100 inst[j].missing = -1;
103 return total_output_size;
106 template <
typename ElementType>
107 class PredLoopDispatcherWithDenseDMatrix {
109 template <
typename ThresholdType,
typename LeafOutputType,
typename PredFunc>
110 inline static size_t Dispatch(
111 const treelite::DMatrix* dmat, ThresholdType,
int num_feature,
size_t rbegin,
size_t rend,
112 LeafOutputType* out_pred, PredFunc func) {
114 return PredLoop<ElementType, ThresholdType, LeafOutputType, PredFunc>(
115 dmat_, num_feature, rbegin, rend, out_pred, func);
119 template <
typename ElementType>
120 class PredLoopDispatcherWithCSRDMatrix {
122 template <
typename ThresholdType,
typename LeafOutputType,
typename PredFunc>
123 inline static size_t Dispatch(
124 const treelite::DMatrix* dmat, ThresholdType,
int num_feature,
size_t rbegin,
size_t rend,
125 LeafOutputType* out_pred, PredFunc func) {
127 return PredLoop<ElementType, ThresholdType, LeafOutputType, PredFunc>(
128 dmat_, num_feature, rbegin, rend, out_pred, func);
132 template <
typename ThresholdType,
typename LeafOutputType,
typename PredFunc>
133 inline size_t PredLoop(
const treelite::DMatrix* dmat, ThresholdType test_val,
int num_feature,
134 size_t rbegin,
size_t rend, LeafOutputType* out_pred, PredFunc func) {
135 treelite::DMatrixType dmat_type = dmat->GetType();
137 case treelite::DMatrixType::kDense: {
138 return treelite::DispatchWithTypeInfo<PredLoopDispatcherWithDenseDMatrix>(
139 dmat->GetElementType(), dmat, test_val, num_feature, rbegin, rend, out_pred, func);
141 case treelite::DMatrixType::kSparseCSR: {
142 return treelite::DispatchWithTypeInfo<PredLoopDispatcherWithCSRDMatrix>(
143 dmat->GetElementType(), dmat, test_val, num_feature, rbegin, rend, out_pred, func);
146 TREELITE_LOG(FATAL) <<
"Unrecognized data matrix type: " <<
static_cast<int>(dmat_type);
154 namespace predictor {
156 SharedLibrary::SharedLibrary() : handle_(nullptr), libpath_() {}
158 SharedLibrary::~SharedLibrary() {
161 FreeLibrary(static_cast<HMODULE>(handle_));
163 dlclose(static_cast<void*>(handle_));
169 SharedLibrary::Load(
const char* libpath) {
171 HMODULE handle = LoadLibraryA(libpath);
173 void* handle = dlopen(libpath, RTLD_LAZY | RTLD_LOCAL);
175 TREELITE_CHECK(handle) <<
"Failed to load dynamic shared library `" << libpath <<
"'";
176 handle_ =
static_cast<LibraryHandle
>(handle);
177 libpath_ = std::string(libpath);
180 SharedLibrary::FunctionHandle
181 SharedLibrary::LoadFunction(
const char* name)
const {
183 FARPROC func_handle = GetProcAddress(static_cast<HMODULE>(handle_), name);
185 void* func_handle = dlsym(static_cast<void*>(handle_), name);
187 TREELITE_CHECK(func_handle)
188 <<
"Dynamic shared library `" << libpath_ <<
"' does not contain a function " << name <<
"().";
189 return static_cast<SharedLibrary::FunctionHandle
>(func_handle);
192 template <
typename HandleType>
194 SharedLibrary::LoadFunctionWithSignature(
const char* name)
const {
195 auto func_handle =
reinterpret_cast<HandleType
>(LoadFunction(name));
196 TREELITE_CHECK(func_handle) <<
"Dynamic shared library `" << libpath_
197 <<
"' does not contain a function " << name
198 <<
"() with the requested signature";
202 template <
typename ThresholdType,
typename LeafOutputType>
205 inline static std::unique_ptr<PredFunction> Dispatch(
206 const SharedLibrary& library,
int num_feature,
int num_class) {
207 return std::make_unique<PredFunctionImpl<ThresholdType, LeafOutputType>>(
208 library, num_feature, num_class);
212 std::unique_ptr<PredFunction>
213 PredFunction::Create(
215 int num_feature,
int num_class) {
216 return DispatchWithModelTypes<PredFunctionInitDispatcher>(
217 threshold_type, leaf_output_type, library, num_feature, num_class);
220 template <
typename ThresholdType,
typename LeafOutputType>
222 const SharedLibrary& library,
int num_feature,
int num_class) {
223 TREELITE_CHECK_GT(num_class, 0) <<
"num_class cannot be zero";
225 handle_ = library.LoadFunction(
"predict_multiclass");
227 handle_ = library.LoadFunction(
"predict");
229 num_feature_ = num_feature;
230 num_class_ = num_class;
233 template <
typename ThresholdType,
typename LeafOutputType>
236 return TypeToInfo<ThresholdType>();
239 template <
typename ThresholdType,
typename LeafOutputType>
242 return TypeToInfo<LeafOutputType>();
245 template <
typename ThresholdType,
typename LeafOutputType>
248 const DMatrix* dmat,
size_t rbegin,
size_t rend,
bool pred_margin,
257 TREELITE_CHECK(rbegin < rend && rend <= dmat->GetNumRow());
258 if (num_class_ > 1) {
260 auto pred_func =
reinterpret_cast<PredFunc
>(handle_);
261 TREELITE_CHECK(pred_func) <<
"The predict_multiclass() function has incorrect signature.";
262 auto pred_func_wrapper
263 = [pred_func, num_class = num_class_, pred_margin]
265 return pred_func(inst, static_cast<int>(pred_margin),
266 &out_pred[rid * num_class]);
268 result_size = PredLoop(dmat, static_cast<ThresholdType>(0), num_feature_, rbegin, rend,
269 static_cast<LeafOutputType*>(out_pred), pred_func_wrapper);
272 auto pred_func =
reinterpret_cast<PredFunc
>(handle_);
273 TREELITE_CHECK(pred_func) <<
"The predict() function has incorrect signature.";
274 auto pred_func_wrapper
275 = [pred_func, pred_margin]
277 out_pred[rid] = pred_func(inst, static_cast<int>(pred_margin));
280 result_size = PredLoop(dmat, static_cast<ThresholdType>(0), num_feature_, rbegin, rend,
281 static_cast<LeafOutputType*>(out_pred), pred_func_wrapper);
286 Predictor::Predictor(
int num_worker_thread)
287 : pred_func_(
nullptr),
288 thread_pool_handle_(
nullptr),
291 sigmoid_alpha_(std::numeric_limits<float>::quiet_NaN()),
292 ratio_c_(std::numeric_limits<float>::quiet_NaN()),
293 global_bias_(std::numeric_limits<float>::quiet_NaN()),
294 num_worker_thread_(num_worker_thread),
295 threshold_type_(TypeInfo::kInvalid),
296 leaf_output_type_(TypeInfo::kInvalid) {}
297 Predictor::~Predictor() {
298 if (thread_pool_handle_) {
304 Predictor::Load(
const char* libpath) {
307 using UnsignedQueryFunc = size_t (*)();
308 using StringQueryFunc =
const char* (*)();
309 using FloatQueryFunc = float (*)();
312 auto* num_class_query_func
313 = lib_.LoadFunctionWithSignature<UnsignedQueryFunc>(
"get_num_class");
314 num_class_ = num_class_query_func();
317 auto* num_feature_query_func
318 = lib_.LoadFunctionWithSignature<UnsignedQueryFunc>(
"get_num_feature");
319 num_feature_ = num_feature_query_func();
320 TREELITE_CHECK_GT(num_feature_, 0) <<
"num_feature cannot be zero";
323 auto* pred_transform_query_func
324 = lib_.LoadFunctionWithSignature<StringQueryFunc>(
"get_pred_transform");
325 pred_transform_ = pred_transform_query_func();
328 auto* sigmoid_alpha_query_func
329 = lib_.LoadFunctionWithSignature<FloatQueryFunc>(
"get_sigmoid_alpha");
330 sigmoid_alpha_ = sigmoid_alpha_query_func();
333 auto* ratio_c_query_func
334 = lib_.LoadFunctionWithSignature<FloatQueryFunc>(
"get_ratio_c");
335 ratio_c_ = ratio_c_query_func();
338 auto* global_bias_query_func = lib_.LoadFunctionWithSignature<FloatQueryFunc>(
"get_global_bias");
339 global_bias_ = global_bias_query_func();
342 auto* threshold_type_query_func
343 = lib_.LoadFunctionWithSignature<StringQueryFunc>(
"get_threshold_type");
345 auto* leaf_output_type_query_func
346 = lib_.LoadFunctionWithSignature<StringQueryFunc>(
"get_leaf_output_type");
350 TREELITE_CHECK_GT(num_class_, 0) <<
"num_class cannot be zero";
351 pred_func_ = PredFunction::Create(
352 threshold_type_, leaf_output_type_, lib_,
353 static_cast<int>(num_feature_), static_cast<int>(num_class_));
355 if (num_worker_thread_ == -1) {
356 num_worker_thread_ =
static_cast<int>(std::thread::hardware_concurrency());
359 new PredThreadPool(num_worker_thread_ - 1,
this,
363 predictor->exception_catcher_.
Run([&]() {
365 while (incoming_queue->Pop(&input)) {
366 const size_t rbegin = input.rbegin;
367 const size_t rend = input.rend;
368 size_t query_result_size
369 = predictor->pred_func_->PredictBatch(
370 input.dmat, rbegin, rend, input.pred_margin, input.out_pred);
371 outgoing_queue->Push(OutputToken{query_result_size});
379 delete static_cast<PredThreadPool*
>(thread_pool_handle_);
383 std::vector<size_t> SplitBatch(
const DMatrix* dmat,
size_t split_factor) {
384 const size_t num_row = dmat->GetNumRow();
385 TREELITE_CHECK_LE(split_factor, num_row);
386 const size_t portion = num_row / split_factor;
387 const size_t remainder = num_row % split_factor;
388 std::vector<size_t> workload(split_factor, portion);
389 std::vector<size_t> row_ptr(split_factor + 1, 0);
390 for (
size_t i = 0; i < remainder; ++i) {
394 for (
size_t i = 0; i < split_factor; ++i) {
395 accum += workload[i];
396 row_ptr[i + 1] = accum;
401 template <
typename LeafOutputType>
404 inline static void Dispatch(
405 size_t num_row,
size_t query_size_per_instance,
size_t num_class,
409 template <
typename LeafOutputType>
415 template <
typename LeafOutputType>
422 Predictor::PredictBatch(
424 const double tstart = GetTime();
426 const size_t num_row = dmat->GetNumRow();
427 auto* pool =
static_cast<PredThreadPool*
>(thread_pool_handle_);
428 InputToken request{dmat, pred_margin, pred_func_.get(), 0, num_row, out_result};
429 OutputToken response;
430 TREELITE_CHECK_GT(num_row, 0);
431 const int nthread = std::min(num_worker_thread_, static_cast<int>(num_row));
432 const std::vector<size_t> row_ptr = SplitBatch(dmat, nthread);
433 for (
int tid = 0; tid < nthread - 1; ++tid) {
434 request.rbegin = row_ptr[tid];
435 request.rend = row_ptr[tid + 1];
436 pool->SubmitTask(tid, request);
438 size_t total_size = 0;
441 const size_t rbegin = row_ptr[nthread - 1];
442 const size_t rend = row_ptr[nthread];
443 const size_t query_result_size
444 = pred_func_->PredictBatch(dmat, rbegin, rend, pred_margin, out_result);
445 total_size += query_result_size;
447 for (
int tid = 0; tid < nthread - 1; ++tid) {
448 if (pool->WaitForTask(tid, &response)) {
449 total_size += response.query_result_size;
453 if (total_size < QueryResultSize(dmat, 0, num_row)) {
454 TREELITE_CHECK_GT(num_class_, 1);
455 TREELITE_CHECK_EQ(total_size % num_row, 0);
456 const size_t query_size_per_instance = total_size / num_row;
457 TREELITE_CHECK_GT(query_size_per_instance, 0);
458 TREELITE_CHECK_LT(query_size_per_instance, num_class_);
459 DispatchWithTypeInfo<ShrinkResultToFit>(
460 leaf_output_type_, num_row, query_size_per_instance, num_class_, out_result);
462 const double tend = GetTime();
464 TREELITE_LOG(INFO) <<
"Treelite: Finished prediction in " << tend - tstart <<
" sec";
470 Predictor::CreateOutputVector(
const DMatrix* dmat)
const {
471 const size_t output_vector_size = this->QueryResultSize(dmat);
472 return DispatchWithTypeInfo<AllocateOutputVector>(leaf_output_type_, output_vector_size);
477 DispatchWithTypeInfo<DeallocateOutputVector>(leaf_output_type_, output_vector);
480 template <
typename LeafOutputType>
483 size_t num_row,
size_t query_size_per_instance,
size_t num_class,
485 auto* out_result_ =
static_cast<LeafOutputType*
>(out_result);
486 for (
size_t rid = 0; rid < num_row; ++rid) {
487 for (
size_t k = 0; k < query_size_per_instance; ++k) {
488 out_result_[rid * query_size_per_instance + k] = out_result_[rid * num_class + k];
493 template <
typename LeafOutputType>
499 template <
typename LeafOutputType>
502 delete[] (
static_cast<LeafOutputType*
>(output_vector));
Load prediction function exported as a shared library.
Some useful math utilities.
size_t num_col
number of columns (i.e. # of features used)
Input data structure of Treelite.
bool CheckNAN(T value)
check for NaN (Not a Number)
ElementType missing_value
value representing the missing value (usually NaN)
void * PredictorOutputHandle
handle to output from predictor
std::vector< size_t > row_ptr
pointer to row headers; length is [num_row] + 1.
std::vector< uint32_t > col_ind
feature indices. col_ind[i] indicates the feature index associated with data[i].
std::vector< ElementType > data
feature values
TypeInfo
Types used by thresholds and leaf outputs.
std::vector< ElementType > data
feature values
Defines TypeInfo class and utilities.
void Run(Function f, Parameters... params)
Parallel OMP blocks should be placed within Run to save exception.
TypeInfo GetTypeInfoByName(const std::string &str)
conversion table from string to TypeInfo, defined in tables.cc
a simple thread pool implementation
data layout. The value -1 signifies the missing value. When the "missing" field is set to -1...
predictor class: wrapper for optimized prediction code
void * ThreadPoolHandle
opaque handle types
size_t num_col
number of columns (i.e. # of features used)