Skip to content

Commit

Permalink
fix test setup
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier committed Oct 7, 2022
1 parent ad4d424 commit d8945e6
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 6 deletions.
2 changes: 1 addition & 1 deletion test/prototype_common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ def make_video_loader(
num_frames = int(torch.randint(1, 5, ())) if num_frames == "random" else num_frames

def fn(shape, dtype, device):
video = make_image(size=shape[-2:], color_space=color_space, extra_dims=shape[:-2], dtype=dtype, device=device)
video = make_image(size=shape[-2:], color_space=color_space, extra_dims=shape[:-3], dtype=dtype, device=device)
return features.Video(video, color_space=color_space)

return VideoLoader(
Expand Down
7 changes: 2 additions & 5 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,8 @@ def test_mixup_cutmix(self, transform, input):
features.ColorSpace.RGB,
],
dtypes=[torch.uint8],
**(
dict(num_frames=[1, "random"], extra_dims=[()])
if fn is make_videos
else dict(extra_dims=[(4,)])
),
extra_dims=[(), (4,)],
**(dict(num_frames=["random"]) if fn is make_videos else dict()),
)
for fn in [
make_images,
Expand Down

0 comments on commit d8945e6

Please sign in to comment.