Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
JanekEbb committed Dec 31, 2023
1 parent 92681c8 commit 64ed67a
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -160,45 +160,6 @@ def test_accumulated_statistics(t_step, num_jobs):
), (stats['cts']['b'], expected_cross_triggers)


def test_no_ground_truth_events():
detection_scores = np.concatenate((np.arange(5), np.arange(4)[::-1]))
timestamps = np.arange(10)/10
scores = pd.DataFrame(
np.array((
timestamps[:-1], timestamps[1:],
detection_scores, np.zeros_like(detection_scores)
)).T,
columns=['onset', 'offset', 'a', 'b'],
)
change_point_scores, stats = accumulated_intermediate_statistics(
scores={'1': scores},
ground_truth={'1': []},
dtc_threshold=.5, gtc_threshold=.5, cttc_threshold=.5,
)[0]['a']
expected_change_point_scores = [4, np.inf] # highest score where a false positive occurs when threshold falls below it
expected_true_positives = [0, 0]
expected_false_positives = [1, 0]
expected_cross_triggers = [0, 0]

assert stats['n_ref'] == 0, stats['n_ref']
assert np.abs(stats['t_ref']) < 1e-12, stats['t_ref']
assert len(change_point_scores) == len(expected_change_point_scores), (
change_point_scores, expected_change_point_scores)
assert (change_point_scores == expected_change_point_scores).all(), (
change_point_scores, expected_change_point_scores)
assert (
(stats['tps'] == expected_true_positives).all()
), (stats['tps'], expected_true_positives)
assert (stats['fps'] == expected_false_positives).all(), (
stats['fps'], expected_false_positives)
assert stats['cts'].keys() == {'b'}, stats['cts']
assert stats['cts']['b'].shape == (len(change_point_scores),), (
stats['cts']['b'], expected_cross_triggers)
assert (
(stats['cts']['b'] == expected_cross_triggers).all()
), (stats['cts']['b'], expected_cross_triggers)


@pytest.mark.parametrize("dtc_threshold", [.5, .6])
def test_event_offset_beyond_file_offset(dtc_threshold):
detection_scores = np.concatenate((np.arange(5), 1+np.arange(3)[::-1]))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@ def test_segment_based_area_under_roc_vs_sklearn():
y_true = np.array([0, 0, 1, 1])
y_scores = np.array([0.1, 0.4, 0.35, 0.8])
timestamps = np.arange(len(y_scores)+1) * segment_length
gt = {'1': [(timestamps[2], timestamps[4], 'a')]}
auroc_sklearn = roc_auc_score(y_true, y_scores)
gt = {
'1': [(timestamps[idx], timestamps[idx+1], 'a') for idx, t in enumerate(y_true) if t]
}
scores = {
'1': create_score_dataframe(
y_scores[..., None], timestamps=timestamps, event_classes=['a']
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@ def test_segment_based_average_precision_vs_sklearn():
y_true = np.array([0, 0, 1, 1])
y_scores = np.array([0.1, 0.4, 0.35, 0.8])
timestamps = np.arange(len(y_scores)+1) * segment_length
gt = {'1': [(timestamps[2], timestamps[4], 'a')]}
ap_sklearn = average_precision_score(y_true, y_scores)
gt = {
'1': [(timestamps[idx], timestamps[idx+1], 'a') for idx, t in enumerate(y_true) if t]
}
scores = {
'1': create_score_dataframe(
y_scores[..., None], timestamps=timestamps, event_classes=['a']
Expand Down

0 comments on commit 64ed67a

Please sign in to comment.