treelite
predictor.cc
Go to the documentation of this file.
1 
8 #include <treelite/predictor.h>
9 #include <treelite/omp.h>
10 #include <dmlc/logging.h>
11 #include <dmlc/timer.h>
12 #include <cstdint>
13 #include <algorithm>
14 #include <limits>
15 #include <functional>
16 #include "common/math.h"
17 
18 #ifdef _WIN32
19 #define NOMINMAX
20 #include <windows.h>
21 #else
22 #include <dlfcn.h>
23 #endif
24 
25 namespace {
26 
27 inline treelite::Predictor::LibraryHandle OpenLibrary(const char* name) {
28 #ifdef _WIN32
29  HMODULE handle = LoadLibraryA(name);
30 #else
31  void* handle = dlopen(name, RTLD_LAZY | RTLD_LOCAL);
32 #endif
33  return static_cast<treelite::Predictor::LibraryHandle>(handle);
34 }
35 
36 inline void CloseLibrary(treelite::Predictor::LibraryHandle handle) {
37 #ifdef _WIN32
38  FreeLibrary(static_cast<HMODULE>(handle));
39 #else
40  dlclose(static_cast<void*>(handle));
41 #endif
42 }
43 
44 template <typename HandleType>
45 inline HandleType LoadFunction(treelite::Predictor::LibraryHandle lib_handle,
46  const char* name) {
47 #ifdef _WIN32
48  FARPROC func_handle = GetProcAddress(static_cast<HMODULE>(lib_handle), name);
49 #else
50  void* func_handle = dlsym(static_cast<void*>(lib_handle), name);
51 #endif
52  return static_cast<HandleType>(func_handle);
53 }
54 
55 template <typename PredFunc>
56 inline void PredLoop(const treelite::CSRBatch* batch, int nthread, int verbose,
57  float* out_pred, PredFunc func) {
58  std::vector<treelite::Predictor::Entry> inst(nthread * batch->num_col, {-1});
59  CHECK(sizeof(size_t) < sizeof(int64_t)
60  || batch->num_row
61  <= static_cast<size_t>(std::numeric_limits<int64_t>::max()));
62  const int64_t num_row = static_cast<int64_t>(batch->num_row);
63  const size_t num_col = batch->num_col;
64  const float* data = batch->data;
65  const uint32_t* col_ind = batch->col_ind;
66  const size_t* row_ptr = batch->row_ptr;
67  #pragma omp parallel for schedule(static) num_threads(nthread) \
68  default(none) firstprivate(num_row, num_col, data, col_ind, row_ptr) \
69  shared(inst, func, out_pred)
70  for (int64_t rid = 0; rid < num_row; ++rid) {
71  const int tid = omp_get_thread_num();
72  const size_t off = num_col * tid;
73  const size_t ibegin = row_ptr[rid];
74  const size_t iend = row_ptr[rid + 1];
75  for (size_t i = ibegin; i < iend; ++i) {
76  inst[off + col_ind[i]].fvalue = data[i];
77  }
78  func(rid, &inst[off], out_pred);
79  for (size_t i = ibegin; i < iend; ++i) {
80  inst[off + col_ind[i]].missing = -1;
81  }
82  }
83 }
84 
85 template <typename PredFunc>
86 inline void PredLoop(const treelite::DenseBatch* batch, int nthread,
87  int verbose, float* out_pred, PredFunc func) {
88  const bool nan_missing
89  = treelite::common::math::CheckNAN(batch->missing_value);
90  std::vector<treelite::Predictor::Entry> inst(nthread * batch->num_col, {-1});
91  CHECK(sizeof(size_t) < sizeof(int64_t)
92  || batch->num_row
93  <= static_cast<size_t>(std::numeric_limits<int64_t>::max()));
94  const int64_t num_row = static_cast<int64_t>(batch->num_row);
95  const size_t num_col = batch->num_col;
96  const float missing_value = batch->missing_value;
97  const float* data = batch->data;
98  const float* row;
99  #pragma omp parallel for schedule(static) num_threads(nthread) \
100  default(none) \
101  firstprivate(num_row, num_col, data, missing_value, nan_missing) \
102  private(row) shared(inst, func, out_pred)
103  for (int64_t rid = 0; rid < num_row; ++rid) {
104  const int tid = omp_get_thread_num();
105  const size_t off = num_col * tid;
106  row = &data[rid * num_col];
107  for (size_t j = 0; j < num_col; ++j) {
108  if (treelite::common::math::CheckNAN(row[j])) {
109  CHECK(nan_missing)
110  << "The missing_value argument must be set to NaN if there is any "
111  << "NaN in the matrix.";
112  } else if (nan_missing || row[j] != missing_value) {
113  inst[off + j].fvalue = row[j];
114  }
115  }
116  func(rid, &inst[off], out_pred);
117  for (size_t j = 0; j < num_col; ++j) {
118  inst[off + j].missing = -1;
119  }
120  }
121 }
122 
123 template <typename BatchType>
124 inline size_t PredictBatch_(const BatchType* batch, int nthread, int verbose,
125  bool pred_margin, size_t num_output_group,
126  treelite::Predictor::PredFuncHandle pred_func_handle,
127  treelite::Predictor::PredTransformFuncHandle pred_transform_func_handle,
128  size_t query_result_size, float* out_pred) {
129  CHECK(pred_func_handle != nullptr)
130  << "A shared library needs to be loaded first using Load()";
131  const int max_thread = omp_get_max_threads();
132  nthread = (nthread == 0) ? max_thread : std::min(nthread, max_thread);
133 
134  if (verbose > 0) {
135  LOG(INFO) << "Begin prediction";
136  }
137  double tstart = dmlc::GetTime();
138 
139  /* Pass the correct prediction function to PredLoop.
140  We also need to specify how the function should be called. */
141  if (num_output_group > 1) {
142  using PredFunc = void (*)(treelite::Predictor::Entry*, float*);
143  PredFunc pred_func = reinterpret_cast<PredFunc>(pred_func_handle);
144  PredLoop(batch, nthread, verbose, out_pred,
145  [pred_func, num_output_group]
146  (int64_t rid, treelite::Predictor::Entry* inst, float* out_pred) {
147  pred_func(inst, &out_pred[rid * num_output_group]);
148  });
149  } else {
150  using PredFunc = float (*)(treelite::Predictor::Entry*);
151  PredFunc pred_func = reinterpret_cast<PredFunc>(pred_func_handle);
152  PredLoop(batch, nthread, verbose, out_pred,
153  [pred_func]
154  (int64_t rid, treelite::Predictor::Entry* inst, float* out_pred) {
155  out_pred[rid] = pred_func(inst);
156  });
157  }
158  if (verbose > 0) {
159  LOG(INFO) << "Finished prediction in "
160  << dmlc::GetTime() - tstart << " sec";
161  }
162 
163  if (pred_margin) {
164  return query_result_size;
165  } else {
166  using PredTransformFunc = size_t(*)(float*, int64_t, int);
167  PredTransformFunc pred_transform_func
168  = reinterpret_cast<PredTransformFunc>(pred_transform_func_handle);
169  return pred_transform_func(out_pred, batch->num_row, nthread);
170  }
171 }
172 
173 } // namespace anonymous
174 
175 namespace treelite {
176 
177 Predictor::Predictor() : lib_handle_(nullptr),
178  query_func_handle_(nullptr),
179  pred_func_handle_(nullptr),
180  pred_transform_func_handle_(nullptr) {}
181 Predictor::~Predictor() {
182  Free();
183 }
184 
185 void
186 Predictor::Load(const char* name) {
187  lib_handle_ = OpenLibrary(name);
188  CHECK(lib_handle_ != nullptr)
189  << "Failed to load dynamic shared library `" << name << "'";
190 
191  /* 1. query # of output groups */
192  query_func_handle_ = LoadFunction<QueryFuncHandle>(lib_handle_,
193  "get_num_output_group");
194  using QueryFunc = size_t (*)(void);
195  QueryFunc query_func = reinterpret_cast<QueryFunc>(query_func_handle_);
196  CHECK(query_func != nullptr)
197  << "Dynamic shared library `" << name
198  << "' does not contain valid get_num_output_group() function";
199  num_output_group_ = query_func();
200 
201  /* 2. load appropriate function for margin prediction */
202  CHECK_GT(num_output_group_, 0) << "num_output_group cannot be zero";
203  if (num_output_group_ > 1) { // multi-class classification
204  pred_func_handle_ = LoadFunction<PredFuncHandle>(lib_handle_,
205  "predict_margin_multiclass");
206  using PredFunc = void (*)(Entry*, float*);
207  PredFunc pred_func = reinterpret_cast<PredFunc>(pred_func_handle_);
208  CHECK(pred_func != nullptr)
209  << "Dynamic shared library `" << name
210  << "' does not contain valid predict_margin_multiclass() function";
211  } else { // everything else
212  pred_func_handle_ = LoadFunction<PredFuncHandle>(lib_handle_,
213  "predict_margin");
214  using PredFunc = float (*)(Entry*);
215  PredFunc pred_func = reinterpret_cast<PredFunc>(pred_func_handle_);
216  CHECK(pred_func != nullptr)
217  << "Dynamic shared library `" << name
218  << "' does not contain valid predict_margin() function";
219  }
220 
221  /* 3. load prediction transform function */
222  pred_transform_func_handle_
223  = LoadFunction<PredTransformFuncHandle>(lib_handle_,
224  "pred_transform_batch");
225  using PredTransformFunc = size_t (*)(float*, int64_t, int);
226  PredTransformFunc pred_transform_func
227  = reinterpret_cast<PredTransformFunc>(pred_transform_func_handle_);
228  CHECK(pred_transform_func != nullptr)
229  << "Dynamic shared library `" << name
230  << "' does not contain valid pred_transform_batch() function";
231 }
232 
233 void
235  CloseLibrary(lib_handle_);
236 }
237 
238 size_t
239 Predictor::PredictBatch(const CSRBatch* batch, int nthread, int verbose,
240  bool pred_margin, float* out_result) const {
241  return PredictBatch_(batch, nthread, verbose, pred_margin, num_output_group_,
242  pred_func_handle_, pred_transform_func_handle_,
243  QueryResultSize(batch), out_result);
244 }
245 
246 size_t
247 Predictor::PredictBatch(const DenseBatch* batch, int nthread, int verbose,
248  bool pred_margin, float* out_result) const {
249  return PredictBatch_(batch, nthread, verbose, pred_margin, num_output_group_,
250  pred_func_handle_, pred_transform_func_handle_,
251  QueryResultSize(batch), out_result);
252 }
253 
254 } // namespace treelite
Load prediction function exported as a shared library.
const uint32_t * col_ind
feature indices
Definition: predictor.h:20
size_t PredictBatch(const CSRBatch *batch, int nthread, int verbose, bool pred_margin, float *out_result) const
make predictions on a batch of data rows
Definition: predictor.cc:239
data layout. The value -1 signifies the missing value. When the "missing" field is set to -1...
Definition: predictor.h:48
sparse batch in Compressed Sparse Row (CSR) format
Definition: predictor.h:16
const size_t * row_ptr
pointer to row headers; length of [num_row] + 1
Definition: predictor.h:22
void Load(const char *name)
load the prediction function from dynamic shared library.
Definition: predictor.cc:186
dense batch
Definition: predictor.h:30
const float * data
feature values
Definition: predictor.h:32
float missing_value
value representing the missing value (usually nan)
Definition: predictor.h:34
const float * data
feature values
Definition: predictor.h:18
compatiblity wrapper for systems that don&#39;t support OpenMP
size_t num_row
number of rows
Definition: predictor.h:36
size_t num_row
number of rows
Definition: predictor.h:24
size_t num_col
number of columns (i.e. # of features used)
Definition: predictor.h:26
void Free()
unload the prediction function
Definition: predictor.cc:234
size_t num_col
number of columns (i.e. # of features used)
Definition: predictor.h:38