Treelite
c_api_runtime.cc
Go to the documentation of this file.
1 
8 #include <treelite/predictor.h>
10 #include <treelite/c_api_error.h>
11 #include <treelite/thread_local.h>
12 #include <string>
13 #include <cstring>
14 
15 using namespace treelite;
16 
17 namespace {
18 
20 struct TreeliteRuntimeAPIThreadLocalEntry {
22  std::string ret_str;
23 };
24 
25 // thread-local store for returning strings
26 using TreeliteRuntimeAPIThreadLocalStore = ThreadLocalStore<TreeliteRuntimeAPIThreadLocalEntry>;
27 
28 } // anonymous namespace
29 
30 int TreelitePredictorLoad(const char* library_path, int num_worker_thread, PredictorHandle* out) {
31  API_BEGIN();
32  auto predictor = std::make_unique<predictor::Predictor>(num_worker_thread);
33  predictor->Load(library_path);
34  *out = static_cast<PredictorHandle>(predictor.release());
35  API_END();
36 }
37 
39  PredictorHandle handle, DMatrixHandle batch, int verbose, int pred_margin,
40  PredictorOutputHandle out_result, size_t* out_result_size) {
41  API_BEGIN();
42  const auto* predictor = static_cast<const predictor::Predictor*>(handle);
43  const auto* dmat = static_cast<const DMatrix*>(batch);
44  const size_t num_feature = predictor->QueryNumFeature();
45  const std::string err_msg
46  = std::string("Too many columns (features) in the given batch. "
47  "Number of features must not exceed ") + std::to_string(num_feature);
48  TREELITE_CHECK_LE(dmat->GetNumCol(), num_feature) << err_msg;
49  *out_result_size = predictor->PredictBatch(dmat, verbose, (pred_margin != 0), out_result);
50  API_END();
51 }
52 
54  PredictorHandle handle, DMatrixHandle batch, PredictorOutputHandle* out_output_vector) {
55  API_BEGIN();
56  const auto* predictor = static_cast<const predictor::Predictor*>(handle);
57  const auto* dmat = static_cast<const DMatrix*>(batch);
58  *out_output_vector = predictor->CreateOutputVector(dmat);
59  API_END();
60 }
61 
63  PredictorHandle handle, PredictorOutputHandle output_vector) {
64  API_BEGIN();
65  const auto* predictor = static_cast<const predictor::Predictor*>(handle);
66  predictor->DeleteOutputVector(output_vector);
67  API_END();
68 }
69 
71  API_BEGIN();
72  const auto* predictor = static_cast<const predictor::Predictor*>(handle);
73  const auto* dmat = static_cast<const DMatrix*>(batch);
74  *out = predictor->QueryResultSize(dmat);
75  API_END();
76 }
77 
79  API_BEGIN();
80  const auto* predictor = static_cast<const predictor::Predictor*>(handle);
81  *out = predictor->QueryNumClass();
82  API_END();
83 }
84 
86  API_BEGIN();
87  const auto* predictor = static_cast<const predictor::Predictor*>(handle);
88  *out = predictor->QueryNumFeature();
89  API_END();
90 }
91 
93  API_BEGIN()
94  const auto* predictor = static_cast<const predictor::Predictor*>(handle);
95  auto pred_transform = predictor->QueryPredTransform();
96  std::string& ret_str = TreeliteRuntimeAPIThreadLocalStore::Get()->ret_str;
97  ret_str = pred_transform;
98  *out = ret_str.c_str();
99  API_END();
100 }
101 
103  API_BEGIN()
104  const auto* predictor = static_cast<const predictor::Predictor*>(handle);
105  *out = predictor->QuerySigmoidAlpha();
106  API_END();
107 }
108 
110  API_BEGIN()
111  const auto* predictor = static_cast<const predictor::Predictor*>(handle);
112  *out = predictor->QueryRatioC();
113  API_END();
114 }
115 
117  API_BEGIN()
118  const auto* predictor = static_cast<const predictor::Predictor*>(handle);
119  *out = predictor->QueryGlobalBias();
120  API_END();
121 }
122 
123 int TreelitePredictorQueryThresholdType(PredictorHandle handle, const char** out) {
124  API_BEGIN()
125  const auto* predictor = static_cast<const predictor::Predictor*>(handle);
126  std::string& ret_str = TreeliteRuntimeAPIThreadLocalStore::Get()->ret_str;
127  ret_str = TypeInfoToString(predictor->QueryThresholdType());
128  *out = ret_str.c_str();
129  API_END();
130 }
131 
132 int TreelitePredictorQueryLeafOutputType(PredictorHandle handle, const char** out) {
133  API_BEGIN()
134  const auto* predictor = static_cast<const predictor::Predictor*>(handle);
135  std::string& ret_str = TreeliteRuntimeAPIThreadLocalStore::Get()->ret_str;
136  ret_str = TypeInfoToString(predictor->QueryLeafOutputType());
137  *out = ret_str.c_str();
138  API_END();
139 }
140 
142  API_BEGIN();
143  delete static_cast<predictor::Predictor*>(handle);
144  API_END();
145 }
Load prediction function exported as a shared library.
std::string QueryPredTransform() const
Get name of post prediction transformation used to train the loaded model.
Definition: predictor.h:202
#define API_BEGIN()
macro to guard beginning and end section of all functions
Definition: c_api_error.h:14
int TreeliteCreatePredictorOutputVector(PredictorHandle handle, DMatrixHandle batch, PredictorOutputHandle *out_output_vector)
Convenience function to allocate an output vector that is able to hold the prediction result for a gi...
int TreelitePredictorPredictBatch(PredictorHandle handle, DMatrixHandle batch, int verbose, int pred_margin, PredictorOutputHandle out_result, size_t *out_result_size)
Make predictions on a batch of data rows (synchronously). This function internally divides the worklo...
C API of Treelite, used for interfacing with other languages This header is used exclusively by the r...
void * PredictorOutputHandle
handle to output from predictor
Definition: c_api_runtime.h:25
size_t QueryNumFeature() const
Get the width (number of features) of each instance used to train the loaded model.
Definition: predictor.h:195
int TreelitePredictorQueryPredTransform(PredictorHandle handle, const char **out)
Get name of post prediction transformation used to train the loaded model.
float QueryRatioC() const
Get c value in exponential standard ratio used to train the loaded model.
Definition: predictor.h:216
void DeleteOutputVector(PredictorOutputHandle output_vector) const
Free an output vector from memory.
Definition: predictor.cc:476
int TreelitePredictorQueryNumFeature(PredictorHandle handle, size_t *out)
Get the width (number of features) of each instance used to train the loaded model.
void * DMatrixHandle
handle to a data matrix
Definition: c_api_common.h:30
Helper class for thread-local storage.
Error handling for C API.
int TreelitePredictorLoad(const char *library_path, int num_worker_thread, PredictorHandle *out)
load prediction code into memory. This function assumes that the prediction code has been already com...
float QuerySigmoidAlpha() const
Get alpha value in sigmoid transformation used to train the loaded model.
Definition: predictor.h:209
int TreelitePredictorQueryRatioC(PredictorHandle handle, float *out)
Get c value of exponential standard ratio transformation used to train the loaded model...
std::string TypeInfoToString(treelite::TypeInfo type)
Get string representation of type info.
Definition: typeinfo.h:39
void * PredictorHandle
handle to predictor class
Definition: c_api_runtime.h:23
int TreeliteDeletePredictorOutputVector(PredictorHandle handle, PredictorOutputHandle output_vector)
De-allocate an output vector.
float QueryGlobalBias() const
Get global bias which adjusting predicted margin scores.
Definition: predictor.h:223
predictor class: wrapper for optimized prediction code
Definition: predictor.h:128
int TreelitePredictorFree(PredictorHandle handle)
delete predictor from memory
A thread-local storage.
Definition: thread_local.h:17
int TreelitePredictorQueryNumClass(PredictorHandle handle, size_t *out)
Get the number classes in the loaded model The number is 1 for most tasks; it is greater than 1 for m...
int TreelitePredictorQueryResultSize(PredictorHandle handle, DMatrixHandle batch, size_t *out)
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
size_t QueryNumClass() const
Get the number of classes in the loaded model The number is 1 for most tasks; it is greater than 1 fo...
Definition: predictor.h:187
int TreelitePredictorQuerySigmoidAlpha(PredictorHandle handle, float *out)
Get alpha value of sigmoid transformation used to train the loaded model.
#define API_END()
every function starts with API_BEGIN(); and finishes with API_END() or API_END_HANDLE_ERROR ...
Definition: c_api_error.h:17
int TreelitePredictorQueryGlobalBias(PredictorHandle handle, float *out)
Get global bias which adjusting predicted margin scores.