Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dropout issue in swin transformers #7224

Merged
merged 5 commits into from
Feb 16, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions torchvision/models/swin_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ def shifted_window_attention(
qkv_bias: Optional[Tensor] = None,
proj_bias: Optional[Tensor] = None,
logit_scale: Optional[torch.Tensor] = None,
):
training: bool = True,
) -> Tensor:
"""
Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Expand Down Expand Up @@ -207,11 +208,11 @@ def shifted_window_attention(
attn = attn.view(-1, num_heads, x.size(1), x.size(1))

attn = F.softmax(attn, dim=-1)
attn = F.dropout(attn, p=attention_dropout)
attn = F.dropout(attn, p=attention_dropout, training=training)

x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), C)
x = F.linear(x, proj_weight, proj_bias)
x = F.dropout(x, p=dropout)
x = F.dropout(x, p=dropout, training=training)

# reverse windows
x = x.view(B, pad_H // window_size[0], pad_W // window_size[1], window_size[0], window_size[1], C)
Expand Down Expand Up @@ -286,7 +287,7 @@ def get_relative_position_bias(self) -> torch.Tensor:
self.relative_position_bias_table, self.relative_position_index, self.window_size # type: ignore[arg-type]
)

def forward(self, x: Tensor):
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x (Tensor): Tensor with layout of [B, H, W, C]
Expand All @@ -306,6 +307,7 @@ def forward(self, x: Tensor):
dropout=self.dropout,
qkv_bias=self.qkv.bias,
proj_bias=self.proj.bias,
training=self.training,
)


Expand Down
6 changes: 4 additions & 2 deletions torchvision/models/video/swin_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def shifted_window_attention_3d(
dropout: float = 0.0,
qkv_bias: Optional[Tensor] = None,
proj_bias: Optional[Tensor] = None,
training: bool = False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we set the default to True since it's the default of torch.nn.functional.dropout, and also the default of shifted_window_attention ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well caught this was mistake it should have been True

) -> Tensor:
"""
Window based multi-head self attention (W-MSA) module with relative position bias.
Expand Down Expand Up @@ -194,11 +195,11 @@ def shifted_window_attention_3d(
attn = attn.view(-1, num_heads, x.size(1), x.size(1))

attn = F.softmax(attn, dim=-1)
attn = F.dropout(attn, p=attention_dropout)
attn = F.dropout(attn, p=attention_dropout, training=training)

x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), c)
x = F.linear(x, proj_weight, proj_bias)
x = F.dropout(x, p=dropout)
x = F.dropout(x, p=dropout, training=training)

# reverse windows
x = x.view(
Expand Down Expand Up @@ -310,6 +311,7 @@ def forward(self, x: Tensor) -> Tensor:
dropout=self.dropout,
qkv_bias=self.qkv.bias,
proj_bias=self.proj.bias,
training=self.training,
)


Expand Down