Skip to content

Commit

Permalink
RandomGeoSampler: improve performance (#1968)
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart authored Apr 16, 2024
1 parent ef0530b commit 925b93f
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions torchgeo/samplers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,16 @@ def get_random_bounding_box(
"""
t_size = _to_tuple(size)

width = (bounds.maxx - bounds.minx - t_size[1]) // res
height = (bounds.maxy - bounds.miny - t_size[0]) // res
# May be negative if bounding box is smaller than patch size
width = (bounds.maxx - bounds.minx - t_size[1]) / res
height = (bounds.maxy - bounds.miny - t_size[0]) / res

minx = bounds.minx
miny = bounds.miny

# random.randrange crashes for inputs <= 0
if width > 0:
minx += torch.rand(1).item() * width * res
if height > 0:
miny += torch.rand(1).item() * height * res
# Use an integer multiple of res to avoid resampling
minx += int(torch.rand(1).item() * width) * res
miny += int(torch.rand(1).item() * height) * res

maxx = minx + t_size[1]
maxy = miny + t_size[0]
Expand Down

0 comments on commit 925b93f

Please sign in to comment.