Skip to content

Commit d9bbea2

Browse files
committed
Add functionality to export computed values into a csv file
Move thresholds files into data (instead of thresholds) and remove thresholds folder
1 parent a7b6143 commit d9bbea2

16 files changed

+121
-10
lines changed

test/quicktest/common.py

+35
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,38 @@ def __eq__(self, other):
133133
except BaseException as e:
134134
logger.exception(e)
135135
return within_bounds
136+
137+
138+
def write_table_file(
139+
table_file: Path | None,
140+
subject_id: str,
141+
file: str,
142+
scores: dict[int, int | float] | dict[str, int | float],
143+
) -> None:
144+
"""
145+
Logs the calculated statistics (difference between test and reference) to a table file.
146+
147+
Parameters
148+
----------
149+
table_file : Path, None
150+
The file to write to, skip if None.
151+
subject_id : str
152+
The subject id.
153+
file : file
154+
The file associated with the comparison.
155+
scores : dict[int | str, int | float]
156+
The pairs of data associated with the comparison, e.g. index and value.
157+
"""
158+
if not bool(table_file):
159+
# no valid file passed, skip
160+
return
161+
162+
for id, score in scores.items():
163+
fmt = f'''"{{subject_id}}","{{file}}",{"{id:d}" if isinstance(id, int) else f'"{id}"'},'''
164+
fmt += f"{{score:{'.6f' if isinstance(score, float) else 'd'}}}\n"
165+
data = {"subject_id": subject_id, "file": file, "id": id, "score": score}
166+
if not table_file.is_file():
167+
with open(table_file, "w") as f:
168+
f.write(",".join(data.keys()) + "\n")
169+
with open(table_file, "a") as f:
170+
f.write(fmt.format(**data))

test/quicktest/conftest.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,12 @@
3232
Folder with reference data (defined in environment variable).
3333
"""
3434
reference_dir: Path = env["REF_DIR"]
35+
__subjects = (p for p in reference_dir.iterdir() if p.is_dir() and p.name not in ("logs", "slurm"))
36+
__max_subjects = int(os.environ.get("MAX_SUBJECTS", -1))
3537
"""
3638
Load the test subjects from the reference path (one subject per folder).
3739
"""
38-
ref_subjects: list[Path] = [p for p in reference_dir.iterdir() if p.is_dir()]
40+
ref_subjects: list[Path] = [p for i, p in enumerate(__subjects) if i < __max_subjects or __max_subjects < 0]
3941

4042
assert len(ref_subjects) > 0, "No test subjects found!"
4143

@@ -62,3 +64,13 @@ def ref_subject(request: pytest.FixtureRequest) -> SubjectDefinition:
6264
@pytest.fixture(scope="session")
6365
def test_subject(ref_subject: SubjectDefinition, subjects_dir: Path) -> SubjectDefinition:
6466
return ref_subject.with_subjects_dir(subjects_dir)
67+
68+
def pytest_addoption(parser):
69+
# the following options is for are for test_images and test_stats only
70+
parser.addoption(
71+
"--collect_csv",
72+
action="store",
73+
default=None,
74+
type=Path,
75+
help="Directory to store csv files that will collect all differences between reference and test.",
76+
)
File renamed without changes.
File renamed without changes.

test/quicktest/test_images.py

+33-5
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from FastSurferCNN.utils.metrics import dice_score
1212

13-
from .common import SubjectDefinition, Tolerances
13+
from .common import SubjectDefinition, Tolerances, write_table_file
1414

1515
logger = getLogger(__name__)
1616

@@ -19,7 +19,7 @@
1919
@pytest.fixture(scope='module')
2020
def segmentation_tolerances(segmentation_image: str) -> Tolerances:
2121

22-
thresholds_file = Path(__file__).parent / f"data/thresholds/{segmentation_image}.yaml"
22+
thresholds_file = Path(__file__).parent / f"data/{segmentation_image}.yaml"
2323
assert thresholds_file.exists(), f"The thresholds file {thresholds_file} does not exist!"
2424
return Tolerances(thresholds_file)
2525

@@ -51,10 +51,11 @@ def compute_dice_score(test_data, reference_data, labels: dict[int, str]) -> tup
5151
Dice scores for each class.
5252
"""
5353
dice_scores = {}
54-
logger.debug("Dice scores:")
54+
logger.debug("Dice scores (reporting non-zero):")
5555
for _, (label, lname) in enumerate(labels.items()):
5656
dice_scores[label] = dice_score(np.equal(reference_data, label), np.equal(test_data, label), validate=False)
57-
logger.debug(f"Label {lname}: {dice_scores[label]:.4f}")
57+
if dice_scores[label] > 0:
58+
logger.debug(f"Label {lname}: {dice_scores[label]:.4f}")
5859
mean_dice_score = np.asarray(list(dice_scores.values())).mean()
5960
return mean_dice_score, dice_scores
6061

