Skip to content

Commit

Permalink
Dfp pm seg instance 3 (#563)
Browse files Browse the repository at this point in the history
* ENH: Add tensormap for instance_3

* FIX: Put back Dice import

* COMP: Remove duplicate imports
  • Loading branch information
daniellepace authored May 17, 2024
1 parent 3e8b662 commit ddcfb60
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 29 deletions.
2 changes: 1 addition & 1 deletion ml4h/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tensorflow.keras.losses import binary_crossentropy, categorical_crossentropy, sparse_categorical_crossentropy
from tensorflow.keras.losses import logcosh, cosine_similarity, mean_squared_error, mean_absolute_error, mean_absolute_percentage_error

#from neurite.tf.losses import Dice
from neurite.tf.losses import Dice

STRING_METRICS = [
'categorical_crossentropy','binary_crossentropy','mean_absolute_error','mae',
Expand Down
60 changes: 37 additions & 23 deletions ml4h/recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@
from ml4h.tensor_generators import TensorGenerator, test_train_valid_tensor_generators, big_batch_from_minibatch_generator
from ml4h.data_descriptions import dataframe_data_description_from_tensor_map, ECGDataDescription, DataFrameDataDescription
from ml4h.metrics import get_roc_aucs, get_precision_recall_aucs, get_pearson_coefficients, log_aucs, log_pearson_coefficients, concordance_index_censored
from ml4h.plots import plot_dice, plot_reconstruction, plot_saliency_maps, plot_partners_ecgs, plot_ecg_rest_mp
from ml4h.plots import evaluate_predictions, plot_scatters, plot_rocs, plot_precision_recalls, subplot_roc_per_class, plot_tsne, plot_survival, plot_dice
from ml4h.plots import plot_reconstruction, plot_hit_to_miss_transforms, plot_saliency_maps, plot_partners_ecgs, plot_ecg_rest_mp
from ml4h.plots import plot_dice, plot_reconstruction, plot_hit_to_miss_transforms, plot_saliency_maps, plot_partners_ecgs, plot_ecg_rest_mp
from ml4h.plots import subplot_rocs, subplot_comparison_rocs, subplot_scatters, subplot_comparison_scatters, plot_prediction_calibrations
from ml4h.models.legacy_models import make_character_model_plus, embed_model_predict, make_siamese_model, legacy_multimodal_multitask_model
from ml4h.plots import evaluate_predictions, plot_scatters, plot_rocs, plot_precision_recalls, subplot_roc_per_class, plot_tsne, plot_survival
Expand Down Expand Up @@ -140,7 +138,7 @@ def run(args):

except Exception as e:
logging.exception(e)

if args.gcs_cloud_bucket is not None:
save_to_google_cloud(args)

Expand Down Expand Up @@ -348,10 +346,14 @@ def option_picker(sample_id, data_descriptions):
valid_ids = list(mrn_df[mrn_df.split == 'valid'].index)
test_ids = list(mrn_df[mrn_df.split == 'test'].index)

train_dataset = SampleGetterIterableDataset(sample_ids=list(train_ids), sample_getter=sg,
get_epoch=shuffle_get_epoch)
valid_dataset = SampleGetterIterableDataset(sample_ids=list(valid_ids), sample_getter=sg,
get_epoch=shuffle_get_epoch)
train_dataset = SampleGetterIterableDataset(
sample_ids=list(train_ids), sample_getter=sg,
get_epoch=shuffle_get_epoch,
)
valid_dataset = SampleGetterIterableDataset(
sample_ids=list(valid_ids), sample_getter=sg,
get_epoch=shuffle_get_epoch,
)

num_train_workers = int(args.training_steps / (args.training_steps + args.validation_steps) * args.num_workers) or (1 if args.num_workers else 0)
num_valid_workers = int(args.validation_steps / (args.training_steps + args.validation_steps) * args.num_workers) or (1 if args.num_workers else 0)
Expand Down Expand Up @@ -447,10 +449,14 @@ def option_picker(sample_id, data_descriptions):
valid_ids = list(mrn_df[mrn_df.split == 'valid'].index)
test_ids = list(mrn_df[mrn_df.split == 'test'].index)

train_dataset = SampleGetterIterableDataset(sample_ids=list(train_ids), sample_getter=sg,
get_epoch=shuffle_get_epoch)
valid_dataset = SampleGetterIterableDataset(sample_ids=list(valid_ids), sample_getter=sg,
get_epoch=shuffle_get_epoch)
train_dataset = SampleGetterIterableDataset(
sample_ids=list(train_ids), sample_getter=sg,
get_epoch=shuffle_get_epoch,
)
valid_dataset = SampleGetterIterableDataset(
sample_ids=list(valid_ids), sample_getter=sg,
get_epoch=shuffle_get_epoch,
)

num_train_workers = int(args.training_steps / (args.training_steps + args.validation_steps) * args.num_workers) or (1 if args.num_workers else 0)
num_valid_workers = int(args.validation_steps / (args.training_steps + args.validation_steps) * args.num_workers) or (1 if args.num_workers else 0)
Expand Down Expand Up @@ -483,14 +489,16 @@ def option_picker(sample_id, data_descriptions):
output_data_descriptions=output_dds, # what we want a model to predict from the input data
option_picker=option_picker,
)
test_dataset = SampleGetterIterableDataset(sample_ids=list(test_ids), sample_getter=test_sg,
get_epoch=shuffle_get_epoch)
test_dataset = SampleGetterIterableDataset(
sample_ids=list(test_ids), sample_getter=test_sg,
get_epoch=shuffle_get_epoch,
)

generate_test = TensorMapDataLoader2(
batch_size=args.batch_size, input_maps=args.tensor_maps_in, output_maps=args.tensor_maps_out,
dataset=test_dataset,
num_workers=num_train_workers,
)
)

