Treelite
predictor.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_PREDICTOR_H_
8 #define TREELITE_PREDICTOR_H_
9 
10 #include <treelite/logging.h>
11 #include <treelite/typeinfo.h>
12 #include <treelite/c_api_runtime.h>
13 #include <treelite/data.h>
14 #include <string>
15 #include <memory>
16 #include <mutex>
17 #include <cstdint>
18 
19 namespace treelite {
20 namespace predictor {
21 
25 class OMPException {
26  private:
27  // exception_ptr member to store the exception
28  std::exception_ptr omp_exception_;
29  // mutex to be acquired during catch to set the exception_ptr
30  std::mutex mutex_;
31 
32  public:
36  template <typename Function, typename... Parameters>
37  void Run(Function f, Parameters... params) {
38  try {
39  f(params...);
40  } catch (treelite::Error &ex) {
41  std::lock_guard<std::mutex> lock(mutex_);
42  if (!omp_exception_) {
43  omp_exception_ = std::current_exception();
44  }
45  } catch (std::exception &ex) {
46  std::lock_guard<std::mutex> lock(mutex_);
47  if (!omp_exception_) {
48  omp_exception_ = std::current_exception();
49  }
50  }
51  }
52 
56  void Rethrow() {
57  if (this->omp_exception_) std::rethrow_exception(this->omp_exception_);
58  }
59 };
60 
65 template <typename ElementType>
66 union Entry {
67  int missing;
68  ElementType fvalue;
69  // may contain extra fields later, such as qvalue
70 };
71 
73  public:
74  using LibraryHandle = void*;
75  using FunctionHandle = void*;
76  SharedLibrary();
77  ~SharedLibrary();
78  void Load(const char* libpath);
79  FunctionHandle LoadFunction(const char* name) const;
80  template<typename HandleType>
81  HandleType LoadFunctionWithSignature(const char* name) const;
82 
83  private:
84  LibraryHandle handle_;
85  std::string libpath_;
86 };
87 
88 class PredFunction {
89  public:
90  static std::unique_ptr<PredFunction> Create(TypeInfo threshold_type, TypeInfo leaf_output_type,
91  const SharedLibrary& library, int num_feature,
92  int num_class);
93  PredFunction() = default;
94  virtual ~PredFunction() = default;
95  virtual TypeInfo GetThresholdType() const = 0;
96  virtual TypeInfo GetLeafOutputType() const = 0;
97  virtual size_t PredictBatch(const DMatrix* dmat, size_t rbegin, size_t rend, bool pred_margin,
98  PredictorOutputHandle out_pred) const = 0;
99 };
100 
101 template<typename ThresholdType, typename LeafOutputType>
103  public:
104  using PredFuncHandle = void*;
105  PredFunctionImpl(const SharedLibrary& library, int num_feature, int num_class);
106  TypeInfo GetThresholdType() const override;
107  TypeInfo GetLeafOutputType() const override;
108  size_t PredictBatch(const DMatrix* dmat, size_t rbegin, size_t rend, bool pred_margin,
109  PredictorOutputHandle out_pred) const override;
110 
111  private:
112  PredFuncHandle handle_;
113  int num_feature_;
114  int num_class_;
115 };
116 
118 class Predictor {
119  public:
121  typedef void* ThreadPoolHandle;
122 
123  explicit Predictor(int num_worker_thread = -1);
124  ~Predictor();
129  void Load(const char* libpath);
133  void Free();
146  size_t PredictBatch(
147  const DMatrix* dmat, int verbose, bool pred_margin, PredictorOutputHandle out_result) const;
154  inline size_t QueryResultSize(const DMatrix* dmat) const {
155  TREELITE_CHECK(pred_func_) << "A shared library needs to be loaded first using Load()";
156  return dmat->GetNumRow() * num_class_;
157  }
166  inline size_t QueryResultSize(const DMatrix* dmat, size_t rbegin, size_t rend) const {
167  TREELITE_CHECK(pred_func_) << "A shared library needs to be loaded first using Load()";
168  TREELITE_CHECK(rbegin < rend && rend <= dmat->GetNumRow());
169  return (rend - rbegin) * num_class_;
170  }
177  inline size_t QueryNumClass() const {
178  return num_class_;
179  }
185  inline size_t QueryNumFeature() const {
186  return num_feature_;
187  }
192  inline std::string QueryPredTransform() const {
193  return pred_transform_;
194  }
199  inline float QuerySigmoidAlpha() const {
200  return sigmoid_alpha_;
201  }
206  inline float QueryRatioC() const {
207  return ratio_c_;
208  }
213  inline float QueryGlobalBias() const {
214  return global_bias_;
215  }
220  inline TypeInfo QueryThresholdType() const {
221  return threshold_type_;
222  }
227  inline TypeInfo QueryLeafOutputType() const {
228  return leaf_output_type_;
229  }
235  PredictorOutputHandle CreateOutputVector(const DMatrix* dmat) const;
240  void DeleteOutputVector(PredictorOutputHandle output_vector) const;
241 
242  private:
243  SharedLibrary lib_;
244  std::unique_ptr<PredFunction> pred_func_;
245  ThreadPoolHandle thread_pool_handle_;
246  size_t num_class_;
247  size_t num_feature_;
248  std::string pred_transform_;
249  float sigmoid_alpha_;
250  float ratio_c_;
251  float global_bias_;
252  int num_worker_thread_;
253  TypeInfo threshold_type_;
254  TypeInfo leaf_output_type_;
255 
256  mutable OMPException exception_catcher_;
257 };
258 
259 } // namespace predictor
260 } // namespace treelite
261 
262 #endif // TREELITE_PREDICTOR_H_
std::string QueryPredTransform() const
Get name of post prediction transformation used to train the loaded model.
Definition: predictor.h:192
void Rethrow()
should be called from the main thread to rethrow the exception
Definition: predictor.h:56
size_t QueryResultSize(const DMatrix *dmat, size_t rbegin, size_t rend) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
Definition: predictor.h:166
Exception class that will be thrown by Treelite.
Definition: logging.h:24
Input data structure of Treelite.
C API of Treelite, used for interfacing with other languages This header is used exclusively by the r...
void Run(Function f, Parameters... params)
Parallel OMP blocks should be placed within Run to save exception.
Definition: predictor.h:37
void * PredictorOutputHandle
handle to output from predictor
Definition: c_api_runtime.h:25
size_t QueryNumFeature() const
Get the width (number of features) of each instance used to train the loaded model.
Definition: predictor.h:185
logging facility for Treelite
float QueryRatioC() const
Get c value in exponential standard ratio used to train the loaded model.
Definition: predictor.h:206
OMP Exception class catches, saves and rethrows exception from OMP blocks.
Definition: predictor.h:25
TypeInfo
Types used by thresholds and leaf outputs.
Definition: typeinfo.h:22
Defines TypeInfo class and utilities.
TypeInfo QueryLeafOutputType() const
Get the type of the leaf outputs.
Definition: predictor.h:227
float QuerySigmoidAlpha() const
Get alpha value in sigmoid transformation used to train the loaded model.
Definition: predictor.h:199
float QueryGlobalBias() const
Get global bias which adjusting predicted margin scores.
Definition: predictor.h:213
data layout. The value -1 signifies the missing value. When the "missing" field is set to -1...
Definition: predictor.h:66
predictor class: wrapper for optimized prediction code
Definition: predictor.h:118
size_t QueryResultSize(const DMatrix *dmat) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
Definition: predictor.h:154
void * ThreadPoolHandle
opaque handle types
Definition: predictor.h:121
TypeInfo QueryThresholdType() const
Get the type of the split thresholds.
Definition: predictor.h:220
size_t QueryNumClass() const
Get the number of classes in the loaded model The number is 1 for most tasks; it is greater than 1 fo...
Definition: predictor.h:177