diff --git a/fets/data/pytorch/gandlf_data.py b/fets/data/pytorch/gandlf_data.py index 99a6c08..e1ebb12 100644 --- a/fets/data/pytorch/gandlf_data.py +++ b/fets/data/pytorch/gandlf_data.py @@ -20,6 +20,15 @@ from fets.data.gandlf_utils import get_dataframe_and_headers from fets.data import get_appropriate_file_paths_from_subject_dir +## added for reproducibility +def seed_worker(worker_id): + worker_seed = torch.initial_seed() % 2**32 + numpy.random.seed(worker_seed) + random.seed(worker_seed) + +g = torch.Generator() +g.manual_seed(0) +## added for reproducibility # adapted from https://codereview.stackexchange.com/questions/132914/crop-black-border-of-image-using-numpy/132933#132933 def crop_image_outside_zeros(array, psize): @@ -672,9 +681,9 @@ def get_loaders(self, data_frame, train, augmentations): preprocessing=self.preprocessing, in_memory=self.in_memory) if train: - loader = DataLoader(data, shuffle=True, batch_size=self.batch_size) + loader = DataLoader(data, shuffle=True, batch_size=self.batch_size, worker_init_fn=seed_worker) else: - loader = DataLoader(data, shuffle=False, batch_size=1) + loader = DataLoader(data, shuffle=False, batch_size=1, , worker_init_fn=seed_worker) companion_loader = None if train: