Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Aug 30, 2023
1 parent 67aac4d commit 49ce20f
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/lightning/pytorch/loops/fetchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def setup(self, combined_loader: CombinedLoader) -> None:
# self.done = self.length == 0

def __iter__(self) -> "_DataFetcher":
self.reset()
self.iterator = iter(self.combined_loader)
self.reset()
return self

def __next__(self) -> Any:
Expand Down
14 changes: 9 additions & 5 deletions src/lightning/pytorch/utilities/combined_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,11 @@ class _MinSize(_ModeIterator[List]):
def __next__(self) -> List:
return [next(it) for it in self.iterators]

# def __len__(self) -> Optional[int]:
# return min(self.limits) if self.limits is not None else None
def __len__(self) -> Optional[int]:
lengths = _get_iterables_lengths(self.iterables)
if self.limits is not None:
return min([min(length, limit) for length, limit in zip(lengths, self.limits)])
return min(lengths)


class _Sequential(_ModeIterator[Tuple[Any, int, int]]):
Expand Down Expand Up @@ -355,8 +358,9 @@ def _shutdown_workers_and_reset_iterator(dataloader: object) -> None:
def _get_iterables_lengths(iterables):
lengths = []
for iterable in iterables:
length = sized_len(iterable)
if length is None:
raise NotImplementedError(f"`{type(iterable).__name__}` does not define `__len__`")
if (length := sized_len(iterable)) is None:
length = float("inf")
# if length is None:
# raise NotImplementedError(f"`{type(iterable).__name__}` does not define `__len__`")
lengths.append(length)
return lengths
19 changes: 19 additions & 0 deletions tests/tests_pytorch/loops/test_fetchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,3 +540,22 @@ def test_done_dataloader_iter(iterable):
with pytest.raises(StopIteration):
next(dataloader_iter)
assert dataloader_iter.done


@pytest.mark.parametrize("iterable", [[1, 2, 3], IterDataset()])
def test_done_dataloader_iter_with_limit(iterable):
loader = CombinedLoader(iterable)
fetcher = _DataLoaderIterDataFetcher()
fetcher.setup(loader)
iter(fetcher)
loader._iterator.limits = [1]
fetcher.reset()

assert not fetcher.done
dataloader_iter = next(fetcher)
assert not dataloader_iter.done
assert next(dataloader_iter) == 1
assert dataloader_iter.done
assert fetcher.done
# with pytest.raises(StopIteration): # TODO why is it not raising?
# next(dataloader_iter)

0 comments on commit 49ce20f

Please sign in to comment.