Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

calling iter twice messes up dataloaders with queues #19427

Open
ben-da6 opened this issue Feb 7, 2024 · 4 comments
Open

calling iter twice messes up dataloaders with queues #19427

ben-da6 opened this issue Feb 7, 2024 · 4 comments
Labels
bug Something isn't working data handling Generic data-related topic loops Related to the Loop API ver: 2.1.x
Milestone

Comments

@ben-da6
Copy link

ben-da6 commented Feb 7, 2024

Bug description

This bug has reappeared #18414

We now call iter() twice in different places:

What version are you seeing the problem on?

v2.1

How to reproduce the bug

import multiprocessing as mp
from queue import Queue
from typing import Iterator

import numpy as np
from lightning import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from torch.utils.data import DataLoader, IterableDataset


class QueueDataset(IterableDataset):
    def __init__(self, queue: Queue) -> None:
        super().__init__()
        self.queue = queue

    def __iter__(self) -> Iterator:
        for k in range(5):
            print(f"getting {k}")
            tensor, index = self.queue.get(timeout=10)
            print(f"got {index}")
            yield tensor


if __name__ == "__main__":
    q = mp.Queue()
    arr = np.random.random([1, 32]).astype(np.float32)
    for ind in range(10):
        q.put((arr, ind))
    max_epoch = 1
    dataloader = DataLoader(QueueDataset(q), num_workers=1, batch_size=None, persistent_workers=True)
    trainer = Trainer(max_epochs=max_epoch, enable_progress_bar=False, devices=1)
    trainer.fit(BoringModel(), dataloader)
    trainer.save_checkpoint("model.ckpt")

    # q now has the next 5 elems in
    # resuming training we will hit the double iter() issue
    dataloader = DataLoader(QueueDataset(q), num_workers=1, batch_size=None, persistent_workers=True)
    trainer = Trainer(max_epochs=max_epoch + 1, enable_progress_bar=False, devices=1)
    trainer.fit(BoringModel(), dataloader, ckpt_path="model.ckpt")

Error messages and logs

relevant logs are:

# first epoch all good
getting 0
got 0
getting 1
got 1
getting 2
got 2
getting 3
got 3
getting 4
got 4

# second epoch we start getting from the queue twice!
# from fit loop iter()
getting 0
got 5
getting 1
got 6
getting 2
got 7
# from training_epoch loop iter()
getting 0
got 8
getting 1
got 9
getting 2

Environment

lighting==2.1.4

More info

No response

cc @justusschock @awaelchli @carmocca

@ben-da6 ben-da6 added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Feb 7, 2024
@awaelchli
Copy link
Contributor

This condition here is meant to prevent the iter() from getting called a second time, because in this case restarting should be True.

# `iter()` was called once in `FitLoop.setup_data()` already
if self.trainer.current_epoch > 0 and not self.restarting:
iter(data_fetcher) # creates the iterator inside the fetcher

But it isn't. The problem is that the fit loop sets restarting=False even though we are resuming, due to the logic here:

def restarting(self, restarting: bool) -> None:
# if the last epoch completely finished, we are not actually restarting
values = self.epoch_progress.current.ready, self.epoch_progress.current.started
epoch_unfinished = any(v != self.epoch_progress.current.processed for v in values)
restarting = restarting and epoch_unfinished or self._iteration_based_training()
_Loop.restarting.fset(self, restarting) # call the parent setter

This is tricky to solve @carmocca. The logic probably needs to be lifted up into the fit loop before epoch_loop.run(), with a different conditioning that does not rely on restarting.

@awaelchli awaelchli added loops Related to the Loop API data handling Generic data-related topic and removed needs triage Waiting to be triaged by maintainers labels Feb 11, 2024
@awaelchli awaelchli added this to the 2.2.x milestone Feb 11, 2024
@carmocca
Copy link
Contributor

I didn't look too deeply. Couldn't we check restarting too for the FitLoop's iter call? We have a lot of tests around this so If a solution passes them we should be good.

@ben-da6
Copy link
Author

ben-da6 commented Feb 14, 2024

The problem in the restarting property is self._iteration_based_training() is False

@ben-da6
Copy link
Author

ben-da6 commented Feb 14, 2024

Also since this has appeared twice now, and its the sort of bug which is hard to track down could we add a test like my example?

@awaelchli awaelchli modified the milestones: 2.2.x, 2.3.x Jun 13, 2024
@awaelchli awaelchli modified the milestones: 2.3.x, 2.4.x Aug 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working data handling Generic data-related topic loops Related to the Loop API ver: 2.1.x
Projects
None yet
Development

No branches or pull requests

3 participants