From 5b3141e51bef7f5519877c721bfd96d7ac23ee43 Mon Sep 17 00:00:00 2001 From: Daniel Stancl <46073029+stancld@users.noreply.github.com> Date: Fri, 14 Jan 2022 21:02:01 +0100 Subject: [PATCH 1/4] Refactor stoi functional --- CHANGELOG.md | 4 +++- tests/audio/test_stoi.py | 10 ++++----- torchmetrics/audio/stoi.py | 6 ++++-- torchmetrics/functional/audio/stoi.py | 29 ++++++++++++++++++++++++--- 4 files changed, 38 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index dd23613d028..8382baafd47 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -43,7 +43,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * `MatthewsCorrcoef` -> `MatthewsCorrCoef` * `PearsonCorrcoef` -> `PearsonCorrCoef` * `SpearmanCorrcoef` -> `SpearmanCorrCoef` -- Renamed audio STOI metric `audio.STOI` to `audio.ShortTermObjectiveIntelligibility` ([#753](https://github.com/PyTorchLightning/metrics/pull/753)) +- Renamed audio STOI metric: ([#753](https://github.com/PyTorchLightning/metrics/pull/753)) + * `audio.STOI` to `audio.ShortTermObjectiveIntelligibility` + * `functional.audio.stoi` to `functional.audio.short_term_objective_intelligibility` - Renamed audio PESQ metrics: ([#751](https://github.com/PyTorchLightning/metrics/pull/751)) * `functional.audio.pesq` -> `functional.audio.perceptual_evaluation_speech_quality` * `audio.PESQ` -> `audio.PerceptualEvaluationSpeechQuality` diff --git a/tests/audio/test_stoi.py b/tests/audio/test_stoi.py index 5d437de48fa..3142d2c9d97 100644 --- a/tests/audio/test_stoi.py +++ b/tests/audio/test_stoi.py @@ -22,7 +22,7 @@ from tests.helpers import seed_all from tests.helpers.testers import MetricTester from torchmetrics.audio.stoi import ShortTermObjectiveIntelligibility -from torchmetrics.functional.audio.stoi import stoi +from torchmetrics.functional.audio.stoi import short_term_objective_intelligibility from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 seed_all(42) @@ -92,7 +92,7 @@ def test_stoi_functional(self, preds, target, sk_metric, fs, extended): self.run_functional_metric_test( preds, target, - stoi, + short_term_objective_intelligibility, sk_metric, metric_args=dict(fs=fs, extended=extended), ) @@ -102,7 +102,7 @@ def test_stoi_differentiability(self, preds, target, sk_metric, fs, extended): preds=preds, target=target, metric_module=ShortTermObjectiveIntelligibility, - metric_functional=stoi, + metric_functional=short_term_objective_intelligibility, metric_args=dict(fs=fs, extended=extended), ) @@ -118,7 +118,7 @@ def test_stoi_half_gpu(self, preds, target, sk_metric, fs, extended): preds=preds, target=target, metric_module=ShortTermObjectiveIntelligibility, - metric_functional=partial(stoi, fs=fs, extended=extended), + metric_functional=partial(short_term_objective_intelligibility, fs=fs, extended=extended), metric_args=dict(fs=fs, extended=extended), ) @@ -139,7 +139,7 @@ def test_on_real_audio(): rate, ref = wavfile.read(os.path.join(current_file_dir, "examples/audio_speech.wav")) rate, deg = wavfile.read(os.path.join(current_file_dir, "examples/audio_speech_bab_0dB.wav")) assert torch.allclose( - stoi(torch.from_numpy(deg), torch.from_numpy(ref), rate).float(), + short_term_objective_intelligibility(torch.from_numpy(deg), torch.from_numpy(ref), rate).float(), torch.tensor(0.6739177), rtol=0.0001, atol=1e-4, diff --git a/torchmetrics/audio/stoi.py b/torchmetrics/audio/stoi.py index 5ecffbcfbfa..a7a74a14e63 100644 --- a/torchmetrics/audio/stoi.py +++ b/torchmetrics/audio/stoi.py @@ -16,7 +16,7 @@ from deprecate import deprecated, void from torch import Tensor, tensor -from torchmetrics.functional.audio.stoi import stoi +from torchmetrics.functional.audio.stoi import short_term_objective_intelligibility from torchmetrics.metric import Metric from torchmetrics.utilities import _future_warning from torchmetrics.utilities.imports import _PYSTOI_AVAILABLE @@ -125,7 +125,9 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore preds: Predictions from model target: Ground truth values """ - stoi_batch = stoi(preds, target, self.fs, self.extended, False).to(self.sum_stoi.device) + stoi_batch = short_term_objective_intelligibility( + preds, target, self.fs, self.extended, False).to(self.sum_stoi.device + ) self.sum_stoi += stoi_batch.sum() self.total += stoi_batch.numel() diff --git a/torchmetrics/functional/audio/stoi.py b/torchmetrics/functional/audio/stoi.py index 17b98056468..4abce8c25fe 100644 --- a/torchmetrics/functional/audio/stoi.py +++ b/torchmetrics/functional/audio/stoi.py @@ -13,6 +13,7 @@ # limitations under the License. import numpy as np import torch +from deprecate import deprecated, void from torchmetrics.utilities.imports import _PYSTOI_AVAILABLE @@ -22,10 +23,13 @@ stoi_backend = None from torch import Tensor +from torchmetrics.utilities import _future_warning from torchmetrics.utilities.checks import _check_same_shape -def stoi(preds: Tensor, target: Tensor, fs: int, extended: bool = False, keep_same_device: bool = False) -> Tensor: +def short_term_objective_intelligibility( + preds: Tensor, target: Tensor, fs: int, extended: bool = False, keep_same_device: bool = False +) -> Tensor: r"""STOI (Short Term Objective Intelligibility, see [2,3]), a wrapper for the pystoi package [1]. Note that input will be moved to `cpu` to perform the metric calculation. @@ -59,12 +63,12 @@ def stoi(preds: Tensor, target: Tensor, fs: int, extended: bool = False, keep_sa If ``pystoi`` package is not installed Example: - >>> from torchmetrics.functional.audio.stoi import stoi + >>> from torchmetrics.functional.audio.stoi import short_term_objective_intelligibility >>> import torch >>> g = torch.manual_seed(1) >>> preds = torch.randn(8000) >>> target = torch.randn(8000) - >>> stoi(preds, target, 8000).float() + >>> short_term_objective_intelligibility(preds, target, 8000).float() tensor(-0.0100) References: @@ -103,3 +107,22 @@ def stoi(preds: Tensor, target: Tensor, fs: int, extended: bool = False, keep_sa stoi_val = stoi_val.to(preds.device) return stoi_val + + +@deprecated(target=short_term_objective_intelligibility, deprecated_in="0.7", remove_in="0.8", stream=_future_warning) +def stoi(preds: Tensor, target: Tensor, fs: int, mode: str, keep_same_device: bool = False) -> Tensor: + r"""STOI (Short Term Objective Intelligibility) + + .. deprecated:: v0.7 + Use :func:`torchmetrics.functional.audio.short_term_objective_intelligibility`. Will be removed in v0.8. + + Example: + >>> from torchmetrics.functional.audio.stoi import stoi + >>> import torch + >>> g = torch.manual_seed(1) + >>> preds = torch.randn(8000) + >>> target = torch.randn(8000) + >>> stoi(preds, target, 8000).float() + tensor(-0.0100) + """ + return void(preds, target, fs, mode, keep_same_device) From cb501f15988c675fca2d0c0509576cc036761cf0 Mon Sep 17 00:00:00 2001 From: Daniel Stancl <46073029+stancld@users.noreply.github.com> Date: Fri, 14 Jan 2022 21:03:16 +0100 Subject: [PATCH 2/4] changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8382baafd47..220e7e56b3f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -43,7 +43,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * `MatthewsCorrcoef` -> `MatthewsCorrCoef` * `PearsonCorrcoef` -> `PearsonCorrCoef` * `SpearmanCorrcoef` -> `SpearmanCorrCoef` -- Renamed audio STOI metric: ([#753](https://github.com/PyTorchLightning/metrics/pull/753)) +- Renamed audio STOI metric: ([#753](https://github.com/PyTorchLightning/metrics/pull/753), [#758](https://github.com/PyTorchLightning/metrics/pull/758)) * `audio.STOI` to `audio.ShortTermObjectiveIntelligibility` * `functional.audio.stoi` to `functional.audio.short_term_objective_intelligibility` - Renamed audio PESQ metrics: ([#751](https://github.com/PyTorchLightning/metrics/pull/751)) From a7d3fa23fe78ce11737970ecf41deef43dc0f7e2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 14 Jan 2022 20:03:53 +0000 Subject: [PATCH 3/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/audio/stoi.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmetrics/audio/stoi.py b/torchmetrics/audio/stoi.py index a7a74a14e63..a3ad567e256 100644 --- a/torchmetrics/audio/stoi.py +++ b/torchmetrics/audio/stoi.py @@ -125,8 +125,8 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore preds: Predictions from model target: Ground truth values """ - stoi_batch = short_term_objective_intelligibility( - preds, target, self.fs, self.extended, False).to(self.sum_stoi.device + stoi_batch = short_term_objective_intelligibility(preds, target, self.fs, self.extended, False).to( + self.sum_stoi.device ) self.sum_stoi += stoi_batch.sum() From c4ef76a0835fada57a514beabbfb1fe809eac97e Mon Sep 17 00:00:00 2001 From: Jirka Date: Fri, 14 Jan 2022 21:12:41 +0100 Subject: [PATCH 4/4] docs --- docs/source/references/functional.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/references/functional.rst b/docs/source/references/functional.rst index 1ca4e5bcd80..4cd9880c3dc 100644 --- a/docs/source/references/functional.rst +++ b/docs/source/references/functional.rst @@ -52,10 +52,10 @@ signal_noise_ratio [func] :noindex: -stoi [func] -~~~~~~~~~~~ +short_term_objective_intelligibility [func] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autofunction:: torchmetrics.functional.audio.stoi.stoi +.. autofunction:: torchmetrics.functional.audio.stoi.short_term_objective_intelligibility :noindex: