Treelite
compiler.cc
Go to the documentation of this file.
1 
6 #include <treelite/compiler.h>
8 #include <treelite/logging.h>
9 #include <rapidjson/document.h>
10 #include <limits>
11 #include "./ast_native.h"
12 #include "./failsafe.h"
13 
14 namespace treelite {
15 
16 Compiler* Compiler::Create(const std::string& name, const char* param_json_str) {
17  compiler::CompilerParam param = compiler::CompilerParam::ParseFromJSON(param_json_str);
18  if (name == "ast_native") {
19  return new compiler::ASTNativeCompiler(param);
20  } else if (name == "failsafe") {
21  return new compiler::FailSafeCompiler(param);
22  } else {
23  TREELITE_LOG(FATAL) << "Unrecognized compiler '" << name << "'";
24  return nullptr;
25  }
26 }
27 
28 namespace compiler {
29 
30 CompilerParam
31 CompilerParam::ParseFromJSON(const char* param_json_str) {
32  CompilerParam param;
33  // Default values
34  param.annotate_in = "NULL";
35  param.quantize = 0;
36  param.parallel_comp = 0;
37  param.verbose = 0;
38  param.native_lib_name = "predictor";
39  param.code_folding_req = std::numeric_limits<double>::infinity();
40  param.dump_array_as_elf = 0;
41 
42  rapidjson::Document doc;
43  doc.Parse(param_json_str);
44  TREELITE_CHECK(doc.IsObject()) << "Got an invalid JSON string:\n" << param_json_str;
45  for (const auto& e : doc.GetObject()) {
46  const std::string key = e.name.GetString();
47  if (key == "annotate_in") {
48  TREELITE_CHECK(e.value.IsString()) << "Expected a string for 'annotate_in'";
49  param.annotate_in = e.value.GetString();
50  } else if (key == "quantize") {
51  TREELITE_CHECK(e.value.IsInt()) << "Expected an integer for 'quantize'";
52  param.quantize = e.value.GetInt();
53  TREELITE_CHECK_GE(param.quantize, 0) << "'quantize' must be 0 or greater";
54  } else if (key == "parallel_comp") {
55  TREELITE_CHECK(e.value.IsInt()) << "Expected an integer for 'parallel_comp'";
56  param.parallel_comp = e.value.GetInt();
57  TREELITE_CHECK_GE(param.parallel_comp, 0) << "'parallel_comp' must be 0 or greater";
58  } else if (key == "verbose") {
59  TREELITE_CHECK(e.value.IsInt()) << "Expected an integer for 'verbose'";
60  param.verbose = e.value.GetInt();
61  } else if (key == "native_lib_name") {
62  TREELITE_CHECK(e.value.IsString()) << "Expected a string for 'native_lib_name'";
63  param.native_lib_name = e.value.GetString();
64  } else if (key == "code_folding_req") {
65  TREELITE_CHECK(e.value.IsDouble())
66  << "Expected a floating-point decimal for 'code_folding_req'";
67  param.code_folding_req = e.value.GetDouble();
68  TREELITE_CHECK_GE(param.code_folding_req, 0) << "'code_folding_req' must be 0 or greater";
69  } else if (key == "dump_array_as_elf") {
70  TREELITE_CHECK(e.value.IsInt()) << "Expected an integer for 'dump_array_as_elf'";
71  param.dump_array_as_elf = e.value.GetInt();
72  TREELITE_CHECK_GE(param.dump_array_as_elf, 0) << "'dump_array_as_elf' must be 0 or greater";
73  } else {
74  TREELITE_LOG(FATAL) << "Unrecognized key '" << key << "' in JSON";
75  }
76  }
77 
78  return param;
79 }
80 
81 } // namespace compiler
82 
83 } // namespace treelite
Parameters for tree compiler.
parameters for tree compiler
logging facility for Treelite
interface of compiler
Definition: compiler.h:53
Interface of compiler that compiles a tree ensemble model.
C code generator (fail-safe). The generated code will mimic prediction logic found in XGBoost...
C code generator.
static Compiler * Create(const std::string &name, const char *param_json_str)
create a compiler from given name
Definition: compiler.cc:16