7 #ifndef TREELITE_TREE_H_ 8 #define TREELITE_TREE_H_ 13 #include <unordered_map> 17 #include <treelite/common.h> 18 #include <dmlc/logging.h> 28 Node() : sindex_(0), missing_category_to_zero_(
false) {}
43 return sindex_ & ((1U << 31) - 1U);
47 return (sindex_ >> 31) != 0;
55 return (this->info_).leaf_value;
62 return this->leaf_vector_;
68 return !(this->leaf_vector_.empty());
72 return (this->info_).threshold;
76 return parent_ & ((1U << 31) - 1);
80 return (parent_ & (1U << 31)) != 0;
92 return left_categories_;
100 return data_count_.has_value();
104 return data_count_.value();
108 return sum_hess_.has_value();
112 return sum_hess_.value();
116 return gain_.has_value();
120 return gain_.value();
125 return missing_category_to_zero_;
137 CHECK_LT(split_index, (1U << 31) - 1) <<
"split_index too big";
138 if (default_left) split_index |= (1U << 31);
140 (this->info_).threshold = threshold;
142 this->split_type_ = SplitFeatureType::kNumerical;
155 CHECK_LT(split_index, (1U << 31) - 1) <<
"split_index too big";
156 if (default_left) split_index |= (1U << 31);
159 std::sort(this->left_categories_.begin(), this->left_categories_.end());
160 this->split_type_ = SplitFeatureType::kCategorical;
171 this->split_type_ = SplitFeatureType::kNone;
182 this->split_type_ = SplitFeatureType::kNone;
212 this->parent_ = pidx;
226 std::vector<tl_float> leaf_vector_;
256 std::vector<uint32_t> left_categories_;
261 bool missing_category_to_zero_;
266 dmlc::optional<size_t> data_count_;
273 dmlc::optional<double> sum_hess_;
277 dmlc::optional<double> gain_;
282 std::vector<Node> nodes;
284 inline int AllocNode() {
286 CHECK_LT(
num_nodes, std::numeric_limits<int>::max())
287 <<
"number of nodes in the tree exceed 2^31";
315 nodes[0].set_leaf(0.0f);
316 nodes[0].set_parent(-1);
323 const int cleft = this->AllocNode();
324 const int cright = this->AllocNode();
325 nodes[nid].cleft_ =
cleft;
326 nodes[nid].cright_ =
cright;
327 nodes[
cleft].set_parent(nid,
true);
328 nodes[
cright].set_parent(nid,
false);
336 std::unordered_map<unsigned, bool> tmp;
337 for (
int nid = 0; nid <
num_nodes; ++nid) {
338 const Node& node = nodes[nid];
340 if (type != SplitFeatureType::kNone) {
341 const bool flag = (type == SplitFeatureType::kCategorical);
343 if (tmp.count(split_index) == 0) {
346 CHECK_EQ(tmp[split_index], flag) <<
"Feature " << split_index
347 <<
" cannot be simultaneously be categorical and numerical.";
351 std::vector<unsigned> result;
352 for (
const auto& kv : tmp) {
354 result.push_back(kv.first);
357 std::sort(result.begin(), result.end());
404 DMLC_DECLARE_FIELD(pred_transform).set_default(
"identity")
405 .describe(
"name of prediction transform function");
406 DMLC_DECLARE_FIELD(sigmoid_alpha).set_default(1.0f)
407 .set_lower_bound(0.0f)
408 .describe(
"scaling parameter for sigmoid function");
409 DMLC_DECLARE_FIELD(global_bias).set_default(0.0f)
410 .describe(
"global bias of the model");
414 inline void InitParamAndCheck(
ModelParam* param,
415 const std::vector<std::pair<std::string, std::string>> cfg) {
416 auto unknown = param->InitAllowUnknown(cfg);
417 if (unknown.size() > 0) {
418 std::ostringstream oss;
419 for (
const auto& kv : unknown) {
420 oss << kv.first <<
", ";
422 LOG(INFO) <<
"\033[1;31mWarning: Unknown parameters found; " 423 <<
"they have been ignored\u001B[0m: " << oss.str();
447 param.Init(std::vector<std::pair<std::string, std::string>>());
456 #endif // TREELITE_TREE_H_ int cright() const
index of right child
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
SplitFeatureType
feature split type
void set_data_count(size_t data_count)
set the data count of the node
void Init()
initialize the model with a single root node
thin wrapper for tree ensemble model
void set_leaf(tl_float value)
set the leaf value of the node
std::vector< Tree > trees
member trees
void set_parent(int pidx, bool is_left_child=true)
set parent of the node
float sigmoid_alpha
scaling parameter for sigmoid function sigmoid(x) = 1 / (1 + exp(-alpha * x))
int cleft() const
index of left child
void set_gain(double gain)
set the gain value of the node
ModelParam param
extra parameters
const std::vector< tl_float > & leaf_vector() const
in-memory representation of a decision tree
void set_categorical_split(unsigned split_index, bool default_left, bool missing_category_to_zero, const std::vector< uint32_t > &left_categories)
create a categorical split
tl_float leaf_value() const
float global_bias
global bias of the model
const std::vector< uint32_t > & left_categories() const
get categories for left child node
bool has_sum_hess() const
test whether this node has hessian sum
Operator comparison_op() const
get comparison operator
size_t data_count() const
get data count
bool has_leaf_vector() const
unsigned split_index() const
feature index of split condition
bool default_left() const
when feature is unknown, whether goes to left child
std::string pred_transform
name of prediction transform function
bool is_leaf() const
whether current node is leaf node
void set_numerical_split(unsigned split_index, tl_float threshold, bool default_left, Operator cmp)
create a numerical split
bool random_forest_flag
flag for random forest; True for random forests and False for gradient boosted trees ...
void set_leaf_vector(const std::vector< tl_float > &leaf_vector)
set the leaf vector of the node; useful for multi-class random forest classifier
int num_nodes
number of nodes
double tl_float
float type to be used internally
tl_float threshold() const
std::vector< unsigned > GetCategoricalFeatures() const
get list of all categorical features that have appeared anywhere in tree
bool has_data_count() const
test whether this node has data count
double gain() const
get gain value
defines configuration macros of treelite
Node & operator[](int nid)
get node given nid
bool missing_category_to_zero() const
test whether missing values should be converted into zero; only applicable for categorical splits ...
void AddChilds(int nid)
add child nodes to node
bool is_root() const
whether current node is root
void set_sum_hess(double sum_hess)
set the hessian sum of the node
bool has_gain() const
test whether this node has gain value
int cdefault() const
index of default child when feature is missing
bool is_left_child() const
whether current node is left child
double sum_hess() const
get hessian sum
SplitFeatureType split_type() const
get feature split type
const Node & operator[](int nid) const
get node given nid (const version)
int parent() const
get parent of the node
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Model()
disable copy; use default move
Operator
comparison operators