Skip to content

Commit

Permalink
Group numba benchmark tests in same class
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Feb 3, 2025
1 parent 2f2d0d3 commit 884dee9
Showing 1 changed file with 68 additions and 70 deletions.
138 changes: 68 additions & 70 deletions tests/link/numba/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,25 +130,6 @@ def test_elemwise_runtime_broadcast():
check_elemwise_runtime_broadcast(get_mode("NUMBA"))


def test_elemwise_speed(benchmark):
x = pt.dmatrix("y")
y = pt.dvector("z")

out = np.exp(2 * x * y + y)

rng = np.random.default_rng(42)

x_val = rng.normal(size=(200, 500))
y_val = rng.normal(size=500)

func = function([x, y], out, mode="NUMBA")
func = func.vm.jit_fn
(out,) = func(x_val, y_val)
np.testing.assert_allclose(np.exp(2 * x_val * y_val + y_val), out)

benchmark(func, x_val, y_val)


@pytest.mark.parametrize(
"v, new_order",
[
Expand Down Expand Up @@ -631,41 +612,6 @@ def test_Argmax(x, axes, exc):
)


@pytest.mark.parametrize("size", [(10, 10), (1000, 1000), (10000, 10000)])
@pytest.mark.parametrize("axis", [0, 1])
def test_logsumexp_benchmark(size, axis, benchmark):
X = pt.matrix("X")
X_max = pt.max(X, axis=axis, keepdims=True)
X_max = pt.switch(pt.isinf(X_max), 0, X_max)
X_lse = pt.log(pt.sum(pt.exp(X - X_max), axis=axis, keepdims=True)) + X_max

rng = np.random.default_rng(23920)
X_val = rng.normal(size=size)

X_lse_fn = pytensor.function([X], X_lse, mode="NUMBA")

# JIT compile first
res = X_lse_fn(X_val)
exp_res = scipy.special.logsumexp(X_val, axis=axis, keepdims=True)
np.testing.assert_array_almost_equal(res, exp_res)
benchmark(X_lse_fn, X_val)


def test_fused_elemwise_benchmark(benchmark):
rng = np.random.default_rng(123)
size = 100_000
x = pytensor.shared(rng.normal(size=size), name="x")
mu = pytensor.shared(rng.normal(size=size), name="mu")

logp = -((x - mu) ** 2) / 2
grad_logp = grad(logp.sum(), x)

func = pytensor.function([], [logp, grad_logp], mode="NUMBA")
# JIT compile first
func()
benchmark(func)


def test_elemwise_out_type():
# Create a graph with an elemwise
# Ravel failes if the elemwise output type is reported incorrectly
Expand All @@ -681,22 +627,6 @@ def test_elemwise_out_type():
assert func(x_val).shape == (18,)


@pytest.mark.parametrize(
"axis",
(0, 1, 2, (0, 1), (0, 2), (1, 2), None),
ids=lambda x: f"axis={x}",
)
@pytest.mark.parametrize(
"c_contiguous",
(True, False),
ids=lambda x: f"c_contiguous={x}",
)
def test_numba_careduce_benchmark(axis, c_contiguous, benchmark):
return careduce_benchmark_tester(
axis, c_contiguous, mode="NUMBA", benchmark=benchmark
)


def test_scalar_loop():
a = float64("a")
scalar_loop = pytensor.scalar.ScalarLoop([a], [a + a])
Expand All @@ -709,3 +639,71 @@ def test_scalar_loop():
([x], [elemwise_loop]),
(np.array([1, 2, 3], dtype="float64"),),
)


class TestsBenchmark:
def test_elemwise_speed(self, benchmark):
x = pt.dmatrix("y")
y = pt.dvector("z")

out = np.exp(2 * x * y + y)

rng = np.random.default_rng(42)

x_val = rng.normal(size=(200, 500))
y_val = rng.normal(size=500)

func = function([x, y], out, mode="NUMBA")
func = func.vm.jit_fn
(out,) = func(x_val, y_val)
np.testing.assert_allclose(np.exp(2 * x_val * y_val + y_val), out)

benchmark(func, x_val, y_val)

def test_fused_elemwise_benchmark(self, benchmark):
rng = np.random.default_rng(123)
size = 100_000
x = pytensor.shared(rng.normal(size=size), name="x")
mu = pytensor.shared(rng.normal(size=size), name="mu")

logp = -((x - mu) ** 2) / 2
grad_logp = grad(logp.sum(), x)

func = pytensor.function([], [logp, grad_logp], mode="NUMBA")
# JIT compile first
func()
benchmark(func)

@pytest.mark.parametrize("size", [(10, 10), (1000, 1000), (10000, 10000)])
@pytest.mark.parametrize("axis", [0, 1])
def test_logsumexp_benchmark(self, size, axis, benchmark):
X = pt.matrix("X")
X_max = pt.max(X, axis=axis, keepdims=True)
X_max = pt.switch(pt.isinf(X_max), 0, X_max)
X_lse = pt.log(pt.sum(pt.exp(X - X_max), axis=axis, keepdims=True)) + X_max

rng = np.random.default_rng(23920)
X_val = rng.normal(size=size)

X_lse_fn = pytensor.function([X], X_lse, mode="NUMBA")

# JIT compile first
res = X_lse_fn(X_val)
exp_res = scipy.special.logsumexp(X_val, axis=axis, keepdims=True)
np.testing.assert_array_almost_equal(res, exp_res)
benchmark(X_lse_fn, X_val)

@pytest.mark.parametrize(
"axis",
(0, 1, 2, (0, 1), (0, 2), (1, 2), None),
ids=lambda x: f"axis={x}",
)
@pytest.mark.parametrize(
"c_contiguous",
(True, False),
ids=lambda x: f"c_contiguous={x}",
)
def test_numba_careduce_benchmark(self, axis, c_contiguous, benchmark):
return careduce_benchmark_tester(
axis, c_contiguous, mode="NUMBA", benchmark=benchmark
)

0 comments on commit 884dee9

Please sign in to comment.