Skip to content

Commit

Permalink
Add support for RF classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
hcho3 committed Oct 4, 2024
1 parent 6e41e81 commit 35dc9d2
Show file tree
Hide file tree
Showing 2 changed files with 216 additions and 92 deletions.
245 changes: 158 additions & 87 deletions python/treelite/sklearn/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from typing import Any

import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor

from ..core import TreeliteError
from ..model import Model
Expand All @@ -24,6 +26,96 @@ def _ensure_numpy(x: Any) -> np.ndarray:
raise ValueError(f"x is not a valid NumPy array. {x.type=}")

Check warning on line 26 in python/treelite/sklearn/exporter.py

View check run for this annotation

Codecov / codecov/patch

python/treelite/sklearn/exporter.py#L26

Added line #L26 was not covered by tests


_node_dtype = np.dtype(
{
"names": [
"left_child",
"right_child",
"feature",
"threshold",
"impurity",
"n_node_samples",
"weighted_n_node_samples",
"missing_go_to_left",
],
"formats": ["<i8", "<i8", "<i8", "<f8", "<f8", "<i8", "<f8", "u1"],
"offsets": [0, 8, 16, 24, 32, 40, 48, 56],
"itemsize": 64,
}
)


def _export_tree(
model, *, tree_id, n_features, n_classes, n_targets, tree_depths, subestimator_class
):
# pylint: disable=too-many-locals
try:
from sklearn import __version__ as sklearn_version
from sklearn.tree._tree import Tree as SKLearnTree
except ImportError as e:
raise TreeliteError("This function requires scikit-learn package") from e

Check warning on line 56 in python/treelite/sklearn/exporter.py

View check run for this annotation

Codecov / codecov/patch

python/treelite/sklearn/exporter.py#L55-L56

Added lines #L55 - L56 were not covered by tests

tree_accessor = model.get_tree_accessor(tree_id)
has_categorical_split = tree_accessor.get_field("has_categorical_split").tolist()[0]
if has_categorical_split:
raise NotImplementedError(

Check warning on line 61 in python/treelite/sklearn/exporter.py

View check run for this annotation

Codecov / codecov/patch

python/treelite/sklearn/exporter.py#L61

Added line #L61 was not covered by tests
"Trees with categorical splits cannot yet be exported as scikit-learn"
)

tree = SKLearnTree(n_features, n_classes, n_targets)

n_nodes = tree_accessor.get_field("num_nodes").tolist()[0]
nodes = np.empty(n_nodes, dtype=_node_dtype)

nodes["left_child"] = tree_accessor.get_field("cleft")
nodes["right_child"] = tree_accessor.get_field("cright")
nodes["feature"] = tree_accessor.get_field("split_index")
nodes["threshold"] = tree_accessor.get_field("threshold")
nodes["impurity"] = np.nan
nodes["n_node_samples"] = -1
nodes["weighted_n_node_samples"] = np.nan
nodes["missing_go_to_left"] = tree_accessor.get_field("default_left")

if n_targets == 1 and n_classes[0] == 1:
leaf_value = (
tree_accessor.get_field("leaf_value").astype("float64").reshape((-1, 1, 1))
)
else:
# Need to map leaf values to correct layout
leaf_value = np.zeros((n_nodes, n_targets, n_classes[0]), dtype="float64")
leaf_value_raw = tree_accessor.get_field("leaf_vector").astype("float64")
leaf_vec_begin = tree_accessor.get_field("leaf_vector_begin")
leaf_vec_end = tree_accessor.get_field("leaf_vector_end")
for node_id in range(n_nodes):
if leaf_vec_begin[node_id] != leaf_vec_end[node_id]:
# This node is a leaf node and outputs a vector
leaf_value[node_id, :, :] = leaf_value_raw[
leaf_vec_begin[node_id] : leaf_vec_end[node_id]
].reshape((n_targets, n_classes[0]))

state = {
"max_depth": tree_depths[tree_id],
"node_count": n_nodes,
"nodes": nodes,
"values": leaf_value,
}
tree.__setstate__(state)

subestimator = subestimator_class()
subestimator_state = {
"tree_": tree,
"n_outputs_": n_targets,
"_sklearn_version": sklearn_version,
}
if subestimator_class is DecisionTreeClassifier:
if n_targets == 1:
subestimator_state["n_classes_"] = n_classes[0]
else:
subestimator_state["n_classes_"] = n_classes.tolist()
subestimator.__setstate__(subestimator_state)
return subestimator


