treelite
data.cc
1 
8 #include <treelite/data.h>
9 #include <treelite/omp.h>
10 #include <memory>
11 #include <limits>
12 #include <cstdint>
13 
14 namespace treelite {
15 
16 DMatrix*
17 DMatrix::Create(const char* filename, const char* format,
18  int nthread, int verbose) {
19  std::unique_ptr<dmlc::Parser<uint32_t>> parser(
20  dmlc::Parser<uint32_t>::Create(filename, 0, 1, format));
21  return Create(parser.get(), nthread, verbose);
22 }
23 
24 DMatrix*
25 DMatrix::Create(dmlc::Parser<uint32_t>* parser, int nthread, int verbose) {
26  const int max_thread = omp_get_max_threads();
27  nthread = (nthread == 0) ? max_thread : std::min(nthread, max_thread);
28 
29  DMatrix* dmat = new DMatrix();
30  dmat->Clear();
31  auto& data_ = dmat->data;
32  auto& col_ind_ = dmat->col_ind;
33  auto& row_ptr_ = dmat->row_ptr;
34  auto& num_row_ = dmat->num_row;
35  auto& num_col_ = dmat->num_col;
36  auto& nelem_ = dmat->nelem;
37 
38  std::vector<size_t> max_col_ind(nthread, 0);
39  parser->BeforeFirst();
40  while (parser->Next()) {
41  const dmlc::RowBlock<uint32_t>& batch = parser->Value();
42  num_row_ += batch.size;
43  nelem_ += batch.offset[batch.size];
44  const size_t top = data_.size();
45  data_.resize(top + batch.offset[batch.size] - batch.offset[0]);
46  col_ind_.resize(top + batch.offset[batch.size] - batch.offset[0]);
47  CHECK_LT(static_cast<int64_t>(batch.offset[batch.size]),
48  std::numeric_limits<int64_t>::max());
49  #pragma omp parallel for schedule(static) num_threads(nthread)
50  for (int64_t i = static_cast<int64_t>(batch.offset[0]);
51  i < static_cast<int64_t>(batch.offset[batch.size]); ++i) {
52  const int tid = omp_get_thread_num();
53  const uint32_t index = batch.index[i];
54  const float fvalue = (batch.value == nullptr) ? 1.0f :
55  static_cast<float>(batch.value[i]);
56  const size_t offset = top + i - batch.offset[0];
57  data_[offset] = fvalue;
58  col_ind_[offset] = index;
59  max_col_ind[tid] = std::max(max_col_ind[tid],
60  static_cast<size_t>(index));
61  }
62  const size_t rtop = row_ptr_.size();
63  row_ptr_.resize(rtop + batch.size);
64  CHECK_LT(static_cast<int64_t>(batch.size),
65  std::numeric_limits<int64_t>::max());
66  #pragma omp parallel for schedule(static) num_threads(nthread)
67  for (int64_t i = 0; i < static_cast<int64_t>(batch.size); ++i) {
68  row_ptr_[rtop + i]
69  = row_ptr_[rtop - 1] + batch.offset[i + 1] - batch.offset[0];
70  }
71  if (verbose > 0) {
72  LOG(INFO) << num_row_ << " rows read into memory";
73  }
74  }
75  num_col_ = *std::max_element(max_col_ind.begin(), max_col_ind.end()) + 1;
76  return dmat;
77 }
78 
79 } // namespace treelite
std::vector< float > data
feature values
Definition: data.h:18
Input data structure of treelite.
std::vector< uint32_t > col_ind
feature indices
Definition: data.h:20
static DMatrix * Create(const char *filename, const char *format, int nthread, int verbose)
construct a new DMatrix from a file
Definition: data.cc:17
size_t num_row
number of rows
Definition: data.h:24
a simple data matrix in CSR (Compressed Sparse Row) storage
Definition: data.h:16
void Clear()
clear all data fields
Definition: data.h:33
size_t num_col
number of columns
Definition: data.h:26
compatiblity wrapper for systems that don&#39;t support OpenMP
size_t nelem
number of nonzero entries
Definition: data.h:28
std::vector< size_t > row_ptr
pointer to row headers; length of [num_row] + 1
Definition: data.h:22