treelite
c_api_runtime.cc
Go to the documentation of this file.
1 
8 #include <treelite/predictor.h>
10 #include <string>
11 #include "./c_api_error.h"
12 
13 using namespace treelite;
14 
15 int TreeliteAssembleSparseBatch(const float* data,
16  const uint32_t* col_ind,
17  const size_t* row_ptr,
18  size_t num_row, size_t num_col,
19  CSRBatchHandle* out) {
20  API_BEGIN();
21  CSRBatch* batch = new CSRBatch();
22  batch->data = data;
23  batch->col_ind = col_ind;
24  batch->row_ptr = row_ptr;
25  batch->num_row = num_row;
26  batch->num_col = num_col;
27  *out = static_cast<CSRBatchHandle>(batch);
28  API_END();
29 }
30 
32  API_BEGIN();
33  delete static_cast<CSRBatch*>(handle);
34  API_END();
35 }
36 
37 int TreeliteAssembleDenseBatch(const float* data, float missing_value,
38  size_t num_row, size_t num_col,
39  DenseBatchHandle* out) {
40  API_BEGIN();
41  DenseBatch* batch = new DenseBatch();
42  batch->data = data;
43  batch->missing_value = missing_value;
44  batch->num_row = num_row;
45  batch->num_col = num_col;
46  *out = static_cast<DenseBatchHandle>(batch);
47  API_END();
48 }
49 
51  API_BEGIN();
52  delete static_cast<DenseBatch*>(handle);
53  API_END();
54 }
55 
56 int TreeliteBatchGetDimension(void* handle,
57  int batch_sparse,
58  size_t* out_num_row,
59  size_t* out_num_col) {
60  API_BEGIN();
61  if (batch_sparse) {
62  const CSRBatch* batch_ = static_cast<CSRBatch*>(handle);
63  *out_num_row = batch_->num_row;
64  *out_num_col = batch_->num_col;
65  } else {
66  const DenseBatch* batch_ = static_cast<DenseBatch*>(handle);
67  *out_num_row = batch_->num_row;
68  *out_num_col = batch_->num_col;
69  }
70  API_END();
71 }
72 
73 int TreelitePredictorLoad(const char* library_path,
74  int num_worker_thread,
75  PredictorHandle* out) {
76  API_BEGIN();
77  Predictor* predictor = new Predictor(num_worker_thread);
78  predictor->Load(library_path);
79  *out = static_cast<PredictorHandle>(predictor);
80  API_END();
81 }
82 
84  void* batch,
85  int batch_sparse,
86  int verbose,
87  int pred_margin,
88  float* out_result,
89  size_t* out_result_size) {
90  API_BEGIN();
91  Predictor* predictor_ = static_cast<Predictor*>(handle);
92  const size_t num_feature = predictor_->QueryNumFeature();
93  const std::string err_msg
94  = std::string("Too many columns (features) in the given batch. "
95  "Number of features must not exceed ")
96  + std::to_string(num_feature);
97  if (batch_sparse) {
98  const CSRBatch* batch_ = static_cast<CSRBatch*>(batch);
99  CHECK_LE(batch_->num_col, num_feature) << err_msg;
100  *out_result_size = predictor_->PredictBatch(batch_, verbose,
101  (pred_margin != 0), out_result);
102  } else {
103  const DenseBatch* batch_ = static_cast<DenseBatch*>(batch);
104  CHECK_LE(batch_->num_col, num_feature) << err_msg;
105  *out_result_size = predictor_->PredictBatch(batch_, verbose,
106  (pred_margin != 0), out_result);
107  }
108  API_END();
109 }
110 
112  union TreelitePredictorEntry* inst,
113  int pred_margin,
114  float* out_result, size_t* out_result_size) {
115  API_BEGIN();
116  Predictor* predictor_ = static_cast<Predictor*>(handle);
117  *out_result_size
118  = predictor_->PredictInst(inst, (pred_margin != 0), out_result);
119  API_END();
120 }
121 
123  void* batch,
124  int batch_sparse,
125  size_t* out) {
126  API_BEGIN();
127  const Predictor* predictor_ = static_cast<Predictor*>(handle);
128  if (batch_sparse) {
129  const CSRBatch* batch_ = static_cast<CSRBatch*>(batch);
130  *out = predictor_->QueryResultSize(batch_);
131  } else {
132  const DenseBatch* batch_ = static_cast<DenseBatch*>(batch);
133  *out = predictor_->QueryResultSize(batch_);
134  }
135  API_END();
136 }
137 
139  size_t* out) {
140  API_BEGIN();
141  const Predictor* predictor_ = static_cast<Predictor*>(handle);
142  *out = predictor_->QueryResultSizeSingleInst();
143  API_END();
144 }
145 
147  API_BEGIN();
148  const Predictor* predictor_ = static_cast<Predictor*>(handle);
149  *out = predictor_->QueryNumOutputGroup();
150  API_END();
151 }
152 
154  API_BEGIN();
155  const Predictor* predictor_ = static_cast<Predictor*>(handle);
156  *out = predictor_->QueryNumFeature();
157  API_END();
158 }
159 
161  API_BEGIN();
162  delete static_cast<Predictor*>(handle);
163  API_END();
164 }
Load prediction function exported as a shared library.
void * DenseBatchHandle
handle to batch of dense data rows
Definition: c_api_runtime.h:28
size_t QueryNumFeature() const
Get the width (number of features) of each instance used to train the loaded model.
Definition: predictor.h:176
void * CSRBatchHandle
handle to batch of sparse data rows
Definition: c_api_runtime.h:26
const uint32_t * col_ind
feature indices
Definition: predictor.h:27
size_t QueryResultSizeSingleInst() const
Query the necessary size of array to hold the prediction for a single data row.
Definition: predictor.h:156
C API of treelite, used for interfacing with other languages This header is used exclusively by the r...
int TreelitePredictorQueryResultSize(PredictorHandle handle, void *batch, int batch_sparse, size_t *out)
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
int TreelitePredictorQueryResultSizeSingleInst(PredictorHandle handle, size_t *out)
Query the necessary size of array to hold the prediction for a single data row.
int TreelitePredictorPredictBatch(PredictorHandle handle, void *batch, int batch_sparse, int verbose, int pred_margin, float *out_result, size_t *out_result_size)
Make predictions on a batch of data rows (synchronously). This function internally divides the worklo...
size_t PredictInst(TreelitePredictorEntry *inst, bool pred_margin, float *out_result)
Make predictions on a single data row (synchronously). The work will be scheduled to the calling thre...
Definition: predictor.cc:449
int TreelitePredictorQueryNumFeature(PredictorHandle handle, size_t *out)
Get the width (number of features) of each instance used to train the loaded model.
int TreeliteBatchGetDimension(void *handle, int batch_sparse, size_t *out_num_row, size_t *out_num_col)
get dimensions of a batch
sparse batch in Compressed Sparse Row (CSR) format
Definition: predictor.h:23
const size_t * row_ptr
pointer to row headers; length of [num_row] + 1
Definition: predictor.h:29
void Load(const char *name)
load the prediction function from dynamic shared library.
Definition: predictor.cc:243
dense batch
Definition: predictor.h:37
const float * data
feature values
Definition: predictor.h:39
int TreelitePredictorLoad(const char *library_path, int num_worker_thread, PredictorHandle *out)
load prediction code into memory. This function assumes that the prediction code has been already com...
int TreeliteAssembleSparseBatch(const float *data, const uint32_t *col_ind, const size_t *row_ptr, size_t num_row, size_t num_col, CSRBatchHandle *out)
assemble a sparse batch
size_t PredictBatch(const CSRBatch *batch, int verbose, bool pred_margin, float *out_result)
Make predictions on a batch of data rows (synchronously). This function internally divides the worklo...
Definition: predictor.cc:437
float missing_value
value representing the missing value (usually nan)
Definition: predictor.h:41
size_t QueryResultSize(const CSRBatch *batch) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
Definition: predictor.h:105
int TreeliteDeleteDenseBatch(DenseBatchHandle handle)
delete a dense batch from memory
int TreeliteDeleteSparseBatch(CSRBatchHandle handle)
delete a sparse batch from memory
int TreeliteAssembleDenseBatch(const float *data, float missing_value, size_t num_row, size_t num_col, DenseBatchHandle *out)
assemble a dense batch
data layout. The value -1 signifies the missing value. When the "missing" field is set to -1...
Definition: entry.h:11
void * PredictorHandle
handle to predictor class
Definition: c_api_runtime.h:24
const float * data
feature values
Definition: predictor.h:25
size_t num_row
number of rows
Definition: predictor.h:43
size_t QueryNumOutputGroup() const
Get the number of output groups in the loaded model The number is 1 for most tasks; it is greater tha...
Definition: predictor.h:167
int TreelitePredictorQueryNumOutputGroup(PredictorHandle handle, size_t *out)
Get the number of output groups in the loaded model The number is 1 for most tasks; it is greater tha...
predictor class: wrapper for optimized prediction code
Definition: predictor.h:49
int TreelitePredictorFree(PredictorHandle handle)
delete predictor from memory
size_t num_row
number of rows
Definition: predictor.h:31
int TreelitePredictorPredictInst(PredictorHandle handle, union TreelitePredictorEntry *inst, int pred_margin, float *out_result, size_t *out_result_size)
Make predictions on a single data row (synchronously). The work will be scheduled to the calling thre...
size_t num_col
number of columns (i.e. # of features used)
Definition: predictor.h:33
size_t num_col
number of columns (i.e. # of features used)
Definition: predictor.h:45