Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to deal with uneven inputs in DDP with sharded data without hanging #20404

ssharpe42 opened this issue Nov 7, 2024 · 1 comment
discussion In a discussion stage


Copy link

ssharpe42 commented Nov 7, 2024

Bug description

This may partially be a feature request, question, and unwanted behavior all in one. I would like to figure out a valid way to use different amounts of data on each gpu process with DDP training and validation with large iterable datasets. When using lightning trainer as is, the training hangs.

For training I have come up with a workaround to use a dataloader that infinitely loops over the data on each GPU process and uses max_steps instead of max_epochs. However, for evaluation/validation I am unsure of a workaround using torchmetrics to produce valid metrics and not duplicate data.

Please see the script below.

What version are you seeing the problem on?


How to reproduce the bug

import os
import shutil

import lightning as L
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from datasets import load_dataset
from datasets.distributed import split_dataset_by_node
from lightning.fabric.plugins.environments.torchelastic import TorchElasticEnvironment
from torch import nn
from import DataLoader
from torchmetrics.classification.auroc import AUROC

def main(args):
    env = TorchElasticEnvironment()

    if env.local_rank() == 0:
        path = "example-dataset"

        if os.path.isdir(path):

        if not os.path.exists(path):

        partition_sizes = [10, 20]
        total = 0
        for i, size in enumerate(partition_sizes):
            data = pd.DataFrame(
                    "id": list(range(total, total + size)),
                    "inputs": [np.random.rand(5).tolist() for _ in range(size)],
                    "labels": np.random.randint(0, 2, size).tolist(),

            data.to_parquet(os.path.join(path, f"data{i}.parquet"))
            total += size

    class Model(L.LightningModule):
        def __init__(self):
            self.model = nn.Linear(5, 2)
            self.auroc = AUROC(task="binary")

        def training_step(self, batch, batch_idx):
            # training_step defines the train loop.
            print(f"{self.trainer.global_rank}: {batch['id'].cpu().numpy().tolist()} ")
            batch["inputs"] = torch.vstack(batch["inputs"]).float()
            y_hat = self.model(batch["inputs"])
            loss = F.cross_entropy(y_hat, batch["labels"])
            return loss

        def validation_step(self, batch, batch_idx):
            batch["inputs"] = torch.vstack(batch["inputs"]).float()
            y_hat = self.model(batch["inputs"])
            loss = F.cross_entropy(y_hat, batch["labels"])
            self.auroc(torch.softmax(y_hat, -1)[:, 1], batch["labels"])
                "loss", loss, on_epoch=True, prog_bar=True, logger=True, sync_dist=True

        def configure_optimizers(self):
            optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
            return optimizer

    # Load dataset
    dataset = load_dataset(
    dataset = split_dataset_by_node(
        dataset, rank=env.global_rank(), world_size=env.world_size()
    model = Model()

    if args.normal_dataloader:

        # Train model
        train_dl = DataLoader(dataset, batch_size=5)
        val_dl = DataLoader(dataset, batch_size=5)
        trainer = L.Trainer(
        ), train_dl, val_dl)

    # Solutions for training #

    # 1. Infitinite dataloader - keep cycling over data and use max_steps instead
    class InfiniteDataLoader(DataLoader):
        Dataloader that continually cycles over the dataset

        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            # Initialize an iterator over the dataset.
            self.dataset_iterator = super().__iter__()
            self.epoch = 0
            self.iters = 0

        def __iter__(self):
            return self

        def __next__(self):
                batch = next(self.dataset_iterator)
            except StopIteration:
                # Dataset exhausted, use a new fresh iterator.
                self.dataset_iterator = super().__iter__()
                batch = next(self.dataset_iterator)
            self.iters += 1
            return batch

        def set_epoch(self, epoch: int):
            "Set iteration for the dataset generator seed for shuffling"

            # We support if a custom `Dataset` implementation has `set_epoch`
            # or in general HF datasets `Datasets`
            if hasattr(self.dataset, "set_epoch"):

        def increment_epoch(self):
            self.epoch += 1
            self.iters = 0

    if args.infinite_dataloader:
        train_dl = InfiniteDataLoader(dataset, batch_size=5)
        val_dl = DataLoader(dataset, batch_size=5)
        trainer = L.Trainer(
        ), train_dl, val_dl)

    # Solutions for eval #
    # 1. Load into memory and reshard --- would like to avoid this

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
        help="Run normal dataloader with a sharded dataset",
        help="Run with an infinite dataloader with a sharded dataset using max_steps",

Error messages and logs

Running torchrun --nproc-per-node 2 --normal-dataloader results in the process hanging since there is uneven data.

Epoch 0: |                                                                                                                                                                                                                                                                                                    | 0/? [00:00<?, ?it/s]1: [10, 11, 12, 13, 14] 
0: [0, 1, 2, 3, 4] 

Epoch 0: |                                                                                                                                                                                                                                                                                  | 1/? [00:00<00:00, 20.45it/s, v_num=57]
0: [5, 6, 7, 8, 9] 

1: [15, 16, 17, 18, 19] 
Epoch 0: |                                                                                                                                                                                                                                                                                  | 2/? [00:00<00:00, 38.07it/s, v_num=57]

1: [20, 21, 22, 23, 24] 

Validation DataLoader 0: |       


Current environment
  • CUDA:
    - GPU:
    - NVIDIA A10G
    - NVIDIA A10G
    - NVIDIA A10G
    - NVIDIA A10G
    - available: True
    - version: 12.4
  • Lightning:
    - lightning: 2.4.0
    - lightning-utilities: 0.11.8
    - pytorch-lightning: 2.4.0
    - torch: 2.5.1
    - torchmetrics: 1.5.1
  • Packages:
    - absl-py: 2.1.0
    - accelerate: 0.34.2
    - aiohappyeyeballs: 2.4.3
    - aiohttp: 3.10.10
    - aiosignal: 1.3.1
    - antlr4-python3-runtime: 4.9.3
    - astroid: 3.3.5
    - asttokens: 2.4.1
    - async-timeout: 4.0.3
    - attrs: 24.2.0
    - autocommand: 2.2.2
    - autoflake: 2.3.1
    - autopep8: 2.3.1
    - backports.tarfile: 1.2.0
    - black: 24.10.0
    - boto3: 1.35.54
    - botocore: 1.35.54
    - c1-cube-versioning: 0.2.8
    - c1-fm-model: 0.2.0
    - certifi: 2024.8.30
    - cfgv: 3.4.0
    - charset-normalizer: 3.4.0
    - click: 8.1.7
    - comm: 0.2.2
    - contourpy: 1.3.0
    - coverage: 7.6.4
    - cramjam: 2.9.0
    - cycler: 0.12.1
    - datasets: 2.18.0
    - debugpy: 1.8.7
    - decorator: 5.1.1
    - dill: 0.3.8
    - distlib: 0.3.9
    - evaluate: 0.4.3
    - exceptiongroup: 1.2.2
    - executing: 2.1.0
    - fastparquet: 2024.5.0
    - filelock: 3.16.1
    - flake8: 7.1.1
    - fonttools: 4.54.1
    - frozenlist: 1.5.0
    - fsspec: 2024.2.0
    - gitdb: 4.0.11
    - gitpython: 3.1.43
    - grpcio: 1.67.1
    - huggingface-hub: 0.26.2
    - hydra-callbacks: 0.6.1
    - hydra-core: 1.3.2
    - identify: 2.6.1
    - idna: 3.10
    - importlib-metadata: 8.5.0
    - importlib-resources: 6.4.5
    - inflect: 7.3.1
    - iniconfig: 2.0.0
    - intake: 2.0.7
    - ipykernel: 6.29.5
    - ipython: 8.18.1
    - isort: 5.13.2
    - jaraco.collections: 5.1.0
    - jaraco.context: 5.3.0
    - jaraco.functools: 4.0.1
    - jaraco.text: 3.12.1
    - jedi: 0.19.1
    - jinja2: 3.1.4
    - jmespath: 1.0.1
    - joblib: 1.4.2
    - jsonpath-ng: 1.6.1
    - jupyter-client: 8.6.3
    - jupyter-core: 5.7.2
    - kiwisolver: 1.4.7
    - lightning: 2.4.0
    - lightning-utilities: 0.11.8
    - markdown: 3.7
    - markdown-it-py: 3.0.0
    - markupsafe: 3.0.2
    - matplotlib: 3.9.2
    - matplotlib-inline: 0.1.7
    - mccabe: 0.7.0
    - mdurl: 0.1.2
    - more-itertools: 10.3.0
    - mpmath: 1.3.0
    - multidict: 6.1.0
    - multiprocess: 0.70.16
    - mypy: 1.13.0
    - mypy-extensions: 1.0.0
    - nbqa: 1.9.0
    - nest-asyncio: 1.6.0
    - networkx: 3.2.1
    - nodeenv: 1.9.1
    - numpy: 1.26.4
    - nvidia-cublas-cu12:
    - nvidia-cuda-cupti-cu12: 12.4.127
    - nvidia-cuda-nvrtc-cu12: 12.4.127
    - nvidia-cuda-runtime-cu12: 12.4.127
    - nvidia-cudnn-cu12:
    - nvidia-cufft-cu12:
    - nvidia-curand-cu12:
    - nvidia-cusolver-cu12:
    - nvidia-cusparse-cu12:
    - nvidia-nccl-cu12: 2.21.5
    - nvidia-nvjitlink-cu12: 12.4.127
    - nvidia-nvtx-cu12: 12.4.127
    - omegaconf: 2.3.0
    - packaging: 24.1
    - pandas: 1.5.3
    - parso: 0.8.4
    - pathspec: 0.12.1
    - pexpect: 4.9.0
    - pickleshare: 0.7.5
    - pillow: 11.0.0
    - pip: 24.3.1
    - platformdirs: 4.3.6
    - pluggy: 1.5.0
    - ply: 3.11
    - pre-commit: 4.0.1
    - pre-commit-hooks: 5.0.0
    - prompt-toolkit: 3.0.48
    - propcache: 0.2.0
    - protobuf: 5.28.3
    - psutil: 6.1.0
    - ptyprocess: 0.7.0
    - pure-eval: 0.2.3
    - pyarrow: 14.0.1
    - pyarrow-hotfix: 0.6
    - pycodestyle: 2.12.1
    - pydantic: 1.10.18
    - pyflakes: 3.2.0
    - pygments: 2.18.0
    - pylint: 3.3.1
    - pyparsing: 3.2.0
    - pyrootutils: 1.0.4
    - pytest: 8.3.3
    - pytest-cov: 6.0.0
    - pytest-mock: 3.14.0
    - python-dateutil: 2.9.0
    - python-dotenv: 1.0.1
    - pytorch-lightning: 2.4.0
    - pytz: 2024.2
    - pyyaml: 6.0.1
    - pyzmq: 26.2.0
    - regex: 2024.9.11
    - requests: 2.32.3
    - rich: 13.9.4
    - ruamel.yaml: 0.18.6
    - ruamel.yaml.clib: 0.2.12
    - rubicon-ml: 0.10.3
    - s3fs: 0.4.2
    - s3transfer: 0.10.3
    - safetensors: 0.4.5
    - scikit-learn: 1.5.2
    - scipy: 1.13.1
    - seaborn: 0.13.2
    - setuptools: 75.3.0
    - six: 1.16.0
    - smmap: 5.0.1
    - stack-data: 0.6.2
    - sympy: 1.13.1
    - tensorboard: 2.18.0
    - tensorboard-data-server: 0.7.2
    - threadpoolctl: 3.5.0
    - tokenize-rt: 6.1.0
    - tokenizers: 0.20.3
    - tomli: 2.0.2
    - tomlkit: 0.13.2
    - torch: 2.5.1
    - torchmetrics: 1.5.1
    - tornado: 6.4.1
    - tqdm: 4.66.6
    - traitlets: 5.14.3
    - transformers: 4.45.2
    - triton: 3.1.0
    - typeguard: 4.3.0
    - typing-extensions: 4.12.2
    - urllib3: 1.26.20
    - virtualenv: 20.27.1
    - wcwidth: 0.2.13
    - werkzeug: 3.1.2
    - wheel: 0.44.0
    - xxhash: 3.5.0
    - yarl: 1.17.1
    - zipp: 3.20.2
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - processor: x86_64
    - python: 3.9.20
    - release: 5.10.226-214.880.amzn2.x86_64
    - version: Proposal for help #1 SMP Tue Oct 8 16:18:15 UTC 2024

More info

No response

cc @Borda

@ssharpe42 ssharpe42 added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Nov 7, 2024
@lantiga lantiga added discussion In a discussion stage and removed bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.4.x labels Nov 18, 2024
Copy link

lantiga commented Nov 18, 2024

Hi @ssharpe42 this is not a bug but rather an intended behavior. In DDP processes are running independently under the assumption that they will perform the exact same instructions (on different data). Any divergence will lead to a hang as soon as the code reaches a collective call.

Of course you can hack your system so that you accumulate gradients on different numbers of batches, but there might be dragons related to distributed sampling, metrics, unforseen collective calls, aggregation of the loss, scaling of the gradients etc.

There's some discussion on the topic here and here.

We are not planning to work on this right now, but I'd be interested in following the work if you decide to go for it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
discussion In a discussion stage
None yet

No branches or pull requests

2 participants