treelite
predictor.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_PREDICTOR_H_
8 #define TREELITE_PREDICTOR_H_
9 
10 #include <dmlc/logging.h>
11 #include <cstdint>
12 
13 namespace treelite {
14 
16 struct CSRBatch {
18  const float* data;
20  const uint32_t* col_ind;
22  const size_t* row_ptr;
24  size_t num_row;
26  size_t num_col;
27 };
28 
30 struct DenseBatch {
32  const float* data;
36  size_t num_row;
38  size_t num_col;
39 };
40 
42 class Predictor {
43  public:
48  union Entry {
49  int missing;
50  float fvalue;
51  // may contain extra fields later, such as qvalue
52  };
53 
55  typedef void* QueryFuncHandle;
56  typedef void* PredFuncHandle;
57  typedef void* LibraryHandle;
58  typedef void* ThreadPoolHandle;
59 
60  Predictor(int num_worker_thread = -1,
61  bool include_master_thread = false);
62  ~Predictor();
67  void Load(const char* name);
71  void Free();
72 
85  size_t PredictBatch(const CSRBatch* batch, int verbose,
86  bool pred_margin, float* out_result);
87  size_t PredictBatch(const DenseBatch* batch, int verbose,
88  bool pred_margin, float* out_result);
89 
96  inline size_t QueryResultSize(const CSRBatch* batch) const {
97  CHECK(pred_func_handle_ != nullptr)
98  << "A shared library needs to be loaded first using Load()";
99  return batch->num_row * num_output_group_;
100  }
101  inline size_t QueryResultSize(const DenseBatch* batch) const {
102  CHECK(pred_func_handle_ != nullptr)
103  << "A shared library needs to be loaded first using Load()";
104  return batch->num_row * num_output_group_;
105  }
106  inline size_t QueryResultSize(const CSRBatch* batch,
107  size_t rbegin, size_t rend) const {
108  CHECK(pred_func_handle_ != nullptr)
109  << "A shared library needs to be loaded first using Load()";
110  CHECK(rbegin < rend && rend <= batch->num_row);
111  return (rend - rbegin) * num_output_group_;
112  }
113  inline size_t QueryResultSize(const DenseBatch* batch,
114  size_t rbegin, size_t rend) const {
115  CHECK(pred_func_handle_ != nullptr)
116  << "A shared library needs to be loaded first using Load()";
117  CHECK(rbegin < rend && rend <= batch->num_row);
118  return (rend - rbegin) * num_output_group_;
119  }
126  inline size_t QueryNumOutputGroup() const {
127  return num_output_group_;
128  }
129 
130  private:
131  LibraryHandle lib_handle_;
132  QueryFuncHandle query_func_handle_;
133  PredFuncHandle pred_func_handle_;
134  ThreadPoolHandle thread_pool_handle_;
135  size_t num_output_group_;
136  int num_worker_thread_;
137  bool include_master_thread_; // run task on master thread?
138 
139  template <typename BatchType>
140  size_t PredictBatchBase_(const BatchType* batch, int verbose,
141  bool pred_margin, float* out_result);
142 };
143 
144 } // namespace treelite
145 
146 #endif // TREELITE_PREDICTOR_H_
const uint32_t * col_ind
feature indices
Definition: predictor.h:20
size_t QueryResultSize(const CSRBatch *batch) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
Definition: predictor.h:96
data layout. The value -1 signifies the missing value. When the "missing" field is set to -1...
Definition: predictor.h:48
size_t QueryNumOutputGroup() const
Get the number of output groups in the loaded model The number is 1 for most tasks; it is greater tha...
Definition: predictor.h:126
sparse batch in Compressed Sparse Row (CSR) format
Definition: predictor.h:16
const size_t * row_ptr
pointer to row headers; length of [num_row] + 1
Definition: predictor.h:22
void * QueryFuncHandle
opaque handle types
Definition: predictor.h:55
dense batch
Definition: predictor.h:30
const float * data
feature values
Definition: predictor.h:32
float missing_value
value representing the missing value (usually nan)
Definition: predictor.h:34
const float * data
feature values
Definition: predictor.h:18
size_t num_row
number of rows
Definition: predictor.h:36
predictor class: wrapper for optimized prediction code
Definition: predictor.h:42
size_t num_row
number of rows
Definition: predictor.h:24
size_t num_col
number of columns (i.e. # of features used)
Definition: predictor.h:26
size_t num_col
number of columns (i.e. # of features used)
Definition: predictor.h:38