diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 84c2ebd67ad..1eb8ff79df7 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -364,10 +364,11 @@ def __iter__(self): @property def total_batch_size(self): + batch_sampler = self.sampler if isinstance(self.sampler, BatchSampler) else self.batch_sampler return ( - self.batch_sampler.batch_size - if self.batch_sampler.split_batches - else (self.batch_sampler.batch_size * self.batch_sampler.num_processes) + batch_sampler.batch_size + if batch_sampler.split_batches + else (batch_sampler.batch_size * batch_sampler.num_processes) ) @property @@ -639,14 +640,17 @@ def prepare_data_loader( ) else: # New batch sampler for the current process. - if hasattr(dataloader.sampler, "generator"): - if dataloader.sampler.generator is None: - dataloader.sampler.generator = torch.Generator() - generator = dataloader.sampler.generator - generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item())) - elif getattr(dataloader.batch_sampler, "generator", None) is not None: - generator = dataloader.batch_sampler.generator sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler) + if sampler_is_batch_sampler: + sampler = dataloader.sampler.sampler + else: + sampler = dataloader.batch_sampler.sampler + if hasattr(sampler, "generator"): + if sampler.generator is None: + sampler.generator = torch.Generator() + generator = sampler.generator + generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item())) + batch_sampler = dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler new_batch_sampler = BatchSamplerShard( batch_sampler, @@ -692,6 +696,7 @@ def prepare_data_loader( new_dataset, device=device if put_on_device and state.distributed_type != DistributedType.TPU else None, sampler=new_batch_sampler, + batch_size=getattr(dataloader, "batch_size", _PYTORCH_DATALOADER_KWARGS["batch_size"]), rng_types=rng_types, generator=generator, **kwargs,