From 49ce20f99b6b6b93d434ea9f67bd50cf7163d770 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 31 Aug 2023 00:10:48 +0200 Subject: [PATCH] fixes --- src/lightning/pytorch/loops/fetchers.py | 2 +- .../pytorch/utilities/combined_loader.py | 14 +++++++++----- tests/tests_pytorch/loops/test_fetchers.py | 19 +++++++++++++++++++ 3 files changed, 29 insertions(+), 6 deletions(-) diff --git a/src/lightning/pytorch/loops/fetchers.py b/src/lightning/pytorch/loops/fetchers.py index 907ab638fcbe8..d175320933718 100644 --- a/src/lightning/pytorch/loops/fetchers.py +++ b/src/lightning/pytorch/loops/fetchers.py @@ -48,8 +48,8 @@ def setup(self, combined_loader: CombinedLoader) -> None: # self.done = self.length == 0 def __iter__(self) -> "_DataFetcher": - self.reset() self.iterator = iter(self.combined_loader) + self.reset() return self def __next__(self) -> Any: diff --git a/src/lightning/pytorch/utilities/combined_loader.py b/src/lightning/pytorch/utilities/combined_loader.py index feda7a63acc7a..1f8cd30ca31dc 100644 --- a/src/lightning/pytorch/utilities/combined_loader.py +++ b/src/lightning/pytorch/utilities/combined_loader.py @@ -108,8 +108,11 @@ class _MinSize(_ModeIterator[List]): def __next__(self) -> List: return [next(it) for it in self.iterators] - # def __len__(self) -> Optional[int]: - # return min(self.limits) if self.limits is not None else None + def __len__(self) -> Optional[int]: + lengths = _get_iterables_lengths(self.iterables) + if self.limits is not None: + return min([min(length, limit) for length, limit in zip(lengths, self.limits)]) + return min(lengths) class _Sequential(_ModeIterator[Tuple[Any, int, int]]): @@ -355,8 +358,9 @@ def _shutdown_workers_and_reset_iterator(dataloader: object) -> None: def _get_iterables_lengths(iterables): lengths = [] for iterable in iterables: - length = sized_len(iterable) - if length is None: - raise NotImplementedError(f"`{type(iterable).__name__}` does not define `__len__`") + if (length := sized_len(iterable)) is None: + length = float("inf") + # if length is None: + # raise NotImplementedError(f"`{type(iterable).__name__}` does not define `__len__`") lengths.append(length) return lengths diff --git a/tests/tests_pytorch/loops/test_fetchers.py b/tests/tests_pytorch/loops/test_fetchers.py index 5b60a278a0e5f..4255619553ef5 100644 --- a/tests/tests_pytorch/loops/test_fetchers.py +++ b/tests/tests_pytorch/loops/test_fetchers.py @@ -540,3 +540,22 @@ def test_done_dataloader_iter(iterable): with pytest.raises(StopIteration): next(dataloader_iter) assert dataloader_iter.done + + +@pytest.mark.parametrize("iterable", [[1, 2, 3], IterDataset()]) +def test_done_dataloader_iter_with_limit(iterable): + loader = CombinedLoader(iterable) + fetcher = _DataLoaderIterDataFetcher() + fetcher.setup(loader) + iter(fetcher) + loader._iterator.limits = [1] + fetcher.reset() + + assert not fetcher.done + dataloader_iter = next(fetcher) + assert not dataloader_iter.done + assert next(dataloader_iter) == 1 + assert dataloader_iter.done + assert fetcher.done + # with pytest.raises(StopIteration): # TODO why is it not raising? + # next(dataloader_iter)