Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixing rng sync when using custom sampler and batch_sampler #696

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,10 +364,11 @@ def __iter__(self):

@property
def total_batch_size(self):
batch_sampler = self.sampler if isinstance(self.sampler, BatchSampler) else self.batch_sampler
return (
self.batch_sampler.batch_size
if self.batch_sampler.split_batches
else (self.batch_sampler.batch_size * self.batch_sampler.num_processes)
batch_sampler.batch_size
if batch_sampler.split_batches
else (batch_sampler.batch_size * batch_sampler.num_processes)
)

@property
Expand Down Expand Up @@ -639,14 +640,17 @@ def prepare_data_loader(
)
else:
# New batch sampler for the current process.
if hasattr(dataloader.sampler, "generator"):
if dataloader.sampler.generator is None:
dataloader.sampler.generator = torch.Generator()
generator = dataloader.sampler.generator
generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))
elif getattr(dataloader.batch_sampler, "generator", None) is not None:
generator = dataloader.batch_sampler.generator
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
if sampler_is_batch_sampler:
sampler = dataloader.sampler.sampler
else:
sampler = dataloader.batch_sampler.sampler
if hasattr(sampler, "generator"):
if sampler.generator is None:
sampler.generator = torch.Generator()
generator = sampler.generator
generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))

batch_sampler = dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler
new_batch_sampler = BatchSamplerShard(
batch_sampler,
Expand Down Expand Up @@ -692,6 +696,7 @@ def prepare_data_loader(
new_dataset,
device=device if put_on_device and state.distributed_type != DistributedType.TPU else None,
sampler=new_batch_sampler,
batch_size=getattr(dataloader, "batch_size", _PYTORCH_DATALOADER_KWARGS["batch_size"]),
pacman100 marked this conversation as resolved.
Show resolved Hide resolved
rng_types=rng_types,
generator=generator,
**kwargs,
Expand Down