treelite
c_api_runtime.cc
Go to the documentation of this file.
1 
8 #include <treelite/predictor.h>
10 #include <string>
11 #include <cstring>
12 #include "./c_api_error.h"
13 
14 using namespace treelite;
15 
16 int TreeliteAssembleSparseBatch(const float* data,
17  const uint32_t* col_ind,
18  const size_t* row_ptr,
19  size_t num_row, size_t num_col,
20  CSRBatchHandle* out) {
21  API_BEGIN();
22  CSRBatch* batch = new CSRBatch();
23  batch->data = data;
24  batch->col_ind = col_ind;
25  batch->row_ptr = row_ptr;
26  batch->num_row = num_row;
27  batch->num_col = num_col;
28  *out = static_cast<CSRBatchHandle>(batch);
29  API_END();
30 }
31 
33  API_BEGIN();
34  delete static_cast<CSRBatch*>(handle);
35  API_END();
36 }
37 
38 int TreeliteAssembleDenseBatch(const float* data, float missing_value,
39  size_t num_row, size_t num_col,
40  DenseBatchHandle* out) {
41  API_BEGIN();
42  DenseBatch* batch = new DenseBatch();
43  batch->data = data;
44  batch->missing_value = missing_value;
45  batch->num_row = num_row;
46  batch->num_col = num_col;
47  *out = static_cast<DenseBatchHandle>(batch);
48  API_END();
49 }
50 
52  API_BEGIN();
53  delete static_cast<DenseBatch*>(handle);
54  API_END();
55 }
56 
57 int TreeliteBatchGetDimension(void* handle,
58  int batch_sparse,
59  size_t* out_num_row,
60  size_t* out_num_col) {
61  API_BEGIN();
62  if (batch_sparse) {
63  const CSRBatch* batch_ = static_cast<CSRBatch*>(handle);
64  *out_num_row = batch_->num_row;
65  *out_num_col = batch_->num_col;
66  } else {
67  const DenseBatch* batch_ = static_cast<DenseBatch*>(handle);
68  *out_num_row = batch_->num_row;
69  *out_num_col = batch_->num_col;
70  }
71  API_END();
72 }
73 
74 int TreelitePredictorLoad(const char* library_path,
75  int num_worker_thread,
76  PredictorHandle* out) {
77  API_BEGIN();
78  Predictor* predictor = new Predictor(num_worker_thread);
79  predictor->Load(library_path);
80  *out = static_cast<PredictorHandle>(predictor);
81  API_END();
82 }
83 
85  void* batch,
86  int batch_sparse,
87  int verbose,
88  int pred_margin,
89  float* out_result,
90  size_t* out_result_size) {
91  API_BEGIN();
92  Predictor* predictor_ = static_cast<Predictor*>(handle);
93  const size_t num_feature = predictor_->QueryNumFeature();
94  const std::string err_msg
95  = std::string("Too many columns (features) in the given batch. "
96  "Number of features must not exceed ")
97  + std::to_string(num_feature);
98  if (batch_sparse) {
99  const CSRBatch* batch_ = static_cast<CSRBatch*>(batch);
100  CHECK_LE(batch_->num_col, num_feature) << err_msg;
101  *out_result_size = predictor_->PredictBatch(batch_, verbose,
102  (pred_margin != 0), out_result);
103  } else {
104  const DenseBatch* batch_ = static_cast<DenseBatch*>(batch);
105  CHECK_LE(batch_->num_col, num_feature) << err_msg;
106  *out_result_size = predictor_->PredictBatch(batch_, verbose,
107  (pred_margin != 0), out_result);
108  }
109  API_END();
110 }
111 
113  union TreelitePredictorEntry* inst,
114  int pred_margin,
115  float* out_result, size_t* out_result_size) {
116  API_BEGIN();
117  Predictor* predictor_ = static_cast<Predictor*>(handle);
118  *out_result_size
119  = predictor_->PredictInst(inst, (pred_margin != 0), out_result);
120  API_END();
121 }
122 
124  void* batch,
125  int batch_sparse,
126  size_t* out) {
127  API_BEGIN();
128  const Predictor* predictor_ = static_cast<Predictor*>(handle);
129  if (batch_sparse) {
130  const CSRBatch* batch_ = static_cast<CSRBatch*>(batch);
131  *out = predictor_->QueryResultSize(batch_);
132  } else {
133  const DenseBatch* batch_ = static_cast<DenseBatch*>(batch);
134  *out = predictor_->QueryResultSize(batch_);
135  }
136  API_END();
137 }
138 
140  size_t* out) {
141  API_BEGIN();
142  const Predictor* predictor_ = static_cast<Predictor*>(handle);
143  *out = predictor_->QueryResultSizeSingleInst();
144  API_END();
145 }
146 
148  API_BEGIN();
149  const Predictor* predictor_ = static_cast<Predictor*>(handle);
150  *out = predictor_->QueryNumOutputGroup();
151  API_END();
152 }
153 
155  API_BEGIN();
156  const Predictor* predictor_ = static_cast<Predictor*>(handle);
157  *out = predictor_->QueryNumFeature();
158  API_END();
159 }
160 
162  API_BEGIN()
163  const Predictor* predictor_ = static_cast<Predictor*>(handle);
164  auto predTransform = predictor_->QueryPredTransform();
165  *out = new char[predTransform.length() + 1];
166  strcpy(*out, predTransform.c_str());
167  API_END();
168 }
169 
171  API_BEGIN()
172  const Predictor* predictor_ = static_cast<Predictor*>(handle);
173  *out = predictor_->QuerySigmoidAlpha();
174  API_END();
175 }
176 
178  API_BEGIN()
179  const Predictor* predictor_ = static_cast<Predictor*>(handle);
180  *out = predictor_->QueryGlobalBias();
181  API_END();
182 }
183 
185  API_BEGIN();
186  delete static_cast<Predictor*>(handle);
187  API_END();
188 }
Load prediction function exported as a shared library.
void * DenseBatchHandle
handle to batch of dense data rows
Definition: c_api_runtime.h:28
size_t QueryNumFeature() const
Get the width (number of features) of each instance used to train the loaded model.
Definition: predictor.h:176
void * CSRBatchHandle
handle to batch of sparse data rows
Definition: c_api_runtime.h:26
const uint32_t * col_ind
feature indices
Definition: predictor.h:27
size_t QueryResultSizeSingleInst() const
Query the necessary size of array to hold the prediction for a single data row.
Definition: predictor.h:156
C API of treelite, used for interfacing with other languages This header is used exclusively by the r...
int TreelitePredictorQueryResultSize(PredictorHandle handle, void *batch, int batch_sparse, size_t *out)
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
int TreelitePredictorQueryResultSizeSingleInst(PredictorHandle handle, size_t *out)
Query the necessary size of array to hold the prediction for a single data row.
int TreelitePredictorPredictBatch(PredictorHandle handle, void *batch, int batch_sparse, int verbose, int pred_margin, float *out_result, size_t *out_result_size)
Make predictions on a batch of data rows (synchronously). This function internally divides the worklo...
int TreelitePredictorQueryPredTransform(PredictorHandle handle, char **out)
Get name of post prediction transformation used to train the loaded model.
size_t PredictInst(TreelitePredictorEntry *inst, bool pred_margin, float *out_result)
Make predictions on a single data row (synchronously). The work will be scheduled to the calling thre...
Definition: predictor.cc:489
int TreelitePredictorQueryNumFeature(PredictorHandle handle, size_t *out)
Get the width (number of features) of each instance used to train the loaded model.
int TreeliteBatchGetDimension(void *handle, int batch_sparse, size_t *out_num_row, size_t *out_num_col)
get dimensions of a batch
sparse batch in Compressed Sparse Row (CSR) format
Definition: predictor.h:23
const size_t * row_ptr
pointer to row headers; length of [num_row] + 1
Definition: predictor.h:29
void Load(const char *name)
load the prediction function from dynamic shared library.
Definition: predictor.cc:243
dense batch
Definition: predictor.h:37
float QueryGlobalBias() const
Get global bias which adjusting predicted margin scores.
Definition: predictor.h:200
const float * data
feature values
Definition: predictor.h:39
int TreelitePredictorLoad(const char *library_path, int num_worker_thread, PredictorHandle *out)
load prediction code into memory. This function assumes that the prediction code has been already com...
int TreeliteAssembleSparseBatch(const float *data, const uint32_t *col_ind, const size_t *row_ptr, size_t num_row, size_t num_col, CSRBatchHandle *out)
assemble a sparse batch
size_t PredictBatch(const CSRBatch *batch, int verbose, bool pred_margin, float *out_result)
Make predictions on a batch of data rows (synchronously). This function internally divides the worklo...
Definition: predictor.cc:477
float missing_value
value representing the missing value (usually nan)
Definition: predictor.h:41
size_t QueryResultSize(const CSRBatch *batch) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
Definition: predictor.h:105
int TreeliteDeleteDenseBatch(DenseBatchHandle handle)
delete a dense batch from memory
int TreeliteDeleteSparseBatch(CSRBatchHandle handle)
delete a sparse batch from memory
int TreeliteAssembleDenseBatch(const float *data, float missing_value, size_t num_row, size_t num_col, DenseBatchHandle *out)
assemble a dense batch
data layout. The value -1 signifies the missing value. When the "missing" field is set to -1...
Definition: entry.h:11
void * PredictorHandle
handle to predictor class
Definition: c_api_runtime.h:24
const float * data
feature values
Definition: predictor.h:25
std::string QueryPredTransform() const
Get name of post prediction transformation used to train the loaded model.
Definition: predictor.h:184
float QuerySigmoidAlpha() const
Get alpha value in sigmoid transformation used to train the loaded model.
Definition: predictor.h:192
size_t num_row
number of rows
Definition: predictor.h:43
size_t QueryNumOutputGroup() const
Get the number of output groups in the loaded model The number is 1 for most tasks; it is greater tha...
Definition: predictor.h:167
int TreelitePredictorQueryNumOutputGroup(PredictorHandle handle, size_t *out)
Get the number of output groups in the loaded model The number is 1 for most tasks; it is greater tha...
predictor class: wrapper for optimized prediction code
Definition: predictor.h:49
int TreelitePredictorFree(PredictorHandle handle)
delete predictor from memory
size_t num_row
number of rows
Definition: predictor.h:31
int TreelitePredictorPredictInst(PredictorHandle handle, union TreelitePredictorEntry *inst, int pred_margin, float *out_result, size_t *out_result_size)
Make predictions on a single data row (synchronously). The work will be scheduled to the calling thre...
size_t num_col
number of columns (i.e. # of features used)
Definition: predictor.h:33
int TreelitePredictorQuerySigmoidAlpha(PredictorHandle handle, float *out)
Get alpha value of sigmoid transformation used to train the loaded model.
int TreelitePredictorQueryGlobalBias(PredictorHandle handle, float *out)
Get global bias which adjusting predicted margin scores.
size_t num_col
number of columns (i.e. # of features used)
Definition: predictor.h:45