Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SaveImage to use MetaTensor #4370

Merged
merged 11 commits into from
May 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/cron.yml
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ jobs:

cron-tutorial-notebooks:
if: github.repository == 'Project-MONAI/MONAI'
needs: cron-gpu # so that monai itself is verified first
# needs: cron-gpu # so that monai itself is verified first
container:
image: nvcr.io/nvidia/pytorch:22.04-py3 # testing with the latest pytorch base image
options: "--gpus all --ipc=host"
Expand Down
5 changes: 3 additions & 2 deletions monai/transforms/io/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import warnings
from pathlib import Path
from pydoc import locate
from typing import Dict, List, Optional, Sequence, Union
from typing import List, Optional, Sequence, Union

import numpy as np
import torch
Expand Down Expand Up @@ -386,12 +386,13 @@ def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, writ
if write_kwargs is not None:
self.write_kwargs.update(write_kwargs)

def __call__(self, img: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None):
def __call__(self, img: Union[torch.Tensor, np.ndarray]):
"""
Args:
img: target data content that save into file. The image should be channel-first, shape: `[C,H,W,[D]]`.
meta_data: key-value pairs of metadata corresponding to the data.
"""
meta_data = img.meta if isinstance(img, MetaTensor) else None
subject = meta_data[Key.FILENAME_OR_OBJ] if meta_data else str(self._data_index)
patch_index = meta_data.get(Key.PATCH_INDEX, None) if meta_data else None
filename = self.folder_layout.filename(subject=f"{subject}", idx=patch_index)
Expand Down
19 changes: 5 additions & 14 deletions monai/transforms/io/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from monai.data.image_reader import ImageReader
from monai.transforms.io.array import LoadImage, SaveImage
from monai.transforms.transform import MapTransform
from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode, deprecated_arg, ensure_tuple_rep
from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode, deprecated_arg
from monai.utils.enums import PostFix

__all__ = ["LoadImaged", "LoadImageD", "LoadImageDict", "SaveImaged", "SaveImageD", "SaveImageDict"]
Expand Down Expand Up @@ -129,12 +129,6 @@ class SaveImaged(MapTransform):
Args:
keys: keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
meta_keys: explicitly indicate the key of the corresponding metadata dictionary.
For example, for data with key `image`, the metadata by default is in `image_meta_dict`.
The metadata is a dictionary contains values such as filename, original_shape.
This argument can be a sequence of string, map to the `keys`.
If `None`, will try to construct meta_keys by `key_{meta_key_postfix}`.
meta_key_postfix: if `meta_keys` is `None`, use `key_{meta_key_postfix}` to retrieve the metadict.
output_dir: output image directory.
output_postfix: a string appended to all output file names, default to `trans`.
output_ext: output file extension name, available extensions: `.nii.gz`, `.nii`, `.png`.
Expand Down Expand Up @@ -189,6 +183,8 @@ class SaveImaged(MapTransform):

"""

@deprecated_arg(name="meta_keys", since="0.8", msg_suffix="Use MetaTensor input")
@deprecated_arg(name="meta_key_postfix", since="0.8", msg_suffix="Use MetaTensor input")
def __init__(
self,
keys: KeysCollection,
Expand All @@ -212,8 +208,6 @@ def __init__(
writer: Union[image_writer.ImageWriter, str, None] = None,
) -> None:
super().__init__(keys, allow_missing_keys)
self.meta_keys = ensure_tuple_rep(meta_keys, len(self.keys))
self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys))
self.saver = SaveImage(
output_dir=output_dir,
output_postfix=output_postfix,
Expand All @@ -237,11 +231,8 @@ def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, writ

def __call__(self, data):
d = dict(data)
for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix):
if meta_key is None and meta_key_postfix is not None:
meta_key = f"{key}_{meta_key_postfix}"
meta_data = d[meta_key] if meta_key is not None else None
self.saver(img=d[key], meta_data=meta_data)
for key in self.key_iterator(d):
self.saver(img=d[key])
return d


Expand Down
22 changes: 17 additions & 5 deletions monai/transforms/meta_utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
"""

from copy import deepcopy
from typing import Dict, Hashable, Mapping
from typing import Dict, Hashable, Mapping, Optional

