10 #include <dmlc/logging.h> 12 #include <dmlc/timer.h> 19 #include <type_traits> 30 enum class InputType : uint8_t {
31 kSparseBatch = 0, kDenseBatch = 1
40 size_t num_output_group;
42 treelite::Predictor::PredFuncHandle pred_func_handle;
50 size_t query_result_size;
55 inline treelite::Predictor::LibraryHandle OpenLibrary(
const char* name) {
57 HMODULE handle = LoadLibraryA(name);
59 void* handle = dlopen(name, RTLD_LAZY | RTLD_LOCAL);
61 return static_cast<treelite::Predictor::LibraryHandle
>(handle);
64 inline void CloseLibrary(treelite::Predictor::LibraryHandle handle) {
66 FreeLibrary(static_cast<HMODULE>(handle));
68 dlclose(static_cast<void*>(handle));
72 template <
typename HandleType>
73 inline HandleType LoadFunction(treelite::Predictor::LibraryHandle lib_handle,
76 FARPROC func_handle = GetProcAddress(static_cast<HMODULE>(lib_handle), name);
78 void* func_handle = dlsym(static_cast<void*>(lib_handle), name);
80 return static_cast<HandleType
>(func_handle);
83 template <
typename PredFunc>
85 size_t rbegin,
size_t rend,
86 float* out_pred, PredFunc func) {
87 CHECK_LE(batch->
num_col, num_feature);
88 std::vector<TreelitePredictorEntry> inst(
89 std::max(batch->
num_col, num_feature), {-1});
90 CHECK(rbegin < rend && rend <= batch->num_row);
91 CHECK(
sizeof(
size_t) <
sizeof(int64_t)
92 || (rbegin <= static_cast<size_t>(std::numeric_limits<int64_t>::max())
93 && rend <= static_cast<size_t>(std::numeric_limits<int64_t>::max())));
94 const int64_t rbegin_ =
static_cast<int64_t
>(rbegin);
95 const int64_t rend_ =
static_cast<int64_t
>(rend);
96 const size_t num_col = batch->
num_col;
97 const float* data = batch->
data;
98 const uint32_t* col_ind = batch->
col_ind;
99 const size_t* row_ptr = batch->
row_ptr;
100 size_t total_output_size = 0;
101 for (int64_t rid = rbegin_; rid < rend_; ++rid) {
102 const size_t ibegin = row_ptr[rid];
103 const size_t iend = row_ptr[rid + 1];
104 for (
size_t i = ibegin; i < iend; ++i) {
105 inst[col_ind[i]].fvalue = data[i];
107 total_output_size += func(rid, &inst[0], out_pred);
108 for (
size_t i = ibegin; i < iend; ++i) {
109 inst[col_ind[i]].missing = -1;
112 return total_output_size;
115 template <
typename PredFunc>
117 size_t rbegin,
size_t rend,
118 float* out_pred, PredFunc func) {
120 CHECK_LE(batch->
num_col, num_feature);
121 std::vector<TreelitePredictorEntry> inst(
122 std::max(batch->
num_col, num_feature), {-1});
123 CHECK(rbegin < rend && rend <= batch->num_row);
124 CHECK(
sizeof(
size_t) <
sizeof(int64_t)
125 || (rbegin <= static_cast<size_t>(std::numeric_limits<int64_t>::max())
126 && rend <= static_cast<size_t>(std::numeric_limits<int64_t>::max())));
127 const int64_t rbegin_ =
static_cast<int64_t
>(rbegin);
128 const int64_t rend_ =
static_cast<int64_t
>(rend);
129 const size_t num_col = batch->
num_col;
131 const float* data = batch->
data;
133 size_t total_output_size = 0;
134 for (int64_t rid = rbegin_; rid < rend_; ++rid) {
135 row = &data[rid * num_col];
136 for (
size_t j = 0; j < num_col; ++j) {
137 if (treelite::math::CheckNAN(row[j])) {
139 <<
"The missing_value argument must be set to NaN if there is any " 140 <<
"NaN in the matrix.";
141 }
else if (nan_missing || row[j] != missing_value) {
142 inst[j].fvalue = row[j];
145 total_output_size += func(rid, &inst[0], out_pred);
146 for (
size_t j = 0; j < num_col; ++j) {
147 inst[j].missing = -1;
150 return total_output_size;
153 template <
typename BatchType>
154 inline size_t PredictBatch_(
const BatchType* batch,
bool pred_margin,
155 size_t num_feature,
size_t num_output_group,
156 treelite::Predictor::PredFuncHandle pred_func_handle,
157 size_t rbegin,
size_t rend,
158 size_t expected_query_result_size,
float* out_pred) {
159 CHECK(pred_func_handle !=
nullptr)
160 <<
"A shared library needs to be loaded first using Load()";
163 size_t query_result_size;
168 if (num_output_group > 1) {
170 PredFunc pred_func =
reinterpret_cast<PredFunc
>(pred_func_handle);
172 PredLoop(batch, num_feature, rbegin, rend, out_pred,
173 [pred_func, num_output_group, pred_margin]
175 return pred_func(inst, static_cast<int>(pred_margin),
176 &out_pred[rid * num_output_group]);
180 PredFunc pred_func =
reinterpret_cast<PredFunc
>(pred_func_handle);
182 PredLoop(batch, num_feature, rbegin, rend, out_pred,
183 [pred_func, pred_margin]
185 out_pred[rid] = pred_func(inst, static_cast<int>(pred_margin));
189 return query_result_size;
193 bool pred_margin,
size_t num_output_group,
194 treelite::Predictor::PredFuncHandle pred_func_handle,
195 size_t expected_query_result_size,
float* out_pred) {
196 CHECK(pred_func_handle !=
nullptr)
197 <<
"A shared library needs to be loaded first using Load()";
198 size_t query_result_size;
199 if (num_output_group > 1) {
201 PredFunc pred_func =
reinterpret_cast<PredFunc
>(pred_func_handle);
202 query_result_size = pred_func(inst, static_cast<int>(pred_margin), out_pred);
205 PredFunc pred_func =
reinterpret_cast<PredFunc
>(pred_func_handle);
206 out_pred[0] = pred_func(inst, static_cast<int>(pred_margin));
207 query_result_size = 1;
209 return query_result_size;
216 Predictor::Predictor(
int num_worker_thread)
217 : lib_handle_(nullptr),
218 num_output_group_query_func_handle_(nullptr),
219 num_feature_query_func_handle_(nullptr),
220 pred_func_handle_(nullptr),
221 thread_pool_handle_(nullptr),
222 num_worker_thread_(num_worker_thread) {}
223 Predictor::~Predictor() {
229 lib_handle_ = OpenLibrary(name);
230 if (lib_handle_ ==
nullptr) {
231 LOG(FATAL) <<
"Failed to load dynamic shared library `" << name <<
"'";
235 num_output_group_query_func_handle_
236 = LoadFunction<QueryFuncHandle>(lib_handle_,
"get_num_output_group");
237 using UnsignedQueryFunc = size_t (*)(void);
239 =
reinterpret_cast<UnsignedQueryFunc
>(num_output_group_query_func_handle_);
240 CHECK(uint_query_func !=
nullptr)
241 <<
"Dynamic shared library `" << name
242 <<
"' does not contain valid get_num_output_group() function";
243 num_output_group_ = uint_query_func();
246 num_feature_query_func_handle_
247 = LoadFunction<QueryFuncHandle>(lib_handle_,
"get_num_feature");
248 uint_query_func =
reinterpret_cast<UnsignedQueryFunc
>(num_feature_query_func_handle_);
249 CHECK(uint_query_func !=
nullptr)
250 <<
"Dynamic shared library `" << name
251 <<
"' does not contain valid get_num_feature() function";
252 num_feature_ = uint_query_func();
253 CHECK_GT(num_feature_, 0) <<
"num_feature cannot be zero";
256 pred_transform_query_func_handle_
257 = LoadFunction<QueryFuncHandle>(lib_handle_,
"get_pred_transform");
258 using StringQueryFunc =
const char* (*)(void);
259 auto str_query_func =
260 reinterpret_cast<StringQueryFunc
>(pred_transform_query_func_handle_);
261 if (str_query_func ==
nullptr) {
262 LOG(INFO) <<
"Dynamic shared library `" << name
263 <<
"' does not contain valid get_pred_transform() function";
264 pred_transform_ =
"unknown";
266 pred_transform_ = str_query_func();
270 sigmoid_alpha_query_func_handle_
271 = LoadFunction<QueryFuncHandle>(lib_handle_,
"get_sigmoid_alpha");
272 using FloatQueryFunc = float (*)(void);
273 auto float_query_func =
274 reinterpret_cast<FloatQueryFunc
>(sigmoid_alpha_query_func_handle_);
275 if (float_query_func ==
nullptr) {
276 LOG(INFO) <<
"Dynamic shared library `" << name
277 <<
"' does not contain valid get_sigmoid_alpha() function";
278 sigmoid_alpha_ = NAN;
280 sigmoid_alpha_ = float_query_func();
284 global_bias_query_func_handle_
285 = LoadFunction<QueryFuncHandle>(lib_handle_,
"get_global_bias");
286 float_query_func =
reinterpret_cast<FloatQueryFunc
>(global_bias_query_func_handle_);
287 if (float_query_func ==
nullptr) {
288 LOG(INFO) <<
"Dynamic shared library `" << name
289 <<
"' does not contain valid get_global_bias() function";
292 global_bias_ = float_query_func();
296 CHECK_GT(num_output_group_, 0) <<
"num_output_group cannot be zero";
297 if (num_output_group_ > 1) {
298 pred_func_handle_ = LoadFunction<PredFuncHandle>(lib_handle_,
299 "predict_multiclass");
301 PredFunc pred_func =
reinterpret_cast<PredFunc
>(pred_func_handle_);
302 CHECK(pred_func !=
nullptr)
303 <<
"Dynamic shared library `" << name
304 <<
"' does not contain valid predict_multiclass() function";
306 pred_func_handle_ = LoadFunction<PredFuncHandle>(lib_handle_,
"predict");
308 PredFunc pred_func =
reinterpret_cast<PredFunc
>(pred_func_handle_);
309 CHECK(pred_func !=
nullptr)
310 <<
"Dynamic shared library `" << name
311 <<
"' does not contain valid predict() function";
314 if (num_worker_thread_ == -1) {
315 num_worker_thread_ = std::thread::hardware_concurrency();
317 thread_pool_handle_ =
static_cast<ThreadPoolHandle
>(
318 new PredThreadPool(num_worker_thread_ - 1,
this,
323 while (incoming_queue->Pop(&input)) {
324 size_t query_result_size;
325 const size_t rbegin = input.rbegin;
326 const size_t rend = input.rend;
327 switch (input.input_type) {
328 case InputType::kSparseBatch:
332 = PredictBatch_(batch, input.pred_margin, input.num_feature,
333 input.num_output_group, input.pred_func_handle,
339 case InputType::kDenseBatch:
343 = PredictBatch_(batch, input.pred_margin, input.num_feature,
344 input.num_output_group, input.pred_func_handle,
351 outgoing_queue->Push(OutputToken{query_result_size});
358 CloseLibrary(lib_handle_);
359 delete static_cast<PredThreadPool*
>(thread_pool_handle_);
362 template <
typename BatchType>
364 std::vector<size_t> SplitBatch(
const BatchType* batch,
size_t split_factor) {
365 const size_t num_row = batch->num_row;
366 CHECK_LE(split_factor, num_row);
367 const size_t portion = num_row / split_factor;
368 const size_t remainder = num_row % split_factor;
369 std::vector<size_t> workload(split_factor, portion);
370 std::vector<size_t> row_ptr(split_factor + 1, 0);
371 for (
size_t i = 0; i < remainder; ++i) {
375 for (
size_t i = 0; i < split_factor; ++i) {
376 accum += workload[i];
377 row_ptr[i + 1] = accum;
382 template <
typename BatchType>
384 Predictor::PredictBatchBase_(
const BatchType* batch,
int verbose,
385 bool pred_margin,
float* out_result) {
386 static_assert(std::is_same<BatchType, DenseBatch>::value
387 || std::is_same<BatchType, CSRBatch>::value,
388 "PredictBatchBase_: unrecognized batch type");
389 const double tstart = dmlc::GetTime();
390 PredThreadPool* pool =
static_cast<PredThreadPool*
>(thread_pool_handle_);
391 const InputType input_type
392 = std::is_same<BatchType, CSRBatch>::value
393 ? InputType::kSparseBatch : InputType::kDenseBatch;
394 InputToken request{input_type,
static_cast<const void*
>(batch), pred_margin,
395 num_feature_, num_output_group_, pred_func_handle_,
396 0, batch->num_row, out_result};
397 OutputToken response;
398 CHECK_GT(batch->num_row, 0);
399 const int nthread = std::min(num_worker_thread_,
400 static_cast<int>(batch->num_row));
401 const std::vector<size_t> row_ptr = SplitBatch(batch, nthread);
402 for (
int tid = 0; tid < nthread - 1; ++tid) {
403 request.rbegin = row_ptr[tid];
404 request.rend = row_ptr[tid + 1];
405 pool->SubmitTask(tid, request);
407 size_t total_size = 0;
410 const size_t rbegin = row_ptr[nthread - 1];
411 const size_t rend = row_ptr[nthread];
412 const size_t query_result_size
413 = PredictBatch_(batch, pred_margin, num_feature_, num_output_group_,
417 total_size += query_result_size;
419 for (
int tid = 0; tid < nthread - 1; ++tid) {
420 if (pool->WaitForTask(tid, &response)) {
421 total_size += response.query_result_size;
426 CHECK_GT(num_output_group_, 1);
427 CHECK_EQ(total_size % batch->num_row, 0);
428 const size_t query_size_per_instance = total_size / batch->num_row;
429 CHECK_GT(query_size_per_instance, 0);
430 CHECK_LT(query_size_per_instance, num_output_group_);
431 for (
size_t rid = 0; rid < batch->num_row; ++rid) {
432 for (
size_t k = 0; k < query_size_per_instance; ++k) {
433 out_result[rid * query_size_per_instance + k]
434 = out_result[rid * num_output_group_ + k];
438 const double tend = dmlc::GetTime();
440 LOG(INFO) <<
"Treelite: Finished prediction in " 441 << tend - tstart <<
" sec";
448 bool pred_margin,
float* out_result) {
449 return PredictBatchBase_(batch, verbose, pred_margin, out_result);
454 bool pred_margin,
float* out_result) {
455 return PredictBatchBase_(batch, verbose, pred_margin, out_result);
462 total_size = PredictInst_(inst, pred_margin, num_output_group_,
Load prediction function exported as a shared library.
Some useful math utilities.
const uint32_t * col_ind
feature indices
size_t QueryResultSizeSingleInst() const
Query the necessary size of array to hold the prediction for a single data row.
bool CheckNAN(T value)
check for NaN (Not a Number)
size_t PredictInst(TreelitePredictorEntry *inst, bool pred_margin, float *out_result)
Make predictions on a single data row (synchronously). The work will be scheduled to the calling thre...
sparse batch in Compressed Sparse Row (CSR) format
const size_t * row_ptr
pointer to row headers; length of [num_row] + 1
void Load(const char *name)
load the prediction function from dynamic shared library.
const float * data
feature values
size_t PredictBatch(const CSRBatch *batch, int verbose, bool pred_margin, float *out_result)
Make predictions on a batch of data rows (synchronously). This function internally divides the worklo...
float missing_value
value representing the missing value (usually nan)
size_t QueryResultSize(const CSRBatch *batch) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
data layout. The value -1 signifies the missing value. When the "missing" field is set to -1...
a simple thread pool implementation
const float * data
feature values
predictor class: wrapper for optimized prediction code
size_t num_col
number of columns (i.e. # of features used)
void Free()
unload the prediction function
size_t num_col
number of columns (i.e. # of features used)