You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This may partially be a feature request, question, and unwanted behavior all in one. I would like to figure out a valid way to use different amounts of data on each gpu process with DDP training and validation with large iterable datasets. When using lightning trainer as is, the training hangs.
For training I have come up with a workaround to use a dataloader that infinitely loops over the data on each GPU process and uses max_steps instead of max_epochs. However, for evaluation/validation I am unsure of a workaround using torchmetrics to produce valid metrics and not duplicate data.
Please see the script below.
What version are you seeing the problem on?
v2.4
How to reproduce the bug
importosimportshutilimportlightningasLimportnumpyasnpimportpandasaspdimporttorchimporttorch.nn.functionalasFfromdatasetsimportload_datasetfromdatasets.distributedimportsplit_dataset_by_nodefromlightning.fabric.plugins.environments.torchelasticimportTorchElasticEnvironmentfromtorchimportnnfromtorch.utils.dataimportDataLoaderfromtorchmetrics.classification.aurocimportAUROCdefmain(args):
print(args)
env=TorchElasticEnvironment()
ifenv.local_rank() ==0:
path="example-dataset"ifos.path.isdir(path):
shutil.rmtree(path)
ifnotos.path.exists(path):
os.mkdir(path)
partition_sizes= [10, 20]
total=0fori, sizeinenumerate(partition_sizes):
data=pd.DataFrame(
{
"id": list(range(total, total+size)),
"inputs": [np.random.rand(5).tolist() for_inrange(size)],
"labels": np.random.randint(0, 2, size).tolist(),
}
)
data.to_parquet(os.path.join(path, f"data{i}.parquet"))
total+=sizeclassModel(L.LightningModule):
def__init__(self):
super().__init__()
self.model=nn.Linear(5, 2)
self.auroc=AUROC(task="binary")
deftraining_step(self, batch, batch_idx):
# training_step defines the train loop.print(f"{self.trainer.global_rank}: {batch['id'].cpu().numpy().tolist()} ")
batch["inputs"] =torch.vstack(batch["inputs"]).float()
y_hat=self.model(batch["inputs"])
loss=F.cross_entropy(y_hat, batch["labels"])
returnlossdefvalidation_step(self, batch, batch_idx):
batch["inputs"] =torch.vstack(batch["inputs"]).float()
y_hat=self.model(batch["inputs"])
loss=F.cross_entropy(y_hat, batch["labels"])
self.auroc(torch.softmax(y_hat, -1)[:, 1], batch["labels"])
self.log(
"loss", loss, on_epoch=True, prog_bar=True, logger=True, sync_dist=True
)
self.log(
"auroc",
self.auroc,
on_epoch=True,
prog_bar=True,
logger=True,
sync_dist=True,
)
defconfigure_optimizers(self):
optimizer=torch.optim.Adam(self.parameters(), lr=1e-3)
returnoptimizer# Load datasetdataset=load_dataset(
"parquet",
data_files=[
"example-dataset/data0.parquet",
"example-dataset/data1.parquet",
],
split="train",
streaming=True,
)
dataset=split_dataset_by_node(
dataset, rank=env.global_rank(), world_size=env.world_size()
)
model=Model()
ifargs.normal_dataloader:
# Train modeltrain_dl=DataLoader(dataset, batch_size=5)
val_dl=DataLoader(dataset, batch_size=5)
trainer=L.Trainer(
accelerator="gpu",
strategy="ddp",
devices=env.world_size(),
num_nodes=1,
max_epochs=1,
)
trainer.fit(model, train_dl, val_dl)
# Solutions for training ## 1. Infitinite dataloader - keep cycling over data and use max_steps insteadclassInfiniteDataLoader(DataLoader):
""" Dataloader that continually cycles over the dataset """def__init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Initialize an iterator over the dataset.self.dataset_iterator=super().__iter__()
self.epoch=0self.iters=0def__iter__(self):
returnselfdef__next__(self):
try:
batch=next(self.dataset_iterator)
exceptStopIteration:
# Dataset exhausted, use a new fresh iterator.self.increment_epoch()
self.dataset_iterator=super().__iter__()
batch=next(self.dataset_iterator)
self.iters+=1returnbatchdefset_epoch(self, epoch: int):
"Set iteration for the dataset generator seed for shuffling"# We support if a custom `Dataset` implementation has `set_epoch`# or in general HF datasets `Datasets`ifhasattr(self.dataset, "set_epoch"):
self.dataset.set_epoch(epoch)
defincrement_epoch(self):
self.epoch+=1self.iters=0self.set_epoch(self.epoch)
ifargs.infinite_dataloader:
train_dl=InfiniteDataLoader(dataset, batch_size=5)
val_dl=DataLoader(dataset, batch_size=5)
trainer=L.Trainer(
accelerator="gpu",
strategy="ddp",
devices=env.world_size(),
num_nodes=1,
max_steps=4,
)
trainer.fit(model, train_dl, val_dl)
# Solutions for eval ## 1. Load into memory and reshard --- would like to avoid thisif__name__=="__main__":
importargparseparser=argparse.ArgumentParser()
parser.add_argument(
"--normal-dataloader",
action="store_true",
help="Run normal dataloader with a sharded dataset",
)
parser.add_argument(
"--infinite-dataloader",
action="store_true",
help="Run with an infinite dataloader with a sharded dataset using max_steps",
)
main(parser.parse_args())
Error messages and logs
Running torchrun --nproc-per-node 2 example_ddp.py --normal-dataloader results in the process hanging since there is uneven data.
Hi @ssharpe42 this is not a bug but rather an intended behavior. In DDP processes are running independently under the assumption that they will perform the exact same instructions (on different data). Any divergence will lead to a hang as soon as the code reaches a collective call.
Of course you can hack your system so that you accumulate gradients on different numbers of batches, but there might be dragons related to distributed sampling, metrics, unforseen collective calls, aggregation of the loss, scaling of the gradients etc.
There's some discussion on the topic here and here.
We are not planning to work on this right now, but I'd be interested in following the work if you decide to go for it.
Bug description
This may partially be a feature request, question, and unwanted behavior all in one. I would like to figure out a valid way to use different amounts of data on each gpu process with DDP training and validation with large iterable datasets. When using lightning trainer as is, the training hangs.
For training I have come up with a workaround to use a dataloader that infinitely loops over the data on each GPU process and uses max_steps instead of max_epochs. However, for evaluation/validation I am unsure of a workaround using torchmetrics to produce valid metrics and not duplicate data.
Please see the script below.
What version are you seeing the problem on?
v2.4
How to reproduce the bug
Error messages and logs
Running
torchrun --nproc-per-node 2 example_ddp.py --normal-dataloader
results in the process hanging since there is uneven data.Environment
Current environment
- GPU:
- NVIDIA A10G
- NVIDIA A10G
- NVIDIA A10G
- NVIDIA A10G
- available: True
- version: 12.4
- lightning: 2.4.0
- lightning-utilities: 0.11.8
- pytorch-lightning: 2.4.0
- torch: 2.5.1
- torchmetrics: 1.5.1
- absl-py: 2.1.0
- accelerate: 0.34.2
- aiohappyeyeballs: 2.4.3
- aiohttp: 3.10.10
- aiosignal: 1.3.1
- antlr4-python3-runtime: 4.9.3
- astroid: 3.3.5
- asttokens: 2.4.1
- async-timeout: 4.0.3
- attrs: 24.2.0
- autocommand: 2.2.2
- autoflake: 2.3.1
- autopep8: 2.3.1
- backports.tarfile: 1.2.0
- black: 24.10.0
- boto3: 1.35.54
- botocore: 1.35.54
- c1-cube-versioning: 0.2.8
- c1-fm-model: 0.2.0
- certifi: 2024.8.30
- cfgv: 3.4.0
- charset-normalizer: 3.4.0
- click: 8.1.7
- comm: 0.2.2
- contourpy: 1.3.0
- coverage: 7.6.4
- cramjam: 2.9.0
- cycler: 0.12.1
- datasets: 2.18.0
- debugpy: 1.8.7
- decorator: 5.1.1
- dill: 0.3.8
- distlib: 0.3.9
- evaluate: 0.4.3
- exceptiongroup: 1.2.2
- executing: 2.1.0
- fastparquet: 2024.5.0
- filelock: 3.16.1
- flake8: 7.1.1
- fonttools: 4.54.1
- frozenlist: 1.5.0
- fsspec: 2024.2.0
- gitdb: 4.0.11
- gitpython: 3.1.43
- grpcio: 1.67.1
- huggingface-hub: 0.26.2
- hydra-callbacks: 0.6.1
- hydra-core: 1.3.2
- identify: 2.6.1
- idna: 3.10
- importlib-metadata: 8.5.0
- importlib-resources: 6.4.5
- inflect: 7.3.1
- iniconfig: 2.0.0
- intake: 2.0.7
- ipykernel: 6.29.5
- ipython: 8.18.1
- isort: 5.13.2
- jaraco.collections: 5.1.0
- jaraco.context: 5.3.0
- jaraco.functools: 4.0.1
- jaraco.text: 3.12.1
- jedi: 0.19.1
- jinja2: 3.1.4
- jmespath: 1.0.1
- joblib: 1.4.2
- jsonpath-ng: 1.6.1
- jupyter-client: 8.6.3
- jupyter-core: 5.7.2
- kiwisolver: 1.4.7
- lightning: 2.4.0
- lightning-utilities: 0.11.8
- markdown: 3.7
- markdown-it-py: 3.0.0
- markupsafe: 3.0.2
- matplotlib: 3.9.2
- matplotlib-inline: 0.1.7
- mccabe: 0.7.0
- mdurl: 0.1.2
- more-itertools: 10.3.0
- mpmath: 1.3.0
- multidict: 6.1.0
- multiprocess: 0.70.16
- mypy: 1.13.0
- mypy-extensions: 1.0.0
- nbqa: 1.9.0
- nest-asyncio: 1.6.0
- networkx: 3.2.1
- nodeenv: 1.9.1
- numpy: 1.26.4
- nvidia-cublas-cu12: 12.4.5.8
- nvidia-cuda-cupti-cu12: 12.4.127
- nvidia-cuda-nvrtc-cu12: 12.4.127
- nvidia-cuda-runtime-cu12: 12.4.127
- nvidia-cudnn-cu12: 9.1.0.70
- nvidia-cufft-cu12: 11.2.1.3
- nvidia-curand-cu12: 10.3.5.147
- nvidia-cusolver-cu12: 11.6.1.9
- nvidia-cusparse-cu12: 12.3.1.170
- nvidia-nccl-cu12: 2.21.5
- nvidia-nvjitlink-cu12: 12.4.127
- nvidia-nvtx-cu12: 12.4.127
- omegaconf: 2.3.0
- packaging: 24.1
- pandas: 1.5.3
- parso: 0.8.4
- pathspec: 0.12.1
- pexpect: 4.9.0
- pickleshare: 0.7.5
- pillow: 11.0.0
- pip: 24.3.1
- platformdirs: 4.3.6
- pluggy: 1.5.0
- ply: 3.11
- pre-commit: 4.0.1
- pre-commit-hooks: 5.0.0
- prompt-toolkit: 3.0.48
- propcache: 0.2.0
- protobuf: 5.28.3
- psutil: 6.1.0
- ptyprocess: 0.7.0
- pure-eval: 0.2.3
- pyarrow: 14.0.1
- pyarrow-hotfix: 0.6
- pycodestyle: 2.12.1
- pydantic: 1.10.18
- pyflakes: 3.2.0
- pygments: 2.18.0
- pylint: 3.3.1
- pyparsing: 3.2.0
- pyrootutils: 1.0.4
- pytest: 8.3.3
- pytest-cov: 6.0.0
- pytest-mock: 3.14.0
- python-dateutil: 2.9.0
- python-dotenv: 1.0.1
- pytorch-lightning: 2.4.0
- pytz: 2024.2
- pyyaml: 6.0.1
- pyzmq: 26.2.0
- regex: 2024.9.11
- requests: 2.32.3
- rich: 13.9.4
- ruamel.yaml: 0.18.6
- ruamel.yaml.clib: 0.2.12
- rubicon-ml: 0.10.3
- s3fs: 0.4.2
- s3transfer: 0.10.3
- safetensors: 0.4.5
- scikit-learn: 1.5.2
- scipy: 1.13.1
- seaborn: 0.13.2
- setuptools: 75.3.0
- six: 1.16.0
- smmap: 5.0.1
- stack-data: 0.6.2
- sympy: 1.13.1
- tensorboard: 2.18.0
- tensorboard-data-server: 0.7.2
- threadpoolctl: 3.5.0
- tokenize-rt: 6.1.0
- tokenizers: 0.20.3
- tomli: 2.0.2
- tomlkit: 0.13.2
- torch: 2.5.1
- torchmetrics: 1.5.1
- tornado: 6.4.1
- tqdm: 4.66.6
- traitlets: 5.14.3
- transformers: 4.45.2
- triton: 3.1.0
- typeguard: 4.3.0
- typing-extensions: 4.12.2
- urllib3: 1.26.20
- virtualenv: 20.27.1
- wcwidth: 0.2.13
- werkzeug: 3.1.2
- wheel: 0.44.0
- xxhash: 3.5.0
- yarl: 1.17.1
- zipp: 3.20.2
- OS: Linux
- architecture:
- 64bit
-
- processor: x86_64
- python: 3.9.20
- release: 5.10.226-214.880.amzn2.x86_64
- version: Proposal for help #1 SMP Tue Oct 8 16:18:15 UTC 2024
More info
No response
cc @Borda
The text was updated successfully, but these errors were encountered: