treelite
predictor.cc
Go to the documentation of this file.
1 
8 #include <treelite/predictor.h>
9 #include <treelite/omp.h>
10 #include <dmlc/logging.h>
11 #include <dmlc/timer.h>
12 #include <cstdint>
13 #include <algorithm>
14 #include <limits>
15 #include <functional>
16 #include <type_traits>
17 #include "common/math.h"
19 
20 #ifdef _WIN32
21 #define NOMINMAX
22 #include <windows.h>
23 #else
24 #include <dlfcn.h>
25 #endif
26 
27 namespace {
28 
29 struct InputToken {
30  bool sparse;
31  const void* batch;
32  bool pred_margin;
33  size_t num_output_group;
34  treelite::Predictor::PredFuncHandle pred_func_handle;
35  size_t rbegin, rend;
36  float* out_pred;
37 };
38 
39 struct OutputToken {
40  size_t query_result_size;
41 };
42 
43 using PredThreadPool
45 
46 inline treelite::Predictor::LibraryHandle OpenLibrary(const char* name) {
47 #ifdef _WIN32
48  HMODULE handle = LoadLibraryA(name);
49 #else
50  void* handle = dlopen(name, RTLD_LAZY | RTLD_LOCAL);
51 #endif
52  return static_cast<treelite::Predictor::LibraryHandle>(handle);
53 }
54 
55 inline void CloseLibrary(treelite::Predictor::LibraryHandle handle) {
56 #ifdef _WIN32
57  FreeLibrary(static_cast<HMODULE>(handle));
58 #else
59  dlclose(static_cast<void*>(handle));
60 #endif
61 }
62 
63 template <typename HandleType>
64 inline HandleType LoadFunction(treelite::Predictor::LibraryHandle lib_handle,
65  const char* name) {
66 #ifdef _WIN32
67  FARPROC func_handle = GetProcAddress(static_cast<HMODULE>(lib_handle), name);
68 #else
69  void* func_handle = dlsym(static_cast<void*>(lib_handle), name);
70 #endif
71  return static_cast<HandleType>(func_handle);
72 }
73 
74 template <typename PredFunc>
75 inline size_t PredLoop(const treelite::CSRBatch* batch,
76  size_t rbegin, size_t rend,
77  float* out_pred, PredFunc func) {
78  std::vector<treelite::Predictor::Entry> inst(batch->num_col, {-1});
79  CHECK(rbegin < rend && rend <= batch->num_row);
80  CHECK(sizeof(size_t) < sizeof(int64_t)
81  || (rbegin <= static_cast<size_t>(std::numeric_limits<int64_t>::max())
82  && rend <= static_cast<size_t>(std::numeric_limits<int64_t>::max())));
83  const int64_t rbegin_ = static_cast<int64_t>(rbegin);
84  const int64_t rend_ = static_cast<int64_t>(rend);
85  const size_t num_col = batch->num_col;
86  const float* data = batch->data;
87  const uint32_t* col_ind = batch->col_ind;
88  const size_t* row_ptr = batch->row_ptr;
89  size_t total_output_size = 0;
90  for (int64_t rid = rbegin_; rid < rend_; ++rid) {
91  const size_t ibegin = row_ptr[rid];
92  const size_t iend = row_ptr[rid + 1];
93  for (size_t i = ibegin; i < iend; ++i) {
94  inst[col_ind[i]].fvalue = data[i];
95  }
96  total_output_size += func(rid, &inst[0], out_pred);
97  for (size_t i = ibegin; i < iend; ++i) {
98  inst[col_ind[i]].missing = -1;
99  }
100  }
101  return total_output_size;
102 }
103 
104 template <typename PredFunc>
105 inline size_t PredLoop(const treelite::DenseBatch* batch,
106  size_t rbegin, size_t rend,
107  float* out_pred, PredFunc func) {
108  const bool nan_missing
109  = treelite::common::math::CheckNAN(batch->missing_value);
110  std::vector<treelite::Predictor::Entry> inst(batch->num_col, {-1});
111  CHECK(rbegin < rend && rend <= batch->num_row);
112  CHECK(sizeof(size_t) < sizeof(int64_t)
113  || (rbegin <= static_cast<size_t>(std::numeric_limits<int64_t>::max())
114  && rend <= static_cast<size_t>(std::numeric_limits<int64_t>::max())));
115  const int64_t rbegin_ = static_cast<int64_t>(rbegin);
116  const int64_t rend_ = static_cast<int64_t>(rend);
117  const size_t num_col = batch->num_col;
118  const float missing_value = batch->missing_value;
119  const float* data = batch->data;
120  const float* row;
121  size_t total_output_size = 0;
122  for (int64_t rid = rbegin_; rid < rend_; ++rid) {
123  row = &data[rid * num_col];
124  for (size_t j = 0; j < num_col; ++j) {
125  if (treelite::common::math::CheckNAN(row[j])) {
126  CHECK(nan_missing)
127  << "The missing_value argument must be set to NaN if there is any "
128  << "NaN in the matrix.";
129  } else if (nan_missing || row[j] != missing_value) {
130  inst[j].fvalue = row[j];
131  }
132  }
133  total_output_size += func(rid, &inst[0], out_pred);
134  for (size_t j = 0; j < num_col; ++j) {
135  inst[j].missing = -1;
136  }
137  }
138  return total_output_size;
139 }
140 
141 template <typename BatchType>
142 inline size_t PredictBatch_(const BatchType* batch,
143  bool pred_margin, size_t num_output_group,
144  treelite::Predictor::PredFuncHandle pred_func_handle,
145  size_t rbegin, size_t rend,
146  size_t expected_query_result_size, float* out_pred) {
147  CHECK(pred_func_handle != nullptr)
148  << "A shared library needs to be loaded first using Load()";
149  /* Pass the correct prediction function to PredLoop.
150  We also need to specify how the function should be called. */
151  size_t query_result_size;
152  // Dimention of output vector:
153  // can be either [num_data] or [num_class]*[num_data].
154  // Note that size of prediction may be smaller than out_pred (this occurs
155  // when pred_function is set to "max_index").
156  if (num_output_group > 1) { // multi-class classification task
157  using PredFunc = size_t (*)(treelite::Predictor::Entry*, int, float*);
158  PredFunc pred_func = reinterpret_cast<PredFunc>(pred_func_handle);
159  query_result_size =
160  PredLoop(batch, rbegin, rend, out_pred,
161  [pred_func, num_output_group, pred_margin]
162  (int64_t rid, treelite::Predictor::Entry* inst, float* out_pred) -> size_t {
163  return pred_func(inst, (int)pred_margin, &out_pred[rid * num_output_group]);
164  });
165  } else { // every other task
166  using PredFunc = float (*)(treelite::Predictor::Entry*, int);
167  PredFunc pred_func = reinterpret_cast<PredFunc>(pred_func_handle);
168  query_result_size =
169  PredLoop(batch, rbegin, rend, out_pred,
170  [pred_func, pred_margin]
171  (int64_t rid, treelite::Predictor::Entry* inst, float* out_pred) -> size_t {
172  out_pred[rid] = pred_func(inst, (int)pred_margin);
173  return 1;
174  });
175  }
176  // re-shape output if query_result_size < dimension of out_pred
177  if (query_result_size < expected_query_result_size) {
178  CHECK_GT(num_output_group, 1);
179  CHECK_EQ(query_result_size % batch->num_row, 0);
180  const size_t query_size_per_instance = query_result_size / batch->num_row;
181  CHECK_GT(query_size_per_instance, 0);
182  CHECK_LT(query_size_per_instance, num_output_group);
183  for (size_t rid = 0; rid < batch->num_row; ++rid) {
184  for (size_t k = 0; k < query_size_per_instance; ++k) {
185  out_pred[rid * query_size_per_instance + k]
186  = out_pred[rid * num_output_group + k];
187  }
188  }
189  }
190  return query_result_size;
191 }
192 
193 } // namespace anonymous
194 
195 namespace treelite {
196 
197 Predictor::Predictor(int num_worker_thread,
198  bool include_master_thread)
199  : lib_handle_(nullptr),
200  query_func_handle_(nullptr),
201  pred_func_handle_(nullptr),
202  thread_pool_handle_(nullptr),
203  include_master_thread_(include_master_thread),
204  num_worker_thread_(num_worker_thread) {}
205 Predictor::~Predictor() {
206  Free();
207 }
208 
209 void
210 Predictor::Load(const char* name) {
211  lib_handle_ = OpenLibrary(name);
212  CHECK(lib_handle_ != nullptr)
213  << "Failed to load dynamic shared library `" << name << "'";
214 
215  /* 1. query # of output groups */
216  query_func_handle_ = LoadFunction<QueryFuncHandle>(lib_handle_,
217  "get_num_output_group");
218  using QueryFunc = size_t (*)(void);
219  QueryFunc query_func = reinterpret_cast<QueryFunc>(query_func_handle_);
220  CHECK(query_func != nullptr)
221  << "Dynamic shared library `" << name
222  << "' does not contain valid get_num_output_group() function";
223  num_output_group_ = query_func();
224 
225  /* 2. load appropriate function for margin prediction */
226  CHECK_GT(num_output_group_, 0) << "num_output_group cannot be zero";
227  if (num_output_group_ > 1) { // multi-class classification
228  pred_func_handle_ = LoadFunction<PredFuncHandle>(lib_handle_,
229  "predict_multiclass");
230  using PredFunc = size_t (*)(Entry*, int, float*);
231  PredFunc pred_func = reinterpret_cast<PredFunc>(pred_func_handle_);
232  CHECK(pred_func != nullptr)
233  << "Dynamic shared library `" << name
234  << "' does not contain valid predict_multiclass() function";
235  } else { // everything else
236  pred_func_handle_ = LoadFunction<PredFuncHandle>(lib_handle_, "predict");
237  using PredFunc = float (*)(Entry*, int);
238  PredFunc pred_func = reinterpret_cast<PredFunc>(pred_func_handle_);
239  CHECK(pred_func != nullptr)
240  << "Dynamic shared library `" << name
241  << "' does not contain valid predict() function";
242  }
243 
244  if (num_worker_thread_ == -1) {
245  num_worker_thread_ = std::thread::hardware_concurrency() - 1;
246  }
247  thread_pool_handle_ = static_cast<ThreadPoolHandle>(
248  new PredThreadPool(num_worker_thread_, this,
249  [](SpscQueue<InputToken>* incoming_queue,
250  SpscQueue<OutputToken>* outgoing_queue,
251  const treelite::Predictor* predictor) {
252  InputToken input;
253  while (incoming_queue->Pop(&input)) {
254  size_t query_result_size;
255  const size_t rbegin = input.rbegin;
256  const size_t rend = input.rend;
257  if (input.sparse) {
258  const CSRBatch* batch = static_cast<const CSRBatch*>(input.batch);
259  query_result_size
260  = PredictBatch_(batch, input.pred_margin, input.num_output_group,
261  input.pred_func_handle,
262  rbegin, rend,
263  predictor->QueryResultSize(batch, rbegin, rend),
264  input.out_pred);
265  } else {
266  const DenseBatch* batch = static_cast<const DenseBatch*>(input.batch);
267  query_result_size
268  = PredictBatch_(batch, input.pred_margin, input.num_output_group,
269  input.pred_func_handle,
270  rbegin, rend,
271  predictor->QueryResultSize(batch, rbegin, rend),
272  input.out_pred);
273  }
274  outgoing_queue->Push(OutputToken{query_result_size});
275  }
276  }));
277 }
278 
279 void
281  CloseLibrary(lib_handle_);
282  delete static_cast<PredThreadPool*>(thread_pool_handle_);
283 }
284 
285 template <typename BatchType>
286 static inline
287 std::vector<size_t> SplitBatch(const BatchType* batch, size_t nthread) {
288  const size_t num_row = batch->num_row;
289  CHECK_LE(nthread, num_row);
290  const size_t portion = num_row / nthread;
291  const size_t remainder = num_row % nthread;
292  std::vector<size_t> workload(nthread, portion);
293  std::vector<size_t> row_ptr(nthread + 1, 0);
294  for (size_t i = 0; i < remainder; ++i) {
295  ++workload[i];
296  }
297  size_t accum = 0;
298  for (size_t i = 0; i < nthread; ++i) {
299  accum += workload[i];
300  row_ptr[i + 1] = accum;
301  }
302  return row_ptr;
303 }
304 
305 template <typename BatchType>
306 inline size_t
307 Predictor::PredictBatchBase_(const BatchType* batch, int verbose,
308  bool pred_margin, float* out_result) {
309  static_assert( std::is_same<BatchType, DenseBatch>::value
310  || std::is_same<BatchType, CSRBatch>::value,
311  "PredictBatchBase_: unrecognized batch type");
312  const double tstart = dmlc::GetTime();
313  PredThreadPool* pool = static_cast<PredThreadPool*>(thread_pool_handle_);
314  InputToken request{std::is_same<BatchType, CSRBatch>::value,
315  static_cast<const void*>(batch), pred_margin,
316  num_output_group_, pred_func_handle_,
317  0, batch->num_row, out_result};
318  OutputToken response;
319  CHECK_GT(batch->num_row, 0);
320  const int nthread = std::min(num_worker_thread_,
321  static_cast<int>(batch->num_row)
322  - (int)(include_master_thread_));
323  const std::vector<size_t> row_ptr
324  = SplitBatch(batch, nthread + (int)(include_master_thread_));
325  for (int tid = 0; tid < nthread; ++tid) {
326  request.rbegin = row_ptr[tid];
327  request.rend = row_ptr[tid + 1];
328  pool->SubmitTask(tid, request);
329  }
330  size_t total_size = 0;
331  if (include_master_thread_) {
332  const size_t rbegin = row_ptr[nthread];
333  const size_t rend = row_ptr[nthread + 1];
334  const size_t query_result_size
335  = PredictBatch_(batch, pred_margin, num_output_group_,
336  pred_func_handle_,
337  rbegin, rend, QueryResultSize(batch, rbegin, rend),
338  out_result);
339  total_size += query_result_size;
340  }
341  for (int tid = 0; tid < nthread; ++tid) {
342  if (pool->WaitForTask(tid, &response)) {
343  total_size += response.query_result_size;
344  }
345  }
346  const double tend = dmlc::GetTime();
347  if (verbose > 0) {
348  LOG(INFO) << "Treelite: Finished prediction in "
349  << tend - tstart << " sec";
350  }
351  return total_size;
352 }
353 
354 size_t
355 Predictor::PredictBatch(const CSRBatch* batch, int verbose,
356  bool pred_margin, float* out_result) {
357  return PredictBatchBase_(batch, verbose, pred_margin, out_result);
358 }
359 
360 size_t
361 Predictor::PredictBatch(const DenseBatch* batch, int verbose,
362  bool pred_margin, float* out_result) {
363  return PredictBatchBase_(batch, verbose, pred_margin, out_result);
364 }
365 
366 } // namespace treelite
Load prediction function exported as a shared library.
const uint32_t * col_ind
feature indices
Definition: predictor.h:20
size_t QueryResultSize(const CSRBatch *batch) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
Definition: predictor.h:96
data layout. The value -1 signifies the missing value. When the "missing" field is set to -1...
Definition: predictor.h:48
sparse batch in Compressed Sparse Row (CSR) format
Definition: predictor.h:16
const size_t * row_ptr
pointer to row headers; length of [num_row] + 1
Definition: predictor.h:22
void Load(const char *name)
load the prediction function from dynamic shared library.
Definition: predictor.cc:210
dense batch
Definition: predictor.h:30
const float * data
feature values
Definition: predictor.h:32
size_t PredictBatch(const CSRBatch *batch, int verbose, bool pred_margin, float *out_result)
Make predictions on a batch of data rows (synchronously). This function internally divides the worklo...
Definition: predictor.cc:355
float missing_value
value representing the missing value (usually nan)
Definition: predictor.h:34
a simple thread pool implementation
const float * data
feature values
Definition: predictor.h:18
compatiblity wrapper for systems that don&#39;t support OpenMP
predictor class: wrapper for optimized prediction code
Definition: predictor.h:42
size_t num_col
number of columns (i.e. # of features used)
Definition: predictor.h:26
void Free()
unload the prediction function
Definition: predictor.cc:280
size_t num_col
number of columns (i.e. # of features used)
Definition: predictor.h:38