Skip to content

Commit

Permalink
add predict
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Apr 5, 2021
1 parent e20ca20 commit 2f6313b
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,7 @@ def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage):

def test_warning_with_iterable_dataset_and_len(tmpdir):
""" Tests that a warning message is shown when an IterableDataset defines `__len__`. """
model = EvalModelTemplate()
model = BoringModel()
original_dataset = model.train_dataloader().dataset

class IterableWithoutLen(IterableDataset):
Expand All @@ -660,6 +660,8 @@ def __len__(self):
trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader])
with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
trainer.test(model, test_dataloaders=[dataloader])
with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
trainer.predict(model, dataloaders=[dataloader])

# without __len__ defined
dataloader = DataLoader(IterableWithoutLen(), batch_size=16)
Expand All @@ -669,6 +671,7 @@ def __len__(self):
trainer.validate(model, val_dataloaders=dataloader)
trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader])
trainer.test(model, test_dataloaders=dataloader)
trainer.predict(model, dataloaders=dataloader)


@RunIf(min_gpus=2)
Expand Down

0 comments on commit 2f6313b

Please sign in to comment.