-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial release of code for running inference with DROID --------- Co-authored-by: Christopher Reeder <[email protected]>
- Loading branch information
1 parent
28a5eea
commit b0c05a0
Showing
8 changed files
with
835 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# DROID (Dimensional Reconstruction of Imaging Data) | ||
|
||
DROID is a 3-D convolutional neural network modeling approach for echocardiographic view | ||
classification and quantification of LA dimension, LV wall thickness, chamber diameter and | ||
ejection fraction. | ||
|
||
The DROID echo movie encoder is based on the | ||
[MoViNet-A2-Base](https://tfhub.dev/tensorflow/movinet/a2/base/kinetics-600/classification/3) | ||
video classification model. MoViNet was fine-tuned in a supervised fashion to produce two | ||
specialized encoders: | ||
- DROID-LA | ||
- input views: PLAX, A4C, A2C | ||
- output predictions: LA A/P | ||
- DROID-LV | ||
- input views: PLAX, A4C, A2C | ||
- output predictions: LVEF, LVEDD, LVESD, IVS, PWT | ||
|
||
Multi-instance attention heads were then trained to integrate up to 40 view encodings to predict | ||
a single measurement of each type per echo study. | ||
|
||
## Requirements | ||
In addition to the `ml4h` repository, DROID also requires `ml4ht_data_source` plus other dependencies. First, clone the | ||
ml4h repositories: | ||
```commandline | ||
git clone https://github.com/broadinstitute/ml4h.git | ||
git clone https://github.com/broadinstitute/ml4ht_data_source.git | ||
``` | ||
|
||
For convenience, we provide a docker image containing additional dependencies: | ||
```commandline | ||
docker run -it --gpus all --rm -v {PARENT_DIRECTORY_OF_REPOS} -v {OPTIONAL_DATA_DIRECTORY} \ | ||
us-central1-docker.pkg.dev/broad-ml4cvd/droid/droid:0.1 /bin/bash | ||
``` | ||
|
||
Within the docker container, install `ml4ht`: | ||
```commandline | ||
pip install --user ml4ht_data_source | ||
``` | ||
|
||
## Usage | ||
### Preprocessing | ||
The following scripts are designed to handle echo movies that have been processed and stored in Lightning | ||
Memory-Mapped Database (lmdb) files. We create one lmdb per echo study in which the keys are the filenames of the dicoms and | ||
the values are echo movies that have been anonymized, cropped, and converted to avis. See `echo_to_lmdb.py` for an | ||
example. | ||
|
||
### Inference | ||
`echo_supervised_inference_recipe.py` can be used to obtain predictions from echo movies given either the DROID-LA or | ||
DROID-LV specialized encoders. | ||
|
||
An example of parameters to use when running this script are: | ||
```commandline | ||
python echo_supervised_inference_recipe.py \ | ||
--n_input_frames 16 \ | ||
--output_labels LA_A_P \ | ||
--selected_views A4C --selected_views A2C --selected_views PLAX \ | ||
--selected_doppler standard \ | ||
--selected_quality good \ | ||
--selected_canonical on_axis \ | ||
--split_idx 0 \ | ||
--n_splits 1 \ | ||
--skip_modulo 4 \ | ||
--wide_file {WIDE_FILE_PATH} \ | ||
--splits_file {SPLITS_JSON} \ | ||
--lmdb_folder {LMDB_DIRECTORY_PATH} \ | ||
--pretrained_ckpt_dir {SPECIALIZED_ENCODER_PATH} \ | ||
--movinet_ckpt_dir {MoViNet-A2-Base_PATH} \ | ||
--output_dir {WHERE_TO_STORE_PREDICTIONS} | ||
``` |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import os | ||
import io | ||
import av | ||
import itertools | ||
|
||
import lmdb | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
from PIL import ImageFile | ||
|
||
ImageFile.LOAD_TRUNCATED_IMAGES = True | ||
|
||
from ml4ht.data.data_description import DataDescription | ||
|
||
VIEW_OPTION_KEY = 'view' | ||
|
||
metadata_elements = ['PhotometricInterpretation', | ||
'TransferSyntaxUID', | ||
'SamplesPerPixel', | ||
'BitsAllocated', | ||
'BitsStored', | ||
'HighBit', | ||
'PixelRepresentation', | ||
'PlanarConfiguration', | ||
'NumberOfFrames', | ||
'Rows', | ||
'Columns', | ||
] | ||
|
||
|
||
class LmdbEchoStudyVideoDataDescription(DataDescription): | ||
|
||
def __init__( | ||
self, | ||
local_lmdb_dir: str, | ||
name: str, | ||
transforms=None, | ||
nframes: int = None, | ||
skip_modulo: int = 1, | ||
start_beat=0, | ||
): | ||
|
||
self.local_lmdb_dir = local_lmdb_dir | ||
self._name = name | ||
self.nframes = nframes | ||
self.nframes = (nframes + start_beat) * skip_modulo | ||
self.start_beat = start_beat | ||
# transformations | ||
self.transforms = transforms or [] | ||
self.skip_modulo = skip_modulo | ||
|
||
def get_loading_options(self, sample_id): | ||
_, study, view = sample_id.split('_') | ||
lmdb_folder = os.path.join(self.local_lmdb_dir, f"{study}.lmdb") | ||
lmdb_log = pd.read_parquet(os.path.join(lmdb_folder, f'log_{study}.pq')).set_index('view') | ||
lmdb_log = lmdb_log[lmdb_log['stored']] | ||
|
||
if view not in lmdb_log.index: | ||
raise ValueError('View not saved in the LMDB') | ||
|
||
return [ | ||
{VIEW_OPTION_KEY: view} | ||
] | ||
|
||
def get_raw_data(self, sample_id, loading_option=None): | ||
try: | ||
sample_id = sample_id.decode('UTF-8') | ||
except (UnicodeDecodeError, AttributeError): | ||
pass | ||
_, study, view = sample_id.split('_') | ||
|
||
lmdb_folder = os.path.join(self.local_lmdb_dir, f"{study}.lmdb") | ||
|
||
env = lmdb.open(lmdb_folder, readonly=True, lock=False) | ||
nframes = self.nframes | ||
|
||
frames = [] | ||
with env.begin(buffers=True) as txn: | ||
in_mem_bytes_io = io.BytesIO(txn.get(view.encode('utf-8'))) | ||
video_container = av.open(in_mem_bytes_io, metadata_errors="ignore") | ||
video_frames = itertools.cycle(video_container.decode(video=0)) | ||
for i, frame in enumerate(video_frames): | ||
if i == nframes: | ||
break | ||
if i < (self.start_beat * self.skip_modulo): | ||
continue | ||
if self.skip_modulo > 1: | ||
if (i % self.skip_modulo) != 0: | ||
continue | ||
frame = np.array(frame.to_image()) | ||
for transform in self.transforms: | ||
frame = transform(frame, loading_option) | ||
frames.append(frame) | ||
del video_frames | ||
video_container.close() | ||
env.close() | ||
return np.squeeze(np.array(frames, dtype='float32') / 255.) | ||
|
||
@property | ||
def name(self): | ||
return self._name |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
category_dictionaries = { | ||
'view': { | ||
'PLAX': 0, | ||
'Ascending_aorta': 1, | ||
'RV_inflow': 2, | ||
'RV_focused': 3, | ||
'Pulmonary_artery': 4, | ||
'PSAX_AV': 5, | ||
'PSAX_MV': 6, | ||
'PSAX_papillary': 7, | ||
'PSAX_apex': 8, | ||
'A4C': 9, | ||
'A5C': 10, | ||
'A3C': 11, | ||
'A2C': 12, | ||
'Suprasternal': 13, | ||
'Subcostal': 14 | ||
}, | ||
'doppler': { | ||
'standard': 0, | ||
'doppler': 1, | ||
'3-D': 2 | ||
}, | ||
|
||
'quality': { | ||
'good': 0, | ||
'unusable': 1, | ||
}, | ||
'canonical': { | ||
'on_axis': 0, | ||
'off_axis': 1 | ||
}, | ||
'LV_EjectionFraction': { | ||
'N': { | ||
'index': 0, | ||
'weight': 0.259667, | ||
}, | ||
'A': { | ||
'index': 1, | ||
'weight': 0.862008, | ||
}, | ||
'I': { | ||
'index': 2, | ||
'weight': 0.916131, | ||
}, | ||
'L': { | ||
'index': 3, | ||
'weight': 0.980843, | ||
}, | ||
'H': { | ||
'index': 0, | ||
'weight': 0.981351, | ||
} | ||
}, | ||
'LV_FunctionDescription': { | ||
'4.0': { | ||
'index': 0, | ||
'weight': 0.520803, | ||
}, | ||
'2.0': { | ||
'index': 1, | ||
'weight': 0.662169, | ||
}, | ||
'3.0': { | ||
'index': 2, | ||
'weight': 0.817028, | ||
} | ||
}, | ||
'LV_CavitySize': { | ||
'N': { | ||
'index': 0, | ||
'weight': 0.209487, | ||
}, | ||
'D': { | ||
'index': 1, | ||
'weight': 0.833406, | ||
}, | ||
'S': { | ||
'index': 2, | ||
'weight': 0.957354, | ||
}, | ||
'P': { | ||
'index': 3, | ||
'weight': 1.0 | ||
} | ||
}, | ||
'RV_SystolicFunction': { | ||
'N': { | ||
'index': 0, | ||
'weight': 0.19156206811684748, | ||
}, | ||
'Y': { | ||
'index': 1, | ||
'weight': 2.5944871794871798, | ||
}, | ||
'A': { | ||
'index': 2, | ||
'weight': 4.161422989923915, | ||
}, | ||
'L': { | ||
'index': 3, | ||
'weight': 8.256629946960423 | ||
} | ||
} | ||
} |
Oops, something went wrong.