Skip to content

Commit

Permalink
Properly manage fetcher.done with dataloader_iter (#18376)
Browse files Browse the repository at this point in the history
(cherry picked from commit 9496d9a)
  • Loading branch information
carmocca authored and Borda committed Aug 28, 2023
1 parent 13cfe95 commit afe1412
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 92 deletions.
2 changes: 1 addition & 1 deletion src/lightning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from lightning.pytorch.core import LightningDataModule, LightningModule # noqa: E402
from lightning.pytorch.trainer import Trainer # noqa: E402

import lightning.app # isort: skip # noqa: E402 F401
import lightning.app # isort: skip # noqa: E402, F401


__all__ = [
Expand Down
126 changes: 54 additions & 72 deletions src/lightning/pytorch/loops/fetchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Iterable, Iterator, List, Optional, Sized, Tuple, Union
from typing import Any, Iterator, List, Optional, Tuple, Union

from torch.utils.data.dataloader import DataLoader

from lightning.fabric.utilities.data import has_len
from lightning.pytorch.utilities.combined_loader import (
_Sequential,
_shutdown_workers_and_reset_iterator,
CombinedLoader,
)
from lightning.fabric.utilities.data import sized_len
from lightning.pytorch.utilities.combined_loader import _Sequential, CombinedLoader
from lightning.pytorch.utilities.exceptions import MisconfigurationException


Expand All @@ -31,61 +25,65 @@ def _profile_nothing() -> None:

class _DataFetcher(Iterator):
def __init__(self) -> None:
self._dataloader: Optional[Iterable] = None
self.dataloader_iter: Optional[Iterator] = None
self._combined_loader: Optional[CombinedLoader] = None
self.iterator: Optional[Iterator] = None
self.fetched: int = 0
self.done: bool = False
self.length: Optional[int] = None
self._start_profiler = _profile_nothing
self._stop_profiler = _profile_nothing

def setup(self, dataloader: Iterable) -> None:
self._dataloader = dataloader

@property
def dataloader(self) -> Iterable:
if self._dataloader is None:
def combined_loader(self) -> CombinedLoader:
if self._combined_loader is None:
raise MisconfigurationException(
f"`{self.__class__.__name__}` should have been `setup` with a dataloader iterable."
f"`{self.__class__.__name__}` should have been `setup` with a `CombinedLoader`."
)
return self._dataloader
return self._combined_loader

def setup(self, combined_loader: CombinedLoader) -> None:
self._combined_loader = combined_loader
self.length = sized_len(combined_loader)
self.done = self.length == 0

def __iter__(self) -> "_DataFetcher":
self.reset()
self.dataloader_iter = iter(self.dataloader)
self.iterator = iter(self.combined_loader)
return self

def __next__(self) -> Any:
assert (iterator := self.iterator) is not None
self._start_profiler()
assert self.dataloader_iter is not None
try:
data = next(self.dataloader_iter)
except StopIteration as ex:
batch = next(iterator)
except StopIteration:
self.done = True
raise ex
raise
finally:
self._stop_profiler()
self.fetched += 1
return data
if self.length is not None:
self.done = self.fetched >= self.length
return batch

def reset(self) -> None:
self.fetched = 0
self.done = False

def teardown(self) -> None:
self.reset()
if isinstance(self._dataloader, CombinedLoader):
self._dataloader.reset()
if isinstance(self._dataloader, DataLoader):
_shutdown_workers_and_reset_iterator(self._dataloader)
self.dataloader_iter = None
if self._combined_loader is not None:
self._combined_loader.reset()
self.iterator = None


class _PrefetchDataFetcher(_DataFetcher):
"""This class is used to control batch fetching flow.
Args:
prefetch_batches: Number of batches to pre-fetch. Pre-fetching at least 1 batch is necessary to properly track
whether a batch is the last one (available with :attr:`self.done`) when the length is not available.
whether a batch is the last one (available with :attr:`self.done`) when the length is not available. The
value of this argument is ignored when the length is available.
"""

Expand All @@ -95,70 +93,42 @@ def __init__(self, prefetch_batches: int = 1) -> None:
raise ValueError("`prefetch_batches` should at least be 0.")
self.prefetch_batches = prefetch_batches
self.batches: List[Any] = []
self._has_len = False

def setup(self, dataloader: Iterable) -> None:
super().setup(dataloader)
self._has_len = has_len(dataloader)

def __iter__(self) -> "_PrefetchDataFetcher":
super().__iter__()
if self._has_len:
if self.length is not None:
# ignore pre-fetching, it's not necessary
return self
# prefetch batches to know when the iterator will be exhausted in advance
iterator = self.dataloader_iter
assert iterator is not None
for _ in range(self.prefetch_batches):
try:
self._fetch_next_batch(iterator)
batch = super().__next__()
self.batches.append(batch)
except StopIteration:
# this would only happen when prefetch_batches > the number of batches available and makes
# `__next__` jump directly to the empty iterator case without trying to fetch again
self.done = True
break
return self

def __next__(self) -> Any:
assert self.dataloader_iter is not None
if self.batches:
# there are pre-fetched batches already from a previous `prefetching` call.
# consume one
batch = self.batches.pop(0)
try:
# refill the consumed batch
self._fetch_next_batch(self.dataloader_iter)
self.batches.append(super().__next__())
except StopIteration:
# no more batches to fetch. we are done only if all pre-fetched batches were returned
self.done = not self.batches
elif not self.done:
# this will run only when no pre-fetching was done.
try:
self._fetch_next_batch(self.dataloader_iter)
# consume the batch we just fetched
batch = self.batches.pop(0)
except StopIteration as ex:
self.done = True
raise ex
batch = super().__next__()
else:
# the iterator is empty
raise StopIteration
return batch

def _fetch_next_batch(self, iterator: Iterator) -> None:
self._start_profiler()
try:
batch = next(iterator)
finally:
self._stop_profiler()
self.fetched += 1
if self._has_len:
# when we don't prefetch but the dataloader is sized, we use the length for `done`
dataloader = self.dataloader
assert isinstance(dataloader, Sized) # `_has_len` is True
self.done = self.fetched >= len(dataloader)
self.batches.append(batch)

def reset(self) -> None:
super().reset()
self.batches = []
Expand All @@ -180,30 +150,42 @@ def training_step(self, dataloader_iter: Iterator, batch_idx: int) -> None:

def __iter__(self) -> "_DataLoaderIterDataFetcher":
super().__iter__()
self.iterator = iter(_DataFetcherWrapper(self))
self.iterator_wrapper = iter(_DataFetcherWrapper(self))
return self

def __next__(self) -> Union["_DataFetcherWrapper", Tuple["_DataFetcherWrapper", int, int]]:
if self.done:
raise StopIteration
assert isinstance(self.iterator, _DataFetcherWrapper)
assert isinstance(self.iterator_wrapper, _DataFetcherWrapper)
if self._is_sequential:
sequential_mode = self.dataloader._iterator
assert isinstance(sequential_mode, _Sequential)
batch_idx = sequential_mode._idx
dataloader_idx = sequential_mode._iterator_idx
return self.iterator, batch_idx, dataloader_idx
return self.iterator
mode = self.combined_loader._iterator
assert isinstance(mode, _Sequential)
batch_idx = mode._idx
dataloader_idx = mode._iterator_idx
return self.iterator_wrapper, batch_idx, dataloader_idx
return self.iterator_wrapper

@property
def _is_sequential(self) -> bool:
return isinstance(self.dataloader, CombinedLoader) and self.dataloader._mode == "sequential"
return self.combined_loader._mode == "sequential"


class _DataFetcherWrapper(Iterator):
def __init__(self, data_fetcher: _DataLoaderIterDataFetcher) -> None:
self.data_fetcher = data_fetcher

@property
def done(self) -> bool:
return self.data_fetcher.done

@property
def fetched(self) -> int:
return self.data_fetcher.fetched

@property
def length(self) -> Optional[int]:
return self.data_fetcher.length

def __next__(self) -> Any:
out = super(_DataLoaderIterDataFetcher, self.data_fetcher).__next__()
if self.data_fetcher._is_sequential:
Expand Down
Loading

0 comments on commit afe1412

Please sign in to comment.