9 #ifndef TREELITE_GTIL_H_
10 #define TREELITE_GTIL_H_
66 template <
typename InputT>
67 void Predict(
Model const& model, InputT
const* input, std::uint64_t num_row, InputT* output,
85 template <
typename InputT>
87 std::uint64_t
const* row_ptr, std::uint64_t num_row, InputT* output,
106 std::uint64_t
const*, std::uint64_t,
float*,
Configuration const&);
108 std::uint64_t
const*, std::uint64_t,
double*,
Configuration const&);
Model class for tree ensemble model.
Definition: tree.h:446
template void PredictSparse< float >(Model const &, float const *, std::uint64_t const *, std::uint64_t const *, std::uint64_t, float *, Configuration const &)
PredictKind
Prediction type.
Definition: gtil.h:26
@ kPredictLeafID
Output one (integer) leaf ID per tree. Expected output dimensions: (num_row, num_tree)
@ kPredictRaw
Sum over trees, but don't apply post-processing; get raw margin scores instead. Expected output dimen...
@ kPredictPerTree
Output one or more margin scores per tree. Expected output dimensions: (num_row, num_tree,...
@ kPredictDefault
Usual prediction method: sum over trees and apply post-processing. Expected output dimensions: (num_r...
template void Predict< double >(Model const &, double const *, std::uint64_t, double *, Configuration const &)
template void PredictSparse< double >(Model const &, double const *, std::uint64_t const *, std::uint64_t const *, std::uint64_t, double *, Configuration const &)
std::vector< std::uint64_t > GetOutputShape(Model const &model, std::uint64_t num_row, Configuration const &config)
Given a data matrix, query the necessary shape of array to hold predictions for all data points.
template void Predict< float >(Model const &, float const *, std::uint64_t, float *, Configuration const &)
void PredictSparse(Model const &model, InputT const *data, std::uint64_t const *col_ind, std::uint64_t const *row_ptr, std::uint64_t num_row, InputT *output, Configuration const &config)
Predict with sparse data with CSR (compressed sparse row) layout.
void Predict(Model const &model, InputT const *input, std::uint64_t num_row, InputT *output, Configuration const &config)
Predict with dense data.
Definition: contiguous_array.h:14
Configuration class.
Definition: gtil.h:50
Configuration(std::string const &config_json)
int nthread
Definition: gtil.h:51
PredictKind pred_kind
Definition: gtil.h:52