treelite
Main Page
Modules
Classes
Files
File List
File Members
include
treelite
predictor.h
Go to the documentation of this file.
1
7
#ifndef TREELITE_PREDICTOR_H_
8
#define TREELITE_PREDICTOR_H_
9
10
#include <dmlc/logging.h>
11
#include <cstdint>
12
13
namespace
treelite
{
14
16
struct
CSRBatch
{
18
const
float
*
data
;
20
const
uint32_t*
col_ind
;
22
const
size_t
*
row_ptr
;
24
size_t
num_row
;
26
size_t
num_col
;
27
};
28
30
struct
DenseBatch
{
32
const
float
*
data
;
34
float
missing_value
;
36
size_t
num_row
;
38
size_t
num_col
;
39
};
40
42
class
Predictor
{
43
public
:
48
union
Entry
{
49
int
missing;
50
float
fvalue;
51
// may contain extra fields later, such as qvalue
52
};
53
55
typedef
void
*
QueryFuncHandle
;
56
typedef
void
* PredFuncHandle;
57
typedef
void
* PredTransformFuncHandle;
58
typedef
void
* LibraryHandle;
59
60
Predictor
();
61
~
Predictor
();
66
void
Load(
const
char
* name);
70
void
Free();
71
84
size_t
PredictBatch(
const
CSRBatch
* batch,
int
nthread,
int
verbose,
85
bool
pred_margin,
float
* out_result)
const
;
86
size_t
PredictBatch(
const
DenseBatch
* batch,
int
nthread,
int
verbose,
87
bool
pred_margin,
float
* out_result)
const
;
88
95
inline
size_t
QueryResultSize
(
const
CSRBatch
* batch)
const
{
96
CHECK(pred_func_handle_ !=
nullptr
)
97
<<
"A shared library needs to be loaded first using Load()"
;
98
return
batch->
num_row
* num_output_group_;
99
}
100
inline
size_t
QueryResultSize(
const
DenseBatch
* batch)
const
{
101
CHECK(pred_func_handle_ !=
nullptr
)
102
<<
"A shared library needs to be loaded first using Load()"
;
103
return
batch->
num_row
* num_output_group_;
104
}
111
inline
size_t
QueryNumOutputGroup
()
const
{
112
return
num_output_group_;
113
}
114
115
private
:
116
LibraryHandle lib_handle_;
117
QueryFuncHandle query_func_handle_;
118
PredFuncHandle pred_func_handle_;
119
PredTransformFuncHandle pred_transform_func_handle_;
120
size_t
num_output_group_;
121
};
122
123
}
// namespace treelite
124
125
#endif // TREELITE_PREDICTOR_H_
treelite::CSRBatch::col_ind
const uint32_t * col_ind
feature indices
Definition:
predictor.h:20
treelite::Predictor::QueryResultSize
size_t QueryResultSize(const CSRBatch *batch) const
Given a batch of data rows, query the necessary size of array to hold predictions for all data points...
Definition:
predictor.h:95
treelite::Predictor::Entry
data layout. The value -1 signifies the missing value. When the "missing" field is set to -1...
Definition:
predictor.h:48
treelite
Definition:
annotator.h:13
treelite::Predictor::QueryNumOutputGroup
size_t QueryNumOutputGroup() const
Get the number of output groups in the loaded model The number is 1 for most tasks; it is greater tha...
Definition:
predictor.h:111
treelite::CSRBatch
sparse batch in Compressed Sparse Row (CSR) format
Definition:
predictor.h:16
treelite::CSRBatch::row_ptr
const size_t * row_ptr
pointer to row headers; length of [num_row] + 1
Definition:
predictor.h:22
treelite::Predictor::QueryFuncHandle
void * QueryFuncHandle
opaque handle types
Definition:
predictor.h:55
treelite::DenseBatch
dense batch
Definition:
predictor.h:30
treelite::DenseBatch::data
const float * data
feature values
Definition:
predictor.h:32
treelite::DenseBatch::missing_value
float missing_value
value representing the missing value (usually nan)
Definition:
predictor.h:34
treelite::CSRBatch::data
const float * data
feature values
Definition:
predictor.h:18
treelite::DenseBatch::num_row
size_t num_row
number of rows
Definition:
predictor.h:36
treelite::Predictor
predictor class: wrapper for optimized prediction code
Definition:
predictor.h:42
treelite::CSRBatch::num_row
size_t num_row
number of rows
Definition:
predictor.h:24
treelite::CSRBatch::num_col
size_t num_col
number of columns (i.e. # of features used)
Definition:
predictor.h:26
treelite::DenseBatch::num_col
size_t num_col
number of columns (i.e. # of features used)
Definition:
predictor.h:38
Generated by
1.8.11