Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Renamestoi functional #758

Merged
merged 4 commits into from
Jan 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* `MatthewsCorrcoef` -> `MatthewsCorrCoef`
* `PearsonCorrcoef` -> `PearsonCorrCoef`
* `SpearmanCorrcoef` -> `SpearmanCorrCoef`
- Renamed audio STOI metric `audio.STOI` to `audio.ShortTermObjectiveIntelligibility` ([#753](https://github.com/PyTorchLightning/metrics/pull/753))
- Renamed audio STOI metric: ([#753](https://github.com/PyTorchLightning/metrics/pull/753), [#758](https://github.com/PyTorchLightning/metrics/pull/758))
* `audio.STOI` to `audio.ShortTermObjectiveIntelligibility`
* `functional.audio.stoi` to `functional.audio.short_term_objective_intelligibility`
- Renamed audio PESQ metrics: ([#751](https://github.com/PyTorchLightning/metrics/pull/751))
* `functional.audio.pesq` -> `functional.audio.perceptual_evaluation_speech_quality`
* `audio.PESQ` -> `audio.PerceptualEvaluationSpeechQuality`
Expand Down
6 changes: 3 additions & 3 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ signal_noise_ratio [func]
:noindex:


stoi [func]
~~~~~~~~~~~
short_term_objective_intelligibility [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.audio.stoi.stoi
.. autofunction:: torchmetrics.functional.audio.stoi.short_term_objective_intelligibility
:noindex:


Expand Down
10 changes: 5 additions & 5 deletions tests/audio/test_stoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from tests.helpers import seed_all
from tests.helpers.testers import MetricTester
from torchmetrics.audio.stoi import ShortTermObjectiveIntelligibility
from torchmetrics.functional.audio.stoi import stoi
from torchmetrics.functional.audio.stoi import short_term_objective_intelligibility
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6

seed_all(42)
Expand Down Expand Up @@ -92,7 +92,7 @@ def test_stoi_functional(self, preds, target, sk_metric, fs, extended):
self.run_functional_metric_test(
preds,
target,
stoi,
short_term_objective_intelligibility,
sk_metric,
metric_args=dict(fs=fs, extended=extended),
)
Expand All @@ -102,7 +102,7 @@ def test_stoi_differentiability(self, preds, target, sk_metric, fs, extended):
preds=preds,
target=target,
metric_module=ShortTermObjectiveIntelligibility,
metric_functional=stoi,
metric_functional=short_term_objective_intelligibility,
metric_args=dict(fs=fs, extended=extended),
)

Expand All @@ -118,7 +118,7 @@ def test_stoi_half_gpu(self, preds, target, sk_metric, fs, extended):
preds=preds,
target=target,
metric_module=ShortTermObjectiveIntelligibility,
metric_functional=partial(stoi, fs=fs, extended=extended),
metric_functional=partial(short_term_objective_intelligibility, fs=fs, extended=extended),
metric_args=dict(fs=fs, extended=extended),
)

Expand All @@ -139,7 +139,7 @@ def test_on_real_audio():
rate, ref = wavfile.read(os.path.join(current_file_dir, "examples/audio_speech.wav"))
rate, deg = wavfile.read(os.path.join(current_file_dir, "examples/audio_speech_bab_0dB.wav"))
assert torch.allclose(
stoi(torch.from_numpy(deg), torch.from_numpy(ref), rate).float(),
short_term_objective_intelligibility(torch.from_numpy(deg), torch.from_numpy(ref), rate).float(),
torch.tensor(0.6739177),
rtol=0.0001,
atol=1e-4,
Expand Down
6 changes: 4 additions & 2 deletions torchmetrics/audio/stoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from deprecate import deprecated, void
from torch import Tensor, tensor

from torchmetrics.functional.audio.stoi import stoi
from torchmetrics.functional.audio.stoi import short_term_objective_intelligibility
from torchmetrics.metric import Metric
from torchmetrics.utilities import _future_warning
from torchmetrics.utilities.imports import _PYSTOI_AVAILABLE
Expand Down Expand Up @@ -125,7 +125,9 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
preds: Predictions from model
target: Ground truth values
"""
stoi_batch = stoi(preds, target, self.fs, self.extended, False).to(self.sum_stoi.device)
stoi_batch = short_term_objective_intelligibility(preds, target, self.fs, self.extended, False).to(
self.sum_stoi.device
)

self.sum_stoi += stoi_batch.sum()
self.total += stoi_batch.numel()
Expand Down
29 changes: 26 additions & 3 deletions torchmetrics/functional/audio/stoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import numpy as np
import torch
from deprecate import deprecated, void

from torchmetrics.utilities.imports import _PYSTOI_AVAILABLE

Expand All @@ -22,10 +23,13 @@
stoi_backend = None
from torch import Tensor

from torchmetrics.utilities import _future_warning
from torchmetrics.utilities.checks import _check_same_shape


def stoi(preds: Tensor, target: Tensor, fs: int, extended: bool = False, keep_same_device: bool = False) -> Tensor:
def short_term_objective_intelligibility(
preds: Tensor, target: Tensor, fs: int, extended: bool = False, keep_same_device: bool = False
) -> Tensor:
r"""STOI (Short Term Objective Intelligibility, see [2,3]), a wrapper for the pystoi package [1].
Note that input will be moved to `cpu` to perform the metric calculation.

Expand Down Expand Up @@ -59,12 +63,12 @@ def stoi(preds: Tensor, target: Tensor, fs: int, extended: bool = False, keep_sa
If ``pystoi`` package is not installed

Example:
>>> from torchmetrics.functional.audio.stoi import stoi
>>> from torchmetrics.functional.audio.stoi import short_term_objective_intelligibility
>>> import torch
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> stoi(preds, target, 8000).float()
>>> short_term_objective_intelligibility(preds, target, 8000).float()
tensor(-0.0100)

References:
Expand Down Expand Up @@ -103,3 +107,22 @@ def stoi(preds: Tensor, target: Tensor, fs: int, extended: bool = False, keep_sa
stoi_val = stoi_val.to(preds.device)

return stoi_val


@deprecated(target=short_term_objective_intelligibility, deprecated_in="0.7", remove_in="0.8", stream=_future_warning)
def stoi(preds: Tensor, target: Tensor, fs: int, mode: str, keep_same_device: bool = False) -> Tensor:
r"""STOI (Short Term Objective Intelligibility)

.. deprecated:: v0.7
Use :func:`torchmetrics.functional.audio.short_term_objective_intelligibility`. Will be removed in v0.8.

Example:
>>> from torchmetrics.functional.audio.stoi import stoi
>>> import torch
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> stoi(preds, target, 8000).float()
tensor(-0.0100)
"""
return void(preds, target, fs, mode, keep_same_device)