9 #include <treelite/common.h> 10 #include <dmlc/logging.h> 12 #include <dmlc/timer.h> 19 #include <type_traits> 20 #include "common/math.h" 21 #include "common/filesystem.h" 33 enum class InputType : uint8_t {
34 kSparseBatch = 0, kDenseBatch = 1
43 size_t num_output_group;
45 treelite::Predictor::PredFuncHandle pred_func_handle;
53 size_t query_result_size;
56 inline std::string GetProtocol(
const char* name) {
57 const char *p = std::strstr(name,
"://");
61 return std::string(name, p - name + 3);
68 inline treelite::Predictor::LibraryHandle OpenLibrary(
const char* name) {
70 HMODULE handle = LoadLibraryA(name);
72 void* handle = dlopen(name, RTLD_LAZY | RTLD_LOCAL);
74 return static_cast<treelite::Predictor::LibraryHandle
>(handle);
77 inline void CloseLibrary(treelite::Predictor::LibraryHandle handle) {
79 FreeLibrary(static_cast<HMODULE>(handle));
81 dlclose(static_cast<void*>(handle));
85 template <
typename HandleType>
86 inline HandleType LoadFunction(treelite::Predictor::LibraryHandle lib_handle,
89 FARPROC func_handle = GetProcAddress(static_cast<HMODULE>(lib_handle), name);
91 void* func_handle = dlsym(static_cast<void*>(lib_handle), name);
93 return static_cast<HandleType
>(func_handle);
96 template <
typename PredFunc>
98 size_t rbegin,
size_t rend,
99 float* out_pred, PredFunc func) {
100 CHECK_LE(batch->
num_col, num_feature);
101 std::vector<TreelitePredictorEntry> inst(
102 std::max(batch->
num_col, num_feature), {-1});
103 CHECK(rbegin < rend && rend <= batch->num_row);
104 CHECK(
sizeof(
size_t) <
sizeof(int64_t)
105 || (rbegin <= static_cast<size_t>(std::numeric_limits<int64_t>::max())
106 && rend <= static_cast<size_t>(std::numeric_limits<int64_t>::max())));
107 const int64_t rbegin_ =
static_cast<int64_t
>(rbegin);
108 const int64_t rend_ =
static_cast<int64_t
>(rend);
109 const size_t num_col = batch->
num_col;
110 const float* data = batch->
data;
111 const uint32_t* col_ind = batch->
col_ind;
112 const size_t* row_ptr = batch->
row_ptr;
113 size_t total_output_size = 0;
114 for (int64_t rid = rbegin_; rid < rend_; ++rid) {
115 const size_t ibegin = row_ptr[rid];
116 const size_t iend = row_ptr[rid + 1];
117 for (
size_t i = ibegin; i < iend; ++i) {
118 inst[col_ind[i]].fvalue = data[i];
120 total_output_size += func(rid, &inst[0], out_pred);
121 for (
size_t i = ibegin; i < iend; ++i) {
122 inst[col_ind[i]].missing = -1;
125 return total_output_size;
128 template <
typename PredFunc>
130 size_t rbegin,
size_t rend,
131 float* out_pred, PredFunc func) {
132 const bool nan_missing
134 CHECK_LE(batch->
num_col, num_feature);
135 std::vector<TreelitePredictorEntry> inst(
136 std::max(batch->
num_col, num_feature), {-1});
137 CHECK(rbegin < rend && rend <= batch->num_row);
138 CHECK(
sizeof(
size_t) <
sizeof(int64_t)
139 || (rbegin <= static_cast<size_t>(std::numeric_limits<int64_t>::max())
140 && rend <= static_cast<size_t>(std::numeric_limits<int64_t>::max())));
141 const int64_t rbegin_ =
static_cast<int64_t
>(rbegin);
142 const int64_t rend_ =
static_cast<int64_t
>(rend);
143 const size_t num_col = batch->
num_col;
145 const float* data = batch->
data;
147 size_t total_output_size = 0;
148 for (int64_t rid = rbegin_; rid < rend_; ++rid) {
149 row = &data[rid * num_col];
150 for (
size_t j = 0; j < num_col; ++j) {
151 if (treelite::common::math::CheckNAN(row[j])) {
153 <<
"The missing_value argument must be set to NaN if there is any " 154 <<
"NaN in the matrix.";
155 }
else if (nan_missing || row[j] != missing_value) {
156 inst[j].fvalue = row[j];
159 total_output_size += func(rid, &inst[0], out_pred);
160 for (
size_t j = 0; j < num_col; ++j) {
161 inst[j].missing = -1;
164 return total_output_size;
167 template <
typename BatchType>
168 inline size_t PredictBatch_(
const BatchType* batch,
bool pred_margin,
169 size_t num_feature,
size_t num_output_group,
170 treelite::Predictor::PredFuncHandle pred_func_handle,
171 size_t rbegin,
size_t rend,
172 size_t expected_query_result_size,
float* out_pred) {
173 CHECK(pred_func_handle !=
nullptr)
174 <<
"A shared library needs to be loaded first using Load()";
177 size_t query_result_size;
182 if (num_output_group > 1) {
184 PredFunc pred_func =
reinterpret_cast<PredFunc
>(pred_func_handle);
186 PredLoop(batch, num_feature, rbegin, rend, out_pred,
187 [pred_func, num_output_group, pred_margin]
189 return pred_func(inst, static_cast<int>(pred_margin),
190 &out_pred[rid * num_output_group]);
194 PredFunc pred_func =
reinterpret_cast<PredFunc
>(pred_func_handle);
196 PredLoop(batch, num_feature, rbegin, rend, out_pred,
197 [pred_func, pred_margin]
199 out_pred[rid] = pred_func(inst, static_cast<int>(pred_margin));
203 return query_result_size;
207 bool pred_margin,
size_t num_output_group,
208 treelite::Predictor::PredFuncHandle pred_func_handle,
209 size_t expected_query_result_size,
float* out_pred) {
210 CHECK(pred_func_handle !=
nullptr)
211 <<
"A shared library needs to be loaded first using Load()";
212 size_t query_result_size;
213 if (num_output_group > 1) {
215 PredFunc pred_func =
reinterpret_cast<PredFunc
>(pred_func_handle);
216 query_result_size = pred_func(inst, (
int)pred_margin, out_pred);
219 PredFunc pred_func =
reinterpret_cast<PredFunc
>(pred_func_handle);
220 out_pred[0] = pred_func(inst, (
int)pred_margin);
221 query_result_size = 1;
223 return query_result_size;
230 Predictor::Predictor(
int num_worker_thread)
231 : lib_handle_(nullptr),
232 num_output_group_query_func_handle_(nullptr),
233 num_feature_query_func_handle_(nullptr),
234 pred_func_handle_(nullptr),
235 thread_pool_handle_(nullptr),
236 num_worker_thread_(num_worker_thread),
238 Predictor::~Predictor() {
243 Predictor::Load(
const char* name) {
244 const std::string protocol = GetProtocol(name);
245 if (protocol ==
"file://" || protocol.empty()) {
247 lib_handle_ = OpenLibrary(name);
251 temp_libfile_ = tempdir_->AddFile(common::filesystem::GetBasename(name));
253 std::unique_ptr<dmlc::Stream> strm(dmlc::Stream::Create(name,
"r"));
254 dmlc::istream is(strm.get());
255 std::ofstream of(temp_libfile_);
258 lib_handle_ = OpenLibrary(temp_libfile_.c_str());
260 if (lib_handle_ ==
nullptr) {
261 LOG(FATAL) <<
"Failed to load dynamic shared library `" << name <<
"'";
265 num_output_group_query_func_handle_
266 = LoadFunction<QueryFuncHandle>(lib_handle_,
"get_num_output_group");
267 using UnsignedQueryFunc = size_t (*)(void);
269 =
reinterpret_cast<UnsignedQueryFunc
>(num_output_group_query_func_handle_);
270 CHECK(uint_query_func !=
nullptr)
271 <<
"Dynamic shared library `" << name
272 <<
"' does not contain valid get_num_output_group() function";
273 num_output_group_ = uint_query_func();
276 num_feature_query_func_handle_
277 = LoadFunction<QueryFuncHandle>(lib_handle_,
"get_num_feature");
278 uint_query_func =
reinterpret_cast<UnsignedQueryFunc
>(num_feature_query_func_handle_);
279 CHECK(uint_query_func !=
nullptr)
280 <<
"Dynamic shared library `" << name
281 <<
"' does not contain valid get_num_feature() function";
282 num_feature_ = uint_query_func();
283 CHECK_GT(num_feature_, 0) <<
"num_feature cannot be zero";
286 pred_transform_query_func_handle_
287 = LoadFunction<QueryFuncHandle>(lib_handle_,
"get_pred_transform");
288 using StringQueryFunc =
const char* (*)(void);
289 auto str_query_func =
290 reinterpret_cast<StringQueryFunc
>(pred_transform_query_func_handle_);
291 if (str_query_func ==
nullptr) {
292 LOG(INFO) <<
"Dynamic shared library `" << name
293 <<
"' does not contain valid get_pred_transform() function";
294 pred_transform_ =
"unknown";
296 pred_transform_ = str_query_func();
300 sigmoid_alpha_query_func_handle_
301 = LoadFunction<QueryFuncHandle>(lib_handle_,
"get_sigmoid_alpha");
302 using FloatQueryFunc = float (*)(void);
303 auto float_query_func =
304 reinterpret_cast<FloatQueryFunc
>(sigmoid_alpha_query_func_handle_);
305 if (float_query_func ==
nullptr) {
306 LOG(INFO) <<
"Dynamic shared library `" << name
307 <<
"' does not contain valid get_sigmoid_alpha() function";
308 sigmoid_alpha_ = NAN;
310 sigmoid_alpha_ = float_query_func();
314 global_bias_query_func_handle_
315 = LoadFunction<QueryFuncHandle>(lib_handle_,
"get_global_bias");
316 float_query_func =
reinterpret_cast<FloatQueryFunc
>(global_bias_query_func_handle_);
317 if (float_query_func ==
nullptr) {
318 LOG(INFO) <<
"Dynamic shared library `" << name
319 <<
"' does not contain valid get_global_bias() function";
322 global_bias_ = float_query_func();
326 CHECK_GT(num_output_group_, 0) <<
"num_output_group cannot be zero";
327 if (num_output_group_ > 1) {
328 pred_func_handle_ = LoadFunction<PredFuncHandle>(lib_handle_,
329 "predict_multiclass");
331 PredFunc pred_func =
reinterpret_cast<PredFunc
>(pred_func_handle_);
332 CHECK(pred_func !=
nullptr)
333 <<
"Dynamic shared library `" << name
334 <<
"' does not contain valid predict_multiclass() function";
336 pred_func_handle_ = LoadFunction<PredFuncHandle>(lib_handle_,
"predict");
338 PredFunc pred_func =
reinterpret_cast<PredFunc
>(pred_func_handle_);
339 CHECK(pred_func !=
nullptr)
340 <<
"Dynamic shared library `" << name
341 <<
"' does not contain valid predict() function";
344 if (num_worker_thread_ == -1) {
345 num_worker_thread_ = std::thread::hardware_concurrency();
347 thread_pool_handle_ =
static_cast<ThreadPoolHandle
>(
348 new PredThreadPool(num_worker_thread_ - 1,
this,
353 while (incoming_queue->Pop(&input)) {
354 size_t query_result_size;
355 const size_t rbegin = input.rbegin;
356 const size_t rend = input.rend;
357 switch (input.input_type) {
358 case InputType::kSparseBatch:
362 = PredictBatch_(batch, input.pred_margin, input.num_feature,
363 input.num_output_group, input.pred_func_handle,
369 case InputType::kDenseBatch:
373 = PredictBatch_(batch, input.pred_margin, input.num_feature,
374 input.num_output_group, input.pred_func_handle,
381 outgoing_queue->Push(OutputToken{query_result_size});
388 CloseLibrary(lib_handle_);
389 delete static_cast<PredThreadPool*
>(thread_pool_handle_);
392 template <
typename BatchType>
394 std::vector<size_t> SplitBatch(
const BatchType* batch,
size_t split_factor) {
395 const size_t num_row = batch->num_row;
396 CHECK_LE(split_factor, num_row);
397 const size_t portion = num_row / split_factor;
398 const size_t remainder = num_row % split_factor;
399 std::vector<size_t> workload(split_factor, portion);
400 std::vector<size_t> row_ptr(split_factor + 1, 0);
401 for (
size_t i = 0; i < remainder; ++i) {
405 for (
size_t i = 0; i < split_factor; ++i) {
406 accum += workload[i];
407 row_ptr[i + 1] = accum;
412 template <
typename BatchType>
414 Predictor::PredictBatchBase_(
const BatchType* batch,
int verbose,
415 bool pred_margin,
float* out_result) {
416 static_assert(std::is_same<BatchType, DenseBatch>::value
417 || std::is_same<BatchType, CSRBatch>::value,
418 "PredictBatchBase_: unrecognized batch type");
419 const double tstart = dmlc::GetTime();
420 PredThreadPool* pool =
static_cast<PredThreadPool*
>(thread_pool_handle_);
421 const InputType input_type
422 = std::is_same<BatchType, CSRBatch>::value
423 ? InputType::kSparseBatch : InputType::kDenseBatch;
424 InputToken request{input_type,
static_cast<const void*
>(batch), pred_margin,
425 num_feature_, num_output_group_, pred_func_handle_,
426 0, batch->num_row, out_result};
427 OutputToken response;
428 CHECK_GT(batch->num_row, 0);
429 const int nthread = std::min(num_worker_thread_,
430 static_cast<int>(batch->num_row));
431 const std::vector<size_t> row_ptr = SplitBatch(batch, nthread);
432 for (
int tid = 0; tid < nthread - 1; ++tid) {
433 request.rbegin = row_ptr[tid];
434 request.rend = row_ptr[tid + 1];
435 pool->SubmitTask(tid, request);
437 size_t total_size = 0;
440 const size_t rbegin = row_ptr[nthread - 1];
441 const size_t rend = row_ptr[nthread];
442 const size_t query_result_size
443 = PredictBatch_(batch, pred_margin, num_feature_, num_output_group_,
445 rbegin, rend, QueryResultSize(batch, rbegin, rend),
447 total_size += query_result_size;
449 for (
int tid = 0; tid < nthread - 1; ++tid) {
450 if (pool->WaitForTask(tid, &response)) {
451 total_size += response.query_result_size;
455 if (total_size < QueryResultSize(batch, 0, batch->num_row)) {
456 CHECK_GT(num_output_group_, 1);
457 CHECK_EQ(total_size % batch->num_row, 0);
458 const size_t query_size_per_instance = total_size / batch->num_row;
459 CHECK_GT(query_size_per_instance, 0);
460 CHECK_LT(query_size_per_instance, num_output_group_);
461 for (
size_t rid = 0; rid < batch->num_row; ++rid) {
462 for (
size_t k = 0; k < query_size_per_instance; ++k) {
463 out_result[rid * query_size_per_instance + k]
464 = out_result[rid * num_output_group_ + k];
468 const double tend = dmlc::GetTime();
470 LOG(INFO) <<
"Treelite: Finished prediction in " 471 << tend - tstart <<
" sec";
477 Predictor::PredictBatch(
const CSRBatch* batch,
int verbose,
478 bool pred_margin,
float* out_result) {
479 return PredictBatchBase_(batch, verbose, pred_margin, out_result);
483 Predictor::PredictBatch(
const DenseBatch* batch,
int verbose,
484 bool pred_margin,
float* out_result) {
485 return PredictBatchBase_(batch, verbose, pred_margin, out_result);
492 total_size = PredictInst_(inst, pred_margin, num_output_group_,
494 QueryResultSizeSingleInst(), out_result);
Load prediction function exported as a shared library.
const uint32_t * col_ind
feature indices
sparse batch in Compressed Sparse Row (CSR) format
const size_t * row_ptr
pointer to row headers; length of [num_row] + 1
const float * data
feature values
float missing_value
value representing the missing value (usually nan)
size_t QueryResultSize(const CSRBatch *batch) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
data layout. The value -1 signifies the missing value. When the "missing" field is set to -1...
a simple thread pool implementation
const float * data
feature values
predictor class: wrapper for optimized prediction code
size_t num_col
number of columns (i.e. # of features used)
size_t num_col
number of columns (i.e. # of features used)