@@ -199,10 +199,6 @@ def __getitem__(self, index: int) -> Sample:
199
199
sample ["target" ] = self ._stack_channels (sample_images , "target" )
200
200
return sample
201
201
202
- def __del__ (self ):
203
- """Close the Zarr store when the dataset instance gets GC'ed."""
204
- self .positions [0 ].zgroup .store .close ()
205
-
206
202
207
203
class MaskTestDataset (SlidingWindowDataset ):
208
204
"""Torch dataset where each element is a window of
@@ -310,7 +306,13 @@ def __init__(
310
306
self .augmentations = augmentations
311
307
self .caching = caching
312
308
self .ground_truth_masks = ground_truth_masks
313
- self .tmp_zarr = None
309
+ self .prepare_data_per_node = True
310
+
311
+ @property
312
+ def cache_path (self ):
313
+ return Path (
314
+ tempfile .gettempdir (), os .getenv ("SLURM_JOB_ID" ), self .data_path .name
315
+ )
314
316
315
317
def prepare_data (self ):
316
318
if not self .caching :
@@ -322,20 +324,14 @@ def prepare_data(self):
322
324
console_handler = logging .StreamHandler ()
323
325
console_handler .setLevel (logging .INFO )
324
326
logger .addHandler (console_handler )
325
- os .mkdir (self .trainer .logger .log_dir )
327
+ os .makedirs (self .trainer .logger .log_dir , exist_ok = True )
326
328
file_handler = logging .FileHandler (
327
329
os .path .join (self .trainer .logger .log_dir , "data.log" )
328
330
)
329
331
file_handler .setLevel (logging .DEBUG )
330
332
logger .addHandler (file_handler )
331
- # cache in temporary directory
332
- self .tmp_zarr = os .path .join (
333
- tempfile .gettempdir (),
334
- os .getenv ("SLURM_JOB_ID" ),
335
- os .path .basename (self .data_path ),
336
- )
337
- logger .info (f"Caching dataset at { self .tmp_zarr } ." )
338
- tmp_store = zarr .NestedDirectoryStore (self .tmp_zarr )
333
+ logger .info (f"Caching dataset at { self .cache_path } ." )
334
+ tmp_store = zarr .NestedDirectoryStore (self .cache_path )
339
335
with open_ome_zarr (self .data_path , mode = "r" ) as lazy_plate :
340
336
_ , skipped , _ = zarr .copy (
341
337
lazy_plate .zgroup ,
@@ -373,7 +369,7 @@ def _setup_fit(self, dataset_settings: dict):
373
369
val_transform = Compose (self .normalizations + fit_transform )
374
370
375
371
dataset_settings ["channels" ]["target" ] = self .target_channel
376
- data_path = self .tmp_zarr if self .tmp_zarr else self .data_path
372
+ data_path = self .cache_path if self .caching else self .data_path
377
373
plate = open_ome_zarr (data_path , mode = "r" )
378
374
379
375
# disable metadata tracking in MONAI for performance
@@ -410,7 +406,7 @@ def _setup_test(self, dataset_settings: dict):
410
406
logging .warning (f"Ignoring batch size { self .batch_size } in test stage." )
411
407
412
408
dataset_settings ["channels" ]["target" ] = self .target_channel
413
- data_path = self .tmp_zarr if self .tmp_zarr else self .data_path
409
+ data_path = self .cache_path if self .cache_path else self .data_path
414
410
plate = open_ome_zarr (data_path , mode = "r" )
415
411
if self .ground_truth_masks :
416
412
self .test_dataset = MaskTestDataset (
@@ -476,7 +472,9 @@ def train_dataloader(self):
476
472
num_workers = self .num_workers ,
477
473
shuffle = True ,
478
474
persistent_workers = bool (self .num_workers ),
475
+ prefetch_factor = 4 ,
479
476
collate_fn = _collate_samples ,
477
+ drop_last = True ,
480
478
)
481
479
482
480
def val_dataloader (self ):
0 commit comments