Treelite
predictor.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_PREDICTOR_H_
8 #define TREELITE_PREDICTOR_H_
9 
10 #include <treelite/logging.h>
11 #include <treelite/typeinfo.h>
12 #include <treelite/c_api_runtime.h>
13 #include <treelite/data.h>
14 #include <string>
15 #include <memory>
16 #include <mutex>
17 #include <cstdint>
18 
19 namespace treelite {
20 namespace predictor {
21 
25 class OMPException {
26  private:
27  // exception_ptr member to store the exception
28  std::exception_ptr omp_exception_;
29  // mutex to be acquired during catch to set the exception_ptr
30  std::mutex mutex_;
31 
32  public:
36  template <typename Function, typename... Parameters>
37  void Run(Function f, Parameters... params) {
38  try {
39  f(params...);
40  } catch (treelite::Error &ex) {
41  std::lock_guard<std::mutex> lock(mutex_);
42  if (!omp_exception_) {
43  omp_exception_ = std::current_exception();
44  }
45  } catch (std::exception &ex) {
46  std::lock_guard<std::mutex> lock(mutex_);
47  if (!omp_exception_) {
48  omp_exception_ = std::current_exception();
49  }
50  }
51  }
52 
56  void Rethrow() {
57  if (this->omp_exception_) std::rethrow_exception(this->omp_exception_);
58  }
59 };
60 
65 template <typename ElementType>
66 union Entry {
67  int missing;
68  ElementType fvalue;
69  // may contain extra fields later, such as qvalue
70 };
71 
73  public:
74  using LibraryHandle = void*;
75  using FunctionHandle = void*;
76  SharedLibrary();
77  ~SharedLibrary();
78  void Load(const char* libpath);
79  FunctionHandle LoadFunction(const char* name) const;
80  template<typename HandleType>
81  HandleType LoadFunctionWithSignature(const char* name) const;
82 
83  private:
84  LibraryHandle handle_;
85  std::string libpath_;
86 };
87 
88 class PredFunction {
89  public:
90  static std::unique_ptr<PredFunction> Create(TypeInfo threshold_type, TypeInfo leaf_output_type,
91  const SharedLibrary& library, int num_feature,
92  int num_class);
93  PredFunction() = default;
94  virtual ~PredFunction() = default;
95  virtual TypeInfo GetThresholdType() const = 0;
96  virtual TypeInfo GetLeafOutputType() const = 0;
97  virtual size_t PredictBatch(const DMatrix* dmat, size_t rbegin, size_t rend, bool pred_margin,
98  PredictorOutputHandle out_pred) const = 0;
99 };
100 
101 template<typename ThresholdType, typename LeafOutputType>
103  public:
104  using PredFuncHandle = void*;
105  PredFunctionImpl(const SharedLibrary& library, int num_feature, int num_class);
106  TypeInfo GetThresholdType() const override;
107  TypeInfo GetLeafOutputType() const override;
108  size_t PredictBatch(const DMatrix* dmat, size_t rbegin, size_t rend, bool pred_margin,
109  PredictorOutputHandle out_pred) const override;
110 
111  private:
112  PredFuncHandle handle_;
113  int num_feature_;
114  int num_class_;
115 };
116 
118 class Predictor {
119  public:
121  typedef void* ThreadPoolHandle;
122 
123  explicit Predictor(int num_worker_thread = -1);
124  ~Predictor();
129  void Load(const char* libpath);
133  void Free();
146  size_t PredictBatch(
147  const DMatrix* dmat, int verbose, bool pred_margin, PredictorOutputHandle out_result) const;
154  inline size_t QueryResultSize(const DMatrix* dmat) const {
155  TREELITE_CHECK(pred_func_) << "A shared library needs to be loaded first using Load()";
156  return dmat->GetNumRow() * num_class_;
157  }
166  inline size_t QueryResultSize(const DMatrix* dmat, size_t rbegin, size_t rend) const {
167  TREELITE_CHECK(pred_func_) << "A shared library needs to be loaded first using Load()";
168  TREELITE_CHECK(rbegin < rend && rend <= dmat->GetNumRow());
169  return (rend - rbegin) * num_class_;
170  }
177  inline size_t QueryNumClass() const {
178  return num_class_;
179  }
185  inline size_t QueryNumFeature() const {
186  return num_feature_;
187  }
192  inline std::string QueryPredTransform() const {
193  return pred_transform_;
194  }
199  inline float QuerySigmoidAlpha() const {
200  return sigmoid_alpha_;
201  }
206  inline float QueryGlobalBias() const {
207  return global_bias_;
208  }
213  inline TypeInfo QueryThresholdType() const {
214  return threshold_type_;
215  }
220  inline TypeInfo QueryLeafOutputType() const {
221  return leaf_output_type_;
222  }
228  PredictorOutputHandle CreateOutputVector(const DMatrix* dmat) const;
233  void DeleteOutputVector(PredictorOutputHandle output_vector) const;
234 
235  private:
236  SharedLibrary lib_;
237  std::unique_ptr<PredFunction> pred_func_;
238  ThreadPoolHandle thread_pool_handle_;
239  size_t num_class_;
240  size_t num_feature_;
241  std::string pred_transform_;
242  float sigmoid_alpha_;
243  float global_bias_;
244  int num_worker_thread_;
245  TypeInfo threshold_type_;
246  TypeInfo leaf_output_type_;
247 
248  mutable OMPException exception_catcher_;
249 };
250 
251 } // namespace predictor
252 } // namespace treelite
253 
254 #endif // TREELITE_PREDICTOR_H_
std::string QueryPredTransform() const
Get name of post prediction transformation used to train the loaded model.
Definition: predictor.h:192
void Rethrow()
should be called from the main thread to rethrow the exception
Definition: predictor.h:56
size_t QueryResultSize(const DMatrix *dmat, size_t rbegin, size_t rend) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
Definition: predictor.h:166
Exception class that will be thrown by Treelite.
Definition: logging.h:24
Input data structure of Treelite.
C API of Treelite, used for interfacing with other languages This header is used exclusively by the r...
void Run(Function f, Parameters... params)
Parallel OMP blocks should be placed within Run to save exception.
Definition: predictor.h:37
void * PredictorOutputHandle
handle to output from predictor
Definition: c_api_runtime.h:25
size_t QueryNumFeature() const
Get the width (number of features) of each instance used to train the loaded model.
Definition: predictor.h:185
logging facility for Treelite
OMP Exception class catches, saves and rethrows exception from OMP blocks.
Definition: predictor.h:25
TypeInfo
Types used by thresholds and leaf outputs.
Definition: typeinfo.h:22
Defines TypeInfo class and utilities.
TypeInfo QueryLeafOutputType() const
Get the type of the leaf outputs.
Definition: predictor.h:220
float QuerySigmoidAlpha() const
Get alpha value in sigmoid transformation used to train the loaded model.
Definition: predictor.h:199
float QueryGlobalBias() const
Get global bias which adjusting predicted margin scores.
Definition: predictor.h:206
data layout. The value -1 signifies the missing value. When the "missing" field is set to -1...
Definition: predictor.h:66
predictor class: wrapper for optimized prediction code
Definition: predictor.h:118
size_t QueryResultSize(const DMatrix *dmat) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
Definition: predictor.h:154
void * ThreadPoolHandle
opaque handle types
Definition: predictor.h:121
TypeInfo QueryThresholdType() const
Get the type of the split thresholds.
Definition: predictor.h:213
size_t QueryNumClass() const
Get the number of classes in the loaded model The number is 1 for most tasks; it is greater than 1 fo...
Definition: predictor.h:177