10 #include <dmlc/logging.h> 11 #include <dmlc/timer.h> 16 #include "common/math.h" 27 inline treelite::Predictor::LibraryHandle OpenLibrary(
const char* name) {
29 HMODULE handle = LoadLibraryA(name);
31 void* handle = dlopen(name, RTLD_LAZY | RTLD_LOCAL);
33 return static_cast<treelite::Predictor::LibraryHandle
>(handle);
36 inline void CloseLibrary(treelite::Predictor::LibraryHandle handle) {
38 FreeLibrary(static_cast<HMODULE>(handle));
40 dlclose(static_cast<void*>(handle));
44 template <
typename HandleType>
45 inline HandleType LoadFunction(treelite::Predictor::LibraryHandle lib_handle,
48 FARPROC func_handle = GetProcAddress(static_cast<HMODULE>(lib_handle), name);
50 void* func_handle = dlsym(static_cast<void*>(lib_handle), name);
52 return static_cast<HandleType
>(func_handle);
55 template <
typename PredFunc>
57 float* out_pred, PredFunc func) {
58 std::vector<treelite::Predictor::Entry> inst(nthread * batch->
num_col, {-1});
59 CHECK(
sizeof(
size_t) <
sizeof(int64_t)
61 <= static_cast<size_t>(std::numeric_limits<int64_t>::max()));
62 const int64_t num_row =
static_cast<int64_t
>(batch->
num_row);
63 const size_t num_col = batch->
num_col;
64 const float* data = batch->
data;
65 const uint32_t* col_ind = batch->
col_ind;
66 const size_t* row_ptr = batch->
row_ptr;
67 #pragma omp parallel for schedule(static) num_threads(nthread) \ 68 default(none) firstprivate(num_row, num_col, data, col_ind, row_ptr) \ 69 shared(inst, func, out_pred) 70 for (int64_t rid = 0; rid < num_row; ++rid) {
71 const int tid = omp_get_thread_num();
72 const size_t off = num_col * tid;
73 const size_t ibegin = row_ptr[rid];
74 const size_t iend = row_ptr[rid + 1];
75 for (
size_t i = ibegin; i < iend; ++i) {
76 inst[off + col_ind[i]].fvalue = data[i];
78 func(rid, &inst[off], out_pred);
79 for (
size_t i = ibegin; i < iend; ++i) {
80 inst[off + col_ind[i]].missing = -1;
85 template <
typename PredFunc>
87 int verbose,
float* out_pred, PredFunc func) {
88 const bool nan_missing
90 std::vector<treelite::Predictor::Entry> inst(nthread * batch->
num_col, {-1});
91 CHECK(
sizeof(
size_t) <
sizeof(int64_t)
93 <= static_cast<size_t>(std::numeric_limits<int64_t>::max()));
94 const int64_t num_row =
static_cast<int64_t
>(batch->
num_row);
95 const size_t num_col = batch->
num_col;
97 const float* data = batch->
data;
99 #pragma omp parallel for schedule(static) num_threads(nthread) \ 101 firstprivate(num_row, num_col, data, missing_value, nan_missing) \ 102 private(row) shared(inst, func, out_pred) 103 for (int64_t rid = 0; rid < num_row; ++rid) {
104 const int tid = omp_get_thread_num();
105 const size_t off = num_col * tid;
106 row = &data[rid * num_col];
107 for (
size_t j = 0; j < num_col; ++j) {
108 if (treelite::common::math::CheckNAN(row[j])) {
110 <<
"The missing_value argument must be set to NaN if there is any " 111 <<
"NaN in the matrix.";
112 }
else if (nan_missing || row[j] != missing_value) {
113 inst[off + j].fvalue = row[j];
116 func(rid, &inst[off], out_pred);
117 for (
size_t j = 0; j < num_col; ++j) {
118 inst[off + j].missing = -1;
123 template <
typename BatchType>
124 inline size_t PredictBatch_(
const BatchType* batch,
int nthread,
int verbose,
125 bool pred_margin,
size_t num_output_group,
126 treelite::Predictor::PredFuncHandle pred_func_handle,
127 treelite::Predictor::PredTransformFuncHandle pred_transform_func_handle,
128 size_t query_result_size,
float* out_pred) {
129 CHECK(pred_func_handle !=
nullptr)
130 <<
"A shared library needs to be loaded first using Load()";
131 const int max_thread = omp_get_max_threads();
132 nthread = (nthread == 0) ? max_thread : std::min(nthread, max_thread);
135 LOG(INFO) <<
"Begin prediction";
137 double tstart = dmlc::GetTime();
141 if (num_output_group > 1) {
143 PredFunc pred_func =
reinterpret_cast<PredFunc
>(pred_func_handle);
144 PredLoop(batch, nthread, verbose, out_pred,
145 [pred_func, num_output_group]
147 pred_func(inst, &out_pred[rid * num_output_group]);
151 PredFunc pred_func =
reinterpret_cast<PredFunc
>(pred_func_handle);
152 PredLoop(batch, nthread, verbose, out_pred,
155 out_pred[rid] = pred_func(inst);
159 LOG(INFO) <<
"Finished prediction in " 160 << dmlc::GetTime() - tstart <<
" sec";
164 return query_result_size;
166 using PredTransformFunc = size_t(*)(
float*, int64_t, int);
167 PredTransformFunc pred_transform_func
168 =
reinterpret_cast<PredTransformFunc
>(pred_transform_func_handle);
169 return pred_transform_func(out_pred, batch->num_row, nthread);
177 Predictor::Predictor() : lib_handle_(nullptr),
178 query_func_handle_(nullptr),
179 pred_func_handle_(nullptr),
180 pred_transform_func_handle_(nullptr) {}
181 Predictor::~Predictor() {
187 lib_handle_ = OpenLibrary(name);
188 CHECK(lib_handle_ !=
nullptr)
189 <<
"Failed to load dynamic shared library `" << name <<
"'";
192 query_func_handle_ = LoadFunction<QueryFuncHandle>(lib_handle_,
193 "get_num_output_group");
194 using QueryFunc = size_t (*)(void);
195 QueryFunc query_func =
reinterpret_cast<QueryFunc
>(query_func_handle_);
196 CHECK(query_func !=
nullptr)
197 <<
"Dynamic shared library `" << name
198 <<
"' does not contain valid get_num_output_group() function";
199 num_output_group_ = query_func();
202 CHECK_GT(num_output_group_, 0) <<
"num_output_group cannot be zero";
203 if (num_output_group_ > 1) {
204 pred_func_handle_ = LoadFunction<PredFuncHandle>(lib_handle_,
205 "predict_margin_multiclass");
206 using PredFunc = void (*)(
Entry*,
float*);
207 PredFunc pred_func =
reinterpret_cast<PredFunc
>(pred_func_handle_);
208 CHECK(pred_func !=
nullptr)
209 <<
"Dynamic shared library `" << name
210 <<
"' does not contain valid predict_margin_multiclass() function";
212 pred_func_handle_ = LoadFunction<PredFuncHandle>(lib_handle_,
214 using PredFunc = float (*)(
Entry*);
215 PredFunc pred_func =
reinterpret_cast<PredFunc
>(pred_func_handle_);
216 CHECK(pred_func !=
nullptr)
217 <<
"Dynamic shared library `" << name
218 <<
"' does not contain valid predict_margin() function";
222 pred_transform_func_handle_
223 = LoadFunction<PredTransformFuncHandle>(lib_handle_,
224 "pred_transform_batch");
225 using PredTransformFunc = size_t (*)(
float*, int64_t, int);
226 PredTransformFunc pred_transform_func
227 =
reinterpret_cast<PredTransformFunc
>(pred_transform_func_handle_);
228 CHECK(pred_transform_func !=
nullptr)
229 <<
"Dynamic shared library `" << name
230 <<
"' does not contain valid pred_transform_batch() function";
235 CloseLibrary(lib_handle_);
240 bool pred_margin,
float* out_result)
const {
241 return PredictBatch_(batch, nthread, verbose, pred_margin, num_output_group_,
242 pred_func_handle_, pred_transform_func_handle_,
243 QueryResultSize(batch), out_result);
248 bool pred_margin,
float* out_result)
const {
249 return PredictBatch_(batch, nthread, verbose, pred_margin, num_output_group_,
250 pred_func_handle_, pred_transform_func_handle_,
251 QueryResultSize(batch), out_result);
Load prediction function exported as a shared library.
const uint32_t * col_ind
feature indices
size_t PredictBatch(const CSRBatch *batch, int nthread, int verbose, bool pred_margin, float *out_result) const
make predictions on a batch of data rows
data layout. The value -1 signifies the missing value. When the "missing" field is set to -1...
sparse batch in Compressed Sparse Row (CSR) format
const size_t * row_ptr
pointer to row headers; length of [num_row] + 1
void Load(const char *name)
load the prediction function from dynamic shared library.
const float * data
feature values
float missing_value
value representing the missing value (usually nan)
const float * data
feature values
compatiblity wrapper for systems that don't support OpenMP
size_t num_row
number of rows
size_t num_row
number of rows
size_t num_col
number of columns (i.e. # of features used)
void Free()
unload the prediction function
size_t num_col
number of columns (i.e. # of features used)