Skip to content

Commit

Permalink
[Breaking] Set 1st dim of prediction output to be row ID (#549)
Browse files Browse the repository at this point in the history
* [Breaking] Set 1st dim of prediction output to be row ID

* Fix tests
  • Loading branch information
hcho3 authored Feb 15, 2024
1 parent 5b1bbff commit cb1aa86
Show file tree
Hide file tree
Showing 12 changed files with 59 additions and 60 deletions.
4 changes: 2 additions & 2 deletions include/treelite/gtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ namespace gtil {
enum class PredictKind : std::int8_t {
/*!
* \brief Usual prediction method: sum over trees and apply post-processing.
* Expected output dimensions: (num_target, num_row, max_num_class)
* Expected output dimensions: (num_row, num_target, max_num_class)
*/
kPredictDefault = 0,
/*!
* \brief Sum over trees, but don't apply post-processing; get raw margin scores instead.
* Expected output dimensions: (num_target, num_row, max_num_class)
* Expected output dimensions: (num_row, num_target, max_num_class)
*/
kPredictRaw = 1,
/*!
Expand Down
2 changes: 1 addition & 1 deletion python/treelite/gtil/gtil.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def predict(
Returns
-------
prediction : :py:class:`numpy.ndarray` array
Prediction output. Expected dimensions: (num_target, num_row, max(num_class))
Prediction output. Expected dimensions: (num_row, num_target, max(num_class))
"""
predict_type = "raw" if pred_margin else "default"

Expand Down
4 changes: 2 additions & 2 deletions src/gtil/output_shape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ std::vector<std::uint64_t> GetOutputShape(
case PredictKind::kPredictDefault:
case PredictKind::kPredictRaw:
if (model.num_target > 1) {
return {static_cast<std::uint64_t>(model.num_target), num_row, max_num_class};
return {num_row, static_cast<std::uint64_t>(model.num_target), max_num_class};
} else {
return {1, num_row, max_num_class};
return {num_row, 1, max_num_class};
}
case PredictKind::kPredictLeafID:
return {num_row, num_tree};
Expand Down
53 changes: 26 additions & 27 deletions src/gtil/predict.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ void OutputLeafVector(Model const& model, Tree<ThresholdT, LeafOutputT> const& t
auto leaf_view = Array2DView<LeafOutputT>(leaf_out.data(), model.num_target, max_num_class);
for (std::int32_t target_id = 0; target_id < model.num_target; ++target_id) {
for (std::int32_t class_id = 0; class_id < model.num_class[target_id]; ++class_id) {
output_view(target_id, row_id, class_id) += leaf_view(target_id, class_id);
output_view(row_id, target_id, class_id) += leaf_view(target_id, class_id);
}
}
} else if (model.target_id[tree_id] == -1) {
Expand All @@ -135,7 +135,7 @@ void OutputLeafVector(Model const& model, Tree<ThresholdT, LeafOutputT> const& t
auto leaf_view = Array2DView<LeafOutputT>(leaf_out.data(), model.num_target, 1);
auto const class_id = model.class_id[tree_id];
for (std::int32_t target_id = 0; target_id < model.num_target; ++target_id) {
output_view(target_id, row_id, class_id) += leaf_view(target_id, 0);
output_view(row_id, target_id, class_id) += leaf_view(target_id, 0);
}
} else if (model.class_id[tree_id] == -1) {
std::vector<std::int32_t> const expected_leaf_shape{1, max_num_class};
Expand All @@ -144,15 +144,15 @@ void OutputLeafVector(Model const& model, Tree<ThresholdT, LeafOutputT> const& t
auto leaf_view = Array2DView<LeafOutputT>(leaf_out.data(), 1, max_num_class);
auto const target_id = model.target_id[tree_id];
for (std::int32_t class_id = 0; class_id < model.num_class[target_id]; ++class_id) {
output_view(target_id, row_id, class_id) += leaf_view(0, class_id);
output_view(row_id, target_id, class_id) += leaf_view(0, class_id);
}
} else {
std::vector<std::int32_t> const expected_leaf_shape{1, 1};
TREELITE_CHECK(model.leaf_vector_shape.AsVector() == expected_leaf_shape);

auto const target_id = model.target_id[tree_id];
auto const class_id = model.class_id[tree_id];
output_view(target_id, row_id, class_id) += leaf_out[0];
output_view(row_id, target_id, class_id) += leaf_out[0];
}
}

Expand All @@ -166,7 +166,7 @@ void OutputLeafValue(Model const& model, Tree<ThresholdT, LeafOutputT> const& tr
std::vector<std::int32_t> const expected_leaf_shape{1, 1};
TREELITE_CHECK(model.leaf_vector_shape.AsVector() == expected_leaf_shape);

output_view(target_id, row_id, class_id) += tree.LeafValue(leaf_id);
output_view(row_id, target_id, class_id) += tree.LeafValue(leaf_id);
}

template <typename InputT>
Expand All @@ -175,7 +175,7 @@ void PredictRaw(Model const& model, InputT const* input, std::uint64_t num_row,
auto input_view = CArray2DView<InputT>(input, num_row, model.num_feature);
auto max_num_class
= *std::max_element(model.num_class.Data(), model.num_class.Data() + model.num_target);
auto output_view = Array3DView<InputT>(output, model.num_target, num_row, max_num_class);
auto output_view = Array3DView<InputT>(output, num_row, model.num_target, max_num_class);
std::size_t const num_tree = model.GetNumTree();
std::fill_n(output, output_view.size(), InputT{}); // Fill with 0's
std::visit(
Expand Down Expand Up @@ -223,27 +223,27 @@ void PredictRaw(Model const& model, InputT const* input, std::uint64_t num_row,
average_factor_view(model.target_id[tree_id], model.class_id[tree_id]) += 1;
}
}
for (std::int32_t target_id = 0; target_id < model.num_target; ++target_id) {
detail::threading_utils::ParallelFor(std::uint64_t(0), num_row, thread_config,
detail::threading_utils::ParallelSchedule::Static(), [&](std::uint64_t row_id, int) {
detail::threading_utils::ParallelFor(std::uint64_t(0), num_row, thread_config,
detail::threading_utils::ParallelSchedule::Static(), [&](std::uint64_t row_id, int) {
for (std::int32_t target_id = 0; target_id < model.num_target; ++target_id) {
for (std::int32_t class_id = 0; class_id < model.num_class[target_id]; ++class_id) {
output_view(target_id, row_id, class_id)
output_view(row_id, target_id, class_id)
/= static_cast<InputT>(average_factor_view(target_id, class_id));
}
});
}
}
});
}
// Apply base scores
auto base_score_view
= CArray2DView<double>(model.base_scores.Data(), model.num_target, max_num_class);
for (std::int32_t target_id = 0; target_id < model.num_target; ++target_id) {
detail::threading_utils::ParallelFor(std::uint64_t(0), num_row, thread_config,
detail::threading_utils::ParallelSchedule::Static(), [&](std::uint64_t row_id, int) {
detail::threading_utils::ParallelFor(std::uint64_t(0), num_row, thread_config,
detail::threading_utils::ParallelSchedule::Static(), [&](std::uint64_t row_id, int) {
for (std::int32_t target_id = 0; target_id < model.num_target; ++target_id) {
for (std::int32_t class_id = 0; class_id < model.num_class[target_id]; ++class_id) {
output_view(target_id, row_id, class_id) += base_score_view(target_id, class_id);
output_view(row_id, target_id, class_id) += base_score_view(target_id, class_id);
}
});
}
}
});
}

template <typename InputT>
Expand All @@ -252,17 +252,16 @@ void ApplyPostProcessor(Model const& model, InputT* output, std::uint64_t num_ro
auto postprocessor_func = gtil::GetPostProcessorFunc<InputT>(model.postprocessor);
auto max_num_class
= *std::max_element(model.num_class.Data(), model.num_class.Data() + model.num_target);
auto output_view = Array3DView<InputT>(output, model.num_target, num_row, max_num_class);
auto output_view = Array3DView<InputT>(output, num_row, model.num_target, max_num_class);

for (std::int32_t target_id = 0; target_id < model.num_target; ++target_id) {
std::int32_t const num_class = model.num_class[target_id];
detail::threading_utils::ParallelFor(std::uint64_t(0), num_row, thread_config,
detail::threading_utils::ParallelSchedule::Static(), [&](std::size_t row_id, int) {
auto row = stdex::submdspan(output_view, target_id, row_id, stdex::full_extent);
detail::threading_utils::ParallelFor(std::uint64_t(0), num_row, thread_config,
detail::threading_utils::ParallelSchedule::Static(), [&](std::size_t row_id, int) {
for (std::int32_t target_id = 0; target_id < model.num_target; ++target_id) {
auto row = stdex::submdspan(output_view, row_id, target_id, stdex::full_extent);
static_assert(std::is_same_v<decltype(row), Array1DView<InputT>>);
postprocessor_func(model, num_class, row.data_handle());
});
}
postprocessor_func(model, model.num_class[target_id], row.data_handle());
}
});
}

template <typename InputT>
Expand Down
8 changes: 4 additions & 4 deletions tests/python/test_field_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def test_tree_editing():
treelite.gtil.predict(
model, np.array([[-1.0, 0.0], [1.0, 0.0]], dtype=np.float32)
),
np.array([[[-1.0], [1.0]]], dtype=np.float32),
np.array([[[-1.0]], [[1.0]]], dtype=np.float32),
)

# Change leaf values
Expand All @@ -180,7 +180,7 @@ def test_tree_editing():
treelite.gtil.predict(
model, np.array([[-1.0, 0.0], [1.0, 0.0]], dtype=np.float32)
),
np.array([[[-100.0], [100.0]]], dtype=np.float32),
np.array([[[-100.0]], [[100.0]]], dtype=np.float32),
)

# Change numerical test
Expand All @@ -190,7 +190,7 @@ def test_tree_editing():
treelite.gtil.predict(
model, np.array([[0.0, 0.0], [0.0, 2.0]], dtype=np.float32)
),
np.array([[[-100.0], [100.0]]], dtype=np.float32),
np.array([[[-100.0]], [[100.0]]], dtype=np.float32),
)

# Add a test node
Expand Down Expand Up @@ -222,5 +222,5 @@ def test_tree_editing():
dtype=np.float32,
),
),
np.array([[[1.0], [1.0], [2.0], [3.0]]], dtype=np.float32),
np.array([[[1.0]], [[1.0]], [[2.0]], [[3.0]]], dtype=np.float32),
)
12 changes: 6 additions & 6 deletions tests/python/test_lightgbm_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_lightgbm_regression(
valid_sets=[dtrain],
valid_names=["train"],
)
expected_pred = lgb_model.predict(X_pred).reshape((1, -1, 1))
expected_pred = lgb_model.predict(X_pred).reshape((-1, 1, 1))

with TemporaryDirectory() as tmpdir:
lgb_model_path = pathlib.Path(tmpdir) / "lightgbm_model.txt"
Expand Down Expand Up @@ -131,7 +131,7 @@ def test_lightgbm_binary_classification(
valid_sets=[dtrain],
valid_names=["train"],
)
expected_prob = lgb_model.predict(X_pred).reshape((1, -1, 1))
expected_prob = lgb_model.predict(X_pred).reshape((-1, 1, 1))

with TemporaryDirectory() as tmpdir:
lgb_model_path = pathlib.Path(tmpdir) / "breast_cancer_lightgbm.txt"
Expand Down Expand Up @@ -192,7 +192,7 @@ def test_lightgbm_multiclass_classification(
valid_sets=[dtrain],
valid_names=["train"],
)
expected_pred = lgb_model.predict(X_pred).reshape((1, -1, num_class))
expected_pred = lgb_model.predict(X_pred).reshape((-1, 1, num_class))

with TemporaryDirectory() as tmpdir:
lgb_model_path = pathlib.Path(tmpdir) / "iris_lightgbm.txt"
Expand All @@ -210,7 +210,7 @@ def test_lightgbm_categorical_data():
tl_model = treelite.frontend.load_lightgbm_model(lgb_model_path)

X, _ = load_svmlight_file(dataset_db[dataset].dtest, zero_based=True)
expected_pred = load_txt(dataset_db[dataset].expected_margin).reshape((1, -1, 1))
expected_pred = load_txt(dataset_db[dataset].expected_margin).reshape((-1, 1, 1))
out_pred = treelite.gtil.predict(tl_model, X.toarray())
np.testing.assert_almost_equal(out_pred, expected_pred, decimal=5)

Expand Down Expand Up @@ -238,7 +238,7 @@ def test_lightgbm_sparse_ranking_model(tmpdir):

dtrain = lgb.Dataset(X, label=y, group=[X.shape[0]])
lgb_model = lgb.train(params, dtrain, num_boost_round=1)
lgb_out = lgb_model.predict(X).reshape((1, -1, 1))
lgb_out = lgb_model.predict(X).reshape((-1, 1, 1))
lgb_model.save_model(lgb_model_path)

tl_model = treelite.frontend.load_lightgbm_model(lgb_model_path)
Expand All @@ -260,7 +260,7 @@ def test_lightgbm_sparse_categorical_model():
X, _ = load_svmlight_file(
dataset_db[dataset].dtest, zero_based=True, n_features=tl_model.num_feature
)
expected_pred = load_txt(dataset_db[dataset].expected_margin).reshape((1, -1, 1))
expected_pred = load_txt(dataset_db[dataset].expected_margin).reshape((-1, 1, 1))
# GTIL doesn't yet support sparse matrix; so use NaN to represent missing values
Xa = X.toarray()
Xa[Xa == 0] = "nan"
Expand Down
4 changes: 2 additions & 2 deletions tests/python/test_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,10 @@ def make_tree_stump(left_child_val, right_child_val):
model = builder.commit()
dmat = np.array([[1.0], [-1.0]], dtype=np.float32)
if predict_kind in "default":
expected_pred = np.array([[[100.0, 200.5, 300.5], [101.0, 200.0, 300.0]]])
expected_pred = np.array([[[100.0, 200.5, 300.5]], [[101.0, 200.0, 300.0]]])
pred = treelite.gtil.predict(model, dmat, pred_margin=False)
elif predict_kind == "raw":
expected_pred = np.array([[[100.0, 200.5, 300.5], [101.0, 200.0, 300.0]]])
expected_pred = np.array([[[100.0, 200.5, 300.5]], [[101.0, 200.0, 300.0]]])
pred = treelite.gtil.predict(model, dmat, pred_margin=True)
else:
expected_pred = np.array([[2, 2], [1, 1]])
Expand Down
6 changes: 3 additions & 3 deletions tests/python/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_serialize_as_bytes(clazz, dataset, n_estimators, max_depth, callback):
)
clf = clazz(**kwargs)
clf.fit(X, y)
expected_prob = clf.predict_proba(X)[np.newaxis, :, :]
expected_prob = clf.predict_proba(X)[:, np.newaxis, :]
if (
clazz in [GradientBoostingClassifier, HistGradientBoostingClassifier]
and n_classes == 2
Expand Down Expand Up @@ -129,9 +129,9 @@ def test_serialize_as_checkpoint(clazz, n_estimators, max_depth, callback):
clf = clazz(**kwargs)
clf.fit(X, y)
if n_targets > 1:
expected_pred = np.transpose(clf.predict(X)[:, :, np.newaxis], axes=(1, 0, 2))
expected_pred = clf.predict(X)[:, :, np.newaxis]
else:
expected_pred = clf.predict(X).reshape((1, X.shape[0], -1))
expected_pred = clf.predict(X).reshape((X.shape[0], 1, -1))

with TemporaryDirectory() as tmpdir:
# Prediction should be correct after a round-trip
Expand Down
10 changes: 5 additions & 5 deletions tests/python/test_sklearn_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ def test_skl_regressor(clazz, n_estimators, callback):
tl_model = treelite.sklearn.import_model(clf)
out_pred = treelite.gtil.predict(tl_model, X)
if n_targets > 1:
expected_pred = np.transpose(clf.predict(X)[:, :, np.newaxis], axes=(1, 0, 2))
expected_pred = clf.predict(X)[:, :, np.newaxis]
else:
expected_pred = clf.predict(X).reshape((1, X.shape[0], -1))
expected_pred = clf.predict(X).reshape((X.shape[0], 1, -1))
np.testing.assert_almost_equal(out_pred, expected_pred, decimal=3)


Expand Down Expand Up @@ -126,7 +126,7 @@ def test_skl_classifier(clazz, dataset, n_estimators, callback):
np.testing.assert_equal(base_scores, np.zeros(expected_base_scores_shape))

out_prob = treelite.gtil.predict(tl_model, X)
expected_prob = clf.predict_proba(X)[np.newaxis, :, :]
expected_prob = clf.predict_proba(X)[:, np.newaxis, :]
if (
clazz in [GradientBoostingClassifier, HistGradientBoostingClassifier]
and n_classes == 2
Expand All @@ -148,7 +148,7 @@ def test_skl_converter_iforest(dataset):
)
clf.fit(X)
expected_pred = clf._compute_chunked_score_samples(X) # pylint: disable=W0212
expected_pred = expected_pred.reshape((1, -1, 1))
expected_pred = expected_pred.reshape((-1, 1, 1))

tl_model = treelite.sklearn.import_model(clf)
out_pred = treelite.gtil.predict(tl_model, X)
Expand Down Expand Up @@ -194,5 +194,5 @@ def test_skl_hist_gradient_boosting_with_categorical(

tl_model = treelite.sklearn.import_model(clf)
out_pred = treelite.gtil.predict(tl_model, X_pred)
expected_pred = clf.predict_proba(X_pred)[np.newaxis, :, 1:]
expected_pred = clf.predict_proba(X_pred)[:, np.newaxis, 1:]
np.testing.assert_almost_equal(out_pred, expected_pred, decimal=4)
12 changes: 6 additions & 6 deletions tests/python/test_xgboost_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def test_xgb_regressor(
out_pred = treelite.gtil.predict(tl_model, X_pred, pred_margin=pred_margin)
expected_pred = xgb_model.predict(
xgb.DMatrix(X_pred), output_margin=pred_margin, validate_features=False
).reshape((1, X.shape[0], -1))
).reshape((X.shape[0], 1, -1))
np.testing.assert_almost_equal(out_pred, expected_pred, decimal=3)


Expand Down Expand Up @@ -198,7 +198,7 @@ def test_xgb_multiclass_classifier(
out_pred = treelite.gtil.predict(tl_model, X_pred, pred_margin=pred_margin)
expected_pred = xgb_model.predict(
xgb.DMatrix(X_pred), output_margin=pred_margin, validate_features=False
).reshape((1, -1, num_class))
).reshape((-1, 1, num_class))
np.testing.assert_almost_equal(out_pred, expected_pred, decimal=5)


Expand Down Expand Up @@ -278,7 +278,7 @@ def test_xgb_nonlinear_objective(
out_pred = treelite.gtil.predict(tl_model, X_pred, pred_margin=True)
expected_pred = xgb_model.predict(
xgb.DMatrix(X_pred), output_margin=True, validate_features=False
).reshape((1, X.shape[0], -1))
).reshape((X.shape[0], 1, -1))
np.testing.assert_almost_equal(out_pred, expected_pred, decimal=5)


Expand Down Expand Up @@ -316,7 +316,7 @@ def test_xgb_dart(dataset, model_format, num_boost_round):
tl_model = treelite.frontend.from_xgboost(xgb_model)
out_pred = treelite.gtil.predict(tl_model, X, pred_margin=True)
expected_pred = xgb_model.predict(dtrain, output_margin=True).reshape(
(1, X.shape[0], -1)
(X.shape[0], 1, -1)
)
np.testing.assert_almost_equal(out_pred, expected_pred, decimal=5)

Expand Down Expand Up @@ -474,7 +474,7 @@ def test_xgb_multi_target_binary_classifier(
expected_pred = bst.predict(
xgb.DMatrix(X_pred), output_margin=pred_margin, validate_features=False
)
expected_pred = np.transpose(expected_pred[:, :, np.newaxis], axes=(1, 0, 2))
expected_pred = expected_pred[:, :, np.newaxis]
np.testing.assert_almost_equal(out_pred, expected_pred, decimal=5)


Expand Down Expand Up @@ -556,7 +556,7 @@ def test_xgb_multi_target_regressor(

out_pred = treelite.gtil.predict(tl_model, X_pred)
expected_pred = xgb_model.predict(xgb.DMatrix(X_pred), validate_features=False)
expected_pred = np.transpose(expected_pred[:, :, np.newaxis], axes=(1, 0, 2))
expected_pred = expected_pred[:, :, np.newaxis]
np.testing.assert_almost_equal(out_pred, expected_pred, decimal=3)


Expand Down
2 changes: 1 addition & 1 deletion tests/serializer/compatibility_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def load(args):
with open(args.model_pickle_path, "rb") as f:
clf = pickle.load(f)
tl_model = treelite.Model.deserialize(args.checkpoint_path)
expected_prob = clf.predict_proba(X).reshape((1, X.shape[0], -1))
expected_prob = clf.predict_proba(X).reshape((X.shape[0], 1, -1))
out_prob = treelite.gtil.predict(tl_model, X)
np.testing.assert_almost_equal(out_prob, expected_prob, decimal=5)
print("Test passed!")
Expand Down
2 changes: 1 addition & 1 deletion tests/serializer/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_serialize_as_buffer(clazz):
params["init"] = "zero"
clf = clazz(**params)
clf.fit(X, y)
expected_prob = clf.predict_proba(X).reshape((1, X.shape[0], -1))
expected_prob = clf.predict_proba(X).reshape((X.shape[0], 1, -1))

# Prediction should be correct after a round-trip
tl_model = treelite.sklearn.import_model(clf)
Expand Down

0 comments on commit cb1aa86

Please sign in to comment.