Skip to content

Commit

Permalink
Fix formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
faizan-m committed Nov 18, 2021
1 parent 376e22e commit 8e68a1d
Showing 1 changed file with 25 additions and 19 deletions.
44 changes: 25 additions & 19 deletions tensorflow_addons/image/dense_image_warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,35 +57,41 @@ def _interpolate_bilinear_with_checks(
) -> tf.Tensor:
"""Perform checks on inputs without tf.function decorator to avoid flakiness."""
if indexing != "ij" and indexing != "xy":
raise ValueError("Indexing mode must be 'ij' or 'xy'")
raise ValueError("Indexing mode must be 'ij' or 'xy'")

grid = tf.convert_to_tensor(grid)
query_points = tf.convert_to_tensor(query_points)
grid_shape = tf.shape(grid)
query_shape = tf.shape(query_points)

with tf.control_dependencies([
tf.Assert(tf.equal(tf.rank(grid), 4), ["Grid must be 4D Tensor"]),
tf.Assert(
tf.greater_equal(grid_shape[1], 2),
["Grid height must be at least 2."]),
tf.Assert(
tf.greater_equal(grid_shape[2], 2),
["Grid width must be at least 2."]),
tf.Assert(
tf.equal(tf.rank(query_points), 3),
["Query points must be 3 dimensional."]),
tf.Assert(
tf.equal(query_shape[2], 2),
["Query points last dimension must be 2."])
]):
with tf.control_dependencies(
[
tf.Assert(tf.equal(tf.rank(grid), 4), ["Grid must be 4D Tensor"]),
tf.Assert(
tf.greater_equal(grid_shape[1], 2), ["Grid height must be at least 2."]
),
tf.Assert(
tf.greater_equal(grid_shape[2], 2), ["Grid width must be at least 2."]
),
tf.Assert(
tf.equal(tf.rank(query_points), 3),
["Query points must be 3 dimensional."],
),
tf.Assert(
tf.equal(query_shape[2], 2), ["Query points last dimension must be 2."]
),
]
):
return _interpolate_bilinear_impl(grid, query_points, indexing, name)


@tf.function
def _interpolate_bilinear_impl(grid: types.TensorLike,
query_points: types.TensorLike, indexing: str,
name: Optional[str]) -> tf.Tensor:
def _interpolate_bilinear_impl(
grid: types.TensorLike,
query_points: types.TensorLike,
indexing: str,
name: Optional[str],
) -> tf.Tensor:
"""tf.function implementation of interpolate_bilinear."""
with tf.name_scope(name or "interpolate_bilinear"):
grid_shape = tf.shape(grid)
Expand Down

0 comments on commit 8e68a1d

Please sign in to comment.