diff --git a/model_zoo/DROID/README.md b/model_zoo/DROID/README.md new file mode 100644 index 000000000..3129ba731 --- /dev/null +++ b/model_zoo/DROID/README.md @@ -0,0 +1,69 @@ +# DROID (Dimensional Reconstruction of Imaging Data) + +DROID is a 3-D convolutional neural network modeling approach for echocardiographic view +classification and quantification of LA dimension, LV wall thickness, chamber diameter and +ejection fraction. + +The DROID echo movie encoder is based on the +[MoViNet-A2-Base](https://tfhub.dev/tensorflow/movinet/a2/base/kinetics-600/classification/3) +video classification model. MoViNet was fine-tuned in a supervised fashion to produce two +specialized encoders: +- DROID-LA + - input views: PLAX, A4C, A2C + - output predictions: LA A/P +- DROID-LV + - input views: PLAX, A4C, A2C + - output predictions: LVEF, LVEDD, LVESD, IVS, PWT + +Multi-instance attention heads were then trained to integrate up to 40 view encodings to predict +a single measurement of each type per echo study. + +## Requirements +In addition to the `ml4h` repository, DROID also requires `ml4ht_data_source` plus other dependencies. First, clone the +ml4h repositories: +```commandline +git clone https://github.com/broadinstitute/ml4h.git +git clone https://github.com/broadinstitute/ml4ht_data_source.git +``` + +For convenience, we provide a docker image containing additional dependencies: +```commandline +docker run -it --gpus all --rm -v {PARENT_DIRECTORY_OF_REPOS} -v {OPTIONAL_DATA_DIRECTORY} \ +us-central1-docker.pkg.dev/broad-ml4cvd/droid/droid:0.1 /bin/bash +``` + +Within the docker container, install `ml4ht`: +```commandline +pip install --user ml4ht_data_source +``` + +## Usage +### Preprocessing +The following scripts are designed to handle echo movies that have been processed and stored in Lightning +Memory-Mapped Database (lmdb) files. We create one lmdb per echo study in which the keys are the filenames of the dicoms and +the values are echo movies that have been anonymized, cropped, and converted to avis. See `echo_to_lmdb.py` for an +example. + +### Inference +`echo_supervised_inference_recipe.py` can be used to obtain predictions from echo movies given either the DROID-LA or +DROID-LV specialized encoders. + +An example of parameters to use when running this script are: +```commandline +python echo_supervised_inference_recipe.py \ + --n_input_frames 16 \ + --output_labels LA_A_P \ + --selected_views A4C --selected_views A2C --selected_views PLAX \ + --selected_doppler standard \ + --selected_quality good \ + --selected_canonical on_axis \ + --split_idx 0 \ + --n_splits 1 \ + --skip_modulo 4 \ + --wide_file {WIDE_FILE_PATH} \ + --splits_file {SPLITS_JSON} \ + --lmdb_folder {LMDB_DIRECTORY_PATH} \ + --pretrained_ckpt_dir {SPECIALIZED_ENCODER_PATH} \ + --movinet_ckpt_dir {MoViNet-A2-Base_PATH} \ + --output_dir {WHERE_TO_STORE_PREDICTIONS} +``` \ No newline at end of file diff --git a/model_zoo/DROID/data_descriptions/__init__.py b/model_zoo/DROID/data_descriptions/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/model_zoo/DROID/data_descriptions/echo.py b/model_zoo/DROID/data_descriptions/echo.py new file mode 100644 index 000000000..29f8a4d29 --- /dev/null +++ b/model_zoo/DROID/data_descriptions/echo.py @@ -0,0 +1,103 @@ +import os +import io +import av +import itertools + +import lmdb + +import numpy as np +import pandas as pd + +from PIL import ImageFile + +ImageFile.LOAD_TRUNCATED_IMAGES = True + +from ml4ht.data.data_description import DataDescription + +VIEW_OPTION_KEY = 'view' + +metadata_elements = ['PhotometricInterpretation', + 'TransferSyntaxUID', + 'SamplesPerPixel', + 'BitsAllocated', + 'BitsStored', + 'HighBit', + 'PixelRepresentation', + 'PlanarConfiguration', + 'NumberOfFrames', + 'Rows', + 'Columns', + ] + + +class LmdbEchoStudyVideoDataDescription(DataDescription): + + def __init__( + self, + local_lmdb_dir: str, + name: str, + transforms=None, + nframes: int = None, + skip_modulo: int = 1, + start_beat=0, + ): + + self.local_lmdb_dir = local_lmdb_dir + self._name = name + self.nframes = nframes + self.nframes = (nframes + start_beat) * skip_modulo + self.start_beat = start_beat + # transformations + self.transforms = transforms or [] + self.skip_modulo = skip_modulo + + def get_loading_options(self, sample_id): + _, study, view = sample_id.split('_') + lmdb_folder = os.path.join(self.local_lmdb_dir, f"{study}.lmdb") + lmdb_log = pd.read_parquet(os.path.join(lmdb_folder, f'log_{study}.pq')).set_index('view') + lmdb_log = lmdb_log[lmdb_log['stored']] + + if view not in lmdb_log.index: + raise ValueError('View not saved in the LMDB') + + return [ + {VIEW_OPTION_KEY: view} + ] + + def get_raw_data(self, sample_id, loading_option=None): + try: + sample_id = sample_id.decode('UTF-8') + except (UnicodeDecodeError, AttributeError): + pass + _, study, view = sample_id.split('_') + + lmdb_folder = os.path.join(self.local_lmdb_dir, f"{study}.lmdb") + + env = lmdb.open(lmdb_folder, readonly=True, lock=False) + nframes = self.nframes + + frames = [] + with env.begin(buffers=True) as txn: + in_mem_bytes_io = io.BytesIO(txn.get(view.encode('utf-8'))) + video_container = av.open(in_mem_bytes_io, metadata_errors="ignore") + video_frames = itertools.cycle(video_container.decode(video=0)) + for i, frame in enumerate(video_frames): + if i == nframes: + break + if i < (self.start_beat * self.skip_modulo): + continue + if self.skip_modulo > 1: + if (i % self.skip_modulo) != 0: + continue + frame = np.array(frame.to_image()) + for transform in self.transforms: + frame = transform(frame, loading_option) + frames.append(frame) + del video_frames + video_container.close() + env.close() + return np.squeeze(np.array(frames, dtype='float32') / 255.) + + @property + def name(self): + return self._name diff --git a/model_zoo/DROID/echo_defines.py b/model_zoo/DROID/echo_defines.py new file mode 100644 index 000000000..2427b4eaa --- /dev/null +++ b/model_zoo/DROID/echo_defines.py @@ -0,0 +1,105 @@ +category_dictionaries = { + 'view': { + 'PLAX': 0, + 'Ascending_aorta': 1, + 'RV_inflow': 2, + 'RV_focused': 3, + 'Pulmonary_artery': 4, + 'PSAX_AV': 5, + 'PSAX_MV': 6, + 'PSAX_papillary': 7, + 'PSAX_apex': 8, + 'A4C': 9, + 'A5C': 10, + 'A3C': 11, + 'A2C': 12, + 'Suprasternal': 13, + 'Subcostal': 14 + }, + 'doppler': { + 'standard': 0, + 'doppler': 1, + '3-D': 2 + }, + + 'quality': { + 'good': 0, + 'unusable': 1, + }, + 'canonical': { + 'on_axis': 0, + 'off_axis': 1 + }, + 'LV_EjectionFraction': { + 'N': { + 'index': 0, + 'weight': 0.259667, + }, + 'A': { + 'index': 1, + 'weight': 0.862008, + }, + 'I': { + 'index': 2, + 'weight': 0.916131, + }, + 'L': { + 'index': 3, + 'weight': 0.980843, + }, + 'H': { + 'index': 0, + 'weight': 0.981351, + } + }, + 'LV_FunctionDescription': { + '4.0': { + 'index': 0, + 'weight': 0.520803, + }, + '2.0': { + 'index': 1, + 'weight': 0.662169, + }, + '3.0': { + 'index': 2, + 'weight': 0.817028, + } + }, + 'LV_CavitySize': { + 'N': { + 'index': 0, + 'weight': 0.209487, + }, + 'D': { + 'index': 1, + 'weight': 0.833406, + }, + 'S': { + 'index': 2, + 'weight': 0.957354, + }, + 'P': { + 'index': 3, + 'weight': 1.0 + } + }, + 'RV_SystolicFunction': { + 'N': { + 'index': 0, + 'weight': 0.19156206811684748, + }, + 'Y': { + 'index': 1, + 'weight': 2.5944871794871798, + }, + 'A': { + 'index': 2, + 'weight': 4.161422989923915, + }, + 'L': { + 'index': 3, + 'weight': 8.256629946960423 + } + } +} diff --git a/model_zoo/DROID/echo_supervised_inference_recipe.py b/model_zoo/DROID/echo_supervised_inference_recipe.py new file mode 100644 index 000000000..3d26eab7e --- /dev/null +++ b/model_zoo/DROID/echo_supervised_inference_recipe.py @@ -0,0 +1,207 @@ +import argparse +import json +import logging +import os + +import numpy as np +import pandas as pd +import tensorflow as tf + +from data_descriptions.echo import LmdbEchoStudyVideoDataDescription +from echo_defines import category_dictionaries +from model_descriptions.echo import DDGenerator, create_movinet_classifier, create_regressor + +logging.basicConfig(level=logging.INFO) +tf.get_logger().setLevel(logging.ERROR) + + +def main( + n_input_frames, + output_labels, + wide_file, + splits_file, + selected_views, + selected_doppler, + selected_quality, + selected_canonical, + n_train_patients, + split_idx, + n_splits, + batch_size, + skip_modulo, + lmdb_folder, + pretrained_ckpt_dir, + movinet_ckpt_dir, + output_dir, + extract_embeddings, + start_beat +): + # Hide devices based on split + physical_devices = tf.config.list_physical_devices('GPU') + tf.config.set_visible_devices([physical_devices[split_idx % 4]], 'GPU') + + wide_df = pd.read_parquet(wide_file) + + # Select only view(s) of interest + selected_views_idx = [category_dictionaries['view'][v] for v in selected_views] + selected_doppler_idx = [category_dictionaries['doppler'][v] for v in selected_doppler] + selected_quality_idx = [category_dictionaries['quality'][v] for v in selected_quality] + selected_canonical_idx = [category_dictionaries['canonical'][v] for v in selected_canonical] + wide_df_selected = wide_df[ + (wide_df['view_prediction'].isin(selected_views_idx)) & + (wide_df['doppler_prediction'].isin(selected_doppler_idx)) & + (wide_df['quality_prediction'].isin(selected_quality_idx)) & + (wide_df['canonical_prediction'].isin(selected_canonical_idx)) + ] + + # Fill entries without measurements and get all sample_ids + for olabel in output_labels: + wide_df_selected.loc[wide_df_selected[olabel].isna(), olabel] = -1 + working_ids = wide_df_selected['sample_id'].values.tolist() + + # Read splits and partition dataset + with open(splits_file, 'r') as json_file: + splits = json.load(json_file) + + patient_train = splits['patient_train'] + patient_valid = splits['patient_valid'] + + if n_train_patients != 'all': + patient_train = patient_train[:int(int(n_train_patients) * 0.9)] + patient_valid = patient_valid[:int(int(n_train_patients) * 0.1)] + + if 'trainvalid' in lmdb_folder: + patient_inference = patient_train + patient_valid + if 'patient_internal_test' in splits: + patient_inference = patient_inference + splits['patient_internal_test'] + else: + patient_inference = splits['patient_test'] + + inference_ids = sorted([t for t in working_ids if int(t.split('_')[0]) in patient_inference]) + + INPUT_DD = LmdbEchoStudyVideoDataDescription( + lmdb_folder, + 'image', + [], + n_input_frames, + skip_modulo, + start_beat=start_beat + ) + + inference_ids_split = np.array_split(inference_ids, n_splits)[split_idx] + body_inference_ids = tf.data.Dataset.from_tensor_slices(inference_ids_split).batch(batch_size, drop_remainder=False) + n_inference_steps = len(inference_ids_split) // batch_size + int((len(inference_ids_split) % batch_size) > 0.5) + + io_inference_ds = body_inference_ids.interleave( + lambda sample_ids: tf.data.Dataset.from_generator( + DDGenerator( + INPUT_DD, + None + ), + output_signature=( + tf.TensorSpec(shape=(None, n_input_frames, 224, 224, 3), dtype=tf.float32), + ), + args=(sample_ids,) + ) + ) + + model, backbone = create_movinet_classifier( + n_input_frames, + batch_size, + num_classes=600, + checkpoint_dir=movinet_ckpt_dir, + ) + + backbone_output = backbone.layers[-1].output[0] + flatten = tf.keras.layers.Flatten()(backbone_output) + encoder = tf.keras.Model(inputs=[backbone.input], outputs=[flatten]) + model_plus_head = create_regressor( + encoder, + input_shape=(n_input_frames, 224, 224, 3), + n_output_features=len(output_labels) + ) + model_plus_head.load_weights(pretrained_ckpt_dir) + + vois = '_'.join(selected_views) + ufm = 'conv7' + if extract_embeddings: + output_folder = os.path.join(output_dir, + f'inference_embeddings_{vois}_{ufm}_{lmdb_folder.split("/")[-1]}_{splits_file}_{start_beat}') + else: + output_folder = os.path.join(output_dir, + f'inference_{vois}_{ufm}_{lmdb_folder.split("/")[-1]}_{splits_file}_{start_beat}') + os.makedirs(output_folder, exist_ok=True) + + if extract_embeddings: + embeddings = encoder.predict(io_inference_ds, steps=n_inference_steps, verbose=1) + df = pd.DataFrame() + df['sample_id'] = inference_ids_split + for j, _ in enumerate(range(embeddings.shape[1])): + df[f'embedding_{j}'] = embeddings[:, j] + + df.to_parquet(os.path.join(output_folder, f'prediction_{split_idx}.pq')) + else: + predictions = model_plus_head.predict(io_inference_ds, steps=n_inference_steps, verbose=1) + df = pd.DataFrame() + df['sample_id'] = inference_ids_split + for j, _ in enumerate(range(predictions.shape[1])): + df[f'prediction_{j}'] = predictions[:, j] + + df.to_parquet(os.path.join(output_folder, f'prediction_{split_idx}.pq')) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument('--n_input_frames', type=int, default=50) + parser.add_argument('-o', '--output_labels', action='append') + parser.add_argument('--wide_file', type=str) + parser.add_argument('--splits_file') + parser.add_argument('-v', '--selected_views', action='append', choices=category_dictionaries['view'].keys(), + required=True) + parser.add_argument('-d', '--selected_doppler', action='append', choices=category_dictionaries['doppler'].keys(), + required=True) + parser.add_argument('-q', '--selected_quality', action='append', choices=category_dictionaries['quality'].keys(), + required=True) + parser.add_argument('-c', '--selected_canonical', action='append', + choices=category_dictionaries['canonical'].keys(), required=True) + parser.add_argument('-n', '--n_train_patients', default='all') + parser.add_argument('--split_idx', type=int, choices=range(4)) + parser.add_argument('--n_splits', type=int, default=4) + parser.add_argument('--batch_size', default=16, type=int) + parser.add_argument('--skip_modulo', type=int, default=1) + parser.add_argument('--lmdb_folder') + parser.add_argument('--pretrained_ckpt_dir', type=str) + parser.add_argument('--movinet_ckpt_dir', type=str) + parser.add_argument('--output_dir', type=str) + parser.add_argument('--extract_embeddings', action='store_true') + parser.add_argument('--start_beat', type=int, default=0) + + args = parser.parse_args() + root = logging.getLogger() + root.setLevel(logging.INFO) + + for arg, value in sorted(vars(args).items()): + logging.info(f"Argument {arg}: {value}") + + main( + n_input_frames=args.n_input_frames, + output_labels=args.output_labels, + wide_file=args.wide_file, + splits_file=args.splits_file, + selected_views=args.selected_views, + selected_doppler=args.selected_doppler, + selected_quality=args.selected_quality, + selected_canonical=args.selected_canonical, + n_train_patients=args.n_train_patients, + split_idx=args.split_idx, + n_splits=args.n_splits, + batch_size=args.batch_size, + skip_modulo=args.skip_modulo, + lmdb_folder=args.lmdb_folder, + pretrained_ckpt_dir=args.pretrained_ckpt_dir, + movinet_ckpt_dir=args.movinet_ckpt_dir, + output_dir=args.output_dir, + extract_embeddings=args.extract_embeddings, + start_beat=args.start_beat + ) diff --git a/model_zoo/DROID/echo_to_lmdb.py b/model_zoo/DROID/echo_to_lmdb.py new file mode 100644 index 000000000..dea8f6e39 --- /dev/null +++ b/model_zoo/DROID/echo_to_lmdb.py @@ -0,0 +1,269 @@ +import argparse +import glob +import io +import logging +import os +import tarfile + +import av +import cv2 +import lmdb +import numpy as np +import pandas as pd +import pydicom +import skimage.measure +import skimage.morphology + + +def get_largest_connected_area(segmentation): + labels = skimage.measure.label(segmentation) + assert (labels.max() != 0) # assume at least 1 CC + largestCC = labels == np.argmax(np.bincount(labels.flat)[1:]) + 1 + return largestCC + + +def remove_small_structures(mask): + # mask = skimage.morphology.closing(mask, np.ones((20, 20), np.float32)) + mask = skimage.morphology.opening(mask, np.ones((20, 20), np.float32)) + + return mask + + +def get_most_central_connected_area(segmentation): + height, width = segmentation.shape + segmentation[height // 2:height // 2 + 200, width // 2 - 100:width // 2 + 100] = 1 + labels = skimage.measure.label(segmentation) + + assert (labels.max() != 0) # assume at least 1 CC + + central_region = labels == np.argmax( + np.bincount(labels[height // 2 - 100:height // 2 + 100, width // 2 - 100:width // 2 + 100].flat)) + return central_region + + +def anonymize_echo_cone(x): + if len(x.shape) != 4: + raise ValueError('Only views with several frames can be anonymized') + + frames, height, width, channels = x.shape + x_r = x[..., 0] + temp = np.sum(x_r, 0).clip(0, 255).astype(np.uint8) + # cone = skimage.morphology.closing(temp>100.0, np.ones((50, 50))) + cone = get_most_central_connected_area(temp > 100.0) + cone = remove_small_structures(cone) + + max_rows = cone.max(axis=1) + max_cols = cone.max(axis=0) + tallest_pixel = np.argwhere(max_rows).min() + shortest_pixel = np.argwhere(max_rows).max() + leftmost_pixel = np.argwhere(max_cols).min() + rightmost_pixel = np.argwhere(max_cols).max() + shortest_rightmost_pixel = np.argwhere(cone[:, rightmost_pixel]).max() + tallest_pixel = min(60, tallest_pixel) + + cone[tallest_pixel:shortest_rightmost_pixel, width // 2:] = 1 + output = np.zeros_like(x_r) + for i in range(x_r.shape[0]): + output[i, :, :] = x_r[i, ...] * cone + return output, leftmost_pixel + + +def lmdb_to_gif(lmdb_folder, view, output_path=None): + env = lmdb.open(lmdb_folder, readonly=True, create=False) + with env.begin(buffers=True) as txn: + in_mem_bytes_io = io.BytesIO(txn.get(view.encode('utf-8'))) + video_container = av.open(in_mem_bytes_io, metadata_errors="ignore") + layers = [] + for frame in video_container.decode(video=0): + layers.append(frame.to_image()) + layers[0].save(f'{lmdb_folder}/{view}.gif', + save_all=True, append_images=layers[1:], loop=0) + + +def anonymize_echo(x): + if (len(x.shape) != 4) or (x.shape[0] < 2): + raise ValueError('Only views with several frames can be anonymized') + + # Some hardcoded hyperparams here, we might want to set as arguments + blur_size = 60 + unblur_size = 40 + eps = 10 + frames, height, width, channels = x.shape + temp = np.where(x < 5, 0, x) + temp = np.sum(temp, 0).clip(0, 255).astype(np.uint8) + + gray = cv2.cvtColor(temp, cv2.COLOR_BGR2GRAY) + kernel = np.ones((blur_size, blur_size), np.float32) / (blur_size ** 2) + filtered_gray = cv2.filter2D(gray, -1, kernel) + ret, thresh = cv2.threshold(filtered_gray, 250, 255, cv2.THRESH_BINARY_INV) + + mask = 1 - thresh.clip(0, 1) + mask[0:height // 10, :] = 0 + kernel = np.ones((unblur_size, unblur_size), np.float32) + filtered_mask = cv2.filter2D(mask, -1, kernel).clip(0, 1) + filtered_mask = np.where(filtered_mask == 0, 0, 1) + inside_mask = np.where(filtered_mask == 1) + + left_bottom_x = min(inside_mask[1]) + right_bottom_x = max(inside_mask[1]) + top_y = min(inside_mask[0]) + left_top_x = min(inside_mask[1][inside_mask[0] == top_y]) + right_top_x = max(inside_mask[1][inside_mask[0] == top_y]) + delta = blur_size + left_bottom_x += delta + left_top_x -= delta + right_top_x += delta + right_bottom_x -= delta + left_bottom_y = min(inside_mask[0][inside_mask[1] == left_bottom_x]) + left_top_y = min(inside_mask[0][inside_mask[1] == left_top_x]) + right_bottom_y = min(inside_mask[0][inside_mask[1] == right_bottom_x]) + right_top_y = min(inside_mask[0][inside_mask[1] == right_top_x]) + + left_slope = (left_top_y - left_bottom_y) / (left_top_x - left_bottom_x) + left_x_intercept = -left_bottom_y / left_slope + left_bottom_x + leftmost = [left_slope, left_x_intercept] + right_slope = (right_top_y - right_bottom_y) / (right_top_x - right_bottom_x) + right_x_intercept = -right_bottom_y / right_slope + right_bottom_x + rightmost = [right_slope, right_x_intercept] + + m1, m2 = np.meshgrid(np.arange(width), np.arange(height)) + # use epsilon to avoid masking part of the echo + mask = leftmost[0] * (m1 - leftmost[1]) - eps < m2 + mask *= rightmost[0] * (m1 - rightmost[1]) - eps < m2 + mask = np.reshape(mask, (height, width)).astype(np.int8) + mask[top_y + delta:] = 0 + filtered_mask += mask + filtered_mask = filtered_mask.clip(0, 1) + + max_rows = filtered_mask.max(axis=1) + max_cols = filtered_mask.max(axis=0) + tallest_pixel = np.argwhere(max_rows).min() + leftmost_pixel = np.argwhere(max_cols).min() + tallest_pixel = max(80, tallest_pixel) + filtered_mask[tallest_pixel:, width // 2:] = 1 + + output = np.zeros_like(x) + for i in range(frames): + for c in range(channels): + output[i, :, :, c] = x[i, :, :, c] * filtered_mask + return output, leftmost_pixel + + +def array_to_cropped_avi(array, output_path, fps, target_size, leftmost_pixel=0): + frames, height, width = array.shape + + if frames < 2: + raise ValueError('You cannot save a video with no frames') + + fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G') + out = cv2.VideoWriter(output_path, fourcc, fps, target_size) + + for i in range(frames): + outputA = array[i, :, :] + max_width = max(width, leftmost_pixel + height) + smallOutput = outputA[:, leftmost_pixel:max_width] + + # Resize image + output = cv2.resize(smallOutput, target_size, interpolation=cv2.INTER_CUBIC) + + finaloutput = cv2.merge([output, output, output]) + out.write(finaloutput) + out.release() + + +def targz_to_avis(study_id, study_folder, tar): + if os.path.isfile(os.path.join(study_folder, f'{study_id}.lmdb/data.mdb')): + logging.warning(f'Skipping {study_id} as it already exists') + return + + log_dic = {'study': [], 'view': [], 'log': [], 'stored': []} + target_size = (224, 224) + + if tar: + try: + with tarfile.open(os.path.join(study_folder, f'{study_id}.tar.gz'), 'r:gz') as targz: + targz.extractall(study_folder) + except: + logging.warning(f'Extraction failed for {study_folder}{study_id}.tar.gz') + return + + dicom_paths = glob.glob(os.path.join(study_folder, str(study_id), '*')) + env = lmdb.open(os.path.join(study_folder, f'{study_id}.lmdb'), map_size=2 ** 32 - 1) + with env.begin(write=True) as txn: + for dicom_path in dicom_paths: + dicom_filename = os.path.basename(dicom_path) + log_dic['view'].append(dicom_filename) + log_dic['study'].append(study_id) + logging.info(f'Reading {dicom_path}') + try: + dcm = pydicom.dcmread(dicom_path, force=True) + testarray = dcm.pixel_array + testarray_anon, leftmost_pixel = anonymize_echo_cone(testarray) + except Exception as e: + error_msg = f'{dicom_path}: {e}' + log_dic['log'].append(error_msg) + log_dic['stored'].append(False) + logging.warning(error_msg) + os.remove(dicom_path) + continue + + fps = 30 + try: + fps = dcm['CineRate'].value + except: + logging.info("Could not find frame rate, default to 30") + + video_path = os.path.join(study_folder, f'{study_id}.lmdb', f'{dicom_filename}.avi') + array_to_cropped_avi( + testarray_anon, + video_path, + fps, + target_size, + leftmost_pixel + ) + + # Save avi into lmdb + with open(video_path, 'rb') as avi: + logging.info(f'Adding {dicom_filename} to the transaction') + txn.put(key=dicom_filename.encode('utf-8'), value=avi.read()) + log_dic['log'].append('') + log_dic['stored'].append(True) + logging.info(f"Successfully stored {dicom_filename} into LMDB") + + log_df = pd.DataFrame(log_dic) + log_df.to_parquet(os.path.join(study_folder, f'{study_id}.lmdb', f'log_{study_id}.pq')) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--df_path', help='path to a csv that must contain a column named study with unique identifiers' + 'that correspond to directory names in the study_folder') + parser.add_argument('--study_folder', help='path to the directory containing one subdirectory per echo study') + parser.add_argument('--start', default=-1, help='optional, row in csv of the study to start processing') + parser.add_argument('--end', default=-1, help='optional, row in csv of the study to end processing (inclusive)') + parser.add_argument('--tar', action='store_true', help='indicates that study folders are stored as .tar.gz files') + return parser.parse_args() + + +def main(**kwargs): + df_path = kwargs['df_path'] + study_folder = kwargs['study_folder'] + start = int(kwargs.get('start')) + end = int(kwargs.get('end')) + tar = kwargs.get('tar') + + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + + df = pd.read_csv(df_path) + for i, row in df.iterrows(): + if i < start: + continue + if -1 < end <= i: + continue + study = int(row['study']) + targz_to_avis(study, study_folder, tar) + + +if __name__ == '__main__': + ARGS = parse_args() + main(**vars(ARGS)) diff --git a/model_zoo/DROID/model_descriptions/__init__.py b/model_zoo/DROID/model_descriptions/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/model_zoo/DROID/model_descriptions/echo.py b/model_zoo/DROID/model_descriptions/echo.py new file mode 100644 index 000000000..fe15aa497 --- /dev/null +++ b/model_zoo/DROID/model_descriptions/echo.py @@ -0,0 +1,82 @@ +import numpy as np +import tensorflow as tf + +# from official.common import flags as tfm_flags +from official.vision.beta.projects.movinet.modeling import movinet, movinet_model + +learning_rate = 0.0001 +hidden_units = 256 +dropout_rate = 0.5 +temperature = 0.05 + + +class DDGenerator: + def __init__(self, input_dd, output_dd, fill_empty=False): + self.input_dd = input_dd + self.output_dd = output_dd + self.fill_empty = fill_empty + + def __call__(self, sample_ids): + ret_input = [] + ret_output = [] + for sample_id in sample_ids: + ret_input.append( + self.input_dd.get_raw_data(sample_id) + ) + if self.output_dd is not None: + ret_output.append( + self.output_dd.get_raw_data(sample_id) + ) + if self.fill_empty: + ret_output.append(np.NaN) + + if self.output_dd is None and self.fill_empty == False: + yielded = (ret_input,) + else: + yielded = (ret_input, ret_output) + yield yielded + + +def create_movinet_classifier( + n_input_frames, + batch_size, + checkpoint_dir, + num_classes, + freeze_backbone=False +): + backbone = movinet.Movinet(model_id='a2') + model = movinet_model.MovinetClassifier(backbone=backbone, num_classes=600) + model.build([1, 1, 1, 1, 3]) + checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir) + checkpoint = tf.train.Checkpoint(model=model) + status = checkpoint.restore(checkpoint_path) + status.assert_existing_objects_matched() + + model = movinet_model.MovinetClassifier( + backbone=backbone, + num_classes=num_classes + ) + model.build([batch_size, n_input_frames, 224, 224, 3]) + + if freeze_backbone: + for layer in model.layers[:-1]: + layer.trainable = False + model.layers[-1].trainable = True + + return model, backbone + + +def create_regressor(encoder, trainable=True, input_shape=(224, 224, 3), n_output_features=1): + for layer in encoder.layers: + layer.trainable = trainable + + inputs = tf.keras.Input(shape=input_shape, name='image') + features = encoder(inputs) + features = tf.keras.layers.Dropout(dropout_rate)(features) + features = tf.keras.layers.Dense(hidden_units, activation="relu")(features) + features = tf.keras.layers.Dropout(dropout_rate)(features) + outputs = tf.keras.layers.Dense(n_output_features, activation=None, name='echolab')(features) + + model = tf.keras.Model(inputs=inputs, outputs=outputs, name="regressor") + + return model