diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index 1ed7284ce305..c290c03d3de2 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -198,7 +198,7 @@ def _create_data(objective, n_samples=1_000, output='array', chunk_size=500, **k def _r2_score(dy_true, dy_pred): numerator = ((dy_true - dy_pred) ** 2).sum(axis=0, dtype=np.float64) - denominator = ((dy_true - dy_pred.mean(axis=0)) ** 2).sum(axis=0, dtype=np.float64) + denominator = ((dy_true - dy_true.mean(axis=0)) ** 2).sum(axis=0, dtype=np.float64) return (1 - numerator / denominator).compute()