Skip to content

Commit

Permalink
Bug fix and improvement in WSI (Project-MONAI#4216)
Browse files Browse the repository at this point in the history
* Make all transforms optional

Signed-off-by: Behrooz <[email protected]>

* Update wsireader tests

Signed-off-by: Behrooz <[email protected]>

* Remove optional from PersistentDataset and its derivatives

Signed-off-by: Behrooz <[email protected]>

* Add unittests for cache without transform

Signed-off-by: Behrooz <[email protected]>

* Add default replace_rate

Signed-off-by: Behrooz <[email protected]>

* Add default value

Signed-off-by: Behrooz <[email protected]>

* Set default replace_rate to 0.1

Signed-off-by: Behrooz <[email protected]>

* Update metadata to include path

Signed-off-by: Behrooz <[email protected]>

* Adds SmartCachePatchWSIDataset

Signed-off-by: Behrooz <[email protected]>

* Add unittests for SmartCachePatchWSIDataset

Signed-off-by: Behrooz <[email protected]>

* Update references

Signed-off-by: Behrooz <[email protected]>

* Update docs

Signed-off-by: Behrooz <[email protected]>

* Remove smart cache

Signed-off-by: Behrooz <[email protected]>

* Remove unused imports

Signed-off-by: Behrooz <[email protected]>

* Add path metadata for OpenSlide

Signed-off-by: Behrooz <[email protected]>

* Update metadata to be unified across different backends

Signed-off-by: Behrooz <[email protected]>

* Update wsi metadata for multi wsi objects

Signed-off-by: Behrooz <[email protected]>

* Add unittests for wsi metadata

Signed-off-by: Behrooz <[email protected]>
  • Loading branch information
bhashemian authored and Can-Zhao committed May 10, 2022
1 parent fd8c1de commit 6a22ed6
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 78 deletions.
16 changes: 9 additions & 7 deletions monai/data/wsi_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.

import inspect
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, Optional, Sequence, Tuple, Union

import numpy as np

Expand All @@ -32,10 +32,12 @@ class PatchWSIDataset(Dataset):
size: the size of patch to be extracted from the whole slide image.
level: the level at which the patches to be extracted (default to 0).
transform: transforms to be executed on input data.
reader: the module to be used for loading whole slide imaging,
- if `reader` is a string, it defines the backend of `monai.data.WSIReader`. Defaults to cuCIM.
- if `reader` is a class (inherited from `BaseWSIReader`), it is initialized and set as wsi_reader.
- if `reader` is an instance of a a class inherited from `BaseWSIReader`, it is set as the wsi_reader.
reader: the module to be used for loading whole slide imaging. If `reader` is
- a string, it defines the backend of `monai.data.WSIReader`. Defaults to cuCIM.
- a class (inherited from `BaseWSIReader`), it is initialized and set as wsi_reader.
- an instance of a a class inherited from `BaseWSIReader`, it is set as the wsi_reader.
kwargs: additional arguments to pass to `WSIReader` or provided whole slide reader class
Note:
Expand All @@ -45,14 +47,14 @@ class PatchWSIDataset(Dataset):
[
{"image": "path/to/image1.tiff", "location": [200, 500], "label": 0},
{"image": "path/to/image2.tiff", "location": [100, 700], "label": 1}
{"image": "path/to/image2.tiff", "location": [100, 700], "size": [20, 20], "level": 2, "label": 1}
]
"""

def __init__(
self,
data: List,
data: Sequence,
size: Optional[Union[int, Tuple[int, int]]] = None,
level: Optional[int] = None,
transform: Optional[Callable] = None,
Expand Down
118 changes: 54 additions & 64 deletions monai/data/wsi_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# limitations under the License.

from abc import abstractmethod
from os.path import abspath
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -53,6 +54,7 @@ class BaseWSIReader(ImageReader):
"""

supported_suffixes: List[str] = []
backend = ""

def __init__(self, level: int, **kwargs):
super().__init__()
Expand All @@ -63,7 +65,7 @@ def __init__(self, level: int, **kwargs):
@abstractmethod
def get_size(self, wsi, level: int) -> Tuple[int, int]:
"""
Returns the size of the whole slide image at a given level.
Returns the size (height, width) of the whole slide image at a given level.
Args:
wsi: a whole slide image object loaded from a file
Expand All @@ -83,6 +85,11 @@ def get_level_count(self, wsi) -> int:
"""
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")

@abstractmethod
def get_file_path(self, wsi) -> str:
"""Return the file path for the WSI object"""
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")

@abstractmethod
def get_patch(
self, wsi, location: Tuple[int, int], size: Tuple[int, int], level: int, dtype: DtypeLike, mode: str
Expand All @@ -102,20 +109,29 @@ def get_patch(
"""
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")

@abstractmethod
def get_metadata(self, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int) -> Dict:
def get_metadata(
self, wsi, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int
) -> Dict:
"""
Returns metadata of the extracted patch from the whole slide image.
Args:
wsi: the whole slide image object, from which the patch is loaded
patch: extracted patch from whole slide image
location: (top, left) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0).
size: (height, width) tuple giving the patch size at the given level (`level`).
If None, it is set to the full image size at the given level.
level: the level number. Defaults to 0
"""
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
metadata: Dict = {
"backend": self.backend,
"original_channel_dim": 0,
"spatial_shape": np.asarray(patch.shape[1:]),
"wsi": {"path": self.get_file_path(wsi)},
"patch": {"location": location, "size": size, "level": level},
}
return metadata

def get_data(
self,
Expand Down Expand Up @@ -194,8 +210,26 @@ def get_data(
patch_list.append(patch)

# Set patch-related metadata
each_meta = self.get_metadata(patch=patch, location=location, size=size, level=level)
metadata.update(each_meta)
each_meta = self.get_metadata(wsi=each_wsi, patch=patch, location=location, size=size, level=level)

if len(wsi) == 1:
metadata = each_meta
else:
if not metadata:
metadata = {
"backend": each_meta["backend"],
"original_channel_dim": each_meta["original_channel_dim"],
"spatial_shape": each_meta["spatial_shape"],
"wsi": [each_meta["wsi"]],
"patch": [each_meta["patch"]],
}
else:
if metadata["original_channel_dim"] != each_meta["original_channel_dim"]:
raise ValueError("original_channel_dim is not consistent across wsi objects.")
if any(metadata["spatial_shape"] != each_meta["spatial_shape"]):
raise ValueError("spatial_shape is not consistent across wsi objects.")
metadata["wsi"].append(each_meta["wsi"])
metadata["patch"].append(each_meta["patch"])

return _stack_images(patch_list, metadata), metadata

Expand Down Expand Up @@ -247,7 +281,7 @@ def get_level_count(self, wsi) -> int:

def get_size(self, wsi, level: int) -> Tuple[int, int]:
"""
Returns the size of the whole slide image at a given level.
Returns the size (height, width) of the whole slide image at a given level.
Args:
wsi: a whole slide image object loaded from a file
Expand All @@ -256,19 +290,9 @@ def get_size(self, wsi, level: int) -> Tuple[int, int]:
"""
return self.reader.get_size(wsi, level)

def get_metadata(self, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int) -> Dict:
"""
Returns metadata of the extracted patch from the whole slide image.
Args:
patch: extracted patch from whole slide image
location: (top, left) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0).
size: (height, width) tuple giving the patch size at the given level (`level`).
If None, it is set to the full image size at the given level.
level: the level number. Defaults to 0
"""
return self.reader.get_metadata(patch=patch, size=size, location=location, level=level)
def get_file_path(self, wsi) -> str:
"""Return the file path for the WSI object"""
return self.reader.get_file_path(wsi)

def get_patch(
self, wsi, location: Tuple[int, int], size: Tuple[int, int], level: int, dtype: DtypeLike, mode: str
Expand Down Expand Up @@ -317,6 +341,7 @@ class CuCIMWSIReader(BaseWSIReader):
"""

supported_suffixes = ["tif", "tiff", "svs"]
backend = "cucim"

def __init__(self, level: int = 0, **kwargs):
super().__init__(level, **kwargs)
Expand All @@ -335,7 +360,7 @@ def get_level_count(wsi) -> int:
@staticmethod
def get_size(wsi, level: int) -> Tuple[int, int]:
"""
Returns the size of the whole slide image at a given level.
Returns the size (height, width) of the whole slide image at a given level.
Args:
wsi: a whole slide image object loaded from a file
Expand All @@ -344,27 +369,9 @@ def get_size(wsi, level: int) -> Tuple[int, int]:
"""
return (wsi.resolutions["level_dimensions"][level][1], wsi.resolutions["level_dimensions"][level][0])

def get_metadata(self, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int) -> Dict:
"""
Returns metadata of the extracted patch from the whole slide image.
Args:
patch: extracted patch from whole slide image
location: (top, left) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0).
size: (height, width) tuple giving the patch size at the given level (`level`).
If None, it is set to the full image size at the given level.
level: the level number. Defaults to 0
"""
metadata: Dict = {
"backend": "cucim",
"spatial_shape": np.asarray(patch.shape[1:]),
"original_channel_dim": 0,
"location": location,
"size": size,
"level": level,
}
return metadata
def get_file_path(self, wsi) -> str:
"""Return the file path for the WSI object"""
return str(abspath(wsi.path))

def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs):
"""
Expand Down Expand Up @@ -440,6 +447,7 @@ class OpenSlideWSIReader(BaseWSIReader):
"""

