Skip to content

Commit

Permalink
fixing rng sync when using custom sampler and batch_sampler (#696)
Browse files Browse the repository at this point in the history
* fixing rng sync when using custom sampler and batch_sampler

* addressing comments

* ✨

Co-authored-by: Sylvain Gugger <[email protected]>
  • Loading branch information
pacman100 and sgugger authored Sep 12, 2022
1 parent 8444465 commit 8d27597
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,10 +364,11 @@ def __iter__(self):

@property
def total_batch_size(self):
batch_sampler = self.sampler if isinstance(self.sampler, BatchSampler) else self.batch_sampler
return (
self.batch_sampler.batch_size
if self.batch_sampler.split_batches
else (self.batch_sampler.batch_size * self.batch_sampler.num_processes)
batch_sampler.batch_size
if batch_sampler.split_batches
else (batch_sampler.batch_size * batch_sampler.num_processes)
)

@property
Expand Down Expand Up @@ -639,14 +640,17 @@ def prepare_data_loader(
)
else:
# New batch sampler for the current process.
if hasattr(dataloader.sampler, "generator"):
if dataloader.sampler.generator is None:
dataloader.sampler.generator = torch.Generator()
generator = dataloader.sampler.generator
generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))
elif getattr(dataloader.batch_sampler, "generator", None) is not None:
generator = dataloader.batch_sampler.generator
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
if sampler_is_batch_sampler:
sampler = dataloader.sampler.sampler
else:
sampler = dataloader.batch_sampler.sampler
if hasattr(sampler, "generator"):
if sampler.generator is None:
sampler.generator = torch.Generator()
generator = sampler.generator
generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))

batch_sampler = dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler
new_batch_sampler = BatchSamplerShard(
batch_sampler,
Expand Down Expand Up @@ -692,6 +696,7 @@ def prepare_data_loader(
new_dataset,
device=device if put_on_device and state.distributed_type != DistributedType.TPU else None,
sampler=new_batch_sampler,
batch_size=getattr(dataloader, "batch_size", _PYTORCH_DATALOADER_KWARGS["batch_size"]),
rng_types=rng_types,
generator=generator,
**kwargs,
Expand Down

0 comments on commit 8d27597

Please sign in to comment.