treelite
predictor.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_PREDICTOR_H_
8 #define TREELITE_PREDICTOR_H_
9 
10 #include <dmlc/logging.h>
11 #include <treelite/entry.h>
12 #include <cstdint>
13 
14 namespace treelite {
15 
16 namespace common {
17 namespace filesystem {
18 class TemporaryDirectory; // forward declaration
19 }
20 }
21 
23 struct CSRBatch {
25  const float* data;
27  const uint32_t* col_ind;
29  const size_t* row_ptr;
31  size_t num_row;
33  size_t num_col;
34 };
35 
37 struct DenseBatch {
39  const float* data;
43  size_t num_row;
45  size_t num_col;
46 };
47 
49 class Predictor {
50  public:
52  typedef void* QueryFuncHandle;
53  typedef void* PredFuncHandle;
54  typedef void* LibraryHandle;
55  typedef void* ThreadPoolHandle;
56 
57  Predictor(int num_worker_thread = -1,
58  bool include_master_thread = false);
59  ~Predictor();
64  void Load(const char* name);
68  void Free();
69 
82  size_t PredictBatch(const CSRBatch* batch, int verbose,
83  bool pred_margin, float* out_result);
84  size_t PredictBatch(const DenseBatch* batch, int verbose,
85  bool pred_margin, float* out_result);
97  size_t PredictInst(TreelitePredictorEntry* inst, bool pred_margin,
98  float* out_result);
99 
106  inline size_t QueryResultSize(const CSRBatch* batch) const {
107  CHECK(pred_func_handle_ != nullptr)
108  << "A shared library needs to be loaded first using Load()";
109  return batch->num_row * num_output_group_;
110  }
117  inline size_t QueryResultSize(const DenseBatch* batch) const {
118  CHECK(pred_func_handle_ != nullptr)
119  << "A shared library needs to be loaded first using Load()";
120  return batch->num_row * num_output_group_;
121  }
130  inline size_t QueryResultSize(const CSRBatch* batch,
131  size_t rbegin, size_t rend) const {
132  CHECK(pred_func_handle_ != nullptr)
133  << "A shared library needs to be loaded first using Load()";
134  CHECK(rbegin < rend && rend <= batch->num_row);
135  return (rend - rbegin) * num_output_group_;
136  }
145  inline size_t QueryResultSize(const DenseBatch* batch,
146  size_t rbegin, size_t rend) const {
147  CHECK(pred_func_handle_ != nullptr)
148  << "A shared library needs to be loaded first using Load()";
149  CHECK(rbegin < rend && rend <= batch->num_row);
150  return (rend - rbegin) * num_output_group_;
151  }
157  inline size_t QueryResultSizeSingleInst() const {
158  CHECK(pred_func_handle_ != nullptr)
159  << "A shared library needs to be loaded first using Load()";
160  return num_output_group_;
161  }
168  inline size_t QueryNumOutputGroup() const {
169  return num_output_group_;
170  }
171 
177  inline size_t QueryNumFeature() const {
178  return num_feature_;
179  }
180 
181  private:
182  LibraryHandle lib_handle_;
183  QueryFuncHandle num_output_group_query_func_handle_;
184  QueryFuncHandle num_feature_query_func_handle_;
185  PredFuncHandle pred_func_handle_;
186  ThreadPoolHandle thread_pool_handle_;
187  size_t num_output_group_;
188  size_t num_feature_;
189  int num_worker_thread_;
190  bool include_master_thread_; // run task on master thread?
191 
192  bool using_remote_lib_; // load lib from remote location?
193  // information for temporary file to cache remote lib
194  std::unique_ptr<common::filesystem::TemporaryDirectory> tempdir_;
195  std::string temp_libfile_;
196 
197  template <typename BatchType>
198  size_t PredictBatchBase_(const BatchType* batch, int verbose,
199  bool pred_margin, float* out_result);
200 };
201 
202 } // namespace treelite
203 
204 #endif // TREELITE_PREDICTOR_H_
const uint32_t * col_ind
feature indices
Definition: predictor.h:27
size_t QueryResultSizeSingleInst() const
Query the necessary size of array to hold the prediction for a single data row.
Definition: predictor.h:157
size_t QueryResultSize(const CSRBatch *batch) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
Definition: predictor.h:106
size_t QueryNumFeature() const
Get the width (number of features) of each instance used to train the loaded model.
Definition: predictor.h:177
size_t QueryResultSize(const CSRBatch *batch, size_t rbegin, size_t rend) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
Definition: predictor.h:130
size_t QueryNumOutputGroup() const
Get the number of output groups in the loaded model The number is 1 for most tasks; it is greater tha...
Definition: predictor.h:168
size_t QueryResultSize(const DenseBatch *batch, size_t rbegin, size_t rend) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
Definition: predictor.h:145
sparse batch in Compressed Sparse Row (CSR) format
Definition: predictor.h:23
size_t QueryResultSize(const DenseBatch *batch) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
Definition: predictor.h:117
const size_t * row_ptr
pointer to row headers; length of [num_row] + 1
Definition: predictor.h:29
void * QueryFuncHandle
opaque handle types
Definition: predictor.h:52
dense batch
Definition: predictor.h:37
const float * data
feature values
Definition: predictor.h:39
float missing_value
value representing the missing value (usually nan)
Definition: predictor.h:41
data layout. The value -1 signifies the missing value. When the "missing" field is set to -1...
Definition: entry.h:11
const float * data
feature values
Definition: predictor.h:25
size_t num_row
number of rows
Definition: predictor.h:43
predictor class: wrapper for optimized prediction code
Definition: predictor.h:49
size_t num_row
number of rows
Definition: predictor.h:31
size_t num_col
number of columns (i.e. # of features used)
Definition: predictor.h:33
size_t num_col
number of columns (i.e. # of features used)
Definition: predictor.h:45