Treelite
predictor.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_PREDICTOR_H_
8 #define TREELITE_PREDICTOR_H_
9 
10 #include <dmlc/logging.h>
11 #include <treelite/entry.h>
12 #include <string>
13 #include <cstdint>
14 
15 namespace treelite {
16 
18 struct CSRBatch {
20  const float* data;
22  const uint32_t* col_ind;
24  const size_t* row_ptr;
26  size_t num_row;
28  size_t num_col;
29 };
30 
32 struct DenseBatch {
34  const float* data;
38  size_t num_row;
40  size_t num_col;
41 };
42 
44 class Predictor {
45  public:
47  typedef void* QueryFuncHandle;
48  typedef void* PredFuncHandle;
49  typedef void* LibraryHandle;
50  typedef void* ThreadPoolHandle;
51 
52  explicit Predictor(int num_worker_thread = -1);
53  ~Predictor();
58  void Load(const char* name);
62  void Free();
63 
76  size_t PredictBatch(const CSRBatch* batch, int verbose,
77  bool pred_margin, float* out_result);
78  size_t PredictBatch(const DenseBatch* batch, int verbose,
79  bool pred_margin, float* out_result);
91  size_t PredictInst(TreelitePredictorEntry* inst, bool pred_margin,
92  float* out_result);
93 
100  inline size_t QueryResultSize(const CSRBatch* batch) const {
101  CHECK(pred_func_handle_ != nullptr)
102  << "A shared library needs to be loaded first using Load()";
103  return batch->num_row * num_output_group_;
104  }
111  inline size_t QueryResultSize(const DenseBatch* batch) const {
112  CHECK(pred_func_handle_ != nullptr)
113  << "A shared library needs to be loaded first using Load()";
114  return batch->num_row * num_output_group_;
115  }
124  inline size_t QueryResultSize(const CSRBatch* batch,
125  size_t rbegin, size_t rend) const {
126  CHECK(pred_func_handle_ != nullptr)
127  << "A shared library needs to be loaded first using Load()";
128  CHECK(rbegin < rend && rend <= batch->num_row);
129  return (rend - rbegin) * num_output_group_;
130  }
139  inline size_t QueryResultSize(const DenseBatch* batch,
140  size_t rbegin, size_t rend) const {
141  CHECK(pred_func_handle_ != nullptr)
142  << "A shared library needs to be loaded first using Load()";
143  CHECK(rbegin < rend && rend <= batch->num_row);
144  return (rend - rbegin) * num_output_group_;
145  }
151  inline size_t QueryResultSizeSingleInst() const {
152  CHECK(pred_func_handle_ != nullptr)
153  << "A shared library needs to be loaded first using Load()";
154  return num_output_group_;
155  }
162  inline size_t QueryNumOutputGroup() const {
163  return num_output_group_;
164  }
165 
171  inline size_t QueryNumFeature() const {
172  return num_feature_;
173  }
174 
179  inline std::string QueryPredTransform() const {
180  return pred_transform_;
181  }
182 
187  inline float QuerySigmoidAlpha() const {
188  return sigmoid_alpha_;
189  }
190 
195  inline float QueryGlobalBias() const {
196  return global_bias_;
197  }
198 
199  private:
200  LibraryHandle lib_handle_;
201  QueryFuncHandle num_output_group_query_func_handle_;
202  QueryFuncHandle num_feature_query_func_handle_;
203  QueryFuncHandle pred_transform_query_func_handle_;
204  QueryFuncHandle sigmoid_alpha_query_func_handle_;
205  QueryFuncHandle global_bias_query_func_handle_;
206  PredFuncHandle pred_func_handle_;
207  ThreadPoolHandle thread_pool_handle_;
208  size_t num_output_group_;
209  size_t num_feature_;
210  std::string pred_transform_;
211  float sigmoid_alpha_;
212  float global_bias_;
213  int num_worker_thread_;
214 
215  template <typename BatchType>
216  size_t PredictBatchBase_(const BatchType* batch, int verbose,
217  bool pred_margin, float* out_result);
218 };
219 
220 } // namespace treelite
221 
222 #endif // TREELITE_PREDICTOR_H_
size_t QueryNumFeature() const
Get the width (number of features) of each instance used to train the loaded model.
Definition: predictor.h:171
const uint32_t * col_ind
feature indices
Definition: predictor.h:22
size_t QueryResultSizeSingleInst() const
Query the necessary size of array to hold the prediction for a single data row.
Definition: predictor.h:151
sparse batch in Compressed Sparse Row (CSR) format
Definition: predictor.h:18
const size_t * row_ptr
pointer to row headers; length of [num_row] + 1
Definition: predictor.h:24
Entry type for Treelite predictor.
void * QueryFuncHandle
opaque handle types
Definition: predictor.h:47
dense batch
Definition: predictor.h:32
float QueryGlobalBias() const
Get global bias which adjusting predicted margin scores.
Definition: predictor.h:195
const float * data
feature values
Definition: predictor.h:34
float missing_value
value representing the missing value (usually nan)
Definition: predictor.h:36
size_t QueryResultSize(const CSRBatch *batch) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
Definition: predictor.h:100
size_t QueryResultSize(const CSRBatch *batch, size_t rbegin, size_t rend) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
Definition: predictor.h:124
size_t QueryResultSize(const DenseBatch *batch, size_t rbegin, size_t rend) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
Definition: predictor.h:139
data layout. The value -1 signifies the missing value. When the "missing" field is set to -1...
Definition: entry.h:14
const float * data
feature values
Definition: predictor.h:20
size_t QueryResultSize(const DenseBatch *batch) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
Definition: predictor.h:111
std::string QueryPredTransform() const
Get name of post prediction transformation used to train the loaded model.
Definition: predictor.h:179
float QuerySigmoidAlpha() const
Get alpha value in sigmoid transformation used to train the loaded model.
Definition: predictor.h:187
size_t num_row
number of rows
Definition: predictor.h:38
size_t QueryNumOutputGroup() const
Get the number of output groups in the loaded model The number is 1 for most tasks; it is greater tha...
Definition: predictor.h:162
predictor class: wrapper for optimized prediction code
Definition: predictor.h:44
size_t num_row
number of rows
Definition: predictor.h:26
size_t num_col
number of columns (i.e. # of features used)
Definition: predictor.h:28
size_t num_col
number of columns (i.e. # of features used)
Definition: predictor.h:40