7 #ifndef TREELITE_PREDICTOR_H_ 8 #define TREELITE_PREDICTOR_H_ 33 template <
typename ElementType>
43 using LibraryHandle = HMODULE;
44 using FunctionHandle = FARPROC;
46 using LibraryHandle =
void*;
47 using FunctionHandle =
void*;
51 void Load(
const char* libpath);
52 FunctionHandle LoadFunction(
const char* name)
const;
53 template<
typename HandleType>
54 HandleType LoadFunctionWithSignature(
const char* name)
const;
57 LibraryHandle handle_;
63 static std::unique_ptr<PredFunction> Create(
TypeInfo threshold_type,
TypeInfo leaf_output_type,
68 virtual TypeInfo GetThresholdType()
const = 0;
69 virtual TypeInfo GetLeafOutputType()
const = 0;
70 virtual size_t PredictBatch(
const DMatrix* dmat,
size_t rbegin,
size_t rend,
bool pred_margin,
74 template<
typename ThresholdType,
typename LeafOutputType>
77 using PredFuncHandle =
void*;
79 TypeInfo GetThresholdType()
const override;
80 TypeInfo GetLeafOutputType()
const override;
81 size_t PredictBatch(
const DMatrix* dmat,
size_t rbegin,
size_t rend,
bool pred_margin,
85 PredFuncHandle handle_;
96 explicit Predictor(
int num_worker_thread = -1);
102 void Load(
const char* libpath);
128 TREELITE_CHECK(pred_func_) <<
"A shared library needs to be loaded first using Load()";
129 return dmat->GetNumRow() * num_class_;
140 TREELITE_CHECK(pred_func_) <<
"A shared library needs to be loaded first using Load()";
141 TREELITE_CHECK(rbegin < rend && rend <= dmat->GetNumRow());
142 return (rend - rbegin) * num_class_;
166 return pred_transform_;
173 return sigmoid_alpha_;
194 return threshold_type_;
201 return leaf_output_type_;
217 std::unique_ptr<PredFunction> pred_func_;
218 ThreadPoolHandle thread_pool_handle_;
221 std::string pred_transform_;
222 float sigmoid_alpha_;
225 int num_worker_thread_;
235 #endif // TREELITE_PREDICTOR_H_ std::string QueryPredTransform() const
Get name of post prediction transformation used to train the loaded model.
size_t QueryResultSize(const DMatrix *dmat, size_t rbegin, size_t rend) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
Input data structure of Treelite.
C API of Treelite, used for interfacing with other languages This header is used exclusively by the r...
void * PredictorOutputHandle
handle to output from predictor
size_t QueryNumFeature() const
Get the width (number of features) of each instance used to train the loaded model.
logging facility for Treelite
float QueryRatioC() const
Get c value in exponential standard ratio used to train the loaded model.
Exception class used throughout the Treelite codebase.
TypeInfo
Types used by thresholds and leaf outputs.
Utility to propagate exceptions throws inside an OpenMP block.
Defines TypeInfo class and utilities.
TypeInfo QueryLeafOutputType() const
Get the type of the leaf outputs.
float QuerySigmoidAlpha() const
Get alpha value in sigmoid transformation used to train the loaded model.
OMP Exception class catches, saves and rethrows exception from OMP blocks.
float QueryGlobalBias() const
Get global bias which adjusting predicted margin scores.
data layout. The value -1 signifies the missing value. When the "missing" field is set to -1...
predictor class: wrapper for optimized prediction code
size_t QueryResultSize(const DMatrix *dmat) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
void * ThreadPoolHandle
opaque handle types
TypeInfo QueryThresholdType() const
Get the type of the split thresholds.
size_t QueryNumClass() const
Get the number of classes in the loaded model The number is 1 for most tasks; it is greater than 1 fo...