Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix drop block #2250

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 44 additions & 58 deletions keras_cv/layers/regularization/dropblock_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf

from keras_cv.backend import config

if config.keras_3():
base_layer = tf.keras.layers.Layer
else:
from tensorflow.keras.__internal__.layers import BaseRandomLayer

base_layer = BaseRandomLayer

from keras_cv.api_export import keras_cv_export
from keras_cv.backend import keras
from keras_cv.backend import ops
from keras_cv.backend import random
from keras_cv.utils import conv_utils


@keras_cv_export("keras_cv.layers.DropBlock2D")
class DropBlock2D(base_layer):
class DropBlock2D(keras.layers.Layer):
"""Applies DropBlock regularization to input features.

DropBlock is a form of structured dropout, where units in a contiguous
Expand Down Expand Up @@ -153,93 +145,87 @@ def __init__(
seed=None,
**kwargs,
):
# To-do: remove this once th elayer is ported to keras 3
# https://github.com/keras-team/keras-cv/issues/2136
if config.keras_3():
super().__init__(**kwargs)
if not 0.0 <= rate <= 1.0:
raise ValueError(
"This layer is not yet compatible with Keras 3."
"Please switch to Keras 2 to use this layer."
f"rate must be a number between 0 and 1. " f"Received: {rate}"
)
else:
super().__init__(seed=seed, **kwargs)
if not 0.0 <= rate <= 1.0:
raise ValueError(
f"rate must be a number between 0 and 1. "
f"Received: {rate}"
)

self._rate = rate
(
self._dropblock_height,
self._dropblock_width,
) = conv_utils.normalize_tuple(
value=block_size, n=2, name="block_size", allow_zero=False
)
self.seed = seed

self._rate = rate
(
self._dropblock_height,
self._dropblock_width,
) = conv_utils.normalize_tuple(
value=block_size, n=2, name="block_size", allow_zero=False
)
self.seed = seed
self._random_generator = random.SeedGenerator(self.seed)

def call(self, x, training=None):
if not training or self._rate == 0.0:
return x

_, height, width, _ = tf.split(tf.shape(x), 4)
_, height, width, _ = ops.split(ops.shape(x), 4)

# Unnest scalar values
height = tf.squeeze(height)
width = tf.squeeze(width)
height = ops.squeeze(height)
width = ops.squeeze(width)

dropblock_height = tf.math.minimum(self._dropblock_height, height)
dropblock_width = tf.math.minimum(self._dropblock_width, width)
dropblock_height = ops.minimum(self._dropblock_height, height)
dropblock_width = ops.minimum(self._dropblock_width, width)

gamma = (
self._rate
* tf.cast(width * height, dtype=tf.float32)
/ tf.cast(dropblock_height * dropblock_width, dtype=tf.float32)
/ tf.cast(
* ops.cast(width * height, dtype="float32")
/ ops.cast(dropblock_height * dropblock_width, dtype="float32")
/ ops.cast(
(width - self._dropblock_width + 1)
* (height - self._dropblock_height + 1),
tf.float32,
"float32",
)
)

# Forces the block to be inside the feature map.
w_i, h_i = tf.meshgrid(tf.range(width), tf.range(height))
valid_block = tf.logical_and(
tf.logical_and(
w_i, h_i = ops.meshgrid(ops.arange(width), ops.arange(height))
valid_block = ops.logical_and(
ops.logical_and(
w_i >= int(dropblock_width // 2),
w_i < width - (dropblock_width - 1) // 2,
),
tf.logical_and(
ops.logical_and(
h_i >= int(dropblock_height // 2),
h_i < width - (dropblock_height - 1) // 2,
),
)

valid_block = tf.reshape(valid_block, [1, height, width, 1])
valid_block = ops.reshape(valid_block, [1, height, width, 1])

random_noise = self._random_generator.random_uniform(
tf.shape(x), dtype=tf.float32
random_noise = random.uniform(
ops.shape(x), seed=self._random_generator, dtype="float32"
)
valid_block = tf.cast(valid_block, dtype=tf.float32)
seed_keep_rate = tf.cast(1 - gamma, dtype=tf.float32)
valid_block = ops.cast(valid_block, dtype="float32")
seed_keep_rate = ops.cast(1 - gamma, dtype="float32")
block_pattern = (1 - valid_block + seed_keep_rate + random_noise) >= 1
block_pattern = tf.cast(block_pattern, dtype=tf.float32)
block_pattern = ops.cast(block_pattern, dtype="float32")

window_size = [1, self._dropblock_height, self._dropblock_width, 1]

# Double negative and max_pool is essentially min_pooling
block_pattern = -tf.nn.max_pool(
block_pattern = -ops.max_pool(
-block_pattern,
ksize=window_size,
pool_size=window_size,
strides=[1, 1, 1, 1],
padding="SAME",
)

# Slightly scale the values, to account for magnitude change
percent_ones = tf.cast(
tf.reduce_sum(block_pattern), tf.float32
) / tf.cast(tf.size(block_pattern), tf.float32)
percent_ones = ops.cast(ops.sum(block_pattern), "float32") / ops.cast(
ops.size(block_pattern), "float32"
)
return (
x / tf.cast(percent_ones, x.dtype) * tf.cast(block_pattern, x.dtype)
x
/ ops.cast(percent_ones, x.dtype)
* ops.cast(block_pattern, x.dtype)
)

def get_config(self):
divyashreepathihalli marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
13 changes: 0 additions & 13 deletions keras_cv/layers/regularization/dropblock_2d_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import tensorflow as tf

from keras_cv.backend.config import keras_3
from keras_cv.layers.regularization.dropblock_2d import DropBlock2D
from keras_cv.tests.test_case import TestCase


@pytest.mark.skipif(keras_3(), reason="not implemented in keras 3")
class DropBlock2DTest(TestCase):
FEATURE_SHAPE = (1, 14, 14, 256) # Shape of ResNet block group 3
rng = tf.random.Generator.from_non_deterministic_state()
Expand Down Expand Up @@ -87,13 +84,3 @@ def test_input_gets_partially_zeroed_out_with_non_square_block_size(self):
@staticmethod
def _count_zeros(tensor: tf.Tensor) -> tf.Tensor:
return tf.size(tensor) - tf.math.count_nonzero(tensor, dtype=tf.int32)

def test_works_with_xla(self):
dummy_inputs = self.rng.uniform(shape=self.FEATURE_SHAPE)
layer = DropBlock2D(rate=0.1, block_size=7)

@tf.function(jit_compile=True)
def apply(x):
return layer(x, training=True)

apply(dummy_inputs)
Loading