from monai.config.type_definitions import NdarrayOrTensor
from monai.config.type_definitions import KeysCollection, NdarrayOrTensor
from monai.data.meta_tensor import MetaTensor
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.transform import MapTransform
from monai.utils.enums import PostFix, TransformBackends
from monai.utils.misc import ensure_tuple_rep

__all__ = [
"FromMetaTensord",
Expand Down Expand Up @@ -78,13 +79,24 @@ class ToMetaTensord(MapTransform, InvertibleTransform):

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, keys: KeysCollection, meta_keys: Optional[KeysCollection] = None) -> None:
wyli marked this conversation as resolved.
Show resolved Hide resolved
"""
Args:
keys: keys to be converted to MetaTensor.
meta_keys: keys to fetch `PostFix.meta` and `PostFix.transforms` from.
"""
super().__init__(keys=keys)
self.meta_keys = ensure_tuple_rep(meta_keys, len(self.keys))

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
for key, meta_key in zip(self.keys, self.meta_keys):
self.push_transform(d, key)
im = d[key]
meta = d.pop(PostFix.meta(key), None)
transforms = d.pop(PostFix.transforms(key), None)
meta_dict_key = PostFix.meta(meta_key if meta_key is not None else key) # type: ignore
meta = d.pop(meta_dict_key, None)
meta_transforms_key = PostFix.transforms(meta_key if meta_key is not None else key) # type: ignore
transforms = d.pop(meta_transforms_key, None)
im = MetaTensor(im, meta=meta, applied_operations=transforms) # type: ignore
d[key] = im
return d
Expand Down
25 changes: 14 additions & 11 deletions tests/test_image_rw.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch
from parameterized import parameterized

from monai.data.image_reader import ITKReader, NibabelReader, NrrdReader, PILReader
from monai.data.image_reader import ITKReader, MetaTensor, NibabelReader, NrrdReader, PILReader
from monai.data.image_writer import ITKWriter, NibabelWriter, PILWriter, register_writer, resolve_writer
from monai.transforms import LoadImage, SaveImage, moveaxis
from monai.utils import OptionalImportError
Expand All @@ -42,14 +42,13 @@ def nifti_rw(self, test_data, reader, writer, dtype, resample=True):
saver = SaveImage(
output_dir=self.test_dir, output_ext=output_ext, resample=resample, separate_folder=False, writer=writer
)
saver(
p(test_data),
{
"filename_or_obj": f"{filepath}.png",
"affine": np.eye(4),
"original_affine": np.array([[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]),
},
)
meta_dict = {
"filename_or_obj": f"{filepath}.png",
"affine": np.eye(4),
"original_affine": np.array([[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]),
}
test_data = MetaTensor(p(test_data), meta=meta_dict)
saver(test_data)
saved_path = os.path.join(self.test_dir, filepath + "_trans" + output_ext)
self.assertTrue(os.path.exists(saved_path))
loader = LoadImage(reader=reader, squeeze_non_spatial_dims=True)
Expand Down Expand Up @@ -97,7 +96,8 @@ def png_rw(self, test_data, reader, writer, dtype, resample=True):
saver = SaveImage(
output_dir=self.test_dir, output_ext=output_ext, resample=resample, separate_folder=False, writer=writer
)
saver(p(test_data), {"filename_or_obj": f"{filepath}.png", "spatial_shape": (6, 8)})
test_data = MetaTensor(p(test_data), meta={"filename_or_obj": f"{filepath}.png", "spatial_shape": (6, 8)})
saver(test_data)
saved_path = os.path.join(self.test_dir, filepath + "_trans" + output_ext)
self.assertTrue(os.path.exists(saved_path))
loader = LoadImage(reader=reader)
Expand Down Expand Up @@ -151,7 +151,10 @@ def nrrd_rw(self, test_data, reader, writer, dtype, resample=True):
saver = SaveImage(
output_dir=self.test_dir, output_ext=output_ext, resample=resample, separate_folder=False, writer=writer
)
saver(p(test_data), {"filename_or_obj": f"{filepath}{output_ext}", "spatial_shape": test_data.shape})
test_data = MetaTensor(
p(test_data), meta={"filename_or_obj": f"{filepath}{output_ext}", "spatial_shape": test_data.shape}
)
saver(test_data)
saved_path = os.path.join(self.test_dir, filepath + "_trans" + output_ext)
loader = LoadImage(reader=reader)
data = loader(saved_path)
Expand Down
9 changes: 4 additions & 5 deletions tests/test_integration_bundle_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,26 +87,25 @@ def test_shape(self, config_file, expected_shape):
# test override with the whole overriding file
json.dump("Dataset", f)

saver = LoadImage(image_only=True)

if sys.platform == "win32":
override = "--network $@network_def.to(@device) --dataset#_target_ Dataset"
else:
override = f"--network %{overridefile1}#move_net --dataset#_target_ %{overridefile2}"
# test with `monai.bundle` as CLI entry directly
cmd = f"-m monai.bundle run evaluating --postprocessing#transforms#2#output_postfix seg {override}"
cmd = f"-m monai.bundle run evaluating --postprocessing#transforms#3#output_postfix seg {override}"
la = ["coverage", "run"] + cmd.split(" ") + ["--meta_file", meta_file] + ["--config_file", config_file]
test_env = os.environ.copy()
print(f"CUDA_VISIBLE_DEVICES in {__file__}", test_env.get("CUDA_VISIBLE_DEVICES"))
subprocess.check_call(la + ["--args_file", def_args_file], env=test_env)
self.assertTupleEqual(saver(os.path.join(tempdir, "image", "image_seg.nii.gz")).shape, expected_shape)
loader = LoadImage()
self.assertTupleEqual(loader(os.path.join(tempdir, "image", "image_seg.nii.gz")).shape, expected_shape)

# here test the script with `google fire` tool as CLI
cmd = "-m fire monai.bundle.scripts run --runner_id evaluating"
cmd += f" --evaluator#amp False {override}"
la = ["coverage", "run"] + cmd.split(" ") + ["--meta_file", meta_file] + ["--config_file", config_file]
subprocess.check_call(la, env=test_env)
self.assertTupleEqual(saver(os.path.join(tempdir, "image", "image_trans.nii.gz")).shape, expected_shape)
self.assertTupleEqual(loader(os.path.join(tempdir, "image", "image_trans.nii.gz")).shape, expected_shape)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions tests/test_integration_sliding_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ignite.engine import Engine, Events
from torch.utils.data import DataLoader

from monai.data import ImageDataset, create_test_image_3d, decollate_batch
from monai.data import ImageDataset, MetaTensor, create_test_image_3d, decollate_batch
from monai.inferers import sliding_window_inference
from monai.networks import eval_mode, predict_segmentation
from monai.networks.nets import UNet
Expand Down Expand Up @@ -49,7 +49,7 @@ def _sliding_window_processor(_engine, batch):
def save_func(engine):
meta_data = decollate_batch(engine.state.batch[2])
for m, o in zip(meta_data, engine.state.output):
saver(o, m)
saver(MetaTensor(o, meta=m))

infer_engine = Engine(_sliding_window_processor)
infer_engine.add_event_handler(Events.ITERATION_COMPLETED, save_func)
Expand Down
10 changes: 5 additions & 5 deletions tests/test_integration_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from torch.utils.tensorboard import SummaryWriter

import monai
from monai.data import create_test_image_3d, decollate_batch
from monai.data import MetaTensor, create_test_image_3d, decollate_batch
from monai.engines import IterationEvents, SupervisedEvaluator, SupervisedTrainer
from monai.handlers import (
CheckpointLoader,
Expand All @@ -51,6 +51,7 @@
SaveImage,
SaveImaged,
ScaleIntensityd,
ToMetaTensord,
ToTensord,
)
from monai.utils import set_determinism
Expand Down Expand Up @@ -253,9 +254,8 @@ def run_inference_test(root_dir, model_file, device="cuda:0", amp=False, num_wor
AsDiscreted(keys="pred", threshold=0.5),
KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
# test the case that `pred` in `engine.state.output`, while `image_meta_dict` in `engine.state.batch`
SaveImaged(
keys="pred", meta_keys=PostFix.meta("image"), output_dir=root_dir, output_postfix="seg_transform"
),
ToMetaTensord(keys="pred", meta_keys="image"),
SaveImaged(keys="pred", output_dir=root_dir, output_postfix="seg_transform"),
]
)
val_handlers = [
Expand All @@ -270,7 +270,7 @@ def save_func(engine):
if isinstance(meta_data, dict):
meta_data = decollate_batch(meta_data)
for m, o in zip(meta_data, from_engine("pred")(engine.state.output)):
saver(o, m)
saver(MetaTensor(o, meta=m))

evaluator = SupervisedEvaluator(
device=device,
Expand Down
10 changes: 1 addition & 9 deletions tests/test_load_imaged.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from monai.data.meta_tensor import MetaTensor
from monai.transforms import Compose, EnsureChannelFirstD, FromMetaTensord, LoadImaged, SaveImageD
from monai.transforms.meta_utility.dictionary import ToMetaTensord
from monai.utils.enums import PostFix
from tests.utils import assert_allclose

KEYS = ["image", "label", "extra"]
Expand Down Expand Up @@ -97,14 +96,7 @@ def _cmp(self, filename, ch_shape, reader_1, reader_2, outname, ext):
self.assertTupleEqual(img_dict["img"].shape, ch_shape)

with tempfile.TemporaryDirectory() as tempdir:
save_xform = Compose(
[
FromMetaTensord(keys),
SaveImageD(
keys, meta_keys=PostFix.meta("img"), output_dir=tempdir, squeeze_end_dims=False, output_ext=ext
),
]
)
save_xform = SaveImageD(keys, output_dir=tempdir, squeeze_end_dims=False, output_ext=ext)
save_xform(img_dict) # save to nifti

new_xforms = Compose(
Expand Down
7 changes: 3 additions & 4 deletions tests/test_resample_to_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,11 @@ def test_correct(self, reader, writer):
with self.assertRaises(ValueError):
ResampleToMatch(mode=None)(img=data["im2"], img_dst=data["im1"])
im_mod = ResampleToMatch()(data["im2"], data["im1"])
im_mod, meta = im_mod.as_tensor(), im_mod.meta
saver = SaveImaged("im3", output_dir=self.tmpdir, output_postfix="", separate_folder=False, writer=writer)
meta["filename_or_obj"] = get_rand_fname()
saver({"im3": im_mod, "im3_meta_dict": meta})
im_mod.meta["filename_or_obj"] = get_rand_fname()
saver({"im3": im_mod})

saved = nib.load(os.path.join(self.tmpdir, meta["filename_or_obj"]))
saved = nib.load(os.path.join(self.tmpdir, im_mod.meta["filename_or_obj"]))
assert_allclose(data["im1"].shape[1:], saved.shape)
assert_allclose(saved.header["dim"][:4], np.array([3, 384, 384, 19]))

Expand Down
9 changes: 6 additions & 3 deletions tests/test_save_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
import tempfile
import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.data.meta_tensor import MetaTensor
from monai.transforms import SaveImage

TEST_CASE_1 = [torch.randint(0, 255, (1, 2, 3, 4)), {"filename_or_obj": "testfile0.nii.gz"}, ".nii.gz", False]
Expand All @@ -26,7 +26,7 @@
TEST_CASE_3 = [torch.randint(0, 255, (1, 2, 3, 4)), {"filename_or_obj": "testfile0.nrrd"}, ".nrrd", False]

TEST_CASE_4 = [
np.random.randint(0, 255, (3, 2, 4, 5), dtype=np.uint8),
torch.randint(0, 255, (3, 2, 4, 5), dtype=torch.uint8),
{"filename_or_obj": "testfile0.dcm"},
".dcm",
False,
Expand All @@ -36,14 +36,17 @@
class TestSaveImage(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
def test_saved_content(self, test_data, meta_data, output_ext, resample):
if meta_data is not None:
test_data = MetaTensor(test_data, meta=meta_data)

with tempfile.TemporaryDirectory() as tempdir:
trans = SaveImage(
output_dir=tempdir,
output_ext=output_ext,
resample=resample,
separate_folder=False, # test saving into the same folder
)
trans(test_data, meta_data)
trans(test_data)

filepath = "testfile0" if meta_data is not None else "0"
self.assertTrue(os.path.exists(os.path.join(tempdir, filepath + "_trans" + output_ext)))
Expand Down
Loading