@@ -254,34 +254,29 @@ def __init__(
254
254
self .labels = self .lut ["ID" ].values
255
255
self .torch_labels = torch .from_numpy (self .lut ["ID" ].values )
256
256
self .names = ["SubjectName" , "Average" , "Subcortical" , "Cortical" ]
257
- self .cfg_fin , cfg_cor , cfg_sag , cfg_ax = args2cfg (
258
- cfg_ax , cfg_cor , cfg_sag , batch_size = batch_size ,
259
- )
257
+ self .cfg_fin , cfg_cor , cfg_sag , cfg_ax = args2cfg (cfg_ax , cfg_cor , cfg_sag , batch_size = batch_size )
260
258
# the order in this dictionary dictates the order in the view aggregation
261
259
self .view_ops = {
262
260
"coronal" : {"cfg" : cfg_cor , "ckpt" : ckpt_cor },
263
261
"sagittal" : {"cfg" : cfg_sag , "ckpt" : ckpt_sag },
264
262
"axial" : {"cfg" : cfg_ax , "ckpt" : ckpt_ax },
265
263
}
266
- self .num_classes = max (
267
- view ["cfg" ].MODEL .NUM_CLASSES for view in self .view_ops .values ()
268
- )
264
+ # self.num_classes = max(view["cfg"].MODEL.NUM_CLASSES for view in self.view_ops.values() if view["cfg"])
265
+ # currently, num_classes must be 79 in all cases. This seems like it is a config option here, but in reality it
266
+ # is not, so we hard-code it here. Only sagittal has < 79 classes, but num_classes is only used to set the
267
+ # dimensions of the view aggregation tensor, which is after splitting the classes from sagittal to all.
268
+ self .num_classes = 79
269
269
self .models = {}
270
270
for plane , view in self .view_ops .items ():
271
271
if all (view [key ] is not None for key in ("cfg" , "ckpt" )):
272
- self .models [plane ] = Inference (
273
- view ["cfg" ], ckpt = view ["ckpt" ], device = self .device , lut = self .lut ,
274
- )
272
+ self .models [plane ] = Inference (view ["cfg" ], ckpt = view ["ckpt" ], device = self .device , lut = self .lut )
275
273
276
274
if vox_size == "min" :
277
275
self .vox_size = "min"
278
276
elif 0.0 < float (vox_size ) <= 1.0 :
279
277
self .vox_size = float (vox_size )
280
278
else :
281
- raise ValueError (
282
- f"Invalid value for vox_size, must be between 0 and 1 or 'min', was "
283
- f"{ vox_size } ."
284
- )
279
+ raise ValueError (f"Invalid value for vox_size, must be between 0 and 1 or 'min', was { vox_size } ." )
285
280
self .conform_to_1mm_threshold = conform_to_1mm_threshold
286
281
287
282
@property
@@ -571,15 +566,15 @@ def make_parser():
571
566
parser ,
572
567
"checkpoint" ,
573
568
files ,
574
- CHECKPOINT_PATHS_FILE
569
+ CHECKPOINT_PATHS_FILE ,
575
570
)
576
571
577
572
# 4. CFG-file with default options for network
578
573
parser = parser_defaults .add_plane_flags (
579
574
parser ,
580
575
"config" ,
581
576
files ,
582
- CHECKPOINT_PATHS_FILE
577
+ CHECKPOINT_PATHS_FILE ,
583
578
)
584
579
585
580
# 5. technical parameters
0 commit comments