From 821ed10f293e1b1d6437159f224293f997361606 Mon Sep 17 00:00:00 2001 From: Oke Aditya Date: Sun, 12 Feb 2023 15:17:35 +0530 Subject: [PATCH 1/3] Fix dropout issue in swin transformers --- torchvision/models/swin_transformer.py | 10 ++++++---- torchvision/models/video/swin_transformer.py | 6 ++++-- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 0d3ab9ad32a..0fdf4a77d59 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -126,7 +126,8 @@ def shifted_window_attention( qkv_bias: Optional[Tensor] = None, proj_bias: Optional[Tensor] = None, logit_scale: Optional[torch.Tensor] = None, -): + training: bool = True, +) -> Tensor: """ Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. @@ -207,11 +208,11 @@ def shifted_window_attention( attn = attn.view(-1, num_heads, x.size(1), x.size(1)) attn = F.softmax(attn, dim=-1) - attn = F.dropout(attn, p=attention_dropout) + attn = F.dropout(attn, p=attention_dropout, training=training) x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), C) x = F.linear(x, proj_weight, proj_bias) - x = F.dropout(x, p=dropout) + x = F.dropout(x, p=dropout, training=training) # reverse windows x = x.view(B, pad_H // window_size[0], pad_W // window_size[1], window_size[0], window_size[1], C) @@ -286,7 +287,7 @@ def get_relative_position_bias(self) -> torch.Tensor: self.relative_position_bias_table, self.relative_position_index, self.window_size # type: ignore[arg-type] ) - def forward(self, x: Tensor): + def forward(self, x: Tensor) -> Tensor: """ Args: x (Tensor): Tensor with layout of [B, H, W, C] @@ -306,6 +307,7 @@ def forward(self, x: Tensor): dropout=self.dropout, qkv_bias=self.qkv.bias, proj_bias=self.proj.bias, + training=self.training, ) diff --git a/torchvision/models/video/swin_transformer.py b/torchvision/models/video/swin_transformer.py index c6a1602d255..4db972274f5 100644 --- a/torchvision/models/video/swin_transformer.py +++ b/torchvision/models/video/swin_transformer.py @@ -124,6 +124,7 @@ def shifted_window_attention_3d( dropout: float = 0.0, qkv_bias: Optional[Tensor] = None, proj_bias: Optional[Tensor] = None, + training: bool = False, ) -> Tensor: """ Window based multi-head self attention (W-MSA) module with relative position bias. @@ -194,11 +195,11 @@ def shifted_window_attention_3d( attn = attn.view(-1, num_heads, x.size(1), x.size(1)) attn = F.softmax(attn, dim=-1) - attn = F.dropout(attn, p=attention_dropout) + attn = F.dropout(attn, p=attention_dropout, training=training) x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), c) x = F.linear(x, proj_weight, proj_bias) - x = F.dropout(x, p=dropout) + x = F.dropout(x, p=dropout, training=training) # reverse windows x = x.view( @@ -310,6 +311,7 @@ def forward(self, x: Tensor) -> Tensor: dropout=self.dropout, qkv_bias=self.qkv.bias, proj_bias=self.proj.bias, + training=self.training, ) From d07b25ebe2e7d81237d007b2e8821e7187e378d0 Mon Sep 17 00:00:00 2001 From: Oke Aditya Date: Thu, 16 Feb 2023 16:39:16 +0530 Subject: [PATCH 2/3] Fix dropout issue in swin transformers add docstring --- torchvision/models/swin_transformer.py | 1 + torchvision/models/video/swin_transformer.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 0fdf4a77d59..386f73a51fe 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -144,6 +144,7 @@ def shifted_window_attention( qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None. proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None. logit_scale (Tensor[out_dim], optional): Logit scale of cosine attention for Swin Transformer V2. Default: None. + training (bool, True): If you set the dropout parameters, set training=True while training. Returns: Tensor[N, H, W, C]: The output tensor after shifted window attention. """ diff --git a/torchvision/models/video/swin_transformer.py b/torchvision/models/video/swin_transformer.py index 4db972274f5..ad62807ccc3 100644 --- a/torchvision/models/video/swin_transformer.py +++ b/torchvision/models/video/swin_transformer.py @@ -124,7 +124,7 @@ def shifted_window_attention_3d( dropout: float = 0.0, qkv_bias: Optional[Tensor] = None, proj_bias: Optional[Tensor] = None, - training: bool = False, + training: bool = True, ) -> Tensor: """ Window based multi-head self attention (W-MSA) module with relative position bias. @@ -141,6 +141,7 @@ def shifted_window_attention_3d( dropout (float): Dropout ratio of output. Default: 0.0. qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None. proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None. + training (bool, True): If you set the dropout parameters, set training=True while training. Returns: Tensor[B, T, H, W, C]: The output tensor after shifted window attention. """ From b3054b4d1879858c7ab1225d253780bb911a6c95 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 16 Feb 2023 16:07:11 +0000 Subject: [PATCH 3/3] Doc fixes + fix one more call --- torchvision/models/swin_transformer.py | 3 ++- torchvision/models/video/swin_transformer.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 386f73a51fe..249ca37b9d2 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -144,7 +144,7 @@ def shifted_window_attention( qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None. proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None. logit_scale (Tensor[out_dim], optional): Logit scale of cosine attention for Swin Transformer V2. Default: None. - training (bool, True): If you set the dropout parameters, set training=True while training. + training (bool, optional): Training flag used by the dropout parameters. Default: True. Returns: Tensor[N, H, W, C]: The output tensor after shifted window attention. """ @@ -394,6 +394,7 @@ def forward(self, x: Tensor): qkv_bias=self.qkv.bias, proj_bias=self.proj.bias, logit_scale=self.logit_scale, + training=self.training, ) diff --git a/torchvision/models/video/swin_transformer.py b/torchvision/models/video/swin_transformer.py index ad62807ccc3..25cf3cf997e 100644 --- a/torchvision/models/video/swin_transformer.py +++ b/torchvision/models/video/swin_transformer.py @@ -141,7 +141,7 @@ def shifted_window_attention_3d( dropout (float): Dropout ratio of output. Default: 0.0. qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None. proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None. - training (bool, True): If you set the dropout parameters, set training=True while training. + training (bool, optional): Training flag used by the dropout parameters. Default: True. Returns: Tensor[B, T, H, W, C]: The output tensor after shifted window attention. """