y_trues = defaultdict(list)
y_preds = defaultdict(list)
Expand Down Expand Up @@ -518,8 +526,10 @@ def option_picker(sample_id, data_descriptions):
plot_survival(y_preds[otm.name], y_trues[otm.name], f'{otm.name.upper()} Model:{args.id}', otm.days_window)
elif otm.is_categorical():
plot_roc(y_preds[otm.name], y_trues[otm.name], otm.channel_map, f'{otm.name} ROC')
plot_precision_recall_per_class(y_preds[otm.name], y_trues[otm.name], otm.channel_map,
f'{otm.name} Precision Recall')
plot_precision_recall_per_class(
y_preds[otm.name], y_trues[otm.name], otm.channel_map,
f'{otm.name} Precision Recall',
)
elif otm.is_continuous():
plot_scatter(y_preds[otm.name], y_trues[otm.name], f'{otm.name} Scatter')

Expand Down Expand Up @@ -561,7 +571,7 @@ def infer_from_dataloader(dataloader, model, tensor_maps_out, max_batches=125000
space_dict[f'{otm.name}_event'].append(str(sick[0]))
space_dict[f'{otm.name}_follow_up'].append(str(follow_up[0]))
for k in target:
if k in ['MRN', 'linker_id', 'is_c3po', 'output_age_in_days_continuous' ]:
if k in ['MRN', 'linker_id', 'is_c3po', 'output_age_in_days_continuous']:
space_dict[f'{k}'].append(target[k][b].numpy())
elif k in ['datetime']:
space_dict[f'{k}'].append(float_to_datetime(int(target[k][b].numpy())))
Expand Down Expand Up @@ -749,13 +759,17 @@ def infer_multimodal_multitask(args):
hd5_path = os.path.join(args.output_folder, args.id, 'inferred_hd5s', f'{sample_id}{TENSOR_EXT}')
os.makedirs(os.path.dirname(hd5_path), exist_ok=True)
with h5py.File(hd5_path, 'a') as hd5:
hd5.create_dataset(f'{otm.name}_truth', data=otm.rescale(output_data[otm.output_name()][0]),
compression='gzip')
hd5.create_dataset(
f'{otm.name}_truth', data=otm.rescale(output_data[otm.output_name()][0]),
compression='gzip',
)
if otm.path_prefix == 'ukb_ecg_rest':
for lead in otm.channel_map:
hd5.create_dataset(f'/ukb_ecg_rest/{lead}/instance_0',
data=otm.rescale(y[0, otm.channel_map[lead]]),
compression='gzip')
hd5.create_dataset(
f'/ukb_ecg_rest/{lead}/instance_0',
data=otm.rescale(y[0, otm.channel_map[lead]]),
compression='gzip',
)
inference_writer.writerow(csv_row)
tensor_paths_inferred.add(tensor_paths[0])
stats['count'] += 1
Expand Down
26 changes: 21 additions & 5 deletions ml4h/tensormap/ukb/mri.py
Original file line number Diff line number Diff line change
Expand Up @@ -2748,13 +2748,21 @@ def _pad_crop_single_channel(tm, hd5, dependents={}, key_prefix=None):
img,
)

