treelite
predictor.cc
Go to the documentation of this file.
1 
8 #include <treelite/predictor.h>
9 #include <treelite/common.h>
10 #include <dmlc/logging.h>
11 #include <dmlc/io.h>
12 #include <dmlc/timer.h>
13 #include <cstdint>
14 #include <algorithm>
15 #include <memory>
16 #include <fstream>
17 #include <limits>
18 #include <functional>
19 #include <type_traits>
20 #include "common/math.h"
21 #include "common/filesystem.h"
23 
24 #ifdef _WIN32
25 #define NOMINMAX
26 #include <windows.h>
27 #else
28 #include <dlfcn.h>
29 #endif
30 
31 namespace {
32 
33 enum class InputType : uint8_t {
34  kSparseBatch = 0, kDenseBatch = 1
35 };
36 
37 struct InputToken {
38  InputType input_type;
39  const void* data; // pointer to input data
40  bool pred_margin; // whether to store raw margin or transformed scores
41  size_t num_feature;
42  // # features (columns) accepted by the tree ensemble model
43  size_t num_output_group;
44  // size of output per instance (row)
45  treelite::Predictor::PredFuncHandle pred_func_handle;
46  size_t rbegin, rend;
47  // range of instances (rows) assigned to each worker
48  float* out_pred;
49  // buffer to store output from each worker
50 };
51 
52 struct OutputToken {
53  size_t query_result_size;
54 };
55 
56 inline std::string GetProtocol(const char* name) {
57  const char *p = std::strstr(name, "://");
58  if (p == NULL) {
59  return "";
60  } else {
61  return std::string(name, p - name + 3);
62  }
63 }
64 
65 using PredThreadPool
67 
68 inline treelite::Predictor::LibraryHandle OpenLibrary(const char* name) {
69 #ifdef _WIN32
70  HMODULE handle = LoadLibraryA(name);
71 #else
72  void* handle = dlopen(name, RTLD_LAZY | RTLD_LOCAL);
73 #endif
74  return static_cast<treelite::Predictor::LibraryHandle>(handle);
75 }
76 
77 inline void CloseLibrary(treelite::Predictor::LibraryHandle handle) {
78 #ifdef _WIN32
79  FreeLibrary(static_cast<HMODULE>(handle));
80 #else
81  dlclose(static_cast<void*>(handle));
82 #endif
83 }
84 
85 template <typename HandleType>
86 inline HandleType LoadFunction(treelite::Predictor::LibraryHandle lib_handle,
87  const char* name) {
88 #ifdef _WIN32
89  FARPROC func_handle = GetProcAddress(static_cast<HMODULE>(lib_handle), name);
90 #else
91  void* func_handle = dlsym(static_cast<void*>(lib_handle), name);
92 #endif
93  return static_cast<HandleType>(func_handle);
94 }
95 
96 template <typename PredFunc>
97 inline size_t PredLoop(const treelite::CSRBatch* batch, size_t num_feature,
98  size_t rbegin, size_t rend,
99  float* out_pred, PredFunc func) {
100  CHECK_LE(batch->num_col, num_feature);
101  std::vector<TreelitePredictorEntry> inst(
102  std::max(batch->num_col, num_feature), {-1});
103  CHECK(rbegin < rend && rend <= batch->num_row);
104  CHECK(sizeof(size_t) < sizeof(int64_t)
105  || (rbegin <= static_cast<size_t>(std::numeric_limits<int64_t>::max())
106  && rend <= static_cast<size_t>(std::numeric_limits<int64_t>::max())));
107  const int64_t rbegin_ = static_cast<int64_t>(rbegin);
108  const int64_t rend_ = static_cast<int64_t>(rend);
109  const size_t num_col = batch->num_col;
110  const float* data = batch->data;
111  const uint32_t* col_ind = batch->col_ind;
112  const size_t* row_ptr = batch->row_ptr;
113  size_t total_output_size = 0;
114  for (int64_t rid = rbegin_; rid < rend_; ++rid) {
115  const size_t ibegin = row_ptr[rid];
116  const size_t iend = row_ptr[rid + 1];
117  for (size_t i = ibegin; i < iend; ++i) {
118  inst[col_ind[i]].fvalue = data[i];
119  }
120  total_output_size += func(rid, &inst[0], out_pred);
121  for (size_t i = ibegin; i < iend; ++i) {
122  inst[col_ind[i]].missing = -1;
123  }
124  }
125  return total_output_size;
126 }
127 
128 template <typename PredFunc>
129 inline size_t PredLoop(const treelite::DenseBatch* batch, size_t num_feature,
130  size_t rbegin, size_t rend,
131  float* out_pred, PredFunc func) {
132  const bool nan_missing
133  = treelite::common::math::CheckNAN(batch->missing_value);
134  CHECK_LE(batch->num_col, num_feature);
135  std::vector<TreelitePredictorEntry> inst(
136  std::max(batch->num_col, num_feature), {-1});
137  CHECK(rbegin < rend && rend <= batch->num_row);
138  CHECK(sizeof(size_t) < sizeof(int64_t)
139  || (rbegin <= static_cast<size_t>(std::numeric_limits<int64_t>::max())
140  && rend <= static_cast<size_t>(std::numeric_limits<int64_t>::max())));
141  const int64_t rbegin_ = static_cast<int64_t>(rbegin);
142  const int64_t rend_ = static_cast<int64_t>(rend);
143  const size_t num_col = batch->num_col;
144  const float missing_value = batch->missing_value;
145  const float* data = batch->data;
146  const float* row;
147  size_t total_output_size = 0;
148  for (int64_t rid = rbegin_; rid < rend_; ++rid) {
149  row = &data[rid * num_col];
150  for (size_t j = 0; j < num_col; ++j) {
151  if (treelite::common::math::CheckNAN(row[j])) {
152  CHECK(nan_missing)
153  << "The missing_value argument must be set to NaN if there is any "
154  << "NaN in the matrix.";
155  } else if (nan_missing || row[j] != missing_value) {
156  inst[j].fvalue = row[j];
157  }
158  }
159  total_output_size += func(rid, &inst[0], out_pred);
160  for (size_t j = 0; j < num_col; ++j) {
161  inst[j].missing = -1;
162  }
163  }
164  return total_output_size;
165 }
166 
167 template <typename BatchType>
168 inline size_t PredictBatch_(const BatchType* batch, bool pred_margin,
169  size_t num_feature, size_t num_output_group,
170  treelite::Predictor::PredFuncHandle pred_func_handle,
171  size_t rbegin, size_t rend,
172  size_t expected_query_result_size, float* out_pred) {
173  CHECK(pred_func_handle != nullptr)
174  << "A shared library needs to be loaded first using Load()";
175  /* Pass the correct prediction function to PredLoop.
176  We also need to specify how the function should be called. */
177  size_t query_result_size;
178  // Dimension of output vector:
179  // can be either [num_data] or [num_class]*[num_data].
180  // Note that size of prediction may be smaller than out_pred (this occurs
181  // when pred_function is set to "max_index").
182  if (num_output_group > 1) { // multi-class classification task
183  using PredFunc = size_t (*)(TreelitePredictorEntry*, int, float*);
184  PredFunc pred_func = reinterpret_cast<PredFunc>(pred_func_handle);
185  query_result_size =
186  PredLoop(batch, num_feature, rbegin, rend, out_pred,
187  [pred_func, num_output_group, pred_margin]
188  (int64_t rid, TreelitePredictorEntry* inst, float* out_pred) -> size_t {
189  return pred_func(inst, static_cast<int>(pred_margin),
190  &out_pred[rid * num_output_group]);
191  });
192  } else { // every other task
193  using PredFunc = float (*)(TreelitePredictorEntry*, int);
194  PredFunc pred_func = reinterpret_cast<PredFunc>(pred_func_handle);
195  query_result_size =
196  PredLoop(batch, num_feature, rbegin, rend, out_pred,
197  [pred_func, pred_margin]
198  (int64_t rid, TreelitePredictorEntry* inst, float* out_pred) -> size_t {
199  out_pred[rid] = pred_func(inst, static_cast<int>(pred_margin));
200  return 1;
201  });
202  }
203  return query_result_size;
204 }
205 
206 inline size_t PredictInst_(TreelitePredictorEntry* inst,
207  bool pred_margin, size_t num_output_group,
208  treelite::Predictor::PredFuncHandle pred_func_handle,
209  size_t expected_query_result_size, float* out_pred) {
210  CHECK(pred_func_handle != nullptr)
211  << "A shared library needs to be loaded first using Load()";
212  size_t query_result_size; // Dimention of output vector
213  if (num_output_group > 1) { // multi-class classification task
214  using PredFunc = size_t (*)(TreelitePredictorEntry*, int, float*);
215  PredFunc pred_func = reinterpret_cast<PredFunc>(pred_func_handle);
216  query_result_size = pred_func(inst, (int)pred_margin, out_pred);
217  } else { // every other task
218  using PredFunc = float (*)(TreelitePredictorEntry*, int);
219  PredFunc pred_func = reinterpret_cast<PredFunc>(pred_func_handle);
220  out_pred[0] = pred_func(inst, (int)pred_margin);
221  query_result_size = 1;
222  }
223  return query_result_size;
224 }
225 
226 } // anonymous namespace
227 
228 namespace treelite {
229 
230 Predictor::Predictor(int num_worker_thread)
231  : lib_handle_(nullptr),
232  num_output_group_query_func_handle_(nullptr),
233  num_feature_query_func_handle_(nullptr),
234  pred_func_handle_(nullptr),
235  thread_pool_handle_(nullptr),
236  num_worker_thread_(num_worker_thread),
237  tempdir_(nullptr) {}
238 Predictor::~Predictor() {
239  Free();
240 }
241 
242 void
243 Predictor::Load(const char* name) {
244  const std::string protocol = GetProtocol(name);
245  if (protocol == "file://" || protocol.empty()) {
246  // local file
247  lib_handle_ = OpenLibrary(name);
248  } else {
249  // remote file
250  tempdir_.reset(new common::filesystem::TemporaryDirectory());
251  temp_libfile_ = tempdir_->AddFile(common::filesystem::GetBasename(name));
252  {
253  std::unique_ptr<dmlc::Stream> strm(dmlc::Stream::Create(name, "r"));
254  dmlc::istream is(strm.get());
255  std::ofstream of(temp_libfile_);
256  of << is.rdbuf();
257  }
258  lib_handle_ = OpenLibrary(temp_libfile_.c_str());
259  }
260  if (lib_handle_ == nullptr) {
261  LOG(FATAL) << "Failed to load dynamic shared library `" << name << "'";
262  }
263 
264  /* 1. query # of output groups */
265  num_output_group_query_func_handle_
266  = LoadFunction<QueryFuncHandle>(lib_handle_, "get_num_output_group");
267  using UnsignedQueryFunc = size_t (*)(void);
268  auto uint_query_func
269  = reinterpret_cast<UnsignedQueryFunc>(num_output_group_query_func_handle_);
270  CHECK(uint_query_func != nullptr)
271  << "Dynamic shared library `" << name
272  << "' does not contain valid get_num_output_group() function";
273  num_output_group_ = uint_query_func();
274 
275  /* 2. query # of features */
276  num_feature_query_func_handle_
277  = LoadFunction<QueryFuncHandle>(lib_handle_, "get_num_feature");
278  uint_query_func = reinterpret_cast<UnsignedQueryFunc>(num_feature_query_func_handle_);
279  CHECK(uint_query_func != nullptr)
280  << "Dynamic shared library `" << name
281  << "' does not contain valid get_num_feature() function";
282  num_feature_ = uint_query_func();
283  CHECK_GT(num_feature_, 0) << "num_feature cannot be zero";
284 
285  /* 3. query # of pred_transform name */
286  pred_transform_query_func_handle_
287  = LoadFunction<QueryFuncHandle>(lib_handle_, "get_pred_transform");
288  using StringQueryFunc = const char* (*)(void);
289  auto str_query_func =
290  reinterpret_cast<StringQueryFunc>(pred_transform_query_func_handle_);
291  if (str_query_func == nullptr) {
292  LOG(INFO) << "Dynamic shared library `" << name
293  << "' does not contain valid get_pred_transform() function";
294  pred_transform_ = "unknown";
295  } else {
296  pred_transform_ = str_query_func();
297  }
298 
299  /* 4. query # of sigmoid_alpha */
300  sigmoid_alpha_query_func_handle_
301  = LoadFunction<QueryFuncHandle>(lib_handle_, "get_sigmoid_alpha");
302  using FloatQueryFunc = float (*)(void);
303  auto float_query_func =
304  reinterpret_cast<FloatQueryFunc>(sigmoid_alpha_query_func_handle_);
305  if (float_query_func == nullptr) {
306  LOG(INFO) << "Dynamic shared library `" << name
307  << "' does not contain valid get_sigmoid_alpha() function";
308  sigmoid_alpha_ = NAN;
309  } else {
310  sigmoid_alpha_ = float_query_func();
311  }
312 
313  /* 5. query # of global_bias */
314  global_bias_query_func_handle_
315  = LoadFunction<QueryFuncHandle>(lib_handle_, "get_global_bias");
316  float_query_func = reinterpret_cast<FloatQueryFunc>(global_bias_query_func_handle_);
317  if (float_query_func == nullptr) {
318  LOG(INFO) << "Dynamic shared library `" << name
319  << "' does not contain valid get_global_bias() function";
320  global_bias_ = NAN;
321  } else {
322  global_bias_ = float_query_func();
323  }
324 
325  /* 6. load appropriate function for margin prediction */
326  CHECK_GT(num_output_group_, 0) << "num_output_group cannot be zero";
327  if (num_output_group_ > 1) { // multi-class classification
328  pred_func_handle_ = LoadFunction<PredFuncHandle>(lib_handle_,
329  "predict_multiclass");
330  using PredFunc = size_t (*)(TreelitePredictorEntry*, int, float*);
331  PredFunc pred_func = reinterpret_cast<PredFunc>(pred_func_handle_);
332  CHECK(pred_func != nullptr)
333  << "Dynamic shared library `" << name
334  << "' does not contain valid predict_multiclass() function";
335  } else { // everything else
336  pred_func_handle_ = LoadFunction<PredFuncHandle>(lib_handle_, "predict");
337  using PredFunc = float (*)(TreelitePredictorEntry*, int);
338  PredFunc pred_func = reinterpret_cast<PredFunc>(pred_func_handle_);
339  CHECK(pred_func != nullptr)
340  << "Dynamic shared library `" << name
341  << "' does not contain valid predict() function";
342  }
343 
344  if (num_worker_thread_ == -1) {
345  num_worker_thread_ = std::thread::hardware_concurrency();
346  }
347  thread_pool_handle_ = static_cast<ThreadPoolHandle>(
348  new PredThreadPool(num_worker_thread_ - 1, this,
349  [](SpscQueue<InputToken>* incoming_queue,
350  SpscQueue<OutputToken>* outgoing_queue,
351  const Predictor* predictor) {
352  InputToken input;
353  while (incoming_queue->Pop(&input)) {
354  size_t query_result_size;
355  const size_t rbegin = input.rbegin;
356  const size_t rend = input.rend;
357  switch (input.input_type) {
358  case InputType::kSparseBatch:
359  {
360  const CSRBatch* batch = static_cast<const CSRBatch*>(input.data);
361  query_result_size
362  = PredictBatch_(batch, input.pred_margin, input.num_feature,
363  input.num_output_group, input.pred_func_handle,
364  rbegin, rend,
365  predictor->QueryResultSize(batch, rbegin, rend),
366  input.out_pred);
367  }
368  break;
369  case InputType::kDenseBatch:
370  {
371  const DenseBatch* batch = static_cast<const DenseBatch*>(input.data);
372  query_result_size
373  = PredictBatch_(batch, input.pred_margin, input.num_feature,
374  input.num_output_group, input.pred_func_handle,
375  rbegin, rend,
376  predictor->QueryResultSize(batch, rbegin, rend),
377  input.out_pred);
378  }
379  break;
380  }
381  outgoing_queue->Push(OutputToken{query_result_size});
382  }
383  }));
384 }
385 
386 void
387 Predictor::Free() {
388  CloseLibrary(lib_handle_);
389  delete static_cast<PredThreadPool*>(thread_pool_handle_);
390 }
391 
392 template <typename BatchType>
393 static inline
394 std::vector<size_t> SplitBatch(const BatchType* batch, size_t split_factor) {
395  const size_t num_row = batch->num_row;
396  CHECK_LE(split_factor, num_row);
397  const size_t portion = num_row / split_factor;
398  const size_t remainder = num_row % split_factor;
399  std::vector<size_t> workload(split_factor, portion);
400  std::vector<size_t> row_ptr(split_factor + 1, 0);
401  for (size_t i = 0; i < remainder; ++i) {
402  ++workload[i];
403  }
404  size_t accum = 0;
405  for (size_t i = 0; i < split_factor; ++i) {
406  accum += workload[i];
407  row_ptr[i + 1] = accum;
408  }
409  return row_ptr;
410 }
411 
412 template <typename BatchType>
413 inline size_t
414 Predictor::PredictBatchBase_(const BatchType* batch, int verbose,
415  bool pred_margin, float* out_result) {
416  static_assert(std::is_same<BatchType, DenseBatch>::value
417  || std::is_same<BatchType, CSRBatch>::value,
418  "PredictBatchBase_: unrecognized batch type");
419  const double tstart = dmlc::GetTime();
420  PredThreadPool* pool = static_cast<PredThreadPool*>(thread_pool_handle_);
421  const InputType input_type
422  = std::is_same<BatchType, CSRBatch>::value
423  ? InputType::kSparseBatch : InputType::kDenseBatch;
424  InputToken request{input_type, static_cast<const void*>(batch), pred_margin,
425  num_feature_, num_output_group_, pred_func_handle_,
426  0, batch->num_row, out_result};
427  OutputToken response;
428  CHECK_GT(batch->num_row, 0);
429  const int nthread = std::min(num_worker_thread_,
430  static_cast<int>(batch->num_row));
431  const std::vector<size_t> row_ptr = SplitBatch(batch, nthread);
432  for (int tid = 0; tid < nthread - 1; ++tid) {
433  request.rbegin = row_ptr[tid];
434  request.rend = row_ptr[tid + 1];
435  pool->SubmitTask(tid, request);
436  }
437  size_t total_size = 0;
438  {
439  // assign work to master
440  const size_t rbegin = row_ptr[nthread - 1];
441  const size_t rend = row_ptr[nthread];
442  const size_t query_result_size
443  = PredictBatch_(batch, pred_margin, num_feature_, num_output_group_,
444  pred_func_handle_,
445  rbegin, rend, QueryResultSize(batch, rbegin, rend),
446  out_result);
447  total_size += query_result_size;
448  }
449  for (int tid = 0; tid < nthread - 1; ++tid) {
450  if (pool->WaitForTask(tid, &response)) {
451  total_size += response.query_result_size;
452  }
453  }
454  // re-shape output if total_size < dimension of out_result
455  if (total_size < QueryResultSize(batch, 0, batch->num_row)) {
456  CHECK_GT(num_output_group_, 1);
457  CHECK_EQ(total_size % batch->num_row, 0);
458  const size_t query_size_per_instance = total_size / batch->num_row;
459  CHECK_GT(query_size_per_instance, 0);
460  CHECK_LT(query_size_per_instance, num_output_group_);
461  for (size_t rid = 0; rid < batch->num_row; ++rid) {
462  for (size_t k = 0; k < query_size_per_instance; ++k) {
463  out_result[rid * query_size_per_instance + k]
464  = out_result[rid * num_output_group_ + k];
465  }
466  }
467  }
468  const double tend = dmlc::GetTime();
469  if (verbose > 0) {
470  LOG(INFO) << "Treelite: Finished prediction in "
471  << tend - tstart << " sec";
472  }
473  return total_size;
474 }
475 
476 size_t
477 Predictor::PredictBatch(const CSRBatch* batch, int verbose,
478  bool pred_margin, float* out_result) {
479  return PredictBatchBase_(batch, verbose, pred_margin, out_result);
480 }
481 
482 size_t
483 Predictor::PredictBatch(const DenseBatch* batch, int verbose,
484  bool pred_margin, float* out_result) {
485  return PredictBatchBase_(batch, verbose, pred_margin, out_result);
486 }
487 
488 size_t
489 Predictor::PredictInst(TreelitePredictorEntry* inst, bool pred_margin,
490  float* out_result) {
491  size_t total_size;
492  total_size = PredictInst_(inst, pred_margin, num_output_group_,
493  pred_func_handle_,
494  QueryResultSizeSingleInst(), out_result);
495  return total_size;
496 }
497 
498 } // namespace treelite
Load prediction function exported as a shared library.
const uint32_t * col_ind
feature indices
Definition: predictor.h:27
sparse batch in Compressed Sparse Row (CSR) format
Definition: predictor.h:23
const size_t * row_ptr
pointer to row headers; length of [num_row] + 1
Definition: predictor.h:29
dense batch
Definition: predictor.h:37
const float * data
feature values
Definition: predictor.h:39
float missing_value
value representing the missing value (usually nan)
Definition: predictor.h:41
size_t QueryResultSize(const CSRBatch *batch) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
Definition: predictor.h:105
data layout. The value -1 signifies the missing value. When the "missing" field is set to -1...
Definition: entry.h:11
a simple thread pool implementation
const float * data
feature values
Definition: predictor.h:25
predictor class: wrapper for optimized prediction code
Definition: predictor.h:49
size_t num_col
number of columns (i.e. # of features used)
Definition: predictor.h:33
size_t num_col
number of columns (i.e. # of features used)
Definition: predictor.h:45