Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: SNR & SI_SNR #712

Merged
merged 18 commits into from
Jan 8, 2022
Merged
9 changes: 8 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `MinMaxMetric` to wrappers ([#556](https://github.com/PyTorchLightning/metrics/pull/556))


- Added `ignore_index` to to retrieval metrics ([#676](https://github.com/PyTorchLightning/metrics/pull/676))
- Added `ignore_index` to retrieval metrics ([#676](https://github.com/PyTorchLightning/metrics/pull/676))


- Added support for multi references in `ROUGEScore` ([#680](https://github.com/PyTorchLightning/metrics/pull/680))
Expand Down Expand Up @@ -71,6 +71,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* `SI_SDR` -> `ScaleInvariantSignalDistortionRatio`


- Renamed audio SNR metrics: ([#712](https://github.com/PyTorchLightning/metrics/pull/712))
* `functional.snr` -> `functional.signal_distortion_ratio`
* `functional.si_snr` -> `functional.scale_invariant_signal_noise_ratio`
* `SNR` -> `SignalNoiseRatio`
* `SI_SNR` -> `ScaleInvariantSignalNoiseRatio`


### Removed

- Removed `embedding_similarity` metric ([#638](https://github.com/PyTorchLightning/metrics/pull/638))
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,8 @@ We currently have implemented metrics within the following domains:

- Audio (
[ScaleInvariantSignalDistortionRatio](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#ScaleInvariantSignalDistortionRatio),
[SI_SNR](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#si-snr),
[SNR](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#snr)
[ScaleInvariantSignalNoiseRatio](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#ScaleInvariantSignalNoiseRatio),
[SignalNoiseRatio](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#SignalNoiseRatio)
and [few more](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#audio-metrics)
)
- Classification (
Expand Down
10 changes: 5 additions & 5 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,17 @@ scale_invariant_signal_distortion_ratio [func]
:noindex:


si_snr [func]
~~~~~~~~~~~~~
scale_invariant_signal_noise_ratio [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.si_snr
:noindex:


snr [func]
~~~~~~~~~~
signal_noise_ratio [func]
~~~~~~~~~~~~~~~~~~~~~~~~~
Borda marked this conversation as resolved.
Show resolved Hide resolved

.. autofunction:: torchmetrics.functional.snr
.. autofunction:: torchmetrics.functional.signal_noise_ratio
Borda marked this conversation as resolved.
Show resolved Hide resolved
:noindex:


Expand Down
23 changes: 14 additions & 9 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,17 @@ the metric will be computed over the ``time`` dimension.
.. doctest::

>>> import torch
>>> from torchmetrics import SignalNoiseRatio
>>> target = torch.tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0])
>>> snr = SignalNoiseRatio()
>>> snr(preds, target)
tensor(16.1805)
>>> from torchmetrics import SNR
>>> target = torch.tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0])
>>> snr = SNR()
>>> snr_val = snr(preds, target)
>>> snr_val
>>> snr = SignalNoiseRatio()
>>> snr(preds, target)
tensor(16.1805)

PESQ
Expand All @@ -97,16 +102,16 @@ ScaleInvariantSignalDistortionRatio
.. autoclass:: torchmetrics.ScaleInvariantSignalDistortionRatio
:noindex:

SI_SNR
~~~~~~
ScaleInvariantSignalNoiseRatio
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.SI_SNR
.. autoclass:: torchmetrics.ScaleInvariantSignalNoiseRatio
:noindex:

SNR
~~~
SignalNoiseRatio
~~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.SNR
.. autoclass:: torchmetrics.SignalNoiseRatio
:noindex:

STOI
Expand Down
24 changes: 17 additions & 7 deletions tests/audio/test_si_snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

from tests.helpers import seed_all
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from torchmetrics.audio import SI_SNR
from torchmetrics.functional import si_snr
from torchmetrics.audio import ScaleInvariantSignalNoiseRatio
from torchmetrics.functional import scale_invariant_signal_noise_ratio
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6

seed_all(42)
Expand Down Expand Up @@ -79,7 +79,7 @@ def test_si_snr(self, preds, target, sk_metric, ddp, dist_sync_on_step):
ddp,
preds,
target,
SI_SNR,
ScaleInvariantSignalNoiseRatio,
sk_metric=partial(average_metric, metric_func=sk_metric),
dist_sync_on_step=dist_sync_on_step,
)
Expand All @@ -88,12 +88,17 @@ def test_si_snr_functional(self, preds, target, sk_metric):
self.run_functional_metric_test(
preds,
target,
si_snr,
scale_invariant_signal_noise_ratio,
sk_metric,
)

def test_si_snr_differentiability(self, preds, target, sk_metric):
self.run_differentiability_test(preds=preds, target=target, metric_module=SI_SNR, metric_functional=si_snr)
self.run_differentiability_test(
preds=preds,
target=target,
metric_module=ScaleInvariantSignalNoiseRatio,
metric_functional=scale_invariant_signal_noise_ratio,
)

@pytest.mark.skipif(
not _TORCH_GREATER_EQUAL_1_6, reason="half support of core operations on not support before pytorch v1.6"
Expand All @@ -103,10 +108,15 @@ def test_si_snr_half_cpu(self, preds, target, sk_metric):

@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
def test_si_snr_half_gpu(self, preds, target, sk_metric):
self.run_precision_test_gpu(preds=preds, target=target, metric_module=SI_SNR, metric_functional=si_snr)
self.run_precision_test_gpu(
preds=preds,
target=target,
metric_module=ScaleInvariantSignalNoiseRatio,
metric_functional=scale_invariant_signal_noise_ratio,
)


def test_error_on_different_shape(metric_class=SI_SNR):
def test_error_on_different_shape(metric_class=ScaleInvariantSignalNoiseRatio):
metric = metric_class()
with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"):
metric(torch.randn(100), torch.randn(50))
28 changes: 18 additions & 10 deletions tests/audio/test_snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,19 @@

from tests.helpers import seed_all
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from torchmetrics.audio import SNR
from torchmetrics.functional import snr
from torchmetrics.audio import SignalNoiseRatio
from torchmetrics.functional import signal_noise_ratio
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6

seed_all(42)

Time = 100
TIME = 100

Input = namedtuple("Input", ["preds", "target"])

inputs = Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, Time),
target=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, Time),
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, TIME),
target=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, TIME),
)


Expand Down Expand Up @@ -86,7 +86,7 @@ def test_snr(self, preds, target, sk_metric, zero_mean, ddp, dist_sync_on_step):
ddp,
preds,
target,
SNR,
SignalNoiseRatio,
sk_metric=partial(average_metric, metric_func=sk_metric),
dist_sync_on_step=dist_sync_on_step,
metric_args=dict(zero_mean=zero_mean),
Expand All @@ -96,14 +96,18 @@ def test_snr_functional(self, preds, target, sk_metric, zero_mean):
self.run_functional_metric_test(
preds,
target,
snr,
signal_noise_ratio,
sk_metric,
metric_args=dict(zero_mean=zero_mean),
)

def test_snr_differentiability(self, preds, target, sk_metric, zero_mean):
self.run_differentiability_test(
preds=preds, target=target, metric_module=SNR, metric_functional=snr, metric_args={"zero_mean": zero_mean}
preds=preds,
target=target,
metric_module=SignalNoiseRatio,
metric_functional=signal_noise_ratio,
metric_args={"zero_mean": zero_mean},
)

@pytest.mark.skipif(
Expand All @@ -115,11 +119,15 @@ def test_snr_half_cpu(self, preds, target, sk_metric, zero_mean):
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
def test_snr_half_gpu(self, preds, target, sk_metric, zero_mean):
self.run_precision_test_gpu(
preds=preds, target=target, metric_module=SNR, metric_functional=snr, metric_args={"zero_mean": zero_mean}
preds=preds,
target=target,
metric_module=SignalNoiseRatio,
metric_functional=signal_noise_ratio,
metric_args={"zero_mean": zero_mean},
)


def test_error_on_different_shape(metric_class=SNR):
def test_error_on_different_shape(metric_class=SignalNoiseRatio):
metric = metric_class()
with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"):
metric(torch.randn(100), torch.randn(50))
4 changes: 4 additions & 0 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
SI_SNR,
SNR,
ScaleInvariantSignalDistortionRatio,
ScaleInvariantSignalNoiseRatio,
SignalDistortionRatio,
SignalNoiseRatio,
)
from torchmetrics.classification import ( # noqa: E402, F401
AUC,
Expand Down Expand Up @@ -154,6 +156,8 @@
"ScaleInvariantSignalDistortionRatio",
"SI_SDR",
"SI_SNR",
"ScaleInvariantSignalNoiseRatio",
"SignalNoiseRatio",
"SNR",
"SpearmanCorrcoef",
"SpearmanCorrCoef",
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@
from torchmetrics.audio.sdr import SDR, ScaleInvariantSignalDistortionRatio, SignalDistortionRatio # noqa: F401
from torchmetrics.audio.si_sdr import SI_SDR # noqa: F401
from torchmetrics.audio.si_snr import SI_SNR # noqa: F401
from torchmetrics.audio.snr import SNR # noqa: F401
from torchmetrics.audio.snr import SNR, ScaleInvariantSignalNoiseRatio, SignalNoiseRatio # noqa: F401
77 changes: 11 additions & 66 deletions torchmetrics/audio/si_snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,90 +12,35 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional
from warnings import warn

from torch import Tensor, tensor
from torch import Tensor

from torchmetrics.functional.audio.si_snr import si_snr
from torchmetrics.metric import Metric
from torchmetrics.audio.snr import ScaleInvariantSignalNoiseRatio


class SI_SNR(Metric):
class SI_SNR(ScaleInvariantSignalNoiseRatio):
"""Scale-invariant signal-to-noise ratio (SI-SNR).

Forward accepts

- ``preds``: ``shape [...,time]``
- ``target``: ``shape [...,time]``

Args:
compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
process_group:
Specify the process group on which synchronization is called.
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When `None`, DDP
will be used to perform the allgather.

Raises:
TypeError:
if target and preds have a different shape

Returns:
average si-snr value
.. deprecated:: v0.7
Use :class:`torchmetrics.ScaleInvariantSignalNoiseRatio`. Will be removed in v0.8.

Example:
>>> import torch
>>> from torchmetrics import SI_SNR
>>> target = torch.tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0])
>>> si_snr = SI_SNR()
>>> si_snr_val = si_snr(preds, target)
>>> si_snr_val
>>> si_snr(torch.tensor([2.5, 0.0, 2.0, 8.0]), torch.tensor([3.0, -0.5, 2.0, 7.0]))
tensor(15.0918)

References:
[1] Y. Luo and N. Mesgarani, "TaSNet: Time-Domain Audio Separation Network for Real-Time, Single-Channel Speech
Separation," 2018 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), 2018, pp.
696-700, doi: 10.1109/ICASSP.2018.8462116.
"""

is_differentiable = True
sum_si_snr: Tensor
total: Tensor
higher_is_better = True

def __init__(
self,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
) -> None:
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
warn(
"`SI_SNR` was renamed to `ScaleInvariantSignalNoiseRatio` in v0.7 and it will be removed in v0.8",
DeprecationWarning,
)

self.add_state("sum_si_snr", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""Update state with predictions and targets.

Args:
preds: Predictions from model
target: Ground truth values
"""
si_snr_batch = si_snr(preds=preds, target=target)

self.sum_si_snr += si_snr_batch.sum()
self.total += si_snr_batch.numel()

def compute(self) -> Tensor:
"""Computes average SI-SNR."""
return self.sum_si_snr / self.total
super().__init__(compute_on_step, dist_sync_on_step, process_group, dist_sync_fn)
Loading