Skip to content

Commit

Permalink
Speed up median_filter2d (#304)
Browse files Browse the repository at this point in the history
* speed up median_filter2d

* split odd and even case
  • Loading branch information
WindQAQ authored and facaiy committed Jun 18, 2019
1 parent ebebe3c commit dbadc7d
Showing 1 changed file with 34 additions and 24 deletions.
58 changes: 34 additions & 24 deletions tensorflow_addons/image/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,34 +174,44 @@ def median_filter2d(image,
if rank == 3:
image = tf.expand_dims(image, axis=0)

image_shape = tf.shape(image)
batch_size = image_shape[0]
height = image_shape[1]
width = image_shape[2]
channels = image_shape[3]

# Explicitly pad the image
image = _pad(
image, filter_shape, mode=padding, constant_values=constant_values)

floor = (filter_shape[0] * filter_shape[1] + 1) // 2
ceil = (filter_shape[0] * filter_shape[1]) // 2 + 1

def _median_filter2d_single_channel(x):
x = tf.expand_dims(x, axis=-1)
patches = tf.image.extract_patches(
x,
sizes=[1, filter_shape[0], filter_shape[1], 1],
strides=[1, 1, 1, 1],
rates=[1, 1, 1, 1],
padding="VALID")

# Note the returned median is casted back to the original type
# Take [5, 6, 7, 8] for example, the median is (6 + 7) / 2 = 3.5
# It turns out to be int(6.5) = 6 if the original type is int
top = tf.nn.top_k(patches, k=ceil).values
median = (top[:, :, :, floor - 1] + top[:, :, :, ceil - 1]) / 2
return tf.dtypes.cast(median, x.dtype)

output = tf.map_fn(
_median_filter2d_single_channel,
elems=tf.transpose(image, [3, 0, 1, 2]),
dtype=image.dtype)
output = tf.transpose(output, [1, 2, 3, 0])
area = filter_shape[0] * filter_shape[1]

floor = (area + 1) // 2
ceil = area // 2 + 1

patches = tf.image.extract_patches(
image,
sizes=[1, filter_shape[0], filter_shape[1], 1],
strides=[1, 1, 1, 1],
rates=[1, 1, 1, 1],
padding="VALID")

patches = tf.reshape(
patches, shape=[batch_size, height, width, area, channels])

patches = tf.transpose(patches, [0, 1, 2, 4, 3])

# Note the returned median is casted back to the original type
# Take [5, 6, 7, 8] for example, the median is (6 + 7) / 2 = 3.5
# It turns out to be int(6.5) = 6 if the original type is int
top = tf.nn.top_k(patches, k=ceil).values
if area % 2 == 1:
median = top[:, :, :, :, floor - 1]
else:
median = (
top[:, :, :, :, floor - 1] + top[:, :, :, :, ceil - 1]) / 2

output = tf.cast(median, image.dtype)

# Squeeze out the first axis to make sure
# output has the same dimension with image.
Expand Down

0 comments on commit dbadc7d

Please sign in to comment.