Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for arbitrary image resolutions #24

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion maxim/blocks/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,9 @@ def apply(x, x_image):

# Get attention maps for num_channels
x2 = tf.nn.sigmoid(
Conv3x3(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_2")(image)
Conv3x3(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_2")(
image
)
)

# Get attended feature maps
Expand Down
7 changes: 3 additions & 4 deletions maxim/blocks/block_gating.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tensorflow.keras import backend as K
from tensorflow.keras import layers

from ..layers import BlockImages, SwapAxes, UnblockImages
from ..layers import SwapAxes, TFBlockImages, TFUnblockImages


def BlockGatingUnit(use_bias: bool = True, name: str = "block_gating_unit"):
Expand Down Expand Up @@ -47,8 +47,7 @@ def apply(x):
K.int_shape(x)[3],
)
fh, fw = block_size
gh, gw = h // fh, w // fw
x = BlockImages()(x, patch_size=(fh, fw))
x, gh, gw = TFBlockImages()(x, patch_size=(fh, fw))
# MLP2: Local (block) mixing part, provides within-block communication.
y = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm")(x)
y = layers.Dense(
Expand All @@ -65,7 +64,7 @@ def apply(x):
)(y)
y = layers.Dropout(dropout_rate)(y)
x = x + y
x = UnblockImages()(x, grid_size=(gh, gw), patch_size=(fh, fw))
x = TFUnblockImages()(x, patch_size=(fh, fw), grid_size=(gh, gw))
return x

return apply
7 changes: 3 additions & 4 deletions maxim/blocks/grid_gating.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tensorflow.keras import backend as K
from tensorflow.keras import layers

from ..layers import BlockImages, SwapAxes, UnblockImages
from ..layers import SwapAxes, TFBlockImagesByGrid, TFUnblockImages
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why there's a separate layer for handling blocking by grids?

Copy link
Author

@LeviVasconcelos LeviVasconcelos Dec 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BlockByGrid can be implemented as follows, please see a more detailed explanation here:

def BlockByGrid(image, grid_size):
    block_size = (image_height // grid_size[0], image_width // grid_size[1])
    return BlockImage(image, block_size)

But, while implementing TFBlockImages I used tf.split which expects an int literal as argument for num_or_size_splits.

However, in cases where we only have the grid_size and the block_size has to be computed on the fly (as here), it needs to be a tensor, and we can't use tf.split ins this case. That's why I also wrote BlockByGrid.



def GridGatingUnit(use_bias: bool = True, name: str = "grid_gating_unit"):
Expand Down Expand Up @@ -47,9 +47,8 @@ def apply(x):
K.int_shape(x)[3],
)
gh, gw = grid_size
fh, fw = h // gh, w // gw

x = BlockImages()(x, patch_size=(fh, fw))
x, ph, pw = TFBlockImagesByGrid()(x, grid_size=(gh, gw))
Comment on lines -52 to +51
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How come these operations are the same?

Copy link
Author

@LeviVasconcelos LeviVasconcelos Dec 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the original implementation, the authors implement BlockByGrid by computing the block size of a grid cell, and using BlockImages (which block images into patches of block-size).

From the paper, the authors explain the difference between "grid" and "block" like that:
Screenshot from 2022-12-16 12-53-08
Note that we can achieve the same result as the grid split by forwarding a block size of [3,2] instead. This is exactly what the authors do in the original code as highlighted here.

They are equivalent because it does the split based on the grid_size as argument instead of the block_size (called as (fh, fw) in the code) as the authors did.

A more formal test is performed here.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining.

Note that we can achieve the same result as the grid split by forwarding a block size of [3,2] instead.

How is the block size of [3, 2] interpreted in that case?

Copy link
Author

@LeviVasconcelos LeviVasconcelos Dec 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the original code, it is done as explained, note here:

gh, gw = grid_size
fh, fw = h // gh, w // gw
u = BlockImages()(u, patch_size=(fh, fw))

