treelite
predictor.cc
Go to the documentation of this file.
1 
8 #include <treelite/predictor.h>
9 #include <treelite/common.h>
10 #include <dmlc/logging.h>
11 #include <dmlc/io.h>
12 #include <dmlc/timer.h>
13 #include <cstdint>
14 #include <algorithm>
15 #include <memory>
16 #include <fstream>
17 #include <limits>
18 #include <functional>
19 #include <type_traits>
20 #include "common/math.h"
21 #include "common/filesystem.h"
23 
24 #ifdef _WIN32
25 #define NOMINMAX
26 #include <windows.h>
27 #else
28 #include <dlfcn.h>
29 #endif
30 
31 namespace {
32 
33 enum class InputType : uint8_t {
34  kSparseBatch = 0, kDenseBatch = 1
35 };
36 
37 struct InputToken {
38  InputType input_type;
39  const void* data; // pointer to input data
40  bool pred_margin; // whether to store raw margin or transformed scores
41  size_t num_feature;
42  // # features (columns) accepted by the tree ensemble model
43  size_t num_output_group;
44  // size of output per instance (row)
45  treelite::Predictor::PredFuncHandle pred_func_handle;
46  size_t rbegin, rend;
47  // range of instances (rows) assigned to each worker
48  float* out_pred;
49  // buffer to store output from each worker
50 };
51 
52 struct OutputToken {
53  size_t query_result_size;
54 };
55 
56 inline std::string GetProtocol(const char* name) {
57  const char *p = std::strstr(name, "://");
58  if (p == NULL) {
59  return "";
60  } else {
61  return std::string(name, p - name + 3);
62  }
63 }
64 
65 using PredThreadPool
67 
68 inline treelite::Predictor::LibraryHandle OpenLibrary(const char* name) {
69 #ifdef _WIN32
70  HMODULE handle = LoadLibraryA(name);
71 #else
72  void* handle = dlopen(name, RTLD_LAZY | RTLD_LOCAL);
73 #endif
74  return static_cast<treelite::Predictor::LibraryHandle>(handle);
75 }
76 
77 inline void CloseLibrary(treelite::Predictor::LibraryHandle handle) {
78 #ifdef _WIN32
79  FreeLibrary(static_cast<HMODULE>(handle));
80 #else
81  dlclose(static_cast<void*>(handle));
82 #endif
83 }
84 
85 template <typename HandleType>
86 inline HandleType LoadFunction(treelite::Predictor::LibraryHandle lib_handle,
87  const char* name) {
88 #ifdef _WIN32
89  FARPROC func_handle = GetProcAddress(static_cast<HMODULE>(lib_handle), name);
90 #else
91  void* func_handle = dlsym(static_cast<void*>(lib_handle), name);
92 #endif
93  return static_cast<HandleType>(func_handle);
94 }
95 
96 template <typename PredFunc>
97 inline size_t PredLoop(const treelite::CSRBatch* batch, size_t num_feature,
98  size_t rbegin, size_t rend,
99  float* out_pred, PredFunc func) {
100  CHECK_LE(batch->num_col, num_feature);
101  std::vector<TreelitePredictorEntry> inst(
102  std::max(batch->num_col, num_feature), {-1});
103  CHECK(rbegin < rend && rend <= batch->num_row);
104  CHECK(sizeof(size_t) < sizeof(int64_t)
105  || (rbegin <= static_cast<size_t>(std::numeric_limits<int64_t>::max())
106  && rend <= static_cast<size_t>(std::numeric_limits<int64_t>::max())));
107  const int64_t rbegin_ = static_cast<int64_t>(rbegin);
108  const int64_t rend_ = static_cast<int64_t>(rend);
109  const size_t num_col = batch->num_col;
110  const float* data = batch->data;
111  const uint32_t* col_ind = batch->col_ind;
112  const size_t* row_ptr = batch->row_ptr;
113  size_t total_output_size = 0;
114  for (int64_t rid = rbegin_; rid < rend_; ++rid) {
115  const size_t ibegin = row_ptr[rid];
116  const size_t iend = row_ptr[rid + 1];
117  for (size_t i = ibegin; i < iend; ++i) {
118  inst[col_ind[i]].fvalue = data[i];
119  }
120  total_output_size += func(rid, &inst[0], out_pred);
121  for (size_t i = ibegin; i < iend; ++i) {
122  inst[col_ind[i]].missing = -1;
123  }
124  }
125  return total_output_size;
126 }
127 
128 template <typename PredFunc>
129 inline size_t PredLoop(const treelite::DenseBatch* batch, size_t num_feature,
130  size_t rbegin, size_t rend,
131  float* out_pred, PredFunc func) {
132  const bool nan_missing
133  = treelite::common::math::CheckNAN(batch->missing_value);
134  CHECK_LE(batch->num_col, num_feature);
135  std::vector<TreelitePredictorEntry> inst(
136  std::max(batch->num_col, num_feature), {-1});
137  CHECK(rbegin < rend && rend <= batch->num_row);
138  CHECK(sizeof(size_t) < sizeof(int64_t)
139  || (rbegin <= static_cast<size_t>(std::numeric_limits<int64_t>::max())
140  && rend <= static_cast<size_t>(std::numeric_limits<int64_t>::max())));
141  const int64_t rbegin_ = static_cast<int64_t>(rbegin);
142  const int64_t rend_ = static_cast<int64_t>(rend);
143  const size_t num_col = batch->num_col;
144  const float missing_value = batch->missing_value;
145  const float* data = batch->data;
146  const float* row;
147  size_t total_output_size = 0;
148  for (int64_t rid = rbegin_; rid < rend_; ++rid) {
149  row = &data[rid * num_col];
150  for (size_t j = 0; j < num_col; ++j) {
151  if (treelite::common::math::CheckNAN(row[j])) {
152  CHECK(nan_missing)
153  << "The missing_value argument must be set to NaN if there is any "
154  << "NaN in the matrix.";
155  } else if (nan_missing || row[j] != missing_value) {
156  inst[j].fvalue = row[j];
157  }
158  }
159  total_output_size += func(rid, &inst[0], out_pred);
160  for (size_t j = 0; j < num_col; ++j) {
161  inst[j].missing = -1;
162  }
163  }
164  return total_output_size;
165 }
166 
167 template <typename BatchType>
168 inline size_t PredictBatch_(const BatchType* batch, bool pred_margin,
169  size_t num_feature, size_t num_output_group,
170  treelite::Predictor::PredFuncHandle pred_func_handle,
171  size_t rbegin, size_t rend,
172  size_t expected_query_result_size, float* out_pred) {
173  CHECK(pred_func_handle != nullptr)
174  << "A shared library needs to be loaded first using Load()";
175  /* Pass the correct prediction function to PredLoop.
176  We also need to specify how the function should be called. */
177  size_t query_result_size;
178  // Dimension of output vector:
179  // can be either [num_data] or [num_class]*[num_data].
180  // Note that size of prediction may be smaller than out_pred (this occurs
181  // when pred_function is set to "max_index").
182  if (num_output_group > 1) { // multi-class classification task
183  using PredFunc = size_t (*)(TreelitePredictorEntry*, int, float*);
184  PredFunc pred_func = reinterpret_cast<PredFunc>(pred_func_handle);
185  query_result_size =
186  PredLoop(batch, num_feature, rbegin, rend, out_pred,
187  [pred_func, num_output_group, pred_margin]
188  (int64_t rid, TreelitePredictorEntry* inst, float* out_pred) -> size_t {
189  return pred_func(inst, static_cast<int>(pred_margin),
190  &out_pred[rid * num_output_group]);
191  });
192  } else { // every other task
193  using PredFunc = float (*)(TreelitePredictorEntry*, int);
194  PredFunc pred_func = reinterpret_cast<PredFunc>(pred_func_handle);
195  query_result_size =
196  PredLoop(batch, num_feature, rbegin, rend, out_pred,
197  [pred_func, pred_margin]
198  (int64_t rid, TreelitePredictorEntry* inst, float* out_pred) -> size_t {
199  out_pred[rid] = pred_func(inst, static_cast<int>(pred_margin));
200  return 1;
201  });
202  }
203  return query_result_size;
204 }
205 
206 inline size_t PredictInst_(TreelitePredictorEntry* inst,
207  bool pred_margin, size_t num_output_group,
208  treelite::Predictor::PredFuncHandle pred_func_handle,
209  size_t expected_query_result_size, float* out_pred) {
210  CHECK(pred_func_handle != nullptr)
211  << "A shared library needs to be loaded first using Load()";
212  size_t query_result_size; // Dimention of output vector
213  if (num_output_group > 1) { // multi-class classification task
214  using PredFunc = size_t (*)(TreelitePredictorEntry*, int, float*);
215  PredFunc pred_func = reinterpret_cast<PredFunc>(pred_func_handle);
216  query_result_size = pred_func(inst, (int)pred_margin, out_pred);
217  } else { // every other task
218  using PredFunc = float (*)(TreelitePredictorEntry*, int);
219  PredFunc pred_func = reinterpret_cast<PredFunc>(pred_func_handle);
220  out_pred[0] = pred_func(inst, (int)pred_margin);
221  query_result_size = 1;
222  }
223  return query_result_size;
224 }
225 
226 } // anonymous namespace
227 
228 namespace treelite {
229 
230 Predictor::Predictor(int num_worker_thread)
231  : lib_handle_(nullptr),
232  num_output_group_query_func_handle_(nullptr),
233  num_feature_query_func_handle_(nullptr),
234  pred_func_handle_(nullptr),
235  thread_pool_handle_(nullptr),
236  num_worker_thread_(num_worker_thread),
237  tempdir_(nullptr) {}
238 Predictor::~Predictor() {
239  Free();
240 }
241 
242 void
243 Predictor::Load(const char* name) {
244  const std::string protocol = GetProtocol(name);
245  if (protocol == "file://" || protocol.empty()) {
246  // local file
247  lib_handle_ = OpenLibrary(name);
248  } else {
249  // remote file
250  tempdir_.reset(new common::filesystem::TemporaryDirectory());
251  temp_libfile_ = tempdir_->AddFile(common::filesystem::GetBasename(name));
252  {
253  std::unique_ptr<dmlc::Stream> strm(dmlc::Stream::Create(name, "r"));
254  dmlc::istream is(strm.get());
255  std::ofstream of(temp_libfile_);
256  of << is.rdbuf();
257  }
258  lib_handle_ = OpenLibrary(temp_libfile_.c_str());
259  }
260  if (lib_handle_ == nullptr) {
261  LOG(FATAL) << "Failed to load dynamic shared library `" << name << "'";
262  }
263 
264  /* 1. query # of output groups */
265  num_output_group_query_func_handle_
266  = LoadFunction<QueryFuncHandle>(lib_handle_, "get_num_output_group");
267  using QueryFunc = size_t (*)(void);
268  QueryFunc query_func
269  = reinterpret_cast<QueryFunc>(num_output_group_query_func_handle_);
270  CHECK(query_func != nullptr)
271  << "Dynamic shared library `" << name
272  << "' does not contain valid get_num_output_group() function";
273  num_output_group_ = query_func();
274 
275  /* 2. query # of features */
276  num_feature_query_func_handle_
277  = LoadFunction<QueryFuncHandle>(lib_handle_, "get_num_feature");
278  query_func = reinterpret_cast<QueryFunc>(num_feature_query_func_handle_);
279  CHECK(query_func != nullptr)
280  << "Dynamic shared library `" << name
281  << "' does not contain valid get_num_feature() function";
282  num_feature_ = query_func();
283  CHECK_GT(num_feature_, 0) << "num_feature cannot be zero";
284 
285  /* 3. load appropriate function for margin prediction */
286  CHECK_GT(num_output_group_, 0) << "num_output_group cannot be zero";
287  if (num_output_group_ > 1) { // multi-class classification
288  pred_func_handle_ = LoadFunction<PredFuncHandle>(lib_handle_,
289  "predict_multiclass");
290  using PredFunc = size_t (*)(TreelitePredictorEntry*, int, float*);
291  PredFunc pred_func = reinterpret_cast<PredFunc>(pred_func_handle_);
292  CHECK(pred_func != nullptr)
293  << "Dynamic shared library `" << name
294  << "' does not contain valid predict_multiclass() function";
295  } else { // everything else
296  pred_func_handle_ = LoadFunction<PredFuncHandle>(lib_handle_, "predict");
297  using PredFunc = float (*)(TreelitePredictorEntry*, int);
298  PredFunc pred_func = reinterpret_cast<PredFunc>(pred_func_handle_);
299  CHECK(pred_func != nullptr)
300  << "Dynamic shared library `" << name
301  << "' does not contain valid predict() function";
302  }
303 
304  if (num_worker_thread_ == -1) {
305  num_worker_thread_ = std::thread::hardware_concurrency();
306  }
307  thread_pool_handle_ = static_cast<ThreadPoolHandle>(
308  new PredThreadPool(num_worker_thread_ - 1, this,
309  [](SpscQueue<InputToken>* incoming_queue,
310  SpscQueue<OutputToken>* outgoing_queue,
311  const Predictor* predictor) {
312  InputToken input;
313  while (incoming_queue->Pop(&input)) {
314  size_t query_result_size;
315  const size_t rbegin = input.rbegin;
316  const size_t rend = input.rend;
317  switch (input.input_type) {
318  case InputType::kSparseBatch:
319  {
320  const CSRBatch* batch = static_cast<const CSRBatch*>(input.data);
321  query_result_size
322  = PredictBatch_(batch, input.pred_margin, input.num_feature,
323  input.num_output_group, input.pred_func_handle,
324  rbegin, rend,
325  predictor->QueryResultSize(batch, rbegin, rend),
326  input.out_pred);
327  }
328  break;
329  case InputType::kDenseBatch:
330  {
331  const DenseBatch* batch = static_cast<const DenseBatch*>(input.data);
332  query_result_size
333  = PredictBatch_(batch, input.pred_margin, input.num_feature,
334  input.num_output_group, input.pred_func_handle,
335  rbegin, rend,
336  predictor->QueryResultSize(batch, rbegin, rend),
337  input.out_pred);
338  }
339  break;
340  }
341  outgoing_queue->Push(OutputToken{query_result_size});
342  }
343  }));
344 }
345 
346 void
347 Predictor::Free() {
348  CloseLibrary(lib_handle_);
349  delete static_cast<PredThreadPool*>(thread_pool_handle_);
350 }
351 
352 template <typename BatchType>
353 static inline
354 std::vector<size_t> SplitBatch(const BatchType* batch, size_t split_factor) {
355  const size_t num_row = batch->num_row;
356  CHECK_LE(split_factor, num_row);
357  const size_t portion = num_row / split_factor;
358  const size_t remainder = num_row % split_factor;
359  std::vector<size_t> workload(split_factor, portion);
360  std::vector<size_t> row_ptr(split_factor + 1, 0);
361  for (size_t i = 0; i < remainder; ++i) {
362  ++workload[i];
363  }
364  size_t accum = 0;
365  for (size_t i = 0; i < split_factor; ++i) {
366  accum += workload[i];
367  row_ptr[i + 1] = accum;
368  }
369  return row_ptr;
370 }
371 
372 template <typename BatchType>
373 inline size_t
374 Predictor::PredictBatchBase_(const BatchType* batch, int verbose,
375  bool pred_margin, float* out_result) {
376  static_assert(std::is_same<BatchType, DenseBatch>::value
377  || std::is_same<BatchType, CSRBatch>::value,
378  "PredictBatchBase_: unrecognized batch type");
379  const double tstart = dmlc::GetTime();
380  PredThreadPool* pool = static_cast<PredThreadPool*>(thread_pool_handle_);
381  const InputType input_type
382  = std::is_same<BatchType, CSRBatch>::value
383  ? InputType::kSparseBatch : InputType::kDenseBatch;
384  InputToken request{input_type, static_cast<const void*>(batch), pred_margin,
385  num_feature_, num_output_group_, pred_func_handle_,
386  0, batch->num_row, out_result};
387  OutputToken response;
388  CHECK_GT(batch->num_row, 0);
389  const int nthread = std::min(num_worker_thread_,
390  static_cast<int>(batch->num_row));
391  const std::vector<size_t> row_ptr = SplitBatch(batch, nthread);
392  for (int tid = 0; tid < nthread - 1; ++tid) {
393  request.rbegin = row_ptr[tid];
394  request.rend = row_ptr[tid + 1];
395  pool->SubmitTask(tid, request);
396  }
397  size_t total_size = 0;
398  {
399  // assign work to master
400  const size_t rbegin = row_ptr[nthread - 1];
401  const size_t rend = row_ptr[nthread];
402  const size_t query_result_size
403  = PredictBatch_(batch, pred_margin, num_feature_, num_output_group_,
404  pred_func_handle_,
405  rbegin, rend, QueryResultSize(batch, rbegin, rend),
406  out_result);
407  total_size += query_result_size;
408  }
409  for (int tid = 0; tid < nthread - 1; ++tid) {
410  if (pool->WaitForTask(tid, &response)) {
411  total_size += response.query_result_size;
412  }
413  }
414  // re-shape output if total_size < dimension of out_result
415  if (total_size < QueryResultSize(batch, 0, batch->num_row)) {
416  CHECK_GT(num_output_group_, 1);
417  CHECK_EQ(total_size % batch->num_row, 0);
418  const size_t query_size_per_instance = total_size / batch->num_row;
419  CHECK_GT(query_size_per_instance, 0);
420  CHECK_LT(query_size_per_instance, num_output_group_);
421  for (size_t rid = 0; rid < batch->num_row; ++rid) {
422  for (size_t k = 0; k < query_size_per_instance; ++k) {
423  out_result[rid * query_size_per_instance + k]
424  = out_result[rid * num_output_group_ + k];
425  }
426  }
427  }
428  const double tend = dmlc::GetTime();
429  if (verbose > 0) {
430  LOG(INFO) << "Treelite: Finished prediction in "
431  << tend - tstart << " sec";
432  }
433  return total_size;
434 }
435 
436 size_t
437 Predictor::PredictBatch(const CSRBatch* batch, int verbose,
438  bool pred_margin, float* out_result) {
439  return PredictBatchBase_(batch, verbose, pred_margin, out_result);
440 }
441 
442 size_t
443 Predictor::PredictBatch(const DenseBatch* batch, int verbose,
444  bool pred_margin, float* out_result) {
445  return PredictBatchBase_(batch, verbose, pred_margin, out_result);
446 }
447 
448 size_t
449 Predictor::PredictInst(TreelitePredictorEntry* inst, bool pred_margin,
450  float* out_result) {
451  size_t total_size;
452  total_size = PredictInst_(inst, pred_margin, num_output_group_,
453  pred_func_handle_,
454  QueryResultSizeSingleInst(), out_result);
455  return total_size;
456 }
457 
458 } // namespace treelite
Load prediction function exported as a shared library.
const uint32_t * col_ind
feature indices
Definition: predictor.h:27
sparse batch in Compressed Sparse Row (CSR) format
Definition: predictor.h:23
const size_t * row_ptr
pointer to row headers; length of [num_row] + 1
Definition: predictor.h:29
dense batch
Definition: predictor.h:37
const float * data
feature values
Definition: predictor.h:39
float missing_value
value representing the missing value (usually nan)
Definition: predictor.h:41
size_t QueryResultSize(const CSRBatch *batch) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
Definition: predictor.h:105
data layout. The value -1 signifies the missing value. When the "missing" field is set to -1...
Definition: entry.h:11
a simple thread pool implementation
const float * data
feature values
Definition: predictor.h:25
predictor class: wrapper for optimized prediction code
Definition: predictor.h:49
size_t num_col
number of columns (i.e. # of features used)
Definition: predictor.h:33
size_t num_col
number of columns (i.e. # of features used)
Definition: predictor.h:45