treelite
c_api_runtime.cc
Go to the documentation of this file.
1 
8 #include <treelite/predictor.h>
10 #include "./c_api_error.h"
11 
12 using namespace treelite;
13 
14 int TreeliteAssembleSparseBatch(const float* data,
15  const uint32_t* col_ind,
16  const size_t* row_ptr,
17  size_t num_row, size_t num_col,
18  CSRBatchHandle* out) {
19  API_BEGIN();
20  CSRBatch* batch = new CSRBatch();
21  batch->data = data;
22  batch->col_ind = col_ind;
23  batch->row_ptr = row_ptr;
24  batch->num_row = num_row;
25  batch->num_col = num_col;
26  *out = static_cast<CSRBatchHandle>(batch);
27  API_END();
28 }
29 
31  API_BEGIN();
32  delete static_cast<CSRBatch*>(handle);
33  API_END();
34 }
35 
36 int TreeliteAssembleDenseBatch(const float* data, float missing_value,
37  size_t num_row, size_t num_col,
38  DenseBatchHandle* out) {
39  API_BEGIN();
40  DenseBatch* batch = new DenseBatch();
41  batch->data = data;
42  batch->missing_value = missing_value;
43  batch->num_row = num_row;
44  batch->num_col = num_col;
45  *out = static_cast<DenseBatchHandle>(batch);
46  API_END();
47 }
48 
50  API_BEGIN();
51  delete static_cast<DenseBatch*>(handle);
52  API_END();
53 }
54 
55 int TreeliteBatchGetDimension(void* handle,
56  int batch_sparse,
57  size_t* out_num_row,
58  size_t* out_num_col) {
59  API_BEGIN();
60  if (batch_sparse) {
61  const CSRBatch* batch_ = static_cast<CSRBatch*>(handle);
62  *out_num_row = batch_->num_row;
63  *out_num_col = batch_->num_col;
64  } else {
65  const DenseBatch* batch_ = static_cast<DenseBatch*>(handle);
66  *out_num_row = batch_->num_row;
67  *out_num_col = batch_->num_col;
68  }
69  API_END();
70 }
71 
72 int TreelitePredictorLoad(const char* library_path,
73  int num_worker_thread,
74  int include_master_thread,
75  PredictorHandle* out) {
76  API_BEGIN();
77  Predictor* predictor = new Predictor(num_worker_thread,
78  (bool)include_master_thread);
79  predictor->Load(library_path);
80  *out = static_cast<PredictorHandle>(predictor);
81  API_END();
82 }
83 
85  void* batch,
86  int batch_sparse,
87  int verbose,
88  int pred_margin,
89  float* out_result,
90  size_t* out_result_size) {
91  API_BEGIN();
92  Predictor* predictor_ = static_cast<Predictor*>(handle);
93  if (batch_sparse) {
94  const CSRBatch* batch_ = static_cast<CSRBatch*>(batch);
95  *out_result_size = predictor_->PredictBatch(batch_, verbose,
96  (pred_margin != 0), out_result);
97  } else {
98  const DenseBatch* batch_ = static_cast<DenseBatch*>(batch);
99  *out_result_size = predictor_->PredictBatch(batch_, verbose,
100  (pred_margin != 0), out_result);
101  }
102  API_END();
103 }
104 
106  void* batch,
107  int batch_sparse,
108  size_t* out) {
109  API_BEGIN();
110  const Predictor* predictor_ = static_cast<Predictor*>(handle);
111  if (batch_sparse) {
112  const CSRBatch* batch_ = static_cast<CSRBatch*>(batch);
113  *out = predictor_->QueryResultSize(batch_);
114  } else {
115  const DenseBatch* batch_ = static_cast<DenseBatch*>(batch);
116  *out = predictor_->QueryResultSize(batch_);
117  }
118  API_END();
119 }
120 
122  API_BEGIN();
123  const Predictor* predictor_ = static_cast<Predictor*>(handle);
124  *out = predictor_->QueryNumOutputGroup();
125  API_END();
126 }
127 
129  API_BEGIN();
130  delete static_cast<Predictor*>(handle);
131  API_END();
132 }
Load prediction function exported as a shared library.
void * DenseBatchHandle
handle to batch of dense data rows
Definition: c_api_runtime.h:27
void * CSRBatchHandle
handle to batch of sparse data rows
Definition: c_api_runtime.h:25
const uint32_t * col_ind
feature indices
Definition: predictor.h:20
#define API_BEGIN()
macro to guard beginning and end section of all functions
Definition: c_api_error.h:15
size_t QueryResultSize(const CSRBatch *batch) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
Definition: predictor.h:96
C API of treelite, used for interfacing with other languages This header is used exclusively by the r...
int TreelitePredictorQueryResultSize(PredictorHandle handle, void *batch, int batch_sparse, size_t *out)
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
int TreelitePredictorLoad(const char *library_path, int num_worker_thread, int include_master_thread, PredictorHandle *out)
load prediction code into memory. This function assumes that the prediction code has been already com...
size_t QueryNumOutputGroup() const
Get the number of output groups in the loaded model The number is 1 for most tasks; it is greater tha...
Definition: predictor.h:126
int TreelitePredictorPredictBatch(PredictorHandle handle, void *batch, int batch_sparse, int verbose, int pred_margin, float *out_result, size_t *out_result_size)
Make predictions on a batch of data rows (synchronously). This function internally divides the worklo...
int TreeliteBatchGetDimension(void *handle, int batch_sparse, size_t *out_num_row, size_t *out_num_col)
get dimensions of a batch
sparse batch in Compressed Sparse Row (CSR) format
Definition: predictor.h:16
const size_t * row_ptr
pointer to row headers; length of [num_row] + 1
Definition: predictor.h:22
void Load(const char *name)
load the prediction function from dynamic shared library.
Definition: predictor.cc:210
dense batch
Definition: predictor.h:30
const float * data
feature values
Definition: predictor.h:32
Error handling for C API.
int TreeliteAssembleSparseBatch(const float *data, const uint32_t *col_ind, const size_t *row_ptr, size_t num_row, size_t num_col, CSRBatchHandle *out)
assemble a sparse batch
size_t PredictBatch(const CSRBatch *batch, int verbose, bool pred_margin, float *out_result)
Make predictions on a batch of data rows (synchronously). This function internally divides the worklo...
Definition: predictor.cc:355
float missing_value
value representing the missing value (usually nan)
Definition: predictor.h:34
int TreeliteDeleteDenseBatch(DenseBatchHandle handle)
delete a dense batch from memory
int TreeliteDeleteSparseBatch(CSRBatchHandle handle)
delete a sparse batch from memory
int TreeliteAssembleDenseBatch(const float *data, float missing_value, size_t num_row, size_t num_col, DenseBatchHandle *out)
assemble a dense batch
void * PredictorHandle
handle to predictor class
Definition: c_api_runtime.h:23
const float * data
feature values
Definition: predictor.h:18
size_t num_row
number of rows
Definition: predictor.h:36
int TreelitePredictorQueryNumOutputGroup(PredictorHandle handle, size_t *out)
Get the number of output groups in the loaded model The number is 1 for most tasks; it is greater tha...
predictor class: wrapper for optimized prediction code
Definition: predictor.h:42
int TreelitePredictorFree(PredictorHandle handle)
delete predictor from memory
size_t num_row
number of rows
Definition: predictor.h:24
size_t num_col
number of columns (i.e. # of features used)
Definition: predictor.h:26
#define API_END()
every function starts with API_BEGIN(); and finishes with API_END() or API_END_HANDLE_ERROR ...
Definition: c_api_error.h:18
size_t num_col
number of columns (i.e. # of features used)
Definition: predictor.h:38