Skip to content

Commit

Permalink
test_transcription modernized
Browse files Browse the repository at this point in the history
  • Loading branch information
bmcfee committed Mar 16, 2024
1 parent 0aff744 commit 530d147
Showing 1 changed file with 54 additions and 68 deletions.
122 changes: 54 additions & 68 deletions tests/test_transcription.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import numpy as np
import glob
import json
from nose.tools import raises
import warnings
import pytest

A_TOL = 1e-12

Expand All @@ -14,6 +13,28 @@
EST_GLOB = 'data/transcription/est*.txt'
SCORES_GLOB = 'data/transcription/output*.json'

ref_files = sorted(glob.glob(REF_GLOB))
est_files = sorted(glob.glob(EST_GLOB))
sco_files = sorted(glob.glob(SCORES_GLOB))

assert len(ref_files) == len(est_files) == len(sco_files) > 0

file_sets = list(zip(ref_files, est_files, sco_files))


@pytest.fixture
def transcription_data(request):
ref_f, est_f, sco_f = request.param
with open(sco_f, "r") as f:
expected_scores = json.load(f)
# Load in an example segmentation annotation
ref_int, ref_pitch = mir_eval.io.load_valued_intervals(ref_f)
# Load in estimated transcription
est_int, est_pitch = mir_eval.io.load_valued_intervals(est_f)

return ref_int, ref_pitch, est_int, est_pitch, expected_scores


REF = np.array([
[0.100, 0.300, 220.000],
[0.300, 0.400, 246.942],
Expand Down Expand Up @@ -155,9 +176,6 @@ def test_precision_recall_f1_overlap():
assert np.allclose(scores_exp, scores_gen, atol=A_TOL)


def __check_score(score, expected_score):
assert np.allclose(score, expected_score, atol=A_TOL)


def test_onset_precision_recall_f1():

Expand Down Expand Up @@ -191,91 +209,57 @@ def test_offset_precision_recall_f1():
assert np.allclose(scores_exp, scores_gen, atol=A_TOL)


def test_regression():

# Regression tests
ref_files = sorted(glob.glob(REF_GLOB))
est_files = sorted(glob.glob(EST_GLOB))
sco_files = sorted(glob.glob(SCORES_GLOB))
@pytest.mark.parametrize("transcription_data", file_sets, indirect=True)
def test_regression(transcription_data):

for ref_f, est_f, sco_f in zip(ref_files, est_files, sco_files):
with open(sco_f, 'r') as f:
expected_scores = json.load(f)
# Load in reference transcription
ref_int, ref_pitch = mir_eval.io.load_valued_intervals(ref_f)
# Load in estimated transcription
est_int, est_pitch = mir_eval.io.load_valued_intervals(est_f)
scores = mir_eval.transcription.evaluate(ref_int, ref_pitch, est_int,
est_pitch)
for metric in scores:
# This is a simple hack to make nosetest's messages more useful
yield (__check_score, scores[metric], expected_scores[metric])
ref_int, ref_pitch, est_int, est_pitch, expected_scores = transcription_data

scores = mir_eval.transcription.evaluate(ref_int, ref_pitch, est_int,
est_pitch)
assert scores.keys() == expected_scores.keys()
for metric in scores:
assert np.allclose(scores[metric], expected_scores[metric], atol=A_TOL)

def test_invalid_pitch():

ref_int, ref_pitch = np.array([[0, 1]]), np.array([-100])
est_int, est_pitch = np.array([[0, 1]]), np.array([100])

yield (raises(ValueError)(mir_eval.transcription.validate),
ref_int, ref_pitch, est_int, est_pitch)
yield (raises(ValueError)(mir_eval.transcription.validate),
est_int, est_pitch, ref_int, ref_pitch)
@pytest.mark.xfail(raises=ValueError)
@pytest.mark.parametrize('ref_pitch, est_pitch', [
(np.array([-100]), np.array([100])),
(np.array([100]), np.array([-100]))])
def test_invalid_pitch(ref_pitch, est_pitch):

ref_int = np.array([[0, 1]])
mir_eval.transcription.validate(ref_int, ref_pitch, ref_int, est_pitch)

def test_inconsistent_int_pitch():

ref_int, ref_pitch = np.array([[0, 1], [2, 3]]), np.array([100])
est_int, est_pitch = np.array([[0, 1]]), np.array([100])
@pytest.mark.xfail(raises=ValueError)
@pytest.mark.parametrize('ref_int, est_int', [
(np.array([[0, 1], [2, 3]]), np.array([[0, 1]])),
(np.array([[0, 1]]), np.array([[0, 1], [2, 3]]))])
def test_inconsistent_int_pitch(ref_int, est_int):

yield (raises(ValueError)(mir_eval.transcription.validate),
ref_int, ref_pitch, est_int, est_pitch)
yield (raises(ValueError)(mir_eval.transcription.validate),
est_int, est_pitch, ref_int, ref_pitch)
ref_pitch = np.array([100])
mir_eval.transcription.validate(ref_int, ref_pitch, est_int, ref_pitch)


def test_empty_ref():

warnings.resetwarnings()
warnings.simplefilter('always')
with warnings.catch_warnings(record=True) as out:

ref_int, ref_pitch = np.empty(shape=(0, 2)), np.array([])
est_int, est_pitch = np.array([[0, 1]]), np.array([100])
ref_int, ref_pitch = np.empty(shape=(0, 2)), np.array([])
est_int, est_pitch = np.array([[0, 1]]), np.array([100])

with pytest.warns(UserWarning, match='Reference notes are empty'):
mir_eval.transcription.validate(ref_int, ref_pitch, est_int, est_pitch)

# Make sure that the warning triggered
assert len(out) > 0

# And that the category is correct
assert out[0].category is UserWarning

# And that it says the right thing (roughly)
assert 'empty' in str(out[0].message).lower()


def test_empty_est():

warnings.resetwarnings()
warnings.simplefilter('always')
with warnings.catch_warnings(record=True) as out:

ref_int, ref_pitch = np.array([[0, 1]]), np.array([100])
est_int, est_pitch = np.empty(shape=(0, 2)), np.array([])
ref_int, ref_pitch = np.array([[0, 1]]), np.array([100])
est_int, est_pitch = np.empty(shape=(0, 2)), np.array([])

with pytest.warns(UserWarning, match='Estimated notes are empty'):
mir_eval.transcription.validate(ref_int, ref_pitch, est_int, est_pitch)

# Make sure that the warning triggered
assert len(out) > 0

# And that the category is correct
assert out[0].category is UserWarning

# And that it says the right thing (roughly)
assert 'empty' in str(out[0].message).lower()


@pytest.mark.filterwarnings("ignore:.*notes are empty")
def test_precision_recall_f1_overlap_empty():

ref_int, ref_pitch = np.empty(shape=(0, 2)), np.array([])
Expand All @@ -294,6 +278,7 @@ def test_precision_recall_f1_overlap_empty():
assert (precision, recall, f1) == (0, 0, 0)


@pytest.mark.filterwarnings("ignore:.*notes are empty")
def test_onset_precision_recall_f1_empty():

ref_int = np.empty(shape=(0, 2))
Expand All @@ -310,6 +295,7 @@ def test_onset_precision_recall_f1_empty():
assert (precision, recall, f1) == (0, 0, 0)


@pytest.mark.filterwarnings("ignore:.*notes are empty")
def test_offset_precision_recall_f1_empty():

ref_int = np.empty(shape=(0, 2))
Expand Down

0 comments on commit 530d147

Please sign in to comment.