Skip to content

Commit

Permalink
Resolve comments
Browse files Browse the repository at this point in the history
  • Loading branch information
midsterx committed Nov 15, 2021
1 parent 5f9015c commit 0a475ec
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 37 deletions.
59 changes: 24 additions & 35 deletions tensorflow_addons/layers/max_unpooling_2d_v2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -24,40 +24,29 @@

def _max_unpooling_2d_v2(updates, mask, output_size):
"""Unpool the outputs of a maximum pooling operation."""
output_size_attr = " ".join(["i: %d" % v for v in output_size])
experimental_implements = [
'name: "addons:MaxUnpooling2DV2"',
'attr { key: "output_size" value { list {%s} } }' % output_size_attr,
]
experimental_implements = " ".join(experimental_implements)

@tf.function(experimental_implements=experimental_implements)
def func(updates, mask):
mask = tf.cast(mask, "int32")
input_shape = tf.shape(updates, out_type="int32")
input_shape = [updates.shape[i] or input_shape[i] for i in range(4)]
output_shape = output_size

# Calculates indices for batch, height, width and feature maps.
one_like_mask = tf.ones_like(mask, dtype="int32")
batch_shape = tf.concat([[input_shape[0]], [1], [1], [1]], axis=0)
batch_range = tf.reshape(
tf.range(output_shape[0], dtype="int32"), shape=batch_shape
)
b = one_like_mask * batch_range
y = mask // (output_shape[2] * output_shape[3])
x = (mask // output_shape[3]) % output_shape[2]
feature_range = tf.range(output_shape[3], dtype="int32")
f = one_like_mask * feature_range

# Transposes indices & reshape update values to one dimension.
updates_size = tf.size(updates)
indices = tf.transpose(tf.reshape(tf.stack([b, y, x, f]), [4, updates_size]))
values = tf.reshape(updates, [updates_size])
ret = tf.scatter_nd(indices, values, output_shape)
return ret

return func(updates, mask)
mask = tf.cast(mask, "int32")
input_shape = tf.shape(updates, out_type="int32")
input_shape = [updates.shape[i] or input_shape[i] for i in range(4)]
output_shape = output_size

# Calculates indices for batch, height, width and feature maps.
one_like_mask = tf.ones_like(mask, dtype="int32")
batch_shape = tf.concat([[input_shape[0]], [1], [1], [1]], axis=0)
batch_range = tf.reshape(
tf.range(output_shape[0], dtype="int32"), shape=batch_shape
)
b = one_like_mask * batch_range
y = mask // (output_shape[2] * output_shape[3])
x = (mask // output_shape[3]) % output_shape[2]
feature_range = tf.range(output_shape[3], dtype="int32")
f = one_like_mask * feature_range

# Transposes indices & reshape update values to one dimension.
updates_size = tf.size(updates)
indices = tf.transpose(tf.reshape(tf.stack([b, y, x, f]), [4, updates_size]))
values = tf.reshape(updates, [updates_size])
ret = tf.scatter_nd(indices, values, output_shape)
return ret


@tf.keras.utils.register_keras_serializable(package="Addons")
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_addons/layers/tests/max_unpooling_2d_v2_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
1 change: 0 additions & 1 deletion tools/testing/source_code_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ def test_no_experimental_api():
allowlist = [
"tensorflow_addons/optimizers/weight_decay_optimizers.py",
"tensorflow_addons/layers/max_unpooling_2d.py",
"tensorflow_addons/layers/max_unpooling_2d_v2.py",
"tensorflow_addons/image/dense_image_warp.py",
]
for file_path, line_idx, line in get_lines_of_source_code(allowlist):
Expand Down

0 comments on commit 0a475ec

Please sign in to comment.