Skip to content

Commit

Permalink
changing for loop to torch.arange
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose committed Mar 26, 2024
1 parent 85d7e66 commit 8c37797
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,11 @@ def slice_scatter_decomposition(
for i, src_each_dim in enumerate(list(src_dim)):
if i != dim:
index_tensor_shape.append(src_each_dim)
for index in range(start, end, step):
cat_tensors.append(index * torch.ones(index_tensor_shape))
indices = torch.arange(start, end, step)
cat_tensors = tuple((indices.unsqueeze(1) * torch.ones(index_tensor_shape)))
index_tensor = torch.stack(cat_tensors, dim)
index_tensor = index_tensor.to(torch.int64).cuda()
output_tensor = torch.scatter(input_tensor, dim, index_tensor, src)
output_tensor = torch.scatter(input_tensor, dim, index_tensor, src_tensor)
return output_tensor


Expand Down

0 comments on commit 8c37797

Please sign in to comment.