From 29ff22b8ff6935d6559b1c14b8dbcecf13442bf3 Mon Sep 17 00:00:00 2001 From: charSLee013 Date: Fri, 28 Jun 2024 01:38:32 +0800 Subject: [PATCH 1/2] refactor: remove unnecessary transpositions --- ChatTTS/model/dvae.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/ChatTTS/model/dvae.py b/ChatTTS/model/dvae.py index 491477e82..50b01a734 100644 --- a/ChatTTS/model/dvae.py +++ b/ChatTTS/model/dvae.py @@ -154,9 +154,8 @@ def __init__( ) self.conv_out = nn.Conv1d(hidden, odim, kernel_size=1, bias=False) - def forward(self, input: torch.Tensor, conditioning=None) -> torch.Tensor: - # B, T, C - x = input.transpose_(1, 2) + def forward(self, x: torch.Tensor, conditioning=None) -> torch.Tensor: + # B, C, T y = self.conv_in(x) del x for f in self.decoder_block: @@ -214,7 +213,7 @@ def forward(self, inp: torch.Tensor) -> torch.Tensor: dec_out = self.out_conv( self.decoder( - input=vq_feats.transpose_(1, 2), + x=vq_feats, ).transpose_(1, 2), ) From 627d7d6a99ffe1b2e85473c48db39775f2052626 Mon Sep 17 00:00:00 2001 From: charSLee013 Date: Fri, 28 Jun 2024 15:37:55 +0800 Subject: [PATCH 2/2] refactor: remove the redundant output transpose --- ChatTTS/model/dvae.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ChatTTS/model/dvae.py b/ChatTTS/model/dvae.py index 50b01a734..c3cc91ee2 100644 --- a/ChatTTS/model/dvae.py +++ b/ChatTTS/model/dvae.py @@ -163,7 +163,7 @@ def forward(self, x: torch.Tensor, conditioning=None) -> torch.Tensor: x = self.conv_out(y) del y - return x.transpose_(1, 2) + return x class DVAE(nn.Module): @@ -214,7 +214,7 @@ def forward(self, inp: torch.Tensor) -> torch.Tensor: dec_out = self.out_conv( self.decoder( x=vq_feats, - ).transpose_(1, 2), + ), ) return torch.mul(dec_out, self.coef, out=dec_out)