treelite
predictor.cc
Go to the documentation of this file.
1 
8 #include <treelite/predictor.h>
9 #include <treelite/common.h>
10 #include <dmlc/logging.h>
11 #include <dmlc/io.h>
12 #include <dmlc/timer.h>
13 #include <cstdint>
14 #include <algorithm>
15 #include <memory>
16 #include <fstream>
17 #include <limits>
18 #include <functional>
19 #include <type_traits>
20 #include "common/math.h"
21 #include "common/filesystem.h"
23 
24 #ifdef _WIN32
25 #define NOMINMAX
26 #include <windows.h>
27 #else
28 #include <dlfcn.h>
29 #endif
30 
31 namespace {
32 
33 enum class InputType : uint8_t {
34  kSparseBatch = 0, kDenseBatch = 1, kSingleInst = 2
35 };
36 
37 struct InputToken {
38  InputType input_type;
39  const void* data;
40  bool pred_margin;
41  size_t num_output_group;
42  treelite::Predictor::PredFuncHandle pred_func_handle;
43  size_t rbegin, rend;
44  float* out_pred;
45 };
46 
47 struct OutputToken {
48  size_t query_result_size;
49 };
50 
51 inline std::string GetProtocol(const char* name) {
52  const char *p = std::strstr(name, "://");
53  if (p == NULL) {
54  return "";
55  } else {
56  return std::string(name, p - name + 3);
57  }
58 }
59 
60 using PredThreadPool
62 
63 inline treelite::Predictor::LibraryHandle OpenLibrary(const char* name) {
64 #ifdef _WIN32
65  HMODULE handle = LoadLibraryA(name);
66 #else
67  void* handle = dlopen(name, RTLD_LAZY | RTLD_LOCAL);
68 #endif
69  return static_cast<treelite::Predictor::LibraryHandle>(handle);
70 }
71 
72 inline void CloseLibrary(treelite::Predictor::LibraryHandle handle) {
73 #ifdef _WIN32
74  FreeLibrary(static_cast<HMODULE>(handle));
75 #else
76  dlclose(static_cast<void*>(handle));
77 #endif
78 }
79 
80 template <typename HandleType>
81 inline HandleType LoadFunction(treelite::Predictor::LibraryHandle lib_handle,
82  const char* name) {
83 #ifdef _WIN32
84  FARPROC func_handle = GetProcAddress(static_cast<HMODULE>(lib_handle), name);
85 #else
86  void* func_handle = dlsym(static_cast<void*>(lib_handle), name);
87 #endif
88  return static_cast<HandleType>(func_handle);
89 }
90 
91 template <typename PredFunc>
92 inline size_t PredLoop(const treelite::CSRBatch* batch,
93  size_t rbegin, size_t rend,
94  float* out_pred, PredFunc func) {
95  std::vector<TreelitePredictorEntry> inst(batch->num_col, {-1});
96  CHECK(rbegin < rend && rend <= batch->num_row);
97  CHECK(sizeof(size_t) < sizeof(int64_t)
98  || (rbegin <= static_cast<size_t>(std::numeric_limits<int64_t>::max())
99  && rend <= static_cast<size_t>(std::numeric_limits<int64_t>::max())));
100  const int64_t rbegin_ = static_cast<int64_t>(rbegin);
101  const int64_t rend_ = static_cast<int64_t>(rend);
102  const size_t num_col = batch->num_col;
103  const float* data = batch->data;
104  const uint32_t* col_ind = batch->col_ind;
105  const size_t* row_ptr = batch->row_ptr;
106  size_t total_output_size = 0;
107  for (int64_t rid = rbegin_; rid < rend_; ++rid) {
108  const size_t ibegin = row_ptr[rid];
109  const size_t iend = row_ptr[rid + 1];
110  for (size_t i = ibegin; i < iend; ++i) {
111  inst[col_ind[i]].fvalue = data[i];
112  }
113  total_output_size += func(rid, &inst[0], out_pred);
114  for (size_t i = ibegin; i < iend; ++i) {
115  inst[col_ind[i]].missing = -1;
116  }
117  }
118  return total_output_size;
119 }
120 
121 template <typename PredFunc>
122 inline size_t PredLoop(const treelite::DenseBatch* batch,
123  size_t rbegin, size_t rend,
124  float* out_pred, PredFunc func) {
125  const bool nan_missing
126  = treelite::common::math::CheckNAN(batch->missing_value);
127  std::vector<TreelitePredictorEntry> inst(batch->num_col, {-1});
128  CHECK(rbegin < rend && rend <= batch->num_row);
129  CHECK(sizeof(size_t) < sizeof(int64_t)
130  || (rbegin <= static_cast<size_t>(std::numeric_limits<int64_t>::max())
131  && rend <= static_cast<size_t>(std::numeric_limits<int64_t>::max())));
132  const int64_t rbegin_ = static_cast<int64_t>(rbegin);
133  const int64_t rend_ = static_cast<int64_t>(rend);
134  const size_t num_col = batch->num_col;
135  const float missing_value = batch->missing_value;
136  const float* data = batch->data;
137  const float* row;
138  size_t total_output_size = 0;
139  for (int64_t rid = rbegin_; rid < rend_; ++rid) {
140  row = &data[rid * num_col];
141  for (size_t j = 0; j < num_col; ++j) {
142  if (treelite::common::math::CheckNAN(row[j])) {
143  CHECK(nan_missing)
144  << "The missing_value argument must be set to NaN if there is any "
145  << "NaN in the matrix.";
146  } else if (nan_missing || row[j] != missing_value) {
147  inst[j].fvalue = row[j];
148  }
149  }
150  total_output_size += func(rid, &inst[0], out_pred);
151  for (size_t j = 0; j < num_col; ++j) {
152  inst[j].missing = -1;
153  }
154  }
155  return total_output_size;
156 }
157 
158 template <typename BatchType>
159 inline size_t PredictBatch_(const BatchType* batch,
160  bool pred_margin, size_t num_output_group,
161  treelite::Predictor::PredFuncHandle pred_func_handle,
162  size_t rbegin, size_t rend,
163  size_t expected_query_result_size, float* out_pred) {
164  CHECK(pred_func_handle != nullptr)
165  << "A shared library needs to be loaded first using Load()";
166  /* Pass the correct prediction function to PredLoop.
167  We also need to specify how the function should be called. */
168  size_t query_result_size;
169  // Dimension of output vector:
170  // can be either [num_data] or [num_class]*[num_data].
171  // Note that size of prediction may be smaller than out_pred (this occurs
172  // when pred_function is set to "max_index").
173  if (num_output_group > 1) { // multi-class classification task
174  using PredFunc = size_t (*)(TreelitePredictorEntry*, int, float*);
175  PredFunc pred_func = reinterpret_cast<PredFunc>(pred_func_handle);
176  query_result_size =
177  PredLoop(batch, rbegin, rend, out_pred,
178  [pred_func, num_output_group, pred_margin]
179  (int64_t rid, TreelitePredictorEntry* inst, float* out_pred) -> size_t {
180  return pred_func(inst, static_cast<int>(pred_margin),
181  &out_pred[rid * num_output_group]);
182  });
183  } else { // every other task
184  using PredFunc = float (*)(TreelitePredictorEntry*, int);
185  PredFunc pred_func = reinterpret_cast<PredFunc>(pred_func_handle);
186  query_result_size =
187  PredLoop(batch, rbegin, rend, out_pred,
188  [pred_func, pred_margin]
189  (int64_t rid, TreelitePredictorEntry* inst, float* out_pred) -> size_t {
190  out_pred[rid] = pred_func(inst, static_cast<int>(pred_margin));
191  return 1;
192  });
193  }
194  return query_result_size;
195 }
196 
197 inline size_t PredictInst_(TreelitePredictorEntry* inst,
198  bool pred_margin, size_t num_output_group,
199  treelite::Predictor::PredFuncHandle pred_func_handle,
200  size_t expected_query_result_size, float* out_pred) {
201  CHECK(pred_func_handle != nullptr)
202  << "A shared library needs to be loaded first using Load()";
203  /* Pass the correct prediction function to PredLoop */
204  size_t query_result_size; // Dimention of output vector
205  if (num_output_group > 1) { // multi-class classification task
206  using PredFunc = size_t (*)(TreelitePredictorEntry*, int, float*);
207  PredFunc pred_func = reinterpret_cast<PredFunc>(pred_func_handle);
208  query_result_size = pred_func(inst, (int)pred_margin, out_pred);
209  } else { // every other task
210  using PredFunc = float (*)(TreelitePredictorEntry*, int);
211  PredFunc pred_func = reinterpret_cast<PredFunc>(pred_func_handle);
212  out_pred[0] = pred_func(inst, (int)pred_margin);
213  query_result_size = 1;
214  }
215  return query_result_size;
216 }
217 
218 } // anonymous namespace
219 
220 namespace treelite {
221 
222 Predictor::Predictor(int num_worker_thread,
223  bool include_master_thread)
224  : lib_handle_(nullptr),
225  num_output_group_query_func_handle_(nullptr),
226  num_feature_query_func_handle_(nullptr),
227  pred_func_handle_(nullptr),
228  thread_pool_handle_(nullptr),
229  include_master_thread_(include_master_thread),
230  num_worker_thread_(num_worker_thread),
231  tempdir_(nullptr) {}
232 Predictor::~Predictor() {
233  Free();
234 }
235 
236 void
237 Predictor::Load(const char* name) {
238  const std::string protocol = GetProtocol(name);
239  if (protocol == "file://" || protocol.empty()) {
240  // local file
241  lib_handle_ = OpenLibrary(name);
242  } else {
243  // remote file
244  tempdir_.reset(new common::filesystem::TemporaryDirectory());
245  temp_libfile_ = tempdir_->AddFile(common::filesystem::GetBasename(name));
246  {
247  std::unique_ptr<dmlc::Stream> strm(dmlc::Stream::Create(name, "r"));
248  dmlc::istream is(strm.get());
249  std::ofstream of(temp_libfile_);
250  of << is.rdbuf();
251  }
252  lib_handle_ = OpenLibrary(temp_libfile_.c_str());
253  }
254  if (lib_handle_ == nullptr) {
255  LOG(FATAL) << "Failed to load dynamic shared library `" << name << "'";
256  }
257 
258  /* 1. query # of output groups */
259  num_output_group_query_func_handle_
260  = LoadFunction<QueryFuncHandle>(lib_handle_, "get_num_output_group");
261  using QueryFunc = size_t (*)(void);
262  QueryFunc query_func
263  = reinterpret_cast<QueryFunc>(num_output_group_query_func_handle_);
264  CHECK(query_func != nullptr)
265  << "Dynamic shared library `" << name
266  << "' does not contain valid get_num_output_group() function";
267  num_output_group_ = query_func();
268 
269  /* 2. query # of features */
270  num_feature_query_func_handle_
271  = LoadFunction<QueryFuncHandle>(lib_handle_, "get_num_feature");
272  query_func = reinterpret_cast<QueryFunc>(num_feature_query_func_handle_);
273  CHECK(query_func != nullptr)
274  << "Dynamic shared library `" << name
275  << "' does not contain valid get_num_feature() function";
276  num_feature_ = query_func();
277  CHECK_GT(num_feature_, 0) << "num_feature cannot be zero";
278 
279  /* 3. load appropriate function for margin prediction */
280  CHECK_GT(num_output_group_, 0) << "num_output_group cannot be zero";
281  if (num_output_group_ > 1) { // multi-class classification
282  pred_func_handle_ = LoadFunction<PredFuncHandle>(lib_handle_,
283  "predict_multiclass");
284  using PredFunc = size_t (*)(TreelitePredictorEntry*, int, float*);
285  PredFunc pred_func = reinterpret_cast<PredFunc>(pred_func_handle_);
286  CHECK(pred_func != nullptr)
287  << "Dynamic shared library `" << name
288  << "' does not contain valid predict_multiclass() function";
289  } else { // everything else
290  pred_func_handle_ = LoadFunction<PredFuncHandle>(lib_handle_, "predict");
291  using PredFunc = float (*)(TreelitePredictorEntry*, int);
292  PredFunc pred_func = reinterpret_cast<PredFunc>(pred_func_handle_);
293  CHECK(pred_func != nullptr)
294  << "Dynamic shared library `" << name
295  << "' does not contain valid predict() function";
296  }
297 
298  if (num_worker_thread_ == -1) {
299  num_worker_thread_
300  = std::thread::hardware_concurrency() - (int)include_master_thread_;
301  }
302  thread_pool_handle_ = static_cast<ThreadPoolHandle>(
303  new PredThreadPool(num_worker_thread_, this,
304  [](SpscQueue<InputToken>* incoming_queue,
305  SpscQueue<OutputToken>* outgoing_queue,
306  const Predictor* predictor) {
307  InputToken input;
308  while (incoming_queue->Pop(&input)) {
309  size_t query_result_size;
310  const size_t rbegin = input.rbegin;
311  const size_t rend = input.rend;
312  switch (input.input_type) {
313  case InputType::kSparseBatch:
314  {
315  const CSRBatch* batch = static_cast<const CSRBatch*>(input.data);
316  query_result_size
317  = PredictBatch_(batch, input.pred_margin, input.num_output_group,
318  input.pred_func_handle,
319  rbegin, rend,
320  predictor->QueryResultSize(batch, rbegin, rend),
321  input.out_pred);
322  }
323  break;
324  case InputType::kDenseBatch:
325  {
326  const DenseBatch* batch = static_cast<const DenseBatch*>(input.data);
327  query_result_size
328  = PredictBatch_(batch, input.pred_margin, input.num_output_group,
329  input.pred_func_handle,
330  rbegin, rend,
331  predictor->QueryResultSize(batch, rbegin, rend),
332  input.out_pred);
333  }
334  break;
335  case InputType::kSingleInst:
336  default:
337  {
339  = const_cast<TreelitePredictorEntry*>(
340  static_cast<const TreelitePredictorEntry*>(input.data));
341  query_result_size
342  = PredictInst_(inst, input.pred_margin, input.num_output_group,
343  input.pred_func_handle,
344  predictor->QueryResultSizeSingleInst(),
345  input.out_pred);
346  }
347  break;
348  }
349  outgoing_queue->Push(OutputToken{query_result_size});
350  }
351  }));
352 }
353 
354 void
355 Predictor::Free() {
356  CloseLibrary(lib_handle_);
357  delete static_cast<PredThreadPool*>(thread_pool_handle_);
358 }
359 
360 template <typename BatchType>
361 static inline
362 std::vector<size_t> SplitBatch(const BatchType* batch, size_t nthread) {
363  const size_t num_row = batch->num_row;
364  CHECK_LE(nthread, num_row);
365  const size_t portion = num_row / nthread;
366  const size_t remainder = num_row % nthread;
367  std::vector<size_t> workload(nthread, portion);
368  std::vector<size_t> row_ptr(nthread + 1, 0);
369  for (size_t i = 0; i < remainder; ++i) {
370  ++workload[i];
371  }
372  size_t accum = 0;
373  for (size_t i = 0; i < nthread; ++i) {
374  accum += workload[i];
375  row_ptr[i + 1] = accum;
376  }
377  return row_ptr;
378 }
379 
380 template <typename BatchType>
381 inline size_t
382 Predictor::PredictBatchBase_(const BatchType* batch, int verbose,
383  bool pred_margin, float* out_result) {
384  static_assert(std::is_same<BatchType, DenseBatch>::value
385  || std::is_same<BatchType, CSRBatch>::value,
386  "PredictBatchBase_: unrecognized batch type");
387  const double tstart = dmlc::GetTime();
388  PredThreadPool* pool = static_cast<PredThreadPool*>(thread_pool_handle_);
389  const InputType input_type
390  = std::is_same<BatchType, CSRBatch>::value
391  ? InputType::kSparseBatch : InputType::kDenseBatch;
392  InputToken request{input_type, static_cast<const void*>(batch), pred_margin,
393  num_output_group_, pred_func_handle_,
394  0, batch->num_row, out_result};
395  OutputToken response;
396  CHECK_GT(batch->num_row, 0);
397  const int nthread = std::min(num_worker_thread_,
398  static_cast<int>(batch->num_row)
399  - static_cast<int>(include_master_thread_));
400  const std::vector<size_t> row_ptr
401  = SplitBatch(batch, nthread + static_cast<int>(include_master_thread_));
402  for (int tid = 0; tid < nthread; ++tid) {
403  request.rbegin = row_ptr[tid];
404  request.rend = row_ptr[tid + 1];
405  pool->SubmitTask(tid, request);
406  }
407  size_t total_size = 0;
408  if (include_master_thread_) {
409  const size_t rbegin = row_ptr[nthread];
410  const size_t rend = row_ptr[nthread + 1];
411  const size_t query_result_size
412  = PredictBatch_(batch, pred_margin, num_output_group_,
413  pred_func_handle_,
414  rbegin, rend, QueryResultSize(batch, rbegin, rend),
415  out_result);
416  total_size += query_result_size;
417  }
418  for (int tid = 0; tid < nthread; ++tid) {
419  if (pool->WaitForTask(tid, &response)) {
420  total_size += response.query_result_size;
421  }
422  }
423  // re-shape output if total_size < dimension of out_result
424  if (total_size < QueryResultSize(batch, 0, batch->num_row)) {
425  CHECK_GT(num_output_group_, 1);
426  CHECK_EQ(total_size % batch->num_row, 0);
427  const size_t query_size_per_instance = total_size / batch->num_row;
428  CHECK_GT(query_size_per_instance, 0);
429  CHECK_LT(query_size_per_instance, num_output_group_);
430  for (size_t rid = 0; rid < batch->num_row; ++rid) {
431  for (size_t k = 0; k < query_size_per_instance; ++k) {
432  out_result[rid * query_size_per_instance + k]
433  = out_result[rid * num_output_group_ + k];
434  }
435  }
436  }
437  const double tend = dmlc::GetTime();
438  if (verbose > 0) {
439  LOG(INFO) << "Treelite: Finished prediction in "
440  << tend - tstart << " sec";
441  }
442  return total_size;
443 }
444 
445 size_t
446 Predictor::PredictBatch(const CSRBatch* batch, int verbose,
447  bool pred_margin, float* out_result) {
448  return PredictBatchBase_(batch, verbose, pred_margin, out_result);
449 }
450 
451 size_t
452 Predictor::PredictBatch(const DenseBatch* batch, int verbose,
453  bool pred_margin, float* out_result) {
454  return PredictBatchBase_(batch, verbose, pred_margin, out_result);
455 }
456 
457 size_t
458 Predictor::PredictInst(TreelitePredictorEntry* inst, bool pred_margin,
459  float* out_result) {
460  PredThreadPool* pool = static_cast<PredThreadPool*>(thread_pool_handle_);
461  const InputType input_type = InputType::kSingleInst;
462  InputToken request{input_type, static_cast<const void*>(inst), pred_margin,
463  num_output_group_, pred_func_handle_,
464  0, 1, out_result};
465  OutputToken response;
466  size_t total_size;
467  total_size = PredictInst_(inst, pred_margin, num_output_group_,
468  pred_func_handle_,
469  QueryResultSizeSingleInst(), out_result);
470  return total_size;
471 }
472 
473 } // namespace treelite
Load prediction function exported as a shared library.
const uint32_t * col_ind
feature indices
Definition: predictor.h:27
size_t QueryResultSizeSingleInst() const
Query the necessary size of array to hold the prediction for a single data row.
Definition: predictor.h:157
size_t QueryResultSize(const CSRBatch *batch) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
Definition: predictor.h:106
sparse batch in Compressed Sparse Row (CSR) format
Definition: predictor.h:23
const size_t * row_ptr
pointer to row headers; length of [num_row] + 1
Definition: predictor.h:29
dense batch
Definition: predictor.h:37
const float * data
feature values
Definition: predictor.h:39
float missing_value
value representing the missing value (usually nan)
Definition: predictor.h:41
data layout. The value -1 signifies the missing value. When the "missing" field is set to -1...
Definition: entry.h:11
a simple thread pool implementation
const float * data
feature values
Definition: predictor.h:25
predictor class: wrapper for optimized prediction code
Definition: predictor.h:49
size_t num_col
number of columns (i.e. # of features used)
Definition: predictor.h:33
size_t num_col
number of columns (i.e. # of features used)
Definition: predictor.h:45