-
Notifications
You must be signed in to change notification settings - Fork 423
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
audio metrics: SNR, SI_SDR, SI_SNR (#292)
* add snr, si_sdr, si_snr * format * add noqa: F401 to __init__.py * remove types in doc, change estimate to preds, remove EPS * update functional.rst * update CHANGELOG.md * switch preds and target * switch preds and target in Example * add SNR, SI_SNR, SI_SDR module implementation * add test * add module doc * use _check_same_shape * to alphabetical order * update test * move Base to the top of Audio * add soundfile * gcc * fix mocking * image * doctest * mypy * fix requirements * fix dtype * something * update * adjust * Apply suggestions from code review * update test_snr * update test_si_snr * new snr: use torch.finfo(preds.dtype).eps * update test_snr.py * new si_sdr imp * update test_si_sdr * update test_si_snr * remove pb_bss_eval * add museval * update test files * remove museval * add funcs update return None annotation * add 'Setup ffmpeg' * update "Setup ffmpeg" * use setup-conda@v1 * multi-OS * update atol to 1e-5 * Apply suggestions from code review * change atol to 1e-2 * update * fix 'Setup Linux' not activated * add sudo * reduce Time to 100 to reduce the test time * increase timeoutInMinutes to 40 * install ffmpeg * timeout-minutes to 55 * +git * show-error-codes * .detach().cpu().numpy() first * add numpy * numpy * ignore_errors torchmetrics.audio.* * solve mypy no-redef error * remove --quiet * pypesq * apt * add # type: ignore * try without test_si_snr & test_si_sdr * test_import_speechmetrics * test_speechmetrics_si_sdr * test_si_sdr_functional * test audio only * install libsndfile1 * add sisnr sisdr test * test all & add quiet & remove test_speechmetrics * remove sudo & install libsndfile1 * add test * update * fix tests * typing * fix typing * fix bus error * SRMRpy * pesq * gcc * comment -u root cuda 10.2 whoami * env Co-authored-by: quancs <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Nicki Skafte <[email protected]> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> Co-authored-by: Jirka <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Justus Schock <[email protected]>
- Loading branch information
1 parent
a75445b
commit fe03f3a
Showing
22 changed files
with
986 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from collections import namedtuple | ||
from functools import partial | ||
|
||
import pytest | ||
import speechmetrics | ||
import torch | ||
from torch import Tensor | ||
|
||
from tests.helpers import seed_all | ||
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester | ||
from torchmetrics.audio import SI_SDR | ||
from torchmetrics.functional import si_sdr | ||
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 | ||
|
||
seed_all(42) | ||
|
||
Time = 100 | ||
|
||
Input = namedtuple('Input', ["preds", "target"]) | ||
|
||
inputs = Input( | ||
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, Time), | ||
target=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, Time), | ||
) | ||
|
||
speechmetrics_sisdr = speechmetrics.load('sisdr') | ||
|
||
|
||
def speechmetrics_si_sdr(preds: Tensor, target: Tensor, zero_mean: bool): | ||
# shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] | ||
# or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] | ||
if zero_mean: | ||
preds = preds - preds.mean(dim=2, keepdim=True) | ||
target = target - target.mean(dim=2, keepdim=True) | ||
target = target.detach().cpu().numpy() | ||
preds = preds.detach().cpu().numpy() | ||
mss = [] | ||
for i in range(preds.shape[0]): | ||
ms = [] | ||
for j in range(preds.shape[1]): | ||
metric = speechmetrics_sisdr(preds[i, j], target[i, j], rate=16000) | ||
ms.append(metric['sisdr'][0]) | ||
mss.append(ms) | ||
return torch.tensor(mss) | ||
|
||
|
||
def average_metric(preds, target, metric_func): | ||
# shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] | ||
# or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] | ||
return metric_func(preds, target).mean() | ||
|
||
|
||
speechmetrics_si_sdr_zero_mean = partial(speechmetrics_si_sdr, zero_mean=True) | ||
speechmetrics_si_sdr_no_zero_mean = partial(speechmetrics_si_sdr, zero_mean=False) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"preds, target, sk_metric, zero_mean", | ||
[ | ||
(inputs.preds, inputs.target, speechmetrics_si_sdr_zero_mean, True), | ||
(inputs.preds, inputs.target, speechmetrics_si_sdr_no_zero_mean, False), | ||
], | ||
) | ||
class TestSISDR(MetricTester): | ||
atol = 1e-2 | ||
|
||
@pytest.mark.parametrize("ddp", [True, False]) | ||
@pytest.mark.parametrize("dist_sync_on_step", [True, False]) | ||
def test_si_sdr(self, preds, target, sk_metric, zero_mean, ddp, dist_sync_on_step): | ||
self.run_class_metric_test( | ||
ddp, | ||
preds, | ||
target, | ||
SI_SDR, | ||
sk_metric=partial(average_metric, metric_func=sk_metric), | ||
dist_sync_on_step=dist_sync_on_step, | ||
metric_args=dict(zero_mean=zero_mean), | ||
) | ||
|
||
def test_si_sdr_functional(self, preds, target, sk_metric, zero_mean): | ||
self.run_functional_metric_test( | ||
preds, | ||
target, | ||
si_sdr, | ||
sk_metric, | ||
metric_args=dict(zero_mean=zero_mean), | ||
) | ||
|
||
def test_si_sdr_differentiability(self, preds, target, sk_metric, zero_mean): | ||
self.run_differentiability_test( | ||
preds=preds, | ||
target=target, | ||
metric_module=SI_SDR, | ||
metric_functional=si_sdr, | ||
metric_args={'zero_mean': zero_mean} | ||
) | ||
|
||
@pytest.mark.skipif( | ||
not _TORCH_GREATER_EQUAL_1_6, reason='half support of core operations on not support before pytorch v1.6' | ||
) | ||
def test_si_sdr_half_cpu(self, preds, target, sk_metric, zero_mean): | ||
pytest.xfail("SI-SDR metric does not support cpu + half precision") | ||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda') | ||
def test_si_sdr_half_gpu(self, preds, target, sk_metric, zero_mean): | ||
self.run_precision_test_gpu( | ||
preds=preds, | ||
target=target, | ||
metric_module=SI_SDR, | ||
metric_functional=si_sdr, | ||
metric_args={'zero_mean': zero_mean} | ||
) | ||
|
||
|
||
def test_error_on_different_shape(metric_class=SI_SDR): | ||
metric = metric_class() | ||
with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): | ||
metric(torch.randn(100, ), torch.randn(50, )) |
Oops, something went wrong.