supported_suffixes = ["tif", "tiff", "svs"]
backend = "openslide"

def __init__(self, level: int = 0, **kwargs):
super().__init__(level, **kwargs)
Expand All @@ -458,7 +466,7 @@ def get_level_count(wsi) -> int:
@staticmethod
def get_size(wsi, level: int) -> Tuple[int, int]:
"""
Returns the size of the whole slide image at a given level.
Returns the size (height, width) of the whole slide image at a given level.
Args:
wsi: a whole slide image object loaded from a file
Expand All @@ -467,27 +475,9 @@ def get_size(wsi, level: int) -> Tuple[int, int]:
"""
return (wsi.level_dimensions[level][1], wsi.level_dimensions[level][0])

def get_metadata(self, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int) -> Dict:
"""
Returns metadata of the extracted patch from the whole slide image.
Args:
patch: extracted patch from whole slide image
location: (top, left) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0).
size: (height, width) tuple giving the patch size at the given level (`level`).
If None, it is set to the full image size at the given level.
level: the level number. Defaults to 0
"""
metadata: Dict = {
"backend": "openslide",
"spatial_shape": np.asarray(patch.shape[1:]),
"original_channel_dim": 0,
"location": location,
"size": size,
"level": level,
}
return metadata
def get_file_path(self, wsi) -> str:
"""Return the file path for the WSI object"""
return str(abspath(wsi._filename))

