Skip to content

Commit

Permalink
Revert to initial stoi indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
mpariente committed Dec 29, 2023
1 parent b9746a8 commit cac3ead
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions torch_stoi/stoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ def forward(self, est_targets: torch.Tensor,
targets, est_targets, self.dyn_range, self.win, self.win_len,
self.win_len//2
)
# Remove the last mask frame to replicate pystoi behavior
mask, _ = mask.sort(-1, descending=True)
mask = mask[..., 1:]

# Here comes the real computation, take STFT
x_spec = self.stft(targets, self.win, self.nfft, overlap=2)
Expand Down Expand Up @@ -256,9 +259,7 @@ def remove_silent_frames(x, y, dyn_range, window, framelen, hop):
x_frames = x_frames.permute(0, 2, 1)
y_frames = y_frames.permute(0, 2, 1)

mask, _ = mask.long().sort(-1, descending=True)

return x_sil, y_sil, mask
return x_sil, y_sil, mask.long()

@staticmethod
def stft(x, win, fft_size, overlap=4):
Expand All @@ -272,7 +273,7 @@ def stft(x, win, fft_size, overlap=4):
win_len = win.shape[0]
hop = int(win_len / overlap)
frames = unfold(x[:, None, None, :], kernel_size=(1, win_len),
stride=(1, hop))
stride=(1, hop))[..., :-1]
return torch.fft.rfft(frames*win[:, None], n=fft_size, dim=1)

@staticmethod
Expand Down

0 comments on commit cac3ead

Please sign in to comment.