Treelite
predictor.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_PREDICTOR_H_
8 #define TREELITE_PREDICTOR_H_
9 
10 #include <treelite/logging.h>
11 #include <treelite/typeinfo.h>
12 #include <treelite/c_api_runtime.h>
13 #include <treelite/data.h>
14 #include <string>
15 #include <memory>
16 #include <mutex>
17 #include <cstdint>
18 
19 #ifdef _WIN32
20 #define NOMINMAX
21 #include <windows.h>
22 #endif // _WIN32
23 
24 namespace treelite {
25 namespace predictor {
26 
30 class OMPException {
31  private:
32  // exception_ptr member to store the exception
33  std::exception_ptr omp_exception_;
34  // mutex to be acquired during catch to set the exception_ptr
35  std::mutex mutex_;
36 
37  public:
41  template <typename Function, typename... Parameters>
42  void Run(Function f, Parameters... params) {
43  try {
44  f(params...);
45  } catch (treelite::Error &ex) {
46  std::lock_guard<std::mutex> lock(mutex_);
47  if (!omp_exception_) {
48  omp_exception_ = std::current_exception();
49  }
50  } catch (std::exception &ex) {
51  std::lock_guard<std::mutex> lock(mutex_);
52  if (!omp_exception_) {
53  omp_exception_ = std::current_exception();
54  }
55  }
56  }
57 
61  void Rethrow() {
62  if (this->omp_exception_) std::rethrow_exception(this->omp_exception_);
63  }
64 };
65 
70 template <typename ElementType>
71 union Entry {
72  int missing;
73  ElementType fvalue;
74  // may contain extra fields later, such as qvalue
75 };
76 
78  public:
79 #ifdef _WIN32
80  using LibraryHandle = HMODULE;
81  using FunctionHandle = FARPROC;
82 #else // _WIN32
83  using LibraryHandle = void*;
84  using FunctionHandle = void*;
85 #endif // _WIN32
86  SharedLibrary();
87  ~SharedLibrary();
88  void Load(const char* libpath);
89  FunctionHandle LoadFunction(const char* name) const;
90  template<typename HandleType>
91  HandleType LoadFunctionWithSignature(const char* name) const;
92 
93  private:
94  LibraryHandle handle_;
95  std::string libpath_;
96 };
97 
98 class PredFunction {
99  public:
100  static std::unique_ptr<PredFunction> Create(TypeInfo threshold_type, TypeInfo leaf_output_type,
101  const SharedLibrary& library, int num_feature,
102  int num_class);
103  PredFunction() = default;
104  virtual ~PredFunction() = default;
105  virtual TypeInfo GetThresholdType() const = 0;
106  virtual TypeInfo GetLeafOutputType() const = 0;
107  virtual size_t PredictBatch(const DMatrix* dmat, size_t rbegin, size_t rend, bool pred_margin,
108  PredictorOutputHandle out_pred) const = 0;
109 };
110 
111 template<typename ThresholdType, typename LeafOutputType>
113  public:
114  using PredFuncHandle = void*;
115  PredFunctionImpl(const SharedLibrary& library, int num_feature, int num_class);
116  TypeInfo GetThresholdType() const override;
117  TypeInfo GetLeafOutputType() const override;
118  size_t PredictBatch(const DMatrix* dmat, size_t rbegin, size_t rend, bool pred_margin,
119  PredictorOutputHandle out_pred) const override;
120 
121  private:
122  PredFuncHandle handle_;
123  int num_feature_;
124  int num_class_;
125 };
126 
128 class Predictor {
129  public:
131  typedef void* ThreadPoolHandle;
132 
133  explicit Predictor(int num_worker_thread = -1);
134  ~Predictor();
139  void Load(const char* libpath);
143  void Free();
156  size_t PredictBatch(
157  const DMatrix* dmat, int verbose, bool pred_margin, PredictorOutputHandle out_result) const;
164  inline size_t QueryResultSize(const DMatrix* dmat) const {
165  TREELITE_CHECK(pred_func_) << "A shared library needs to be loaded first using Load()";
166  return dmat->GetNumRow() * num_class_;
167  }
176  inline size_t QueryResultSize(const DMatrix* dmat, size_t rbegin, size_t rend) const {
177  TREELITE_CHECK(pred_func_) << "A shared library needs to be loaded first using Load()";
178  TREELITE_CHECK(rbegin < rend && rend <= dmat->GetNumRow());
179  return (rend - rbegin) * num_class_;
180  }
187  inline size_t QueryNumClass() const {
188  return num_class_;
189  }
195  inline size_t QueryNumFeature() const {
196  return num_feature_;
197  }
202  inline std::string QueryPredTransform() const {
203  return pred_transform_;
204  }
209  inline float QuerySigmoidAlpha() const {
210  return sigmoid_alpha_;
211  }
216  inline float QueryRatioC() const {
217  return ratio_c_;
218  }
223  inline float QueryGlobalBias() const {
224  return global_bias_;
225  }
230  inline TypeInfo QueryThresholdType() const {
231  return threshold_type_;
232  }
237  inline TypeInfo QueryLeafOutputType() const {
238  return leaf_output_type_;
239  }
245  PredictorOutputHandle CreateOutputVector(const DMatrix* dmat) const;
250  void DeleteOutputVector(PredictorOutputHandle output_vector) const;
251 
252  private:
253  SharedLibrary lib_;
254  std::unique_ptr<PredFunction> pred_func_;
255  ThreadPoolHandle thread_pool_handle_;
256  size_t num_class_;
257  size_t num_feature_;
258  std::string pred_transform_;
259  float sigmoid_alpha_;
260  float ratio_c_;
261  float global_bias_;
262  int num_worker_thread_;
263  TypeInfo threshold_type_;
264  TypeInfo leaf_output_type_;
265 
266  mutable OMPException exception_catcher_;
267 };
268 
269 } // namespace predictor
270 } // namespace treelite
271 
272 #endif // TREELITE_PREDICTOR_H_
std::string QueryPredTransform() const
Get name of post prediction transformation used to train the loaded model.
Definition: predictor.h:202
void Rethrow()
should be called from the main thread to rethrow the exception
Definition: predictor.h:61
size_t QueryResultSize(const DMatrix *dmat, size_t rbegin, size_t rend) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
Definition: predictor.h:176
Exception class that will be thrown by Treelite.
Definition: logging.h:24
Input data structure of Treelite.
C API of Treelite, used for interfacing with other languages This header is used exclusively by the r...
void Run(Function f, Parameters... params)
Parallel OMP blocks should be placed within Run to save exception.
Definition: predictor.h:42
void * PredictorOutputHandle
handle to output from predictor
Definition: c_api_runtime.h:25
size_t QueryNumFeature() const
Get the width (number of features) of each instance used to train the loaded model.
Definition: predictor.h:195
logging facility for Treelite
float QueryRatioC() const
Get c value in exponential standard ratio used to train the loaded model.
Definition: predictor.h:216
OMP Exception class catches, saves and rethrows exception from OMP blocks.
Definition: predictor.h:30
TypeInfo
Types used by thresholds and leaf outputs.
Definition: typeinfo.h:22
Defines TypeInfo class and utilities.
TypeInfo QueryLeafOutputType() const
Get the type of the leaf outputs.
Definition: predictor.h:237
float QuerySigmoidAlpha() const
Get alpha value in sigmoid transformation used to train the loaded model.
Definition: predictor.h:209
float QueryGlobalBias() const
Get global bias which adjusting predicted margin scores.
Definition: predictor.h:223
data layout. The value -1 signifies the missing value. When the "missing" field is set to -1...
Definition: predictor.h:71
predictor class: wrapper for optimized prediction code
Definition: predictor.h:128
size_t QueryResultSize(const DMatrix *dmat) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
Definition: predictor.h:164
void * ThreadPoolHandle
opaque handle types
Definition: predictor.h:131
TypeInfo QueryThresholdType() const
Get the type of the split thresholds.
Definition: predictor.h:230
size_t QueryNumClass() const
Get the number of classes in the loaded model The number is 1 for most tasks; it is greater than 1 fo...
Definition: predictor.h:187