Skip to content

Commit baf3e47

Browse files
committed
typing fixes
1 parent b14bb60 commit baf3e47

File tree

2 files changed

+90
-40
lines changed

2 files changed

+90
-40
lines changed

viscy/data/hcs.py

+36-27
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from torch import Tensor
2626
from torch.utils.data import DataLoader, Dataset
2727

28-
from viscy.data.typing import ChannelMap, Sample
28+
from viscy.data.typing import ChannelMap, HCSStackIndex, NormMeta, Sample
2929

3030

3131
def _ensure_channel_list(str_or_seq: str | Sequence[str]) -> list[str]:
@@ -64,11 +64,11 @@ def _collate_samples(batch: Sequence[Sample]) -> Sample:
6464
as is the case with ``train_patches_per_stack > 1``.
6565
:return Sample: Batch sample (dictionary of tensors)
6666
"""
67-
collated = {}
67+
collated: Sample = {}
6868
for key in batch[0].keys():
6969
data = []
7070
for sample in batch:
71-
if isinstance(sample[key], list):
71+
if isinstance(sample[key], Sequence):
7272
data.extend(sample[key])
7373
else:
7474
data.append(sample[key])
@@ -84,7 +84,7 @@ class SlidingWindowDataset(Dataset):
8484
:param ChannelMap channels: source and target channel names,
8585
e.g. ``{'source': 'Phase', 'target': ['Nuclei', 'Membrane']}``
8686
:param int z_window_size: Z window size of the 2.5D U-Net, 1 for 2D
87-
:param Callable[[dict[str, Tensor]], dict[str, Tensor]] transform:
87+
:param Callable[[dict[str, Tensor]], dict[str, Tensor]] | None transform:
8888
a callable that transforms data, defaults to None
8989
"""
9090

@@ -93,7 +93,7 @@ def __init__(
9393
positions: list[Position],
9494
channels: ChannelMap,
9595
z_window_size: int,
96-
transform: Callable[[dict[str, Tensor]], dict[str, Tensor]] = None,
96+
transform: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
9797
) -> None:
9898
super().__init__()
9999
self.positions = positions
@@ -116,18 +116,18 @@ def _get_windows(self) -> None:
116116
w = 0
117117
self.window_keys = []
118118
self.window_arrays = []
119-
self.window_norm_meta = []
119+
self.window_norm_meta: list[NormMeta | None] = []
120120
for fov in self.positions:
121-
img_arr = fov["0"]
121+
img_arr: ImageArray = fov["0"]
122122
ts = img_arr.frames
123123
zs = img_arr.slices - self.z_window_size + 1
124124
w += ts * zs
125125
self.window_keys.append(w)
126126
self.window_arrays.append(img_arr)
127-
self.window_norm_meta.append(fov.zattrs.get("normalization", 0))
127+
self.window_norm_meta.append(fov.zattrs.get("normalization", None))
128128
self._max_window = w
129129

130-
def _find_window(self, index: int) -> tuple[int, int]:
130+
def _find_window(self, index: int) -> tuple[ImageArray, int, NormMeta | None]:
131131
"""Look up window given index."""
132132
window_idx = sorted(self.window_keys + [index + 1]).index(index + 1)
133133
w = self.window_keys[window_idx]
@@ -136,16 +136,16 @@ def _find_window(self, index: int) -> tuple[int, int]:
136136
return (self.window_arrays[self.window_keys.index(w)], tz, norm_meta)
137137

138138
def _read_img_window(
139-
self, img: ImageArray, ch_idx: list[str], tz: int
140-
) -> tuple[tuple[Tensor], tuple[str, int, int]]:
139+
self, img: ImageArray, ch_idx: list[int], tz: int
140+
) -> tuple[list[Tensor], HCSStackIndex]:
141141
"""Read image window as tensor.
142142
143143
:param ImageArray img: NGFF image array
144-
:param list[int] channels: list of channel indices to read,
144+
:param list[int] ch_idx: list of channel indices to read,
145145
output channel ordering will reflect the sequence
146146
:param int tz: window index within the FOV, counted Z-first
147-
:return tuple[Tensor], tuple[str, int, int]:
148-
tuple of (C=1, Z, Y, X) image tensors,
147+
:return list[Tensor], HCSStackIndex:
148+
list of (C=1, Z, Y, X) image tensors,
149149
tuple of image name, time index, and Z index
150150
"""
151151
zs = img.shape[-3] - self.z_window_size + 1
@@ -162,8 +162,8 @@ def __len__(self) -> int:
162162
return self._max_window
163163

