Treelite
c_api_runtime.cc
Go to the documentation of this file.
1 
8 #include <treelite/predictor.h>
10 #include <treelite/c_api_error.h>
11 #include <treelite/thread_local.h>
12 #include <string>
13 #include <cstring>
14 
15 using namespace treelite;
16 
17 namespace {
18 
20 struct TreeliteRuntimeAPIThreadLocalEntry {
22  std::string ret_str;
23 };
24 
25 // thread-local store for returning strings
26 using TreeliteRuntimeAPIThreadLocalStore = ThreadLocalStore<TreeliteRuntimeAPIThreadLocalEntry>;
27 
28 } // anonymous namespace
29 
30 int TreelitePredictorLoad(const char* library_path, int num_worker_thread, PredictorHandle* out) {
31  API_BEGIN();
32  auto predictor = std::make_unique<predictor::Predictor>(num_worker_thread);
33  predictor->Load(library_path);
34  *out = static_cast<PredictorHandle>(predictor.release());
35  API_END();
36 }
37 
39  PredictorHandle handle, DMatrixHandle batch, int verbose, int pred_margin,
40  PredictorOutputHandle out_result, size_t* out_result_size) {
41  API_BEGIN();
42  const auto* predictor = static_cast<const predictor::Predictor*>(handle);
43  const auto* dmat = static_cast<const DMatrix*>(batch);
44  const size_t num_feature = predictor->QueryNumFeature();
45  const std::string err_msg
46  = std::string("Too many columns (features) in the given batch. "
47  "Number of features must not exceed ") + std::to_string(num_feature);
48  TREELITE_CHECK_LE(dmat->GetNumCol(), num_feature) << err_msg;
49  *out_result_size = predictor->PredictBatch(dmat, verbose, (pred_margin != 0), out_result);
50  API_END();
51 }
52 
54  PredictorHandle handle, DMatrixHandle batch, PredictorOutputHandle* out_output_vector) {
55  API_BEGIN();
56  const auto* predictor = static_cast<const predictor::Predictor*>(handle);
57  const auto* dmat = static_cast<const DMatrix*>(batch);
58  *out_output_vector = predictor->CreateOutputVector(dmat);
59  API_END();
60 }
61 
63  PredictorHandle handle, PredictorOutputHandle output_vector) {
64  API_BEGIN();
65  const auto* predictor = static_cast<const predictor::Predictor*>(handle);
66  predictor->DeleteOutputVector(output_vector);
67  API_END();
68 }
69 
71  API_BEGIN();
72  const auto* predictor = static_cast<const predictor::Predictor*>(handle);
73  const auto* dmat = static_cast<const DMatrix*>(batch);
74  *out = predictor->QueryResultSize(dmat);
75  API_END();
76 }
77 
79  API_BEGIN();
80  const auto* predictor = static_cast<const predictor::Predictor*>(handle);
81  *out = predictor->QueryNumClass();
82  API_END();
83 }
84 
86  API_BEGIN();
87  const auto* predictor = static_cast<const predictor::Predictor*>(handle);
88  *out = predictor->QueryNumFeature();
89  API_END();
90 }
91 
93  API_BEGIN()
94  const auto* predictor = static_cast<const predictor::Predictor*>(handle);
95  auto pred_transform = predictor->QueryPredTransform();
96  std::string& ret_str = TreeliteRuntimeAPIThreadLocalStore::Get()->ret_str;
97  ret_str = pred_transform;
98  *out = ret_str.c_str();
99  API_END();
100 }
101 
103  API_BEGIN()
104  const auto* predictor = static_cast<const predictor::Predictor*>(handle);
105  *out = predictor->QuerySigmoidAlpha();
106  API_END();
107 }
108 
110  API_BEGIN()
111  const auto* predictor = static_cast<const predictor::Predictor*>(handle);
112  *out = predictor->QueryGlobalBias();
113  API_END();
114 }
115 
116 int TreelitePredictorQueryThresholdType(PredictorHandle handle, const char** out) {
117  API_BEGIN()
118  const auto* predictor = static_cast<const predictor::Predictor*>(handle);
119  std::string& ret_str = TreeliteRuntimeAPIThreadLocalStore::Get()->ret_str;
120  ret_str = TypeInfoToString(predictor->QueryThresholdType());
121  *out = ret_str.c_str();
122  API_END();
123 }
124 
125 int TreelitePredictorQueryLeafOutputType(PredictorHandle handle, const char** out) {
126  API_BEGIN()
127  const auto* predictor = static_cast<const predictor::Predictor*>(handle);
128  std::string& ret_str = TreeliteRuntimeAPIThreadLocalStore::Get()->ret_str;
129  ret_str = TypeInfoToString(predictor->QueryLeafOutputType());
130  *out = ret_str.c_str();
131  API_END();
132 }
133 
135  API_BEGIN();
136  delete static_cast<predictor::Predictor*>(handle);
137  API_END();
138 }
Load prediction function exported as a shared library.
std::string QueryPredTransform() const
Get name of post prediction transformation used to train the loaded model.
Definition: predictor.h:192
#define API_BEGIN()
macro to guard beginning and end section of all functions
Definition: c_api_error.h:14
int TreeliteCreatePredictorOutputVector(PredictorHandle handle, DMatrixHandle batch, PredictorOutputHandle *out_output_vector)
Convenience function to allocate an output vector that is able to hold the prediction result for a gi...
int TreelitePredictorPredictBatch(PredictorHandle handle, DMatrixHandle batch, int verbose, int pred_margin, PredictorOutputHandle out_result, size_t *out_result_size)
Make predictions on a batch of data rows (synchronously). This function internally divides the worklo...
C API of Treelite, used for interfacing with other languages This header is used exclusively by the r...
void * PredictorOutputHandle
handle to output from predictor
Definition: c_api_runtime.h:25
size_t QueryNumFeature() const
Get the width (number of features) of each instance used to train the loaded model.
Definition: predictor.h:185
int TreelitePredictorQueryPredTransform(PredictorHandle handle, const char **out)
Get name of post prediction transformation used to train the loaded model.
void DeleteOutputVector(PredictorOutputHandle output_vector) const
Free an output vector from memory.
Definition: predictor.cc:470
int TreelitePredictorQueryNumFeature(PredictorHandle handle, size_t *out)
Get the width (number of features) of each instance used to train the loaded model.
void * DMatrixHandle
handle to a data matrix
Definition: c_api_common.h:30
Helper class for thread-local storage.
Error handling for C API.
int TreelitePredictorLoad(const char *library_path, int num_worker_thread, PredictorHandle *out)
load prediction code into memory. This function assumes that the prediction code has been already com...
float QuerySigmoidAlpha() const
Get alpha value in sigmoid transformation used to train the loaded model.
Definition: predictor.h:199
std::string TypeInfoToString(treelite::TypeInfo type)
Get string representation of type info.
Definition: typeinfo.h:39
void * PredictorHandle
handle to predictor class
Definition: c_api_runtime.h:23
int TreeliteDeletePredictorOutputVector(PredictorHandle handle, PredictorOutputHandle output_vector)
De-allocate an output vector.
float QueryGlobalBias() const
Get global bias which adjusting predicted margin scores.
Definition: predictor.h:206
predictor class: wrapper for optimized prediction code
Definition: predictor.h:118
int TreelitePredictorFree(PredictorHandle handle)
delete predictor from memory
A thread-local storage.
Definition: thread_local.h:17
int TreelitePredictorQueryNumClass(PredictorHandle handle, size_t *out)
Get the number classes in the loaded model The number is 1 for most tasks; it is greater than 1 for m...
int TreelitePredictorQueryResultSize(PredictorHandle handle, DMatrixHandle batch, size_t *out)
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
size_t QueryNumClass() const
Get the number of classes in the loaded model The number is 1 for most tasks; it is greater than 1 fo...
Definition: predictor.h:177
int TreelitePredictorQuerySigmoidAlpha(PredictorHandle handle, float *out)
Get alpha value of sigmoid transformation used to train the loaded model.
#define API_END()
every function starts with API_BEGIN(); and finishes with API_END() or API_END_HANDLE_ERROR ...
Definition: c_api_error.h:17
int TreelitePredictorQueryGlobalBias(PredictorHandle handle, float *out)
Get global bias which adjusting predicted margin scores.