Editing tree models (Advanced)
The field accessor API allows users to inspect and edit tree model objects after they have been constructed. Here are some examples:
import treelite
import numpy as np
# model is treelite.Model object
# Get the "num_feature" field in the header
model.get_header_accessor().get_field("num_feature")
# Modify the "num_feature" field in the header. Use length-1 array to indicate scalar
new_value = np.array([100], dtype=np.int32)
model.get_header_accessor().set_field("num_feature", new_value)
# Get the "postprocessor" field in the header
model.get_header_accessor().get_field("postprocessor")
# Modify the "postprocessor" field in the header
model.get_header_accessor().set_field("postprocessor", "identity")
# Get the "leaf_value" field in the first tree
model.get_tree_accessor(0).get_field("leaf_value")
# Modify the "leaf_value" field in the first tree
new_value = np.array([0, 0, 0.5, 1, -0.5], dtype=np.float32)
model.get_tree_accessor(0).set_field("leaf_value", new_value)
Consult Treelite Serialization Format v4 for the list of fields.
How to use setter methods
There are lots of gotchas and pitfalls involved when using set_field()
to modify
trees. We must start by the following notice:
Note
Modifying a field is an unsafe operation
Treelite does not prevent users from assigning an invalid value to a field. Setting an invalid value may cause undefined behavior. Always consult the model spec to carefully examine model invariants and constraints on fields.
Make sure to follow the rules below to prevent errors and silent crashes:
Always pass in a NumPy array when calling
set_field()
, even when the field being set is a scalar.Make sure to use the correct
dtype
when passing in the NumPy array. For example, sincenum_feature
has typeint32_t
according to the model spec, usenp.array([...], dtype=np.int32)
.Most of the fields accessed through the tree accessor
TreeAccessor
must have their values set to arrays of lengthnum_nodes
, wherenum_nodes
is the number of nodes. Setting a shorter array will likely cause undefined behavior and silent crashes.When adding additional nodes, make sure to update the
num_nodes
field as well as all tree fields. Example:import treelite from treelite.model_builder import ( Metadata, ModelBuilder, PostProcessorFunc, TreeAnnotation, ) # Tree stump with 3 nodes builder = ModelBuilder( threshold_type="float32", leaf_output_type="float32", metadata=Metadata( num_feature=2, task_type="kRegressor", average_tree_output=False, num_target=1, num_class=[1], leaf_vector_shape=(1, 1), ), tree_annotation=TreeAnnotation(num_tree=1, target_id=[0], class_id=[0]), postprocessor=PostProcessorFunc(name="identity"), base_scores=[0.0], ) builder.start_tree() builder.start_node(0) builder.numerical_test( feature_id=0, threshold=0.0, default_left=False, opname="<=", left_child_key=1, right_child_key=2, ) builder.end_node() builder.start_node(1) builder.leaf(-1.0) builder.end_node() builder.start_node(2) builder.leaf(1.0) builder.end_node() builder.end_tree() model = builder.commit() # Add a test node. The tree now has 5 nodes total tree.set_field("num_nodes", np.array([5], dtype=np.int32)) tree.set_field("node_type", np.array([1, 0, 1, 0, 0], dtype=np.int8)) tree.set_field("cleft", np.array([1, -1, 3, -1, -1], dtype=np.int32)) tree.set_field("cright", np.array([2, -1, 4, -1, -1], dtype=np.int32)) tree.set_field("split_index", np.array([0, -1, 1, -1, 1], dtype=np.int32)) tree.set_field("default_left", np.array([0, 0, 0, 0, 0], dtype=np.int8)) tree.set_field("leaf_value", np.array([0.0, 1.0, 0.0, 2.0, 3.0], dtype=np.float32)) tree.set_field("threshold", np.array([1.0, 0.0, 2.0, 0.0, 0.0], dtype=np.float32)) tree.set_field("cmp", np.array([2, 0, 2, 0, 0], dtype=np.int8)) tree.set_field("category_list_right_child", np.array([0] * 5, dtype=np.uint8)) tree.set_field("leaf_vector_begin", np.array([0] * 5, dtype=np.uint64)) tree.set_field("leaf_vector_end", np.array([0] * 5, dtype=np.uint64)) tree.set_field("category_list_begin", np.array([0] * 5, dtype=np.uint64)) tree.set_field("category_list_end", np.array([0] * 5, dtype=np.uint64))
But really, if possible, avoid changing the number of nodes in the tree to avoid possible errors. Actions that don’t change the tree structure, such as re-numbering feature IDs and changing leaf outputs, are much safer.
Currently, it is not possible to add or remove trees using the field accessor API.