Skip to content

Commit

Permalink
4462 dataset summary for metatensor (#4728)
Browse files Browse the repository at this point in the history
* dataset summary for metatensor

Signed-off-by: Wenqi Li <[email protected]>

* dataset summary support metatensor

Signed-off-by: Wenqi Li <[email protected]>

* fixes tests

Signed-off-by: Wenqi Li <[email protected]>
  • Loading branch information
wyli authored Jul 20, 2022
1 parent 422cc6d commit 73cd27e
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 15 deletions.
33 changes: 24 additions & 9 deletions monai/data/dataset_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from itertools import chain
from typing import List, Optional

Expand All @@ -18,9 +19,10 @@
from monai.config import KeysCollection
from monai.data.dataloader import DataLoader
from monai.data.dataset import Dataset
from monai.data.meta_tensor import MetaTensor
from monai.data.utils import affine_to_spacing
from monai.transforms import concatenate
from monai.utils import PostFix, convert_data_type
from monai.utils import PostFix, convert_data_type, convert_to_tensor

DEFAULT_POST_FIX = PostFix.meta()

Expand All @@ -30,9 +32,9 @@ class DatasetSummary:
This class provides a way to calculate a reasonable output voxel spacing according to
the input dataset. The achieved values can used to resample the input in 3d segmentation tasks
(like using as the `pixdim` parameter in `monai.transforms.Spacingd`).
In addition, it also supports to count the mean, std, min and max intensities of the input,
In addition, it also supports to compute the mean, std, min and max intensities of the input,
and these statistics are helpful for image normalization
(like using in `monai.transforms.ScaleIntensityRanged` and `monai.transforms.NormalizeIntensityd`).
(as parameters of `monai.transforms.ScaleIntensityRanged` and `monai.transforms.NormalizeIntensityd`).
The algorithm for calculation refers to:
`Automated Design of Deep Learning Methods for Biomedical Image Segmentation <https://arxiv.org/abs/1904.08128>`_.
Expand All @@ -58,6 +60,7 @@ def __init__(
for example, for data with key `image`, the metadata by default is in `image_meta_dict`.
the metadata is a dictionary object which contains: filename, affine, original_shape, etc.
if None, will try to construct meta_keys by `{image_key}_{meta_key_postfix}`.
This is not required if `data[image_key]` is a MetaTensor.
meta_key_postfix: use `{image_key}_{meta_key_postfix}` to fetch the metadata from dict,
the metadata is a dictionary object (default: ``meta_dict``).
num_workers: how many subprocesses to use for data loading.
Expand All @@ -80,17 +83,21 @@ def collect_meta_data(self):
"""

for data in self.data_loader:
if self.meta_key not in data:
raise ValueError(f"To collect metadata for the dataset, key `{self.meta_key}` must exist in `data`.")
self.all_meta_data.append(data[self.meta_key])
if isinstance(data[self.image_key], MetaTensor):
meta_dict = data[self.image_key].meta
elif self.meta_key in data:
meta_dict = data[self.meta_key]
else:
warnings.warn(f"To collect metadata for the dataset, `{self.meta_key}` or `data.meta` must exist.")
self.all_meta_data.append(meta_dict)

def get_target_spacing(self, spacing_key: str = "affine", anisotropic_threshold: int = 3, percentile: float = 10.0):
"""
Calculate the target spacing according to all spacings.
If the target spacing is very anisotropic,
decrease the spacing value of the maximum axis according to percentile.
So far, this function only supports NIFTI images which store spacings in headers with key "pixdim".
After loading with `monai.DataLoader`, "pixdim" is in the form of `torch.Tensor` with size `(batch_size, 8)`.
The spacing is computed from `affine_to_spacing(data[spacing_key][0], 3)` if `data[spacing_key]` is a matrix,
otherwise, the `data[spacing_key]` must be a vector of pixdim values.
Args:
spacing_key: key of the affine used to compute spacing in metadata (default: ``affine``).
Expand All @@ -103,7 +110,15 @@ def get_target_spacing(self, spacing_key: str = "affine", anisotropic_threshold:
self.collect_meta_data()
if spacing_key not in self.all_meta_data[0]:
raise ValueError("The provided spacing_key is not in self.all_meta_data.")
spacings = [affine_to_spacing(data[spacing_key][0], 3)[None] for data in self.all_meta_data]
spacings = []
for data in self.all_meta_data:
spacing_vals = convert_to_tensor(data[spacing_key][0], track_meta=False, wrap_sequence=True)
if spacing_vals.ndim == 1: # vector
spacings.append(spacing_vals[:3][None])
elif spacing_vals.ndim == 2: # matrix
spacings.append(affine_to_spacing(spacing_vals, 3)[None])
else:
raise ValueError("data[spacing_key] must be a vector or a matrix.")
all_spacings = concatenate(to_cat=spacings, axis=0)
all_spacings, *_ = convert_data_type(data=all_spacings, output_type=np.ndarray, wrap_sequence=True)

Expand Down
9 changes: 3 additions & 6 deletions tests/test_dataset_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,8 @@
from monai.data import Dataset, DatasetSummary, create_test_image_3d
from monai.transforms import LoadImaged
from monai.transforms.compose import Compose
from monai.transforms.meta_utility.dictionary import FromMetaTensord
from monai.transforms.utility.dictionary import ToNumpyd
from monai.utils import set_determinism
from monai.utils.enums import PostFix


def test_collate(batch):
Expand Down Expand Up @@ -56,7 +54,6 @@ def test_spacing_intensity(self):
t = Compose(
[
LoadImaged(keys=["image", "label"]),
FromMetaTensord(keys=["image", "label"]),
ToNumpyd(keys=["image", "label", "image_meta_dict", "label_meta_dict"]),
]
)
Expand All @@ -65,7 +62,7 @@ def test_spacing_intensity(self):
# test **kwargs of `DatasetSummary` for `DataLoader`
calculator = DatasetSummary(dataset, num_workers=4, meta_key="image_meta_dict", collate_fn=test_collate)

target_spacing = calculator.get_target_spacing()
target_spacing = calculator.get_target_spacing(spacing_key="pixdim")
self.assertEqual(target_spacing, (1.0, 1.0, 1.0))
calculator.calculate_statistics()
np.testing.assert_allclose(calculator.data_mean, 0.892599, rtol=1e-5, atol=1e-5)
Expand Down Expand Up @@ -93,10 +90,10 @@ def test_anisotropic_spacing(self):
{"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels)
]

t = Compose([LoadImaged(keys=["image", "label"]), FromMetaTensord(keys=["image", "label"])])
t = Compose([LoadImaged(keys=["image", "label"])])
dataset = Dataset(data=data_dicts, transform=t)

calculator = DatasetSummary(dataset, num_workers=4, meta_key_postfix=PostFix.meta())
calculator = DatasetSummary(dataset, num_workers=4)

target_spacing = calculator.get_target_spacing(anisotropic_threshold=4.0, percentile=20.0)
np.testing.assert_allclose(target_spacing, (1.0, 1.0, 1.8))
Expand Down

0 comments on commit 73cd27e

Please sign in to comment.