Treelite
predictor.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_PREDICTOR_H_
8 #define TREELITE_PREDICTOR_H_
9 
10 #include <dmlc/logging.h>
11 #include <treelite/typeinfo.h>
12 #include <treelite/c_api_runtime.h>
13 #include <treelite/data.h>
14 #include <string>
15 #include <memory>
16 #include <cstdint>
17 
18 namespace treelite {
19 namespace predictor {
20 
25 template <typename ElementType>
26 union Entry {
27  int missing;
28  ElementType fvalue;
29  // may contain extra fields later, such as qvalue
30 };
31 
33  public:
34  using LibraryHandle = void*;
35  using FunctionHandle = void*;
36  SharedLibrary();
37  ~SharedLibrary();
38  void Load(const char* libpath);
39  FunctionHandle LoadFunction(const char* name) const;
40  template<typename HandleType>
41  HandleType LoadFunctionWithSignature(const char* name) const;
42 
43  private:
44  LibraryHandle handle_;
45  std::string libpath_;
46 };
47 
48 class PredFunction {
49  public:
50  static std::unique_ptr<PredFunction> Create(TypeInfo threshold_type, TypeInfo leaf_output_type,
51  const SharedLibrary& library, int num_feature,
52  int num_class);
53  PredFunction() = default;
54  virtual ~PredFunction() = default;
55  virtual TypeInfo GetThresholdType() const = 0;
56  virtual TypeInfo GetLeafOutputType() const = 0;
57  virtual size_t PredictBatch(const DMatrix* dmat, size_t rbegin, size_t rend, bool pred_margin,
58  PredictorOutputHandle out_pred) const = 0;
59 };
60 
61 template<typename ThresholdType, typename LeafOutputType>
63  public:
64  using PredFuncHandle = void*;
65  PredFunctionImpl(const SharedLibrary& library, int num_feature, int num_class);
66  TypeInfo GetThresholdType() const override;
67  TypeInfo GetLeafOutputType() const override;
68  size_t PredictBatch(const DMatrix* dmat, size_t rbegin, size_t rend, bool pred_margin,
69  PredictorOutputHandle out_pred) const override;
70 
71  private:
72  PredFuncHandle handle_;
73  int num_feature_;
74  int num_class_;
75 };
76 
78 class Predictor {
79  public:
81  typedef void* ThreadPoolHandle;
82 
83  explicit Predictor(int num_worker_thread = -1);
84  ~Predictor();
89  void Load(const char* libpath);
93  void Free();
106  size_t PredictBatch(
107  const DMatrix* dmat, int verbose, bool pred_margin, PredictorOutputHandle out_result) const;
114  inline size_t QueryResultSize(const DMatrix* dmat) const {
115  CHECK(pred_func_) << "A shared library needs to be loaded first using Load()";
116  return dmat->GetNumRow() * num_class_;
117  }
126  inline size_t QueryResultSize(const DMatrix* dmat, size_t rbegin, size_t rend) const {
127  CHECK(pred_func_) << "A shared library needs to be loaded first using Load()";
128  CHECK(rbegin < rend && rend <= dmat->GetNumRow());
129  return (rend - rbegin) * num_class_;
130  }
137  inline size_t QueryNumClass() const {
138  return num_class_;
139  }
145  inline size_t QueryNumFeature() const {
146  return num_feature_;
147  }
152  inline std::string QueryPredTransform() const {
153  return pred_transform_;
154  }
159  inline float QuerySigmoidAlpha() const {
160  return sigmoid_alpha_;
161  }
166  inline float QueryGlobalBias() const {
167  return global_bias_;
168  }
173  inline TypeInfo QueryThresholdType() const {
174  return threshold_type_;
175  }
180  inline TypeInfo QueryLeafOutputType() const {
181  return leaf_output_type_;
182  }
188  PredictorOutputHandle CreateOutputVector(const DMatrix* dmat) const;
193  void DeleteOutputVector(PredictorOutputHandle output_vector) const;
194 
195  private:
196  SharedLibrary lib_;
197  std::unique_ptr<PredFunction> pred_func_;
198  ThreadPoolHandle thread_pool_handle_;
199  size_t num_class_;
200  size_t num_feature_;
201  std::string pred_transform_;
202  float sigmoid_alpha_;
203  float global_bias_;
204  int num_worker_thread_;
205  TypeInfo threshold_type_;
206  TypeInfo leaf_output_type_;
207 
208  mutable dmlc::OMPException exception_catcher_;
209 };
210 
211 } // namespace predictor
212 } // namespace treelite
213 
214 #endif // TREELITE_PREDICTOR_H_
std::string QueryPredTransform() const
Get name of post prediction transformation used to train the loaded model.
Definition: predictor.h:152
size_t QueryResultSize(const DMatrix *dmat, size_t rbegin, size_t rend) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
Definition: predictor.h:126
Input data structure of Treelite.
C API of Treelite, used for interfacing with other languages This header is used exclusively by the r...
void * PredictorOutputHandle
handle to output from predictor
Definition: c_api_runtime.h:25
size_t QueryNumFeature() const
Get the width (number of features) of each instance used to train the loaded model.
Definition: predictor.h:145
TypeInfo
Types used by thresholds and leaf outputs.
Definition: typeinfo.h:22
Defines TypeInfo class and utilities.
TypeInfo QueryLeafOutputType() const
Get the type of the leaf outputs.
Definition: predictor.h:180
float QuerySigmoidAlpha() const
Get alpha value in sigmoid transformation used to train the loaded model.
Definition: predictor.h:159
float QueryGlobalBias() const
Get global bias which adjusting predicted margin scores.
Definition: predictor.h:166
data layout. The value -1 signifies the missing value. When the "missing" field is set to -1...
Definition: predictor.h:26
predictor class: wrapper for optimized prediction code
Definition: predictor.h:78
size_t QueryResultSize(const DMatrix *dmat) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
Definition: predictor.h:114
void * ThreadPoolHandle
opaque handle types
Definition: predictor.h:81
TypeInfo QueryThresholdType() const
Get the type of the split thresholds.
Definition: predictor.h:173
size_t QueryNumClass() const
Get the number of classes in the loaded model The number is 1 for most tasks; it is greater than 1 fo...
Definition: predictor.h:137