From a2941cc568d4ccec23fe7bc0c00f455eae6dfe14 Mon Sep 17 00:00:00 2001 From: Jirka Date: Thu, 6 Jan 2022 12:40:16 +0100 Subject: [PATCH] sisnr --- CHANGELOG.md | 5 ++--- README.md | 2 +- docs/source/references/functional.rst | 6 ++--- docs/source/references/modules.rst | 6 ++--- tests/audio/test_si_snr.py | 18 +++++++-------- tests/audio/test_snr.py | 8 +++---- torchmetrics/__init__.py | 4 ++-- torchmetrics/audio/__init__.py | 2 +- torchmetrics/audio/si_snr.py | 8 +++---- torchmetrics/audio/snr.py | 10 ++++----- torchmetrics/functional/__init__.py | 5 ++--- torchmetrics/functional/audio/__init__.py | 2 +- torchmetrics/functional/audio/si_snr.py | 8 +++---- torchmetrics/functional/audio/snr.py | 27 +++++------------------ 14 files changed, 47 insertions(+), 64 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cfcd0b1f16a..ca4e0536f97 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -65,9 +65,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Renamed audio SNR metrics: ([#712](https://github.com/PyTorchLightning/metrics/pull/712)) - * `functional.snr` -> `functional.signal_distortion_ratio` - * `functional.si_snr` -> `functional.scale_invariant_signal_noise_ratio` - * `SI_SNR` -> `ScaleInvariantSNR` + * `functional.si_snr` -> `functional.sisnr` + * `SI_SNR` -> `SISNR` ### Removed diff --git a/README.md b/README.md index 1019c7d1864..536eff29b36 100644 --- a/README.md +++ b/README.md @@ -267,7 +267,7 @@ We currently have implemented metrics within the following domains: - Audio ( [SI_SDR](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#si-sdr), - [ScaleInvariantSNR](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#ScaleInvariantSNR), + [SISNR](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#SISNR), [SNR](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#snr) and [few more](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#audio-metrics) ) diff --git a/docs/source/references/functional.rst b/docs/source/references/functional.rst index b79f85d7d83..edac0afb059 100644 --- a/docs/source/references/functional.rst +++ b/docs/source/references/functional.rst @@ -38,10 +38,10 @@ si_sdr [func] :noindex: -scale_invariant_signal_noise_ratio [func] -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +sisnr [func] +~~~~~~~~~~~~ -.. autofunction:: torchmetrics.functional.si_snr +.. autofunction:: torchmetrics.functional.sisnr :noindex: diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index 6abd716e3e6..cc0a5a81808 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -97,10 +97,10 @@ SI_SDR .. autoclass:: torchmetrics.SI_SDR :noindex: -ScaleInvariantSNR -~~~~~~~~~~~~~~~~~ +SISNR +~~~~~ -.. autoclass:: torchmetrics.ScaleInvariantSNR +.. autoclass:: torchmetrics.SISNR :noindex: SNR diff --git a/tests/audio/test_si_snr.py b/tests/audio/test_si_snr.py index 5156dc16920..35c07bd533b 100644 --- a/tests/audio/test_si_snr.py +++ b/tests/audio/test_si_snr.py @@ -21,8 +21,8 @@ from tests.helpers import seed_all from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester -from torchmetrics.audio import ScaleInvariantSNR -from torchmetrics.functional import scale_invariant_signal_noise_ratio +from torchmetrics.audio import SISNR +from torchmetrics.functional import sisnr from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 seed_all(42) @@ -79,7 +79,7 @@ def test_si_snr(self, preds, target, sk_metric, ddp, dist_sync_on_step): ddp, preds, target, - ScaleInvariantSNR, + SISNR, sk_metric=partial(average_metric, metric_func=sk_metric), dist_sync_on_step=dist_sync_on_step, ) @@ -88,7 +88,7 @@ def test_si_snr_functional(self, preds, target, sk_metric): self.run_functional_metric_test( preds, target, - scale_invariant_signal_noise_ratio, + sisnr, sk_metric, ) @@ -96,8 +96,8 @@ def test_si_snr_differentiability(self, preds, target, sk_metric): self.run_differentiability_test( preds=preds, target=target, - metric_module=ScaleInvariantSNR, - metric_functional=scale_invariant_signal_noise_ratio, + metric_module=SISNR, + metric_functional=sisnr, ) @pytest.mark.skipif( @@ -111,12 +111,12 @@ def test_si_snr_half_gpu(self, preds, target, sk_metric): self.run_precision_test_gpu( preds=preds, target=target, - metric_module=ScaleInvariantSNR, - metric_functional=scale_invariant_signal_noise_ratio, + metric_module=SISNR, + metric_functional=sisnr, ) -def test_error_on_different_shape(metric_class=ScaleInvariantSNR): +def test_error_on_different_shape(metric_class=SISNR): metric = metric_class() with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"): metric(torch.randn(100), torch.randn(50)) diff --git a/tests/audio/test_snr.py b/tests/audio/test_snr.py index ed6ca585928..e5ff2d5e7d7 100644 --- a/tests/audio/test_snr.py +++ b/tests/audio/test_snr.py @@ -23,7 +23,7 @@ from tests.helpers import seed_all from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester from torchmetrics.audio import SNR -from torchmetrics.functional import signal_noise_ratio +from torchmetrics.functional import snr from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 seed_all(42) @@ -96,7 +96,7 @@ def test_snr_functional(self, preds, target, sk_metric, zero_mean): self.run_functional_metric_test( preds, target, - signal_noise_ratio, + snr, sk_metric, metric_args=dict(zero_mean=zero_mean), ) @@ -106,7 +106,7 @@ def test_snr_differentiability(self, preds, target, sk_metric, zero_mean): preds=preds, target=target, metric_module=SNR, - metric_functional=signal_noise_ratio, + metric_functional=snr, metric_args={"zero_mean": zero_mean}, ) @@ -122,7 +122,7 @@ def test_snr_half_gpu(self, preds, target, sk_metric, zero_mean): preds=preds, target=target, metric_module=SNR, - metric_functional=signal_noise_ratio, + metric_functional=snr, metric_args={"zero_mean": zero_mean}, ) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index e35816c1682..d3a852bca01 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -13,7 +13,7 @@ from torchmetrics import functional # noqa: E402 from torchmetrics.aggregation import CatMetric, MaxMetric, MeanMetric, MinMetric, SumMetric # noqa: E402 -from torchmetrics.audio import PIT, SDR, SI_SDR, SI_SNR, SNR, ScaleInvariantSNR # noqa: E402 +from torchmetrics.audio import PIT, SDR, SI_SDR, SI_SNR, SNR, SISNR # noqa: E402 from torchmetrics.classification import ( # noqa: E402, F401 AUC, AUROC, @@ -144,7 +144,7 @@ "SDR", "SI_SDR", "SI_SNR", - "ScaleInvariantSNR", + "SISNR", "SNR", "SpearmanCorrcoef", "SpearmanCorrCoef", diff --git a/torchmetrics/audio/__init__.py b/torchmetrics/audio/__init__.py index a45102f2801..0b043bd466b 100644 --- a/torchmetrics/audio/__init__.py +++ b/torchmetrics/audio/__init__.py @@ -15,4 +15,4 @@ from torchmetrics.audio.sdr import SDR # noqa: F401 from torchmetrics.audio.si_sdr import SI_SDR # noqa: F401 from torchmetrics.audio.si_snr import SI_SNR # noqa: F401 -from torchmetrics.audio.snr import SNR, ScaleInvariantSNR # noqa: F401 +from torchmetrics.audio.snr import SNR, SISNR # noqa: F401 diff --git a/torchmetrics/audio/si_snr.py b/torchmetrics/audio/si_snr.py index 5dbf98ae71d..eabce5d6b3f 100644 --- a/torchmetrics/audio/si_snr.py +++ b/torchmetrics/audio/si_snr.py @@ -16,14 +16,14 @@ from torch import Tensor -from torchmetrics.audio.snr import ScaleInvariantSNR +from torchmetrics.audio.snr import SISNR -class SI_SNR(ScaleInvariantSNR): +class SI_SNR(SISNR): """Scale-invariant signal-to-noise ratio (SI-SNR). .. deprecated:: v0.7 - Use :class:`torchmetrics.ScaleInvariantSNR`. Will be removed in v0.8. + Use :class:`torchmetrics.SISNR`. Will be removed in v0.8. Example: >>> import torch @@ -39,5 +39,5 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None, ) -> None: - warn("`SI_SDR` was renamed to `ScaleInvariantSNR` in v0.7 and it will be removed in v0.8", DeprecationWarning) + warn("`SI_SDR` was renamed to `SISNR` in v0.7 and it will be removed in v0.8", DeprecationWarning) super().__init__(compute_on_step, dist_sync_on_step, process_group, dist_sync_fn) diff --git a/torchmetrics/audio/snr.py b/torchmetrics/audio/snr.py index 8a0adb4958c..a02461061b8 100644 --- a/torchmetrics/audio/snr.py +++ b/torchmetrics/audio/snr.py @@ -15,7 +15,7 @@ from torch import Tensor, tensor -from torchmetrics.functional.audio.snr import scale_invariant_signal_noise_ratio, snr +from torchmetrics.functional.audio.snr import sisnr, snr from torchmetrics.metric import Metric @@ -111,7 +111,7 @@ def compute(self) -> Tensor: return self.sum_snr / self.total -class ScaleInvariantSNR(Metric): +class SISNR(Metric): """Scale-invariant signal-to-noise ratio (SI-SNR). Forward accepts @@ -140,10 +140,10 @@ class ScaleInvariantSNR(Metric): Example: >>> import torch - >>> from torchmetrics import ScaleInvariantSNR + >>> from torchmetrics import SISNR >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) >>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0]) - >>> si_snr = ScaleInvariantSNR() + >>> si_snr = SISNR() >>> si_snr(preds, target) tensor(15.0918) @@ -182,7 +182,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore preds: Predictions from model target: Ground truth values """ - si_snr_batch = scale_invariant_signal_noise_ratio(preds=preds, target=target) + si_snr_batch = sisnr(preds=preds, target=target) self.sum_si_snr += si_snr_batch.sum() self.total += si_snr_batch.numel() diff --git a/torchmetrics/functional/__init__.py b/torchmetrics/functional/__init__.py index a3e92094051..a455930f72b 100644 --- a/torchmetrics/functional/__init__.py +++ b/torchmetrics/functional/__init__.py @@ -15,7 +15,7 @@ from torchmetrics.functional.audio.sdr import sdr from torchmetrics.functional.audio.si_sdr import si_sdr from torchmetrics.functional.audio.si_snr import si_snr -from torchmetrics.functional.audio.snr import scale_invariant_signal_noise_ratio, signal_noise_ratio, snr +from torchmetrics.functional.audio.snr import sisnr, snr, snr from torchmetrics.functional.classification.accuracy import accuracy from torchmetrics.functional.classification.auc import auc from torchmetrics.functional.classification.auroc import auroc @@ -131,9 +131,8 @@ "sdr", "si_sdr", "si_snr", - "scale_invariant_signal_noise_ratio", + "sisnr", "snr", - "signal_noise_ratio", "spearman_corrcoef", "specificity", "squad", diff --git a/torchmetrics/functional/audio/__init__.py b/torchmetrics/functional/audio/__init__.py index cf0f11f8c04..7a120d19529 100644 --- a/torchmetrics/functional/audio/__init__.py +++ b/torchmetrics/functional/audio/__init__.py @@ -15,4 +15,4 @@ from torchmetrics.functional.audio.sdr import sdr # noqa: F401 from torchmetrics.functional.audio.si_sdr import si_sdr # noqa: F401 from torchmetrics.functional.audio.si_snr import si_snr # noqa: F401 -from torchmetrics.functional.audio.snr import scale_invariant_signal_noise_ratio, signal_noise_ratio, snr # noqa: F401 +from torchmetrics.functional.audio.snr import sisnr, snr, snr # noqa: F401 diff --git a/torchmetrics/functional/audio/si_snr.py b/torchmetrics/functional/audio/si_snr.py index 65acdd8f59c..108ace3e948 100644 --- a/torchmetrics/functional/audio/si_snr.py +++ b/torchmetrics/functional/audio/si_snr.py @@ -15,14 +15,14 @@ from torch import Tensor -from torchmetrics.functional.audio.snr import scale_invariant_signal_noise_ratio +from torchmetrics.functional.audio.snr import sisnr def si_snr(preds: Tensor, target: Tensor) -> Tensor: """Scale-invariant signal-to-noise ratio (SI-SNR). .. deprecated:: v0.7 - Use :func:`torchmetrics.functional.scale_invariant_signal_noise_ratio`. Will be removed in v0.8. + Use :func:`torchmetrics.functional.sisnr`. Will be removed in v0.8. Example: >>> import torch @@ -30,7 +30,7 @@ def si_snr(preds: Tensor, target: Tensor) -> Tensor: tensor(15.0918) """ warn( - "`si_snr` was renamed to `scale_invariant_signal_noise_ratio` in v0.7 and it will be removed in v0.8", + "`si_snr` was renamed to `sisnr` in v0.7 and it will be removed in v0.8", DeprecationWarning, ) - return scale_invariant_signal_noise_ratio(preds, target) + return sisnr(preds, target) diff --git a/torchmetrics/functional/audio/snr.py b/torchmetrics/functional/audio/snr.py index 6ed6735d94b..dd4d98740c6 100644 --- a/torchmetrics/functional/audio/snr.py +++ b/torchmetrics/functional/audio/snr.py @@ -20,7 +20,7 @@ from torchmetrics.utilities.checks import _check_same_shape -def signal_noise_ratio(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor: +def snr(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor: r"""Signal-to-noise ratio (SNR_): .. math:: @@ -42,10 +42,10 @@ def signal_noise_ratio(preds: Tensor, target: Tensor, zero_mean: bool = False) - snr value of shape [...] Example: - >>> from torchmetrics.functional.audio import signal_noise_ratio + >>> from torchmetrics.functional.audio import snr >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) >>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0]) - >>> signal_noise_ratio(preds, target) + >>> snr(preds, target) tensor(16.1805) References: @@ -68,22 +68,7 @@ def signal_noise_ratio(preds: Tensor, target: Tensor, zero_mean: bool = False) - return snr_value -def snr(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor: - r"""Signal-to-noise ratio (SNR_) - - .. deprecated:: v0.7 - Use :func:`torchmetrics.functional.signal_noise_ratio`. Will be removed in v0.8. - - Example: - >>> snr(torch.tensor([2.5, 0.0, 2.0, 8.0]), torch.tensor([3.0, -0.5, 2.0, 7.0])) - tensor(16.1805) - - """ - warn("`snr` was renamed to `signal_noise_ratio` in v0.7 and it will be removed in v0.8", DeprecationWarning) - return signal_noise_ratio(preds, target, zero_mean) - - -def scale_invariant_signal_noise_ratio(preds: Tensor, target: Tensor) -> Tensor: +def sisnr(preds: Tensor, target: Tensor) -> Tensor: """Scale-invariant signal-to-noise ratio (SI-SNR). Args: @@ -97,10 +82,10 @@ def scale_invariant_signal_noise_ratio(preds: Tensor, target: Tensor) -> Tensor: Example: >>> import torch - >>> from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio + >>> from torchmetrics.functional.audio import sisnr >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) >>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0]) - >>> scale_invariant_signal_noise_ratio(preds, target) + >>> sisnr(preds, target) tensor(15.0918) References: