Skip to content

Commit

Permalink
Fixed torch randint incoherent sampling (compatible to random.randint) (
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 authored Jul 7, 2020
1 parent 86b6c3e commit b572d5e
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,8 +510,8 @@ def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int
if w == tw and h == th:
return 0, 0, h, w

i = torch.randint(0, h - th, size=(1, )).item()
j = torch.randint(0, w - tw, size=(1, )).item()
i = torch.randint(0, h - th + 1, size=(1, )).item()
j = torch.randint(0, w - tw + 1, size=(1, )).item()
return i, j, th, tw

def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"):
Expand Down Expand Up @@ -1433,8 +1433,8 @@ def get_params(
else:
v = torch.tensor(value)[:, None, None]

i = torch.randint(0, img_h - h, size=(1, )).item()
j = torch.randint(0, img_w - w, size=(1, )).item()
i = torch.randint(0, img_h - h + 1, size=(1, )).item()
j = torch.randint(0, img_w - w + 1, size=(1, )).item()
return i, j, h, w, v

# Return original image
Expand Down

0 comments on commit b572d5e

Please sign in to comment.