-
Notifications
You must be signed in to change notification settings - Fork 423
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* functional + class version * add req * add installation note * add doc * add test * Apply suggestions from code review Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
- Loading branch information
1 parent
39ca748
commit 8bbc750
Showing
12 changed files
with
408 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
pesq>=0.0.3 | ||
pystoi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from collections import namedtuple | ||
from functools import partial | ||
|
||
import pytest | ||
import torch | ||
from pystoi import stoi as stoi_backend | ||
from torch import Tensor | ||
|
||
from tests.helpers import seed_all | ||
from tests.helpers.testers import MetricTester | ||
from torchmetrics.audio import STOI | ||
from torchmetrics.functional import stoi | ||
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 | ||
|
||
seed_all(42) | ||
|
||
Input = namedtuple("Input", ["preds", "target"]) | ||
|
||
inputs_8k = Input( | ||
preds=torch.rand(2, 3, 8000), | ||
target=torch.rand(2, 3, 8000), | ||
) | ||
inputs_16k = Input( | ||
preds=torch.rand(2, 3, 16000), | ||
target=torch.rand(2, 3, 16000), | ||
) | ||
|
||
|
||
def stoi_original_batch(preds: Tensor, target: Tensor, fs: int, extended: bool): | ||
# shape: preds [BATCH_SIZE, Time] , target [BATCH_SIZE, Time] | ||
# or shape: preds [NUM_BATCHES*BATCH_SIZE, Time] , target [NUM_BATCHES*BATCH_SIZE, Time] | ||
target = target.detach().cpu().numpy() | ||
preds = preds.detach().cpu().numpy() | ||
mss = [] | ||
for b in range(preds.shape[0]): | ||
pesq_val = stoi_backend(target[b, ...], preds[b, ...], fs, extended) | ||
mss.append(pesq_val) | ||
return torch.tensor(mss) | ||
|
||
|
||
def average_metric(preds, target, metric_func): | ||
# shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] | ||
# or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] | ||
return metric_func(preds, target).mean() | ||
|
||
|
||
stoi_original_batch_8k_ext = partial(stoi_original_batch, fs=8000, extended=True) | ||
stoi_original_batch_16k_ext = partial(stoi_original_batch, fs=16000, extended=True) | ||
stoi_original_batch_8k_noext = partial(stoi_original_batch, fs=8000, extended=False) | ||
stoi_original_batch_16k_noext = partial(stoi_original_batch, fs=16000, extended=False) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"preds, target, sk_metric, fs, extended", | ||
[ | ||
(inputs_8k.preds, inputs_8k.target, stoi_original_batch_8k_ext, 8000, True), | ||
(inputs_16k.preds, inputs_16k.target, stoi_original_batch_16k_ext, 16000, True), | ||
(inputs_8k.preds, inputs_8k.target, stoi_original_batch_8k_noext, 8000, False), | ||
(inputs_16k.preds, inputs_16k.target, stoi_original_batch_16k_noext, 16000, False), | ||
], | ||
) | ||
class TestSTOI(MetricTester): | ||
atol = 1e-2 | ||
|
||
@pytest.mark.parametrize("ddp", [True, False]) | ||
@pytest.mark.parametrize("dist_sync_on_step", [True, False]) | ||
def test_stoi(self, preds, target, sk_metric, fs, extended, ddp, dist_sync_on_step): | ||
self.run_class_metric_test( | ||
ddp, | ||
preds, | ||
target, | ||
STOI, | ||
sk_metric=partial(average_metric, metric_func=sk_metric), | ||
dist_sync_on_step=dist_sync_on_step, | ||
metric_args=dict(fs=fs, extended=extended), | ||
) | ||
|
||
def test_stoi_functional(self, preds, target, sk_metric, fs, extended): | ||
self.run_functional_metric_test( | ||
preds, | ||
target, | ||
stoi, | ||
sk_metric, | ||
metric_args=dict(fs=fs, extended=extended), | ||
) | ||
|
||
def test_stoi_differentiability(self, preds, target, sk_metric, fs, extended): | ||
self.run_differentiability_test( | ||
preds=preds, | ||
target=target, | ||
metric_module=STOI, | ||
metric_functional=stoi, | ||
metric_args=dict(fs=fs, extended=extended), | ||
) | ||
|
||
@pytest.mark.skipif( | ||
not _TORCH_GREATER_EQUAL_1_6, reason="half support of core operations on not support before pytorch v1.6" | ||
) | ||
def test_stoi_half_cpu(self, preds, target, sk_metric, fs, extended): | ||
pytest.xfail("STOI metric does not support cpu + half precision") | ||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") | ||
def test_stoi_half_gpu(self, preds, target, sk_metric, fs, extended): | ||
self.run_precision_test_gpu( | ||
preds=preds, | ||
target=target, | ||
metric_module=STOI, | ||
metric_functional=partial(stoi, fs=fs, extended=extended), | ||
metric_args=dict(fs=fs, extended=extended), | ||
) | ||
|
||
|
||
def test_error_on_different_shape(metric_class=STOI): | ||
metric = metric_class(16000) | ||
with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"): | ||
metric(torch.randn(100), torch.randn(50)) | ||
|
||
|
||
def test_on_real_audio(): | ||
import os | ||
|
||
from scipy.io import wavfile | ||
|
||
current_file_dir = os.path.dirname(__file__) | ||
|
||
rate, ref = wavfile.read(os.path.join(current_file_dir, "examples/audio_speech.wav")) | ||
rate, deg = wavfile.read(os.path.join(current_file_dir, "examples/audio_speech_bab_0dB.wav")) | ||
assert torch.allclose( | ||
stoi(torch.from_numpy(deg), torch.from_numpy(ref), rate).float(), | ||
torch.tensor(0.6739177), | ||
rtol=0.0001, | ||
atol=1e-4, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Any, Callable, Optional | ||
|
||
from torch import Tensor, tensor | ||
|
||
from torchmetrics.functional.audio.stoi import stoi | ||
from torchmetrics.metric import Metric | ||
from torchmetrics.utilities.imports import _PYSTOI_AVAILABLE | ||
|
||
|
||
class STOI(Metric): | ||
r"""STOI (Short Term Objective Intelligibility, see [2,3]), a wrapper for the pystoi package [1]. | ||
Note that input will be moved to `cpu` to perform the metric calculation. | ||
Intelligibility measure which is highly correlated with the intelligibility of degraded speech signals, e.g., due | ||
to additive noise, single/multi-channel noise reduction, binary masking and vocoded speech as in CI simulations. | ||
The STOI-measure is intrusive, i.e., a function of the clean and degraded speech signals. STOI may be a good | ||
alternative to the speech intelligibility index (SII) or the speech transmission index (STI), when you are | ||
interested in the effect of nonlinear processing to noisy speech, e.g., noise reduction, binary masking algorithms, | ||
on speech intelligibility. Description taken from [Cees Taal's website](http://www.ceestaal.nl/code/). | ||
.. note:: using this metrics requires you to have ``pystoi`` install. Either install as ``pip install | ||
torchmetrics[audio]`` or ``pip install pystoi`` | ||
Forward accepts | ||
- ``preds``: ``shape [...,time]`` | ||
- ``target``: ``shape [...,time]`` | ||
Args: | ||
fs: | ||
sampling frequency (Hz) | ||
extended: | ||
whether to use the extended STOI described in [4] | ||
compute_on_step: | ||
Forward only calls ``update()`` and returns None if this is set to False. default: True | ||
dist_sync_on_step: | ||
Synchronize metric state across processes at each ``forward()`` | ||
before returning the value at the step. | ||
process_group: | ||
Specify the process group on which synchronization is called. default: None (which selects the entire world) | ||
dist_sync_fn: | ||
Callback that performs the allgather operation on the metric state. When `None`, DDP | ||
will be used to perform the allgather. | ||
Returns: | ||
average STOI value | ||
Raises: | ||
ModuleNotFoundError: | ||
If ``pystoi`` package is not installed | ||
Example: | ||
>>> from torchmetrics.audio import STOI | ||
>>> import torch | ||
>>> g = torch.manual_seed(1) | ||
>>> preds = torch.randn(8000) | ||
>>> target = torch.randn(8000) | ||
>>> stoi = STOI(8000, False) | ||
>>> stoi(preds, target) | ||
tensor(-0.0100) | ||
References: | ||
[1] https://github.com/mpariente/pystoi | ||
[2] C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'A Short-Time Objective Intelligibility Measure for | ||
Time-Frequency Weighted Noisy Speech', ICASSP 2010, Texas, Dallas. | ||
[3] C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'An Algorithm for Intelligibility Prediction of | ||
Time-Frequency Weighted Noisy Speech', IEEE Transactions on Audio, Speech, and Language Processing, 2011. | ||
[4] J. Jensen and C. H. Taal, 'An Algorithm for Predicting the Intelligibility of Speech Masked by Modulated | ||
Noise Maskers', IEEE Transactions on Audio, Speech and Language Processing, 2016. | ||
""" | ||
sum_stoi: Tensor | ||
total: Tensor | ||
is_differentiable = False | ||
higher_is_better = True | ||
|
||
def __init__( | ||
self, | ||
fs: int, | ||
extended: bool = False, | ||
compute_on_step: bool = True, | ||
dist_sync_on_step: bool = False, | ||
process_group: Optional[Any] = None, | ||
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None, | ||
) -> None: | ||
super().__init__( | ||
compute_on_step=compute_on_step, | ||
dist_sync_on_step=dist_sync_on_step, | ||
process_group=process_group, | ||
dist_sync_fn=dist_sync_fn, | ||
) | ||
if not _PYSTOI_AVAILABLE: | ||
raise ModuleNotFoundError( | ||
"STOI metric requires that pystoi is installed." | ||
" Either install as `pip install torchmetrics[audio]` or `pip install pystoi`" | ||
) | ||
self.fs = fs | ||
self.extended = extended | ||
|
||
self.add_state("sum_stoi", default=tensor(0.0), dist_reduce_fx="sum") | ||
self.add_state("total", default=tensor(0), dist_reduce_fx="sum") | ||
|
||
def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore | ||
"""Update state with predictions and targets. | ||
Args: | ||
preds: Predictions from model | ||
target: Ground truth values | ||
""" | ||
stoi_batch = stoi(preds, target, self.fs, self.extended, False).to(self.sum_stoi.device) | ||
|
||
self.sum_stoi += stoi_batch.sum() | ||
self.total += stoi_batch.numel() | ||
|
||
def compute(self) -> Tensor: | ||
"""Computes average STOI.""" | ||
return self.sum_stoi / self.total |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.