Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fix sampler instance support (#1204)
Browse files Browse the repository at this point in the history
Co-authored-by: Kushashwa Ravi Shrimali <[email protected]>
  • Loading branch information
ethanwharris and krshrimali authored Mar 1, 2022
1 parent 8abc10a commit f6817b6
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 8 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug where buffers in loss functions were not correctly registered in the `Task` ([#1203](https://github.com/PyTorchLightning/lightning-flash/pull/1203))

- Fixed support for passing a sampler instance to `from_*` methods / the `DataModule` ([#1204](https://github.com/PyTorchLightning/lightning-flash/pull/1204))

## [0.7.0] - 2022-02-15

### Added
Expand Down
37 changes: 35 additions & 2 deletions flash/core/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,37 @@ class DataModule(pl.LightningDataModule):
num_workers: The number of workers to use for parallelized loading.
sampler: A sampler following the :class:`~torch.utils.data.sampler.Sampler` type.
Will be passed to the DataLoader for the training dataset. Defaults to None.
Examples
________
.. testsetup::
>>> from flash import DataModule
>>> from flash.core.utilities.stages import RunningStage
>>> from torch.utils.data.sampler import SequentialSampler, WeightedRandomSampler
>>> class TestInput(Input):
... def train_load_data(self, _):
... return [(0, 1, 2, 3), (0, 1, 2, 3)]
>>> train_input = TestInput(RunningStage.TRAINING, [1])
You can provide the sampler to use for the train dataloader using the ``sampler`` argument.
The sampler can be a function or type that needs the dataset as an argument:
.. doctest::
>>> datamodule = DataModule(train_input, sampler=SequentialSampler, batch_size=1)
>>> print(datamodule.train_dataloader().sampler) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
<torch.utils.data.sampler.SequentialSampler object at ...>
Alternatively, you can pass a sampler instance:
.. doctest::
>>> datamodule = DataModule(train_input, sampler=WeightedRandomSampler([0.1, 0.5], 2), batch_size=1)
>>> print(datamodule.train_dataloader().sampler) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
<torch.utils.data.sampler.WeightedRandomSampler object at ...>
"""

input_transform_cls = InputTransform
Expand All @@ -84,7 +115,7 @@ def __init__(
val_split: Optional[float] = None,
batch_size: Optional[int] = None,
num_workers: int = 0,
sampler: Optional[Type[Sampler]] = None,
sampler: Optional[Union[Callable, Sampler, Type[Sampler]]] = None,
pin_memory: bool = True,
persistent_workers: bool = False,
) -> None:
Expand Down Expand Up @@ -206,8 +237,10 @@ def _train_dataloader(self) -> DataLoader:
if self.sampler is None:
sampler = None
shuffle = not isinstance(train_ds, IterableDataset)
else:
elif callable(self.sampler):
sampler = self.sampler(train_ds)
else:
sampler = self.sampler

if isinstance(getattr(self, "trainer", None), pl.Trainer):
dataloader = self.trainer.lightning_module.process_train_dataset(
Expand Down
18 changes: 12 additions & 6 deletions tests/core/data/test_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,22 +419,28 @@ def validation_step(self, batch, batch_idx):
trainer.fit(model, datamodule=datamodule)


@pytest.mark.parametrize("sampler, callable", [(mock.MagicMock(), True), (mock.NonCallableMock(), False)])
@mock.patch("flash.core.data.data_module.DataLoader")
def test_dataloaders_with_sampler(mock_dataloader):
mock_sampler = mock.MagicMock()
def test_dataloaders_with_sampler(mock_dataloader, sampler, callable):
train_input = TestInput(RunningStage.TRAINING, [1])
datamodule = DataModule(
TestInput(RunningStage.TRAINING, [1]),
train_input,
TestInput(RunningStage.VALIDATING, [1]),
TestInput(RunningStage.TESTING, [1]),
batch_size=2,
num_workers=0,
sampler=mock_sampler,
sampler=sampler,
)
assert datamodule.sampler is mock_sampler

assert datamodule.sampler is sampler
dl = datamodule.train_dataloader()

if callable:
sampler.assert_called_once_with(train_input)

kwargs = mock_dataloader.call_args[1]
assert "sampler" in kwargs
assert kwargs["sampler"] is mock_sampler.return_value
assert kwargs["sampler"] is (sampler.return_value if callable else sampler)
for dl in [datamodule.val_dataloader(), datamodule.test_dataloader()]:
kwargs = mock_dataloader.call_args[1]
assert "sampler" not in kwargs
Expand Down

0 comments on commit f6817b6

Please sign in to comment.