treelite
predictor.cc
Go to the documentation of this file.
1 
8 #include <treelite/predictor.h>
9 #include <treelite/omp.h>
10 #include <dmlc/logging.h>
11 #include <dmlc/timer.h>
12 #include <cstdint>
13 #include <algorithm>
14 #include <limits>
15 #include <functional>
16 #include "common/math.h"
17 
18 #ifdef _WIN32
19 #define NOMINMAX
20 #include <windows.h>
21 #else
22 #include <dlfcn.h>
23 #endif
24 
25 namespace {
26 
27 inline treelite::Predictor::LibraryHandle OpenLibrary(const char* name) {
28 #ifdef _WIN32
29  HMODULE handle = LoadLibraryA(name);
30 #else
31  void* handle = dlopen(name, RTLD_LAZY | RTLD_LOCAL);
32 #endif
33  return static_cast<treelite::Predictor::LibraryHandle>(handle);
34 }
35 
36 inline void CloseLibrary(treelite::Predictor::LibraryHandle handle) {
37 #ifdef _WIN32
38  FreeLibrary(static_cast<HMODULE>(handle));
39 #else
40  dlclose(static_cast<void*>(handle));
41 #endif
42 }
43 
44 template <typename HandleType>
45 inline HandleType LoadFunction(treelite::Predictor::LibraryHandle lib_handle,
46  const char* name) {
47 #ifdef _WIN32
48  FARPROC func_handle = GetProcAddress(static_cast<HMODULE>(lib_handle), name);
49 #else
50  void* func_handle = dlsym(static_cast<void*>(lib_handle), name);
51 #endif
52  return static_cast<HandleType>(func_handle);
53 }
54 
55 template <typename PredFunc>
56 inline size_t PredLoop(const treelite::CSRBatch* batch, int nthread, int verbose,
57  float* out_pred, PredFunc func) {
58  std::vector<treelite::Predictor::Entry> inst(nthread * batch->num_col, {-1});
59  CHECK(sizeof(size_t) < sizeof(int64_t)
60  || batch->num_row
61  <= static_cast<size_t>(std::numeric_limits<int64_t>::max()));
62  const int64_t num_row = static_cast<int64_t>(batch->num_row);
63  const size_t num_col = batch->num_col;
64  const float* data = batch->data;
65  const uint32_t* col_ind = batch->col_ind;
66  const size_t* row_ptr = batch->row_ptr;
67  size_t total_output_size = 0;
68  #pragma omp parallel for schedule(static) num_threads(nthread) \
69  default(none) firstprivate(num_row, num_col, data, col_ind, row_ptr) \
70  shared(inst, func, out_pred) \
71  reduction(+:total_output_size)
72  for (int64_t rid = 0; rid < num_row; ++rid) {
73  const int tid = omp_get_thread_num();
74  const size_t off = num_col * tid;
75  const size_t ibegin = row_ptr[rid];
76  const size_t iend = row_ptr[rid + 1];
77  for (size_t i = ibegin; i < iend; ++i) {
78  inst[off + col_ind[i]].fvalue = data[i];
79  }
80  total_output_size += func(rid, &inst[off], out_pred);
81  for (size_t i = ibegin; i < iend; ++i) {
82  inst[off + col_ind[i]].missing = -1;
83  }
84  }
85  return total_output_size;
86 }
87 
88 template <typename PredFunc>
89 inline size_t PredLoop(const treelite::DenseBatch* batch, int nthread,
90  int verbose, float* out_pred, PredFunc func) {
91  const bool nan_missing
92  = treelite::common::math::CheckNAN(batch->missing_value);
93  std::vector<treelite::Predictor::Entry> inst(nthread * batch->num_col, {-1});
94  CHECK(sizeof(size_t) < sizeof(int64_t)
95  || batch->num_row
96  <= static_cast<size_t>(std::numeric_limits<int64_t>::max()));
97  const int64_t num_row = static_cast<int64_t>(batch->num_row);
98  const size_t num_col = batch->num_col;
99  const float missing_value = batch->missing_value;
100  const float* data = batch->data;
101  const float* row;
102  size_t total_output_size = 0;
103  #pragma omp parallel for schedule(static) num_threads(nthread) \
104  default(none) \
105  firstprivate(num_row, num_col, data, missing_value, nan_missing) \
106  private(row) shared(inst, func, out_pred) \
107  reduction(+:total_output_size)
108  for (int64_t rid = 0; rid < num_row; ++rid) {
109  const int tid = omp_get_thread_num();
110  const size_t off = num_col * tid;
111  row = &data[rid * num_col];
112  for (size_t j = 0; j < num_col; ++j) {
113  if (treelite::common::math::CheckNAN(row[j])) {
114  CHECK(nan_missing)
115  << "The missing_value argument must be set to NaN if there is any "
116  << "NaN in the matrix.";
117  } else if (nan_missing || row[j] != missing_value) {
118  inst[off + j].fvalue = row[j];
119  }
120  }
121  total_output_size += func(rid, &inst[off], out_pred);
122  for (size_t j = 0; j < num_col; ++j) {
123  inst[off + j].missing = -1;
124  }
125  }
126  return total_output_size;
127 }
128 
129 template <typename BatchType>
130 inline size_t PredictBatch_(const BatchType* batch, int nthread, int verbose,
131  bool pred_margin, size_t num_output_group,
132  treelite::Predictor::PredFuncHandle pred_func_handle,
133  size_t expected_query_result_size, float* out_pred) {
134  CHECK(pred_func_handle != nullptr)
135  << "A shared library needs to be loaded first using Load()";
136  const int max_thread = omp_get_max_threads();
137  nthread = (nthread == 0) ? max_thread : std::min(nthread, max_thread);
138 
139  if (verbose > 0) {
140  LOG(INFO) << "Begin prediction";
141  }
142  double tstart = dmlc::GetTime();
143 
144  /* Pass the correct prediction function to PredLoop.
145  We also need to specify how the function should be called. */
146  size_t query_result_size;
147  // Dimention of output vector:
148  // can be either [num_data] or [num_class]*[num_data].
149  // Note that size of prediction may be smaller than out_pred (this occurs
150  // when pred_function is set to "max_index").
151  if (num_output_group > 1) { // multi-class classification task
152  using PredFunc = size_t (*)(treelite::Predictor::Entry*, int, float*);
153  PredFunc pred_func = reinterpret_cast<PredFunc>(pred_func_handle);
154  query_result_size =
155  PredLoop(batch, nthread, verbose, out_pred,
156  [pred_func, num_output_group, pred_margin]
157  (int64_t rid, treelite::Predictor::Entry* inst, float* out_pred) -> size_t {
158  return pred_func(inst, (int)pred_margin, &out_pred[rid * num_output_group]);
159  });
160  } else { // every other task
161  using PredFunc = float (*)(treelite::Predictor::Entry*, int);
162  PredFunc pred_func = reinterpret_cast<PredFunc>(pred_func_handle);
163  query_result_size =
164  PredLoop(batch, nthread, verbose, out_pred,
165  [pred_func, pred_margin]
166  (int64_t rid, treelite::Predictor::Entry* inst, float* out_pred) -> size_t {
167  out_pred[rid] = pred_func(inst, (int)pred_margin);
168  return 1;
169  });
170  }
171  if (verbose > 0) {
172  LOG(INFO) << "Finished prediction in "
173  << dmlc::GetTime() - tstart << " sec";
174  }
175  // re-shape output if query_result_size < dimension of out_pred
176  if (query_result_size < expected_query_result_size) {
177  CHECK_GT(num_output_group, 1);
178  CHECK_EQ(query_result_size % batch->num_row, 0);
179  const size_t query_size_per_instance = query_result_size / batch->num_row;
180  CHECK_GT(query_size_per_instance, 0);
181  CHECK_LT(query_size_per_instance, num_output_group);
182  for (size_t rid = 0; rid < batch->num_row; ++rid) {
183  for (size_t k = 0; k < query_size_per_instance; ++k) {
184  out_pred[rid * query_size_per_instance + k]
185  = out_pred[rid * num_output_group + k];
186  }
187  }
188  }
189  return query_result_size;
190 }
191 
192 } // namespace anonymous
193 
194 namespace treelite {
195 
196 Predictor::Predictor() : lib_handle_(nullptr),
197  query_func_handle_(nullptr),
198  pred_func_handle_(nullptr) {}
199 Predictor::~Predictor() {
200  Free();
201 }
202 
203 void
204 Predictor::Load(const char* name) {
205  lib_handle_ = OpenLibrary(name);
206  CHECK(lib_handle_ != nullptr)
207  << "Failed to load dynamic shared library `" << name << "'";
208 
209  /* 1. query # of output groups */
210  query_func_handle_ = LoadFunction<QueryFuncHandle>(lib_handle_,
211  "get_num_output_group");
212  using QueryFunc = size_t (*)(void);
213  QueryFunc query_func = reinterpret_cast<QueryFunc>(query_func_handle_);
214  CHECK(query_func != nullptr)
215  << "Dynamic shared library `" << name
216  << "' does not contain valid get_num_output_group() function";
217  num_output_group_ = query_func();
218 
219  /* 2. load appropriate function for margin prediction */
220  CHECK_GT(num_output_group_, 0) << "num_output_group cannot be zero";
221  if (num_output_group_ > 1) { // multi-class classification
222  pred_func_handle_ = LoadFunction<PredFuncHandle>(lib_handle_,
223  "predict_multiclass");
224  using PredFunc = size_t (*)(Entry*, int, float*);
225  PredFunc pred_func = reinterpret_cast<PredFunc>(pred_func_handle_);
226  CHECK(pred_func != nullptr)
227  << "Dynamic shared library `" << name
228  << "' does not contain valid predict_multiclass() function";
229  } else { // everything else
230  pred_func_handle_ = LoadFunction<PredFuncHandle>(lib_handle_, "predict");
231  using PredFunc = float (*)(Entry*, int);
232  PredFunc pred_func = reinterpret_cast<PredFunc>(pred_func_handle_);
233  CHECK(pred_func != nullptr)
234  << "Dynamic shared library `" << name
235  << "' does not contain valid predict() function";
236  }
237 }
238 
239 void
241  CloseLibrary(lib_handle_);
242 }
243 
244 size_t
245 Predictor::PredictBatch(const CSRBatch* batch, int nthread, int verbose,
246  bool pred_margin, float* out_result) const {
247  return PredictBatch_(batch, nthread, verbose, pred_margin, num_output_group_,
248  pred_func_handle_, QueryResultSize(batch), out_result);
249 }
250 
251 size_t
252 Predictor::PredictBatch(const DenseBatch* batch, int nthread, int verbose,
253  bool pred_margin, float* out_result) const {
254  return PredictBatch_(batch, nthread, verbose, pred_margin, num_output_group_,
255  pred_func_handle_, QueryResultSize(batch), out_result);
256 }
257 
258 } // namespace treelite
Load prediction function exported as a shared library.
const uint32_t * col_ind
feature indices
Definition: predictor.h:20
size_t PredictBatch(const CSRBatch *batch, int nthread, int verbose, bool pred_margin, float *out_result) const
make predictions on a batch of data rows
Definition: predictor.cc:245
data layout. The value -1 signifies the missing value. When the "missing" field is set to -1...
Definition: predictor.h:48
sparse batch in Compressed Sparse Row (CSR) format
Definition: predictor.h:16
const size_t * row_ptr
pointer to row headers; length of [num_row] + 1
Definition: predictor.h:22
void Load(const char *name)
load the prediction function from dynamic shared library.
Definition: predictor.cc:204
dense batch
Definition: predictor.h:30
const float * data
feature values
Definition: predictor.h:32
float missing_value
value representing the missing value (usually nan)
Definition: predictor.h:34
const float * data
feature values
Definition: predictor.h:18
compatiblity wrapper for systems that don&#39;t support OpenMP
size_t num_row
number of rows
Definition: predictor.h:36
size_t num_row
number of rows
Definition: predictor.h:24
size_t num_col
number of columns (i.e. # of features used)
Definition: predictor.h:26
void Free()
unload the prediction function
Definition: predictor.cc:240
size_t num_col
number of columns (i.e. # of features used)
Definition: predictor.h:38