7 #ifndef TREELITE_PREDICTOR_H_ 8 #define TREELITE_PREDICTOR_H_ 10 #include <dmlc/logging.h> 56 typedef void* PredFuncHandle;
57 typedef void* LibraryHandle;
58 typedef void* ThreadPoolHandle;
61 bool include_master_thread =
false);
67 void Load(
const char* name);
85 size_t PredictBatch(
const CSRBatch* batch,
int verbose,
86 bool pred_margin,
float* out_result);
87 size_t PredictBatch(
const DenseBatch* batch,
int verbose,
88 bool pred_margin,
float* out_result);
97 CHECK(pred_func_handle_ !=
nullptr)
98 <<
"A shared library needs to be loaded first using Load()";
99 return batch->
num_row * num_output_group_;
101 inline size_t QueryResultSize(
const DenseBatch* batch)
const {
102 CHECK(pred_func_handle_ !=
nullptr)
103 <<
"A shared library needs to be loaded first using Load()";
104 return batch->
num_row * num_output_group_;
106 inline size_t QueryResultSize(
const CSRBatch* batch,
107 size_t rbegin,
size_t rend)
const {
108 CHECK(pred_func_handle_ !=
nullptr)
109 <<
"A shared library needs to be loaded first using Load()";
110 CHECK(rbegin < rend && rend <= batch->
num_row);
111 return (rend - rbegin) * num_output_group_;
113 inline size_t QueryResultSize(
const DenseBatch* batch,
114 size_t rbegin,
size_t rend)
const {
115 CHECK(pred_func_handle_ !=
nullptr)
116 <<
"A shared library needs to be loaded first using Load()";
117 CHECK(rbegin < rend && rend <= batch->
num_row);
118 return (rend - rbegin) * num_output_group_;
127 return num_output_group_;
131 LibraryHandle lib_handle_;
132 QueryFuncHandle query_func_handle_;
133 PredFuncHandle pred_func_handle_;
134 ThreadPoolHandle thread_pool_handle_;
135 size_t num_output_group_;
136 int num_worker_thread_;
137 bool include_master_thread_;
139 template <
typename BatchType>
140 size_t PredictBatchBase_(
const BatchType* batch,
int verbose,
141 bool pred_margin,
float* out_result);
146 #endif // TREELITE_PREDICTOR_H_ const uint32_t * col_ind
feature indices
size_t QueryResultSize(const CSRBatch *batch) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
data layout. The value -1 signifies the missing value. When the "missing" field is set to -1...
size_t QueryNumOutputGroup() const
Get the number of output groups in the loaded model The number is 1 for most tasks; it is greater tha...
sparse batch in Compressed Sparse Row (CSR) format
const size_t * row_ptr
pointer to row headers; length of [num_row] + 1
void * QueryFuncHandle
opaque handle types
const float * data
feature values
float missing_value
value representing the missing value (usually nan)
const float * data
feature values
size_t num_row
number of rows
predictor class: wrapper for optimized prediction code
size_t num_row
number of rows
size_t num_col
number of columns (i.e. # of features used)
size_t num_col
number of columns (i.e. # of features used)