Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Fix sampler instance support #1204

Merged
merged 5 commits into from
Mar 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug where buffers in loss functions were not correctly registered in the `Task` ([#1203](https://github.com/PyTorchLightning/lightning-flash/pull/1203))

- Fixed support for passing a sampler instance to `from_*` methods / the `DataModule` ([#1204](https://github.com/PyTorchLightning/lightning-flash/pull/1204))

## [0.7.0] - 2022-02-15

### Added
Expand Down
37 changes: 35 additions & 2 deletions flash/core/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,37 @@ class DataModule(pl.LightningDataModule):
num_workers: The number of workers to use for parallelized loading.
sampler: A sampler following the :class:`~torch.utils.data.sampler.Sampler` type.
Will be passed to the DataLoader for the training dataset. Defaults to None.

Examples
________

.. testsetup::

>>> from flash import DataModule
>>> from flash.core.utilities.stages import RunningStage
>>> from torch.utils.data.sampler import SequentialSampler, WeightedRandomSampler
>>> class TestInput(Input):
... def train_load_data(self, _):
... return [(0, 1, 2, 3), (0, 1, 2, 3)]
>>> train_input = TestInput(RunningStage.TRAINING, [1])

You can provide the sampler to use for the train dataloader using the ``sampler`` argument.
The sampler can be a function or type that needs the dataset as an argument:

.. doctest::

>>> datamodule = DataModule(train_input, sampler=SequentialSampler, batch_size=1)
>>> print(datamodule.train_dataloader().sampler) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
<torch.utils.data.sampler.SequentialSampler object at ...>

Alternatively, you can pass a sampler instance:

.. doctest::

>>> datamodule = DataModule(train_input, sampler=WeightedRandomSampler([0.1, 0.5], 2), batch_size=1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this is the alternative for CrossEntropyLoss weights?

>>> print(datamodule.train_dataloader().sampler) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
<torch.utils.data.sampler.WeightedRandomSampler object at ...>

"""

input_transform_cls = InputTransform
Expand All @@ -84,7 +115,7 @@ def __init__(
val_split: Optional[float] = None,
batch_size: Optional[int] = None,
num_workers: int = 0,
sampler: Optional[Type[Sampler]] = None,
sampler: Optional[Union[Callable, Sampler, Type[Sampler]]] = None,
pin_memory: bool = True,
persistent_workers: bool = False,
) -> None:
Expand Down Expand Up @@ -206,8 +237,10 @@ def _train_dataloader(self) -> DataLoader:
if self.sampler is None:
sampler = None
shuffle = not isinstance(train_ds, IterableDataset)
else:
elif callable(self.sampler):
sampler = self.sampler(train_ds)
else:
sampler = self.sampler

if isinstance(getattr(self, "trainer", None), pl.Trainer):
dataloader = self.trainer.lightning_module.process_train_dataset(
Expand Down
18 changes: 12 additions & 6 deletions tests/core/data/test_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,22 +419,28 @@ def validation_step(self, batch, batch_idx):
trainer.fit(model, datamodule=datamodule)


@pytest.mark.parametrize("sampler, callable", [(mock.MagicMock(), True), (mock.NonCallableMock(), False)])
@mock.patch("flash.core.data.data_module.DataLoader")
def test_dataloaders_with_sampler(mock_dataloader):
mock_sampler = mock.MagicMock()
def test_dataloaders_with_sampler(mock_dataloader, sampler, callable):
train_input = TestInput(RunningStage.TRAINING, [1])
datamodule = DataModule(
TestInput(RunningStage.TRAINING, [1]),
train_input,
TestInput(RunningStage.VALIDATING, [1]),
TestInput(RunningStage.TESTING, [1]),
batch_size=2,
num_workers=0,
sampler=mock_sampler,
sampler=sampler,
)
assert datamodule.sampler is mock_sampler

assert datamodule.sampler is sampler
dl = datamodule.train_dataloader()

if callable:
sampler.assert_called_once_with(train_input)

kwargs = mock_dataloader.call_args[1]
assert "sampler" in kwargs
assert kwargs["sampler"] is mock_sampler.return_value
assert kwargs["sampler"] is (sampler.return_value if callable else sampler)
for dl in [datamodule.val_dataloader(), datamodule.test_dataloader()]:
kwargs = mock_dataloader.call_args[1]
assert "sampler" not in kwargs
Expand Down