7 #ifndef TREELITE_PREDICTOR_H_ 8 #define TREELITE_PREDICTOR_H_ 33 std::exception_ptr omp_exception_;
41 template <
typename Function,
typename... Parameters>
42 void Run(Function f, Parameters... params) {
46 std::lock_guard<std::mutex> lock(mutex_);
47 if (!omp_exception_) {
48 omp_exception_ = std::current_exception();
50 }
catch (std::exception &ex) {
51 std::lock_guard<std::mutex> lock(mutex_);
52 if (!omp_exception_) {
53 omp_exception_ = std::current_exception();
62 if (this->omp_exception_) std::rethrow_exception(this->omp_exception_);
70 template <
typename ElementType>
80 using LibraryHandle = HMODULE;
81 using FunctionHandle = FARPROC;
83 using LibraryHandle =
void*;
84 using FunctionHandle =
void*;
88 void Load(
const char* libpath);
89 FunctionHandle LoadFunction(
const char* name)
const;
90 template<
typename HandleType>
91 HandleType LoadFunctionWithSignature(
const char* name)
const;
94 LibraryHandle handle_;
100 static std::unique_ptr<PredFunction> Create(
TypeInfo threshold_type,
TypeInfo leaf_output_type,
105 virtual TypeInfo GetThresholdType()
const = 0;
106 virtual TypeInfo GetLeafOutputType()
const = 0;
107 virtual size_t PredictBatch(
const DMatrix* dmat,
size_t rbegin,
size_t rend,
bool pred_margin,
111 template<
typename ThresholdType,
typename LeafOutputType>
114 using PredFuncHandle =
void*;
116 TypeInfo GetThresholdType()
const override;
117 TypeInfo GetLeafOutputType()
const override;
118 size_t PredictBatch(
const DMatrix* dmat,
size_t rbegin,
size_t rend,
bool pred_margin,
122 PredFuncHandle handle_;
133 explicit Predictor(
int num_worker_thread = -1);
139 void Load(
const char* libpath);
165 TREELITE_CHECK(pred_func_) <<
"A shared library needs to be loaded first using Load()";
166 return dmat->GetNumRow() * num_class_;
177 TREELITE_CHECK(pred_func_) <<
"A shared library needs to be loaded first using Load()";
178 TREELITE_CHECK(rbegin < rend && rend <= dmat->GetNumRow());
179 return (rend - rbegin) * num_class_;
203 return pred_transform_;
210 return sigmoid_alpha_;
231 return threshold_type_;
238 return leaf_output_type_;
254 std::unique_ptr<PredFunction> pred_func_;
255 ThreadPoolHandle thread_pool_handle_;
258 std::string pred_transform_;
259 float sigmoid_alpha_;
262 int num_worker_thread_;
272 #endif // TREELITE_PREDICTOR_H_ std::string QueryPredTransform() const
Get name of post prediction transformation used to train the loaded model.
void Rethrow()
should be called from the main thread to rethrow the exception
size_t QueryResultSize(const DMatrix *dmat, size_t rbegin, size_t rend) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
Exception class that will be thrown by Treelite.
Input data structure of Treelite.
C API of Treelite, used for interfacing with other languages This header is used exclusively by the r...
void Run(Function f, Parameters... params)
Parallel OMP blocks should be placed within Run to save exception.
void * PredictorOutputHandle
handle to output from predictor
size_t QueryNumFeature() const
Get the width (number of features) of each instance used to train the loaded model.
logging facility for Treelite
float QueryRatioC() const
Get c value in exponential standard ratio used to train the loaded model.
OMP Exception class catches, saves and rethrows exception from OMP blocks.
TypeInfo
Types used by thresholds and leaf outputs.
Defines TypeInfo class and utilities.
TypeInfo QueryLeafOutputType() const
Get the type of the leaf outputs.
float QuerySigmoidAlpha() const
Get alpha value in sigmoid transformation used to train the loaded model.
float QueryGlobalBias() const
Get global bias which adjusting predicted margin scores.
data layout. The value -1 signifies the missing value. When the "missing" field is set to -1...
predictor class: wrapper for optimized prediction code
size_t QueryResultSize(const DMatrix *dmat) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
void * ThreadPoolHandle
opaque handle types
TypeInfo QueryThresholdType() const
Get the type of the split thresholds.
size_t QueryNumClass() const
Get the number of classes in the loaded model The number is 1 for most tasks; it is greater than 1 fo...