@@ -96,6 +97,7 @@ def test_segmentation_image(
9697
ref_subject: SubjectDefinition,
9798
segmentation_image: str,
9899
segmentation_tolerances: Tolerances,
100+
pytestconfig: pytest.Config,
99101
):
100102
"""
101103
Test the segmentation data by calculating and comparing dice scores.
@@ -110,6 +112,8 @@ def test_segmentation_image(
110112
Name of the segmentation image file.
111113
segmentation_tolerances: Tolerances
112114
Object to provide the relevant tolerances for the respective segmentation_image.
115+
pytestconfig : pytest.Config
116+
The sessions config object.
113117
114118
Raises
115119
------
@@ -129,13 +133,23 @@ def test_segmentation_image(
129133
# Compute the dice score
130134
mean_dice, dice_scores = compute_dice_score(test_data, reference_data, labels_lnames)
131135

136+
delta_dir: Path = pytestconfig.getoption("--collect_csv")
137+
if delta_dir:
138+
delta_dir.mkdir(parents=True, exist_ok=True)
139+
write_table_file(delta_dir / "dice.csv", test_subject.name, segmentation_image, dice_scores)
140+
132141
failed_labels = (i for i, dice in dice_scores.items() if not np.isclose(dice, 0, atol=labels_lnames_tols[i][1]))
133142
dice_exceeding_threshold = [f"Label {labels_lnames[lbl]}: {1-dice_scores[lbl]}" for lbl in failed_labels]
134143
assert [] == dice_exceeding_threshold, f"Dice scores in {segmentation_image} are not within range!"
135144
logger.debug("Dice scores are within range for all classes")
136145

137146

138-
def test_intensity_image(test_subject: SubjectDefinition, ref_subject: SubjectDefinition, intensity_image: str):
147+
def test_intensity_image(
148+
test_subject: SubjectDefinition,
149+
ref_subject: SubjectDefinition,
150+
intensity_image: str,
151+
pytestconfig: pytest.Config,
152+
):
139153
"""
140154
Test the intensity data by calculating and comparing the mean square error.
141155
@@ -147,6 +161,8 @@ def test_intensity_image(test_subject: SubjectDefinition, ref_subject: SubjectDe
147161
Definition of the reference subject.
148162
intensity_image : str
149163
Name of the image file.
164+
pytestconfig : pytest.Config
165+
The sessions config object.
150166
151167
Raises
152168
------
@@ -158,6 +174,18 @@ def test_intensity_image(test_subject: SubjectDefinition, ref_subject: SubjectDe
158174
test_data = test_img.get_fdata()
159175
reference_file, reference_img = ref_subject.load_image(intensity_image)
160176
reference_data = reference_img.get_fdata()
177+
178+
delta_dir = pytestconfig.getoption("--collect_csv")
179+
if delta_dir:
180+
delta_dir.mkdir(parents=True, exist_ok=True)
181+
# this analysis will write not only the max (tested below, but also mean, mse and some percentiles)
182+
absdelta = np.abs(test_data - reference_data)
183+
scores = {p: v for p, v in zip(("median", "95th", "99th"), np.percentile(absdelta, (50, 95, 99)), strict=True)}
184+
scores.update(mean=absdelta.mean(), max=np.max(absdelta))
185+
reldelta = absdelta / np.maximum(np.maximum(abs(reference_data), abs(test_data)), 1e-8)
186+
scores.update(rel=reldelta.mean(), relmax=np.max(reldelta))
187+
write_table_file(delta_dir / "intensity.csv", test_subject.name, intensity_image, scores)
188+
161189
# Check the image data
162190
np.testing.assert_allclose(test_data, reference_data, rtol=1e-4, err_msg="Image intensity data do not match!")
163191

test/quicktest/test_stats.py

+40-4
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from FastSurferCNN.segstats import PVStats
99

10-
from .common import ApproxAndLog, SubjectDefinition, Tolerances
10+
from .common import ApproxAndLog, SubjectDefinition, Tolerances, write_table_file
1111

1212
logger = getLogger(__name__)
1313

@@ -54,7 +54,7 @@ def measure_tolerances(stats_file: str) -> MeasureTolerances:
5454
MeasureTolerances
5555
The list of measures expected for this file.
5656
"""
57-
thresholds_file = Path(__file__).parent / f"data/thresholds/{stats_file}.yaml"
57+
thresholds_file = Path(__file__).parent / f"data/{stats_file}.yaml"
5858
assert Path(thresholds_file).is_file(), f"The threshold file {thresholds_file} does not exist!"
5959
return MeasureTolerances(thresholds_file)
6060

@@ -73,7 +73,7 @@ def stats_tolerances(stats_file: str) -> Tolerances:
7373
Tolerances
7474
Per-structure tolerances object
7575
"""
76-
thresholds_file = Path(__file__).parent / f"data/thresholds/{stats_file}.yaml"
76+
thresholds_file = Path(__file__).parent / f"data/{stats_file}.yaml"
7777
assert Path(thresholds_file).is_file(), f"The threshold file {thresholds_file} does not exist!"
7878
return Tolerances(thresholds_file)
7979

@@ -159,6 +159,7 @@ def test_measure_thresholds(
159159
ref_subject: SubjectDefinition,
160160
stats_file: str,
161161
measure_tolerances: MeasureTolerances,
162+
pytestconfig: pytest.Config,
162163
):
163164
"""
164165
Test if the measure is within thresholds in stats_file.
@@ -173,6 +174,8 @@ def test_measure_thresholds(
173174
Name of the test directory.
174175
measure_tolerances : MeasureTolerances
175176
The object to provide the measure tolerances for stats_file.
177+
pytestconfig : pytest.Config
178+
The sessions config object.
176179
177180
Raises
178181
------
@@ -198,6 +201,14 @@ def check_measure(measure: str) -> bool:
198201
actual: MeasureTuple = actual_annots[measure]
199202
return expected == ApproxAndLog(actual, rel=measure_tolerances.threshold(measure))
200203

204+
delta_dir: Path = pytestconfig.getoption("--collect_csv")
205+
if delta_dir:
206+
delta_dir.mkdir(parents=True, exist_ok=True)
207+
values = [(m, expected_annots[m][2], actual_annots[m][2]) for m in expected_measures]
208+
scores: dict[str, float] = {m: abs(a - b) for m, a, b in values}
209+
scores.update({m + "_rel": abs(a - b)/max((abs(a), abs(b))) for m, a, b in values})
210+
write_table_file(delta_dir / "stats-measure.csv", test_subject.name, stats_file, scores)
211+
201212
failed_measures = (m for m in expected_measures if not check_measure(m))
202213
measures_outside_spec = [f"Measure {m}: {expected_annots[m][2]} <> {actual_annots[m][2]}" for m in failed_measures]
203214
assert [] == measures_outside_spec, f"Some Measures are outside of the threshold in {test_subject}: {stats_file}!"
@@ -241,6 +252,7 @@ def test_stats_table(
241252
ref_subject: SubjectDefinition,
242253
stats_file: str,
243254
stats_tolerances: Tolerances,
255+
pytestconfig: pytest.Config,
244256
):
245257
"""
246258
Test if the tables are within the threshold.
@@ -255,6 +267,8 @@ def test_stats_table(
255267
Name of the test directory.
256268
stats_tolerances : Tolerances
257269
The object to provide the tolerances for stats_file.
270+
pytestconfig : pytest.Config
271+
The sessions config object.
258272
259273
Raises
260274
------
@@ -264,9 +278,14 @@ def test_stats_table(
264278
_, _, expected_table = ref_subject.load_stats_file(stats_file)
265279
_, _, actual_table = test_subject.load_stats_file(stats_file)
266280
actual_segids = [stats["SegId"] for stats in actual_table]
281+
ignored_columns = ["SegId", "StructName"]
267282

268283
def filter_keys(stats: PVStats) -> dict[str, int | float]:
269-
return {k: v for k, v in stats.items() if k not in ["SegId", "StructName"]}
284+
return {k: v for k, v in stats.items() if k not in ignored_columns}
285+
286+
delta_dir: Path = pytestconfig.getoption("--collect_csv")
287+
table_data = []
288+
keys = []
270289

271290
expected_different = []
272291
actual_different = []
@@ -275,10 +294,27 @@ def filter_keys(stats: PVStats) -> dict[str, int | float]:
275294
_expected = filter_keys(expected)
276295
actual = actual_table[actual_segids.index(expected_segid)]
277296
_actual = filter_keys(actual)
297+
if delta_dir:
298+
keys.append((list(_expected.keys()), list(_actual.keys())))
299+
table_data.append((expected_segid, _expected, _actual))
278300
if not _expected == ApproxAndLog(_actual, abs=stats_tolerances.threshold(expected_segid)):
279301
expected_different.append(expected)
280302
actual_different.append(actual)
281303

304+
if delta_dir:
305+
expected_keys, actual_keys = zip(*keys, strict=True)
306+
assert expected_keys == actual_keys, "To compare tables keys of reference and test have to be identical!"
307+
308+
def relative(a, b, field) -> float:
309+
return abs(a[field] - b[field]) / max((abs(a[field]), abs(b[field]), 1e-8))
310+
311+
delta_dir.mkdir(parents=True, exist_ok=True)
312+
313+
for field in actual_keys[0]:
314+
scores: dict[str, float] = {f"{seg_id}:{field}": abs(a[field] - b[field]) for seg_id, a, b in table_data}
315+
scores.update({f"{seg_id}:rel-{field}": relative(a, b, field) for seg_id, a, b in table_data})
316+
write_table_file(delta_dir / "stats-table.csv", test_subject.name, stats_file, scores)
317+
282318
assert expected_different == actual_different, f"The tables of some structures in {stats_file} are 'too' different!"
283319

284320

0 commit comments

Comments
 (0)