Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly manage fetcher.done with dataloader_iter #18376

Merged
merged 5 commits into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 21 additions & 28 deletions src/lightning/pytorch/loops/fetchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(self) -> None:
self.iterator: Optional[Iterator] = None
self.fetched: int = 0
self.done: bool = False
self.length: Optional[int] = None
self._start_profiler = _profile_nothing
self._stop_profiler = _profile_nothing

Expand All @@ -42,24 +43,32 @@ def combined_loader(self) -> CombinedLoader:

def setup(self, combined_loader: CombinedLoader) -> None:
self._combined_loader = combined_loader
self.length = sized_len(combined_loader)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm exploring some ideas how we could make this length respect the limit_x_batches settings from the loop: #18436

self.done = self.length == 0
carmocca marked this conversation as resolved.
Show resolved Hide resolved

def __iter__(self) -> "_DataFetcher":
self.reset()
self.iterator = iter(self.combined_loader)
return self

def __next__(self) -> Any:
self._start_profiler()
assert self.iterator is not None
data = self._fetch_next_batch(self.iterator)
return data

def _fetch_next_batch(self, iterator: Iterator) -> None:
self._start_profiler()
try:
data = next(self.iterator)
except StopIteration as ex:
batch = next(iterator)
except StopIteration:
self.done = True
raise ex
raise
finally:
self._stop_profiler()
self.fetched += 1
return data
if self.length is not None:
self.done = self.fetched >= self.length
return batch

def reset(self) -> None:
self.fetched = 0
Expand All @@ -77,7 +86,8 @@ class _PrefetchDataFetcher(_DataFetcher):

Args:
prefetch_batches: Number of batches to pre-fetch. Pre-fetching at least 1 batch is necessary to properly track
whether a batch is the last one (available with :attr:`self.done`) when the length is not available.
whether a batch is the last one (available with :attr:`self.done`) when the length is not available. The
value of this argument is ignored when the length is available.

