Treelite
tree_impl.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_TREE_IMPL_H_
8 #define TREELITE_TREE_IMPL_H_
9 
10 #include <algorithm>
11 #include <limits>
12 #include <map>
13 #include <string>
14 #include <utility>
15 #include <vector>
16 #include <unordered_map>
17 #include <sstream>
18 #include <iomanip>
19 #include <stdexcept>
20 #include <iostream>
21 
22 namespace {
23 
24 template <typename T>
25 inline std::string GetString(T x) {
26  return std::to_string(x);
27 }
28 
29 template <>
30 inline std::string GetString<float>(float x) {
31  std::ostringstream oss;
32  oss << std::setprecision(std::numeric_limits<float>::max_digits10) << x;
33  return oss.str();
34 }
35 
36 template <>
37 inline std::string GetString<double>(double x) {
38  std::ostringstream oss;
39  oss << std::setprecision(std::numeric_limits<double>::max_digits10) << x;
40  return oss.str();
41 }
42 
43 } // anonymous namespace
44 
45 namespace treelite {
46 
47 template <typename T>
48 ContiguousArray<T>::ContiguousArray()
49  : buffer_(nullptr), size_(0), capacity_(0), owned_buffer_(true) {}
50 
51 template <typename T>
52 ContiguousArray<T>::~ContiguousArray() {
53  if (buffer_ && owned_buffer_) {
54  std::free(buffer_);
55  }
56 }
57 
58 template <typename T>
59 ContiguousArray<T>::ContiguousArray(ContiguousArray&& other) noexcept
60  : buffer_(other.buffer_), size_(other.size_), capacity_(other.capacity_),
61  owned_buffer_(other.owned_buffer_) {
62  other.buffer_ = nullptr;
63  other.size_ = other.capacity_ = 0;
64 }
65 
66 template <typename T>
67 ContiguousArray<T>&
68 ContiguousArray<T>::operator=(ContiguousArray&& other) noexcept {
69  if (buffer_ && owned_buffer_) {
70  std::free(buffer_);
71  }
72  buffer_ = other.buffer_;
73  size_ = other.size_;
74  capacity_ = other.capacity_;
75  owned_buffer_ = other.owned_buffer_;
76  other.buffer_ = nullptr;
77  other.size_ = other.capacity_ = 0;
78  return *this;
79 }
80 
81 template <typename T>
82 inline ContiguousArray<T>
83 ContiguousArray<T>::Clone() const {
84  ContiguousArray clone;
85  clone.buffer_ = static_cast<T*>(std::malloc(sizeof(T) * capacity_));
86  if (!clone.buffer_) {
87  throw std::runtime_error("Could not allocate memory for the clone");
88  }
89  std::memcpy(clone.buffer_, buffer_, sizeof(T) * size_);
90  clone.size_ = size_;
91  clone.capacity_ = capacity_;
92  clone.owned_buffer_ = true;
93  return clone;
94 }
95 
96 template <typename T>
97 inline void
98 ContiguousArray<T>::UseForeignBuffer(void* prealloc_buf, size_t size) {
99  if (buffer_ && owned_buffer_) {
100  std::free(buffer_);
101  }
102  buffer_ = static_cast<T*>(prealloc_buf);
103  size_ = size;
104  capacity_ = size;
105  owned_buffer_ = false;
106 }
107 
108 template <typename T>
109 inline T*
110 ContiguousArray<T>::Data() {
111  return buffer_;
112 }
113 
114 template <typename T>
115 inline const T*
116 ContiguousArray<T>::Data() const {
117  return buffer_;
118 }
119 
120 template <typename T>
121 inline T*
122 ContiguousArray<T>::End() {
123  return &buffer_[Size()];
124 }
125 
126 template <typename T>
127 inline const T*
128 ContiguousArray<T>::End() const {
129  return &buffer_[Size()];
130 }
131 
132 template <typename T>
133 inline T&
134 ContiguousArray<T>::Back() {
135  return buffer_[Size() - 1];
136 }
137 
138 template <typename T>
139 inline const T&
140 ContiguousArray<T>::Back() const {
141  return buffer_[Size() - 1];
142 }
143 
144 template <typename T>
145 inline size_t
146 ContiguousArray<T>::Size() const {
147  return size_;
148 }
149 
150 template <typename T>
151 inline void
152 ContiguousArray<T>::Reserve(size_t newsize) {
153  if (!owned_buffer_) {
154  throw std::runtime_error("Cannot resize when using a foreign buffer; clone first");
155  }
156  T* newbuf = static_cast<T*>(std::realloc(static_cast<void*>(buffer_), sizeof(T) * newsize));
157  if (!newbuf) {
158  throw std::runtime_error("Could not expand buffer");
159  }
160  buffer_ = newbuf;
161  capacity_ = newsize;
162 }
163 
164 template <typename T>
165 inline void
166 ContiguousArray<T>::Resize(size_t newsize) {
167  if (!owned_buffer_) {
168  throw std::runtime_error("Cannot resize when using a foreign buffer; clone first");
169  }
170  if (newsize > capacity_) {
171  size_t newcapacity = capacity_;
172  if (newcapacity == 0) {
173  newcapacity = 1;
174  }
175  while (newcapacity <= newsize) {
176  newcapacity *= 2;
177  }
178  Reserve(newcapacity);
179  }
180  size_ = newsize;
181 }
182 
183 template <typename T>
184 inline void
185 ContiguousArray<T>::Resize(size_t newsize, T t) {
186  if (!owned_buffer_) {
187  throw std::runtime_error("Cannot resize when using a foreign buffer; clone first");
188  }
189  size_t oldsize = Size();
190  Resize(newsize);
191  for (size_t i = oldsize; i < newsize; ++i) {
192  buffer_[i] = t;
193  }
194 }
195 
196 template <typename T>
197 inline void
198 ContiguousArray<T>::Clear() {
199  if (!owned_buffer_) {
200  throw std::runtime_error("Cannot clear when using a foreign buffer; clone first");
201  }
202  Resize(0);
203 }
204 
205 template <typename T>
206 inline void
207 ContiguousArray<T>::PushBack(T t) {
208  if (!owned_buffer_) {
209  throw std::runtime_error("Cannot add element when using a foreign buffer; clone first");
210  }
211  if (size_ == capacity_) {
212  Reserve(capacity_ * 2);
213  }
214  buffer_[size_++] = t;
215 }
216 
217 template <typename T>
218 inline void
219 ContiguousArray<T>::Extend(const std::vector<T>& other) {
220  if (!owned_buffer_) {
221  throw std::runtime_error("Cannot add elements when using a foreign buffer; clone first");
222  }
223  size_t newsize = size_ + other.size();
224  if (newsize > capacity_) {
225  size_t newcapacity = capacity_;
226  if (newcapacity == 0) {
227  newcapacity = 1;
228  }
229  while (newcapacity <= newsize) {
230  newcapacity *= 2;
231  }
232  Reserve(newcapacity);
233  }
234  std::memcpy(&buffer_[size_], static_cast<const void*>(other.data()), sizeof(T) * other.size());
235  size_ = newsize;
236 }
237 
238 template <typename T>
239 inline T&
240 ContiguousArray<T>::operator[](size_t idx) {
241  return buffer_[idx];
242 }
243 
244 template <typename T>
245 inline const T&
246 ContiguousArray<T>::operator[](size_t idx) const {
247  return buffer_[idx];
248 }
249 
250 template<typename Container>
251 inline std::vector<std::pair<std::string, std::string> >
252 ModelParam::InitAllowUnknown(const Container& kwargs) {
253  std::vector<std::pair<std::string, std::string>> unknowns;
254  for (const auto& e : kwargs) {
255  if (e.first == "pred_transform") {
256  std::strncpy(this->pred_transform, e.second.c_str(),
257  TREELITE_MAX_PRED_TRANSFORM_LENGTH - 1);
258  this->pred_transform[TREELITE_MAX_PRED_TRANSFORM_LENGTH - 1] = '\0';
259  } else if (e.first == "sigmoid_alpha") {
260  this->sigmoid_alpha = dmlc::stof(e.second, nullptr);
261  } else if (e.first == "global_bias") {
262  this->global_bias = dmlc::stof(e.second, nullptr);
263  }
264  }
265  return unknowns;
266 }
267 
268 inline std::map<std::string, std::string>
269 ModelParam::__DICT__() const {
270  std::map<std::string, std::string> ret;
271  ret.emplace("pred_transform", std::string(this->pred_transform));
272  ret.emplace("sigmoid_alpha", GetString(this->sigmoid_alpha));
273  ret.emplace("global_bias", GetString(this->global_bias));
274  return ret;
275 }
276 
277 inline PyBufferFrame GetPyBufferFromArray(void* data, const char* format,
278  size_t itemsize, size_t nitem) {
279  return PyBufferFrame{data, const_cast<char*>(format), itemsize, nitem};
280 }
281 
282 // Infer format string from data type
283 template <typename T>
284 inline const char* InferFormatString() {
285  switch (sizeof(T)) {
286  case 1:
287  return (std::is_unsigned<T>::value ? "=B" : "=b");
288  case 2:
289  return (std::is_unsigned<T>::value ? "=H" : "=h");
290  case 4:
291  if (std::is_integral<T>::value) {
292  return (std::is_unsigned<T>::value ? "=L" : "=l");
293  } else {
294  if (!std::is_floating_point<T>::value) {
295  throw std::runtime_error("Could not infer format string");
296  }
297  return "=f";
298  }
299  case 8:
300  if (std::is_integral<T>::value) {
301  return (std::is_unsigned<T>::value ? "=Q" : "=q");
302  } else {
303  if (!std::is_floating_point<T>::value) {
304  throw std::runtime_error("Could not infer format string");
305  }
306  return "=d";
307  }
308  default:
309  throw std::runtime_error("Unrecognized type");
310  }
311  return nullptr;
312 }
313 
314 template <typename T>
315 inline PyBufferFrame GetPyBufferFromArray(ContiguousArray<T>* vec, const char* format) {
316  return GetPyBufferFromArray(static_cast<void*>(vec->Data()), format, sizeof(T), vec->Size());
317 }
318 
319 template <typename T>
320 inline PyBufferFrame GetPyBufferFromArray(ContiguousArray<T>* vec) {
321  static_assert(std::is_arithmetic<T>::value,
322  "Use GetPyBufferFromArray(vec, format) for composite types; specify format string manually");
323  return GetPyBufferFromArray(vec, InferFormatString<T>());
324 }
325 
326 inline PyBufferFrame GetPyBufferFromScalar(void* data, const char* format, size_t itemsize) {
327  return GetPyBufferFromArray(data, format, itemsize, 1);
328 }
329 
330 template <typename T>
331 inline PyBufferFrame GetPyBufferFromScalar(T* scalar, const char* format) {
332  return GetPyBufferFromScalar(static_cast<void*>(scalar), format, sizeof(T));
333 }
334 
335 template <typename T>
336 inline PyBufferFrame GetPyBufferFromScalar(T* scalar) {
337  static_assert(std::is_arithmetic<T>::value,
338  "Use GetPyBufferFromScalar(scalar, format) for composite types; "
339  "specify format string manually");
340  return GetPyBufferFromScalar(scalar, InferFormatString<T>());
341 }
342 
343 template <typename T>
344 inline void InitArrayFromPyBuffer(ContiguousArray<T>* vec, PyBufferFrame buffer) {
345  if (sizeof(T) != buffer.itemsize) {
346  throw std::runtime_error("Incorrect itemsize");
347  }
348  vec->UseForeignBuffer(buffer.buf, buffer.nitem);
349 }
350 
351 template <typename T>
352 inline void InitScalarFromPyBuffer(T* scalar, PyBufferFrame buffer) {
353  if (sizeof(T) != buffer.itemsize) {
354  throw std::runtime_error("Incorrect itemsize");
355  }
356  if (buffer.nitem != 1) {
357  throw std::runtime_error("nitem must be 1 for a scalar");
358  }
359  T* t = static_cast<T*>(buffer.buf);
360  *scalar = *t;
361 }
362 
363 constexpr size_t kNumFramePerTree = 6;
364 
365 inline std::vector<PyBufferFrame>
366 Tree::GetPyBuffer() {
367  return {
368  GetPyBufferFromScalar(&num_nodes),
369  GetPyBufferFromArray(&nodes_, "T{=l=l=L=f=Q=d=d=b=b=?=?=?=?=H}"),
370  GetPyBufferFromArray(&leaf_vector_),
371  GetPyBufferFromArray(&leaf_vector_offset_),
372  GetPyBufferFromArray(&left_categories_),
373  GetPyBufferFromArray(&left_categories_offset_)
374  };
375 }
376 
377 inline void
378 Tree::InitFromPyBuffer(std::vector<PyBufferFrame> frames) {
379  size_t frame_id = 0;
380  InitScalarFromPyBuffer(&num_nodes, frames[frame_id++]);
381  InitArrayFromPyBuffer(&nodes_, frames[frame_id++]);
382  if (num_nodes != nodes_.Size()) {
383  throw std::runtime_error("Could not load the correct number of nodes");
384  }
385  InitArrayFromPyBuffer(&leaf_vector_, frames[frame_id++]);
386  InitArrayFromPyBuffer(&leaf_vector_offset_, frames[frame_id++]);
387  InitArrayFromPyBuffer(&left_categories_, frames[frame_id++]);
388  InitArrayFromPyBuffer(&left_categories_offset_, frames[frame_id++]);
389  if (frame_id != kNumFramePerTree) {
390  throw std::runtime_error("Wrong number of frames loaded");
391  }
392 }
393 
394 inline std::vector<PyBufferFrame>
395 Model::GetPyBuffer() {
396  /* Header */
397  std::vector<PyBufferFrame> frames{
398  GetPyBufferFromScalar(&num_feature),
399  GetPyBufferFromScalar(&num_output_group),
400  GetPyBufferFromScalar(&random_forest_flag),
401  GetPyBufferFromScalar(&param, "T{" _TREELITE_STR(TREELITE_MAX_PRED_TRANSFORM_LENGTH) "s=f=f}")
402  };
403 
404  /* Body */
405  for (auto& tree : trees) {
406  auto tree_frames = tree.GetPyBuffer();
407  frames.insert(frames.end(), tree_frames.begin(), tree_frames.end());
408  }
409  return frames;
410 }
411 
412 inline void
413 Model::InitFromPyBuffer(std::vector<PyBufferFrame> frames) {
414  /* Header */
415  size_t frame_id = 0;
416  InitScalarFromPyBuffer(&num_feature, frames[frame_id++]);
417  InitScalarFromPyBuffer(&num_output_group, frames[frame_id++]);
418  InitScalarFromPyBuffer(&random_forest_flag, frames[frame_id++]);
419  InitScalarFromPyBuffer(&param, frames[frame_id++]);
420  /* Body */
421  const size_t num_frame = frames.size();
422  if ((num_frame - frame_id) % kNumFramePerTree != 0) {
423  throw std::runtime_error("Wrong number of frames");
424  }
425  trees.clear();
426  for (; frame_id < num_frame; frame_id += kNumFramePerTree) {
427  std::vector<PyBufferFrame> tree_frames(frames.begin() + frame_id,
428  frames.begin() + frame_id + kNumFramePerTree);
429  trees.emplace_back();
430  trees.back().InitFromPyBuffer(tree_frames);
431  }
432 }
433 
434 inline void Tree::Node::Init() {
435  cleft_ = cright_ = -1;
436  sindex_ = 0;
437  info_.leaf_value = 0.0f;
438  info_.threshold = 0.0f;
439  data_count_ = 0;
440  sum_hess_ = gain_ = 0.0;
441  missing_category_to_zero_ = false;
442  data_count_present_ = sum_hess_present_ = gain_present_ = false;
443  split_type_ = SplitFeatureType::kNone;
444  cmp_ = Operator::kNone;
445  pad_ = 0;
446 }
447 
448 inline int
449 Tree::AllocNode() {
450  int nd = num_nodes++;
451  if (nodes_.Size() != static_cast<size_t>(nd)) {
452  throw std::runtime_error("Invariant violated: nodes_ contains incorrect number of nodes");
453  }
454  for (int nid = nd; nid < num_nodes; ++nid) {
455  leaf_vector_offset_.PushBack(leaf_vector_offset_.Back());
456  left_categories_offset_.PushBack(left_categories_offset_.Back());
457  nodes_.Resize(nodes_.Size() + 1);
458  nodes_.Back().Init();
459  }
460  return nd;
461 }
462 
463 inline Tree
464 Tree::Clone() const {
465  Tree tree;
466  tree.num_nodes = num_nodes;
467  tree.nodes_ = nodes_.Clone();
468  tree.leaf_vector_ = leaf_vector_.Clone();
469  tree.leaf_vector_offset_ = leaf_vector_offset_.Clone();
470  tree.left_categories_ = left_categories_.Clone();
471  tree.left_categories_offset_ = left_categories_offset_.Clone();
472  return tree;
473 }
474 
475 inline void
477  num_nodes = 1;
478  leaf_vector_.Clear();
479  leaf_vector_offset_.Resize(2, 0);
480  left_categories_.Clear();
481  left_categories_offset_.Resize(2, 0);
482  nodes_.Resize(1);
483  nodes_[0].Init();
484  SetLeaf(0, 0.0f);
485 }
486 
487 inline void
488 Tree::AddChilds(int nid) {
489  const int cleft = this->AllocNode();
490  const int cright = this->AllocNode();
491  nodes_[nid].cleft_ = cleft;
492  nodes_[nid].cright_ = cright;
493 }
494 
495 inline std::vector<unsigned>
497  std::unordered_map<unsigned, bool> tmp;
498  for (int nid = 0; nid < num_nodes; ++nid) {
499  const SplitFeatureType type = SplitType(nid);
500  if (type != SplitFeatureType::kNone) {
501  const bool flag = (type == SplitFeatureType::kCategorical);
502  const uint32_t split_index = SplitIndex(nid);
503  if (tmp.count(split_index) == 0) {
504  tmp[split_index] = flag;
505  } else {
506  if (tmp[split_index] != flag) {
507  throw std::runtime_error("Feature " + std::to_string(split_index) +
508  " cannot be simultaneously be categorical and numerical.");
509  }
510  }
511  }
512  }
513  std::vector<unsigned> result;
514  for (const auto& kv : tmp) {
515  if (kv.second) {
516  result.push_back(kv.first);
517  }
518  }
519  std::sort(result.begin(), result.end());
520  return result;
521 }
522 
523 inline int
524 Tree::LeftChild(int nid) const {
525  return nodes_[nid].cleft_;
526 }
527 
528 inline int
529 Tree::RightChild(int nid) const {
530  return nodes_[nid].cright_;
531 }
532 
533 inline int
534 Tree::DefaultChild(int nid) const {
535  return DefaultLeft(nid) ? LeftChild(nid) : RightChild(nid);
536 }
537 
538 inline uint32_t
539 Tree::SplitIndex(int nid) const {
540  return (nodes_[nid].sindex_ & ((1U << 31U) - 1U));
541 }
542 
543 inline bool
544 Tree::DefaultLeft(int nid) const {
545  return (nodes_[nid].sindex_ >> 31U) != 0;
546 }
547 
548 inline bool
549 Tree::IsLeaf(int nid) const {
550  return nodes_[nid].cleft_ == -1;
551 }
552 
553 inline tl_float
554 Tree::LeafValue(int nid) const {
555  return (nodes_[nid].info_).leaf_value;
556 }
557 
558 inline std::vector<tl_float>
559 Tree::LeafVector(int nid) const {
560  if (nid > leaf_vector_offset_.Size()) {
561  throw std::runtime_error("nid too large");
562  }
563  return std::vector<tl_float>(&leaf_vector_[leaf_vector_offset_[nid]],
564  &leaf_vector_[leaf_vector_offset_[nid + 1]]);
565 }
566 
567 inline bool
568 Tree::HasLeafVector(int nid) const {
569  if (nid > leaf_vector_offset_.Size()) {
570  throw std::runtime_error("nid too large");
571  }
572  return leaf_vector_offset_[nid] != leaf_vector_offset_[nid + 1];
573 }
574 
575 inline tl_float
576 Tree::Threshold(int nid) const {
577  return (nodes_[nid].info_).threshold;
578 }
579 
580 inline Operator
581 Tree::ComparisonOp(int nid) const {
582  return nodes_[nid].cmp_;
583 }
584 
585 inline std::vector<uint32_t>
586 Tree::LeftCategories(int nid) const {
587  if (nid > left_categories_offset_.Size()) {
588  throw std::runtime_error("nid too large");
589  }
590  return std::vector<uint32_t>(&left_categories_[left_categories_offset_[nid]],
591  &left_categories_[left_categories_offset_[nid + 1]]);
592 }
593 
594 inline SplitFeatureType
595 Tree::SplitType(int nid) const {
596  return nodes_[nid].split_type_;
597 }
598 
599 inline bool
600 Tree::HasDataCount(int nid) const {
601  return nodes_[nid].data_count_present_;
602 }
603 
604 inline uint64_t
605 Tree::DataCount(int nid) const {
606  return nodes_[nid].data_count_;
607 }
608 
609 inline bool
610 Tree::HasSumHess(int nid) const {
611  return nodes_[nid].sum_hess_present_;
612 }
613 
614 inline double
615 Tree::SumHess(int nid) const {
616  return nodes_[nid].sum_hess_;
617 }
618 
619 inline bool
620 Tree::HasGain(int nid) const {
621  return nodes_[nid].gain_present_;
622 }
623 
624 inline double
625 Tree::Gain(int nid) const {
626  return nodes_[nid].gain_;
627 }
628 
629 inline bool
631  return nodes_[nid].missing_category_to_zero_;
632 }
633 
634 inline void
635 Tree::SetNumericalSplit(int nid, unsigned split_index, tl_float threshold,
636  bool default_left, Operator cmp) {
637  Node& node = nodes_[nid];
638  if (split_index >= ((1U << 31U) - 1)) {
639  throw std::runtime_error("split_index too big");
640  }
641  if (default_left) split_index |= (1U << 31U);
642  node.sindex_ = split_index;
643  (node.info_).threshold = threshold;
644  node.cmp_ = cmp;
645  node.split_type_ = SplitFeatureType::kNumerical;
646 }
647 
648 inline void
649 Tree::SetCategoricalSplit(int nid, unsigned split_index, bool default_left,
650  bool missing_category_to_zero,
651  const std::vector<uint32_t>& node_left_categories) {
652  if (split_index >= ((1U << 31U) - 1)) {
653  throw std::runtime_error("split_index too big");
654  }
655 
656  const size_t end_oft = left_categories_offset_.Back();
657  const size_t new_end_oft = end_oft + node_left_categories.size();
658  if (end_oft != left_categories_.Size()) {
659  throw std::runtime_error("Invariant violated");
660  }
661  if (!std::all_of(&left_categories_offset_[nid + 1], left_categories_offset_.End(),
662  [end_oft](size_t x) { return (x == end_oft); })) {
663  throw std::runtime_error("Invariant violated");
664  }
665  // Hopefully we won't have to move any element as we add node_left_categories for node nid
666  left_categories_.Extend(node_left_categories);
667  if (new_end_oft != left_categories_.Size()) {
668  throw std::runtime_error("Invariant violated");
669  }
670  std::for_each(&left_categories_offset_[nid + 1], left_categories_offset_.End(),
671  [new_end_oft](size_t& x) { x = new_end_oft; });
672  std::sort(&left_categories_[end_oft], left_categories_.End());
673 
674  Node& node = nodes_[nid];
675  if (default_left) split_index |= (1U << 31U);
676  node.sindex_ = split_index;
677  node.split_type_ = SplitFeatureType::kCategorical;
678  node.missing_category_to_zero_ = missing_category_to_zero;
679 }
680 
681 inline void
682 Tree::SetLeaf(int nid, tl_float value) {
683  Node& node = nodes_[nid];
684  (node.info_).leaf_value = value;
685  node.cleft_ = -1;
686  node.cright_ = -1;
687  node.split_type_ = SplitFeatureType::kNone;
688 }
689 
690 inline void
691 Tree::SetLeafVector(int nid, const std::vector<tl_float>& node_leaf_vector) {
692  const size_t end_oft = leaf_vector_offset_.Back();
693  const size_t new_end_oft = end_oft + node_leaf_vector.size();
694  if (end_oft != leaf_vector_.Size()) {
695  throw std::runtime_error("Invariant violated");
696  }
697  if (!std::all_of(&leaf_vector_offset_[nid + 1], leaf_vector_offset_.End(),
698  [end_oft](size_t x) { return (x == end_oft); })) {
699  throw std::runtime_error("Invariant violated");
700  }
701  // Hopefully we won't have to move any element as we add leaf vector elements for node nid
702  leaf_vector_.Extend(node_leaf_vector);
703  if (new_end_oft != leaf_vector_.Size()) {
704  throw std::runtime_error("Invariant violated");
705  }
706  std::for_each(&leaf_vector_offset_[nid + 1], leaf_vector_offset_.End(),
707  [new_end_oft](size_t& x) { x = new_end_oft; });
708 
709  Node& node = nodes_[nid];
710  node.cleft_ = -1;
711  node.cright_ = -1;
712  node.split_type_ = SplitFeatureType::kNone;
713 }
714 
715 inline void
716 Tree::SetSumHess(int nid, double sum_hess) {
717  Node& node = nodes_[nid];
718  node.sum_hess_ = sum_hess;
719  node.sum_hess_present_ = true;
720 }
721 
722 inline void
723 Tree::SetDataCount(int nid, uint64_t data_count) {
724  Node& node = nodes_[nid];
725  node.data_count_ = data_count;
726  node.data_count_present_ = true;
727 }
728 
729 inline void
730 Tree::SetGain(int nid, double gain) {
731  Node& node = nodes_[nid];
732  node.gain_ = gain;
733  node.gain_present_ = true;
734 }
735 
736 inline Model
737 Model::Clone() const {
738  Model model;
739  for (const Tree& t : trees) {
740  model.trees.push_back(t.Clone());
741  }
742  model.num_feature = num_feature;
745  model.param = param;
746  return model;
747 }
748 
749 inline void InitParamAndCheck(ModelParam* param,
750  const std::vector<std::pair<std::string, std::string>>& cfg) {
751  auto unknown = param->InitAllowUnknown(cfg);
752  if (!unknown.empty()) {
753  std::ostringstream oss;
754  for (const auto& kv : unknown) {
755  oss << kv.first << ", ";
756  }
757  std::cerr << "\033[1;31mWarning: Unknown parameters found; "
758  << "they have been ignored\u001B[0m: " << oss.str() << std::endl;
759  }
760 }
761 
762 } // namespace treelite
763 #endif // TREELITE_TREE_IMPL_H_
SplitFeatureType split_type_
feature split type
Definition: tree.h:118
Operator ComparisonOp(int nid) const
get comparison operator
Definition: tree_impl.h:581
bool gain_present_
whether gain_present_ field is present
Definition: tree.h:135
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
Definition: tree.h:419
SplitFeatureType
feature split type
Definition: base.h:20
void Init()
initialize the model with a single root node
Definition: tree_impl.h:476
uint64_t data_count_
number of data points whose traversal paths include this node. LightGBM models natively store this st...
Definition: tree.h:105
thin wrapper for tree ensemble model
Definition: tree.h:409
float tl_float
float type to be used internally
Definition: base.h:18
bool HasDataCount(int nid) const
test whether this node has data count
Definition: tree_impl.h:600
std::vector< tl_float > LeafVector(int nid) const
get leaf vector of the leaf node; useful for multi-class random forest classifier ...
Definition: tree_impl.h:559
bool HasGain(int nid) const
test whether this node has gain value
Definition: tree_impl.h:620
Operator cmp_
operator to use for expression of form [fval] OP [threshold]. If the expression evaluates to true...
Definition: tree.h:124
std::vector< Tree > trees
member trees
Definition: tree.h:411
tl_float Threshold(int nid) const
get threshold of the node
Definition: tree_impl.h:576
int DefaultChild(int nid) const
index of the node&#39;s "default" child, used when feature is missing
Definition: tree_impl.h:534
ModelParam param
extra parameters
Definition: tree.h:424
bool data_count_present_
whether data_count_ field is present
Definition: tree.h:131
tree node
Definition: tree.h:83
int32_t cleft_
pointer to left and right children
Definition: tree.h:93
std::vector< uint32_t > LeftCategories(int nid) const
Get list of all categories belonging to the left child node. Categories not in this list will belong ...
Definition: tree_impl.h:586
void SetSumHess(int nid, double sum_hess)
set the hessian sum of the node
Definition: tree_impl.h:716
in-memory representation of a decision tree
Definition: tree.h:80
double sum_hess_
sum of hessian values for all data points whose traversal paths include this node. This value is generally correlated positively with the data count. XGBoost models natively store this statistics.
Definition: tree.h:112
double gain_
change in loss that is attributed to a particular split
Definition: tree.h:116
uint32_t SplitIndex(int nid) const
feature index of the node&#39;s split condition
Definition: tree_impl.h:539
std::vector< unsigned > GetCategoricalFeatures() const
get list of all categorical features that have appeared anywhere in tree
Definition: tree_impl.h:496
bool MissingCategoryToZero(int nid) const
test whether missing values should be converted into zero; only applicable for categorical splits ...
Definition: tree_impl.h:630
double SumHess(int nid) const
get hessian sum
Definition: tree_impl.h:615
void SetLeafVector(int nid, const std::vector< tl_float > &leaf_vector)
set the leaf vector of the node; useful for multi-class random forest classifier
Definition: tree_impl.h:691
void SetCategoricalSplit(int nid, unsigned split_index, bool default_left, bool missing_category_to_zero, const std::vector< uint32_t > &left_categories)
create a categorical split
Definition: tree_impl.h:649
void SetGain(int nid, double gain)
set the gain value of the node
Definition: tree_impl.h:730
bool random_forest_flag
flag for random forest; True for random forests and False for gradient boosted trees ...
Definition: tree.h:422
void SetDataCount(int nid, uint64_t data_count)
set the data count of the node
Definition: tree_impl.h:723
void Init()
Initialization method. Use this in lieu of constructor (POD types cannot have a non-trivial construct...
Definition: tree_impl.h:434
tl_float LeafValue(int nid) const
get leaf value of the leaf node
Definition: tree_impl.h:554
int num_nodes
number of nodes
Definition: tree.h:167
SplitFeatureType SplitType(int nid) const
get feature split type
Definition: tree_impl.h:595
void SetLeaf(int nid, tl_float value)
set the leaf value of the node
Definition: tree_impl.h:682
void SetNumericalSplit(int nid, unsigned split_index, tl_float threshold, bool default_left, Operator cmp)
Setters.
Definition: tree_impl.h:635
int LeftChild(int nid) const
Getters.
Definition: tree_impl.h:524
uint64_t DataCount(int nid) const
get data count
Definition: tree_impl.h:605
double Gain(int nid) const
get gain value
Definition: tree_impl.h:625
int RightChild(int nid) const
index of the node&#39;s right child
Definition: tree_impl.h:529
void AddChilds(int nid)
add child nodes to node
Definition: tree_impl.h:488
bool sum_hess_present_
whether sum_hess_ field is present
Definition: tree.h:133
bool DefaultLeft(int nid) const
whether to use the left child node, when the feature in the split condition is missing ...
Definition: tree_impl.h:544
bool HasSumHess(int nid) const
test whether this node has hessian sum
Definition: tree_impl.h:610
bool IsLeaf(int nid) const
whether the node is leaf node
Definition: tree_impl.h:549
uint32_t sindex_
feature index used for the split highest bit indicates default direction for missing values ...
Definition: tree.h:98
bool HasLeafVector(int nid) const
tests whether the leaf node has a non-empty leaf vector
Definition: tree_impl.h:568
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:416
Info info_
storage for leaf value or decision threshold
Definition: tree.h:100
Operator
comparison operators
Definition: base.h:24