Skip to content

Commit

Permalink
Codacy/pylint fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Jan 7, 2025
1 parent eb7817e commit e93f3df
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 46 deletions.
2 changes: 1 addition & 1 deletion direct/nn/transformers/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any, Callable, Optional

import torch
import torch.nn as nn
from torch import nn

from direct.constants import COMPLEX_SIZE
from direct.data.transforms import reduce_operator
Expand Down
94 changes: 55 additions & 39 deletions direct/nn/transformers/uformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
from typing import List, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import nn
from torch.nn.init import trunc_normal_

from direct.nn.transformers.utils import DropoutPath, init_weights, norm, pad_to_square, unnorm, unpad_to_original
Expand Down Expand Up @@ -263,20 +263,20 @@ def forward(
v : torch.Tensor
Value tensor.
"""
b, n, c, h = *x.shape, self.heads
l = int(math.sqrt(n))
_, n, _, h = *x.shape, self.heads
f = int(math.sqrt(n))
w = int(math.sqrt(n))

attn_kv = x if attn_kv is None else attn_kv
x = rearrange(x, "b (l w) c -> b c l w", l=l, w=w)
attn_kv = rearrange(attn_kv, "b (l w) c -> b c l w", l=l, w=w)
x = rearrange(x, "b (f w) c -> b c f w", f=f, w=w)
attn_kv = rearrange(attn_kv, "b (f w) c -> b c f w", f=f, w=w)
q = self.to_q(x)
q = rearrange(q, "b (h d) l w -> b h (l w) d", h=h)
q = rearrange(q, "b (h d) f w -> b h (f w) d", h=h)

k = self.to_k(attn_kv)
v = self.to_v(attn_kv)
k = rearrange(k, "b (h d) l w -> b h (l w) d", h=h)
v = rearrange(v, "b (h d) l w -> b h (l w) d", h=h)
k = rearrange(k, "b (h d) f w -> b h (f w) d", h=h)
v = rearrange(v, "b (h d) f w -> b h (f w) d", h=h)
return q, k, v


Expand Down Expand Up @@ -369,7 +369,8 @@ class WindowAttentionModule(nn.Module):
num_heads : int
Number of heads for multi-head self-attention.
token_projection : AttentionTokenProjectionType
Type of projection for token-level queries, keys, and values. Either "conv" or "linear".
Type of token projection. Must be one of AttentionTokenProjectionType.LINEAR
or AttentionTokenProjectionType.CONV. Default: AttentionTokenProjectionType.LINEAR.
qkv_bias : bool
Whether to use bias in the linear projection layer for queries, keys, and values.
qk_scale : float
Expand Down Expand Up @@ -402,7 +403,8 @@ def __init__(
num_heads : int
Number of heads for multi-head self-attention.
token_projection : AttentionTokenProjectionType
Type of projection for token-level queries, keys, and values. Either "conv" or "linear".
Type of token projection. Must be one of AttentionTokenProjectionType.LINEAR
or AttentionTokenProjectionType.CONV. Default: AttentionTokenProjectionType.LINEAR.
qkv_bias : bool
Whether to use bias in the linear projection layer for queries, keys, and values.
qk_scale : float
Expand Down Expand Up @@ -438,12 +440,10 @@ def __init__(
self.register_buffer("relative_position_index", relative_position_index)
trunc_normal_(self.relative_position_bias_table, std=0.02)

if token_projection == "conv":
if token_projection == AttentionTokenProjectionType.CONV:
self.qkv = ConvProjectionModule(dim, num_heads, dim // num_heads, bias=qkv_bias)
elif token_projection == "linear":
self.qkv = LinearProjectionModule(dim, num_heads, dim // num_heads, bias=qkv_bias)
else:
raise Exception("Projection error!")
self.qkv = LinearProjectionModule(dim, num_heads, dim // num_heads, bias=qkv_bias)

self.token_projection = token_projection
self.attn_drop = nn.Dropout(attn_drop)
Expand Down Expand Up @@ -773,7 +773,6 @@ def window_partition(x: torch.Tensor, win_size: int, dilation_rate: int = 1) ->
B, H, W, C = x.shape
if dilation_rate != 1:
x = x.permute(0, 3, 1, 2) # B, C, H, W
assert type(dilation_rate) is int, "dilation_rate should be a int"
x = F.unfold(
x, kernel_size=win_size, dilation=dilation_rate, padding=4 * (dilation_rate - 1), stride=win_size
) # B, C*Wh*Ww, H/Wh*W/Ww
Expand Down Expand Up @@ -1042,7 +1041,7 @@ def __init__(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=kernel_size // 2),
)
if act_layer is not None:
self.proj.add_module(act_layer(inplace=True))
self.proj.add_module("activation", act_layer(inplace=True))
if norm_layer is not None:
self.norm = norm_layer(out_channels)
else:
Expand Down Expand Up @@ -1110,9 +1109,11 @@ class LeWinTransformerBlock(nn.Module):
norm_layer : nn.Module
The normalization layer to use. Default: nn.LayerNorm.
token_projection : AttentionTokenProjectionType
Type of token projection. Must be one of ["linear", "conv"]. Default: AttentionTokenProjectionType.LINEAR.
Type of token projection. Must be one of AttentionTokenProjectionType.LINEAR
or AttentionTokenProjectionType.CONV. Default: AttentionTokenProjectionType.LINEAR.
token_mlp : LeWinTransformerMLPTokenType
Type of token-level MLP. Must be one of ["leff", "mlp", "ffn"]. Default: LeWinTransformerMLPTokenType.LEFF.
Type of token-level MLP. Must be one of LeWinTransformerMLPTokenType.LEFF or
LeWinTransformerMLPTokenType.MLP. Default: LeWinTransformerMLPTokenType.LEFF.
modulator : bool
Whether to use a modulator in the attention mechanism. Default: False.
cross_modulator : bool
Expand Down Expand Up @@ -1171,9 +1172,11 @@ def __init__(
norm_layer : nn.Module
The normalization layer to use. Default: nn.LayerNorm.
token_projection : AttentionTokenProjectionType
Type of token projection. Must be one of ["linear", "conv"]. Default: AttentionTokenProjectionType.LINEAR.
Type of token projection. Must be one of AttentionTokenProjectionType.LINEAR
or AttentionTokenProjectionType.CONV. Default: AttentionTokenProjectionType.LINEAR.
token_mlp : LeWinTransformerMLPTokenType
Type of token-level MLP. Must be one of ["leff", "mlp", "ffn"]. Default: LeWinTransformerMLPTokenType.LEFF.
Type of token-level MLP. Must be one of LeWinTransformerMLPTokenType.LEFF or
LeWinTransformerMLPTokenType.MLP. Default: LeWinTransformerMLPTokenType.LEFF.
modulator : bool
Whether to use a modulator in the attention mechanism. Default: False.
cross_modulator : bool
Expand Down Expand Up @@ -1226,12 +1229,10 @@ def __init__(
self.drop_path = DropoutPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
if token_mlp in ["ffn", "mlp"]:
if token_mlp == LeWinTransformerMLPTokenType.MLP:
self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
elif token_mlp == "leff":
self.mlp = LeFF(dim, mlp_hidden_dim, act_layer=act_layer)
else:
raise Exception("FFN error!")
self.mlp = LeFF(dim, mlp_hidden_dim, act_layer=act_layer)

def with_pos_embed(self, tensor: torch.Tensor, pos: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Add positional embeddings to the input tensor.
Expand Down Expand Up @@ -1274,7 +1275,7 @@ def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch
W = int(math.sqrt(L))

## input mask
if mask != None:
if mask is not None:
input_mask = F.interpolate(mask, size=(H, W)).permute(0, 2, 3, 1)
input_mask_windows = window_partition(input_mask, self.win_size) # nW, win_size, win_size, 1
attn_mask = input_mask_windows.view(-1, self.win_size * self.win_size) # nW, win_size*win_size
Expand Down Expand Up @@ -1382,9 +1383,11 @@ class BasicUFormerLayer(nn.Module):
norm_layer : nn.Module
The normalization layer to use. Default: nn.LayerNorm.
token_projection : AttentionTokenProjectionType
Type of token projection. Must be one of ["linear", "conv"]. Default: AttentionTokenProjectionType.LINEAR.
Type of token projection. Must be one of AttentionTokenProjectionType.LINEAR
or AttentionTokenProjectionType.CONV. Default: AttentionTokenProjectionType.LINEAR.
token_mlp : LeWinTransformerMLPTokenType
Type of token-level MLP. Must be one of ["leff", "mlp", "ffn"]. Default: LeWinTransformerMLPTokenType.LEFF.
Type of token-level MLP. Must be one of LeWinTransformerMLPTokenType.LEFF or
LeWinTransformerMLPTokenType.MLP. Default: LeWinTransformerMLPTokenType.LEFF.
shift_flag : bool
Whether to use shift in the attention sliding windows or not. Default: True.
modulator : bool
Expand All @@ -1408,7 +1411,7 @@ def __init__(
drop_path: List[float] | float = 0.0,
norm_layer: nn.Module = nn.LayerNorm,
token_projection: AttentionTokenProjectionType = AttentionTokenProjectionType.LINEAR,
token_mlp: LeWinTransformerMLPTokenType = LeWinTransformerMLPTokenType.FFN,
token_mlp: LeWinTransformerMLPTokenType = LeWinTransformerMLPTokenType.MLP,
shift_flag: bool = True,
modulator: bool = False,
cross_modulator: bool = False,
Expand Down Expand Up @@ -1441,9 +1444,11 @@ def __init__(
norm_layer : nn.Module
The normalization layer to use. Default: nn.LayerNorm.
token_projection : AttentionTokenProjectionType
Type of token projection. Must be one of ["linear", "conv"]. Default: AttentionTokenProjectionType.LINEAR.
Type of token projection. Must be one of AttentionTokenProjectionType.LINEAR
or AttentionTokenProjectionType.CONV. Default: AttentionTokenProjectionType.LINEAR.
token_mlp : LeWinTransformerMLPTokenType
Type of token-level MLP. Must be one of ["leff", "mlp", "ffn"]. Default: LeWinTransformerMLPTokenType.LEFF.
Type of token-level MLP. Must be one of LeWinTransformerMLPTokenType.LEFF or
LeWinTransformerMLPTokenType.MLP. Default: LeWinTransformerMLPTokenType.LEFF.
shift_flag : bool
Whether to use shift in the attention sliding windows or not. Default: True.
modulator : bool
Expand Down Expand Up @@ -1541,9 +1546,11 @@ class UFormer(nn.Module):
patch_norm : bool
Whether to use normalization for the patch embeddings. Default: True.
token_projection : AttentionTokenProjectionType
Type of token projection. Must be one of ["linear", "conv"]. Default: AttentionTokenProjectionType.LINEAR.
Type of token projection. Must be one of AttentionTokenProjectionType.LINEAR
or AttentionTokenProjectionType.CONV. Default: AttentionTokenProjectionType.LINEAR.
token_mlp : LeWinTransformerMLPTokenType
Type of token-level MLP. Must be one of ["leff", "mlp", "ffn"]. Default: LeWinTransformerMLPTokenType.LEFF.
Type of token-level MLP. Must be one of LeWinTransformerMLPTokenType.LEFF or
LeWinTransformerMLPTokenType.MLP. Default: LeWinTransformerMLPTokenType.LEFF.
shift_flag : bool
Whether to use shift operation in the local attention mechanism. Default: True.
modulator : bool
Expand Down Expand Up @@ -1622,9 +1629,11 @@ def __init__(
patch_norm : bool
Whether to use normalization for the patch embeddings. Default: True.
token_projection : AttentionTokenProjectionType
Type of token projection. Must be one of ["linear", "conv"]. Default: AttentionTokenProjectionType.LINEAR.
Type of token projection. Must be one of AttentionTokenProjectionType.LINEAR
or AttentionTokenProjectionType.CONV. Default: AttentionTokenProjectionType.LINEAR.
token_mlp : LeWinTransformerMLPTokenType
Type of token-level MLP. Must be one of ["leff", "mlp", "ffn"]. Default: LeWinTransformerMLPTokenType.LEFF.
Type of token-level MLP. Must be one of LeWinTransformerMLPTokenType.LEFF or
LeWinTransformerMLPTokenType.MLP. Default: LeWinTransformerMLPTokenType.LEFF.
shift_flag : bool
Whether to use shift operation in the local attention mechanism. Default: True.
modulator : bool
Expand Down Expand Up @@ -1783,7 +1792,10 @@ def no_weight_decay_keywords(self):
return {"relative_position_bias_table"}

def extra_repr(self) -> str:
return f"embedding_dim={self.embedding_dim}, token_projection={self.token_projection}, token_mlp={self.mlp},win_size={self.win_size}"
return (
f"embedding_dim={self.embedding_dim}, token_projection={self.token_projection}, "
+ f"token_mlp={self.mlp},win_size={self.win_size}"
)

def forward(self, input: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Performs forward pass of :class:`UFormer`.
Expand Down Expand Up @@ -1865,9 +1877,11 @@ class UFormerModel(nn.Module):
patch_norm : bool
Whether to use normalization for the patch embeddings. Default: True.
token_projection : AttentionTokenProjectionType
Type of token projection. Must be one of ["linear", "conv"]. Default: AttentionTokenProjectionType.LINEAR.
Type of token projection. Must be one of AttentionTokenProjectionType.LINEAR
or AttentionTokenProjectionType.CONV. Default: AttentionTokenProjectionType.LINEAR.
token_mlp : LeWinTransformerMLPTokenType
Type of token-level MLP. Must be one of ["leff", "mlp", "ffn"]. Default: LeWinTransformerMLPTokenType.LEFF.
Type of token-level MLP. Must be one of LeWinTransformerMLPTokenType.LEFF or
LeWinTransformerMLPTokenType.MLP. Default: LeWinTransformerMLPTokenType.LEFF.
shift_flag : bool
Whether to use shift operation in the local attention mechanism. Default: True.
modulator : bool
Expand Down Expand Up @@ -1943,9 +1957,11 @@ def __init__(
patch_norm : bool
Whether to use normalization for the patch embeddings. Default: True.
token_projection : AttentionTokenProjectionType
Type of token projection. Must be one of ["linear", "conv"]. Default: AttentionTokenProjectionType.LINEAR.
Type of token projection. Must be one of AttentionTokenProjectionType.LINEAR
or AttentionTokenProjectionType.CONV. Default: AttentionTokenProjectionType.LINEAR.
token_mlp : LeWinTransformerMLPTokenType
Type of token-level MLP. Must be one of ["leff", "mlp", "ffn"]. Default: LeWinTransformerMLPTokenType.LEFF.
Type of token-level MLP. Must be one of LeWinTransformerMLPTokenType.LEFF or
LeWinTransformerMLPTokenType.MLP. Default: LeWinTransformerMLPTokenType.LEFF.
shift_flag : bool
Whether to use shift operation in the local attention mechanism. Default: True.
modulator : bool
Expand Down
5 changes: 3 additions & 2 deletions direct/nn/transformers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from math import ceil, floor

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch import nn

__all__ = ["init_weights", "norm", "pad_to_divisible", "pad_to_square", "unnorm", "unpad_to_original", "DropoutPath"]

Expand All @@ -34,7 +34,8 @@ def pad_to_divisible(x: torch.Tensor, pad_size: tuple[int, ...]) -> tuple[torch.
pad_before = (p_dim - dim % p_dim) % p_dim / 2
pads.append((floor(pad_before), ceil(pad_before)))

# Reverse and flatten pads to match torch's expected (pad_n_before, pad_n_after, ..., pad_1_before, pad_1_after) format
# Reverse and flatten pads to match torch's expected
# (pad_n_before, pad_n_after, ..., pad_1_before, pad_1_after) format
flat_pads = tuple(val for sublist in pads[::-1] for val in sublist)
x = F.pad(x, flat_pads)

Expand Down
3 changes: 1 addition & 2 deletions direct/nn/transformers/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch import nn

from direct.constants import COMPLEX_SIZE
from direct.nn.transformers.utils import DropoutPath, init_weights, norm, pad_to_divisible, unnorm, unpad_to_original
Expand Down Expand Up @@ -978,7 +978,6 @@ def seq2img(self, x: torch.Tensor, img_size: tuple[int, ...]) -> torch.Tensor:
torch.Tensor
The image tensor.
"""
pass

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Performs forward pass of :class:`VisionTransformer`.
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_nn/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def create_input(shape):
)
@pytest.mark.parametrize(
"token_mlp",
[LeWinTransformerMLPTokenType.FFN, LeWinTransformerMLPTokenType.MLP, LeWinTransformerMLPTokenType.LEFF],
[LeWinTransformerMLPTokenType.MLP, LeWinTransformerMLPTokenType.LEFF],
)
def test_uformer(
shape,
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_nn/test_transformers_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def create_sample(shape, **kwargs):
)
@pytest.mark.parametrize(
"token_mlp",
[LeWinTransformerMLPTokenType.FFN],
[LeWinTransformerMLPTokenType.MLP],
)
def test_image_uformer_engine(
shape,
Expand Down

0 comments on commit e93f3df

Please sign in to comment.