Treelite
predictor.cc
Go to the documentation of this file.
1 
8 #include <treelite/predictor.h>
9 #include <treelite/math.h>
10 #include <treelite/data.h>
11 #include <treelite/typeinfo.h>
12 #include <dmlc/logging.h>
13 #include <dmlc/io.h>
14 #include <dmlc/timer.h>
15 #include <cstdint>
16 #include <algorithm>
17 #include <memory>
18 #include <fstream>
19 #include <limits>
20 #include <functional>
21 #include <type_traits>
23 
24 #ifdef _WIN32
25 #include <windows.h>
26 #else
27 #include <dlfcn.h>
28 #endif
29 
30 namespace {
31 
32 struct InputToken {
33  const treelite::DMatrix* dmat; // input data
34  bool pred_margin; // whether to store raw margin or transformed scores
35  const treelite::predictor::PredFunction* pred_func_;
36  size_t rbegin, rend; // range of instances (rows) assigned to each worker
37  PredictorOutputHandle out_pred; // buffer to store output from each worker
38 };
39 
40 struct OutputToken {
41  size_t query_result_size;
42 };
43 
44 using PredThreadPool
46 
47 template <typename ElementType, typename ThresholdType, typename LeafOutputType, typename PredFunc>
48 inline size_t PredLoop(const treelite::CSRDMatrixImpl<ElementType>* dmat, int num_feature,
49  size_t rbegin, size_t rend, LeafOutputType* out_pred, PredFunc func) {
50  CHECK_LE(dmat->num_col, static_cast<size_t>(num_feature));
51  std::vector<treelite::predictor::Entry<ThresholdType>> inst(
52  std::max(dmat->num_col, static_cast<size_t>(num_feature)), {-1});
53  CHECK(rbegin < rend && rend <= dmat->num_row);
54  const ElementType* data = dmat->data.data();
55  const uint32_t* col_ind = dmat->col_ind.data();
56  const size_t* row_ptr = dmat->row_ptr.data();
57  size_t total_output_size = 0;
58  for (size_t rid = rbegin; rid < rend; ++rid) {
59  const size_t ibegin = row_ptr[rid];
60  const size_t iend = row_ptr[rid + 1];
61  for (size_t i = ibegin; i < iend; ++i) {
62  inst[col_ind[i]].fvalue = static_cast<ThresholdType>(data[i]);
63  }
64  total_output_size += func(rid, &inst[0], out_pred);
65  for (size_t i = ibegin; i < iend; ++i) {
66  inst[col_ind[i]].missing = -1;
67  }
68  }
69  return total_output_size;
70 }
71 
72 template <typename ElementType, typename ThresholdType, typename LeafOutputType, typename PredFunc>
73 inline size_t PredLoop(const treelite::DenseDMatrixImpl<ElementType>* dmat, int num_feature,
74  size_t rbegin, size_t rend, LeafOutputType* out_pred, PredFunc func) {
75  const bool nan_missing = treelite::math::CheckNAN(dmat->missing_value);
76  CHECK_LE(dmat->num_col, static_cast<size_t>(num_feature));
77  std::vector<treelite::predictor::Entry<ThresholdType>> inst(
78  std::max(dmat->num_col, static_cast<size_t>(num_feature)), {-1});
79  CHECK(rbegin < rend && rend <= dmat->num_row);
80  const size_t num_col = dmat->num_col;
81  const ElementType missing_value = dmat->missing_value;
82  const ElementType* data = dmat->data.data();
83  const ElementType* row = nullptr;
84  size_t total_output_size = 0;
85  for (size_t rid = rbegin; rid < rend; ++rid) {
86  row = &data[rid * num_col];
87  for (size_t j = 0; j < num_col; ++j) {
88  if (treelite::math::CheckNAN(row[j])) {
89  CHECK(nan_missing)
90  << "The missing_value argument must be set to NaN if there is any NaN in the matrix.";
91  } else if (nan_missing || row[j] != missing_value) {
92  inst[j].fvalue = static_cast<ThresholdType>(row[j]);
93  }
94  }
95  total_output_size += func(rid, &inst[0], out_pred);
96  for (size_t j = 0; j < num_col; ++j) {
97  inst[j].missing = -1;
98  }
99  }
100  return total_output_size;
101 }
102 
103 template <typename ElementType>
104 class PredLoopDispatcherWithDenseDMatrix {
105  public:
106  template <typename ThresholdType, typename LeafOutputType, typename PredFunc>
107  inline static size_t Dispatch(
108  const treelite::DMatrix* dmat, ThresholdType, int num_feature, size_t rbegin, size_t rend,
109  LeafOutputType* out_pred, PredFunc func) {
110  const auto* dmat_ = static_cast<const treelite::DenseDMatrixImpl<ElementType>*>(dmat);
111  return PredLoop<ElementType, ThresholdType, LeafOutputType, PredFunc>(
112  dmat_, num_feature, rbegin, rend, out_pred, func);
113  }
114 };
115 
116 template <typename ElementType>
117 class PredLoopDispatcherWithCSRDMatrix {
118  public:
119  template <typename ThresholdType, typename LeafOutputType, typename PredFunc>
120  inline static size_t Dispatch(
121  const treelite::DMatrix* dmat, ThresholdType, int num_feature, size_t rbegin, size_t rend,
122  LeafOutputType* out_pred, PredFunc func) {
123  const auto* dmat_ = static_cast<const treelite::CSRDMatrixImpl<ElementType>*>(dmat);
124  return PredLoop<ElementType, ThresholdType, LeafOutputType, PredFunc>(
125  dmat_, num_feature, rbegin, rend, out_pred, func);
126  }
127 };
128 
129 template <typename ThresholdType, typename LeafOutputType, typename PredFunc>
130 inline size_t PredLoop(const treelite::DMatrix* dmat, ThresholdType test_val, int num_feature,
131  size_t rbegin, size_t rend, LeafOutputType* out_pred, PredFunc func) {
132  treelite::DMatrixType dmat_type = dmat->GetType();
133  switch (dmat_type) {
134  case treelite::DMatrixType::kDense: {
135  return treelite::DispatchWithTypeInfo<PredLoopDispatcherWithDenseDMatrix>(
136  dmat->GetElementType(), dmat, test_val, num_feature, rbegin, rend, out_pred, func);
137  }
138  case treelite::DMatrixType::kSparseCSR: {
139  return treelite::DispatchWithTypeInfo<PredLoopDispatcherWithCSRDMatrix>(
140  dmat->GetElementType(), dmat, test_val, num_feature, rbegin, rend, out_pred, func);
141  }
142  default:
143  LOG(FATAL) << "Unrecognized data matrix type: " << static_cast<int>(dmat_type);
144  return 0;
145  }
146 }
147 
148 } // anonymous namespace
149 
150 namespace treelite {
151 namespace predictor {
152 
153 SharedLibrary::SharedLibrary() : handle_(nullptr), libpath_() {}
154 
155 SharedLibrary::~SharedLibrary() {
156  if (handle_) {
157 #ifdef _WIN32
158  FreeLibrary(static_cast<HMODULE>(handle_));
159 #else
160  dlclose(static_cast<void*>(handle_));
161 #endif
162  }
163 }
164 
165 void
166 SharedLibrary::Load(const char* libpath) {
167 #ifdef _WIN32
168  HMODULE handle = LoadLibraryA(libpath);
169 #else
170  void* handle = dlopen(libpath, RTLD_LAZY | RTLD_LOCAL);
171 #endif
172  CHECK(handle) << "Failed to load dynamic shared library `" << libpath << "'";
173  handle_ = static_cast<LibraryHandle>(handle);
174  libpath_ = std::string(libpath);
175 }
176 
177 SharedLibrary::FunctionHandle
178 SharedLibrary::LoadFunction(const char* name) const {
179 #ifdef _WIN32
180  FARPROC func_handle = GetProcAddress(static_cast<HMODULE>(handle_), name);
181 #else
182  void* func_handle = dlsym(static_cast<void*>(handle_), name);
183 #endif
184  CHECK(func_handle)
185  << "Dynamic shared library `" << libpath_ << "' does not contain a function " << name << "().";
186  return static_cast<SharedLibrary::FunctionHandle>(func_handle);
187 }
188 
189 template <typename HandleType>
190 HandleType
191 SharedLibrary::LoadFunctionWithSignature(const char* name) const {
192  auto func_handle = reinterpret_cast<HandleType>(LoadFunction(name));
193  CHECK(func_handle) << "Dynamic shared library `" << libpath_ << "' does not contain a function "
194  << name << "() with the requested signature";
195  return func_handle;
196 }
197 
198 template <typename ThresholdType, typename LeafOutputType>
200  public:
201  inline static std::unique_ptr<PredFunction> Dispatch(
202  const SharedLibrary& library, int num_feature, int num_class) {
203  return std::make_unique<PredFunctionImpl<ThresholdType, LeafOutputType>>(
204  library, num_feature, num_class);
205  }
206 };
207 
208 std::unique_ptr<PredFunction>
209 PredFunction::Create(
210  TypeInfo threshold_type, TypeInfo leaf_output_type, const SharedLibrary& library,
211  int num_feature, int num_class) {
212  return DispatchWithModelTypes<PredFunctionInitDispatcher>(
213  threshold_type, leaf_output_type, library, num_feature, num_class);
214 }
215 
216 template <typename ThresholdType, typename LeafOutputType>
218  const SharedLibrary& library, int num_feature, int num_class) {
219  CHECK_GT(num_class, 0) << "num_class cannot be zero";
220  if (num_class > 1) { // multi-class classification
221  handle_ = library.LoadFunction("predict_multiclass");
222  } else { // everything else
223  handle_ = library.LoadFunction("predict");
224  }
225  num_feature_ = num_feature;
226  num_class_ = num_class;
227 }
228 
229 template <typename ThresholdType, typename LeafOutputType>
230 TypeInfo
232  return TypeToInfo<ThresholdType>();
233 }
234 
235 template <typename ThresholdType, typename LeafOutputType>
236 TypeInfo
238  return TypeToInfo<LeafOutputType>();
239 }
240 
241 template <typename ThresholdType, typename LeafOutputType>
242 size_t
244  const DMatrix* dmat, size_t rbegin, size_t rend, bool pred_margin,
245  PredictorOutputHandle out_pred) const {
246  /* Pass the correct prediction function to PredLoop.
247  We also need to specify how the function should be called. */
248  size_t result_size;
249  // Dimension of output vector:
250  // can be either [num_data] or [num_class]*[num_data].
251  // Note that size of prediction may be smaller than out_pred (this occurs
252  // when pred_function is set to "max_index").
253  CHECK(rbegin < rend && rend <= dmat->GetNumRow());
254  if (num_class_ > 1) { // multi-class classification
255  using PredFunc = size_t (*)(Entry<ThresholdType>*, int, LeafOutputType*);
256  auto pred_func = reinterpret_cast<PredFunc>(handle_);
257  CHECK(pred_func) << "The predict_multiclass() function has incorrect signature.";
258  auto pred_func_wrapper
259  = [pred_func, num_class = num_class_, pred_margin]
260  (int64_t rid, Entry<ThresholdType>* inst, LeafOutputType* out_pred) -> size_t {
261  return pred_func(inst, static_cast<int>(pred_margin),
262  &out_pred[rid * num_class]);
263  };
264  result_size = PredLoop(dmat, static_cast<ThresholdType>(0), num_feature_, rbegin, rend,
265  static_cast<LeafOutputType*>(out_pred), pred_func_wrapper);
266  } else { // everything else
267  using PredFunc = LeafOutputType (*)(Entry<ThresholdType>*, int);
268  auto pred_func = reinterpret_cast<PredFunc>(handle_);
269  CHECK(pred_func) << "The predict() function has incorrect signature.";
270  auto pred_func_wrapper
271  = [pred_func, pred_margin]
272  (int64_t rid, Entry<ThresholdType>* inst, LeafOutputType* out_pred) -> size_t {
273  out_pred[rid] = pred_func(inst, static_cast<int>(pred_margin));
274  return 1;
275  };
276  result_size = PredLoop(dmat, static_cast<ThresholdType>(0), num_feature_, rbegin, rend,
277  static_cast<LeafOutputType*>(out_pred), pred_func_wrapper);
278  }
279  return result_size;
280 }
281 
282 Predictor::Predictor(int num_worker_thread)
283  : pred_func_(nullptr),
284  thread_pool_handle_(nullptr),
285  num_class_(0),
286  num_feature_(0),
287  sigmoid_alpha_(std::numeric_limits<float>::quiet_NaN()),
288  global_bias_(std::numeric_limits<float>::quiet_NaN()),
289  num_worker_thread_(num_worker_thread),
290  threshold_type_(TypeInfo::kInvalid),
291  leaf_output_type_(TypeInfo::kInvalid) {}
292 Predictor::~Predictor() {
293  if (thread_pool_handle_) {
294  Free();
295  }
296 }
297 
298 void
299 Predictor::Load(const char* libpath) {
300  lib_.Load(libpath);
301 
302  using UnsignedQueryFunc = size_t (*)();
303  using StringQueryFunc = const char* (*)();
304  using FloatQueryFunc = float (*)();
305 
306  /* 1. query # of output groups */
307  auto* num_class_query_func
308  = lib_.LoadFunctionWithSignature<UnsignedQueryFunc>("get_num_class");
309  num_class_ = num_class_query_func();
310 
311  /* 2. query # of features */
312  auto* num_feature_query_func
313  = lib_.LoadFunctionWithSignature<UnsignedQueryFunc>("get_num_feature");
314  num_feature_ = num_feature_query_func();
315  CHECK_GT(num_feature_, 0) << "num_feature cannot be zero";
316 
317  /* 3. query # of pred_transform name */
318  auto* pred_transform_query_func
319  = lib_.LoadFunctionWithSignature<StringQueryFunc>("get_pred_transform");
320  pred_transform_ = pred_transform_query_func();
321 
322  /* 4. query # of sigmoid_alpha */
323  auto* sigmoid_alpha_query_func
324  = lib_.LoadFunctionWithSignature<FloatQueryFunc>("get_sigmoid_alpha");
325  sigmoid_alpha_ = sigmoid_alpha_query_func();
326 
327  /* 5. query # of global_bias */
328  auto* global_bias_query_func = lib_.LoadFunctionWithSignature<FloatQueryFunc>("get_global_bias");
329  global_bias_ = global_bias_query_func();
330 
331  /* 6. Query the data type for thresholds and leaf outputs */
332  auto* threshold_type_query_func
333  = lib_.LoadFunctionWithSignature<StringQueryFunc>("get_threshold_type");
334  threshold_type_ = GetTypeInfoByName(threshold_type_query_func());
335  auto* leaf_output_type_query_func
336  = lib_.LoadFunctionWithSignature<StringQueryFunc>("get_leaf_output_type");
337  leaf_output_type_ = GetTypeInfoByName(leaf_output_type_query_func());
338 
339  /* 7. load appropriate function for margin prediction */
340  CHECK_GT(num_class_, 0) << "num_class cannot be zero";
341  pred_func_ = PredFunction::Create(
342  threshold_type_, leaf_output_type_, lib_,
343  static_cast<int>(num_feature_), static_cast<int>(num_class_));
344 
345  if (num_worker_thread_ == -1) {
346  num_worker_thread_ = static_cast<int>(std::thread::hardware_concurrency());
347  }
348  thread_pool_handle_ = static_cast<ThreadPoolHandle>(
349  new PredThreadPool(num_worker_thread_ - 1, this,
350  [](SpscQueue<InputToken>* incoming_queue,
351  SpscQueue<OutputToken>* outgoing_queue,
352  const Predictor* predictor) {
353  predictor->exception_catcher_.Run([&]() {
354  InputToken input;
355  while (incoming_queue->Pop(&input)) {
356  const size_t rbegin = input.rbegin;
357  const size_t rend = input.rend;
358  size_t query_result_size
359  = predictor->pred_func_->PredictBatch(
360  input.dmat, rbegin, rend, input.pred_margin, input.out_pred);
361  outgoing_queue->Push(OutputToken{query_result_size});
362  }
363  });
364  }));
365 }
366 
367 void
368 Predictor::Free() {
369  delete static_cast<PredThreadPool*>(thread_pool_handle_);
370 }
371 
372 static inline
373 std::vector<size_t> SplitBatch(const DMatrix* dmat, size_t split_factor) {
374  const size_t num_row = dmat->GetNumRow();
375  CHECK_LE(split_factor, num_row);
376  const size_t portion = num_row / split_factor;
377  const size_t remainder = num_row % split_factor;
378  std::vector<size_t> workload(split_factor, portion);
379  std::vector<size_t> row_ptr(split_factor + 1, 0);
380  for (size_t i = 0; i < remainder; ++i) {
381  ++workload[i];
382  }
383  size_t accum = 0;
384  for (size_t i = 0; i < split_factor; ++i) {
385  accum += workload[i];
386  row_ptr[i + 1] = accum;
387  }
388  return row_ptr;
389 }
390 
391 template <typename LeafOutputType>
393  public:
394  inline static void Dispatch(
395  size_t num_row, size_t query_size_per_instance, size_t num_class,
396  PredictorOutputHandle out_result);
397 };
398 
399 template <typename LeafOutputType>
401  public:
402  inline static PredictorOutputHandle Dispatch(size_t size);
403 };
404 
405 template <typename LeafOutputType>
407  public:
408  inline static void Dispatch(PredictorOutputHandle output_vector);
409 };
410 
411 size_t
412 Predictor::PredictBatch(
413  const DMatrix* dmat, int verbose, bool pred_margin, PredictorOutputHandle out_result) const {
414  const double tstart = dmlc::GetTime();
415 
416  const size_t num_row = dmat->GetNumRow();
417  auto* pool = static_cast<PredThreadPool*>(thread_pool_handle_);
418  InputToken request{dmat, pred_margin, pred_func_.get(), 0, num_row, out_result};
419  OutputToken response;
420  CHECK_GT(num_row, 0);
421  const int nthread = std::min(num_worker_thread_, static_cast<int>(num_row));
422  const std::vector<size_t> row_ptr = SplitBatch(dmat, nthread);
423  for (int tid = 0; tid < nthread - 1; ++tid) {
424  request.rbegin = row_ptr[tid];
425  request.rend = row_ptr[tid + 1];
426  pool->SubmitTask(tid, request);
427  }
428  size_t total_size = 0;
429  {
430  // assign work to the main thread
431  const size_t rbegin = row_ptr[nthread - 1];
432  const size_t rend = row_ptr[nthread];
433  const size_t query_result_size
434  = pred_func_->PredictBatch(dmat, rbegin, rend, pred_margin, out_result);
435  total_size += query_result_size;
436  }
437  for (int tid = 0; tid < nthread - 1; ++tid) {
438  if (pool->WaitForTask(tid, &response)) {
439  total_size += response.query_result_size;
440  }
441  }
442  // re-shape output if total_size < dimension of out_result
443  if (total_size < QueryResultSize(dmat, 0, num_row)) {
444  CHECK_GT(num_class_, 1);
445  CHECK_EQ(total_size % num_row, 0);
446  const size_t query_size_per_instance = total_size / num_row;
447  CHECK_GT(query_size_per_instance, 0);
448  CHECK_LT(query_size_per_instance, num_class_);
449  DispatchWithTypeInfo<ShrinkResultToFit>(
450  leaf_output_type_, num_row, query_size_per_instance, num_class_, out_result);
451  }
452  const double tend = dmlc::GetTime();
453  if (verbose > 0) {
454  LOG(INFO) << "Treelite: Finished prediction in " << tend - tstart << " sec";
455  }
456  return total_size;
457 }
458 
460 Predictor::CreateOutputVector(const DMatrix* dmat) const {
461  const size_t output_vector_size = this->QueryResultSize(dmat);
462  return DispatchWithTypeInfo<AllocateOutputVector>(leaf_output_type_, output_vector_size);
463 }
464 
465 void
466 Predictor::DeleteOutputVector(PredictorOutputHandle output_vector) const {
467  DispatchWithTypeInfo<DeallocateOutputVector>(leaf_output_type_, output_vector);
468 }
469 
470 template <typename LeafOutputType>
471 void
473  size_t num_row, size_t query_size_per_instance, size_t num_class,
474  PredictorOutputHandle out_result) {
475  auto* out_result_ = static_cast<LeafOutputType*>(out_result);
476  for (size_t rid = 0; rid < num_row; ++rid) {
477  for (size_t k = 0; k < query_size_per_instance; ++k) {
478  out_result_[rid * query_size_per_instance + k] = out_result_[rid * num_class + k];
479  }
480  }
481 }
482 
483 template <typename LeafOutputType>
486  return static_cast<PredictorOutputHandle>(new LeafOutputType[size]);
487 }
488 
489 template <typename LeafOutputType>
490 void
492  delete[] (static_cast<LeafOutputType*>(output_vector));
493 }
494 
495 } // namespace predictor
496 } // namespace treelite
Load prediction function exported as a shared library.
Some useful math utilities.
size_t num_col
number of columns (i.e. # of features used)
Definition: data.h:118
Input data structure of Treelite.
bool CheckNAN(T value)
check for NaN (Not a Number)
Definition: math.h:43
ElementType missing_value
value representing the missing value (usually NaN)
Definition: data.h:59
void * PredictorOutputHandle
handle to output from predictor
Definition: c_api_runtime.h:25
std::vector< size_t > row_ptr
pointer to row headers; length is [num_row] + 1.
Definition: data.h:114
std::vector< uint32_t > col_ind
feature indices. col_ind[i] indicates the feature index associated with data[i].
Definition: data.h:112
std::vector< ElementType > data
feature values
Definition: data.h:57
TypeInfo
Types used by thresholds and leaf outputs.
Definition: typeinfo.h:22
std::vector< ElementType > data
feature values
Definition: data.h:110
Defines TypeInfo class and utilities.
TypeInfo GetTypeInfoByName(const std::string &str)
conversion table from string to TypeInfo, defined in tables.cc
Definition: typeinfo.cc:16
a simple thread pool implementation
data layout. The value -1 signifies the missing value. When the "missing" field is set to -1...
Definition: predictor.h:26
predictor class: wrapper for optimized prediction code
Definition: predictor.h:78
void * ThreadPoolHandle
opaque handle types
Definition: predictor.h:81
size_t num_col
number of columns (i.e. # of features used)
Definition: data.h:63