7
7
8
8
from FastSurferCNN .segstats import PVStats
9
9
10
- from .common import ApproxAndLog , SubjectDefinition , Tolerances
10
+ from .common import ApproxAndLog , SubjectDefinition , Tolerances , write_table_file
11
11
12
12
logger = getLogger (__name__ )
13
13
@@ -54,7 +54,7 @@ def measure_tolerances(stats_file: str) -> MeasureTolerances:
54
54
MeasureTolerances
55
55
The list of measures expected for this file.
56
56
"""
57
- thresholds_file = Path (__file__ ).parent / f"data/thresholds/ { stats_file } .yaml"
57
+ thresholds_file = Path (__file__ ).parent / f"data/{ stats_file } .yaml"
58
58
assert Path (thresholds_file ).is_file (), f"The threshold file { thresholds_file } does not exist!"
59
59
return MeasureTolerances (thresholds_file )
60
60
@@ -73,7 +73,7 @@ def stats_tolerances(stats_file: str) -> Tolerances:
73
73
Tolerances
74
74
Per-structure tolerances object
75
75
"""
76
- thresholds_file = Path (__file__ ).parent / f"data/thresholds/ { stats_file } .yaml"
76
+ thresholds_file = Path (__file__ ).parent / f"data/{ stats_file } .yaml"
77
77
assert Path (thresholds_file ).is_file (), f"The threshold file { thresholds_file } does not exist!"
78
78
return Tolerances (thresholds_file )
79
79
@@ -159,6 +159,7 @@ def test_measure_thresholds(
159
159
ref_subject : SubjectDefinition ,
160
160
stats_file : str ,
161
161
measure_tolerances : MeasureTolerances ,
162
+ pytestconfig : pytest .Config ,
162
163
):
163
164
"""
164
165
Test if the measure is within thresholds in stats_file.
@@ -173,6 +174,8 @@ def test_measure_thresholds(
173
174
Name of the test directory.
174
175
measure_tolerances : MeasureTolerances
175
176
The object to provide the measure tolerances for stats_file.
177
+ pytestconfig : pytest.Config
178
+ The sessions config object.
176
179
177
180
Raises
178
181
------
@@ -198,6 +201,14 @@ def check_measure(measure: str) -> bool:
198
201
actual : MeasureTuple = actual_annots [measure ]
199
202
return expected == ApproxAndLog (actual , rel = measure_tolerances .threshold (measure ))
200
203
204
+ delta_dir : Path = pytestconfig .getoption ("--collect_csv" )
205
+ if delta_dir :
206
+ delta_dir .mkdir (parents = True , exist_ok = True )
207
+ values = [(m , expected_annots [m ][2 ], actual_annots [m ][2 ]) for m in expected_measures ]
208
+ scores : dict [str , float ] = {m : abs (a - b ) for m , a , b in values }
209
+ scores .update ({m + "_rel" : abs (a - b )/ max ((abs (a ), abs (b ))) for m , a , b in values })
210
+ write_table_file (delta_dir / "stats-measure.csv" , test_subject .name , stats_file , scores )
211
+
201
212
failed_measures = (m for m in expected_measures if not check_measure (m ))
202
213
measures_outside_spec = [f"Measure { m } : { expected_annots [m ][2 ]} <> { actual_annots [m ][2 ]} " for m in failed_measures ]
203
214
assert [] == measures_outside_spec , f"Some Measures are outside of the threshold in { test_subject } : { stats_file } !"
@@ -241,6 +252,7 @@ def test_stats_table(
241
252
ref_subject : SubjectDefinition ,
242
253
stats_file : str ,
243
254
stats_tolerances : Tolerances ,
255
+ pytestconfig : pytest .Config ,
244
256
):
245
257
"""
246
258
Test if the tables are within the threshold.
@@ -255,6 +267,8 @@ def test_stats_table(
255
267
Name of the test directory.
256
268
stats_tolerances : Tolerances
257
269
The object to provide the tolerances for stats_file.
270
+ pytestconfig : pytest.Config
271
+ The sessions config object.
258
272
259
273
Raises
260
274
------
@@ -264,9 +278,14 @@ def test_stats_table(
264
278
_ , _ , expected_table = ref_subject .load_stats_file (stats_file )
265
279
_ , _ , actual_table = test_subject .load_stats_file (stats_file )
266
280
actual_segids = [stats ["SegId" ] for stats in actual_table ]
281
+ ignored_columns = ["SegId" , "StructName" ]
267
282
268
283
def filter_keys (stats : PVStats ) -> dict [str , int | float ]:
269
- return {k : v for k , v in stats .items () if k not in ["SegId" , "StructName" ]}
284
+ return {k : v for k , v in stats .items () if k not in ignored_columns }
285
+
286
+ delta_dir : Path = pytestconfig .getoption ("--collect_csv" )
287
+ table_data = []
288
+ keys = []
270
289
271
290
expected_different = []
272
291
actual_different = []
@@ -275,10 +294,27 @@ def filter_keys(stats: PVStats) -> dict[str, int | float]:
275
294
_expected = filter_keys (expected )
276
295
actual = actual_table [actual_segids .index (expected_segid )]
277
296
_actual = filter_keys (actual )
297
+ if delta_dir :
298
+ keys .append ((list (_expected .keys ()), list (_actual .keys ())))
299
+ table_data .append ((expected_segid , _expected , _actual ))
278
300
if not _expected == ApproxAndLog (_actual , abs = stats_tolerances .threshold (expected_segid )):
279
301
expected_different .append (expected )
280
302
actual_different .append (actual )
281
303
304
+ if delta_dir :
305
+ expected_keys , actual_keys = zip (* keys , strict = True )
306
+ assert expected_keys == actual_keys , "To compare tables keys of reference and test have to be identical!"
307
+
308
+ def relative (a , b , field ) -> float :
309
+ return abs (a [field ] - b [field ]) / max ((abs (a [field ]), abs (b [field ]), 1e-8 ))
310
+
311
+ delta_dir .mkdir (parents = True , exist_ok = True )
312
+
313
+ for field in actual_keys [0 ]:
314
+ scores : dict [str , float ] = {f"{ seg_id } :{ field } " : abs (a [field ] - b [field ]) for seg_id , a , b in table_data }
315
+ scores .update ({f"{ seg_id } :rel-{ field } " : relative (a , b , field ) for seg_id , a , b in table_data })
316
+ write_table_file (delta_dir / "stats-table.csv" , test_subject .name , stats_file , scores )
317
+
282
318
assert expected_different == actual_different , f"The tables of some structures in { stats_file } are 'too' different!"
283
319
284
320
0 commit comments