Treelite
predictor.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_PREDICTOR_H_
8 #define TREELITE_PREDICTOR_H_
9 
10 #include <treelite/error.h>
11 #include <treelite/omp_exception.h>
12 #include <treelite/logging.h>
13 #include <treelite/typeinfo.h>
14 #include <treelite/c_api_runtime.h>
15 #include <treelite/data.h>
16 #include <string>
17 #include <memory>
18 #include <mutex>
19 #include <cstdint>
20 
21 #ifdef _WIN32
22 #define NOMINMAX
23 #include <windows.h>
24 #endif // _WIN32
25 
26 namespace treelite {
27 namespace predictor {
28 
33 template <typename ElementType>
34 union Entry {
35  int missing;
36  ElementType fvalue;
37  // may contain extra fields later, such as qvalue
38 };
39 
41  public:
42 #ifdef _WIN32
43  using LibraryHandle = HMODULE;
44  using FunctionHandle = FARPROC;
45 #else // _WIN32
46  using LibraryHandle = void*;
47  using FunctionHandle = void*;
48 #endif // _WIN32
49  SharedLibrary();
50  ~SharedLibrary();
51  void Load(const char* libpath);
52  FunctionHandle LoadFunction(const char* name) const;
53  template<typename HandleType>
54  HandleType LoadFunctionWithSignature(const char* name) const;
55 
56  private:
57  LibraryHandle handle_;
58  std::string libpath_;
59 };
60 
61 class PredFunction {
62  public:
63  static std::unique_ptr<PredFunction> Create(TypeInfo threshold_type, TypeInfo leaf_output_type,
64  const SharedLibrary& library, int num_feature,
65  int num_class);
66  PredFunction() = default;
67  virtual ~PredFunction() = default;
68  virtual TypeInfo GetThresholdType() const = 0;
69  virtual TypeInfo GetLeafOutputType() const = 0;
70  virtual size_t PredictBatch(const DMatrix* dmat, size_t rbegin, size_t rend, bool pred_margin,
71  PredictorOutputHandle out_pred) const = 0;
72 };
73 
74 template<typename ThresholdType, typename LeafOutputType>
76  public:
77  using PredFuncHandle = void*;
78  PredFunctionImpl(const SharedLibrary& library, int num_feature, int num_class);
79  TypeInfo GetThresholdType() const override;
80  TypeInfo GetLeafOutputType() const override;
81  size_t PredictBatch(const DMatrix* dmat, size_t rbegin, size_t rend, bool pred_margin,
82  PredictorOutputHandle out_pred) const override;
83 
84  private:
85  PredFuncHandle handle_;
86  int num_feature_;
87  int num_class_;
88 };
89 
91 class Predictor {
92  public:
94  typedef void* ThreadPoolHandle;
95 
96  explicit Predictor(int num_worker_thread = -1);
97  ~Predictor();
102  void Load(const char* libpath);
106  void Free();
119  size_t PredictBatch(
120  const DMatrix* dmat, int verbose, bool pred_margin, PredictorOutputHandle out_result) const;
127  inline size_t QueryResultSize(const DMatrix* dmat) const {
128  TREELITE_CHECK(pred_func_) << "A shared library needs to be loaded first using Load()";
129  return dmat->GetNumRow() * num_class_;
130  }
139  inline size_t QueryResultSize(const DMatrix* dmat, size_t rbegin, size_t rend) const {
140  TREELITE_CHECK(pred_func_) << "A shared library needs to be loaded first using Load()";
141  TREELITE_CHECK(rbegin < rend && rend <= dmat->GetNumRow());
142  return (rend - rbegin) * num_class_;
143  }
150  inline size_t QueryNumClass() const {
151  return num_class_;
152  }
158  inline size_t QueryNumFeature() const {
159  return num_feature_;
160  }
165  inline std::string QueryPredTransform() const {
166  return pred_transform_;
167  }
172  inline float QuerySigmoidAlpha() const {
173  return sigmoid_alpha_;
174  }
179  inline float QueryRatioC() const {
180  return ratio_c_;
181  }
186  inline float QueryGlobalBias() const {
187  return global_bias_;
188  }
193  inline TypeInfo QueryThresholdType() const {
194  return threshold_type_;
195  }
200  inline TypeInfo QueryLeafOutputType() const {
201  return leaf_output_type_;
202  }
208  PredictorOutputHandle CreateOutputVector(const DMatrix* dmat) const;
213  void DeleteOutputVector(PredictorOutputHandle output_vector) const;
214 
215  private:
216  SharedLibrary lib_;
217  std::unique_ptr<PredFunction> pred_func_;
218  ThreadPoolHandle thread_pool_handle_;
219  size_t num_class_;
220  size_t num_feature_;
221  std::string pred_transform_;
222  float sigmoid_alpha_;
223  float ratio_c_;
224  float global_bias_;
225  int num_worker_thread_;
226  TypeInfo threshold_type_;
227  TypeInfo leaf_output_type_;
228 
229  mutable OMPException exception_catcher_;
230 };
231 
232 } // namespace predictor
233 } // namespace treelite
234 
235 #endif // TREELITE_PREDICTOR_H_
std::string QueryPredTransform() const
Get name of post prediction transformation used to train the loaded model.
Definition: predictor.h:165
size_t QueryResultSize(const DMatrix *dmat, size_t rbegin, size_t rend) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
Definition: predictor.h:139
Input data structure of Treelite.
C API of Treelite, used for interfacing with other languages This header is used exclusively by the r...
void * PredictorOutputHandle
handle to output from predictor
Definition: c_api_runtime.h:25
size_t QueryNumFeature() const
Get the width (number of features) of each instance used to train the loaded model.
Definition: predictor.h:158
logging facility for Treelite
float QueryRatioC() const
Get c value in exponential standard ratio used to train the loaded model.
Definition: predictor.h:179
Exception class used throughout the Treelite codebase.
TypeInfo
Types used by thresholds and leaf outputs.
Definition: typeinfo.h:23
Utility to propagate exceptions throws inside an OpenMP block.
Defines TypeInfo class and utilities.
TypeInfo QueryLeafOutputType() const
Get the type of the leaf outputs.
Definition: predictor.h:200
float QuerySigmoidAlpha() const
Get alpha value in sigmoid transformation used to train the loaded model.
Definition: predictor.h:172
OMP Exception class catches, saves and rethrows exception from OMP blocks.
Definition: omp_exception.h:19
float QueryGlobalBias() const
Get global bias which adjusting predicted margin scores.
Definition: predictor.h:186
data layout. The value -1 signifies the missing value. When the "missing" field is set to -1...
Definition: predictor.h:34
predictor class: wrapper for optimized prediction code
Definition: predictor.h:91
size_t QueryResultSize(const DMatrix *dmat) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
Definition: predictor.h:127
void * ThreadPoolHandle
opaque handle types
Definition: predictor.h:94
TypeInfo QueryThresholdType() const
Get the type of the split thresholds.
Definition: predictor.h:193
size_t QueryNumClass() const
Get the number of classes in the loaded model The number is 1 for most tasks; it is greater than 1 fo...
Definition: predictor.h:150