Skip to content

Commit da32765

Browse files
committed
Modify run_prediction so the user is able to "deselect" view to not include.
- parser_defaults.py: Allow empty or none for --cfg_ax, --cfg_sag, --cfg_cor. - run_prediction.py: If cgf_ax/sag/cor is none/empty, do not use that view - some formatting cleanup
1 parent de8416a commit da32765

File tree

3 files changed

+15
-19
lines changed

3 files changed

+15
-19
lines changed

FastSurferCNN/inference.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -382,9 +382,7 @@ def eval(
382382

383383
# check if we need a special mapping (e.g. as for sagittal)
384384
if self.get_plane() == "sagittal":
385-
pred = map_prediction_sagittal2full(
386-
pred, num_classes=self.get_num_classes(), lut=self.lut
387-
)
385+
pred = map_prediction_sagittal2full(pred, num_classes=self.get_num_classes(), lut=self.lut)
388386

389387
# permute the prediction into the out slice order
390388
pred = pred.permute(*self.permute_order[plane]).to(

FastSurferCNN/run_prediction.py

+10-15
Original file line numberDiff line numberDiff line change
@@ -254,34 +254,29 @@ def __init__(
254254
self.labels = self.lut["ID"].values
255255
self.torch_labels = torch.from_numpy(self.lut["ID"].values)
256256
self.names = ["SubjectName", "Average", "Subcortical", "Cortical"]
257-
self.cfg_fin, cfg_cor, cfg_sag, cfg_ax = args2cfg(
258-
cfg_ax, cfg_cor, cfg_sag, batch_size=batch_size,
259-
)
257+
self.cfg_fin, cfg_cor, cfg_sag, cfg_ax = args2cfg(cfg_ax, cfg_cor, cfg_sag, batch_size=batch_size)
260258
# the order in this dictionary dictates the order in the view aggregation
261259
self.view_ops = {
262260
"coronal": {"cfg": cfg_cor, "ckpt": ckpt_cor},
263261
"sagittal": {"cfg": cfg_sag, "ckpt": ckpt_sag},
264262
"axial": {"cfg": cfg_ax, "ckpt": ckpt_ax},
265263
}
266-
self.num_classes = max(
267-
view["cfg"].MODEL.NUM_CLASSES for view in self.view_ops.values()
268-
)
264+
# self.num_classes = max(view["cfg"].MODEL.NUM_CLASSES for view in self.view_ops.values() if view["cfg"])
265+
# currently, num_classes must be 79 in all cases. This seems like it is a config option here, but in reality it
266+
# is not, so we hard-code it here. Only sagittal has < 79 classes, but num_classes is only used to set the
267+
# dimensions of the view aggregation tensor, which is after splitting the classes from sagittal to all.
268+
self.num_classes = 79
269269
self.models = {}
270270
for plane, view in self.view_ops.items():
271271
if all(view[key] is not None for key in ("cfg", "ckpt")):
272-
self.models[plane] = Inference(
273-
view["cfg"], ckpt=view["ckpt"], device=self.device, lut=self.lut,
274-
)
272+
self.models[plane] = Inference(view["cfg"], ckpt=view["ckpt"], device=self.device, lut=self.lut)
275273

276274
if vox_size == "min":
277275
self.vox_size = "min"
278276
elif 0.0 < float(vox_size) <= 1.0:
279277
self.vox_size = float(vox_size)
280278
else:
281-
raise ValueError(
282-
f"Invalid value for vox_size, must be between 0 and 1 or 'min', was "
283-
f"{vox_size}."
284-
)
279+
raise ValueError(f"Invalid value for vox_size, must be between 0 and 1 or 'min', was {vox_size}.")
285280
self.conform_to_1mm_threshold = conform_to_1mm_threshold
286281

287282
@property
@@ -571,15 +566,15 @@ def make_parser():
571566
parser,
572567
"checkpoint",
573568
files,
574-
CHECKPOINT_PATHS_FILE
569+
CHECKPOINT_PATHS_FILE,
575570
)
576571

577572
# 4. CFG-file with default options for network
578573
parser = parser_defaults.add_plane_flags(
579574
parser,
580575
"config",
581576
files,
582-
CHECKPOINT_PATHS_FILE
577+
CHECKPOINT_PATHS_FILE,
583578
)
584579

585580
# 5. technical parameters

FastSurferCNN/utils/parser_defaults.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,9 @@ def add_plane_flags(
407407
if configtype not in PLANE_SHORT:
408408
raise ValueError("type must be either config or checkpoint.")
409409

410+
def cast_type(__value: str) -> Path | None:
411+
return Path(__value) if configtype == "checkpoint" and __value and __value.lower() != "none" else None
412+
410413
from FastSurferCNN.utils.checkpoint import load_checkpoint_config_defaults
411414
defaults = load_checkpoint_config_defaults(configtype, defaults_path)
412415

@@ -422,7 +425,7 @@ def add_plane_flags(
422425
plane_short = plane[: index + 2]
423426
parser.add_argument(
424427
f"--{PLANE_SHORT[configtype]}_{plane_short}",
425-
type=Path,
428+
type=cast_type,
426429
dest=f"{PLANE_SHORT[configtype]}_{plane_short}",
427430
help=PLANE_HELP[configtype].format(plane),
428431
default=path,

0 commit comments

Comments
 (0)