treelite
predictor.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_PREDICTOR_H_
8 #define TREELITE_PREDICTOR_H_
9 
10 #include <dmlc/logging.h>
11 #include <cstdint>
12 
13 namespace treelite {
14 
16 struct CSRBatch {
18  const float* data;
20  const uint32_t* col_ind;
22  const size_t* row_ptr;
24  size_t num_row;
26  size_t num_col;
27 };
28 
30 struct DenseBatch {
32  const float* data;
36  size_t num_row;
38  size_t num_col;
39 };
40 
42 class Predictor {
43  public:
48  union Entry {
49  int missing;
50  float fvalue;
51  // may contain extra fields later, such as qvalue
52  };
53 
55  typedef void* QueryFuncHandle;
56  typedef void* PredFuncHandle;
57  typedef void* PredTransformFuncHandle;
58  typedef void* LibraryHandle;
59 
60  Predictor();
61  ~Predictor();
66  void Load(const char* name);
70  void Free();
71 
84  size_t PredictBatch(const CSRBatch* batch, int nthread, int verbose,
85  bool pred_margin, float* out_result) const;
86  size_t PredictBatch(const DenseBatch* batch, int nthread, int verbose,
87  bool pred_margin, float* out_result) const;
88 
95  inline size_t QueryResultSize(const CSRBatch* batch) const {
96  CHECK(pred_func_handle_ != nullptr)
97  << "A shared library needs to be loaded first using Load()";
98  return batch->num_row * num_output_group_;
99  }
100  inline size_t QueryResultSize(const DenseBatch* batch) const {
101  CHECK(pred_func_handle_ != nullptr)
102  << "A shared library needs to be loaded first using Load()";
103  return batch->num_row * num_output_group_;
104  }
111  inline size_t QueryNumOutputGroup() const {
112  return num_output_group_;
113  }
114 
115  private:
116  LibraryHandle lib_handle_;
117  QueryFuncHandle query_func_handle_;
118  PredFuncHandle pred_func_handle_;
119  PredTransformFuncHandle pred_transform_func_handle_;
120  size_t num_output_group_;
121 };
122 
123 } // namespace treelite
124 
125 #endif // TREELITE_PREDICTOR_H_
const uint32_t * col_ind
feature indices
Definition: predictor.h:20
size_t QueryResultSize(const CSRBatch *batch) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
Definition: predictor.h:95
data layout. The value -1 signifies the missing value. When the "missing" field is set to -1...
Definition: predictor.h:48
size_t QueryNumOutputGroup() const
Get the number of output groups in the loaded model The number is 1 for most tasks; it is greater tha...
Definition: predictor.h:111
sparse batch in Compressed Sparse Row (CSR) format
Definition: predictor.h:16
const size_t * row_ptr
pointer to row headers; length of [num_row] + 1
Definition: predictor.h:22
void * QueryFuncHandle
opaque handle types
Definition: predictor.h:55
dense batch
Definition: predictor.h:30
const float * data
feature values
Definition: predictor.h:32
float missing_value
value representing the missing value (usually nan)
Definition: predictor.h:34
const float * data
feature values
Definition: predictor.h:18
size_t num_row
number of rows
Definition: predictor.h:36
predictor class: wrapper for optimized prediction code
Definition: predictor.h:42
size_t num_row
number of rows
Definition: predictor.h:24
size_t num_col
number of columns (i.e. # of features used)
Definition: predictor.h:26
size_t num_col
number of columns (i.e. # of features used)
Definition: predictor.h:38