Skip to content

Commit

Permalink
Fix test warning
Browse files Browse the repository at this point in the history
  • Loading branch information
Ian Schweer authored and ricardoV94 committed Nov 25, 2024
1 parent 7300a68 commit ae66e82
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tests/link/pytorch/test_blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
basic = pytest.importorskip("pytensor.link.pytorch.dispatch.basic")


class TestOp(Op):
class BatchedTestOp(Op):
gufunc_signature = "(m,n),(n,p)->(m,p)"

def __init__(self, final_shape):
Expand All @@ -27,7 +27,7 @@ def perform(self, *_):
raise RuntimeError("In perform")


@basic.pytorch_funcify.register(TestOp)
@basic.pytorch_funcify.register(BatchedTestOp)
def evaluate_test_op(op, **_):
def func(a, b):
op.call_shapes.extend(map(torch.Tensor.size, [a, b]))
Expand All @@ -42,7 +42,7 @@ def test_blockwise_broadcast():

x = pt.tensor4("x", shape=(5, 1, 2, 3))
y = pt.tensor3("y", shape=(3, 3, 2))
op = TestOp((2, 2))
op = BatchedTestOp((2, 2))
z = Blockwise(op)(x, y)

f = pytensor.function([x, y], z, mode="PYTORCH")
Expand Down

0 comments on commit ae66e82

Please sign in to comment.