"""

Expand All @@ -87,15 +97,10 @@ def __init__(self, prefetch_batches: int = 1) -> None:
raise ValueError("`prefetch_batches` should at least be 0.")
self.prefetch_batches = prefetch_batches
self.batches: List[Any] = []
self._len: Optional[int] = None
carmocca marked this conversation as resolved.
Show resolved Hide resolved

def setup(self, combined_loader: CombinedLoader) -> None:
super().setup(combined_loader)
self._len = sized_len(combined_loader)

def __iter__(self) -> "_PrefetchDataFetcher":
super().__iter__()
if self._len is not None:
if self.length is not None:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
# ignore pre-fetching, it's not necessary
return self
# prefetch batches to know when the iterator will be exhausted in advance
Expand All @@ -107,7 +112,6 @@ def __iter__(self) -> "_PrefetchDataFetcher":
except StopIteration:
# this would only happen when prefetch_batches > the number of batches available and makes
# `__next__` jump directly to the empty iterator case without trying to fetch again
self.done = True
carmocca marked this conversation as resolved.
Show resolved Hide resolved
break
return self

Expand All @@ -125,27 +129,16 @@ def __next__(self) -> Any:
self.done = not self.batches
elif not self.done:
# this will run only when no pre-fetching was done.
try:
self._fetch_next_batch(self.iterator)
# consume the batch we just fetched
batch = self.batches.pop(0)
except StopIteration as ex:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self.done = True
raise ex
self._fetch_next_batch(self.iterator)
# consume the batch we just fetched
batch = self.batches.pop(0)
else:
# the iterator is empty
raise StopIteration
return batch

def _fetch_next_batch(self, iterator: Iterator) -> None:
self._start_profiler()
try:
batch = next(iterator)
finally:
self._stop_profiler()
self.fetched += 1
if self._len is not None:
self.done = self.fetched >= self._len
batch = super()._fetch_next_batch(iterator)
self.batches.append(batch)

def reset(self) -> None:
Expand Down
83 changes: 64 additions & 19 deletions tests/tests_pytorch/loops/test_fetchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,17 @@ def __getitem__(self, idx):
return idx + 1


@pytest.mark.parametrize("use_combined_loader", [False, True])
@pytest.mark.parametrize("multiple_iterables", [False, True])
@pytest.mark.parametrize("dataset_cls", [IterDataset, SizedDataset])
@pytest.mark.parametrize("prefetch_batches", list(range(5)))
def test_prefetch_iterator(use_combined_loader, dataset_cls, prefetch_batches):
def test_prefetch_iterator(multiple_iterables, dataset_cls, prefetch_batches):
fetcher = _PrefetchDataFetcher(prefetch_batches=prefetch_batches)
assert fetcher.prefetch_batches == prefetch_batches

if use_combined_loader:
if multiple_iterables:
loader = CombinedLoader([DataLoader(dataset_cls()), DataLoader(dataset_cls())])
else:
loader = DataLoader(dataset_cls())
loader = CombinedLoader(DataLoader(dataset_cls()))
carmocca marked this conversation as resolved.
Show resolved Hide resolved
fetcher.setup(loader)

def generate():
Expand All @@ -67,7 +67,7 @@ def generate():
fetched = (
[1, 2, 3] if dataset_cls is SizedDataset else [1, 2, 3, 3, 3, 3, 3][prefetch_batches : prefetch_batches + 3]
)
batches = [[1, 1], [2, 2], [3, 3]] if use_combined_loader else [1, 2, 3]
batches = [[1, 1], [2, 2], [3, 3]] if multiple_iterables else [1, 2, 3]
expected = list(zip(fetched, batches, is_last_batch))
assert len(expected) == 3

Expand All @@ -77,8 +77,8 @@ def generate():
assert fetcher.fetched == 3


@pytest.mark.parametrize("use_combined_loader", [False, True])
def test_profiler_closing(use_combined_loader):
@pytest.mark.parametrize("multiple_iterables", [False, True])
def test_profiler_closing(multiple_iterables):
"""Tests if the profiler terminates upon raising a StopIteration on an iterable dataset."""

class TestDataset(IterableDataset):
Expand All @@ -89,10 +89,10 @@ def __iter__(self):
return iter(self.list)

fetcher = _PrefetchDataFetcher()
if use_combined_loader:
if multiple_iterables:
loader = CombinedLoader([DataLoader(TestDataset()), DataLoader(TestDataset())])
else:
loader = DataLoader(TestDataset())
loader = CombinedLoader(TestDataset())
fetcher.setup(loader)
profiler = SimpleProfiler()
fetcher._start_profiler = lambda: profiler.start("test")
Expand All @@ -115,11 +115,14 @@ def __len__(self):
@pytest.mark.parametrize("dataset_cls", [EmptyIterDataset, EmptySizedDataset])
@pytest.mark.parametrize("prefetch_batches", list(range(2)))
def test_empty_prefetch_iterator(dataset_cls, prefetch_batches):
loader = DataLoader(dataset_cls())
loader = CombinedLoader(DataLoader(dataset_cls()))
fetcher = _PrefetchDataFetcher(prefetch_batches=prefetch_batches)
fetcher.setup(loader)

assert not fetcher.done
if dataset_cls is EmptySizedDataset:
assert fetcher.done # for 0 length sized datasets we know we're done already
carmocca marked this conversation as resolved.
Show resolved Hide resolved
else:
assert not fetcher.done
assert not list(fetcher)
assert fetcher.done

Expand Down Expand Up @@ -192,12 +195,14 @@ def training_step(self, dataloader_iter, batch_idx):
opt.step()

def on_train_epoch_end(self):
assert self.trainer.fit_loop.epoch_loop.batch_progress.current.ready == 33
# since the dataset is sized, the loop stops at the limit even though the training_step controls the
# consumption of batches
carmocca marked this conversation as resolved.
Show resolved Hide resolved
assert self.trainer.fit_loop.epoch_loop.batch_progress.current.ready == 32
assert self.trainer.fit_loop._data_fetcher.fetched == 64
assert self.count == 64

model = TestModel(automatic_optimization=automatic_optimization)
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, accelerator="cpu")
trainer.fit(model)


Expand Down Expand Up @@ -227,7 +232,7 @@ def predict_step(self, dataloader_iter, batch_idx):
return super().test_step(batch, batch_idx)

model = TestModel()
trainer = Trainer(default_root_dir=tmp_path, fast_dev_run=1)
trainer = Trainer(default_root_dir=tmp_path, fast_dev_run=1, accelerator="cpu")
trainer_fn = getattr(trainer, fn)
trainer_fn(model)

Expand Down Expand Up @@ -275,13 +280,15 @@ def _async_op(self, batch: Any) -> DummyWaitable:
def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT:
if self.batch_i_handle is None:
batch_i_raw = next(dataloader_iter)
self.num_batches_processed += 1
self.batch_i_handle = self._async_op(batch_i_raw)

# Invariant: _async_op for batch[i] has been initiated
batch_ip1_handle = None
is_last = False
try:
batch_ip1_raw = next(dataloader_iter)
self.num_batches_processed += 1
batch_ip1_handle = self._async_op(batch_ip1_raw)
except StopIteration:
is_last = True
Expand All @@ -294,7 +301,6 @@ def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT:
self.optimizers().zero_grad()

self.batch_i_handle = batch_ip1_handle
self.num_batches_processed += 1
carmocca marked this conversation as resolved.
Show resolved Hide resolved

return {"loss": loss, "is_last": is_last}

Expand All @@ -304,7 +310,7 @@ def train_dataloader(self):

def test_training_step_with_dataloader_access(tmpdir) -> None:
"""A baseline functional test for `training_step` with dataloader access."""
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir, accelerator="cpu")
m = AsyncBoringModel()
trainer.fit(m)
assert m.num_batches_processed == DATASET_LEN, f"Expect all {DATASET_LEN} batches to be processed."
Expand Down Expand Up @@ -333,7 +339,7 @@ def train_dataloader(self):
return DataLoader(RandomDataset(BATCH_SIZE, 2 * EXPECT_NUM_BATCHES_PROCESSED))
return DataLoader(RandomDataset(BATCH_SIZE, EXPECT_NUM_BATCHES_PROCESSED))

trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir, accelerator="cpu")
m = TestModel(trigger_stop_iteration)
trainer.fit(m)
expected = EXPECT_NUM_BATCHES_PROCESSED
Expand All @@ -350,7 +356,7 @@ class InvalidModel(AsyncBoringModel):
def on_train_batch_start(self, batch, batch_idx):
pass

trainer = Trainer(fast_dev_run=1, default_root_dir=tmpdir)
trainer = Trainer(fast_dev_run=1, default_root_dir=tmpdir, accelerator="cpu")
m = InvalidModel()
with pytest.warns(match="InvalidModel.on_train_batch_start` hook may not match"):
trainer.fit(m)
Expand All @@ -364,7 +370,7 @@ class InvalidModel(AsyncBoringModel):
def on_train_batch_end(self, *_):
pass

trainer = Trainer(fast_dev_run=1, default_root_dir=tmpdir)
trainer = Trainer(fast_dev_run=1, default_root_dir=tmpdir, accelerator="cpu")
m = InvalidModel()
with pytest.warns(match="InvalidModel.on_train_batch_end` hook may not match"):
trainer.fit(m)
Expand Down Expand Up @@ -437,6 +443,7 @@ def val_dataloader(self):
enable_checkpointing=False,
enable_progress_bar=False,
logger=False,
accelerator="cpu",
)
trainer.fit(model)
trainer.test(model)
Expand Down Expand Up @@ -487,6 +494,7 @@ def training_step(self, dataloader_iter):
enable_checkpointing=False,
enable_progress_bar=False,
logger=False,
accelerator="cpu",
)
trainer.fit(model)

Expand All @@ -498,3 +506,40 @@ def training_step(self, dataloader_iter):
durations = profiler.recorded_durations[key]
assert len(durations) == 2 # 2 polls in training_step
assert all(d > 0 for d in durations)


def test_done_consistent_across_fetchers():
iterable = [0, 1, 2]
loader = CombinedLoader(iterable)
fetcher = _PrefetchDataFetcher(prefetch_batches=0)
fetcher.setup(loader)
iter(fetcher)
assert not fetcher.done
assert next(fetcher) == 0
assert not fetcher.done
assert next(fetcher) == 1
assert not fetcher.done
assert next(fetcher) == 2
assert fetcher.done
with pytest.raises(StopIteration):
next(fetcher)
assert fetcher.done

loader = CombinedLoader(iterable)
fetcher = _DataLoaderIterDataFetcher()
fetcher.setup(loader)
iter(fetcher)
assert not fetcher.done
for i in range(5): # doesn't matter how many times you next this, the iter itself needs to be consumed
dataloader_iter = next(fetcher)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
assert not fetcher.done
assert fetcher is dataloader_iter.data_fetcher
assert next(dataloader_iter) == 0
assert not fetcher.done
assert next(dataloader_iter) == 1
assert not fetcher.done
assert next(dataloader_iter) == 2
assert fetcher.done
with pytest.raises(StopIteration):
next(dataloader_iter)
assert fetcher.done