From 530d147ffe1393a4bbcaa39353db3d4234e92b77 Mon Sep 17 00:00:00 2001 From: Brian McFee Date: Sat, 16 Mar 2024 09:57:05 -0400 Subject: [PATCH] test_transcription modernized --- tests/test_transcription.py | 122 ++++++++++++++++-------------------- 1 file changed, 54 insertions(+), 68 deletions(-) diff --git a/tests/test_transcription.py b/tests/test_transcription.py index 89fcf0f8..4f245a12 100644 --- a/tests/test_transcription.py +++ b/tests/test_transcription.py @@ -4,8 +4,7 @@ import numpy as np import glob import json -from nose.tools import raises -import warnings +import pytest A_TOL = 1e-12 @@ -14,6 +13,28 @@ EST_GLOB = 'data/transcription/est*.txt' SCORES_GLOB = 'data/transcription/output*.json' +ref_files = sorted(glob.glob(REF_GLOB)) +est_files = sorted(glob.glob(EST_GLOB)) +sco_files = sorted(glob.glob(SCORES_GLOB)) + +assert len(ref_files) == len(est_files) == len(sco_files) > 0 + +file_sets = list(zip(ref_files, est_files, sco_files)) + + +@pytest.fixture +def transcription_data(request): + ref_f, est_f, sco_f = request.param + with open(sco_f, "r") as f: + expected_scores = json.load(f) + # Load in an example segmentation annotation + ref_int, ref_pitch = mir_eval.io.load_valued_intervals(ref_f) + # Load in estimated transcription + est_int, est_pitch = mir_eval.io.load_valued_intervals(est_f) + + return ref_int, ref_pitch, est_int, est_pitch, expected_scores + + REF = np.array([ [0.100, 0.300, 220.000], [0.300, 0.400, 246.942], @@ -155,9 +176,6 @@ def test_precision_recall_f1_overlap(): assert np.allclose(scores_exp, scores_gen, atol=A_TOL) -def __check_score(score, expected_score): - assert np.allclose(score, expected_score, atol=A_TOL) - def test_onset_precision_recall_f1(): @@ -191,91 +209,57 @@ def test_offset_precision_recall_f1(): assert np.allclose(scores_exp, scores_gen, atol=A_TOL) -def test_regression(): - - # Regression tests - ref_files = sorted(glob.glob(REF_GLOB)) - est_files = sorted(glob.glob(EST_GLOB)) - sco_files = sorted(glob.glob(SCORES_GLOB)) +@pytest.mark.parametrize("transcription_data", file_sets, indirect=True) +def test_regression(transcription_data): - for ref_f, est_f, sco_f in zip(ref_files, est_files, sco_files): - with open(sco_f, 'r') as f: - expected_scores = json.load(f) - # Load in reference transcription - ref_int, ref_pitch = mir_eval.io.load_valued_intervals(ref_f) - # Load in estimated transcription - est_int, est_pitch = mir_eval.io.load_valued_intervals(est_f) - scores = mir_eval.transcription.evaluate(ref_int, ref_pitch, est_int, - est_pitch) - for metric in scores: - # This is a simple hack to make nosetest's messages more useful - yield (__check_score, scores[metric], expected_scores[metric]) + ref_int, ref_pitch, est_int, est_pitch, expected_scores = transcription_data + scores = mir_eval.transcription.evaluate(ref_int, ref_pitch, est_int, + est_pitch) + assert scores.keys() == expected_scores.keys() + for metric in scores: + assert np.allclose(scores[metric], expected_scores[metric], atol=A_TOL) -def test_invalid_pitch(): - ref_int, ref_pitch = np.array([[0, 1]]), np.array([-100]) - est_int, est_pitch = np.array([[0, 1]]), np.array([100]) - - yield (raises(ValueError)(mir_eval.transcription.validate), - ref_int, ref_pitch, est_int, est_pitch) - yield (raises(ValueError)(mir_eval.transcription.validate), - est_int, est_pitch, ref_int, ref_pitch) +@pytest.mark.xfail(raises=ValueError) +@pytest.mark.parametrize('ref_pitch, est_pitch', [ + (np.array([-100]), np.array([100])), + (np.array([100]), np.array([-100]))]) +def test_invalid_pitch(ref_pitch, est_pitch): + ref_int = np.array([[0, 1]]) + mir_eval.transcription.validate(ref_int, ref_pitch, ref_int, est_pitch) -def test_inconsistent_int_pitch(): - ref_int, ref_pitch = np.array([[0, 1], [2, 3]]), np.array([100]) - est_int, est_pitch = np.array([[0, 1]]), np.array([100]) +@pytest.mark.xfail(raises=ValueError) +@pytest.mark.parametrize('ref_int, est_int', [ + (np.array([[0, 1], [2, 3]]), np.array([[0, 1]])), + (np.array([[0, 1]]), np.array([[0, 1], [2, 3]]))]) +def test_inconsistent_int_pitch(ref_int, est_int): - yield (raises(ValueError)(mir_eval.transcription.validate), - ref_int, ref_pitch, est_int, est_pitch) - yield (raises(ValueError)(mir_eval.transcription.validate), - est_int, est_pitch, ref_int, ref_pitch) + ref_pitch = np.array([100]) + mir_eval.transcription.validate(ref_int, ref_pitch, est_int, ref_pitch) def test_empty_ref(): - warnings.resetwarnings() - warnings.simplefilter('always') - with warnings.catch_warnings(record=True) as out: - - ref_int, ref_pitch = np.empty(shape=(0, 2)), np.array([]) - est_int, est_pitch = np.array([[0, 1]]), np.array([100]) + ref_int, ref_pitch = np.empty(shape=(0, 2)), np.array([]) + est_int, est_pitch = np.array([[0, 1]]), np.array([100]) + with pytest.warns(UserWarning, match='Reference notes are empty'): mir_eval.transcription.validate(ref_int, ref_pitch, est_int, est_pitch) - # Make sure that the warning triggered - assert len(out) > 0 - - # And that the category is correct - assert out[0].category is UserWarning - - # And that it says the right thing (roughly) - assert 'empty' in str(out[0].message).lower() - def test_empty_est(): - warnings.resetwarnings() - warnings.simplefilter('always') - with warnings.catch_warnings(record=True) as out: - - ref_int, ref_pitch = np.array([[0, 1]]), np.array([100]) - est_int, est_pitch = np.empty(shape=(0, 2)), np.array([]) + ref_int, ref_pitch = np.array([[0, 1]]), np.array([100]) + est_int, est_pitch = np.empty(shape=(0, 2)), np.array([]) + with pytest.warns(UserWarning, match='Estimated notes are empty'): mir_eval.transcription.validate(ref_int, ref_pitch, est_int, est_pitch) - # Make sure that the warning triggered - assert len(out) > 0 - - # And that the category is correct - assert out[0].category is UserWarning - - # And that it says the right thing (roughly) - assert 'empty' in str(out[0].message).lower() - +@pytest.mark.filterwarnings("ignore:.*notes are empty") def test_precision_recall_f1_overlap_empty(): ref_int, ref_pitch = np.empty(shape=(0, 2)), np.array([]) @@ -294,6 +278,7 @@ def test_precision_recall_f1_overlap_empty(): assert (precision, recall, f1) == (0, 0, 0) +@pytest.mark.filterwarnings("ignore:.*notes are empty") def test_onset_precision_recall_f1_empty(): ref_int = np.empty(shape=(0, 2)) @@ -310,6 +295,7 @@ def test_onset_precision_recall_f1_empty(): assert (precision, recall, f1) == (0, 0, 0) +@pytest.mark.filterwarnings("ignore:.*notes are empty") def test_offset_precision_recall_f1_empty(): ref_int = np.empty(shape=(0, 2))