7 #ifndef TREELITE_PREDICTOR_H_ 8 #define TREELITE_PREDICTOR_H_ 10 #include <dmlc/logging.h> 48 typedef void* PredFuncHandle;
49 typedef void* LibraryHandle;
50 typedef void* ThreadPoolHandle;
52 explicit Predictor(
int num_worker_thread = -1);
58 void Load(
const char* name);
76 size_t PredictBatch(
const CSRBatch* batch,
int verbose,
77 bool pred_margin,
float* out_result);
78 size_t PredictBatch(
const DenseBatch* batch,
int verbose,
79 bool pred_margin,
float* out_result);
101 CHECK(pred_func_handle_ !=
nullptr)
102 <<
"A shared library needs to be loaded first using Load()";
103 return batch->
num_row * num_output_group_;
112 CHECK(pred_func_handle_ !=
nullptr)
113 <<
"A shared library needs to be loaded first using Load()";
114 return batch->
num_row * num_output_group_;
125 size_t rbegin,
size_t rend)
const {
126 CHECK(pred_func_handle_ !=
nullptr)
127 <<
"A shared library needs to be loaded first using Load()";
128 CHECK(rbegin < rend && rend <= batch->
num_row);
129 return (rend - rbegin) * num_output_group_;
140 size_t rbegin,
size_t rend)
const {
141 CHECK(pred_func_handle_ !=
nullptr)
142 <<
"A shared library needs to be loaded first using Load()";
143 CHECK(rbegin < rend && rend <= batch->
num_row);
144 return (rend - rbegin) * num_output_group_;
152 CHECK(pred_func_handle_ !=
nullptr)
153 <<
"A shared library needs to be loaded first using Load()";
154 return num_output_group_;
163 return num_output_group_;
180 return pred_transform_;
188 return sigmoid_alpha_;
200 LibraryHandle lib_handle_;
201 QueryFuncHandle num_output_group_query_func_handle_;
202 QueryFuncHandle num_feature_query_func_handle_;
203 QueryFuncHandle pred_transform_query_func_handle_;
204 QueryFuncHandle sigmoid_alpha_query_func_handle_;
205 QueryFuncHandle global_bias_query_func_handle_;
206 PredFuncHandle pred_func_handle_;
207 ThreadPoolHandle thread_pool_handle_;
208 size_t num_output_group_;
210 std::string pred_transform_;
211 float sigmoid_alpha_;
213 int num_worker_thread_;
215 template <
typename BatchType>
216 size_t PredictBatchBase_(
const BatchType* batch,
int verbose,
217 bool pred_margin,
float* out_result);
222 #endif // TREELITE_PREDICTOR_H_ size_t QueryNumFeature() const
Get the width (number of features) of each instance used to train the loaded model.
const uint32_t * col_ind
feature indices
size_t QueryResultSizeSingleInst() const
Query the necessary size of array to hold the prediction for a single data row.
sparse batch in Compressed Sparse Row (CSR) format
const size_t * row_ptr
pointer to row headers; length of [num_row] + 1
Entry type for Treelite predictor.
void * QueryFuncHandle
opaque handle types
float QueryGlobalBias() const
Get global bias which adjusting predicted margin scores.
const float * data
feature values
float missing_value
value representing the missing value (usually nan)
size_t QueryResultSize(const CSRBatch *batch) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
size_t QueryResultSize(const CSRBatch *batch, size_t rbegin, size_t rend) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
size_t QueryResultSize(const DenseBatch *batch, size_t rbegin, size_t rend) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
data layout. The value -1 signifies the missing value. When the "missing" field is set to -1...
const float * data
feature values
size_t QueryResultSize(const DenseBatch *batch) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
std::string QueryPredTransform() const
Get name of post prediction transformation used to train the loaded model.
float QuerySigmoidAlpha() const
Get alpha value in sigmoid transformation used to train the loaded model.
size_t num_row
number of rows
size_t QueryNumOutputGroup() const
Get the number of output groups in the loaded model The number is 1 for most tasks; it is greater tha...
predictor class: wrapper for optimized prediction code
size_t num_row
number of rows
size_t num_col
number of columns (i.e. # of features used)
size_t num_col
number of columns (i.e. # of features used)