treelite
gtil.h
Go to the documentation of this file.
1 
9 #ifndef TREELITE_GTIL_H_
10 #define TREELITE_GTIL_H_
11 
12 #include <cstddef>
13 #include <cstdint>
14 #include <optional>
15 #include <string>
16 #include <variant>
17 #include <vector>
18 
19 namespace treelite {
20 
21 class Model;
22 
23 namespace gtil {
24 
26 enum class PredictKind : std::int8_t {
31  kPredictDefault = 0,
36  kPredictRaw = 1,
41  kPredictLeafID = 2,
46  kPredictPerTree = 3
47 };
48 
50 struct Configuration {
51  int nthread{0}; // use all threads by default
53  Configuration() = default;
54  explicit Configuration(std::string const& config_json);
55 };
56 
66 template <typename InputT>
67 void Predict(Model const& model, InputT const* input, std::uint64_t num_row, InputT* output,
68  Configuration const& config);
69 
85 template <typename InputT>
86 void PredictSparse(Model const& model, InputT const* data, std::uint64_t const* col_ind,
87  std::uint64_t const* row_ptr, std::uint64_t num_row, InputT* output,
88  Configuration const& config);
89 
98 std::vector<std::uint64_t> GetOutputShape(
99  Model const& model, std::uint64_t num_row, Configuration const& config);
100 
101 extern template void Predict<float>(
102  Model const&, float const*, std::uint64_t, float*, Configuration const&);
103 extern template void Predict<double>(
104  Model const&, double const*, std::uint64_t, double*, Configuration const&);
105 extern template void PredictSparse<float>(Model const&, float const*, std::uint64_t const*,
106  std::uint64_t const*, std::uint64_t, float*, Configuration const&);
107 extern template void PredictSparse<double>(Model const&, double const*, std::uint64_t const*,
108  std::uint64_t const*, std::uint64_t, double*, Configuration const&);
109 
110 } // namespace gtil
111 } // namespace treelite
112 
113 #endif // TREELITE_GTIL_H_
Model class for tree ensemble model.
Definition: tree.h:446
template void PredictSparse< float >(Model const &, float const *, std::uint64_t const *, std::uint64_t const *, std::uint64_t, float *, Configuration const &)
PredictKind
Prediction type.
Definition: gtil.h:26
@ kPredictLeafID
Output one (integer) leaf ID per tree. Expected output dimensions: (num_row, num_tree)
@ kPredictRaw
Sum over trees, but don't apply post-processing; get raw margin scores instead. Expected output dimen...
@ kPredictPerTree
Output one or more margin scores per tree. Expected output dimensions: (num_row, num_tree,...
@ kPredictDefault
Usual prediction method: sum over trees and apply post-processing. Expected output dimensions: (num_r...
template void Predict< double >(Model const &, double const *, std::uint64_t, double *, Configuration const &)
template void PredictSparse< double >(Model const &, double const *, std::uint64_t const *, std::uint64_t const *, std::uint64_t, double *, Configuration const &)
std::vector< std::uint64_t > GetOutputShape(Model const &model, std::uint64_t num_row, Configuration const &config)
Given a data matrix, query the necessary shape of array to hold predictions for all data points.
template void Predict< float >(Model const &, float const *, std::uint64_t, float *, Configuration const &)
void PredictSparse(Model const &model, InputT const *data, std::uint64_t const *col_ind, std::uint64_t const *row_ptr, std::uint64_t num_row, InputT *output, Configuration const &config)
Predict with sparse data with CSR (compressed sparse row) layout.
void Predict(Model const &model, InputT const *input, std::uint64_t num_row, InputT *output, Configuration const &config)
Predict with dense data.
Definition: contiguous_array.h:14
Configuration class.
Definition: gtil.h:50
Configuration(std::string const &config_json)
int nthread
Definition: gtil.h:51
PredictKind pred_kind
Definition: gtil.h:52