treelite
predictor.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_PREDICTOR_H_
8 #define TREELITE_PREDICTOR_H_
9 
10 #include <dmlc/logging.h>
11 #include <cstdint>
12 
13 namespace treelite {
14 
16 struct CSRBatch {
18  const float* data;
20  const uint32_t* col_ind;
22  const size_t* row_ptr;
24  size_t num_row;
26  size_t num_col;
27 };
28 
30 struct DenseBatch {
32  const float* data;
36  size_t num_row;
38  size_t num_col;
39 };
40 
42 class Predictor {
43  public:
48  union Entry {
49  int missing;
50  float fvalue;
51  // may contain extra fields later, such as qvalue
52  };
53 
55  typedef void* QueryFuncHandle;
56  typedef void* PredFuncHandle;
57  typedef void* LibraryHandle;
58 
59  Predictor();
60  ~Predictor();
65  void Load(const char* name);
69  void Free();
70 
83  size_t PredictBatch(const CSRBatch* batch, int nthread, int verbose,
84  bool pred_margin, float* out_result) const;
85  size_t PredictBatch(const DenseBatch* batch, int nthread, int verbose,
86  bool pred_margin, float* out_result) const;
87 
94  inline size_t QueryResultSize(const CSRBatch* batch) const {
95  CHECK(pred_func_handle_ != nullptr)
96  << "A shared library needs to be loaded first using Load()";
97  return batch->num_row * num_output_group_;
98  }
99  inline size_t QueryResultSize(const DenseBatch* batch) const {
100  CHECK(pred_func_handle_ != nullptr)
101  << "A shared library needs to be loaded first using Load()";
102  return batch->num_row * num_output_group_;
103  }
110  inline size_t QueryNumOutputGroup() const {
111  return num_output_group_;
112  }
113 
114  private:
115  LibraryHandle lib_handle_;
116  QueryFuncHandle query_func_handle_;
117  PredFuncHandle pred_func_handle_;
118  size_t num_output_group_;
119 };
120 
121 } // namespace treelite
122 
123 #endif // TREELITE_PREDICTOR_H_
const uint32_t * col_ind
feature indices
Definition: predictor.h:20
size_t QueryResultSize(const CSRBatch *batch) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
Definition: predictor.h:94
data layout. The value -1 signifies the missing value. When the "missing" field is set to -1...
Definition: predictor.h:48
size_t QueryNumOutputGroup() const
Get the number of output groups in the loaded model The number is 1 for most tasks; it is greater tha...
Definition: predictor.h:110
sparse batch in Compressed Sparse Row (CSR) format
Definition: predictor.h:16
const size_t * row_ptr
pointer to row headers; length of [num_row] + 1
Definition: predictor.h:22
void * QueryFuncHandle
opaque handle types
Definition: predictor.h:55
dense batch
Definition: predictor.h:30
const float * data
feature values
Definition: predictor.h:32
float missing_value
value representing the missing value (usually nan)
Definition: predictor.h:34
const float * data
feature values
Definition: predictor.h:18
size_t num_row
number of rows
Definition: predictor.h:36
predictor class: wrapper for optimized prediction code
Definition: predictor.h:42
size_t num_row
number of rows
Definition: predictor.h:24
size_t num_col
number of columns (i.e. # of features used)
Definition: predictor.h:26
size_t num_col
number of columns (i.e. # of features used)
Definition: predictor.h:38