7 #ifndef TREELITE_PREDICTOR_H_ 8 #define TREELITE_PREDICTOR_H_ 10 #include <dmlc/logging.h> 25 template <
typename ElementType>
34 using LibraryHandle =
void*;
35 using FunctionHandle =
void*;
38 void Load(
const char* libpath);
39 FunctionHandle LoadFunction(
const char* name)
const;
40 template<
typename HandleType>
41 HandleType LoadFunctionWithSignature(
const char* name)
const;
44 LibraryHandle handle_;
50 static std::unique_ptr<PredFunction> Create(
TypeInfo threshold_type,
TypeInfo leaf_output_type,
55 virtual TypeInfo GetThresholdType()
const = 0;
56 virtual TypeInfo GetLeafOutputType()
const = 0;
57 virtual size_t PredictBatch(
const DMatrix* dmat,
size_t rbegin,
size_t rend,
bool pred_margin,
61 template<
typename ThresholdType,
typename LeafOutputType>
64 using PredFuncHandle =
void*;
66 TypeInfo GetThresholdType()
const override;
67 TypeInfo GetLeafOutputType()
const override;
68 size_t PredictBatch(
const DMatrix* dmat,
size_t rbegin,
size_t rend,
bool pred_margin,
72 PredFuncHandle handle_;
83 explicit Predictor(
int num_worker_thread = -1);
89 void Load(
const char* libpath);
115 CHECK(pred_func_) <<
"A shared library needs to be loaded first using Load()";
116 return dmat->GetNumRow() * num_class_;
127 CHECK(pred_func_) <<
"A shared library needs to be loaded first using Load()";
128 CHECK(rbegin < rend && rend <= dmat->GetNumRow());
129 return (rend - rbegin) * num_class_;
153 return pred_transform_;
160 return sigmoid_alpha_;
174 return threshold_type_;
181 return leaf_output_type_;
197 std::unique_ptr<PredFunction> pred_func_;
198 ThreadPoolHandle thread_pool_handle_;
201 std::string pred_transform_;
202 float sigmoid_alpha_;
204 int num_worker_thread_;
208 mutable dmlc::OMPException exception_catcher_;
214 #endif // TREELITE_PREDICTOR_H_ std::string QueryPredTransform() const
Get name of post prediction transformation used to train the loaded model.
size_t QueryResultSize(const DMatrix *dmat, size_t rbegin, size_t rend) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
Input data structure of Treelite.
C API of Treelite, used for interfacing with other languages This header is used exclusively by the r...
void * PredictorOutputHandle
handle to output from predictor
size_t QueryNumFeature() const
Get the width (number of features) of each instance used to train the loaded model.
TypeInfo
Types used by thresholds and leaf outputs.
Defines TypeInfo class and utilities.
TypeInfo QueryLeafOutputType() const
Get the type of the leaf outputs.
float QuerySigmoidAlpha() const
Get alpha value in sigmoid transformation used to train the loaded model.
float QueryGlobalBias() const
Get global bias which adjusting predicted margin scores.
data layout. The value -1 signifies the missing value. When the "missing" field is set to -1...
predictor class: wrapper for optimized prediction code
size_t QueryResultSize(const DMatrix *dmat) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
void * ThreadPoolHandle
opaque handle types
TypeInfo QueryThresholdType() const
Get the type of the split thresholds.
size_t QueryNumClass() const
Get the number of classes in the loaded model The number is 1 for most tasks; it is greater than 1 fo...