diff --git a/docs/source/internal.mdx b/docs/source/internal.mdx index 8da23bee816..4ff1d6ff920 100644 --- a/docs/source/internal.mdx +++ b/docs/source/internal.mdx @@ -26,7 +26,7 @@ The main work on your PyTorch `DataLoader` is done by the following function: [[autodoc]] data_loader.prepare_data_loader -### BatchSamplerShard +### DataLoaderShard [[autodoc]] data_loader.DataLoaderShard diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 524034eb190..f58752e5392 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -38,6 +38,22 @@ if is_tpu_available(check_device=False): import torch_xla.distributed.parallel_loader as xpl + class MpDeviceLoaderWrapper(xpl.MpDeviceLoader): + """ + Wrapper for the xpl.MpDeviceLoader class that knows the total batch size. + + **Available attributes:** + + - **total_batch_size** (`int`) -- Total batch size of the dataloader across all processes. + Equal to the original batch size when `split_batches=True`; otherwise the original batch size * the total + number of processes + """ + + @property + def total_batch_size(self): + return self._loader.total_batch_size + + logger = get_logger(__name__) # kwargs of the DataLoader in min version 1.4.0. @@ -289,6 +305,12 @@ class DataLoaderShard(DataLoader): A random number generator to keep synchronized across processes. kwargs: All other keyword arguments to pass to the regular `DataLoader` initialization. + + **Available attributes:** + + - **total_batch_size** (`int`) -- Total batch size of the dataloader across all processes. + Equal to the original batch size when `split_batches=True`; otherwise the original batch size * the total + number of processes """ def __init__(self, dataset, device=None, rng_types=None, generator=None, **kwargs): @@ -321,6 +343,14 @@ def __iter__(self): yield current_batch break + @property + def total_batch_size(self): + return ( + self.batch_sampler.batch_size + if self.batch_sampler.split_batches + else (self.batch_sampler.batch_size * self.batch_sampler.num_processes) + ) + class DataLoaderDispatcher(DataLoader): """ @@ -334,6 +364,12 @@ class DataLoaderDispatcher(DataLoader): the same as the initial `dataloader` if this option is set to `True`, the batch size of the initial `dataloader` multiplied by `num_processes` otherwise. Setting this option to `True` requires that the batch size of the `dataloader` is a round multiple of `batch_size`. + + **Available attributes:** + + - **total_batch_size** (`int`) -- Total batch size of the dataloader across all processes. + Equal to the original batch size when `split_batches=True`; otherwise the original batch size * the total + number of processes """ def __init__(self, dataset, split_batches: bool = False, **kwargs): @@ -432,6 +468,12 @@ def __len__(self): else: return math.ceil(whole_length / self.state.num_processes) + @property + def total_batch_size(self): + return ( + self.dataset.batch_size if self.split_batches else (self.dataset.batch_size * self.dataset.num_processes) + ) + def prepare_data_loader( dataloader: DataLoader, @@ -577,7 +619,10 @@ def prepare_data_loader( if dispatch_batches: dataloader = DataLoaderDispatcher( - new_dataset, split_batches=split_batches, batch_sampler=new_batch_sampler, **kwargs + new_dataset, + split_batches=split_batches, + batch_sampler=new_batch_sampler, + **kwargs, ) else: dataloader = DataLoaderShard( @@ -590,5 +635,5 @@ def prepare_data_loader( ) if state.distributed_type == DistributedType.TPU: - return xpl.MpDeviceLoader(dataloader, device) + return MpDeviceLoaderWrapper(dataloader, device) return dataloader