From 97d6bdce032ab27d552bf07b8cd90af395f79964 Mon Sep 17 00:00:00 2001 From: Ch0ronomato Date: Fri, 1 Nov 2024 22:10:48 -0700 Subject: [PATCH] Address pr comments --- pytensor/link/pytorch/dispatch/elemwise.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/elemwise.py b/pytensor/link/pytorch/dispatch/elemwise.py index eca9cab989..72f97af1fa 100644 --- a/pytensor/link/pytorch/dispatch/elemwise.py +++ b/pytensor/link/pytorch/dispatch/elemwise.py @@ -21,13 +21,11 @@ def elemwise_fn(*inputs): def elemwise_fn(*inputs): Elemwise._check_runtime_broadcast(node, inputs) - shaped_inputs = torch.broadcast_tensors(*inputs) + broadcast_inputs = torch.broadcast_tensors(*inputs) ufunc = base_fn - for _ in range(shaped_inputs[0].dim()): + for _ in range(broadcast_inputs[0].dim()): ufunc = torch.vmap(ufunc) - # @todo: This will fail for anything that calls - # `.item()` - return ufunc(*shaped_inputs) + return ufunc(*broadcast_inputs) return elemwise_fn