25
25
from torch import Tensor
26
26
from torch .utils .data import DataLoader , Dataset
27
27
28
- from viscy .data .typing import ChannelMap , Sample
28
+ from viscy .data .typing import ChannelMap , HCSStackIndex , NormMeta , Sample
29
29
30
30
31
31
def _ensure_channel_list (str_or_seq : str | Sequence [str ]) -> list [str ]:
@@ -64,11 +64,11 @@ def _collate_samples(batch: Sequence[Sample]) -> Sample:
64
64
as is the case with ``train_patches_per_stack > 1``.
65
65
:return Sample: Batch sample (dictionary of tensors)
66
66
"""
67
- collated = {}
67
+ collated : Sample = {}
68
68
for key in batch [0 ].keys ():
69
69
data = []
70
70
for sample in batch :
71
- if isinstance (sample [key ], list ):
71
+ if isinstance (sample [key ], Sequence ):
72
72
data .extend (sample [key ])
73
73
else :
74
74
data .append (sample [key ])
@@ -84,7 +84,7 @@ class SlidingWindowDataset(Dataset):
84
84
:param ChannelMap channels: source and target channel names,
85
85
e.g. ``{'source': 'Phase', 'target': ['Nuclei', 'Membrane']}``
86
86
:param int z_window_size: Z window size of the 2.5D U-Net, 1 for 2D
87
- :param Callable[[dict[str, Tensor]], dict[str, Tensor]] transform:
87
+ :param Callable[[dict[str, Tensor]], dict[str, Tensor]] | None transform:
88
88
a callable that transforms data, defaults to None
89
89
"""
90
90
@@ -93,7 +93,7 @@ def __init__(
93
93
positions : list [Position ],
94
94
channels : ChannelMap ,
95
95
z_window_size : int ,
96
- transform : Callable [[dict [str , Tensor ]], dict [str , Tensor ]] = None ,
96
+ transform : Callable [[dict [str , Tensor ]], dict [str , Tensor ]] | None = None ,
97
97
) -> None :
98
98
super ().__init__ ()
99
99
self .positions = positions
@@ -116,18 +116,18 @@ def _get_windows(self) -> None:
116
116
w = 0
117
117
self .window_keys = []
118
118
self .window_arrays = []
119
- self .window_norm_meta = []
119
+ self .window_norm_meta : list [ NormMeta | None ] = []
120
120
for fov in self .positions :
121
- img_arr = fov ["0" ]
121
+ img_arr : ImageArray = fov ["0" ]
122
122
ts = img_arr .frames
123
123
zs = img_arr .slices - self .z_window_size + 1
124
124
w += ts * zs
125
125
self .window_keys .append (w )
126
126
self .window_arrays .append (img_arr )
127
- self .window_norm_meta .append (fov .zattrs .get ("normalization" , 0 ))
127
+ self .window_norm_meta .append (fov .zattrs .get ("normalization" , None ))
128
128
self ._max_window = w
129
129
130
- def _find_window (self , index : int ) -> tuple [int , int ]:
130
+ def _find_window (self , index : int ) -> tuple [ImageArray , int , NormMeta | None ]:
131
131
"""Look up window given index."""
132
132
window_idx = sorted (self .window_keys + [index + 1 ]).index (index + 1 )
133
133
w = self .window_keys [window_idx ]
@@ -136,16 +136,16 @@ def _find_window(self, index: int) -> tuple[int, int]:
136
136
return (self .window_arrays [self .window_keys .index (w )], tz , norm_meta )
137
137
138
138
def _read_img_window (
139
- self , img : ImageArray , ch_idx : list [str ], tz : int
140
- ) -> tuple [tuple [Tensor ], tuple [ str , int , int ] ]:
139
+ self , img : ImageArray , ch_idx : list [int ], tz : int
140
+ ) -> tuple [list [Tensor ], HCSStackIndex ]:
141
141
"""Read image window as tensor.
142
142
143
143
:param ImageArray img: NGFF image array
144
- :param list[int] channels : list of channel indices to read,
144
+ :param list[int] ch_idx : list of channel indices to read,
145
145
output channel ordering will reflect the sequence
146
146
:param int tz: window index within the FOV, counted Z-first
147
- :return tuple [Tensor], tuple[str, int, int] :
148
- tuple of (C=1, Z, Y, X) image tensors,
147
+ :return list [Tensor], HCSStackIndex :
148
+ list of (C=1, Z, Y, X) image tensors,
149
149
tuple of image name, time index, and Z index
150
150
"""
151
151
zs = img .shape [- 3 ] - self .z_window_size + 1
@@ -162,8 +162,8 @@ def __len__(self) -> int:
162
162
return self ._max_window
163
163
164
164
def _stack_channels (
165
- self , sample_images : list [dict [str , Tensor ]], key : str
166
- ) -> Tensor :
165
+ self , sample_images : list [dict [str , Tensor ]] | dict [ str , Tensor ] , key : str
166
+ ) -> Tensor | list [ Tensor ] :
167
167
"""Stack single-channel images into a multi-channel tensor."""
168
168
if not isinstance (sample_images , list ):
169
169
return torch .stack ([sample_images [ch ][0 ] for ch in self .channels [key ]])
@@ -187,7 +187,8 @@ def __getitem__(self, index: int) -> Sample:
187
187
# since adding a reference to a tensor does not copy
188
188
# maybe write a weight map in preprocessing to use more information?
189
189
sample_images ["weight" ] = sample_images [self .channels ["target" ][0 ]]
190
- sample_images ["norm_meta" ] = norm_meta
190
+ if norm_meta is not None :
191
+ sample_images ["norm_meta" ] = norm_meta
191
192
if self .transform :
192
193
sample_images = self .transform (sample_images )
193
194
# if isinstance(sample_images, list):
@@ -224,7 +225,7 @@ def __init__(
224
225
positions : list [Position ],
225
226
channels : ChannelMap ,
226
227
z_window_size : int ,
227
- transform : Callable [[dict [str , Tensor ]], dict [str , Tensor ]] = None ,
228
+ transform : Callable [[dict [str , Tensor ]], dict [str , Tensor ]] | None = None ,
228
229
ground_truth_masks : str = None ,
229
230
) -> None :
230
231
super ().__init__ (positions , channels , z_window_size , transform )
@@ -268,9 +269,9 @@ class HCSDataModule(LightningDataModule):
268
269
defaults to "2.5D"
269
270
:param tuple[int, int] yx_patch_size: patch size in (Y, X),
270
271
defaults to (256, 256)
271
- :param Optional[ list[MapTransform] ] normalizations: MONAI dictionary transforms
272
+ :param list[MapTransform] normalizations: MONAI dictionary transforms
272
273
applied to selected channels, defaults to [] (no normalization)
273
- :param Optional[ list[MapTransform] ] augmentations: MONAI dictionary transforms
274
+ :param list[MapTransform] augmentations: MONAI dictionary transforms
274
275
applied to the training set, defaults to [] (no augmentation)
275
276
:param bool caching: whether to decompress all the images and cache the result,
276
277
will store in ``/tmp/$SLURM_JOB_ID/`` if available,
@@ -291,8 +292,8 @@ def __init__(
291
292
num_workers : int = 8 ,
292
293
architecture : Literal ["2D" , "2.1D" , "2.2D" , "2.5D" , "3D" , "fcmae" ] = "2.5D" ,
293
294
yx_patch_size : tuple [int , int ] = (256 , 256 ),
294
- normalizations : Optional [ list [MapTransform ] ] = [],
295
- augmentations : Optional [ list [MapTransform ] ] = [],
295
+ normalizations : list [MapTransform ] = [],
296
+ augmentations : list [MapTransform ] = [],
296
297
caching : bool = False ,
297
298
ground_truth_masks : Optional [Path ] = None ,
298
299
):
@@ -315,9 +316,20 @@ def __init__(
315
316
@property
316
317
def cache_path (self ):
317
318
return Path (
318
- tempfile .gettempdir (), os .getenv ("SLURM_JOB_ID" ), self .data_path .name
319
+ tempfile .gettempdir (),
320
+ os .getenv ("SLURM_JOB_ID" , "viscy_cache" ),
321
+ self .data_path .name ,
319
322
)
320
323
324
+ def _data_log_path (self ) -> Path :
325
+ log_dir = Path .cwd ()
326
+ if self .trainer :
327
+ if self .trainer .logger :
328
+ if self .trainer .logger .log_dir :
329
+ log_dir = Path (self .trainer .logger .log_dir )
330
+ log_dir .mkdir (parents = True , exist_ok = True )
331
+ return log_dir / "data.log"
332
+
321
333
def prepare_data (self ):
322
334
if not self .caching :
323
335
return
@@ -328,10 +340,7 @@ def prepare_data(self):
328
340
console_handler = logging .StreamHandler ()
329
341
console_handler .setLevel (logging .INFO )
330
342
logger .addHandler (console_handler )
331
- os .makedirs (self .trainer .logger .log_dir , exist_ok = True )
332
- file_handler = logging .FileHandler (
333
- os .path .join (self .trainer .logger .log_dir , "data.log" )
334
- )
343
+ file_handler = logging .FileHandler (self ._data_log_path ())
335
344
file_handler .setLevel (logging .DEBUG )
336
345
logger .addHandler (file_handler )
337
346
logger .info (f"Caching dataset at { self .cache_path } ." )
0 commit comments