10 #include <dmlc/logging.h> 11 #include <dmlc/timer.h> 16 #include <type_traits> 17 #include "common/math.h" 33 size_t num_output_group;
34 treelite::Predictor::PredFuncHandle pred_func_handle;
40 size_t query_result_size;
46 inline treelite::Predictor::LibraryHandle OpenLibrary(
const char* name) {
48 HMODULE handle = LoadLibraryA(name);
50 void* handle = dlopen(name, RTLD_LAZY | RTLD_LOCAL);
52 return static_cast<treelite::Predictor::LibraryHandle
>(handle);
55 inline void CloseLibrary(treelite::Predictor::LibraryHandle handle) {
57 FreeLibrary(static_cast<HMODULE>(handle));
59 dlclose(static_cast<void*>(handle));
63 template <
typename HandleType>
64 inline HandleType LoadFunction(treelite::Predictor::LibraryHandle lib_handle,
67 FARPROC func_handle = GetProcAddress(static_cast<HMODULE>(lib_handle), name);
69 void* func_handle = dlsym(static_cast<void*>(lib_handle), name);
71 return static_cast<HandleType
>(func_handle);
74 template <
typename PredFunc>
76 size_t rbegin,
size_t rend,
77 float* out_pred, PredFunc func) {
78 std::vector<treelite::Predictor::Entry> inst(batch->
num_col, {-1});
79 CHECK(rbegin < rend && rend <= batch->num_row);
80 CHECK(
sizeof(
size_t) <
sizeof(int64_t)
81 || (rbegin <= static_cast<size_t>(std::numeric_limits<int64_t>::max())
82 && rend <= static_cast<size_t>(std::numeric_limits<int64_t>::max())));
83 const int64_t rbegin_ =
static_cast<int64_t
>(rbegin);
84 const int64_t rend_ =
static_cast<int64_t
>(rend);
85 const size_t num_col = batch->
num_col;
86 const float* data = batch->
data;
87 const uint32_t* col_ind = batch->
col_ind;
88 const size_t* row_ptr = batch->
row_ptr;
89 size_t total_output_size = 0;
90 for (int64_t rid = rbegin_; rid < rend_; ++rid) {
91 const size_t ibegin = row_ptr[rid];
92 const size_t iend = row_ptr[rid + 1];
93 for (
size_t i = ibegin; i < iend; ++i) {
94 inst[col_ind[i]].fvalue = data[i];
96 total_output_size += func(rid, &inst[0], out_pred);
97 for (
size_t i = ibegin; i < iend; ++i) {
98 inst[col_ind[i]].missing = -1;
101 return total_output_size;
104 template <
typename PredFunc>
106 size_t rbegin,
size_t rend,
107 float* out_pred, PredFunc func) {
108 const bool nan_missing
110 std::vector<treelite::Predictor::Entry> inst(batch->
num_col, {-1});
111 CHECK(rbegin < rend && rend <= batch->num_row);
112 CHECK(
sizeof(
size_t) <
sizeof(int64_t)
113 || (rbegin <= static_cast<size_t>(std::numeric_limits<int64_t>::max())
114 && rend <= static_cast<size_t>(std::numeric_limits<int64_t>::max())));
115 const int64_t rbegin_ =
static_cast<int64_t
>(rbegin);
116 const int64_t rend_ =
static_cast<int64_t
>(rend);
117 const size_t num_col = batch->
num_col;
119 const float* data = batch->
data;
121 size_t total_output_size = 0;
122 for (int64_t rid = rbegin_; rid < rend_; ++rid) {
123 row = &data[rid * num_col];
124 for (
size_t j = 0; j < num_col; ++j) {
125 if (treelite::common::math::CheckNAN(row[j])) {
127 <<
"The missing_value argument must be set to NaN if there is any " 128 <<
"NaN in the matrix.";
129 }
else if (nan_missing || row[j] != missing_value) {
130 inst[j].fvalue = row[j];
133 total_output_size += func(rid, &inst[0], out_pred);
134 for (
size_t j = 0; j < num_col; ++j) {
135 inst[j].missing = -1;
138 return total_output_size;
141 template <
typename BatchType>
142 inline size_t PredictBatch_(
const BatchType* batch,
143 bool pred_margin,
size_t num_output_group,
144 treelite::Predictor::PredFuncHandle pred_func_handle,
145 size_t rbegin,
size_t rend,
146 size_t expected_query_result_size,
float* out_pred) {
147 CHECK(pred_func_handle !=
nullptr)
148 <<
"A shared library needs to be loaded first using Load()";
151 size_t query_result_size;
156 if (num_output_group > 1) {
158 PredFunc pred_func =
reinterpret_cast<PredFunc
>(pred_func_handle);
160 PredLoop(batch, rbegin, rend, out_pred,
161 [pred_func, num_output_group, pred_margin]
163 return pred_func(inst, (
int)pred_margin, &out_pred[rid * num_output_group]);
167 PredFunc pred_func =
reinterpret_cast<PredFunc
>(pred_func_handle);
169 PredLoop(batch, rbegin, rend, out_pred,
170 [pred_func, pred_margin]
172 out_pred[rid] = pred_func(inst, (
int)pred_margin);
177 if (query_result_size < expected_query_result_size) {
178 CHECK_GT(num_output_group, 1);
179 CHECK_EQ(query_result_size % batch->num_row, 0);
180 const size_t query_size_per_instance = query_result_size / batch->num_row;
181 CHECK_GT(query_size_per_instance, 0);
182 CHECK_LT(query_size_per_instance, num_output_group);
183 for (
size_t rid = 0; rid < batch->num_row; ++rid) {
184 for (
size_t k = 0; k < query_size_per_instance; ++k) {
185 out_pred[rid * query_size_per_instance + k]
186 = out_pred[rid * num_output_group + k];
190 return query_result_size;
197 Predictor::Predictor(
int num_worker_thread,
198 bool include_master_thread)
199 : lib_handle_(nullptr),
200 query_func_handle_(nullptr),
201 pred_func_handle_(nullptr),
202 thread_pool_handle_(nullptr),
203 include_master_thread_(include_master_thread),
204 num_worker_thread_(num_worker_thread) {}
205 Predictor::~Predictor() {
211 lib_handle_ = OpenLibrary(name);
212 CHECK(lib_handle_ !=
nullptr)
213 <<
"Failed to load dynamic shared library `" << name <<
"'";
216 query_func_handle_ = LoadFunction<QueryFuncHandle>(lib_handle_,
217 "get_num_output_group");
218 using QueryFunc = size_t (*)(void);
219 QueryFunc query_func =
reinterpret_cast<QueryFunc
>(query_func_handle_);
220 CHECK(query_func !=
nullptr)
221 <<
"Dynamic shared library `" << name
222 <<
"' does not contain valid get_num_output_group() function";
223 num_output_group_ = query_func();
226 CHECK_GT(num_output_group_, 0) <<
"num_output_group cannot be zero";
227 if (num_output_group_ > 1) {
228 pred_func_handle_ = LoadFunction<PredFuncHandle>(lib_handle_,
229 "predict_multiclass");
230 using PredFunc = size_t (*)(
Entry*, int,
float*);
231 PredFunc pred_func =
reinterpret_cast<PredFunc
>(pred_func_handle_);
232 CHECK(pred_func !=
nullptr)
233 <<
"Dynamic shared library `" << name
234 <<
"' does not contain valid predict_multiclass() function";
236 pred_func_handle_ = LoadFunction<PredFuncHandle>(lib_handle_,
"predict");
237 using PredFunc = float (*)(
Entry*, int);
238 PredFunc pred_func =
reinterpret_cast<PredFunc
>(pred_func_handle_);
239 CHECK(pred_func !=
nullptr)
240 <<
"Dynamic shared library `" << name
241 <<
"' does not contain valid predict() function";
244 if (num_worker_thread_ == -1) {
245 num_worker_thread_ = std::thread::hardware_concurrency() - 1;
247 thread_pool_handle_ =
static_cast<ThreadPoolHandle
>(
248 new PredThreadPool(num_worker_thread_,
this,
253 while (incoming_queue->Pop(&input)) {
254 size_t query_result_size;
255 const size_t rbegin = input.rbegin;
256 const size_t rend = input.rend;
260 = PredictBatch_(batch, input.pred_margin, input.num_output_group,
261 input.pred_func_handle,
268 = PredictBatch_(batch, input.pred_margin, input.num_output_group,
269 input.pred_func_handle,
274 outgoing_queue->Push(OutputToken{query_result_size});
281 CloseLibrary(lib_handle_);
282 delete static_cast<PredThreadPool*
>(thread_pool_handle_);
285 template <
typename BatchType>
287 std::vector<size_t> SplitBatch(
const BatchType* batch,
size_t nthread) {
288 const size_t num_row = batch->num_row;
289 CHECK_LE(nthread, num_row);
290 const size_t portion = num_row / nthread;
291 const size_t remainder = num_row % nthread;
292 std::vector<size_t> workload(nthread, portion);
293 std::vector<size_t> row_ptr(nthread + 1, 0);
294 for (
size_t i = 0; i < remainder; ++i) {
298 for (
size_t i = 0; i < nthread; ++i) {
299 accum += workload[i];
300 row_ptr[i + 1] = accum;
305 template <
typename BatchType>
307 Predictor::PredictBatchBase_(
const BatchType* batch,
int verbose,
308 bool pred_margin,
float* out_result) {
309 static_assert( std::is_same<BatchType, DenseBatch>::value
310 || std::is_same<BatchType, CSRBatch>::value,
311 "PredictBatchBase_: unrecognized batch type");
312 const double tstart = dmlc::GetTime();
313 PredThreadPool* pool =
static_cast<PredThreadPool*
>(thread_pool_handle_);
314 InputToken request{std::is_same<BatchType, CSRBatch>::value,
315 static_cast<const void*
>(batch), pred_margin,
316 num_output_group_, pred_func_handle_,
317 0, batch->num_row, out_result};
318 OutputToken response;
319 CHECK_GT(batch->num_row, 0);
320 const int nthread = std::min(num_worker_thread_,
321 static_cast<int>(batch->num_row)
322 - (
int)(include_master_thread_));
323 const std::vector<size_t> row_ptr
324 = SplitBatch(batch, nthread + (
int)(include_master_thread_));
325 for (
int tid = 0; tid < nthread; ++tid) {
326 request.rbegin = row_ptr[tid];
327 request.rend = row_ptr[tid + 1];
328 pool->SubmitTask(tid, request);
330 size_t total_size = 0;
331 if (include_master_thread_) {
332 const size_t rbegin = row_ptr[nthread];
333 const size_t rend = row_ptr[nthread + 1];
334 const size_t query_result_size
335 = PredictBatch_(batch, pred_margin, num_output_group_,
337 rbegin, rend, QueryResultSize(batch, rbegin, rend),
339 total_size += query_result_size;
341 for (
int tid = 0; tid < nthread; ++tid) {
342 if (pool->WaitForTask(tid, &response)) {
343 total_size += response.query_result_size;
346 const double tend = dmlc::GetTime();
348 LOG(INFO) <<
"Treelite: Finished prediction in " 349 << tend - tstart <<
" sec";
356 bool pred_margin,
float* out_result) {
357 return PredictBatchBase_(batch, verbose, pred_margin, out_result);
362 bool pred_margin,
float* out_result) {
363 return PredictBatchBase_(batch, verbose, pred_margin, out_result);
Load prediction function exported as a shared library.
const uint32_t * col_ind
feature indices
size_t QueryResultSize(const CSRBatch *batch) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
data layout. The value -1 signifies the missing value. When the "missing" field is set to -1...
sparse batch in Compressed Sparse Row (CSR) format
const size_t * row_ptr
pointer to row headers; length of [num_row] + 1
void Load(const char *name)
load the prediction function from dynamic shared library.
const float * data
feature values
size_t PredictBatch(const CSRBatch *batch, int verbose, bool pred_margin, float *out_result)
Make predictions on a batch of data rows (synchronously). This function internally divides the worklo...
float missing_value
value representing the missing value (usually nan)
a simple thread pool implementation
const float * data
feature values
compatiblity wrapper for systems that don't support OpenMP
predictor class: wrapper for optimized prediction code
size_t num_col
number of columns (i.e. # of features used)
void Free()
unload the prediction function
size_t num_col
number of columns (i.e. # of features used)