From 43d5247c19839029f7aa0fd2e8c565ab1c44f582 Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Mon, 12 Sep 2022 14:53:49 +0530 Subject: [PATCH 1/3] fixing rng sync when using custom sampler and batch_sampler --- src/accelerate/data_loader.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 84c2ebd67ad..582b4e6fee4 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -364,10 +364,14 @@ def __iter__(self): @property def total_batch_size(self): + if self.batch_sampler is None: + batch_sampler = self.sampler + else: + batch_sampler = self.batch_sampler return ( - self.batch_sampler.batch_size - if self.batch_sampler.split_batches - else (self.batch_sampler.batch_size * self.batch_sampler.num_processes) + batch_sampler.batch_size + if batch_sampler.split_batches + else (batch_sampler.batch_size * batch_sampler.num_processes) ) @property @@ -639,14 +643,17 @@ def prepare_data_loader( ) else: # New batch sampler for the current process. - if hasattr(dataloader.sampler, "generator"): - if dataloader.sampler.generator is None: - dataloader.sampler.generator = torch.Generator() - generator = dataloader.sampler.generator - generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item())) - elif getattr(dataloader.batch_sampler, "generator", None) is not None: - generator = dataloader.batch_sampler.generator sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler) + if sampler_is_batch_sampler: + sampler = dataloader.sampler.sampler + else: + sampler = dataloader.batch_sampler.sampler + if hasattr(sampler, "generator"): + if sampler.generator is None: + sampler.generator = torch.Generator() + generator = sampler.generator + generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item())) + batch_sampler = dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler new_batch_sampler = BatchSamplerShard( batch_sampler, @@ -692,6 +699,7 @@ def prepare_data_loader( new_dataset, device=device if put_on_device and state.distributed_type != DistributedType.TPU else None, sampler=new_batch_sampler, + batch_size=None, rng_types=rng_types, generator=generator, **kwargs, From c823458f806683b53573f1e3f74a21dc8cd604ca Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Mon, 12 Sep 2022 19:00:55 +0530 Subject: [PATCH 2/3] addressing comments --- src/accelerate/data_loader.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 582b4e6fee4..1eb8ff79df7 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -364,10 +364,7 @@ def __iter__(self): @property def total_batch_size(self): - if self.batch_sampler is None: - batch_sampler = self.sampler - else: - batch_sampler = self.batch_sampler + batch_sampler = self.sampler if isinstance(self.sampler, BatchSampler) else self.batch_sampler return ( batch_sampler.batch_size if batch_sampler.split_batches @@ -699,7 +696,7 @@ def prepare_data_loader( new_dataset, device=device if put_on_device and state.distributed_type != DistributedType.TPU else None, sampler=new_batch_sampler, - batch_size=None, + batch_size=getattr(dataloader, "batch_size", _PYTORCH_DATALOADER_KWARGS["batch_size"]), rng_types=rng_types, generator=generator, **kwargs, From 9708f00ad0de9237da332838834be57739aa1c5e Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Mon, 12 Sep 2022 19:28:22 +0530 Subject: [PATCH 3/3] :sparkles: