-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dropout in SwinTransformer unable to set training=False #7103
Comments
Hi @nagadomi , I'm not sure I understand what the problem might be. (Is the I'll assume that you're worried about the fact that dropout will always be applied, even during eval: that's won't be a problem as long as you call |
The code below is a proper code in my opinion. diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py
index 0d3ab9ad32..c0d107facb 100644
--- a/torchvision/models/swin_transformer.py
+++ b/torchvision/models/swin_transformer.py
@@ -126,6 +126,7 @@ def shifted_window_attention(
qkv_bias: Optional[Tensor] = None,
proj_bias: Optional[Tensor] = None,
logit_scale: Optional[torch.Tensor] = None,
+ training: bool = True,
):
"""
Window based multi-head self attention (W-MSA) module with relative position bias.
@@ -207,11 +208,11 @@ def shifted_window_attention(
attn = attn.view(-1, num_heads, x.size(1), x.size(1))
attn = F.softmax(attn, dim=-1)
- attn = F.dropout(attn, p=attention_dropout)
+ attn = F.dropout(attn, p=attention_dropout, training=training)
x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), C)
x = F.linear(x, proj_weight, proj_bias)
- x = F.dropout(x, p=dropout)
+ x = F.dropout(x, p=dropout, training=training)
# reverse windows
x = x.view(B, pad_H // window_size[0], pad_W // window_size[1], window_size[0], window_size[1], C)
@@ -306,6 +307,7 @@ class ShiftedWindowAttention(nn.Module):
dropout=self.dropout,
qkv_bias=self.qkv.bias,
proj_bias=self.proj.bias,
+ training=self.training,
) |
Thanks for correcting me @nagadomi , you're right! Would you like to submit a PR with your proposed fix? One minor positive note here is that all the Swin transformers where trained with the default |
I am not familiar with the manner of this codebase. |
I found this issue when I tried to export my trained model to ONNX. |
Ok Let me check this. I will submit a PR for this. |
Thanks @oke-aditya , looking forward to your PR! |
Extremely sorry for the delay. 2023 has come really hard here. Lot of unplanned events occurring. So, coming back to this. I did validate that I also validated that our codebase in general uses Since I remember adding the code for SwinTransformer3d (unreleased before 0.15) the bug had propelled there too. I am fixing that as well 😄 Reference https://stackoverflow.com/questions/53419474/pytorch-nn-dropout-vs-f-dropout |
🐛 Describe the bug
vision/torchvision/models/swin_transformer.py
Lines 209 to 214 in 93df9a5
training
argument is not used here, sotraining=True
(by default) even whenmodel.training=False
.Versions
Collecting environment information...
PyTorch version: 1.13.1+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.1 LTS (x86_64)
GCC version: (Ubuntu 11.3.0-1ubuntu1~22.04) 11.3.0
Clang version: Could not collect
CMake version: version 3.22.4
Libc version: glibc-2.35
Python version: 3.10.6 (main, Nov 14 2022, 16:10:14) [GCC 11.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-56-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3070 Ti
Nvidia driver version: 525.60.13
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.7.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] numpy==1.24.1
[pip3] torch==1.13.1
[pip3] torchaudio==0.13.1
[pip3] torchtext==0.14.1
[pip3] torchvision==0.14.1
[conda] Could not collect
The text was updated successfully, but these errors were encountered: