From daffe602f2d4570fc4d073a17617eae62969ca92 Mon Sep 17 00:00:00 2001 From: ugeunpark Date: Tue, 21 Jan 2025 22:27:55 +0900 Subject: [PATCH 1/5] Add initial random_erasing --- keras/api/_tf_keras/keras/layers/__init__.py | 3 + keras/api/layers/__init__.py | 3 + .../image_preprocessing/random_erasing.py | 119 ++++++++++++++++++ 3 files changed, 125 insertions(+) create mode 100644 keras/src/layers/preprocessing/image_preprocessing/random_erasing.py diff --git a/keras/api/_tf_keras/keras/layers/__init__.py b/keras/api/_tf_keras/keras/layers/__init__.py index 77e16bf97f0d..f4e7c56b82e5 100644 --- a/keras/api/_tf_keras/keras/layers/__init__.py +++ b/keras/api/_tf_keras/keras/layers/__init__.py @@ -172,6 +172,9 @@ from keras.src.layers.preprocessing.image_preprocessing.random_crop import ( RandomCrop, ) +from keras.src.layers.preprocessing.image_preprocessing.random_erasing import ( + RandomErasing, +) from keras.src.layers.preprocessing.image_preprocessing.random_flip import ( RandomFlip, ) diff --git a/keras/api/layers/__init__.py b/keras/api/layers/__init__.py index d925484d03bf..f74cc7dac6f9 100644 --- a/keras/api/layers/__init__.py +++ b/keras/api/layers/__init__.py @@ -172,6 +172,9 @@ from keras.src.layers.preprocessing.image_preprocessing.random_crop import ( RandomCrop, ) +from keras.src.layers.preprocessing.image_preprocessing.random_erasing import ( + RandomErasing, +) from keras.src.layers.preprocessing.image_preprocessing.random_flip import ( RandomFlip, ) diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_erasing.py b/keras/src/layers/preprocessing/image_preprocessing/random_erasing.py new file mode 100644 index 000000000000..042d2e4039d9 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_erasing.py @@ -0,0 +1,119 @@ +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) +from keras.src.random import SeedGenerator +import tensorflow as tf + + +@keras_export("keras.layers.RandomErasing") +class RandomErasing(BaseImagePreprocessingLayer): + """CutMix data augmentation technique. + + CutMix is a data augmentation method where patches are cut and pasted + between two images in the dataset, while the labels are also mixed + proportionally to the area of the patches. + + Args: + factor: A single float or a tuple of two floats between 0 and 1. + If a tuple of numbers is passed, a `factor` is sampled + between the two values. + If a single float is passed, a value between 0 and the passed + float is sampled. These values define the range from which the + mixing weight is sampled. A higher factor increases the variability + in patch sizes, leading to more diverse and larger mixed patches. + Defaults to 1. + seed: Integer. Used to create a random seed. + + References: + - [CutMix paper]( https://arxiv.org/abs/1905.04899). + """ + + _USE_BASE_FACTOR = False + _FACTOR_BOUNDS = (0, 1) + + def __init__(self, factor=0.3, seed=None, data_format=None, **kwargs): + super().__init__(data_format=data_format, **kwargs) + self._set_factor(factor) + self.seed = seed + self.generator = SeedGenerator(seed) + + if self.data_format == "channels_first": + self.height_axis = -2 + self.width_axis = -1 + self.channel_axis = -3 + else: + self.height_axis = -3 + self.width_axis = -2 + self.channel_axis = -1 + + def transform_images(self, images, transformation=None, training=True): + + batch_size, image_height, image_width, image_channel = ( + images.shape[0], + images.shape[self.height_axis], + images.shape[self.width_axis], + images.shape[self.channel_axis]) + + area = image_height * image_width + + scale = (0., 1.) + ratio = (0., 1.) + + min_area = area * scale[0] + max_area = area * scale[1] + min_aspect_ratio = ratio[0] + max_aspect_ratio = ratio[1] + + target_area = self.backend.random.uniform((), + min_area, + max_area, + dtype=self.compute_dtype) + + aspect_ratio = self.backend.random.uniform((), + min_aspect_ratio, + max_aspect_ratio, + dtype=self.compute_dtype) + + h = self.backend.cast(self.backend.numpy.sqrt(target_area * aspect_ratio), dtype='int32') + w = self.backend.cast(self.backend.numpy.sqrt(target_area / aspect_ratio), dtype='int32') + + x = self.backend.random.randint((), 0, image_height - h) + y = self.backend.random.randint((), 0, image_width - w) + + v = self.backend.random.normal(shape=[h, w, image_channel]) + + images = self.backend.convert_to_numpy(images) + images[..., x:x + h, y:y + w, :] = v + + tf.print(x, y, w, h) + images = self.backend.convert_to_tensor(images) + images = self.backend.cast(images, self.compute_dtype) + return images + + def transform_labels(self, labels, transformation, training=True): + return labels + + def transform_bounding_boxes( + self, + bounding_boxes, + transformation, + training=True, + ): + return bounding_boxes + + def transform_segmentation_masks( + self, segmentation_masks, transformation, training=True + ): + return segmentation_masks + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + config = { + "factor": self.factor, + "seed": self.seed, + } + base_config = super().get_config() + return {**base_config, **config} From ccac985c47a5e672b66daf66c1178ee597dfc14f Mon Sep 17 00:00:00 2001 From: ugeunpark Date: Wed, 22 Jan 2025 23:27:59 +0900 Subject: [PATCH 2/5] Update random_erasing logic --- .../image_preprocessing/random_erasing.py | 263 +++++++++++++++--- 1 file changed, 226 insertions(+), 37 deletions(-) diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_erasing.py b/keras/src/layers/preprocessing/image_preprocessing/random_erasing.py index 042d2e4039d9..188011c42803 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_erasing.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_erasing.py @@ -3,7 +3,6 @@ BaseImagePreprocessingLayer, ) from keras.src.random import SeedGenerator -import tensorflow as tf @keras_export("keras.layers.RandomErasing") @@ -32,9 +31,21 @@ class RandomErasing(BaseImagePreprocessingLayer): _USE_BASE_FACTOR = False _FACTOR_BOUNDS = (0, 1) - def __init__(self, factor=0.3, seed=None, data_format=None, **kwargs): + def __init__( + self, + factor=1.0, + ratio=(0.02, 0.33), + fill_value=None, + value_range=(0, 255), + seed=None, + data_format=None, + **kwargs, + ): super().__init__(data_format=data_format, **kwargs) self._set_factor(factor) + self.ratio = self._set_factor_by_name(ratio, "ratio") + self.fill_value = fill_value + self.value_range = value_range self.seed = seed self.generator = SeedGenerator(seed) @@ -47,47 +58,222 @@ def __init__(self, factor=0.3, seed=None, data_format=None, **kwargs): self.width_axis = -2 self.channel_axis = -1 - def transform_images(self, images, transformation=None, training=True): + def _set_factor_by_name(self, factor, name): + error_msg = ( + f"The `{name}` argument should be a number " + "(or a list of two numbers) " + "in the range " + f"[{self._FACTOR_BOUNDS[0]}, {self._FACTOR_BOUNDS[1]}]. " + f"Received: factor={factor}" + ) + if isinstance(factor, (tuple, list)): + if len(factor) != 2: + raise ValueError(error_msg) + if ( + factor[0] > self._FACTOR_BOUNDS[1] + or factor[1] < self._FACTOR_BOUNDS[0] + ): + raise ValueError(error_msg) + lower, upper = sorted(factor) + elif isinstance(factor, (int, float)): + if ( + factor < self._FACTOR_BOUNDS[0] + or factor > self._FACTOR_BOUNDS[1] + ): + raise ValueError(error_msg) + factor = abs(factor) + lower, upper = [max(-factor, self._FACTOR_BOUNDS[0]), factor] + else: + raise ValueError(error_msg) + return lower, upper + + def _compute_crop_bounds(self, batch_size, image_length, crop_ratio, seed): + crop_length = self.backend.cast( + crop_ratio * image_length, dtype=self.compute_dtype + ) + + start_pos = self.backend.random.uniform( + shape=[batch_size], + minval=0, + maxval=1, + dtype=self.compute_dtype, + seed=seed, + ) * (image_length - crop_length) + + end_pos = start_pos + crop_length + + return start_pos, end_pos - batch_size, image_height, image_width, image_channel = ( - images.shape[0], - images.shape[self.height_axis], - images.shape[self.width_axis], - images.shape[self.channel_axis]) + def _generate_batch_mask(self, images_shape, box_corners): + def _generate_grid_xy(image_height, image_width): + grid_y, grid_x = self.backend.numpy.meshgrid( + self.backend.numpy.arange( + image_height, dtype=self.compute_dtype + ), + self.backend.numpy.arange( + image_width, dtype=self.compute_dtype + ), + indexing="ij", + ) + if self.data_format == "channels_last": + grid_y = self.backend.cast( + grid_y[None, :, :, None], dtype=self.compute_dtype + ) + grid_x = self.backend.cast( + grid_x[None, :, :, None], dtype=self.compute_dtype + ) + else: + grid_y = self.backend.cast( + grid_y[None, None, :, :], dtype=self.compute_dtype + ) + grid_x = self.backend.cast( + grid_x[None, None, :, :], dtype=self.compute_dtype + ) + return grid_x, grid_y - area = image_height * image_width + image_height, image_width = ( + images_shape[self.height_axis], + images_shape[self.width_axis], + ) + grid_x, grid_y = _generate_grid_xy(image_height, image_width) - scale = (0., 1.) - ratio = (0., 1.) + x0, x1, y0, y1 = box_corners - min_area = area * scale[0] - max_area = area * scale[1] - min_aspect_ratio = ratio[0] - max_aspect_ratio = ratio[1] + x0 = x0[:, None, None, None] + y0 = y0[:, None, None, None] + x1 = x1[:, None, None, None] + y1 = y1[:, None, None, None] - target_area = self.backend.random.uniform((), - min_area, - max_area, - dtype=self.compute_dtype) + batch_masks = ( + (grid_x >= x0) & (grid_x < x1) & (grid_y >= y0) & (grid_y < y1) + ) + batch_masks = self.backend.numpy.repeat( + batch_masks, images_shape[self.channel_axis], axis=self.channel_axis + ) - aspect_ratio = self.backend.random.uniform((), - min_aspect_ratio, - max_aspect_ratio, - dtype=self.compute_dtype) + return batch_masks - h = self.backend.cast(self.backend.numpy.sqrt(target_area * aspect_ratio), dtype='int32') - w = self.backend.cast(self.backend.numpy.sqrt(target_area / aspect_ratio), dtype='int32') + def _get_fill_value(self, images, images_shape): + fill_value = self.fill_value + if fill_value is None: + fill_value = self.backend.random.normal( + images_shape, dtype=self.compute_dtype + ) + else: + error_msg = ( + "The `fill_value` argument should be a number " + "(or a list of three numbers) " + ) + if isinstance(fill_value, (tuple, list)): + if len(fill_value) != 3: + raise ValueError(error_msg) + fill_value = self.backend.numpy.full_like( + images, fill_value, dtype=self.compute_dtype + ) + elif isinstance(fill_value, (int, float)): + fill_value = ( + self.backend.numpy.ones( + images_shape, dtype=self.compute_dtype + ) + * fill_value + ) + else: + raise ValueError(error_msg) + fill_value = self.backend.numpy.clip( + fill_value, self.value_range[0], self.value_range[1] + ) + return fill_value + + def get_random_transformation(self, data, training=True, seed=None): + if not training: + return None + + if isinstance(data, dict): + images = data["images"] + else: + images = data + + images_shape = self.backend.shape(images) + rank = len(images_shape) + if rank == 3: + batch_size = 1 + elif rank == 4: + batch_size = images_shape[0] + else: + raise ValueError( + "Expected the input image to be rank 3 or 4. Received " + f"inputs.shape={images_shape}" + ) - x = self.backend.random.randint((), 0, image_height - h) - y = self.backend.random.randint((), 0, image_width - w) + image_height = images_shape[self.height_axis] + image_width = images_shape[self.width_axis] + + seed = seed or self._get_seed_generator(self.backend._backend) + + mix_weight = self.backend.random.uniform( + shape=(batch_size, 2), + minval=self.ratio[0], + maxval=self.ratio[1], + dtype=self.compute_dtype, + seed=seed, + ) + + mix_weight = self.backend.numpy.sqrt(mix_weight) + + x0, x1 = self._compute_crop_bounds( + batch_size, image_width, mix_weight[:, 0], seed + ) + y0, y1 = self._compute_crop_bounds( + batch_size, image_height, mix_weight[:, 1], seed + ) + + batch_masks = self._generate_batch_mask( + images_shape, + (x0, x1, y0, y1), + ) + + erase_probability = self.backend.random.uniform( + shape=(batch_size,), + minval=self.factor[0], + maxval=self.factor[1], + seed=seed, + ) + + random_threshold = self.backend.random.uniform( + shape=(batch_size,), + minval=0.0, + maxval=1.0, + seed=seed, + ) + apply_erasing = random_threshold < erase_probability + + fill_value = self._get_fill_value(images, images_shape) + + return { + "apply_erasing": apply_erasing, + "batch_masks": batch_masks, + "fill_value": fill_value, + } + + def transform_images(self, images, transformation=None, training=True): + if training: + images = self.backend.cast(images, self.compute_dtype) + batch_masks = transformation["batch_masks"] + apply_erasing = transformation["apply_erasing"] + fill_value = transformation["fill_value"] - v = self.backend.random.normal(shape=[h, w, image_channel]) + erased_images = self.backend.numpy.where( + batch_masks, + fill_value, + images, + ) - images = self.backend.convert_to_numpy(images) - images[..., x:x + h, y:y + w, :] = v + images = self.backend.numpy.where( + apply_erasing[:, None, None, None], + erased_images, + images, + ) - tf.print(x, y, w, h) - images = self.backend.convert_to_tensor(images) images = self.backend.cast(images, self.compute_dtype) return images @@ -95,15 +281,15 @@ def transform_labels(self, labels, transformation, training=True): return labels def transform_bounding_boxes( - self, - bounding_boxes, - transformation, - training=True, + self, + bounding_boxes, + transformation, + training=True, ): return bounding_boxes def transform_segmentation_masks( - self, segmentation_masks, transformation, training=True + self, segmentation_masks, transformation, training=True ): return segmentation_masks @@ -113,6 +299,9 @@ def compute_output_shape(self, input_shape): def get_config(self): config = { "factor": self.factor, + "ratio": self.ratio, + "fill_value": self.fill_value, + "value_range": self.value_range, "seed": self.seed, } base_config = super().get_config() From c96728717ff6216951f29ad444c1fbb16690f8e2 Mon Sep 17 00:00:00 2001 From: ugeunpark Date: Thu, 23 Jan 2025 19:11:46 +0900 Subject: [PATCH 3/5] Update description and add test case --- keras/src/layers/__init__.py | 3 + .../image_preprocessing/random_erasing.py | 47 ++++++---- .../random_erasing_test.py | 91 +++++++++++++++++++ 3 files changed, 123 insertions(+), 18 deletions(-) create mode 100644 keras/src/layers/preprocessing/image_preprocessing/random_erasing_test.py diff --git a/keras/src/layers/__init__.py b/keras/src/layers/__init__.py index 7071495a2f44..2628addd4a08 100644 --- a/keras/src/layers/__init__.py +++ b/keras/src/layers/__init__.py @@ -116,6 +116,9 @@ from keras.src.layers.preprocessing.image_preprocessing.random_crop import ( RandomCrop, ) +from keras.src.layers.preprocessing.image_preprocessing.random_erasing import ( + RandomErasing, +) from keras.src.layers.preprocessing.image_preprocessing.random_flip import ( RandomFlip, ) diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_erasing.py b/keras/src/layers/preprocessing/image_preprocessing/random_erasing.py index 188011c42803..af7a2794d508 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_erasing.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_erasing.py @@ -7,25 +7,36 @@ @keras_export("keras.layers.RandomErasing") class RandomErasing(BaseImagePreprocessingLayer): - """CutMix data augmentation technique. + """Random Erasing data augmentation technique. - CutMix is a data augmentation method where patches are cut and pasted - between two images in the dataset, while the labels are also mixed - proportionally to the area of the patches. + Random Erasing is a data augmentation method where random patches of + an image are erased (replaced by a constant value or noise) + during training to improve generalization. Args: - factor: A single float or a tuple of two floats between 0 and 1. - If a tuple of numbers is passed, a `factor` is sampled - between the two values. - If a single float is passed, a value between 0 and the passed - float is sampled. These values define the range from which the - mixing weight is sampled. A higher factor increases the variability - in patch sizes, leading to more diverse and larger mixed patches. - Defaults to 1. + factor: A single float or a tuple of two floats. + `factor` controls the extent to which the image hue is impacted. + `factor=0.0` makes this layer perform a no-op operation, + while a value of `1.0` performs the most aggressive + erasing available. If a tuple is used, a `factor` is + sampled between the two values for every image augmented. If a + single float is used, a value between `0.0` and the passed float is + sampled. Default is 1.0. + scale: A tuple of two floats representing the aspect ratio range of + the erased patch. This defines the width-to-height ratio of + the patch to be erased. It can help control the rw shape of + the erased region. Default is (0.02, 0.33). + fill_value: A value to fill the erased region with. This can be set to + a constant value or `None` to sample a random value + from a normal distribution. Default is `None`. + value_range: the range of values the incoming images will have. + Represented as a two-number tuple written `[low, high]`. This is + typically either `[0, 1]` or `[0, 255]` depending on how your + preprocessing pipeline is set up. seed: Integer. Used to create a random seed. References: - - [CutMix paper]( https://arxiv.org/abs/1905.04899). + - [Random Erasing paper](https://arxiv.org/abs/1708.04896). """ _USE_BASE_FACTOR = False @@ -34,7 +45,7 @@ class RandomErasing(BaseImagePreprocessingLayer): def __init__( self, factor=1.0, - ratio=(0.02, 0.33), + scale=(0.02, 0.33), fill_value=None, value_range=(0, 255), seed=None, @@ -43,7 +54,7 @@ def __init__( ): super().__init__(data_format=data_format, **kwargs) self._set_factor(factor) - self.ratio = self._set_factor_by_name(ratio, "ratio") + self.scale = self._set_factor_by_name(scale, "scale") self.fill_value = fill_value self.value_range = value_range self.seed = seed @@ -212,8 +223,8 @@ def get_random_transformation(self, data, training=True, seed=None): mix_weight = self.backend.random.uniform( shape=(batch_size, 2), - minval=self.ratio[0], - maxval=self.ratio[1], + minval=self.scale[0], + maxval=self.scale[1], dtype=self.compute_dtype, seed=seed, ) @@ -299,7 +310,7 @@ def compute_output_shape(self, input_shape): def get_config(self): config = { "factor": self.factor, - "ratio": self.ratio, + "scale": self.scale, "fill_value": self.fill_value, "value_range": self.value_range, "seed": self.seed, diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_erasing_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_erasing_test.py new file mode 100644 index 000000000000..1db6ae654eaa --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_erasing_test.py @@ -0,0 +1,91 @@ +import numpy as np +import pytest +from tensorflow import data as tf_data + +from keras.src import backend +from keras.src import layers +from keras.src import testing + + +class RandomErasingTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_layer(self): + self.run_layer_test( + layers.RandomErasing, + init_kwargs={ + "factor": 1.0, + "scale": 0.5, + "fill_value": 0, + "value_range": (0, 255), + "seed": 1, + }, + input_shape=(8, 3, 4, 3), + supports_masking=False, + expected_output_shape=(8, 3, 4, 3), + ) + + def test_random_erasing_inference(self): + seed = 3481 + layer = layers.RandomErasing() + + np.random.seed(seed) + inputs = np.random.randint(0, 255, size=(224, 224, 3)) + output = layer(inputs, training=False) + self.assertAllClose(inputs, output) + + def test_random_erasing_no_op(self): + seed = 3481 + layer = layers.RandomErasing(factor=0) + + np.random.seed(seed) + inputs = np.random.randint(0, 255, size=(224, 224, 3)) + output = layer(inputs) + self.assertAllClose(inputs, output) + + layer = layers.RandomErasing(scale=0) + + np.random.seed(seed) + inputs = np.random.randint(0, 255, size=(224, 224, 3)) + output = layer(inputs) + self.assertAllClose(inputs, output) + + def test_random_erasing_basic(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + inputs = np.ones((2, 2, 1)) + expected_output = np.array([[[[0.0], [1.0]], [[1.0], [1.0]]]]) + + else: + inputs = np.ones((1, 2, 2)) + + expected_output = np.array( + [[[[0.0, 0.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]] + ) + + layer = layers.RandomErasing(data_format=data_format) + + transformation = { + "apply_erasing": np.asarray([True]), + "batch_masks": np.asarray( + [[[[True], [False]], [[False], [False]]]] + ), + "fill_value": 0, + } + + output = layer.transform_images(inputs, transformation) + + print(output) + + self.assertAllClose(expected_output, output) + + def test_tf_data_compatibility(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + input_data = np.random.random((2, 8, 8, 3)) + else: + input_data = np.random.random((2, 3, 8, 8)) + layer = layers.RandomErasing(data_format=data_format) + + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + for output in ds.take(1): + output.numpy() From 6c8bb85bc848700c105de58a2c65469e84eed381 Mon Sep 17 00:00:00 2001 From: ugeunpark Date: Thu, 23 Jan 2025 19:26:38 +0900 Subject: [PATCH 4/5] fix value range bug --- .../preprocessing/image_preprocessing/random_erasing.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_erasing.py b/keras/src/layers/preprocessing/image_preprocessing/random_erasing.py index af7a2794d508..da874279d68a 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_erasing.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_erasing.py @@ -167,8 +167,11 @@ def _generate_grid_xy(image_height, image_width): def _get_fill_value(self, images, images_shape): fill_value = self.fill_value if fill_value is None: - fill_value = self.backend.random.normal( - images_shape, dtype=self.compute_dtype + fill_value = ( + self.backend.random.normal( + images_shape, dtype=self.compute_dtype + ) + * self.value_range[1] ) else: error_msg = ( From fce9c0c345fc6f6c51e37668b0ed82954fea6305 Mon Sep 17 00:00:00 2001 From: ugeunpark Date: Fri, 24 Jan 2025 20:18:23 +0900 Subject: [PATCH 5/5] add seed for random fill_value --- .../preprocessing/image_preprocessing/random_erasing.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_erasing.py b/keras/src/layers/preprocessing/image_preprocessing/random_erasing.py index da874279d68a..b49d3dee93e1 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_erasing.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_erasing.py @@ -164,12 +164,14 @@ def _generate_grid_xy(image_height, image_width): return batch_masks - def _get_fill_value(self, images, images_shape): + def _get_fill_value(self, images, images_shape, seed): fill_value = self.fill_value if fill_value is None: fill_value = ( self.backend.random.normal( - images_shape, dtype=self.compute_dtype + images_shape, + dtype=self.compute_dtype, + seed=seed, ) * self.value_range[1] ) @@ -261,7 +263,7 @@ def get_random_transformation(self, data, training=True, seed=None): ) apply_erasing = random_threshold < erase_probability - fill_value = self._get_fill_value(images, images_shape) + fill_value = self._get_fill_value(images, images_shape, seed) return { "apply_erasing": apply_erasing,