From b6dd1631d01ad37581f27e39ebbe8866ac1bef73 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Tue, 18 Jun 2019 01:52:39 +0800 Subject: [PATCH 1/2] speed up median_filter2d --- tensorflow_addons/image/filters.py | 54 +++++++++++++++++------------- 1 file changed, 30 insertions(+), 24 deletions(-) diff --git a/tensorflow_addons/image/filters.py b/tensorflow_addons/image/filters.py index 503d4657c9..ae5d9e4c9f 100644 --- a/tensorflow_addons/image/filters.py +++ b/tensorflow_addons/image/filters.py @@ -174,34 +174,40 @@ def median_filter2d(image, if rank == 3: image = tf.expand_dims(image, axis=0) + image_shape = tf.shape(image) + batch_size = image_shape[0] + height = image_shape[1] + width = image_shape[2] + channels = image_shape[3] + # Explicitly pad the image image = _pad( image, filter_shape, mode=padding, constant_values=constant_values) - floor = (filter_shape[0] * filter_shape[1] + 1) // 2 - ceil = (filter_shape[0] * filter_shape[1]) // 2 + 1 - - def _median_filter2d_single_channel(x): - x = tf.expand_dims(x, axis=-1) - patches = tf.image.extract_patches( - x, - sizes=[1, filter_shape[0], filter_shape[1], 1], - strides=[1, 1, 1, 1], - rates=[1, 1, 1, 1], - padding="VALID") - - # Note the returned median is casted back to the original type - # Take [5, 6, 7, 8] for example, the median is (6 + 7) / 2 = 3.5 - # It turns out to be int(6.5) = 6 if the original type is int - top = tf.nn.top_k(patches, k=ceil).values - median = (top[:, :, :, floor - 1] + top[:, :, :, ceil - 1]) / 2 - return tf.dtypes.cast(median, x.dtype) - - output = tf.map_fn( - _median_filter2d_single_channel, - elems=tf.transpose(image, [3, 0, 1, 2]), - dtype=image.dtype) - output = tf.transpose(output, [1, 2, 3, 0]) + area = filter_shape[0] * filter_shape[1] + + floor = (area + 1) // 2 + ceil = area // 2 + 1 + + patches = tf.image.extract_patches( + image, + sizes=[1, filter_shape[0], filter_shape[1], 1], + strides=[1, 1, 1, 1], + rates=[1, 1, 1, 1], + padding="VALID") + + patches = tf.reshape( + patches, shape=[batch_size, height, width, area, channels]) + + patches = tf.transpose(patches, [0, 1, 2, 4, 3]) + + # Note the returned median is casted back to the original type + # Take [5, 6, 7, 8] for example, the median is (6 + 7) / 2 = 3.5 + # It turns out to be int(6.5) = 6 if the original type is int + top = tf.nn.top_k(patches, k=ceil).values + output = tf.cast( + (top[:, :, :, :, floor - 1] + top[:, :, :, :, ceil - 1]) / 2, + image.dtype) # Squeeze out the first axis to make sure # output has the same dimension with image. From a5ad3d796266c4ba27856b112b78f2452b23402c Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Tue, 18 Jun 2019 16:42:01 +0800 Subject: [PATCH 2/2] split odd and even case --- tensorflow_addons/image/filters.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tensorflow_addons/image/filters.py b/tensorflow_addons/image/filters.py index ae5d9e4c9f..861badd256 100644 --- a/tensorflow_addons/image/filters.py +++ b/tensorflow_addons/image/filters.py @@ -205,9 +205,13 @@ def median_filter2d(image, # Take [5, 6, 7, 8] for example, the median is (6 + 7) / 2 = 3.5 # It turns out to be int(6.5) = 6 if the original type is int top = tf.nn.top_k(patches, k=ceil).values - output = tf.cast( - (top[:, :, :, :, floor - 1] + top[:, :, :, :, ceil - 1]) / 2, - image.dtype) + if area % 2 == 1: + median = top[:, :, :, :, floor - 1] + else: + median = ( + top[:, :, :, :, floor - 1] + top[:, :, :, :, ceil - 1]) / 2 + + output = tf.cast(median, image.dtype) # Squeeze out the first axis to make sure # output has the same dimension with image.