Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

[3/N] Data sources - docs #272

Merged
merged 26 commits into from
May 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
265 changes: 161 additions & 104 deletions docs/source/custom_task.rst

Large diffs are not rendered by default.

81 changes: 35 additions & 46 deletions docs/source/general/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -185,31 +185,6 @@ Example::
# Set ``preprocess_cls`` with your custom ``preprocess``.
preprocess_cls = ImageClassificationPreprocess

@classmethod
def from_folders(
cls,
train_folder: Optional[str],
val_folder: Optional[str],
test_folder: Optional[str],
predict_folder: Optional[str],
preprocess: Optional[Preprocess] = None,
**kwargs
):

# Set a custom ``Preprocess`` if none was provided
preprocess = preprocess or cls.preprocess_cls()

# {stage}_load_data_input will be given to your
# ``Preprocess`` ``{stage}_load_data`` function.
return cls.from_load_data_inputs(
train_load_data_input=train_folder,
val_load_data_input=val_folder,
test_load_data_input=test_folder,
predict_load_data_input=predict_folder,
preprocess=preprocess, # DON'T FORGET TO PASS THE CREATED PREPROCESS
**kwargs,
)


3. The Preprocess
__________________
Expand All @@ -218,9 +193,12 @@ Finally, implement your custom ``ImageClassificationPreprocess``.

Example::

from typing import Any, Callable, Dict, Optional, Tuple, Union
import os
import numpy as np
from flash.data.data_source import DefaultDataSources
from flash.data.process import Preprocess
from flash.vision.data import ImageNumpyDataSource, ImagePathsDataSource, ImageTensorDataSource
from PIL import Image
import torchvision.transforms as T
from torch import Tensor
Expand All @@ -231,29 +209,32 @@ Example::

to_tensor = T.ToTensor()

def load_data(self, folder: str, dataset: AutoDataset) -> Iterable:
# The AutoDataset is optional but can be useful to save some metadata.

# metadata contains the image path and its corresponding label with the following structure:
# [(image_path_1, label_1), ... (image_path_n, label_n)].
metadata = make_dataset(folder)

# for the train ``AutoDataset``, we want to store the ``num_classes``.
if self.training:
dataset.num_classes = len(np.unique([m[1] for m in metadata]))

return metadata
def __init__(
self,
train_transform: Optional[Dict[str, Callable]] = None,
val_transform: Optional[Dict[str, Callable]] = None,
test_transform: Optional[Dict[str, Callable]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
):
super().__init__(
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
predict_transform=predict_transform,
data_sources={
DefaultDataSources.PATHS: ImagePathsDataSource(),
DefaultDataSources.NUMPY: ImageNumpyDataSource(),
DefaultDataSources.TENSOR: ImageTensorDataSource(),
},
default_data_source=DefaultDataSources.PATHS,
)

def predict_load_data(self, predict_folder: str) -> Iterable:
# This returns [image_path_1, ... image_path_m].
return os.listdir(folder)
def get_state_dict(self) -> Dict[str, Any]:
return {**self.transforms}

def load_sample(self, sample: Union[str, Tuple[str, int]]) -> Tuple[Image, int]
if self.predicting:
return Image.open(image_path)
else:
image_path, label = sample
return Image.open(image_path), label
@classmethod
def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False):
return cls(**state_dict)

def to_tensor_transform(
self,
Expand Down Expand Up @@ -285,6 +266,14 @@ __________
.. autoclass:: flash.data.data_source.DataSource
:members:

.. autoclass:: flash.data.data_source.DefaultDataSources
:members:
:undoc-members:

.. autoclass:: flash.data.data_source.DefaultDataKeys
:members:
:undoc-members:


----------

Expand Down
21 changes: 12 additions & 9 deletions flash/data/auto_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,20 @@


class BaseAutoDataset(Generic[DATA_TYPE]):

DATASET_KEY = "dataset"
"""This class is used to encapsulate a Preprocess Object ``load_data`` and ``load_sample`` functions. ``load_data``
will be called within the ``__init__`` function of the AutoDataset if ``running_stage`` is provided and
``load_sample`` within ``__getitem__``.
"""The ``BaseAutoDataset`` class wraps the output of a call to :meth:`~flash.data.data_source.DataSource.load_data`
and a :class:`~fash.data.data_source.DataSource` and provides the ``_call_load_sample`` method to call
:meth:`~flash.data.data_source.DataSource.load_sample` with the correct
:class:`~flash.data.utils.CurrentRunningStageFuncContext` for the current ``running_stage``. Inheriting classes are
responsible for extracting samples from ``data`` to be given to ``_call_load_sample``.

Args:

data: The output of a call to :meth:`~flash.data.data_source.load_data`.

data: The output of a call to :meth:`~flash.data.data_source.DataSource.load_data`.
data_source: The :class:`~flash.data.data_source.DataSource` which has the ``load_sample`` method.

running_stage: The current running stage.
"""

DATASET_KEY = "dataset"

def __init__(
self,
data: DATA_TYPE,
Expand Down Expand Up @@ -93,6 +92,8 @@ def _call_load_sample(self, sample: Any) -> Any:


class AutoDataset(BaseAutoDataset[Sequence], Dataset):
"""The ``AutoDataset`` is a ``BaseAutoDataset`` and a :class:`~torch.utils.data.Dataset`. The `data` argument
must be a ``Sequence`` (it must have a length)."""

def __getitem__(self, index: int) -> Any:
return self._call_load_sample(self.data[index])
Expand All @@ -102,6 +103,8 @@ def __len__(self) -> int:


class IterableAutoDataset(BaseAutoDataset[Iterable], IterableDataset):
"""The ``IterableAutoDataset`` is a ``BaseAutoDataset`` and a :class:`~torch.utils.data.IterableDataset`. The `data`
argument must be an ``Iterable``."""

def __iter__(self):
self.data_iter = iter(self.data)
Expand Down
32 changes: 21 additions & 11 deletions flash/data/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,16 @@ class BaseDataFetcher(FlashCallback):

from flash.data.callback import BaseDataFetcher
from flash.data.data_module import DataModule
from flash.data.data_source import DataSource
from flash.data.process import Preprocess

class CustomPreprocess(Preprocess):

def __init__(**kwargs):
super().__init__(
data_sources = {"inputs": DataSource()},
**kwargs,
)

class PrintData(BaseDataFetcher):

Expand All @@ -90,6 +99,8 @@ def print(self):

class CustomDataModule(DataModule):

preprocess_cls = CustomPreprocess

@staticmethod
def configure_data_fetcher():
return PrintData()
Expand All @@ -100,17 +111,16 @@ def from_inputs(
train_data: Any,
val_data: Any,
test_data: Any,
predict_data: Any) -> "CustomDataModule":

preprocess = CustomPreprocess()

return cls.from_load_data_inputs(
train_load_data_input=train_data,
val_load_data_input=val_data,
test_load_data_input=test_data,
predict_load_data_input=predict_data,
preprocess=preprocess,
batch_size=5)
predict_data: Any,
) -> "CustomDataModule":
return cls.from_data_source(
"inputs",
train_data=train_data,
val_data=val_data,
test_data=test_data,
predict_data=predict_data,
batch_size=5,
)

dm = CustomDataModule.from_inputs(range(5), range(5), range(5), range(5))
data_fetcher = dm.data_fetcher
Expand Down
Loading