7 #ifndef TREELITE_PREDICTOR_H_ 8 #define TREELITE_PREDICTOR_H_ 10 #include <dmlc/logging.h> 11 #include <treelite/entry.h> 17 namespace filesystem {
18 class TemporaryDirectory;
53 typedef void* PredFuncHandle;
54 typedef void* LibraryHandle;
55 typedef void* ThreadPoolHandle;
63 void Load(
const char* name);
81 size_t PredictBatch(
const CSRBatch* batch,
int verbose,
82 bool pred_margin,
float* out_result);
83 size_t PredictBatch(
const DenseBatch* batch,
int verbose,
84 bool pred_margin,
float* out_result);
106 CHECK(pred_func_handle_ !=
nullptr)
107 <<
"A shared library needs to be loaded first using Load()";
108 return batch->
num_row * num_output_group_;
117 CHECK(pred_func_handle_ !=
nullptr)
118 <<
"A shared library needs to be loaded first using Load()";
119 return batch->
num_row * num_output_group_;
130 size_t rbegin,
size_t rend)
const {
131 CHECK(pred_func_handle_ !=
nullptr)
132 <<
"A shared library needs to be loaded first using Load()";
133 CHECK(rbegin < rend && rend <= batch->num_row);
134 return (rend - rbegin) * num_output_group_;
145 size_t rbegin,
size_t rend)
const {
146 CHECK(pred_func_handle_ !=
nullptr)
147 <<
"A shared library needs to be loaded first using Load()";
148 CHECK(rbegin < rend && rend <= batch->num_row);
149 return (rend - rbegin) * num_output_group_;
157 CHECK(pred_func_handle_ !=
nullptr)
158 <<
"A shared library needs to be loaded first using Load()";
159 return num_output_group_;
168 return num_output_group_;
185 return pred_transform_;
193 return sigmoid_alpha_;
205 LibraryHandle lib_handle_;
206 QueryFuncHandle num_output_group_query_func_handle_;
207 QueryFuncHandle num_feature_query_func_handle_;
208 QueryFuncHandle pred_transform_query_func_handle_;
209 QueryFuncHandle sigmoid_alpha_query_func_handle_;
210 QueryFuncHandle global_bias_query_func_handle_;
211 PredFuncHandle pred_func_handle_;
212 ThreadPoolHandle thread_pool_handle_;
213 size_t num_output_group_;
215 std::string pred_transform_;
216 float sigmoid_alpha_;
218 int num_worker_thread_;
220 bool using_remote_lib_;
222 std::unique_ptr<common::filesystem::TemporaryDirectory> tempdir_;
223 std::string temp_libfile_;
225 template <
typename BatchType>
226 size_t PredictBatchBase_(
const BatchType* batch,
int verbose,
227 bool pred_margin,
float* out_result);
232 #endif // TREELITE_PREDICTOR_H_ size_t QueryNumFeature() const
Get the width (number of features) of each instance used to train the loaded model.
const uint32_t * col_ind
feature indices
size_t QueryResultSizeSingleInst() const
Query the necessary size of array to hold the prediction for a single data row.
sparse batch in Compressed Sparse Row (CSR) format
const size_t * row_ptr
pointer to row headers; length of [num_row] + 1
void * QueryFuncHandle
opaque handle types
float QueryGlobalBias() const
Get global bias which adjusting predicted margin scores.
const float * data
feature values
float missing_value
value representing the missing value (usually nan)
size_t QueryResultSize(const CSRBatch *batch) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
size_t QueryResultSize(const CSRBatch *batch, size_t rbegin, size_t rend) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
size_t QueryResultSize(const DenseBatch *batch, size_t rbegin, size_t rend) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
data layout. The value -1 signifies the missing value. When the "missing" field is set to -1...
const float * data
feature values
size_t QueryResultSize(const DenseBatch *batch) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
std::string QueryPredTransform() const
Get name of post prediction transformation used to train the loaded model.
float QuerySigmoidAlpha() const
Get alpha value in sigmoid transformation used to train the loaded model.
size_t num_row
number of rows
size_t QueryNumOutputGroup() const
Get the number of output groups in the loaded model The number is 1 for most tasks; it is greater tha...
predictor class: wrapper for optimized prediction code
size_t num_row
number of rows
size_t num_col
number of columns (i.e. # of features used)
size_t num_col
number of columns (i.e. # of features used)