7 #ifndef TREELITE_PREDICTOR_H_ 8 #define TREELITE_PREDICTOR_H_ 28 std::exception_ptr omp_exception_;
36 template <
typename Function,
typename... Parameters>
37 void Run(Function f, Parameters... params) {
41 std::lock_guard<std::mutex> lock(mutex_);
42 if (!omp_exception_) {
43 omp_exception_ = std::current_exception();
45 }
catch (std::exception &ex) {
46 std::lock_guard<std::mutex> lock(mutex_);
47 if (!omp_exception_) {
48 omp_exception_ = std::current_exception();
57 if (this->omp_exception_) std::rethrow_exception(this->omp_exception_);
65 template <
typename ElementType>
74 using LibraryHandle =
void*;
75 using FunctionHandle =
void*;
78 void Load(
const char* libpath);
79 FunctionHandle LoadFunction(
const char* name)
const;
80 template<
typename HandleType>
81 HandleType LoadFunctionWithSignature(
const char* name)
const;
84 LibraryHandle handle_;
90 static std::unique_ptr<PredFunction> Create(
TypeInfo threshold_type,
TypeInfo leaf_output_type,
95 virtual TypeInfo GetThresholdType()
const = 0;
96 virtual TypeInfo GetLeafOutputType()
const = 0;
97 virtual size_t PredictBatch(
const DMatrix* dmat,
size_t rbegin,
size_t rend,
bool pred_margin,
101 template<
typename ThresholdType,
typename LeafOutputType>
104 using PredFuncHandle =
void*;
106 TypeInfo GetThresholdType()
const override;
107 TypeInfo GetLeafOutputType()
const override;
108 size_t PredictBatch(
const DMatrix* dmat,
size_t rbegin,
size_t rend,
bool pred_margin,
112 PredFuncHandle handle_;
123 explicit Predictor(
int num_worker_thread = -1);
129 void Load(
const char* libpath);
155 TREELITE_CHECK(pred_func_) <<
"A shared library needs to be loaded first using Load()";
156 return dmat->GetNumRow() * num_class_;
167 TREELITE_CHECK(pred_func_) <<
"A shared library needs to be loaded first using Load()";
168 TREELITE_CHECK(rbegin < rend && rend <= dmat->GetNumRow());
169 return (rend - rbegin) * num_class_;
193 return pred_transform_;
200 return sigmoid_alpha_;
221 return threshold_type_;
228 return leaf_output_type_;
244 std::unique_ptr<PredFunction> pred_func_;
245 ThreadPoolHandle thread_pool_handle_;
248 std::string pred_transform_;
249 float sigmoid_alpha_;
252 int num_worker_thread_;
262 #endif // TREELITE_PREDICTOR_H_ std::string QueryPredTransform() const
Get name of post prediction transformation used to train the loaded model.
void Rethrow()
should be called from the main thread to rethrow the exception
size_t QueryResultSize(const DMatrix *dmat, size_t rbegin, size_t rend) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
Exception class that will be thrown by Treelite.
Input data structure of Treelite.
C API of Treelite, used for interfacing with other languages This header is used exclusively by the r...
void Run(Function f, Parameters... params)
Parallel OMP blocks should be placed within Run to save exception.
void * PredictorOutputHandle
handle to output from predictor
size_t QueryNumFeature() const
Get the width (number of features) of each instance used to train the loaded model.
logging facility for Treelite
float QueryRatioC() const
Get c value in exponential standard ratio used to train the loaded model.
OMP Exception class catches, saves and rethrows exception from OMP blocks.
TypeInfo
Types used by thresholds and leaf outputs.
Defines TypeInfo class and utilities.
TypeInfo QueryLeafOutputType() const
Get the type of the leaf outputs.
float QuerySigmoidAlpha() const
Get alpha value in sigmoid transformation used to train the loaded model.
float QueryGlobalBias() const
Get global bias which adjusting predicted margin scores.
data layout. The value -1 signifies the missing value. When the "missing" field is set to -1...
predictor class: wrapper for optimized prediction code
size_t QueryResultSize(const DMatrix *dmat) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
void * ThreadPoolHandle
opaque handle types
TypeInfo QueryThresholdType() const
Get the type of the split thresholds.
size_t QueryNumClass() const
Get the number of classes in the loaded model The number is 1 for most tasks; it is greater than 1 fo...