treelite
predictor.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_PREDICTOR_H_
8 #define TREELITE_PREDICTOR_H_
9 
10 #include <dmlc/logging.h>
11 #include <treelite/entry.h>
12 #include <cstdint>
13 
14 namespace treelite {
15 
16 namespace common {
17 namespace filesystem {
18 class TemporaryDirectory; // forward declaration
19 }
20 }
21 
23 struct CSRBatch {
25  const float* data;
27  const uint32_t* col_ind;
29  const size_t* row_ptr;
31  size_t num_row;
33  size_t num_col;
34 };
35 
37 struct DenseBatch {
39  const float* data;
43  size_t num_row;
45  size_t num_col;
46 };
47 
49 class Predictor {
50  public:
52  typedef void* QueryFuncHandle;
53  typedef void* PredFuncHandle;
54  typedef void* LibraryHandle;
55  typedef void* ThreadPoolHandle;
56 
57  Predictor(int num_worker_thread = -1);
58  ~Predictor();
63  void Load(const char* name);
67  void Free();
68 
81  size_t PredictBatch(const CSRBatch* batch, int verbose,
82  bool pred_margin, float* out_result);
83  size_t PredictBatch(const DenseBatch* batch, int verbose,
84  bool pred_margin, float* out_result);
96  size_t PredictInst(TreelitePredictorEntry* inst, bool pred_margin,
97  float* out_result);
98 
105  inline size_t QueryResultSize(const CSRBatch* batch) const {
106  CHECK(pred_func_handle_ != nullptr)
107  << "A shared library needs to be loaded first using Load()";
108  return batch->num_row * num_output_group_;
109  }
116  inline size_t QueryResultSize(const DenseBatch* batch) const {
117  CHECK(pred_func_handle_ != nullptr)
118  << "A shared library needs to be loaded first using Load()";
119  return batch->num_row * num_output_group_;
120  }
129  inline size_t QueryResultSize(const CSRBatch* batch,
130  size_t rbegin, size_t rend) const {
131  CHECK(pred_func_handle_ != nullptr)
132  << "A shared library needs to be loaded first using Load()";
133  CHECK(rbegin < rend && rend <= batch->num_row);
134  return (rend - rbegin) * num_output_group_;
135  }
144  inline size_t QueryResultSize(const DenseBatch* batch,
145  size_t rbegin, size_t rend) const {
146  CHECK(pred_func_handle_ != nullptr)
147  << "A shared library needs to be loaded first using Load()";
148  CHECK(rbegin < rend && rend <= batch->num_row);
149  return (rend - rbegin) * num_output_group_;
150  }
156  inline size_t QueryResultSizeSingleInst() const {
157  CHECK(pred_func_handle_ != nullptr)
158  << "A shared library needs to be loaded first using Load()";
159  return num_output_group_;
160  }
167  inline size_t QueryNumOutputGroup() const {
168  return num_output_group_;
169  }
170 
176  inline size_t QueryNumFeature() const {
177  return num_feature_;
178  }
179 
180  private:
181  LibraryHandle lib_handle_;
182  QueryFuncHandle num_output_group_query_func_handle_;
183  QueryFuncHandle num_feature_query_func_handle_;
184  PredFuncHandle pred_func_handle_;
185  ThreadPoolHandle thread_pool_handle_;
186  size_t num_output_group_;
187  size_t num_feature_;
188  int num_worker_thread_;
189 
190  bool using_remote_lib_; // load lib from remote location?
191  // information for temporary file to cache remote lib
192  std::unique_ptr<common::filesystem::TemporaryDirectory> tempdir_;
193  std::string temp_libfile_;
194 
195  template <typename BatchType>
196  size_t PredictBatchBase_(const BatchType* batch, int verbose,
197  bool pred_margin, float* out_result);
198 };
199 
200 } // namespace treelite
201 
202 #endif // TREELITE_PREDICTOR_H_
size_t QueryNumFeature() const
Get the width (number of features) of each instance used to train the loaded model.
Definition: predictor.h:176
const uint32_t * col_ind
feature indices
Definition: predictor.h:27
size_t QueryResultSizeSingleInst() const
Query the necessary size of array to hold the prediction for a single data row.
Definition: predictor.h:156
sparse batch in Compressed Sparse Row (CSR) format
Definition: predictor.h:23
const size_t * row_ptr
pointer to row headers; length of [num_row] + 1
Definition: predictor.h:29
void * QueryFuncHandle
opaque handle types
Definition: predictor.h:52
dense batch
Definition: predictor.h:37
const float * data
feature values
Definition: predictor.h:39
float missing_value
value representing the missing value (usually nan)
Definition: predictor.h:41
size_t QueryResultSize(const CSRBatch *batch) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
Definition: predictor.h:105
size_t QueryResultSize(const CSRBatch *batch, size_t rbegin, size_t rend) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
Definition: predictor.h:129
size_t QueryResultSize(const DenseBatch *batch, size_t rbegin, size_t rend) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
Definition: predictor.h:144
data layout. The value -1 signifies the missing value. When the "missing" field is set to -1...
Definition: entry.h:11
const float * data
feature values
Definition: predictor.h:25
size_t QueryResultSize(const DenseBatch *batch) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
Definition: predictor.h:116
size_t num_row
number of rows
Definition: predictor.h:43
size_t QueryNumOutputGroup() const
Get the number of output groups in the loaded model The number is 1 for most tasks; it is greater tha...
Definition: predictor.h:167
predictor class: wrapper for optimized prediction code
Definition: predictor.h:49
size_t num_row
number of rows
Definition: predictor.h:31
size_t num_col
number of columns (i.e. # of features used)
Definition: predictor.h:33
size_t num_col
number of columns (i.e. # of features used)
Definition: predictor.h:45