Skip to content

Commit

Permalink
sisnr
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Jan 6, 2022
1 parent a8ec4d7 commit a2941cc
Show file tree
Hide file tree
Showing 14 changed files with 47 additions and 64 deletions.
5 changes: 2 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0


- Renamed audio SNR metrics: ([#712](https://github.com/PyTorchLightning/metrics/pull/712))
* `functional.snr` -> `functional.signal_distortion_ratio`
* `functional.si_snr` -> `functional.scale_invariant_signal_noise_ratio`
* `SI_SNR` -> `ScaleInvariantSNR`
* `functional.si_snr` -> `functional.sisnr`
* `SI_SNR` -> `SISNR`


### Removed
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ We currently have implemented metrics within the following domains:

- Audio (
[SI_SDR](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#si-sdr),
[ScaleInvariantSNR](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#ScaleInvariantSNR),
[SISNR](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#SISNR),
[SNR](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#snr)
and [few more](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#audio-metrics)
)
Expand Down
6 changes: 3 additions & 3 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ si_sdr [func]
:noindex:


scale_invariant_signal_noise_ratio [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
sisnr [func]
~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.si_snr
.. autofunction:: torchmetrics.functional.sisnr
:noindex:


Expand Down
6 changes: 3 additions & 3 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@ SI_SDR
.. autoclass:: torchmetrics.SI_SDR
:noindex:

ScaleInvariantSNR
~~~~~~~~~~~~~~~~~
SISNR
~~~~~

.. autoclass:: torchmetrics.ScaleInvariantSNR
.. autoclass:: torchmetrics.SISNR
:noindex:

SNR
Expand Down
18 changes: 9 additions & 9 deletions tests/audio/test_si_snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

from tests.helpers import seed_all
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from torchmetrics.audio import ScaleInvariantSNR
from torchmetrics.functional import scale_invariant_signal_noise_ratio
from torchmetrics.audio import SISNR
from torchmetrics.functional import sisnr
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6

seed_all(42)
Expand Down Expand Up @@ -79,7 +79,7 @@ def test_si_snr(self, preds, target, sk_metric, ddp, dist_sync_on_step):
ddp,
preds,
target,
ScaleInvariantSNR,
SISNR,
sk_metric=partial(average_metric, metric_func=sk_metric),
dist_sync_on_step=dist_sync_on_step,
)
Expand All @@ -88,16 +88,16 @@ def test_si_snr_functional(self, preds, target, sk_metric):
self.run_functional_metric_test(
preds,
target,
scale_invariant_signal_noise_ratio,
sisnr,
sk_metric,
)

def test_si_snr_differentiability(self, preds, target, sk_metric):
self.run_differentiability_test(
preds=preds,
target=target,
metric_module=ScaleInvariantSNR,
metric_functional=scale_invariant_signal_noise_ratio,
metric_module=SISNR,
metric_functional=sisnr,
)

@pytest.mark.skipif(
Expand All @@ -111,12 +111,12 @@ def test_si_snr_half_gpu(self, preds, target, sk_metric):
self.run_precision_test_gpu(
preds=preds,
target=target,
metric_module=ScaleInvariantSNR,
metric_functional=scale_invariant_signal_noise_ratio,
metric_module=SISNR,
metric_functional=sisnr,
)


def test_error_on_different_shape(metric_class=ScaleInvariantSNR):
def test_error_on_different_shape(metric_class=SISNR):
metric = metric_class()
with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"):
metric(torch.randn(100), torch.randn(50))
8 changes: 4 additions & 4 deletions tests/audio/test_snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from tests.helpers import seed_all
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from torchmetrics.audio import SNR
from torchmetrics.functional import signal_noise_ratio
from torchmetrics.functional import snr
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6

seed_all(42)
Expand Down Expand Up @@ -96,7 +96,7 @@ def test_snr_functional(self, preds, target, sk_metric, zero_mean):
self.run_functional_metric_test(
preds,
target,
signal_noise_ratio,
snr,
sk_metric,
metric_args=dict(zero_mean=zero_mean),
)
Expand All @@ -106,7 +106,7 @@ def test_snr_differentiability(self, preds, target, sk_metric, zero_mean):
preds=preds,
target=target,
metric_module=SNR,
metric_functional=signal_noise_ratio,
metric_functional=snr,
metric_args={"zero_mean": zero_mean},
)

Expand All @@ -122,7 +122,7 @@ def test_snr_half_gpu(self, preds, target, sk_metric, zero_mean):
preds=preds,
target=target,
metric_module=SNR,
metric_functional=signal_noise_ratio,
metric_functional=snr,
metric_args={"zero_mean": zero_mean},
)

Expand Down
4 changes: 2 additions & 2 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from torchmetrics import functional # noqa: E402
from torchmetrics.aggregation import CatMetric, MaxMetric, MeanMetric, MinMetric, SumMetric # noqa: E402
from torchmetrics.audio import PIT, SDR, SI_SDR, SI_SNR, SNR, ScaleInvariantSNR # noqa: E402
from torchmetrics.audio import PIT, SDR, SI_SDR, SI_SNR, SNR, SISNR # noqa: E402
from torchmetrics.classification import ( # noqa: E402, F401
AUC,
AUROC,
Expand Down Expand Up @@ -144,7 +144,7 @@
"SDR",
"SI_SDR",
"SI_SNR",
"ScaleInvariantSNR",
"SISNR",
"SNR",
"SpearmanCorrcoef",
"SpearmanCorrCoef",
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@
from torchmetrics.audio.sdr import SDR # noqa: F401
from torchmetrics.audio.si_sdr import SI_SDR # noqa: F401
from torchmetrics.audio.si_snr import SI_SNR # noqa: F401
from torchmetrics.audio.snr import SNR, ScaleInvariantSNR # noqa: F401
from torchmetrics.audio.snr import SNR, SISNR # noqa: F401
8 changes: 4 additions & 4 deletions torchmetrics/audio/si_snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@

from torch import Tensor

from torchmetrics.audio.snr import ScaleInvariantSNR
from torchmetrics.audio.snr import SISNR


class SI_SNR(ScaleInvariantSNR):
class SI_SNR(SISNR):
"""Scale-invariant signal-to-noise ratio (SI-SNR).
.. deprecated:: v0.7
Use :class:`torchmetrics.ScaleInvariantSNR`. Will be removed in v0.8.
Use :class:`torchmetrics.SISNR`. Will be removed in v0.8.
Example:
>>> import torch
Expand All @@ -39,5 +39,5 @@ def __init__(
process_group: Optional[Any] = None,
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
) -> None:
warn("`SI_SDR` was renamed to `ScaleInvariantSNR` in v0.7 and it will be removed in v0.8", DeprecationWarning)
warn("`SI_SDR` was renamed to `SISNR` in v0.7 and it will be removed in v0.8", DeprecationWarning)
super().__init__(compute_on_step, dist_sync_on_step, process_group, dist_sync_fn)
10 changes: 5 additions & 5 deletions torchmetrics/audio/snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from torch import Tensor, tensor

from torchmetrics.functional.audio.snr import scale_invariant_signal_noise_ratio, snr
from torchmetrics.functional.audio.snr import sisnr, snr
from torchmetrics.metric import Metric


Expand Down Expand Up @@ -111,7 +111,7 @@ def compute(self) -> Tensor:
return self.sum_snr / self.total


class ScaleInvariantSNR(Metric):
class SISNR(Metric):
"""Scale-invariant signal-to-noise ratio (SI-SNR).
Forward accepts
Expand Down Expand Up @@ -140,10 +140,10 @@ class ScaleInvariantSNR(Metric):
Example:
>>> import torch
>>> from torchmetrics import ScaleInvariantSNR
>>> from torchmetrics import SISNR
>>> target = torch.tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0])
>>> si_snr = ScaleInvariantSNR()
>>> si_snr = SISNR()
>>> si_snr(preds, target)
tensor(15.0918)
Expand Down Expand Up @@ -182,7 +182,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
preds: Predictions from model
target: Ground truth values
"""
si_snr_batch = scale_invariant_signal_noise_ratio(preds=preds, target=target)
si_snr_batch = sisnr(preds=preds, target=target)

self.sum_si_snr += si_snr_batch.sum()
self.total += si_snr_batch.numel()
Expand Down
5 changes: 2 additions & 3 deletions torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torchmetrics.functional.audio.sdr import sdr
from torchmetrics.functional.audio.si_sdr import si_sdr
from torchmetrics.functional.audio.si_snr import si_snr
from torchmetrics.functional.audio.snr import scale_invariant_signal_noise_ratio, signal_noise_ratio, snr
from torchmetrics.functional.audio.snr import sisnr, snr, snr
from torchmetrics.functional.classification.accuracy import accuracy
from torchmetrics.functional.classification.auc import auc
from torchmetrics.functional.classification.auroc import auroc
Expand Down Expand Up @@ -131,9 +131,8 @@
"sdr",
"si_sdr",
"si_snr",
"scale_invariant_signal_noise_ratio",
"sisnr",
"snr",
"signal_noise_ratio",
"spearman_corrcoef",
"specificity",
"squad",
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/functional/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@
from torchmetrics.functional.audio.sdr import sdr # noqa: F401
from torchmetrics.functional.audio.si_sdr import si_sdr # noqa: F401
from torchmetrics.functional.audio.si_snr import si_snr # noqa: F401
from torchmetrics.functional.audio.snr import scale_invariant_signal_noise_ratio, signal_noise_ratio, snr # noqa: F401
from torchmetrics.functional.audio.snr import sisnr, snr, snr # noqa: F401
8 changes: 4 additions & 4 deletions torchmetrics/functional/audio/si_snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,22 @@

from torch import Tensor

from torchmetrics.functional.audio.snr import scale_invariant_signal_noise_ratio
from torchmetrics.functional.audio.snr import sisnr


def si_snr(preds: Tensor, target: Tensor) -> Tensor:
"""Scale-invariant signal-to-noise ratio (SI-SNR).
.. deprecated:: v0.7
Use :func:`torchmetrics.functional.scale_invariant_signal_noise_ratio`. Will be removed in v0.8.
Use :func:`torchmetrics.functional.sisnr`. Will be removed in v0.8.
Example:
>>> import torch
>>> si_snr(torch.tensor([2.5, 0.0, 2.0, 8.0]), torch.tensor([3.0, -0.5, 2.0, 7.0]))
tensor(15.0918)
"""
warn(
"`si_snr` was renamed to `scale_invariant_signal_noise_ratio` in v0.7 and it will be removed in v0.8",
"`si_snr` was renamed to `sisnr` in v0.7 and it will be removed in v0.8",
DeprecationWarning,
)
return scale_invariant_signal_noise_ratio(preds, target)
return sisnr(preds, target)
27 changes: 6 additions & 21 deletions torchmetrics/functional/audio/snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torchmetrics.utilities.checks import _check_same_shape


def signal_noise_ratio(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor:
def snr(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor:
r"""Signal-to-noise ratio (SNR_):
.. math::
Expand All @@ -42,10 +42,10 @@ def signal_noise_ratio(preds: Tensor, target: Tensor, zero_mean: bool = False) -
snr value of shape [...]
Example:
>>> from torchmetrics.functional.audio import signal_noise_ratio
>>> from torchmetrics.functional.audio import snr
>>> target = torch.tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0])
>>> signal_noise_ratio(preds, target)
>>> snr(preds, target)
tensor(16.1805)
References:
Expand All @@ -68,22 +68,7 @@ def signal_noise_ratio(preds: Tensor, target: Tensor, zero_mean: bool = False) -
return snr_value


def snr(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor:
r"""Signal-to-noise ratio (SNR_)
.. deprecated:: v0.7
Use :func:`torchmetrics.functional.signal_noise_ratio`. Will be removed in v0.8.
Example:
>>> snr(torch.tensor([2.5, 0.0, 2.0, 8.0]), torch.tensor([3.0, -0.5, 2.0, 7.0]))
tensor(16.1805)
"""
warn("`snr` was renamed to `signal_noise_ratio` in v0.7 and it will be removed in v0.8", DeprecationWarning)
return signal_noise_ratio(preds, target, zero_mean)


def scale_invariant_signal_noise_ratio(preds: Tensor, target: Tensor) -> Tensor:
def sisnr(preds: Tensor, target: Tensor) -> Tensor:
"""Scale-invariant signal-to-noise ratio (SI-SNR).
Args:
Expand All @@ -97,10 +82,10 @@ def scale_invariant_signal_noise_ratio(preds: Tensor, target: Tensor) -> Tensor:
Example:
>>> import torch
>>> from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio
>>> from torchmetrics.functional.audio import sisnr
>>> target = torch.tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0])
>>> scale_invariant_signal_noise_ratio(preds, target)
>>> sisnr(preds, target)
tensor(15.0918)
References:
Expand Down

0 comments on commit a2941cc

Please sign in to comment.