Skip to content

Commit dd64b31

Browse files
committed
ddp caching fixes
1 parent 01c71cf commit dd64b31

File tree

1 file changed

+14
-16
lines changed

1 file changed

+14
-16
lines changed

viscy/data/hcs.py

+14-16
Original file line numberDiff line numberDiff line change
@@ -199,10 +199,6 @@ def __getitem__(self, index: int) -> Sample:
199199
sample["target"] = self._stack_channels(sample_images, "target")
200200
return sample
201201

202-
def __del__(self):
203-
"""Close the Zarr store when the dataset instance gets GC'ed."""
204-
self.positions[0].zgroup.store.close()
205-
206202

207203
class MaskTestDataset(SlidingWindowDataset):
208204
"""Torch dataset where each element is a window of
@@ -310,7 +306,13 @@ def __init__(
310306
self.augmentations = augmentations
311307
self.caching = caching
312308
self.ground_truth_masks = ground_truth_masks
313-
self.tmp_zarr = None
309+
self.prepare_data_per_node = True
310+
311+
@property
312+
def cache_path(self):
313+
return Path(
314+
tempfile.gettempdir(), os.getenv("SLURM_JOB_ID"), self.data_path.name
315+
)
314316

315317
def prepare_data(self):
316318
if not self.caching:
@@ -322,20 +324,14 @@ def prepare_data(self):
322324
console_handler = logging.StreamHandler()
323325
console_handler.setLevel(logging.INFO)
324326
logger.addHandler(console_handler)
325-
os.mkdir(self.trainer.logger.log_dir)
327+
os.makedirs(self.trainer.logger.log_dir, exist_ok=True)
326328
file_handler = logging.FileHandler(
327329
os.path.join(self.trainer.logger.log_dir, "data.log")
328330
)
329331
file_handler.setLevel(logging.DEBUG)
330332
logger.addHandler(file_handler)
331-
# cache in temporary directory
332-
self.tmp_zarr = os.path.join(
333-
tempfile.gettempdir(),
334-
os.getenv("SLURM_JOB_ID"),
335-
os.path.basename(self.data_path),
336-
)
337-
logger.info(f"Caching dataset at {self.tmp_zarr}.")
338-
tmp_store = zarr.NestedDirectoryStore(self.tmp_zarr)
333+
logger.info(f"Caching dataset at {self.cache_path}.")
334+
tmp_store = zarr.NestedDirectoryStore(self.cache_path)
339335
with open_ome_zarr(self.data_path, mode="r") as lazy_plate:
340336
_, skipped, _ = zarr.copy(
341337
lazy_plate.zgroup,
@@ -373,7 +369,7 @@ def _setup_fit(self, dataset_settings: dict):
373369
val_transform = Compose(self.normalizations + fit_transform)
374370

375371
dataset_settings["channels"]["target"] = self.target_channel
376-
data_path = self.tmp_zarr if self.tmp_zarr else self.data_path
372+
data_path = self.cache_path if self.caching else self.data_path
377373
plate = open_ome_zarr(data_path, mode="r")
378374

379375
# disable metadata tracking in MONAI for performance
@@ -410,7 +406,7 @@ def _setup_test(self, dataset_settings: dict):
410406
logging.warning(f"Ignoring batch size {self.batch_size} in test stage.")
411407

412408
dataset_settings["channels"]["target"] = self.target_channel
413-
data_path = self.tmp_zarr if self.tmp_zarr else self.data_path
409+
data_path = self.cache_path if self.cache_path else self.data_path
414410
plate = open_ome_zarr(data_path, mode="r")
415411
if self.ground_truth_masks:
416412
self.test_dataset = MaskTestDataset(
@@ -476,7 +472,9 @@ def train_dataloader(self):
476472
num_workers=self.num_workers,
477473
shuffle=True,
478474
persistent_workers=bool(self.num_workers),
475+
prefetch_factor=4,
479476
collate_fn=_collate_samples,
477+
drop_last=True,
480478
)
481479

482480
def val_dataloader(self):

0 commit comments

Comments
 (0)