diff --git a/tests/link/numba/test_elemwise.py b/tests/link/numba/test_elemwise.py index 1da34ff392..b2ccc1ef1e 100644 --- a/tests/link/numba/test_elemwise.py +++ b/tests/link/numba/test_elemwise.py @@ -130,25 +130,6 @@ def test_elemwise_runtime_broadcast(): check_elemwise_runtime_broadcast(get_mode("NUMBA")) -def test_elemwise_speed(benchmark): - x = pt.dmatrix("y") - y = pt.dvector("z") - - out = np.exp(2 * x * y + y) - - rng = np.random.default_rng(42) - - x_val = rng.normal(size=(200, 500)) - y_val = rng.normal(size=500) - - func = function([x, y], out, mode="NUMBA") - func = func.vm.jit_fn - (out,) = func(x_val, y_val) - np.testing.assert_allclose(np.exp(2 * x_val * y_val + y_val), out) - - benchmark(func, x_val, y_val) - - @pytest.mark.parametrize( "v, new_order", [ @@ -631,41 +612,6 @@ def test_Argmax(x, axes, exc): ) -@pytest.mark.parametrize("size", [(10, 10), (1000, 1000), (10000, 10000)]) -@pytest.mark.parametrize("axis", [0, 1]) -def test_logsumexp_benchmark(size, axis, benchmark): - X = pt.matrix("X") - X_max = pt.max(X, axis=axis, keepdims=True) - X_max = pt.switch(pt.isinf(X_max), 0, X_max) - X_lse = pt.log(pt.sum(pt.exp(X - X_max), axis=axis, keepdims=True)) + X_max - - rng = np.random.default_rng(23920) - X_val = rng.normal(size=size) - - X_lse_fn = pytensor.function([X], X_lse, mode="NUMBA") - - # JIT compile first - res = X_lse_fn(X_val) - exp_res = scipy.special.logsumexp(X_val, axis=axis, keepdims=True) - np.testing.assert_array_almost_equal(res, exp_res) - benchmark(X_lse_fn, X_val) - - -def test_fused_elemwise_benchmark(benchmark): - rng = np.random.default_rng(123) - size = 100_000 - x = pytensor.shared(rng.normal(size=size), name="x") - mu = pytensor.shared(rng.normal(size=size), name="mu") - - logp = -((x - mu) ** 2) / 2 - grad_logp = grad(logp.sum(), x) - - func = pytensor.function([], [logp, grad_logp], mode="NUMBA") - # JIT compile first - func() - benchmark(func) - - def test_elemwise_out_type(): # Create a graph with an elemwise # Ravel failes if the elemwise output type is reported incorrectly @@ -681,22 +627,6 @@ def test_elemwise_out_type(): assert func(x_val).shape == (18,) -@pytest.mark.parametrize( - "axis", - (0, 1, 2, (0, 1), (0, 2), (1, 2), None), - ids=lambda x: f"axis={x}", -) -@pytest.mark.parametrize( - "c_contiguous", - (True, False), - ids=lambda x: f"c_contiguous={x}", -) -def test_numba_careduce_benchmark(axis, c_contiguous, benchmark): - return careduce_benchmark_tester( - axis, c_contiguous, mode="NUMBA", benchmark=benchmark - ) - - def test_scalar_loop(): a = float64("a") scalar_loop = pytensor.scalar.ScalarLoop([a], [a + a]) @@ -709,3 +639,71 @@ def test_scalar_loop(): ([x], [elemwise_loop]), (np.array([1, 2, 3], dtype="float64"),), ) + + +class TestsBenchmark: + def test_elemwise_speed(self, benchmark): + x = pt.dmatrix("y") + y = pt.dvector("z") + + out = np.exp(2 * x * y + y) + + rng = np.random.default_rng(42) + + x_val = rng.normal(size=(200, 500)) + y_val = rng.normal(size=500) + + func = function([x, y], out, mode="NUMBA") + func = func.vm.jit_fn + (out,) = func(x_val, y_val) + np.testing.assert_allclose(np.exp(2 * x_val * y_val + y_val), out) + + benchmark(func, x_val, y_val) + + def test_fused_elemwise_benchmark(self, benchmark): + rng = np.random.default_rng(123) + size = 100_000 + x = pytensor.shared(rng.normal(size=size), name="x") + mu = pytensor.shared(rng.normal(size=size), name="mu") + + logp = -((x - mu) ** 2) / 2 + grad_logp = grad(logp.sum(), x) + + func = pytensor.function([], [logp, grad_logp], mode="NUMBA") + # JIT compile first + func() + benchmark(func) + + @pytest.mark.parametrize("size", [(10, 10), (1000, 1000), (10000, 10000)]) + @pytest.mark.parametrize("axis", [0, 1]) + def test_logsumexp_benchmark(self, size, axis, benchmark): + X = pt.matrix("X") + X_max = pt.max(X, axis=axis, keepdims=True) + X_max = pt.switch(pt.isinf(X_max), 0, X_max) + X_lse = pt.log(pt.sum(pt.exp(X - X_max), axis=axis, keepdims=True)) + X_max + + rng = np.random.default_rng(23920) + X_val = rng.normal(size=size) + + X_lse_fn = pytensor.function([X], X_lse, mode="NUMBA") + + # JIT compile first + res = X_lse_fn(X_val) + exp_res = scipy.special.logsumexp(X_val, axis=axis, keepdims=True) + np.testing.assert_array_almost_equal(res, exp_res) + benchmark(X_lse_fn, X_val) + + @pytest.mark.parametrize( + "axis", + (0, 1, 2, (0, 1), (0, 2), (1, 2), None), + ids=lambda x: f"axis={x}", + ) + @pytest.mark.parametrize( + "c_contiguous", + (True, False), + ids=lambda x: f"c_contiguous={x}", + ) + def test_numba_careduce_benchmark(self, axis, c_contiguous, benchmark): + return careduce_benchmark_tester( + axis, c_contiguous, mode="NUMBA", benchmark=benchmark + )