10 #include <dmlc/logging.h> 11 #include <dmlc/timer.h> 16 #include "common/math.h" 27 inline treelite::Predictor::LibraryHandle OpenLibrary(
const char* name) {
29 HMODULE handle = LoadLibraryA(name);
31 void* handle = dlopen(name, RTLD_LAZY | RTLD_LOCAL);
33 return static_cast<treelite::Predictor::LibraryHandle
>(handle);
36 inline void CloseLibrary(treelite::Predictor::LibraryHandle handle) {
38 FreeLibrary(static_cast<HMODULE>(handle));
40 dlclose(static_cast<void*>(handle));
44 template <
typename HandleType>
45 inline HandleType LoadFunction(treelite::Predictor::LibraryHandle lib_handle,
48 FARPROC func_handle = GetProcAddress(static_cast<HMODULE>(lib_handle), name);
50 void* func_handle = dlsym(static_cast<void*>(lib_handle), name);
52 return static_cast<HandleType
>(func_handle);
55 template <
typename PredFunc>
57 float* out_pred, PredFunc func) {
58 std::vector<treelite::Predictor::Entry> inst(nthread * batch->
num_col, {-1});
59 CHECK(
sizeof(
size_t) <
sizeof(int64_t)
61 <= static_cast<size_t>(std::numeric_limits<int64_t>::max()));
62 const int64_t num_row =
static_cast<int64_t
>(batch->
num_row);
63 const size_t num_col = batch->
num_col;
64 const float* data = batch->
data;
65 const uint32_t* col_ind = batch->
col_ind;
66 const size_t* row_ptr = batch->
row_ptr;
67 size_t total_output_size = 0;
68 #pragma omp parallel for schedule(static) num_threads(nthread) \ 69 default(none) firstprivate(num_row, num_col, data, col_ind, row_ptr) \ 70 shared(inst, func, out_pred) \ 71 reduction(+:total_output_size) 72 for (int64_t rid = 0; rid < num_row; ++rid) {
73 const int tid = omp_get_thread_num();
74 const size_t off = num_col * tid;
75 const size_t ibegin = row_ptr[rid];
76 const size_t iend = row_ptr[rid + 1];
77 for (
size_t i = ibegin; i < iend; ++i) {
78 inst[off + col_ind[i]].fvalue = data[i];
80 total_output_size += func(rid, &inst[off], out_pred);
81 for (
size_t i = ibegin; i < iend; ++i) {
82 inst[off + col_ind[i]].missing = -1;
85 return total_output_size;
88 template <
typename PredFunc>
90 int verbose,
float* out_pred, PredFunc func) {
91 const bool nan_missing
93 std::vector<treelite::Predictor::Entry> inst(nthread * batch->
num_col, {-1});
94 CHECK(
sizeof(
size_t) <
sizeof(int64_t)
96 <= static_cast<size_t>(std::numeric_limits<int64_t>::max()));
97 const int64_t num_row =
static_cast<int64_t
>(batch->
num_row);
98 const size_t num_col = batch->
num_col;
100 const float* data = batch->
data;
102 size_t total_output_size = 0;
103 #pragma omp parallel for schedule(static) num_threads(nthread) \ 105 firstprivate(num_row, num_col, data, missing_value, nan_missing) \ 106 private(row) shared(inst, func, out_pred) \ 107 reduction(+:total_output_size) 108 for (int64_t rid = 0; rid < num_row; ++rid) {
109 const int tid = omp_get_thread_num();
110 const size_t off = num_col * tid;
111 row = &data[rid * num_col];
112 for (
size_t j = 0; j < num_col; ++j) {
113 if (treelite::common::math::CheckNAN(row[j])) {
115 <<
"The missing_value argument must be set to NaN if there is any " 116 <<
"NaN in the matrix.";
117 }
else if (nan_missing || row[j] != missing_value) {
118 inst[off + j].fvalue = row[j];
121 total_output_size += func(rid, &inst[off], out_pred);
122 for (
size_t j = 0; j < num_col; ++j) {
123 inst[off + j].missing = -1;
126 return total_output_size;
129 template <
typename BatchType>
130 inline size_t PredictBatch_(
const BatchType* batch,
int nthread,
int verbose,
131 bool pred_margin,
size_t num_output_group,
132 treelite::Predictor::PredFuncHandle pred_func_handle,
133 size_t expected_query_result_size,
float* out_pred) {
134 CHECK(pred_func_handle !=
nullptr)
135 <<
"A shared library needs to be loaded first using Load()";
136 const int max_thread = omp_get_max_threads();
137 nthread = (nthread == 0) ? max_thread : std::min(nthread, max_thread);
140 LOG(INFO) <<
"Begin prediction";
142 double tstart = dmlc::GetTime();
146 size_t query_result_size;
151 if (num_output_group > 1) {
153 PredFunc pred_func =
reinterpret_cast<PredFunc
>(pred_func_handle);
155 PredLoop(batch, nthread, verbose, out_pred,
156 [pred_func, num_output_group, pred_margin]
158 return pred_func(inst, (
int)pred_margin, &out_pred[rid * num_output_group]);
162 PredFunc pred_func =
reinterpret_cast<PredFunc
>(pred_func_handle);
164 PredLoop(batch, nthread, verbose, out_pred,
165 [pred_func, pred_margin]
167 out_pred[rid] = pred_func(inst, (
int)pred_margin);
172 LOG(INFO) <<
"Finished prediction in " 173 << dmlc::GetTime() - tstart <<
" sec";
176 if (query_result_size < expected_query_result_size) {
177 CHECK_GT(num_output_group, 1);
178 CHECK_EQ(query_result_size % batch->num_row, 0);
179 const size_t query_size_per_instance = query_result_size / batch->num_row;
180 CHECK_GT(query_size_per_instance, 0);
181 CHECK_LT(query_size_per_instance, num_output_group);
182 for (
size_t rid = 0; rid < batch->num_row; ++rid) {
183 for (
size_t k = 0; k < query_size_per_instance; ++k) {
184 out_pred[rid * query_size_per_instance + k]
185 = out_pred[rid * num_output_group + k];
189 return query_result_size;
196 Predictor::Predictor() : lib_handle_(nullptr),
197 query_func_handle_(nullptr),
198 pred_func_handle_(nullptr) {}
199 Predictor::~Predictor() {
205 lib_handle_ = OpenLibrary(name);
206 CHECK(lib_handle_ !=
nullptr)
207 <<
"Failed to load dynamic shared library `" << name <<
"'";
210 query_func_handle_ = LoadFunction<QueryFuncHandle>(lib_handle_,
211 "get_num_output_group");
212 using QueryFunc = size_t (*)(void);
213 QueryFunc query_func =
reinterpret_cast<QueryFunc
>(query_func_handle_);
214 CHECK(query_func !=
nullptr)
215 <<
"Dynamic shared library `" << name
216 <<
"' does not contain valid get_num_output_group() function";
217 num_output_group_ = query_func();
220 CHECK_GT(num_output_group_, 0) <<
"num_output_group cannot be zero";
221 if (num_output_group_ > 1) {
222 pred_func_handle_ = LoadFunction<PredFuncHandle>(lib_handle_,
223 "predict_multiclass");
224 using PredFunc = size_t (*)(
Entry*, int,
float*);
225 PredFunc pred_func =
reinterpret_cast<PredFunc
>(pred_func_handle_);
226 CHECK(pred_func !=
nullptr)
227 <<
"Dynamic shared library `" << name
228 <<
"' does not contain valid predict_multiclass() function";
230 pred_func_handle_ = LoadFunction<PredFuncHandle>(lib_handle_,
"predict");
231 using PredFunc = float (*)(
Entry*, int);
232 PredFunc pred_func =
reinterpret_cast<PredFunc
>(pred_func_handle_);
233 CHECK(pred_func !=
nullptr)
234 <<
"Dynamic shared library `" << name
235 <<
"' does not contain valid predict() function";
241 CloseLibrary(lib_handle_);
246 bool pred_margin,
float* out_result)
const {
247 return PredictBatch_(batch, nthread, verbose, pred_margin, num_output_group_,
248 pred_func_handle_, QueryResultSize(batch), out_result);
253 bool pred_margin,
float* out_result)
const {
254 return PredictBatch_(batch, nthread, verbose, pred_margin, num_output_group_,
255 pred_func_handle_, QueryResultSize(batch), out_result);
Load prediction function exported as a shared library.
const uint32_t * col_ind
feature indices
size_t PredictBatch(const CSRBatch *batch, int nthread, int verbose, bool pred_margin, float *out_result) const
make predictions on a batch of data rows
data layout. The value -1 signifies the missing value. When the "missing" field is set to -1...
sparse batch in Compressed Sparse Row (CSR) format
const size_t * row_ptr
pointer to row headers; length of [num_row] + 1
void Load(const char *name)
load the prediction function from dynamic shared library.
const float * data
feature values
float missing_value
value representing the missing value (usually nan)
const float * data
feature values
compatiblity wrapper for systems that don't support OpenMP
size_t num_row
number of rows
size_t num_row
number of rows
size_t num_col
number of columns (i.e. # of features used)
void Free()
unload the prediction function
size_t num_col
number of columns (i.e. # of features used)