164164
def _stack_channels(
165-
self, sample_images: list[dict[str, Tensor]], key: str
166-
) -> Tensor:
165+
self, sample_images: list[dict[str, Tensor]] | dict[str, Tensor], key: str
166+
) -> Tensor | list[Tensor]:
167167
"""Stack single-channel images into a multi-channel tensor."""
168168
if not isinstance(sample_images, list):
169169
return torch.stack([sample_images[ch][0] for ch in self.channels[key]])
@@ -187,7 +187,8 @@ def __getitem__(self, index: int) -> Sample:
187187
# since adding a reference to a tensor does not copy
188188
# maybe write a weight map in preprocessing to use more information?
189189
sample_images["weight"] = sample_images[self.channels["target"][0]]
190-
sample_images["norm_meta"] = norm_meta
190+
if norm_meta is not None:
191+
sample_images["norm_meta"] = norm_meta
191192
if self.transform:
192193
sample_images = self.transform(sample_images)
193194
# if isinstance(sample_images, list):
@@ -224,7 +225,7 @@ def __init__(
224225
positions: list[Position],
225226
channels: ChannelMap,
226227
z_window_size: int,
227-
transform: Callable[[dict[str, Tensor]], dict[str, Tensor]] = None,
228+
transform: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
228229
ground_truth_masks: str = None,
229230
) -> None:
230231
super().__init__(positions, channels, z_window_size, transform)
@@ -268,9 +269,9 @@ class HCSDataModule(LightningDataModule):
268269
defaults to "2.5D"
269270
:param tuple[int, int] yx_patch_size: patch size in (Y, X),
270271
defaults to (256, 256)
271-
:param Optional[list[MapTransform]] normalizations: MONAI dictionary transforms
272+
:param list[MapTransform] normalizations: MONAI dictionary transforms
272273
applied to selected channels, defaults to [] (no normalization)
273-
:param Optional[list[MapTransform]] augmentations: MONAI dictionary transforms
274+
:param list[MapTransform] augmentations: MONAI dictionary transforms
274275
applied to the training set, defaults to [] (no augmentation)
275276
:param bool caching: whether to decompress all the images and cache the result,
276277
will store in ``/tmp/$SLURM_JOB_ID/`` if available,
@@ -291,8 +292,8 @@ def __init__(
291292
num_workers: int = 8,
292293
architecture: Literal["2D", "2.1D", "2.2D", "2.5D", "3D", "fcmae"] = "2.5D",
293294
yx_patch_size: tuple[int, int] = (256, 256),
294-
normalizations: Optional[list[MapTransform]] = [],
295-
augmentations: Optional[list[MapTransform]] = [],
295+
normalizations: list[MapTransform] = [],
296+
augmentations: list[MapTransform] = [],
296297
caching: bool = False,
297298
ground_truth_masks: Optional[Path] = None,
298299
):
@@ -315,9 +316,20 @@ def __init__(
315316
@property
316317
def cache_path(self):
317318
return Path(
318-
tempfile.gettempdir(), os.getenv("SLURM_JOB_ID"), self.data_path.name
319+
tempfile.gettempdir(),
320+
os.getenv("SLURM_JOB_ID", "viscy_cache"),
321+
self.data_path.name,
319322
)
320323

324+
def _data_log_path(self) -> Path:
325+
log_dir = Path.cwd()
326+
if self.trainer:
327+
if self.trainer.logger:
328+
if self.trainer.logger.log_dir:
329+
log_dir = Path(self.trainer.logger.log_dir)
330+
log_dir.mkdir(parents=True, exist_ok=True)
331+
return log_dir / "data.log"
332+
321333
def prepare_data(self):
322334
if not self.caching:
323335
return
@@ -328,10 +340,7 @@ def prepare_data(self):
328340
console_handler = logging.StreamHandler()
329341
console_handler.setLevel(logging.INFO)
330342
logger.addHandler(console_handler)
331-
os.makedirs(self.trainer.logger.log_dir, exist_ok=True)
332-
file_handler = logging.FileHandler(
333-
os.path.join(self.trainer.logger.log_dir, "data.log")
334-
)
343+
file_handler = logging.FileHandler(self._data_log_path())
335344
file_handler.setLevel(logging.DEBUG)
336345
logger.addHandler(file_handler)
337346
logger.info(f"Caching dataset at {self.cache_path}.")

viscy/data/typing.py

+54-13
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,63 @@
1-
from typing import Sequence, TypedDict, Union
1+
from __future__ import annotations
22

3-
from torch import Tensor
3+
from typing import TYPE_CHECKING, NamedTuple, Sequence, TypedDict, TypeVar
4+
5+
if TYPE_CHECKING:
6+
from torch import Tensor
7+
8+
T = TypeVar("T")
9+
OneOrSeq = T | Sequence[T]
10+
11+
12+
class LevelNormStats(TypedDict):
13+
mean: float
14+
std: float
15+
median: float
16+
iqr: float
17+
18+
19+
class ChannelNormStats(TypedDict):
20+
dataset_statistics: LevelNormStats
21+
fov_statistics: LevelNormStats
22+
23+
24+
NormMeta = dict[str, ChannelNormStats]
25+
26+
27+
class HCSStackIndex(NamedTuple):
28+
"""HCS stack index."""
29+
30+
# name of the image array, e.g. "A/1/0/0"
31+
image: str
32+
time: int
33+
z: int
434

535

636
class Sample(TypedDict, total=False):
7-
"""Image sample type for mini-batches."""
37+
"""
38+
Image sample type for mini-batches.
39+
All fields are optional.
40+
"""
41+
42+
index: HCSStackIndex
43+
# Image data
44+
source: OneOrSeq[Tensor]
45+
target: OneOrSeq[Tensor]
46+
weight: OneOrSeq[Tensor]
47+
# Instance segmentation masks
48+
labels: OneOrSeq[Tensor]
49+
# None: not available
50+
norm_meta: NormMeta
51+
52+
53+
class _ChannelMap(TypedDict):
54+
"""Source channel names."""
855

9-
# all optional
10-
index: tuple[str, int, int]
11-
source: Union[Tensor, Sequence[Tensor]]
12-
target: Union[Tensor, Sequence[Tensor]]
13-
labels: Union[Tensor, Sequence[Tensor]]
14-
norm_meta: dict[str, dict]
56+
source: OneOrSeq[str]
1557

1658

17-
class ChannelMap(TypedDict, total=False):
59+
class ChannelMap(_ChannelMap, total=False):
1860
"""Source and target channel names."""
1961

20-
source: Union[str, Sequence[str]]
21-
# optional
22-
target: Union[str, Sequence[str]]
62+
# TODO: use typing.NotRequired when upgrading to Python 3.11
63+
target: OneOrSeq[str]

0 commit comments

Comments
 (0)