diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index dca9b1f7282a..ec32adf90000 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -173,3 +173,4 @@ List of Contributors * [Jesse Brizzi](https://github.com/jessebrizzi) * [Hang Zhang](http://hangzh.com) * [Kou Ding](https://github.com/chinakook) +* [Istvan Fehervari](https://github.com/ifeherva) diff --git a/python/mxnet/image/image.py b/python/mxnet/image/image.py index b4b9cc2f1c08..5af2b9556651 100644 --- a/python/mxnet/image/image.py +++ b/python/mxnet/image/image.py @@ -839,8 +839,8 @@ class ColorNormalizeAug(Augmenter): """ def __init__(self, mean, std): super(ColorNormalizeAug, self).__init__(mean=mean, std=std) - self.mean = nd.array(mean) if mean is not None else None - self.std = nd.array(std) if std is not None else None + self.mean = mean if mean is None or isinstance(mean, nd.NDArray) else nd.array(mean) + self.std = std if std is None or isinstance(std, nd.NDArray) else nd.array(std) def __call__(self, src): """Augmenter body""" @@ -999,14 +999,14 @@ def CreateAugmenter(data_shape, resize=0, rand_crop=False, rand_resize=False, ra auglist.append(RandomGrayAug(rand_gray)) if mean is True: - mean = np.array([123.68, 116.28, 103.53]) + mean = nd.array([123.68, 116.28, 103.53]) elif mean is not None: - assert isinstance(mean, np.ndarray) and mean.shape[0] in [1, 3] + assert isinstance(mean, (np.ndarray, nd.NDArray)) and mean.shape[0] in [1, 3] if std is True: - std = np.array([58.395, 57.12, 57.375]) + std = nd.array([58.395, 57.12, 57.375]) elif std is not None: - assert isinstance(std, np.ndarray) and std.shape[0] in [1, 3] + assert isinstance(std, (np.ndarray, nd.NDArray)) and std.shape[0] in [1, 3] if mean is not None or std is not None: auglist.append(ColorNormalizeAug(mean, std)) diff --git a/tests/python/unittest/test_image.py b/tests/python/unittest/test_image.py index 78c3ce14eb43..636c5e2be67c 100644 --- a/tests/python/unittest/test_image.py +++ b/tests/python/unittest/test_image.py @@ -18,7 +18,7 @@ import mxnet as mx import numpy as np from mxnet.test_utils import * -from common import assertRaises +from common import assertRaises, with_seed import shutil import tempfile import unittest @@ -153,8 +153,19 @@ def test_imageiter(self): for batch in test_iter: pass - + @with_seed() def test_augmenters(self): + # ColorNormalizeAug + mean = np.random.rand(3) * 255 + std = np.random.rand(3) + 1 + width = np.random.randint(100, 500) + height = np.random.randint(100, 500) + src = np.random.rand(height, width, 3) * 255. + # We test numpy and mxnet NDArray inputs + color_norm_aug = mx.image.ColorNormalizeAug(mean=mx.nd.array(mean), std=std) + out_image = color_norm_aug(mx.nd.array(src)) + assert_almost_equal(out_image.asnumpy(), (src - mean) / std, atol=1e-3) + # only test if all augmenters will work # TODO(Joshua Zhang): verify the augmenter outputs im_list = [[0, x] for x in TestImage.IMAGES]