def export_model(model: Model):
"""
Export as scikit-learn RandomForest or GradientBoosting
Expand All @@ -46,111 +138,90 @@ def export_model(model: Model):
try:
from sklearn import __version__ as sklearn_version
from sklearn.ensemble import RandomForestRegressor
from sklearn.tree import DecisionTreeRegressor
from sklearn.tree._tree import Tree as SKLearnTree
except ImportError as e:
raise TreeliteError("This function requires scikit-learn package") from e

Check warning on line 142 in python/treelite/sklearn/exporter.py

View check run for this annotation

Codecov / codecov/patch

python/treelite/sklearn/exporter.py#L141-L142

Added lines #L141 - L142 were not covered by tests

node_dtype = np.dtype(
{
"names": [
"left_child",
"right_child",
"feature",
"threshold",
"impurity",
"n_node_samples",
"weighted_n_node_samples",
"missing_go_to_left",
],
"formats": ["<i8", "<i8", "<i8", "<f8", "<f8", "<i8", "<f8", "u1"],
"offsets": [0, 8, 16, 24, 32, 40, 48, 56],
"itemsize": 64,
}
)

header_accessor = model.get_header_accessor()
# average_tree_output = (
# header_accessor.get_field("average_tree_output").tolist()[0] == 1
# )
average_tree_output = (
_ensure_scalar_int(header_accessor.get_field("average_tree_output")) == 1
)
n_features = _ensure_scalar_int(header_accessor.get_field("num_feature"))
n_trees = _ensure_scalar_int(header_accessor.get_field("num_tree"))
n_targets = _ensure_scalar_int(header_accessor.get_field("num_target"))
n_classes = _ensure_numpy(header_accessor.get_field("num_class"))
leaf_vector_shape = _ensure_numpy(header_accessor.get_field("leaf_vector_shape"))
target_id = _ensure_numpy(header_accessor.get_field("target_id"))
class_id = _ensure_numpy(header_accessor.get_field("class_id"))
tree_depths = model.get_tree_depth()

assert np.all(n_classes == 1)
assert np.array_equal(leaf_vector_shape, [n_targets, 1])
# Heuristics to ensure that the model can be represented as scikit-learn random forest
# 1. average_tree_output must be True
# 2. n_classes[i] must be identical for all targets
# 3. Each leaf must yield an output of shape (n_targets, n_classes)
# 4. target_id[i] must be either 0 or -1
# 5. class_id[i] must be either 0 or -1
def raise_not_rf_error():
raise NotImplementedError(

Check warning on line 164 in python/treelite/sklearn/exporter.py

View check run for this annotation

Codecov / codecov/patch

python/treelite/sklearn/exporter.py#L164

Added line #L164 was not covered by tests
"Currently only random forests can be exported as scikit-learn model objects. "
"Gradient boosting models and other kinds of decision tree models are not yet "
"supported."
)

if not average_tree_output:
raise_not_rf_error()

Check warning on line 171 in python/treelite/sklearn/exporter.py

View check run for this annotation

Codecov / codecov/patch

python/treelite/sklearn/exporter.py#L171

Added line #L171 was not covered by tests
if not np.all(n_classes == n_classes[0]):
raise_not_rf_error()

Check warning on line 173 in python/treelite/sklearn/exporter.py

View check run for this annotation

Codecov / codecov/patch

python/treelite/sklearn/exporter.py#L173

Added line #L173 was not covered by tests
if not np.array_equal(leaf_vector_shape, [n_targets, n_classes.max()]):
raise_not_rf_error()

Check warning on line 175 in python/treelite/sklearn/exporter.py

View check run for this annotation

Codecov / codecov/patch

python/treelite/sklearn/exporter.py#L175

Added line #L175 was not covered by tests
if not np.all((target_id == 0) | (target_id == -1)):
raise_not_rf_error()

Check warning on line 177 in python/treelite/sklearn/exporter.py

View check run for this annotation

Codecov / codecov/patch

python/treelite/sklearn/exporter.py#L177

Added line #L177 was not covered by tests
if not np.all((class_id == 0) | (class_id == -1)):
raise_not_rf_error()

Check warning on line 179 in python/treelite/sklearn/exporter.py

View check run for this annotation

Codecov / codecov/patch

python/treelite/sklearn/exporter.py#L179

Added line #L179 was not covered by tests

# Heuristics for determining whether the model is a regressor or a classifier
if n_classes[0] == 1:
estimator_class = RandomForestRegressor
subestimator_class = DecisionTreeRegressor
else:
estimator_class = RandomForestClassifier
subestimator_class = DecisionTreeClassifier

estimators = []

for tree_id in range(n_trees):
tree_accessor = model.get_tree_accessor(tree_id)
has_categorical_split = tree_accessor.get_field(
"has_categorical_split"
).tolist()[0]
assert not has_categorical_split

tree = SKLearnTree(n_features, n_classes, n_targets)

n_nodes = tree_accessor.get_field("num_nodes").tolist()[0]
nodes = np.empty(n_nodes, dtype=node_dtype)

nodes["left_child"] = tree_accessor.get_field("cleft")
nodes["right_child"] = tree_accessor.get_field("cright")
nodes["feature"] = tree_accessor.get_field("split_index")
nodes["threshold"] = tree_accessor.get_field("threshold")
nodes["impurity"] = np.nan
nodes["n_node_samples"] = -1
nodes["weighted_n_node_samples"] = np.nan
nodes["missing_go_to_left"] = tree_accessor.get_field("default_left")
estimators.append(
_export_tree(
model,
tree_id=tree_id,
n_features=n_features,
n_classes=n_classes,
n_targets=n_targets,
tree_depths=tree_depths,
subestimator_class=subestimator_class,
)
)

clf = estimator_class()
state = {
"estimators_": estimators,
"n_outputs_": n_targets,
"_sklearn_version": sklearn_version,
}
if estimator_class is RandomForestClassifier:
if n_targets == 1:
leaf_value = (
tree_accessor.get_field("leaf_value")
.astype("float64")
.reshape((-1, 1, 1))
state.update(
{
"n_classes_": n_classes[0],
"classes_": np.arange(n_classes[0]),
}
)
else:
# Need to map leaf values to correct layout
leaf_value = np.zeros((n_nodes, n_targets, 1), dtype="float64")
leaf_value_raw = tree_accessor.get_field("leaf_vector").astype("float64")
leaf_vec_begin = tree_accessor.get_field("leaf_vector_begin")
leaf_vec_end = tree_accessor.get_field("leaf_vector_end")
for node_id in range(n_nodes):
if leaf_vec_begin[node_id] != leaf_vec_end[node_id]:
# This node is a leaf node and outputs a vector
leaf_value[node_id, :, :] = leaf_value_raw[
leaf_vec_begin[node_id] : leaf_vec_end[node_id]
].reshape((n_targets, 1))

state = {
"max_depth": tree_depths[tree_id],
"node_count": n_nodes,
"nodes": nodes,
"values": leaf_value,
}
tree.__setstate__(state)

reg = DecisionTreeRegressor()
reg.__setstate__(
{
"tree_": tree,
"n_outputs_": n_targets,
"_sklearn_version": sklearn_version,
}
)

estimators.append(reg)

clf = RandomForestRegressor()
clf.__setstate__(
{
"estimators_": estimators,
"n_outputs_": n_targets,
"_sklearn_version": sklearn_version,
}
)
state.update(
{
"n_classes_": n_classes.tolist(),
"classes_": [np.arange(n_classes[i]) for i in range(n_targets)],
}
)
clf.__setstate__(state)

return clf
63 changes: 58 additions & 5 deletions tests/python/test_sklearn_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,15 +259,16 @@ def test_skl_hist_gradient_boosting_with_string_categorical():


@given(
n_targets=integers(min_value=1, max_value=3),
dataset=standard_regression_datasets(
n_targets=integers(min_value=1, max_value=3),
),
max_depth=integers(min_value=3, max_value=15),
n_estimators=integers(min_value=5, max_value=10),
callback=hypothesis_callback(),
)
@settings(**standard_settings())
def test_skl_export_regressor(n_targets, max_depth, n_estimators, callback):
"""Round trip with scikit-learn regressor"""
X, y = callback.draw(standard_regression_datasets(n_targets=just(n_targets)))
def test_skl_export_rf_regressor(dataset, max_depth, n_estimators):
"""Round trip with scikit-learn RF regressor"""
X, y = dataset
clf = RandomForestRegressor(
max_depth=max_depth, random_state=0, n_estimators=n_estimators, n_jobs=-1
)
Expand All @@ -277,3 +278,55 @@ def test_skl_export_regressor(n_targets, max_depth, n_estimators, callback):
clf2 = treelite.sklearn.export_model(tl_model)
assert isinstance(clf2, RandomForestRegressor)
np.testing.assert_almost_equal(clf2.predict(X), clf.predict(X))


@given(
dataset=standard_classification_datasets(
n_classes=integers(min_value=2, max_value=4),
),
max_depth=integers(min_value=3, max_value=15),
n_estimators=integers(min_value=5, max_value=10),
)
@settings(**standard_settings())
def test_skl_export_rf_classifier(dataset, max_depth, n_estimators):
"""Round trip with scikit-learn RF classifier"""
X, y = dataset
clf = RandomForestClassifier(
max_depth=max_depth, random_state=0, n_estimators=n_estimators, n_jobs=-1
)
clf.fit(X, y)

tl_model = treelite.sklearn.import_model(clf)
clf2 = treelite.sklearn.export_model(tl_model)
assert isinstance(clf2, RandomForestClassifier)
np.testing.assert_almost_equal(clf2.predict(X), clf.predict(X))


@given(
n_classes=integers(min_value=3, max_value=5),
n_estimators=integers(min_value=3, max_value=10),
)
@settings(**standard_settings())
def test_skl_export_rf_multitarget_multiclass(n_classes, n_estimators):
"""Round trip with scikit-learn RF classifier, with multiple outputs and classes"""
X, y1 = make_classification(
n_samples=1000,
n_features=100,
n_informative=30,
n_classes=n_classes,
random_state=0,
)
y2 = shuffle(y1, random_state=1)
y3 = shuffle(y1, random_state=2)

y = np.vstack((y1, y2, y3)).T

clf = RandomForestClassifier(
max_depth=8, n_estimators=n_estimators, n_jobs=-1, random_state=4
)
clf.fit(X, y)

tl_model = treelite.sklearn.import_model(clf)
clf2 = treelite.sklearn.export_model(tl_model)
assert isinstance(clf2, RandomForestClassifier)
np.testing.assert_almost_equal(clf2.predict_proba(X), clf.predict_proba(X))

0 comments on commit 35dc9d2

Please sign in to comment.