Treelite
predictor.cc
Go to the documentation of this file.
1 
8 #include <treelite/predictor.h>
9 #include <treelite/math.h>
10 #include <dmlc/logging.h>
11 #include <dmlc/io.h>
12 #include <dmlc/timer.h>
13 #include <cstdint>
14 #include <algorithm>
15 #include <memory>
16 #include <fstream>
17 #include <limits>
18 #include <functional>
19 #include <type_traits>
21 
22 #ifdef _WIN32
23 #include <windows.h>
24 #else
25 #include <dlfcn.h>
26 #endif
27 
28 namespace {
29 
30 enum class InputType : uint8_t {
31  kSparseBatch = 0, kDenseBatch = 1
32 };
33 
34 struct InputToken {
35  InputType input_type;
36  const void* data; // pointer to input data
37  bool pred_margin; // whether to store raw margin or transformed scores
38  size_t num_feature;
39  // # features (columns) accepted by the tree ensemble model
40  size_t num_output_group;
41  // size of output per instance (row)
42  treelite::Predictor::PredFuncHandle pred_func_handle;
43  size_t rbegin, rend;
44  // range of instances (rows) assigned to each worker
45  float* out_pred;
46  // buffer to store output from each worker
47 };
48 
49 struct OutputToken {
50  size_t query_result_size;
51 };
52 
54 
55 inline treelite::Predictor::LibraryHandle OpenLibrary(const char* name) {
56 #ifdef _WIN32
57  HMODULE handle = LoadLibraryA(name);
58 #else
59  void* handle = dlopen(name, RTLD_LAZY | RTLD_LOCAL);
60 #endif
61  return static_cast<treelite::Predictor::LibraryHandle>(handle);
62 }
63 
64 inline void CloseLibrary(treelite::Predictor::LibraryHandle handle) {
65 #ifdef _WIN32
66  FreeLibrary(static_cast<HMODULE>(handle));
67 #else
68  dlclose(static_cast<void*>(handle));
69 #endif
70 }
71 
72 template <typename HandleType>
73 inline HandleType LoadFunction(treelite::Predictor::LibraryHandle lib_handle,
74  const char* name) {
75 #ifdef _WIN32
76  FARPROC func_handle = GetProcAddress(static_cast<HMODULE>(lib_handle), name);
77 #else
78  void* func_handle = dlsym(static_cast<void*>(lib_handle), name);
79 #endif
80  return static_cast<HandleType>(func_handle);
81 }
82 
83 template <typename PredFunc>
84 inline size_t PredLoop(const treelite::CSRBatch* batch, size_t num_feature,
85  size_t rbegin, size_t rend,
86  float* out_pred, PredFunc func) {
87  CHECK_LE(batch->num_col, num_feature);
88  std::vector<TreelitePredictorEntry> inst(
89  std::max(batch->num_col, num_feature), {-1});
90  CHECK(rbegin < rend && rend <= batch->num_row);
91  CHECK(sizeof(size_t) < sizeof(int64_t)
92  || (rbegin <= static_cast<size_t>(std::numeric_limits<int64_t>::max())
93  && rend <= static_cast<size_t>(std::numeric_limits<int64_t>::max())));
94  const int64_t rbegin_ = static_cast<int64_t>(rbegin);
95  const int64_t rend_ = static_cast<int64_t>(rend);
96  const size_t num_col = batch->num_col;
97  const float* data = batch->data;
98  const uint32_t* col_ind = batch->col_ind;
99  const size_t* row_ptr = batch->row_ptr;
100  size_t total_output_size = 0;
101  for (int64_t rid = rbegin_; rid < rend_; ++rid) {
102  const size_t ibegin = row_ptr[rid];
103  const size_t iend = row_ptr[rid + 1];
104  for (size_t i = ibegin; i < iend; ++i) {
105  inst[col_ind[i]].fvalue = data[i];
106  }
107  total_output_size += func(rid, &inst[0], out_pred);
108  for (size_t i = ibegin; i < iend; ++i) {
109  inst[col_ind[i]].missing = -1;
110  }
111  }
112  return total_output_size;
113 }
114 
115 template <typename PredFunc>
116 inline size_t PredLoop(const treelite::DenseBatch* batch, size_t num_feature,
117  size_t rbegin, size_t rend,
118  float* out_pred, PredFunc func) {
119  const bool nan_missing = treelite::math::CheckNAN(batch->missing_value);
120  CHECK_LE(batch->num_col, num_feature);
121  std::vector<TreelitePredictorEntry> inst(
122  std::max(batch->num_col, num_feature), {-1});
123  CHECK(rbegin < rend && rend <= batch->num_row);
124  CHECK(sizeof(size_t) < sizeof(int64_t)
125  || (rbegin <= static_cast<size_t>(std::numeric_limits<int64_t>::max())
126  && rend <= static_cast<size_t>(std::numeric_limits<int64_t>::max())));
127  const int64_t rbegin_ = static_cast<int64_t>(rbegin);
128  const int64_t rend_ = static_cast<int64_t>(rend);
129  const size_t num_col = batch->num_col;
130  const float missing_value = batch->missing_value;
131  const float* data = batch->data;
132  const float* row;
133  size_t total_output_size = 0;
134  for (int64_t rid = rbegin_; rid < rend_; ++rid) {
135  row = &data[rid * num_col];
136  for (size_t j = 0; j < num_col; ++j) {
137  if (treelite::math::CheckNAN(row[j])) {
138  CHECK(nan_missing)
139  << "The missing_value argument must be set to NaN if there is any "
140  << "NaN in the matrix.";
141  } else if (nan_missing || row[j] != missing_value) {
142  inst[j].fvalue = row[j];
143  }
144  }
145  total_output_size += func(rid, &inst[0], out_pred);
146  for (size_t j = 0; j < num_col; ++j) {
147  inst[j].missing = -1;
148  }
149  }
150  return total_output_size;
151 }
152 
153 template <typename BatchType>
154 inline size_t PredictBatch_(const BatchType* batch, bool pred_margin,
155  size_t num_feature, size_t num_output_group,
156  treelite::Predictor::PredFuncHandle pred_func_handle,
157  size_t rbegin, size_t rend,
158  size_t expected_query_result_size, float* out_pred) {
159  CHECK(pred_func_handle != nullptr)
160  << "A shared library needs to be loaded first using Load()";
161  /* Pass the correct prediction function to PredLoop.
162  We also need to specify how the function should be called. */
163  size_t query_result_size;
164  // Dimension of output vector:
165  // can be either [num_data] or [num_class]*[num_data].
166  // Note that size of prediction may be smaller than out_pred (this occurs
167  // when pred_function is set to "max_index").
168  if (num_output_group > 1) { // multi-class classification task
169  using PredFunc = size_t (*)(TreelitePredictorEntry*, int, float*);
170  PredFunc pred_func = reinterpret_cast<PredFunc>(pred_func_handle);
171  query_result_size =
172  PredLoop(batch, num_feature, rbegin, rend, out_pred,
173  [pred_func, num_output_group, pred_margin]
174  (int64_t rid, TreelitePredictorEntry* inst, float* out_pred) -> size_t {
175  return pred_func(inst, static_cast<int>(pred_margin),
176  &out_pred[rid * num_output_group]);
177  });
178  } else { // every other task
179  using PredFunc = float (*)(TreelitePredictorEntry*, int);
180  PredFunc pred_func = reinterpret_cast<PredFunc>(pred_func_handle);
181  query_result_size =
182  PredLoop(batch, num_feature, rbegin, rend, out_pred,
183  [pred_func, pred_margin]
184  (int64_t rid, TreelitePredictorEntry* inst, float* out_pred) -> size_t {
185  out_pred[rid] = pred_func(inst, static_cast<int>(pred_margin));
186  return 1;
187  });
188  }
189  return query_result_size;
190 }
191 
192 inline size_t PredictInst_(TreelitePredictorEntry* inst,
193  bool pred_margin, size_t num_output_group,
194  treelite::Predictor::PredFuncHandle pred_func_handle,
195  size_t expected_query_result_size, float* out_pred) {
196  CHECK(pred_func_handle != nullptr)
197  << "A shared library needs to be loaded first using Load()";
198  size_t query_result_size; // Dimention of output vector
199  if (num_output_group > 1) { // multi-class classification task
200  using PredFunc = size_t (*)(TreelitePredictorEntry*, int, float*);
201  PredFunc pred_func = reinterpret_cast<PredFunc>(pred_func_handle);
202  query_result_size = pred_func(inst, static_cast<int>(pred_margin), out_pred);
203  } else { // every other task
204  using PredFunc = float (*)(TreelitePredictorEntry*, int);
205  PredFunc pred_func = reinterpret_cast<PredFunc>(pred_func_handle);
206  out_pred[0] = pred_func(inst, static_cast<int>(pred_margin));
207  query_result_size = 1;
208  }
209  return query_result_size;
210 }
211 
212 } // anonymous namespace
213 
214 namespace treelite {
215 
216 Predictor::Predictor(int num_worker_thread)
217  : lib_handle_(nullptr),
218  num_output_group_query_func_handle_(nullptr),
219  num_feature_query_func_handle_(nullptr),
220  pred_func_handle_(nullptr),
221  thread_pool_handle_(nullptr),
222  num_worker_thread_(num_worker_thread) {}
223 Predictor::~Predictor() {
224  Free();
225 }
226 
227 void
228 Predictor::Load(const char* name) {
229  lib_handle_ = OpenLibrary(name);
230  if (lib_handle_ == nullptr) {
231  LOG(FATAL) << "Failed to load dynamic shared library `" << name << "'";
232  }
233 
234  /* 1. query # of output groups */
235  num_output_group_query_func_handle_
236  = LoadFunction<QueryFuncHandle>(lib_handle_, "get_num_output_group");
237  using UnsignedQueryFunc = size_t (*)(void);
238  auto uint_query_func
239  = reinterpret_cast<UnsignedQueryFunc>(num_output_group_query_func_handle_);
240  CHECK(uint_query_func != nullptr)
241  << "Dynamic shared library `" << name
242  << "' does not contain valid get_num_output_group() function";
243  num_output_group_ = uint_query_func();
244 
245  /* 2. query # of features */
246  num_feature_query_func_handle_
247  = LoadFunction<QueryFuncHandle>(lib_handle_, "get_num_feature");
248  uint_query_func = reinterpret_cast<UnsignedQueryFunc>(num_feature_query_func_handle_);
249  CHECK(uint_query_func != nullptr)
250  << "Dynamic shared library `" << name
251  << "' does not contain valid get_num_feature() function";
252  num_feature_ = uint_query_func();
253  CHECK_GT(num_feature_, 0) << "num_feature cannot be zero";
254 
255  /* 3. query # of pred_transform name */
256  pred_transform_query_func_handle_
257  = LoadFunction<QueryFuncHandle>(lib_handle_, "get_pred_transform");
258  using StringQueryFunc = const char* (*)(void);
259  auto str_query_func =
260  reinterpret_cast<StringQueryFunc>(pred_transform_query_func_handle_);
261  if (str_query_func == nullptr) {
262  LOG(INFO) << "Dynamic shared library `" << name
263  << "' does not contain valid get_pred_transform() function";
264  pred_transform_ = "unknown";
265  } else {
266  pred_transform_ = str_query_func();
267  }
268 
269  /* 4. query # of sigmoid_alpha */
270  sigmoid_alpha_query_func_handle_
271  = LoadFunction<QueryFuncHandle>(lib_handle_, "get_sigmoid_alpha");
272  using FloatQueryFunc = float (*)(void);
273  auto float_query_func =
274  reinterpret_cast<FloatQueryFunc>(sigmoid_alpha_query_func_handle_);
275  if (float_query_func == nullptr) {
276  LOG(INFO) << "Dynamic shared library `" << name
277  << "' does not contain valid get_sigmoid_alpha() function";
278  sigmoid_alpha_ = NAN;
279  } else {
280  sigmoid_alpha_ = float_query_func();
281  }
282 
283  /* 5. query # of global_bias */
284  global_bias_query_func_handle_
285  = LoadFunction<QueryFuncHandle>(lib_handle_, "get_global_bias");
286  float_query_func = reinterpret_cast<FloatQueryFunc>(global_bias_query_func_handle_);
287  if (float_query_func == nullptr) {
288  LOG(INFO) << "Dynamic shared library `" << name
289  << "' does not contain valid get_global_bias() function";
290  global_bias_ = NAN;
291  } else {
292  global_bias_ = float_query_func();
293  }
294 
295  /* 6. load appropriate function for margin prediction */
296  CHECK_GT(num_output_group_, 0) << "num_output_group cannot be zero";
297  if (num_output_group_ > 1) { // multi-class classification
298  pred_func_handle_ = LoadFunction<PredFuncHandle>(lib_handle_,
299  "predict_multiclass");
300  using PredFunc = size_t (*)(TreelitePredictorEntry*, int, float*);
301  PredFunc pred_func = reinterpret_cast<PredFunc>(pred_func_handle_);
302  CHECK(pred_func != nullptr)
303  << "Dynamic shared library `" << name
304  << "' does not contain valid predict_multiclass() function";
305  } else { // everything else
306  pred_func_handle_ = LoadFunction<PredFuncHandle>(lib_handle_, "predict");
307  using PredFunc = float (*)(TreelitePredictorEntry*, int);
308  PredFunc pred_func = reinterpret_cast<PredFunc>(pred_func_handle_);
309  CHECK(pred_func != nullptr)
310  << "Dynamic shared library `" << name
311  << "' does not contain valid predict() function";
312  }
313 
314  if (num_worker_thread_ == -1) {
315  num_worker_thread_ = std::thread::hardware_concurrency();
316  }
317  thread_pool_handle_ = static_cast<ThreadPoolHandle>(
318  new PredThreadPool(num_worker_thread_ - 1, this,
319  [](SpscQueue<InputToken>* incoming_queue,
320  SpscQueue<OutputToken>* outgoing_queue,
321  const Predictor* predictor) {
322  InputToken input;
323  while (incoming_queue->Pop(&input)) {
324  size_t query_result_size;
325  const size_t rbegin = input.rbegin;
326  const size_t rend = input.rend;
327  switch (input.input_type) {
328  case InputType::kSparseBatch:
329  {
330  const CSRBatch* batch = static_cast<const CSRBatch*>(input.data);
331  query_result_size
332  = PredictBatch_(batch, input.pred_margin, input.num_feature,
333  input.num_output_group, input.pred_func_handle,
334  rbegin, rend,
335  predictor->QueryResultSize(batch, rbegin, rend),
336  input.out_pred);
337  }
338  break;
339  case InputType::kDenseBatch:
340  {
341  const DenseBatch* batch = static_cast<const DenseBatch*>(input.data);
342  query_result_size
343  = PredictBatch_(batch, input.pred_margin, input.num_feature,
344  input.num_output_group, input.pred_func_handle,
345  rbegin, rend,
346  predictor->QueryResultSize(batch, rbegin, rend),
347  input.out_pred);
348  }
349  break;
350  }
351  outgoing_queue->Push(OutputToken{query_result_size});
352  }
353  }));
354 }
355 
356 void
358  CloseLibrary(lib_handle_);
359  delete static_cast<PredThreadPool*>(thread_pool_handle_);
360 }
361 
362 template <typename BatchType>
363 static inline
364 std::vector<size_t> SplitBatch(const BatchType* batch, size_t split_factor) {
365  const size_t num_row = batch->num_row;
366  CHECK_LE(split_factor, num_row);
367  const size_t portion = num_row / split_factor;
368  const size_t remainder = num_row % split_factor;
369  std::vector<size_t> workload(split_factor, portion);
370  std::vector<size_t> row_ptr(split_factor + 1, 0);
371  for (size_t i = 0; i < remainder; ++i) {
372  ++workload[i];
373  }
374  size_t accum = 0;
375  for (size_t i = 0; i < split_factor; ++i) {
376  accum += workload[i];
377  row_ptr[i + 1] = accum;
378  }
379  return row_ptr;
380 }
381 
382 template <typename BatchType>
383 inline size_t
384 Predictor::PredictBatchBase_(const BatchType* batch, int verbose,
385  bool pred_margin, float* out_result) {
386  static_assert(std::is_same<BatchType, DenseBatch>::value
387  || std::is_same<BatchType, CSRBatch>::value,
388  "PredictBatchBase_: unrecognized batch type");
389  const double tstart = dmlc::GetTime();
390  PredThreadPool* pool = static_cast<PredThreadPool*>(thread_pool_handle_);
391  const InputType input_type
392  = std::is_same<BatchType, CSRBatch>::value
393  ? InputType::kSparseBatch : InputType::kDenseBatch;
394  InputToken request{input_type, static_cast<const void*>(batch), pred_margin,
395  num_feature_, num_output_group_, pred_func_handle_,
396  0, batch->num_row, out_result};
397  OutputToken response;
398  CHECK_GT(batch->num_row, 0);
399  const int nthread = std::min(num_worker_thread_,
400  static_cast<int>(batch->num_row));
401  const std::vector<size_t> row_ptr = SplitBatch(batch, nthread);
402  for (int tid = 0; tid < nthread - 1; ++tid) {
403  request.rbegin = row_ptr[tid];
404  request.rend = row_ptr[tid + 1];
405  pool->SubmitTask(tid, request);
406  }
407  size_t total_size = 0;
408  {
409  // assign work to master
410  const size_t rbegin = row_ptr[nthread - 1];
411  const size_t rend = row_ptr[nthread];
412  const size_t query_result_size
413  = PredictBatch_(batch, pred_margin, num_feature_, num_output_group_,
414  pred_func_handle_,
415  rbegin, rend, QueryResultSize(batch, rbegin, rend),
416  out_result);
417  total_size += query_result_size;
418  }
419  for (int tid = 0; tid < nthread - 1; ++tid) {
420  if (pool->WaitForTask(tid, &response)) {
421  total_size += response.query_result_size;
422  }
423  }
424  // re-shape output if total_size < dimension of out_result
425  if (total_size < QueryResultSize(batch, 0, batch->num_row)) {
426  CHECK_GT(num_output_group_, 1);
427  CHECK_EQ(total_size % batch->num_row, 0);
428  const size_t query_size_per_instance = total_size / batch->num_row;
429  CHECK_GT(query_size_per_instance, 0);
430  CHECK_LT(query_size_per_instance, num_output_group_);
431  for (size_t rid = 0; rid < batch->num_row; ++rid) {
432  for (size_t k = 0; k < query_size_per_instance; ++k) {
433  out_result[rid * query_size_per_instance + k]
434  = out_result[rid * num_output_group_ + k];
435  }
436  }
437  }
438  const double tend = dmlc::GetTime();
439  if (verbose > 0) {
440  LOG(INFO) << "Treelite: Finished prediction in "
441  << tend - tstart << " sec";
442  }
443  return total_size;
444 }
445 
446 size_t
447 Predictor::PredictBatch(const CSRBatch* batch, int verbose,
448  bool pred_margin, float* out_result) {
449  return PredictBatchBase_(batch, verbose, pred_margin, out_result);
450 }
451 
452 size_t
453 Predictor::PredictBatch(const DenseBatch* batch, int verbose,
454  bool pred_margin, float* out_result) {
455  return PredictBatchBase_(batch, verbose, pred_margin, out_result);
456 }
457 
458 size_t
460  float* out_result) {
461  size_t total_size;
462  total_size = PredictInst_(inst, pred_margin, num_output_group_,
463  pred_func_handle_,
464  QueryResultSizeSingleInst(), out_result);
465  return total_size;
466 }
467 
468 } // namespace treelite
Load prediction function exported as a shared library.
Some useful math utilities.
const uint32_t * col_ind
feature indices
Definition: predictor.h:22
size_t QueryResultSizeSingleInst() const
Query the necessary size of array to hold the prediction for a single data row.
Definition: predictor.h:151
bool CheckNAN(T value)
check for NaN (Not a Number)
Definition: math.h:43
size_t PredictInst(TreelitePredictorEntry *inst, bool pred_margin, float *out_result)
Make predictions on a single data row (synchronously). The work will be scheduled to the calling thre...
Definition: predictor.cc:459
sparse batch in Compressed Sparse Row (CSR) format
Definition: predictor.h:18
const size_t * row_ptr
pointer to row headers; length of [num_row] + 1
Definition: predictor.h:24
void Load(const char *name)
load the prediction function from dynamic shared library.
Definition: predictor.cc:228
dense batch
Definition: predictor.h:32
const float * data
feature values
Definition: predictor.h:34
size_t PredictBatch(const CSRBatch *batch, int verbose, bool pred_margin, float *out_result)
Make predictions on a batch of data rows (synchronously). This function internally divides the worklo...
Definition: predictor.cc:447
float missing_value
value representing the missing value (usually nan)
Definition: predictor.h:36
size_t QueryResultSize(const CSRBatch *batch) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
Definition: predictor.h:100
data layout. The value -1 signifies the missing value. When the "missing" field is set to -1...
Definition: entry.h:14
a simple thread pool implementation
const float * data
feature values
Definition: predictor.h:20
predictor class: wrapper for optimized prediction code
Definition: predictor.h:44
size_t num_col
number of columns (i.e. # of features used)
Definition: predictor.h:28
void Free()
unload the prediction function
Definition: predictor.cc:357
size_t num_col
number of columns (i.e. # of features used)
Definition: predictor.h:40