From 3eb83939b31f734a3c85f2deb0ca2d38152196f9 Mon Sep 17 00:00:00 2001 From: Jirka Date: Fri, 14 Jan 2022 11:02:47 +0100 Subject: [PATCH 1/3] ShortTermObjectiveIntelligibility --- docs/source/references/modules.rst | 4 +-- tests/audio/test_stoi.py | 10 ++++---- torchmetrics/audio/__init__.py | 4 +++ torchmetrics/audio/stoi.py | 37 ++++++++++++++++++++++++--- torchmetrics/functional/audio/stoi.py | 2 +- 5 files changed, 46 insertions(+), 11 deletions(-) diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index a1c95469f8b..e1cecb26ca8 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -111,10 +111,10 @@ SignalNoiseRatio .. autoclass:: torchmetrics.SignalNoiseRatio :noindex: -STOI +ShortTermObjectiveIntelligibility ~~~~ -.. autoclass:: torchmetrics.audio.stoi.STOI +.. autoclass:: torchmetrics.audio.stoi.ShortTermObjectiveIntelligibility :noindex: diff --git a/tests/audio/test_stoi.py b/tests/audio/test_stoi.py index cd4192e83d7..5d437de48fa 100644 --- a/tests/audio/test_stoi.py +++ b/tests/audio/test_stoi.py @@ -21,7 +21,7 @@ from tests.helpers import seed_all from tests.helpers.testers import MetricTester -from torchmetrics.audio.stoi import STOI +from torchmetrics.audio.stoi import ShortTermObjectiveIntelligibility from torchmetrics.functional.audio.stoi import stoi from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 @@ -82,7 +82,7 @@ def test_stoi(self, preds, target, sk_metric, fs, extended, ddp, dist_sync_on_st ddp, preds, target, - STOI, + ShortTermObjectiveIntelligibility, sk_metric=partial(average_metric, metric_func=sk_metric), dist_sync_on_step=dist_sync_on_step, metric_args=dict(fs=fs, extended=extended), @@ -101,7 +101,7 @@ def test_stoi_differentiability(self, preds, target, sk_metric, fs, extended): self.run_differentiability_test( preds=preds, target=target, - metric_module=STOI, + metric_module=ShortTermObjectiveIntelligibility, metric_functional=stoi, metric_args=dict(fs=fs, extended=extended), ) @@ -117,13 +117,13 @@ def test_stoi_half_gpu(self, preds, target, sk_metric, fs, extended): self.run_precision_test_gpu( preds=preds, target=target, - metric_module=STOI, + metric_module=ShortTermObjectiveIntelligibility, metric_functional=partial(stoi, fs=fs, extended=extended), metric_args=dict(fs=fs, extended=extended), ) -def test_error_on_different_shape(metric_class=STOI): +def test_error_on_different_shape(metric_class=ShortTermObjectiveIntelligibility): metric = metric_class(16000) with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"): metric(torch.randn(100), torch.randn(50)) diff --git a/torchmetrics/audio/__init__.py b/torchmetrics/audio/__init__.py index 6f1bd8492d0..1ac8a6c0eff 100644 --- a/torchmetrics/audio/__init__.py +++ b/torchmetrics/audio/__init__.py @@ -16,3 +16,7 @@ from torchmetrics.audio.si_sdr import SI_SDR # noqa: F401 from torchmetrics.audio.si_snr import SI_SNR # noqa: F401 from torchmetrics.audio.snr import SNR, ScaleInvariantSignalNoiseRatio, SignalNoiseRatio # noqa: F401 +from torchmetrics.utilities.imports import _PYSTOI_AVAILABLE + +if _PYSTOI_AVAILABLE: + from torchmetrics.audio.stoi import ShortTermObjectiveIntelligibility # noqa: F401 diff --git a/torchmetrics/audio/stoi.py b/torchmetrics/audio/stoi.py index 67eb3d28e83..5ecffbcfbfa 100644 --- a/torchmetrics/audio/stoi.py +++ b/torchmetrics/audio/stoi.py @@ -13,14 +13,16 @@ # limitations under the License. from typing import Any, Callable, Optional +from deprecate import deprecated, void from torch import Tensor, tensor from torchmetrics.functional.audio.stoi import stoi from torchmetrics.metric import Metric +from torchmetrics.utilities import _future_warning from torchmetrics.utilities.imports import _PYSTOI_AVAILABLE -class STOI(Metric): +class ShortTermObjectiveIntelligibility(Metric): r"""STOI (Short Term Objective Intelligibility, see [2,3]), a wrapper for the pystoi package [1]. Note that input will be moved to `cpu` to perform the metric calculation. @@ -63,12 +65,12 @@ class STOI(Metric): If ``pystoi`` package is not installed Example: - >>> from torchmetrics.audio.stoi import STOI + >>> from torchmetrics.audio.stoi import ShortTermObjectiveIntelligibility >>> import torch >>> g = torch.manual_seed(1) >>> preds = torch.randn(8000) >>> target = torch.randn(8000) - >>> stoi = STOI(8000, False) + >>> stoi = ShortTermObjectiveIntelligibility(8000, False) >>> stoi(preds, target) tensor(-0.0100) @@ -131,3 +133,32 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore def compute(self) -> Tensor: """Computes average STOI.""" return self.sum_stoi / self.total + + +class STOI(ShortTermObjectiveIntelligibility): + r"""STOI (Short Term Objective Intelligibility), a wrapper for the pystoi package. + + .. deprecated:: v0.7 + Use :class:`torchmetrics.audio.ShortTermObjectiveIntelligibility`. Will be removed in v0.8. + + Example: + >>> import torch + >>> g = torch.manual_seed(1) + >>> preds = torch.randn(8000) + >>> target = torch.randn(8000) + >>> stoi = STOI(8000, False) + >>> stoi(preds, target) + tensor(-0.0100) + """ + + @deprecated(target=ShortTermObjectiveIntelligibility, deprecated_in="0.7", remove_in="0.8", stream=_future_warning) + def __init__( + self, + fs: int, + extended: bool = False, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None, + ) -> None: + void(fs, extended, compute_on_step, dist_sync_on_step, process_group, dist_sync_fn) diff --git a/torchmetrics/functional/audio/stoi.py b/torchmetrics/functional/audio/stoi.py index ff8acb12c96..17b98056468 100644 --- a/torchmetrics/functional/audio/stoi.py +++ b/torchmetrics/functional/audio/stoi.py @@ -82,7 +82,7 @@ def stoi(preds: Tensor, target: Tensor, fs: int, extended: bool = False, keep_sa """ if not _PYSTOI_AVAILABLE: raise ModuleNotFoundError( - "STOI metric requires that `pystoi` is installed." + "ShortTermObjectiveIntelligibility metric requires that `pystoi` is installed." " Either install as `pip install torchmetrics[audio]` or `pip install pystoi`." ) _check_same_shape(preds, target) From db2f3b9bc5e847b7cea0f88b8c98e53f814dd93e Mon Sep 17 00:00:00 2001 From: Jirka Date: Fri, 14 Jan 2022 11:08:39 +0100 Subject: [PATCH 2/3] chlog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 780798b458a..1fd7b8e5ea5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -43,6 +43,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * `MatthewsCorrcoef` -> `MatthewsCorrCoef` * `PearsonCorrcoef` -> `PearsonCorrCoef` * `SpearmanCorrcoef` -> `SpearmanCorrCoef` +- Renamed audio STOI metric `audio.STOI` to `audio.ShortTermObjectiveIntelligibility` ([#753](https://github.com/PyTorchLightning/metrics/pull/753)) - Renamed audio SDR metrics: ([#711](https://github.com/PyTorchLightning/metrics/pull/711)) * `functional.sdr` -> `functional.signal_distortion_ratio` * `functional.si_sdr` -> `functional.scale_invariant_signal_distortion_ratio` From 49f1d421957ca0a95ba9354c2187e866f2fa895c Mon Sep 17 00:00:00 2001 From: Jirka Date: Fri, 14 Jan 2022 12:05:44 +0100 Subject: [PATCH 3/3] docs --- .github/workflows/docs-check.yml | 3 ++- docs/source/references/modules.rst | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/docs-check.yml b/.github/workflows/docs-check.yml index ab81cbac4b1..66abccbd121 100644 --- a/.github/workflows/docs-check.yml +++ b/.github/workflows/docs-check.yml @@ -37,9 +37,10 @@ jobs: working-directory: ./docs run: | # First run the same pipeline as Read-The-Docs - apt-get update && sudo apt-get install -y cmake + sudo apt-get update && sudo apt-get install -y cmake make doctest make coverage + shell: bash make-docs: runs-on: ubuntu-20.04 diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index e1cecb26ca8..79537ae6e17 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -112,7 +112,7 @@ SignalNoiseRatio :noindex: ShortTermObjectiveIntelligibility -~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: torchmetrics.audio.stoi.ShortTermObjectiveIntelligibility :noindex: