diff --git a/monai/networks/blocks/spatialattention.py b/monai/networks/blocks/spatialattention.py index 665442b55e..60a89a7840 100644 --- a/monai/networks/blocks/spatialattention.py +++ b/monai/networks/blocks/spatialattention.py @@ -17,9 +17,6 @@ import torch.nn as nn from monai.networks.blocks import SABlock -from monai.utils import optional_import - -Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") class SpatialAttentionBlock(nn.Module): @@ -74,24 +71,10 @@ def __init__( def forward(self, x: torch.Tensor): residual = x - - if self.spatial_dims == 1: - h = x.shape[2] - rearrange_input = Rearrange("b c h -> b h c") - rearrange_output = Rearrange("b h c -> b c h", h=h) - if self.spatial_dims == 2: - h, w = x.shape[2], x.shape[3] - rearrange_input = Rearrange("b c h w -> b (h w) c") - rearrange_output = Rearrange("b (h w) c -> b c h w", h=h, w=w) - else: - h, w, d = x.shape[2], x.shape[3], x.shape[4] - rearrange_input = Rearrange("b c h w d -> b (h w d) c") - rearrange_output = Rearrange("b (h w d) c -> b c h w d", h=h, w=w, d=d) - + shape = x.shape x = self.norm(x) - x = rearrange_input(x) # B x (x_dim * y_dim [ * z_dim]) x C - + x = x.reshape(*shape[:2], -1).transpose(1, 2) # "b c h w d -> b (h w d) c" x = self.attn(x) - x = rearrange_output(x) # B x x C x x_dim * y_dim * [z_dim] + x = x.transpose(1, 2).reshape(shape) # "b (h w d) c -> b c h w d" x = x + residual return x