Skip to content

Commit

Permalink
ENH: Combine 2 paps after postprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
daniellepace committed Jul 29, 2024
1 parent b1f84f0 commit 66b06eb
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 15 deletions.
2 changes: 1 addition & 1 deletion ml4h/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def parse_args():

# Arguments for explorations/infer_stats_from_segmented_regions
parser.add_argument('--analyze_ground_truth', default=False, action='store_true', help='Whether or not to filter by images with ground truth segmentations, for comparison')
parser.add_argument('--structures_to_analyze', nargs='*', default=[], help='Structure names to include in the .tsv files and scatter plots. Must be in the same order as the output channel map.')
parser.add_argument('--structures_to_analyze', nargs='*', default=[], help='Structure names to include in the .tsv files and scatter plots. Must be in the same order as the output channel map. Use + to merge structures before postprocessing, and ++ to merge structures after postprocessing.')
parser.add_argument('--erosion_radius', nargs='*', default=[], type=int, help='Radius of the unit disk structuring element for erosion preprocessing, optionally as a list per structure to analyze')
parser.add_argument('--intensity_thresh', type=float, help='Threshold value for preprocessing')
parser.add_argument('--intensity_thresh_in_structures', nargs='*', default=[], help='Structure names whose pixels should be replaced if the images has intensity above the threshold')
Expand Down
52 changes: 38 additions & 14 deletions ml4h/explorations.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,13 +790,12 @@ def _intensity_thresh_auto(
return (bins[np.where(pred == 1)[0][-1]][0] + bins[np.where(pred == 0)[0][0]][0]) / 2

def _scatter_plots_from_segmented_region_stats(
inference_tsv_true, inference_tsv_pred, structures_to_analyze,
output_folder, id, input_name, output_name,
inference_tsv_true, inference_tsv_pred, output_folder, id, input_name, output_name,
):
df_true = pd.read_csv(inference_tsv_true, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL)
df_pred = pd.read_csv(inference_tsv_pred, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL)

results_to_plot = [f'{s}_median' for s in structures_to_analyze]
results_to_plot = [c for c in df_pred.columns if 'median' in c]
for col in results_to_plot:
for i in ['all', 'filter_outliers']: # Two types of plots
plot_data = pd.concat(
Expand Down Expand Up @@ -864,28 +863,54 @@ def infer_stats_from_segmented_regions(args):
_, _, generate_test = test_train_valid_tensor_generators(**args.__dict__)
model, _, _, _ = make_multimodal_multitask_model(**args.__dict__)

# the user can use '+' to create a new channel that merges other channels
# those channels must be included alone in structures_to_analyze as well, and must be at the end of the list
merged_locs = [k for k in range(len(args.structures_to_analyze)) if '+' in args.structures_to_analyze[k]]
# the user can use '+' to create a new channel that merges other channels before postprocessing
# the user can use '++' to create a new channel that merges other channels after postprocessing
# those channels must be included alone in structures_to_analyze as well, then '+', then '++'
merged_locs = [
k for k in range(len(args.structures_to_analyze))
if '+' in args.structures_to_analyze[k] and '++' not in args.structures_to_analyze[k]
]
merged_after_locs = [
k for k in range(len(args.structures_to_analyze))
if '++' in args.structures_to_analyze[k]
]
uni_locs = [k for k in range(len(args.structures_to_analyze)) if '+' not in args.structures_to_analyze[k]]
assert((len(merged_locs) == 0) or (merged_locs[0] > uni_locs[-1]))
assert((len(merged_after_locs) == 0) or (merged_after_locs[0] > uni_locs[-1]))
merged_structures = [args.structures_to_analyze[k] for k in merged_locs]
merged_after_structures = [args.structures_to_analyze[k] for k in merged_after_locs]
uni_structures = [args.structures_to_analyze[k] for k in uni_locs]
merged_channels = [k.split('+') for k in merged_structures]
merged_after_channels = [k.split('++') for k in merged_after_structures]
for i in range(len(merged_channels)):
for j in range(len(merged_channels[i])):
merged_channels[i][j] = tm_out.channel_map[merged_channels[i][j]]
for i in range(len(merged_after_channels)):
for j in range(len(merged_after_channels[i])):
merged_after_channels[i][j] = tm_out.channel_map[merged_after_channels[i][j]]

# structures have to be in the same order as the channel map
good_channels = [tm_out.channel_map[k] for k in uni_structures]
assert (good_channels == sorted(good_channels))
good_structures = [[k for k in tm_out.channel_map.keys() if tm_out.channel_map[k] == v][0] for v in good_channels] + merged_structures
title_structures = [[k for k in tm_out.channel_map.keys() if tm_out.channel_map[k] == v][0] for v in good_channels] \
+ merged_structures + merged_after_structures
nb_orig_channels = len(tm_out.channel_map)
nb_out_channels = len(good_channels) + len(merged_structures)
bad_channels = [k for k in range(nb_orig_channels) if k not in good_channels]
for m in merged_channels:
for c in m:
assert(c in good_channels)
for m in merged_after_channels:
for c in m:
assert(c in good_channels)

# TODO take me out
print(args.structures_to_analyze)
print(merged_structures)
print(merged_channels)
print(merged_after_structures)
print(merged_after_channels)
assert(False)

# Structuring element used for the erosion
if len(args.erosion_radius) > 0:
Expand Down Expand Up @@ -933,11 +958,11 @@ def infer_stats_from_segmented_regions(args):
inference_writer_pred = csv.writer(inference_file_pred, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL)

header = ['sample_id']
header += [f'{k}_mean' for k in good_structures]
header += [f'{k}_median' for k in good_structures]
header += [f'{k}_std' for k in good_structures]
header += [f'{k}_iqr' for k in good_structures]
header += [f'{k}_count' for k in good_structures]
header += [f'{k}_mean' for k in title_structures]
header += [f'{k}_median' for k in title_structures]
header += [f'{k}_std' for k in title_structures]
header += [f'{k}_iqr' for k in title_structures]
header += [f'{k}_count' for k in title_structures]
header += ['mri_date']
inference_writer_true.writerow(header)
inference_writer_pred.writerow(header)
Expand Down Expand Up @@ -1025,8 +1050,7 @@ def postprocess_seg_and_write_stats(y, inference_writer):
# Scatter plots
if args.analyze_ground_truth:
_scatter_plots_from_segmented_region_stats(
inference_tsv_true, inference_tsv_pred, args.structures_to_analyze,
args.output_folder, args.id, tm_in.input_name(), tm_out.output_name(),
inference_tsv_true, inference_tsv_pred, args.output_folder, args.id, tm_in.input_name(), tm_out.output_name(),
)

# pngs
Expand Down

0 comments on commit 66b06eb

Please sign in to comment.