Treelite
predictor.cc
Go to the documentation of this file.
1 
8 #include <treelite/predictor.h>
9 #include <treelite/math.h>
10 #include <treelite/data.h>
11 #include <treelite/typeinfo.h>
12 #include <cstdint>
13 #include <algorithm>
14 #include <memory>
15 #include <fstream>
16 #include <limits>
17 #include <functional>
18 #include <type_traits>
19 #include <chrono>
21 
22 #ifdef _WIN32
23 #include <windows.h>
24 #else
25 #include <dlfcn.h>
26 #endif
27 
28 namespace {
29 
30 inline double GetTime(void) {
31  return std::chrono::duration<double>(
32  std::chrono::high_resolution_clock::now().time_since_epoch()).count();
33 }
34 
35 struct InputToken {
36  const treelite::DMatrix* dmat; // input data
37  bool pred_margin; // whether to store raw margin or transformed scores
38  const treelite::predictor::PredFunction* pred_func_;
39  size_t rbegin, rend; // range of instances (rows) assigned to each worker
40  PredictorOutputHandle out_pred; // buffer to store output from each worker
41 };
42 
43 struct OutputToken {
44  size_t query_result_size;
45 };
46 
47 using PredThreadPool
49 
50 template <typename ElementType, typename ThresholdType, typename LeafOutputType, typename PredFunc>
51 inline size_t PredLoop(const treelite::CSRDMatrixImpl<ElementType>* dmat, int num_feature,
52  size_t rbegin, size_t rend, LeafOutputType* out_pred, PredFunc func) {
53  TREELITE_CHECK_LE(dmat->num_col, static_cast<size_t>(num_feature));
54  std::vector<treelite::predictor::Entry<ThresholdType>> inst(
55  std::max(dmat->num_col, static_cast<size_t>(num_feature)), {-1});
56  TREELITE_CHECK(rbegin < rend && rend <= dmat->num_row);
57  const ElementType* data = dmat->data.data();
58  const uint32_t* col_ind = dmat->col_ind.data();
59  const size_t* row_ptr = dmat->row_ptr.data();
60  size_t total_output_size = 0;
61  for (size_t rid = rbegin; rid < rend; ++rid) {
62  const size_t ibegin = row_ptr[rid];
63  const size_t iend = row_ptr[rid + 1];
64  for (size_t i = ibegin; i < iend; ++i) {
65  inst[col_ind[i]].fvalue = static_cast<ThresholdType>(data[i]);
66  }
67  total_output_size += func(rid, &inst[0], out_pred);
68  for (size_t i = ibegin; i < iend; ++i) {
69  inst[col_ind[i]].missing = -1;
70  }
71  }
72  return total_output_size;
73 }
74 
75 template <typename ElementType, typename ThresholdType, typename LeafOutputType, typename PredFunc>
76 inline size_t PredLoop(const treelite::DenseDMatrixImpl<ElementType>* dmat, int num_feature,
77  size_t rbegin, size_t rend, LeafOutputType* out_pred, PredFunc func) {
78  const bool nan_missing = treelite::math::CheckNAN(dmat->missing_value);
79  TREELITE_CHECK_LE(dmat->num_col, static_cast<size_t>(num_feature));
80  std::vector<treelite::predictor::Entry<ThresholdType>> inst(
81  std::max(dmat->num_col, static_cast<size_t>(num_feature)), {-1});
82  TREELITE_CHECK(rbegin < rend && rend <= dmat->num_row);
83  const size_t num_col = dmat->num_col;
84  const ElementType missing_value = dmat->missing_value;
85  const ElementType* data = dmat->data.data();
86  const ElementType* row = nullptr;
87  size_t total_output_size = 0;
88  for (size_t rid = rbegin; rid < rend; ++rid) {
89  row = &data[rid * num_col];
90  for (size_t j = 0; j < num_col; ++j) {
91  if (treelite::math::CheckNAN(row[j])) {
92  TREELITE_CHECK(nan_missing)
93  << "The missing_value argument must be set to NaN if there is any NaN in the matrix.";
94  } else if (nan_missing || row[j] != missing_value) {
95  inst[j].fvalue = static_cast<ThresholdType>(row[j]);
96  }
97  }
98  total_output_size += func(rid, &inst[0], out_pred);
99  for (size_t j = 0; j < num_col; ++j) {
100  inst[j].missing = -1;
101  }
102  }
103  return total_output_size;
104 }
105 
106 template <typename ElementType>
107 class PredLoopDispatcherWithDenseDMatrix {
108  public:
109  template <typename ThresholdType, typename LeafOutputType, typename PredFunc>
110  inline static size_t Dispatch(
111  const treelite::DMatrix* dmat, ThresholdType, int num_feature, size_t rbegin, size_t rend,
112  LeafOutputType* out_pred, PredFunc func) {
113  const auto* dmat_ = static_cast<const treelite::DenseDMatrixImpl<ElementType>*>(dmat);
114  return PredLoop<ElementType, ThresholdType, LeafOutputType, PredFunc>(
115  dmat_, num_feature, rbegin, rend, out_pred, func);
116  }
117 };
118 
119 template <typename ElementType>
120 class PredLoopDispatcherWithCSRDMatrix {
121  public:
122  template <typename ThresholdType, typename LeafOutputType, typename PredFunc>
123  inline static size_t Dispatch(
124  const treelite::DMatrix* dmat, ThresholdType, int num_feature, size_t rbegin, size_t rend,
125  LeafOutputType* out_pred, PredFunc func) {
126  const auto* dmat_ = static_cast<const treelite::CSRDMatrixImpl<ElementType>*>(dmat);
127  return PredLoop<ElementType, ThresholdType, LeafOutputType, PredFunc>(
128  dmat_, num_feature, rbegin, rend, out_pred, func);
129  }
130 };
131 
132 template <typename ThresholdType, typename LeafOutputType, typename PredFunc>
133 inline size_t PredLoop(const treelite::DMatrix* dmat, ThresholdType test_val, int num_feature,
134  size_t rbegin, size_t rend, LeafOutputType* out_pred, PredFunc func) {
135  treelite::DMatrixType dmat_type = dmat->GetType();
136  switch (dmat_type) {
137  case treelite::DMatrixType::kDense: {
138  return treelite::DispatchWithTypeInfo<PredLoopDispatcherWithDenseDMatrix>(
139  dmat->GetElementType(), dmat, test_val, num_feature, rbegin, rend, out_pred, func);
140  }
141  case treelite::DMatrixType::kSparseCSR: {
142  return treelite::DispatchWithTypeInfo<PredLoopDispatcherWithCSRDMatrix>(
143  dmat->GetElementType(), dmat, test_val, num_feature, rbegin, rend, out_pred, func);
144  }
145  default:
146  TREELITE_LOG(FATAL) << "Unrecognized data matrix type: " << static_cast<int>(dmat_type);
147  return 0;
148  }
149 }
150 
151 } // anonymous namespace
152 
153 namespace treelite {
154 namespace predictor {
155 
156 SharedLibrary::SharedLibrary() : handle_(nullptr), libpath_() {}
157 
158 SharedLibrary::~SharedLibrary() {
159  if (handle_) {
160 #ifdef _WIN32
161  FreeLibrary(static_cast<HMODULE>(handle_));
162 #else
163  dlclose(static_cast<void*>(handle_));
164 #endif
165  }
166 }
167 
168 void
169 SharedLibrary::Load(const char* libpath) {
170 #ifdef _WIN32
171  HMODULE handle = LoadLibraryA(libpath);
172 #else
173  void* handle = dlopen(libpath, RTLD_LAZY | RTLD_LOCAL);
174 #endif
175  TREELITE_CHECK(handle) << "Failed to load dynamic shared library `" << libpath << "'";
176  handle_ = static_cast<LibraryHandle>(handle);
177  libpath_ = std::string(libpath);
178 }
179 
180 SharedLibrary::FunctionHandle
181 SharedLibrary::LoadFunction(const char* name) const {
182 #ifdef _WIN32
183  FARPROC func_handle = GetProcAddress(static_cast<HMODULE>(handle_), name);
184 #else
185  void* func_handle = dlsym(static_cast<void*>(handle_), name);
186 #endif
187  TREELITE_CHECK(func_handle)
188  << "Dynamic shared library `" << libpath_ << "' does not contain a function " << name << "().";
189  return static_cast<SharedLibrary::FunctionHandle>(func_handle);
190 }
191 
192 template <typename HandleType>
193 HandleType
194 SharedLibrary::LoadFunctionWithSignature(const char* name) const {
195  auto func_handle = reinterpret_cast<HandleType>(LoadFunction(name));
196  TREELITE_CHECK(func_handle) << "Dynamic shared library `" << libpath_
197  << "' does not contain a function " << name
198  << "() with the requested signature";
199  return func_handle;
200 }
201 
202 template <typename ThresholdType, typename LeafOutputType>
204  public:
205  inline static std::unique_ptr<PredFunction> Dispatch(
206  const SharedLibrary& library, int num_feature, int num_class) {
207  return std::make_unique<PredFunctionImpl<ThresholdType, LeafOutputType>>(
208  library, num_feature, num_class);
209  }
210 };
211 
212 std::unique_ptr<PredFunction>
213 PredFunction::Create(
214  TypeInfo threshold_type, TypeInfo leaf_output_type, const SharedLibrary& library,
215  int num_feature, int num_class) {
216  return DispatchWithModelTypes<PredFunctionInitDispatcher>(
217  threshold_type, leaf_output_type, library, num_feature, num_class);
218 }
219 
220 template <typename ThresholdType, typename LeafOutputType>
222  const SharedLibrary& library, int num_feature, int num_class) {
223  TREELITE_CHECK_GT(num_class, 0) << "num_class cannot be zero";
224  if (num_class > 1) { // multi-class classification
225  handle_ = library.LoadFunction("predict_multiclass");
226  } else { // everything else
227  handle_ = library.LoadFunction("predict");
228  }
229  num_feature_ = num_feature;
230  num_class_ = num_class;
231 }
232 
233 template <typename ThresholdType, typename LeafOutputType>
234 TypeInfo
236  return TypeToInfo<ThresholdType>();
237 }
238 
239 template <typename ThresholdType, typename LeafOutputType>
240 TypeInfo
242  return TypeToInfo<LeafOutputType>();
243 }
244 
245 template <typename ThresholdType, typename LeafOutputType>
246 size_t
248  const DMatrix* dmat, size_t rbegin, size_t rend, bool pred_margin,
249  PredictorOutputHandle out_pred) const {
250  /* Pass the correct prediction function to PredLoop.
251  We also need to specify how the function should be called. */
252  size_t result_size;
253  // Dimension of output vector:
254  // can be either [num_data] or [num_class]*[num_data].
255  // Note that size of prediction may be smaller than out_pred (this occurs
256  // when pred_function is set to "max_index").
257  TREELITE_CHECK(rbegin < rend && rend <= dmat->GetNumRow());
258  if (num_class_ > 1) { // multi-class classification
259  using PredFunc = size_t (*)(Entry<ThresholdType>*, int, LeafOutputType*);
260  auto pred_func = reinterpret_cast<PredFunc>(handle_);
261  TREELITE_CHECK(pred_func) << "The predict_multiclass() function has incorrect signature.";
262  auto pred_func_wrapper
263  = [pred_func, num_class = num_class_, pred_margin]
264  (int64_t rid, Entry<ThresholdType>* inst, LeafOutputType* out_pred) -> size_t {
265  return pred_func(inst, static_cast<int>(pred_margin),
266  &out_pred[rid * num_class]);
267  };
268  result_size = PredLoop(dmat, static_cast<ThresholdType>(0), num_feature_, rbegin, rend,
269  static_cast<LeafOutputType*>(out_pred), pred_func_wrapper);
270  } else { // everything else
271  using PredFunc = LeafOutputType (*)(Entry<ThresholdType>*, int);
272  auto pred_func = reinterpret_cast<PredFunc>(handle_);
273  TREELITE_CHECK(pred_func) << "The predict() function has incorrect signature.";
274  auto pred_func_wrapper
275  = [pred_func, pred_margin]
276  (int64_t rid, Entry<ThresholdType>* inst, LeafOutputType* out_pred) -> size_t {
277  out_pred[rid] = pred_func(inst, static_cast<int>(pred_margin));
278  return 1;
279  };
280  result_size = PredLoop(dmat, static_cast<ThresholdType>(0), num_feature_, rbegin, rend,
281  static_cast<LeafOutputType*>(out_pred), pred_func_wrapper);
282  }
283  return result_size;
284 }
285 
286 Predictor::Predictor(int num_worker_thread)
287  : pred_func_(nullptr),
288  thread_pool_handle_(nullptr),
289  num_class_(0),
290  num_feature_(0),
291  sigmoid_alpha_(std::numeric_limits<float>::quiet_NaN()),
292  ratio_c_(std::numeric_limits<float>::quiet_NaN()),
293  global_bias_(std::numeric_limits<float>::quiet_NaN()),
294  num_worker_thread_(num_worker_thread),
295  threshold_type_(TypeInfo::kInvalid),
296  leaf_output_type_(TypeInfo::kInvalid) {}
297 Predictor::~Predictor() {
298  if (thread_pool_handle_) {
299  Free();
300  }
301 }
302 
303 void
304 Predictor::Load(const char* libpath) {
305  lib_.Load(libpath);
306 
307  using UnsignedQueryFunc = size_t (*)();
308  using StringQueryFunc = const char* (*)();
309  using FloatQueryFunc = float (*)();
310 
311  /* 1. query # of output groups */
312  auto* num_class_query_func
313  = lib_.LoadFunctionWithSignature<UnsignedQueryFunc>("get_num_class");
314  num_class_ = num_class_query_func();
315 
316  /* 2. query # of features */
317  auto* num_feature_query_func
318  = lib_.LoadFunctionWithSignature<UnsignedQueryFunc>("get_num_feature");
319  num_feature_ = num_feature_query_func();
320  TREELITE_CHECK_GT(num_feature_, 0) << "num_feature cannot be zero";
321 
322  /* 3. query # of pred_transform name */
323  auto* pred_transform_query_func
324  = lib_.LoadFunctionWithSignature<StringQueryFunc>("get_pred_transform");
325  pred_transform_ = pred_transform_query_func();
326 
327  /* 4. query # of sigmoid_alpha */
328  auto* sigmoid_alpha_query_func
329  = lib_.LoadFunctionWithSignature<FloatQueryFunc>("get_sigmoid_alpha");
330  sigmoid_alpha_ = sigmoid_alpha_query_func();
331 
332  /* 5. query # of ratio_c */
333  auto* ratio_c_query_func
334  = lib_.LoadFunctionWithSignature<FloatQueryFunc>("get_ratio_c");
335  ratio_c_ = ratio_c_query_func();
336 
337  /* 6. query # of global_bias */
338  auto* global_bias_query_func = lib_.LoadFunctionWithSignature<FloatQueryFunc>("get_global_bias");
339  global_bias_ = global_bias_query_func();
340 
341  /* 7. Query the data type for thresholds and leaf outputs */
342  auto* threshold_type_query_func
343  = lib_.LoadFunctionWithSignature<StringQueryFunc>("get_threshold_type");
344  threshold_type_ = GetTypeInfoByName(threshold_type_query_func());
345  auto* leaf_output_type_query_func
346  = lib_.LoadFunctionWithSignature<StringQueryFunc>("get_leaf_output_type");
347  leaf_output_type_ = GetTypeInfoByName(leaf_output_type_query_func());
348 
349  /* 8. load appropriate function for margin prediction */
350  TREELITE_CHECK_GT(num_class_, 0) << "num_class cannot be zero";
351  pred_func_ = PredFunction::Create(
352  threshold_type_, leaf_output_type_, lib_,
353  static_cast<int>(num_feature_), static_cast<int>(num_class_));
354 
355  if (num_worker_thread_ == -1) {
356  num_worker_thread_ = static_cast<int>(std::thread::hardware_concurrency());
357  }
358  thread_pool_handle_ = static_cast<ThreadPoolHandle>(
359  new PredThreadPool(num_worker_thread_ - 1, this,
360  [](SpscQueue<InputToken>* incoming_queue,
361  SpscQueue<OutputToken>* outgoing_queue,
362  const Predictor* predictor) {
363  predictor->exception_catcher_.Run([&]() {
364  InputToken input;
365  while (incoming_queue->Pop(&input)) {
366  const size_t rbegin = input.rbegin;
367  const size_t rend = input.rend;
368  size_t query_result_size
369  = predictor->pred_func_->PredictBatch(
370  input.dmat, rbegin, rend, input.pred_margin, input.out_pred);
371  outgoing_queue->Push(OutputToken{query_result_size});
372  }
373  });
374  }));
375 }
376 
377 void
378 Predictor::Free() {
379  delete static_cast<PredThreadPool*>(thread_pool_handle_);
380 }
381 
382 static inline
383 std::vector<size_t> SplitBatch(const DMatrix* dmat, size_t split_factor) {
384  const size_t num_row = dmat->GetNumRow();
385  TREELITE_CHECK_LE(split_factor, num_row);
386  const size_t portion = num_row / split_factor;
387  const size_t remainder = num_row % split_factor;
388  std::vector<size_t> workload(split_factor, portion);
389  std::vector<size_t> row_ptr(split_factor + 1, 0);
390  for (size_t i = 0; i < remainder; ++i) {
391  ++workload[i];
392  }
393  size_t accum = 0;
394  for (size_t i = 0; i < split_factor; ++i) {
395  accum += workload[i];
396  row_ptr[i + 1] = accum;
397  }
398  return row_ptr;
399 }
400 
401 template <typename LeafOutputType>
403  public:
404  inline static void Dispatch(
405  size_t num_row, size_t query_size_per_instance, size_t num_class,
406  PredictorOutputHandle out_result);
407 };
408 
409 template <typename LeafOutputType>
411  public:
412  inline static PredictorOutputHandle Dispatch(size_t size);
413 };
414 
415 template <typename LeafOutputType>
417  public:
418  inline static void Dispatch(PredictorOutputHandle output_vector);
419 };
420 
421 size_t
422 Predictor::PredictBatch(
423  const DMatrix* dmat, int verbose, bool pred_margin, PredictorOutputHandle out_result) const {
424  const double tstart = GetTime();
425 
426  const size_t num_row = dmat->GetNumRow();
427  auto* pool = static_cast<PredThreadPool*>(thread_pool_handle_);
428  InputToken request{dmat, pred_margin, pred_func_.get(), 0, num_row, out_result};
429  OutputToken response;
430  TREELITE_CHECK_GT(num_row, 0);
431  const int nthread = std::min(num_worker_thread_, static_cast<int>(num_row));
432  const std::vector<size_t> row_ptr = SplitBatch(dmat, nthread);
433  for (int tid = 0; tid < nthread - 1; ++tid) {
434  request.rbegin = row_ptr[tid];
435  request.rend = row_ptr[tid + 1];
436  pool->SubmitTask(tid, request);
437  }
438  size_t total_size = 0;
439  {
440  // assign work to the main thread
441  const size_t rbegin = row_ptr[nthread - 1];
442  const size_t rend = row_ptr[nthread];
443  const size_t query_result_size
444  = pred_func_->PredictBatch(dmat, rbegin, rend, pred_margin, out_result);
445  total_size += query_result_size;
446  }
447  for (int tid = 0; tid < nthread - 1; ++tid) {
448  if (pool->WaitForTask(tid, &response)) {
449  total_size += response.query_result_size;
450  }
451  }
452  // re-shape output if total_size < dimension of out_result
453  if (total_size < QueryResultSize(dmat, 0, num_row)) {
454  TREELITE_CHECK_GT(num_class_, 1);
455  TREELITE_CHECK_EQ(total_size % num_row, 0);
456  const size_t query_size_per_instance = total_size / num_row;
457  TREELITE_CHECK_GT(query_size_per_instance, 0);
458  TREELITE_CHECK_LT(query_size_per_instance, num_class_);
459  DispatchWithTypeInfo<ShrinkResultToFit>(
460  leaf_output_type_, num_row, query_size_per_instance, num_class_, out_result);
461  }
462  const double tend = GetTime();
463  if (verbose > 0) {
464  TREELITE_LOG(INFO) << "Treelite: Finished prediction in " << tend - tstart << " sec";
465  }
466  return total_size;
467 }
468 
470 Predictor::CreateOutputVector(const DMatrix* dmat) const {
471  const size_t output_vector_size = this->QueryResultSize(dmat);
472  return DispatchWithTypeInfo<AllocateOutputVector>(leaf_output_type_, output_vector_size);
473 }
474 
475 void
476 Predictor::DeleteOutputVector(PredictorOutputHandle output_vector) const {
477  DispatchWithTypeInfo<DeallocateOutputVector>(leaf_output_type_, output_vector);
478 }
479 
480 template <typename LeafOutputType>
481 void
483  size_t num_row, size_t query_size_per_instance, size_t num_class,
484  PredictorOutputHandle out_result) {
485  auto* out_result_ = static_cast<LeafOutputType*>(out_result);
486  for (size_t rid = 0; rid < num_row; ++rid) {
487  for (size_t k = 0; k < query_size_per_instance; ++k) {
488  out_result_[rid * query_size_per_instance + k] = out_result_[rid * num_class + k];
489  }
490  }
491 }
492 
493 template <typename LeafOutputType>
496  return static_cast<PredictorOutputHandle>(new LeafOutputType[size]);
497 }
498 
499 template <typename LeafOutputType>
500 void
502  delete[] (static_cast<LeafOutputType*>(output_vector));
503 }
504 
505 } // namespace predictor
506 } // namespace treelite
Load prediction function exported as a shared library.
Some useful math utilities.
size_t num_col
number of columns (i.e. # of features used)
Definition: data.h:120
Input data structure of Treelite.
bool CheckNAN(T value)
check for NaN (Not a Number)
Definition: math.h:43
ElementType missing_value
value representing the missing value (usually NaN)
Definition: data.h:58
void Run(Function f, Parameters... params)
Parallel OMP blocks should be placed within Run to save exception.
Definition: predictor.h:42
void * PredictorOutputHandle
handle to output from predictor
Definition: c_api_runtime.h:25
std::vector< size_t > row_ptr
pointer to row headers; length is [num_row] + 1.
Definition: data.h:116
std::vector< uint32_t > col_ind
feature indices. col_ind[i] indicates the feature index associated with data[i].
Definition: data.h:114
std::vector< ElementType > data
feature values
Definition: data.h:56
TypeInfo
Types used by thresholds and leaf outputs.
Definition: typeinfo.h:22
std::vector< ElementType > data
feature values
Definition: data.h:112
Defines TypeInfo class and utilities.
TypeInfo GetTypeInfoByName(const std::string &str)
conversion table from string to TypeInfo, defined in tables.cc
Definition: typeinfo.cc:16
a simple thread pool implementation
data layout. The value -1 signifies the missing value. When the "missing" field is set to -1...
Definition: predictor.h:71
predictor class: wrapper for optimized prediction code
Definition: predictor.h:128
void * ThreadPoolHandle
opaque handle types
Definition: predictor.h:131
size_t num_col
number of columns (i.e. # of features used)
Definition: data.h:62