Skip to content

Commit

Permalink
Address pr comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Ch0ronomato committed Nov 6, 2024
1 parent a095076 commit 97d6bdc
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions pytensor/link/pytorch/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,11 @@ def elemwise_fn(*inputs):

def elemwise_fn(*inputs):
Elemwise._check_runtime_broadcast(node, inputs)
shaped_inputs = torch.broadcast_tensors(*inputs)
broadcast_inputs = torch.broadcast_tensors(*inputs)
ufunc = base_fn
for _ in range(shaped_inputs[0].dim()):
for _ in range(broadcast_inputs[0].dim()):
ufunc = torch.vmap(ufunc)
# @todo: This will fail for anything that calls
# `.item()`
return ufunc(*shaped_inputs)
return ufunc(*broadcast_inputs)

return elemwise_fn

Expand Down

0 comments on commit 97d6bdc

Please sign in to comment.