diff --git a/ml4h/metrics.py b/ml4h/metrics.py index c5cda5aef..d4444f174 100755 --- a/ml4h/metrics.py +++ b/ml4h/metrics.py @@ -10,7 +10,7 @@ from tensorflow.keras.losses import binary_crossentropy, categorical_crossentropy, sparse_categorical_crossentropy from tensorflow.keras.losses import logcosh, cosine_similarity, mean_squared_error, mean_absolute_error, mean_absolute_percentage_error -#from neurite.tf.losses import Dice +from neurite.tf.losses import Dice STRING_METRICS = [ 'categorical_crossentropy','binary_crossentropy','mean_absolute_error','mae', diff --git a/ml4h/recipes.py b/ml4h/recipes.py index 13473e195..bd367e993 100755 --- a/ml4h/recipes.py +++ b/ml4h/recipes.py @@ -34,9 +34,7 @@ from ml4h.tensor_generators import TensorGenerator, test_train_valid_tensor_generators, big_batch_from_minibatch_generator from ml4h.data_descriptions import dataframe_data_description_from_tensor_map, ECGDataDescription, DataFrameDataDescription from ml4h.metrics import get_roc_aucs, get_precision_recall_aucs, get_pearson_coefficients, log_aucs, log_pearson_coefficients, concordance_index_censored -from ml4h.plots import plot_dice, plot_reconstruction, plot_saliency_maps, plot_partners_ecgs, plot_ecg_rest_mp -from ml4h.plots import evaluate_predictions, plot_scatters, plot_rocs, plot_precision_recalls, subplot_roc_per_class, plot_tsne, plot_survival, plot_dice -from ml4h.plots import plot_reconstruction, plot_hit_to_miss_transforms, plot_saliency_maps, plot_partners_ecgs, plot_ecg_rest_mp +from ml4h.plots import plot_dice, plot_reconstruction, plot_hit_to_miss_transforms, plot_saliency_maps, plot_partners_ecgs, plot_ecg_rest_mp from ml4h.plots import subplot_rocs, subplot_comparison_rocs, subplot_scatters, subplot_comparison_scatters, plot_prediction_calibrations from ml4h.models.legacy_models import make_character_model_plus, embed_model_predict, make_siamese_model, legacy_multimodal_multitask_model from ml4h.plots import evaluate_predictions, plot_scatters, plot_rocs, plot_precision_recalls, subplot_roc_per_class, plot_tsne, plot_survival @@ -140,7 +138,7 @@ def run(args): except Exception as e: logging.exception(e) - + if args.gcs_cloud_bucket is not None: save_to_google_cloud(args) @@ -348,10 +346,14 @@ def option_picker(sample_id, data_descriptions): valid_ids = list(mrn_df[mrn_df.split == 'valid'].index) test_ids = list(mrn_df[mrn_df.split == 'test'].index) - train_dataset = SampleGetterIterableDataset(sample_ids=list(train_ids), sample_getter=sg, - get_epoch=shuffle_get_epoch) - valid_dataset = SampleGetterIterableDataset(sample_ids=list(valid_ids), sample_getter=sg, - get_epoch=shuffle_get_epoch) + train_dataset = SampleGetterIterableDataset( + sample_ids=list(train_ids), sample_getter=sg, + get_epoch=shuffle_get_epoch, + ) + valid_dataset = SampleGetterIterableDataset( + sample_ids=list(valid_ids), sample_getter=sg, + get_epoch=shuffle_get_epoch, + ) num_train_workers = int(args.training_steps / (args.training_steps + args.validation_steps) * args.num_workers) or (1 if args.num_workers else 0) num_valid_workers = int(args.validation_steps / (args.training_steps + args.validation_steps) * args.num_workers) or (1 if args.num_workers else 0) @@ -447,10 +449,14 @@ def option_picker(sample_id, data_descriptions): valid_ids = list(mrn_df[mrn_df.split == 'valid'].index) test_ids = list(mrn_df[mrn_df.split == 'test'].index) - train_dataset = SampleGetterIterableDataset(sample_ids=list(train_ids), sample_getter=sg, - get_epoch=shuffle_get_epoch) - valid_dataset = SampleGetterIterableDataset(sample_ids=list(valid_ids), sample_getter=sg, - get_epoch=shuffle_get_epoch) + train_dataset = SampleGetterIterableDataset( + sample_ids=list(train_ids), sample_getter=sg, + get_epoch=shuffle_get_epoch, + ) + valid_dataset = SampleGetterIterableDataset( + sample_ids=list(valid_ids), sample_getter=sg, + get_epoch=shuffle_get_epoch, + ) num_train_workers = int(args.training_steps / (args.training_steps + args.validation_steps) * args.num_workers) or (1 if args.num_workers else 0) num_valid_workers = int(args.validation_steps / (args.training_steps + args.validation_steps) * args.num_workers) or (1 if args.num_workers else 0) @@ -483,14 +489,16 @@ def option_picker(sample_id, data_descriptions): output_data_descriptions=output_dds, # what we want a model to predict from the input data option_picker=option_picker, ) - test_dataset = SampleGetterIterableDataset(sample_ids=list(test_ids), sample_getter=test_sg, - get_epoch=shuffle_get_epoch) + test_dataset = SampleGetterIterableDataset( + sample_ids=list(test_ids), sample_getter=test_sg, + get_epoch=shuffle_get_epoch, + ) generate_test = TensorMapDataLoader2( batch_size=args.batch_size, input_maps=args.tensor_maps_in, output_maps=args.tensor_maps_out, dataset=test_dataset, num_workers=num_train_workers, - ) + ) y_trues = defaultdict(list) y_preds = defaultdict(list) @@ -518,8 +526,10 @@ def option_picker(sample_id, data_descriptions): plot_survival(y_preds[otm.name], y_trues[otm.name], f'{otm.name.upper()} Model:{args.id}', otm.days_window) elif otm.is_categorical(): plot_roc(y_preds[otm.name], y_trues[otm.name], otm.channel_map, f'{otm.name} ROC') - plot_precision_recall_per_class(y_preds[otm.name], y_trues[otm.name], otm.channel_map, - f'{otm.name} Precision Recall') + plot_precision_recall_per_class( + y_preds[otm.name], y_trues[otm.name], otm.channel_map, + f'{otm.name} Precision Recall', + ) elif otm.is_continuous(): plot_scatter(y_preds[otm.name], y_trues[otm.name], f'{otm.name} Scatter') @@ -561,7 +571,7 @@ def infer_from_dataloader(dataloader, model, tensor_maps_out, max_batches=125000 space_dict[f'{otm.name}_event'].append(str(sick[0])) space_dict[f'{otm.name}_follow_up'].append(str(follow_up[0])) for k in target: - if k in ['MRN', 'linker_id', 'is_c3po', 'output_age_in_days_continuous' ]: + if k in ['MRN', 'linker_id', 'is_c3po', 'output_age_in_days_continuous']: space_dict[f'{k}'].append(target[k][b].numpy()) elif k in ['datetime']: space_dict[f'{k}'].append(float_to_datetime(int(target[k][b].numpy()))) @@ -749,13 +759,17 @@ def infer_multimodal_multitask(args): hd5_path = os.path.join(args.output_folder, args.id, 'inferred_hd5s', f'{sample_id}{TENSOR_EXT}') os.makedirs(os.path.dirname(hd5_path), exist_ok=True) with h5py.File(hd5_path, 'a') as hd5: - hd5.create_dataset(f'{otm.name}_truth', data=otm.rescale(output_data[otm.output_name()][0]), - compression='gzip') + hd5.create_dataset( + f'{otm.name}_truth', data=otm.rescale(output_data[otm.output_name()][0]), + compression='gzip', + ) if otm.path_prefix == 'ukb_ecg_rest': for lead in otm.channel_map: - hd5.create_dataset(f'/ukb_ecg_rest/{lead}/instance_0', - data=otm.rescale(y[0, otm.channel_map[lead]]), - compression='gzip') + hd5.create_dataset( + f'/ukb_ecg_rest/{lead}/instance_0', + data=otm.rescale(y[0, otm.channel_map[lead]]), + compression='gzip', + ) inference_writer.writerow(csv_row) tensor_paths_inferred.add(tensor_paths[0]) stats['count'] += 1 diff --git a/ml4h/tensormap/ukb/mri.py b/ml4h/tensormap/ukb/mri.py index 828e5f059..2b1ca2010 100755 --- a/ml4h/tensormap/ukb/mri.py +++ b/ml4h/tensormap/ukb/mri.py @@ -2748,13 +2748,21 @@ def _pad_crop_single_channel(tm, hd5, dependents={}, key_prefix=None): img, ) -def _pad_crop_single_channel_t1map_b2(tm, hd5, dependents={}): - if f'/{tm.path_prefix}/shmolli_192i_sax_b2s_sax_b2s_sax_b2s_t1map/instance_2' in hd5: - key_prefix = f'/{tm.path_prefix}/shmolli_192i_sax_b2s_sax_b2s_sax_b2s_t1map/instance_2' - elif f'/{tm.path_prefix}/shmolli_192i_b2_sax_b2s_sax_b2s_sax_b2s_t1map/instance_2' in hd5: - key_prefix = f'/{tm.path_prefix}/shmolli_192i_b2_sax_b2s_sax_b2s_sax_b2s_t1map/instance_2' +def _pad_crop_single_channel_t1map_b2_instance(tm, hd5, instance): + if f'/{tm.path_prefix}/shmolli_192i_sax_b2s_sax_b2s_sax_b2s_t1map/instance_{instance}' in hd5: + key_prefix = f'/{tm.path_prefix}/shmolli_192i_sax_b2s_sax_b2s_sax_b2s_t1map/instance_{instance}' + elif f'/{tm.path_prefix}/shmolli_192i_b2_sax_b2s_sax_b2s_sax_b2s_t1map/instance_{instance}' in hd5: + key_prefix = f'/{tm.path_prefix}/shmolli_192i_b2_sax_b2s_sax_b2s_sax_b2s_t1map/instance_{instance}' else: raise ValueError(f'Could not find T1 Map image for tensormap: {tm.name}') + return key_prefix + +def _pad_crop_single_channel_t1map_b2(tm, hd5, dependents={}): + key_prefix = _pad_crop_single_channel_t1map_b2_instance(tm, hd5, 2) + return _pad_crop_single_channel(tm, hd5, dependents, key_prefix) + +def _pad_crop_single_channel_t1map_b2_instance_3(tm, hd5, dependents={}): + key_prefix = _pad_crop_single_channel_t1map_b2_instance(tm, hd5, 3) return _pad_crop_single_channel(tm, hd5, dependents, key_prefix) t1map_b2 = TensorMap( @@ -2765,6 +2773,14 @@ def _pad_crop_single_channel_t1map_b2(tm, hd5, dependents={}): tensor_from_file=_pad_crop_single_channel_t1map_b2, ) +t1map_b2_instance3 = TensorMap( + 'shmolli_192i_sax_b2s_sax_b2s_sax_b2s_t1map', + shape=(384, 384, 1), + path_prefix='ukb_cardiac_mri', + normalization=Standardize(mean=455.81, std=609.50), + tensor_from_file=_pad_crop_single_channel_t1map_b2_instance_3, +) + t1map_pancreas = TensorMap( 'shmolli_192i_pancreas_t1map', shape=(288, 384, 1),