Skip to content

Commit

Permalink
Revert to initial stoi indexing (#12)
Browse files Browse the repository at this point in the history
* Revert to initial stoi indexing

* Update to Python 3.8 in CI
  • Loading branch information
mpariente authored Dec 30, 2023
1 parent b9746a8 commit 1acea38
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ jobs:
timeout-minutes: 10
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.6
- name: Set up Python 3.8
uses: actions/setup-python@v2
with:
python-version: 3.6
python-version: 3.8

- name: Install python dependencies
run: |
Expand Down
9 changes: 5 additions & 4 deletions torch_stoi/stoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ def forward(self, est_targets: torch.Tensor,
targets, est_targets, self.dyn_range, self.win, self.win_len,
self.win_len//2
)
# Remove the last mask frame to replicate pystoi behavior
mask, _ = mask.sort(-1, descending=True)
mask = mask[..., 1:]

# Here comes the real computation, take STFT
x_spec = self.stft(targets, self.win, self.nfft, overlap=2)
Expand Down Expand Up @@ -256,9 +259,7 @@ def remove_silent_frames(x, y, dyn_range, window, framelen, hop):
x_frames = x_frames.permute(0, 2, 1)
y_frames = y_frames.permute(0, 2, 1)

mask, _ = mask.long().sort(-1, descending=True)

return x_sil, y_sil, mask
return x_sil, y_sil, mask.long()

@staticmethod
def stft(x, win, fft_size, overlap=4):
Expand All @@ -272,7 +273,7 @@ def stft(x, win, fft_size, overlap=4):
win_len = win.shape[0]
hop = int(win_len / overlap)
frames = unfold(x[:, None, None, :], kernel_size=(1, win_len),
stride=(1, hop))
stride=(1, hop))[..., :-1]
return torch.fft.rfft(frames*win[:, None], n=fft_size, dim=1)

@staticmethod
Expand Down

0 comments on commit 1acea38

Please sign in to comment.