def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs):
"""
Expand Down
29 changes: 22 additions & 7 deletions tests/test_wsireader_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,13 @@ class Tests(unittest.TestCase):
def test_read_whole_image(self, file_path, level, expected_shape):
reader = WSIReader(self.backend, level=level)
with reader.read(file_path) as img_obj:
img = reader.get_data(img_obj)[0]
img, meta = reader.get_data(img_obj)
self.assertTupleEqual(img.shape, expected_shape)
self.assertEqual(meta["backend"], self.backend)
self.assertEqual(meta["wsi"]["path"], str(os.path.abspath(file_path)))
self.assertEqual(meta["patch"]["level"], level)
self.assertTupleEqual(meta["patch"]["size"], expected_shape[1:])
self.assertTupleEqual(meta["patch"]["location"], (0, 0))

@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
def test_read_region(self, file_path, patch_info, expected_img):
Expand All @@ -138,29 +143,39 @@ def test_read_region(self, file_path, patch_info, expected_img):
reader.get_data(img_obj, **patch_info)[0]
else:
# Read twice to check multiple calls
img = reader.get_data(img_obj, **patch_info)[0]
img, meta = reader.get_data(img_obj, **patch_info)
img2 = reader.get_data(img_obj, **patch_info)[0]
self.assertTupleEqual(img.shape, img2.shape)
self.assertIsNone(assert_array_equal(img, img2))
self.assertTupleEqual(img.shape, expected_img.shape)
self.assertIsNone(assert_array_equal(img, expected_img))
self.assertEqual(meta["backend"], self.backend)
self.assertEqual(meta["wsi"]["path"], str(os.path.abspath(file_path)))
self.assertEqual(meta["patch"]["level"], patch_info["level"])
self.assertTupleEqual(meta["patch"]["size"], expected_img.shape[1:])
self.assertTupleEqual(meta["patch"]["location"], patch_info["location"])

@parameterized.expand([TEST_CASE_3])
def test_read_region_multi_wsi(self, file_path, patch_info, expected_img):
def test_read_region_multi_wsi(self, file_path_list, patch_info, expected_img):
kwargs = {"name": None, "offset": None} if self.backend == "tifffile" else {}
reader = WSIReader(self.backend, **kwargs)
img_obj = reader.read(file_path, **kwargs)
img_obj_list = reader.read(file_path_list, **kwargs)
if self.backend == "tifffile":
with self.assertRaises(ValueError):
reader.get_data(img_obj, **patch_info)[0]
reader.get_data(img_obj_list, **patch_info)[0]
else:
# Read twice to check multiple calls
img = reader.get_data(img_obj, **patch_info)[0]
img2 = reader.get_data(img_obj, **patch_info)[0]
img, meta = reader.get_data(img_obj_list, **patch_info)
img2 = reader.get_data(img_obj_list, **patch_info)[0]
self.assertTupleEqual(img.shape, img2.shape)
self.assertIsNone(assert_array_equal(img, img2))
self.assertTupleEqual(img.shape, expected_img.shape)
self.assertIsNone(assert_array_equal(img, expected_img))
self.assertEqual(meta["backend"], self.backend)
self.assertEqual(meta["wsi"][0]["path"], str(os.path.abspath(file_path_list[0])))
self.assertEqual(meta["patch"][0]["level"], patch_info["level"])
self.assertTupleEqual(meta["patch"][0]["size"], expected_img.shape[1:])
self.assertTupleEqual(meta["patch"][0]["location"], patch_info["location"])

@parameterized.expand([TEST_CASE_RGB_0, TEST_CASE_RGB_1])
@skipUnless(has_tiff, "Requires tifffile.")
Expand Down

0 comments on commit 6a22ed6

Please sign in to comment.