Skip to content

Commit

Permalink
Improve transforms test codebase (pytorch#2620)
Browse files Browse the repository at this point in the history
* Improve transforms test codebase
- refactored compareTensorToPIL, _create_data, approxEqualTensorToPIL methods

* Fixed flake8
  • Loading branch information
vfdev-5 authored and bryant1410 committed Nov 22, 2020
1 parent 391b1d5 commit 511d0da
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 32 deletions.
28 changes: 28 additions & 0 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from torch._six import string_classes
from collections import OrderedDict

import numpy as np
from PIL import Image


@contextlib.contextmanager
def get_tmp_dir(src=None, **kwargs):
Expand Down Expand Up @@ -329,3 +332,28 @@ def freeze_rng_state():
if torch.cuda.is_available():
torch.cuda.set_rng_state(cuda_rng_state)
torch.set_rng_state(rng_state)


class TransformsTester(unittest.TestCase):

def _create_data(self, height=3, width=3, channels=3, device="cpu"):
tensor = torch.randint(0, 255, (channels, height, width), dtype=torch.uint8, device=device)
pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().cpu().numpy())
return tensor, pil_img

def compareTensorToPIL(self, tensor, pil_image, msg=None):
np_pil_image = np.array(pil_image)
if np_pil_image.ndim == 2:
np_pil_image = np_pil_image[:, :, None]
pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1)))
if msg is None:
msg = "tensor:\n{} \ndid not equal PIL tensor:\n{}".format(tensor, pil_tensor)
self.assertTrue(tensor.cpu().equal(pil_tensor), msg)

def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None, method="mean"):
pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1))).to(tensor)
err = getattr(torch, method)(tensor - pil_tensor).item()
self.assertTrue(
err < tol,
msg="{}: err={}, tol={}: \n{}\nvs\n{}".format(msg, err, tol, tensor[0, :10, :10], pil_tensor[0, :10, :10])
)
26 changes: 3 additions & 23 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,20 @@
import unittest
import random
import colorsys
import math

from PIL import Image
from PIL.Image import NEAREST, BILINEAR, BICUBIC

import numpy as np
from PIL.Image import NEAREST, BILINEAR, BICUBIC

import torch
import torchvision.transforms as transforms
import torchvision.transforms.functional_tensor as F_t
import torchvision.transforms.functional_pil as F_pil
import torchvision.transforms.functional as F

from common_utils import TransformsTester

class Tester(unittest.TestCase):

def _create_data(self, height=3, width=3, channels=3, device="cpu"):
tensor = torch.randint(0, 255, (channels, height, width), dtype=torch.uint8, device=device)
pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().cpu().numpy())
return tensor, pil_img

def compareTensorToPIL(self, tensor, pil_image, msg=None):
pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1)))
if msg is None:
msg = "tensor:\n{} \ndid not equal PIL tensor:\n{}".format(tensor, pil_tensor)
self.assertTrue(tensor.cpu().equal(pil_tensor), msg)

def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None):
pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1))).to(tensor)
mae = torch.abs(tensor - pil_tensor).mean().item()
self.assertTrue(
mae < tol,
msg="{}: mae={}, tol={}: \n{}\nvs\n{}".format(msg, mae, tol, tensor[0, :10, :10], pil_tensor[0, :10, :10])
)
class Tester(TransformsTester):

def _test_vflip(self, device):
script_vflip = torch.jit.script(F_t.vflip)
Expand Down
12 changes: 3 additions & 9 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,17 @@
import torch
from torchvision import transforms as T
from torchvision.transforms import functional as F
from PIL import Image

from PIL.Image import NEAREST, BILINEAR, BICUBIC

import numpy as np

import unittest

from common_utils import TransformsTester

class Tester(unittest.TestCase):
def _create_data(self, height=3, width=3, channels=3):
tensor = torch.randint(0, 255, (channels, height, width), dtype=torch.uint8)
pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().numpy())
return tensor, pil_img

def compareTensorToPIL(self, tensor, pil_image):
pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1)))
self.assertTrue(tensor.equal(pil_tensor))
class Tester(TransformsTester):

def _test_functional_geom_op(self, func, fn_kwargs):
if fn_kwargs is None:
Expand Down

0 comments on commit 511d0da

Please sign in to comment.