Note that this code is very similar to the pseudo-code written here. grid_size is passed as a parameter, but h and whave to be inferred from the image dimensions (which in case of (None, None, 3)), they are None tensors. Thus can't be used in the einops operations, and the way I found to overcome this was to rewrite the operations in tf.

We can use the block [3,2] to compute the green part of the image (grid blocking with grid_size=[3,2]) this way:

In the example shown in the image, we have that image size is [6,4]. Thus to split it with a grid_size of [2,2], we can do:

gh, gw = (2, 2)
h, w = (6,4) # image dimensions
fh, fw = h // gh, w // gw # Note that fh = 3, and fw = 2
block_image = BlockImages()(image_from_the_piture, patch_size=(fh,fw)) # patch_size=(3,2)

The above code snippet implements the green part of the image, and is very similar to what we described first.

In case with the TFBlockByGrid(), we can simply do:

gh, gw = (2,2)
block_image_using_tfblockByGrid = TFBlockByGrid()(image_from_the_picture, grid_size=(gh,gw))

and block_image should be equivalent to block_image_using_tfblockByGrid, as asserted by this test

I am not sure if this answer what you asked, though. Let me know.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

So, TFBlockByGrid() becomes more idiomatic in that sense. We want to have grid sizes of (2, 2) in the output so, directly pass that as an argument. Correct?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly!

# gMLP1: Global (grid) mixing part, provides global grid communication.
y = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm")(x)
y = layers.Dense(
Expand All @@ -66,7 +65,7 @@ def apply(x):
)(y)
y = layers.Dropout(dropout_rate)(y)
x = x + y
x = UnblockImages()(x, grid_size=(gh, gw), patch_size=(fh, fw))
x = TFUnblockImages()(x, grid_size=(gh, gw), patch_size=(ph, pw))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same. You're changing the semanticity of the code. Could you please elaborate why?

Reading this change and also previous x, ph, pw = TFBlockImagesByGrid()(x, grid_size=(gh, gw)) and comparing them to their previous versions -- they don't read the same too.

return x

return apply
32 changes: 19 additions & 13 deletions maxim/blocks/misc_gating.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from tensorflow.keras import backend as K
from tensorflow.keras import layers

from ..layers import BlockImages, SwapAxes, UnblockImages
from ..layers import SwapAxes, TFBlockImages, TFBlockImagesByGrid, TFUnblockImages
from .block_gating import BlockGmlpLayer
from .grid_gating import GridGmlpLayer

Expand Down Expand Up @@ -117,23 +117,21 @@ def apply(x):

# Get grid MLP weights
gh, gw = grid_size
fh, fw = h // gh, w // gw
u = BlockImages()(u, patch_size=(fh, fw))
dim_u = K.int_shape(u)[-3]
u, phu, pwu = TFBlockImagesByGrid()(u, grid_size=(gh, gw))
dim_u = gh * gw
u = SwapAxes()(u, -1, -3)
u = layers.Dense(dim_u, use_bias=use_bias, name=f"{name}_Dense_0")(u)
u = SwapAxes()(u, -1, -3)
u = UnblockImages()(u, grid_size=(gh, gw), patch_size=(fh, fw))
u = TFUnblockImages()(u, grid_size=(gh, gw), patch_size=(phu, pwu))

# Get Block MLP weights
fh, fw = block_size
gh, gw = h // fh, w // fw
v = BlockImages()(v, patch_size=(fh, fw))
dim_v = K.int_shape(v)[-2]
v, gh, gw = TFBlockImages()(v, patch_size=(fh, fw))
dim_v = fh * fw
v = SwapAxes()(v, -1, -2)
v = layers.Dense(dim_v, use_bias=use_bias, name=f"{name}_Dense_1")(v)
v = SwapAxes()(v, -1, -2)
v = UnblockImages()(v, grid_size=(gh, gw), patch_size=(fh, fw))
v = TFUnblockImages()(v, patch_size=(fh, fw), grid_size=(gh, gw))

x = tf.concat([u, v], axis=-1)
x = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_out_project")(x)
Expand Down Expand Up @@ -178,7 +176,9 @@ def apply(x, y):

# Get gating weights from X
x = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm_x")(x)
x = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_in_project_x")(x)
x = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_in_project_x")(
x
)
x = tf.nn.gelu(x, approximate=True)
gx = GetSpatialGatingWeights(
features=num_channels,
Expand All @@ -191,7 +191,9 @@ def apply(x, y):

# Get gating weights from Y
y = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm_y")(y)
y = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_in_project_y")(y)
y = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_in_project_y")(
y
)
y = tf.nn.gelu(y, approximate=True)
gy = GetSpatialGatingWeights(
features=num_channels,
Expand All @@ -204,12 +206,16 @@ def apply(x, y):

# Apply cross gating: X = X * GY, Y = Y * GX
y = y * gx
y = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_out_project_y")(y)
y = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_out_project_y")(
y
)
y = layers.Dropout(dropout_rate)(y)
y = y + shortcut_y

x = x * gy # gating x using y
x = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_out_project_x")(x)
x = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_out_project_x")(
x
)
x = layers.Dropout(dropout_rate)(x)
x = x + y + shortcut_x # get all aggregated signals
return x, y
Expand Down
10 changes: 1 addition & 9 deletions maxim/blocks/others.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,9 @@ def UpSampleRatio(
"""Upsample features given a ratio > 0."""

def apply(x):
n, h, w, c = (
K.int_shape(x)[0],
K.int_shape(x)[1],
K.int_shape(x)[2],
K.int_shape(x)[3],
)

# Following `jax.image.resize()`
x = Resizing(
height=int(h * ratio),
width=int(w * ratio),
ratio=1 / ratio,
method="bilinear",
antialias=True,
name=f"{name}_resizing_{K.get_uid('Resizing')}",
Expand Down
140 changes: 96 additions & 44 deletions maxim/layers.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,114 @@
"""
Layers based on https://github.com/google-research/maxim/blob/main/maxim/models/maxim.py
and reworked to cope with variable image dimensions
"""

import einops
import tensorflow as tf
from tensorflow.experimental import numpy as tnp
from tensorflow.keras import backend as K
from tensorflow.keras import layers


@tf.keras.utils.register_keras_serializable("maxim")
class BlockImages(layers.Layer):
class TFBlockImages(layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)

def call(self, x, patch_size):
def call(self, image, patch_size):
bs, h, w, num_channels = (
K.int_shape(x)[0],
K.int_shape(x)[1],
K.int_shape(x)[2],
K.int_shape(x)[3],
tf.shape(image)[0],
tf.shape(image)[1],
tf.shape(image)[2],
tf.shape(image)[3],
)
ph, pw = patch_size
gh = h // ph
gw = w // pw
pad = [[0, 0], [0, 0]]
patches = tf.space_to_batch_nd(image, [ph, pw], pad)
patches = tf.split(patches, ph * pw, axis=0)
patches = tf.stack(patches, 3) # (bs, h/p, h/p, p*p, 3)
patches_dim = tf.shape(patches)
patches = tf.reshape(
patches, [patches_dim[0], patches_dim[1], patches_dim[2], -1]
)
patches = tf.reshape(
patches,
(patches_dim[0], patches_dim[1] * patches_dim[2], ph * pw, num_channels),
)
return [patches, gh, gw]

grid_height, grid_width = h // patch_size[0], w // patch_size[1]
def get_config(self):
return super().get_config()

x = einops.rearrange(
x,
"n (gh fh) (gw fw) c -> n (gh gw) (fh fw) c",
gh=grid_height,
gw=grid_width,
fh=patch_size[0],
fw=patch_size[1],
)

return x
@tf.keras.utils.register_keras_serializable("maxim")
class TFBlockImagesByGrid(layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)

def call(self, image, grid_size):
bs, h, w, num_channels = (
tf.shape(image)[0],
tf.shape(image)[1],
tf.shape(image)[2],
tf.shape(image)[3],
)
gh, gw = grid_size
ph = h // gh
pw = w // gw
pad = [[0, 0], [0, 0]]

def block_single_image(img):
pat = tf.expand_dims(img, 0) # batch = 1
pat = tf.space_to_batch_nd(pat, [ph, pw], pad) # p*p*bs, g, g, c
pat = tf.expand_dims(pat, 3) # pxpxbs, g, g, 1, c
pat = tf.transpose(pat, perm=[3, 1, 2, 0, 4]) # 1, g, g, pxp, c
pat = tf.reshape(pat, [gh, gw, ph * pw, num_channels])
return pat

patches = image
patches = tf.map_fn(fn=lambda x: block_single_image(x), elems=patches)
patches_dim = tf.shape(patches)
patches = tf.reshape(
patches, [patches_dim[0], patches_dim[1], patches_dim[2], -1]
)
patches = tf.reshape(
patches,
(patches_dim[0], patches_dim[1] * patches_dim[2], ph * pw, num_channels),
)
return [patches, ph, pw]

def get_config(self):
config = super().get_config().copy()
return config
return super().get_config()


@tf.keras.utils.register_keras_serializable("maxim")
class UnblockImages(layers.Layer):
class TFUnblockImages(layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)

def call(self, x, grid_size, patch_size):
x = einops.rearrange(
x,
"n (gh gw) (fh fw) c -> n (gh fh) (gw fw) c",
gh=grid_size[0],
gw=grid_size[1],
fh=patch_size[0],
fw=patch_size[1],
def call(self, x, patch_size, grid_size):
bs, grid_sqrt, patch_sqrt, num_channels = (
tf.shape(x)[0],
tf.shape(x)[1],
tf.shape(x)[2],
tf.shape(x)[3],
)
ph, pw = patch_size
gh, gw = grid_size

return x
pad = [[0, 0], [0, 0]]

y = tf.reshape(x, (bs, gh, gw, -1, num_channels)) # (bs, gh, gw, ph*pw, 3)
y = tf.expand_dims(y, 0)
y = tf.transpose(y, perm=[4, 1, 2, 3, 0, 5])
y = tf.reshape(y, [bs * ph * pw, gh, gw, num_channels])
y = tf.batch_to_space(y, [ph, pw], pad)

return y

def get_config(self):
config = super().get_config().copy()
return config
return super().get_config()


@tf.keras.utils.register_keras_serializable("maxim")
Expand All @@ -76,28 +125,31 @@ def get_config(self):


@tf.keras.utils.register_keras_serializable("maxim")
class Resizing(layers.Layer):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the need to segregate this to Up and Down?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found it easier to read, but indeed it adds a chunk of code. Reformatting to use a single layer only.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's easier to read, I would consider adding an elaborate comment in the script so that readers are aware.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think now is better (with a single resizing layer).

def __init__(self, height, width, antialias=True, method="bilinear", **kwargs):
class Resizing(tf.keras.layers.Layer):
def __init__(self, ratio: float, method="bilinear", antialias=True, **kwargs):
super().__init__(**kwargs)
self.height = height
self.width = width
self.antialias = antialias
self.ratio = ratio
self.method = method
self.antialias = antialias

def call(self, x):
return tf.image.resize(
x,
size=(self.height, self.width),
antialias=self.antialias,
def call(self, img):
shape = tf.shape(img)

new_sh = tf.cast(shape[1:3], tf.float32) // self.ratio

x = tf.image.resize(
img,
size=tf.cast(new_sh, tf.int32),
method=self.method,
antialias=self.antialias,
)
return x

def get_config(self):
config = super().get_config().copy()
config.update(
{
"height": self.height,
"width": self.width,
"ratio": self.ratio,
"antialias": self.antialias,
"method": self.method,
}
Expand Down
Loading