Treelite
reference_serializer.cc
Go to the documentation of this file.
1 
8 #include <treelite/tree.h>
9 #include <dmlc/io.h>
10 #include <dmlc/serializer.h>
11 
12 namespace dmlc {
13 namespace serializer {
14 
15 template <typename T>
16 struct Handler<treelite::ContiguousArray<T>> {
17  inline static void Write(Stream* strm, const treelite::ContiguousArray<T>& data) {
18  uint64_t sz = static_cast<uint64_t>(data.Size());
19  strm->Write(sz);
20  strm->Write(data.Data(), sz * sizeof(T));
21  }
22 
23  inline static bool Read(Stream* strm, treelite::ContiguousArray<T>* data) {
24  uint64_t sz;
25  bool status = strm->Read(&sz);
26  if (!status) {
27  return false;
28  }
29  data->Resize(sz);
30  return strm->Read(data->Data(), sz * sizeof(T));
31  }
32 };
33 
34 } // namespace serializer
35 } // namespace dmlc
36 
37 namespace treelite {
38 
39 template <typename ThresholdType, typename LeafOutputType>
40 void Tree<ThresholdType, LeafOutputType>::ReferenceSerialize(dmlc::Stream* fo) const {
41  fo->Write(num_nodes);
42  fo->Write(leaf_vector_);
43  fo->Write(leaf_vector_offset_);
44  fo->Write(matching_categories_);
45  fo->Write(matching_categories_offset_);
46  uint64_t sz = static_cast<uint64_t>(nodes_.Size());
47  fo->Write(sz);
48  fo->Write(nodes_.Data(), sz * sizeof(Tree::Node));
49 
50  // Sanity check
51  CHECK_EQ(nodes_.Size(), num_nodes);
52  CHECK_EQ(nodes_.Size() + 1, leaf_vector_offset_.Size());
53  CHECK_EQ(leaf_vector_offset_.Back(), leaf_vector_.Size());
54  CHECK_EQ(nodes_.Size() + 1, matching_categories_offset_.Size());
55  CHECK_EQ(matching_categories_offset_.Back(), matching_categories_.Size());
56 }
57 
58 template <typename ThresholdType, typename LeafOutputType>
59 void ModelImpl<ThresholdType, LeafOutputType>::ReferenceSerialize(dmlc::Stream* fo) const {
60  fo->Write(num_feature);
61  fo->Write(static_cast<uint8_t>(task_type));
62  fo->Write(average_tree_output);
63  fo->Write(&task_param, sizeof(task_param));
64  fo->Write(&param, sizeof(param));
65  uint64_t sz = static_cast<uint64_t>(trees.size());
66  fo->Write(sz);
67  for (const Tree<ThresholdType, LeafOutputType>& tree : trees) {
68  tree.ReferenceSerialize(fo);
69  }
70 }
71 
72 template void Tree<float, uint32_t>::ReferenceSerialize(dmlc::Stream* fo) const;
73 template void Tree<float, float>::ReferenceSerialize(dmlc::Stream* fo) const;
74 template void Tree<double, uint32_t>::ReferenceSerialize(dmlc::Stream* fo) const;
75 template void Tree<double, double>::ReferenceSerialize(dmlc::Stream* fo) const;
76 
77 template void ModelImpl<float, uint32_t>::ReferenceSerialize(dmlc::Stream* fo) const;
78 template void ModelImpl<float, float>::ReferenceSerialize(dmlc::Stream* fo) const;
79 template void ModelImpl<double, uint32_t>::ReferenceSerialize(dmlc::Stream* fo) const;
80 template void ModelImpl<double, double>::ReferenceSerialize(dmlc::Stream* fo) const;
81 
82 } // namespace treelite
model structure for tree ensemble
Definition: tree.h:29