Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rename STOI #753

Merged
merged 4 commits into from
Jan 14, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* `MatthewsCorrcoef` -> `MatthewsCorrCoef`
* `PearsonCorrcoef` -> `PearsonCorrCoef`
* `SpearmanCorrcoef` -> `SpearmanCorrCoef`
- Renamed audio STOI metric `audio.STOI` to `audio.ShortTermObjectiveIntelligibility` ([#753](https://github.com/PyTorchLightning/metrics/pull/753))
- Renamed audio SDR metrics: ([#711](https://github.com/PyTorchLightning/metrics/pull/711))
* `functional.sdr` -> `functional.signal_distortion_ratio`
* `functional.si_sdr` -> `functional.scale_invariant_signal_distortion_ratio`
Expand Down
4 changes: 2 additions & 2 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ SignalNoiseRatio
.. autoclass:: torchmetrics.SignalNoiseRatio
:noindex:

STOI
ShortTermObjectiveIntelligibility
~~~~

.. autoclass:: torchmetrics.audio.stoi.STOI
.. autoclass:: torchmetrics.audio.stoi.ShortTermObjectiveIntelligibility
:noindex:


Expand Down
10 changes: 5 additions & 5 deletions tests/audio/test_stoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from tests.helpers import seed_all
from tests.helpers.testers import MetricTester
from torchmetrics.audio.stoi import STOI
from torchmetrics.audio.stoi import ShortTermObjectiveIntelligibility
from torchmetrics.functional.audio.stoi import stoi
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6

Expand Down Expand Up @@ -82,7 +82,7 @@ def test_stoi(self, preds, target, sk_metric, fs, extended, ddp, dist_sync_on_st
ddp,
preds,
target,
STOI,
ShortTermObjectiveIntelligibility,
sk_metric=partial(average_metric, metric_func=sk_metric),
dist_sync_on_step=dist_sync_on_step,
metric_args=dict(fs=fs, extended=extended),
Expand All @@ -101,7 +101,7 @@ def test_stoi_differentiability(self, preds, target, sk_metric, fs, extended):
self.run_differentiability_test(
preds=preds,
target=target,
metric_module=STOI,
metric_module=ShortTermObjectiveIntelligibility,
metric_functional=stoi,
metric_args=dict(fs=fs, extended=extended),
)
Expand All @@ -117,13 +117,13 @@ def test_stoi_half_gpu(self, preds, target, sk_metric, fs, extended):
self.run_precision_test_gpu(
preds=preds,
target=target,
metric_module=STOI,
metric_module=ShortTermObjectiveIntelligibility,
metric_functional=partial(stoi, fs=fs, extended=extended),
metric_args=dict(fs=fs, extended=extended),
)


def test_error_on_different_shape(metric_class=STOI):
def test_error_on_different_shape(metric_class=ShortTermObjectiveIntelligibility):
metric = metric_class(16000)
with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"):
metric(torch.randn(100), torch.randn(50))
Expand Down
4 changes: 4 additions & 0 deletions torchmetrics/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,7 @@
from torchmetrics.audio.si_sdr import SI_SDR # noqa: F401
from torchmetrics.audio.si_snr import SI_SNR # noqa: F401
from torchmetrics.audio.snr import SNR, ScaleInvariantSignalNoiseRatio, SignalNoiseRatio # noqa: F401
from torchmetrics.utilities.imports import _PYSTOI_AVAILABLE

if _PYSTOI_AVAILABLE:
from torchmetrics.audio.stoi import ShortTermObjectiveIntelligibility # noqa: F401
37 changes: 34 additions & 3 deletions torchmetrics/audio/stoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@
# limitations under the License.
from typing import Any, Callable, Optional

from deprecate import deprecated, void
from torch import Tensor, tensor

from torchmetrics.functional.audio.stoi import stoi
from torchmetrics.metric import Metric
from torchmetrics.utilities import _future_warning
from torchmetrics.utilities.imports import _PYSTOI_AVAILABLE


class STOI(Metric):
class ShortTermObjectiveIntelligibility(Metric):
r"""STOI (Short Term Objective Intelligibility, see [2,3]), a wrapper for the pystoi package [1].
Note that input will be moved to `cpu` to perform the metric calculation.

Expand Down Expand Up @@ -63,12 +65,12 @@ class STOI(Metric):
If ``pystoi`` package is not installed

Example:
>>> from torchmetrics.audio.stoi import STOI
>>> from torchmetrics.audio.stoi import ShortTermObjectiveIntelligibility
>>> import torch
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> stoi = STOI(8000, False)
>>> stoi = ShortTermObjectiveIntelligibility(8000, False)
>>> stoi(preds, target)
tensor(-0.0100)

Expand Down Expand Up @@ -131,3 +133,32 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
def compute(self) -> Tensor:
"""Computes average STOI."""
return self.sum_stoi / self.total


class STOI(ShortTermObjectiveIntelligibility):
Borda marked this conversation as resolved.
Show resolved Hide resolved
r"""STOI (Short Term Objective Intelligibility), a wrapper for the pystoi package.

.. deprecated:: v0.7
Use :class:`torchmetrics.audio.ShortTermObjectiveIntelligibility`. Will be removed in v0.8.

Example:
>>> import torch
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> stoi = STOI(8000, False)
>>> stoi(preds, target)
tensor(-0.0100)
"""

@deprecated(target=ShortTermObjectiveIntelligibility, deprecated_in="0.7", remove_in="0.8", stream=_future_warning)
def __init__(
self,
fs: int,
extended: bool = False,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
) -> None:
void(fs, extended, compute_on_step, dist_sync_on_step, process_group, dist_sync_fn)
2 changes: 1 addition & 1 deletion torchmetrics/functional/audio/stoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def stoi(preds: Tensor, target: Tensor, fs: int, extended: bool = False, keep_sa
"""
if not _PYSTOI_AVAILABLE:
raise ModuleNotFoundError(
"STOI metric requires that `pystoi` is installed."
"ShortTermObjectiveIntelligibility metric requires that `pystoi` is installed."
" Either install as `pip install torchmetrics[audio]` or `pip install pystoi`."
)
_check_same_shape(preds, target)
Expand Down