def _pad_crop_single_channel_t1map_b2(tm, hd5, dependents={}):
if f'/{tm.path_prefix}/shmolli_192i_sax_b2s_sax_b2s_sax_b2s_t1map/instance_2' in hd5:
key_prefix = f'/{tm.path_prefix}/shmolli_192i_sax_b2s_sax_b2s_sax_b2s_t1map/instance_2'
elif f'/{tm.path_prefix}/shmolli_192i_b2_sax_b2s_sax_b2s_sax_b2s_t1map/instance_2' in hd5:
key_prefix = f'/{tm.path_prefix}/shmolli_192i_b2_sax_b2s_sax_b2s_sax_b2s_t1map/instance_2'
def _pad_crop_single_channel_t1map_b2_instance(tm, hd5, instance):
if f'/{tm.path_prefix}/shmolli_192i_sax_b2s_sax_b2s_sax_b2s_t1map/instance_{instance}' in hd5:
key_prefix = f'/{tm.path_prefix}/shmolli_192i_sax_b2s_sax_b2s_sax_b2s_t1map/instance_{instance}'
elif f'/{tm.path_prefix}/shmolli_192i_b2_sax_b2s_sax_b2s_sax_b2s_t1map/instance_{instance}' in hd5:
key_prefix = f'/{tm.path_prefix}/shmolli_192i_b2_sax_b2s_sax_b2s_sax_b2s_t1map/instance_{instance}'
else:
raise ValueError(f'Could not find T1 Map image for tensormap: {tm.name}')
return key_prefix

def _pad_crop_single_channel_t1map_b2(tm, hd5, dependents={}):
key_prefix = _pad_crop_single_channel_t1map_b2_instance(tm, hd5, 2)
return _pad_crop_single_channel(tm, hd5, dependents, key_prefix)

def _pad_crop_single_channel_t1map_b2_instance_3(tm, hd5, dependents={}):
key_prefix = _pad_crop_single_channel_t1map_b2_instance(tm, hd5, 3)
return _pad_crop_single_channel(tm, hd5, dependents, key_prefix)

t1map_b2 = TensorMap(
Expand All @@ -2765,6 +2773,14 @@ def _pad_crop_single_channel_t1map_b2(tm, hd5, dependents={}):
tensor_from_file=_pad_crop_single_channel_t1map_b2,
)

t1map_b2_instance3 = TensorMap(
'shmolli_192i_sax_b2s_sax_b2s_sax_b2s_t1map',
shape=(384, 384, 1),
path_prefix='ukb_cardiac_mri',
normalization=Standardize(mean=455.81, std=609.50),
tensor_from_file=_pad_crop_single_channel_t1map_b2_instance_3,
)

t1map_pancreas = TensorMap(
'shmolli_192i_pancreas_t1map',
shape=(288, 384, 1),
Expand Down

0 comments on commit ddcfb60

Please sign in to comment.