diff --git a/tests/link/pytorch/test_blockwise.py b/tests/link/pytorch/test_blockwise.py index 762f9b985e..d0678fd2c4 100644 --- a/tests/link/pytorch/test_blockwise.py +++ b/tests/link/pytorch/test_blockwise.py @@ -12,7 +12,7 @@ basic = pytest.importorskip("pytensor.link.pytorch.dispatch.basic") -class TestOp(Op): +class BatchedTestOp(Op): gufunc_signature = "(m,n),(n,p)->(m,p)" def __init__(self, final_shape): @@ -27,7 +27,7 @@ def perform(self, *_): raise RuntimeError("In perform") -@basic.pytorch_funcify.register(TestOp) +@basic.pytorch_funcify.register(BatchedTestOp) def evaluate_test_op(op, **_): def func(a, b): op.call_shapes.extend(map(torch.Tensor.size, [a, b])) @@ -42,7 +42,7 @@ def test_blockwise_broadcast(): x = pt.tensor4("x", shape=(5, 1, 2, 3)) y = pt.tensor3("y", shape=(3, 3, 2)) - op = TestOp((2, 2)) + op = BatchedTestOp((2, 2)) z = Blockwise(op)(x, y) f = pytensor.function([x, y], z, mode="PYTORCH")