Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Add BaseViz Callback (2 / 2) #201

Merged
merged 37 commits into from
Apr 6, 2021
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
168b231
wip
tchaton Mar 30, 2021
cda64d3
add base_viz + new features for DataPipeline
tchaton Mar 31, 2021
2b2c499
update
tchaton Mar 31, 2021
6db6b1c
resolve flake8
tchaton Mar 31, 2021
f61deea
update
tchaton Mar 31, 2021
cb85981
Merge branch 'master' into base_viz
tchaton Mar 31, 2021
ffaa7c7
resolve tests
tchaton Mar 31, 2021
596a523
update
tchaton Mar 31, 2021
2fdefbe
wip
tchaton Mar 31, 2021
4381441
update
tchaton Mar 31, 2021
d572248
resolve doc
tchaton Mar 31, 2021
b928fc5
resolve doc
tchaton Mar 31, 2021
9381d41
update doc
tchaton Mar 31, 2021
108a7cc
update
tchaton Apr 1, 2021
6da92b3
update
tchaton Apr 1, 2021
d4cf9f5
update
tchaton Apr 1, 2021
16deb7b
convert to staticmethod
tchaton Apr 1, 2021
4025eb0
initial visualisation implementation
edgarriba Apr 1, 2021
37c8084
Merge branch 'base_viz_2' of https://github.com/PyTorchLightning/ligh…
edgarriba Apr 1, 2021
d2076d4
implement test case using Kornia transforms
edgarriba Apr 1, 2021
ff8e1ad
update on comments
tchaton Apr 1, 2021
84eaa68
resolve bug
tchaton Apr 1, 2021
881851a
Merge branch 'data_pipeline_current_fn' into base_viz_2
tchaton Apr 1, 2021
fb25c04
update
tchaton Apr 1, 2021
cc760a5
Merge branch 'master' into base_viz_2
tchaton Apr 1, 2021
d3932c9
update
tchaton Apr 1, 2021
ee9f781
Merge branch 'base_viz_2' of https://github.com/PyTorchLightning/ligh…
tchaton Apr 1, 2021
f6f33b8
add test
tchaton Apr 1, 2021
2de0e15
update
tchaton Apr 1, 2021
631f06f
resolve tests
tchaton Apr 6, 2021
bda5ff2
resolve flake8
tchaton Apr 6, 2021
0e74167
update
tchaton Apr 6, 2021
d0fb78d
update
tchaton Apr 6, 2021
098d7ab
update
tchaton Apr 6, 2021
ba0a992
Merge branch 'master' into base_viz_2
tchaton Apr 6, 2021
9bdd179
resolve test
tchaton Apr 6, 2021
67ba94c
Merge branch 'base_viz_2' of https://github.com/PyTorchLightning/ligh…
tchaton Apr 6, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions flash/data/base_viz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import functools
from contextlib import contextmanager
from typing import Any, Callable

from pytorch_lightning.callbacks import Callback
from pytorch_lightning.trainer.states import RunningStage

from flash.data.data_pipeline import DataPipeline
from flash.data.process import Preprocess
from flash.data.utils import _STAGES_PREFIX


class BaseViz(Callback):
"""
This class is used to profile ``Preprocess`` hook outputs and visualize the data transformations.
It is disabled by default.
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(self, enabled: bool = False):
self.batches = {k: {} for k in _STAGES_PREFIX.values()}
self.enabled = enabled
self._datamodule = None
self._preprocess = None

@contextmanager
def enable(self):
self.enabled = True
yield
self.enabled = False

def attach_to_preprocess(self, preprocess: Preprocess) -> None:
self._wrap_functions_per_stage(RunningStage.TRAINING, preprocess)

def attach_to_datamodule(self, datamodule) -> None:
self._datamodule = datamodule
datamodule.viz = self

def _wrap_fn(
self,
fn: Callable,
) -> Callable:

@functools.wraps(fn)
def wrapper(*args) -> Any:
data = fn(*args)
if self.enabled:
batches = self.batches[_STAGES_PREFIX[self._preprocess.running_stage]]
if fn.__name__ not in batches:
batches[fn.__name__] = []
batches[fn.__name__].append(data)
return data

return wrapper

def _wrap_functions_per_stage(self, running_stage: RunningStage, preprocess: Preprocess):
self._preprocess = preprocess
fn_names = {
k: DataPipeline._resolve_function_hierarchy(k, preprocess, running_stage, Preprocess)
for k in DataPipeline.PREPROCESS_FUNCS
}
for fn_name in fn_names:
fn = getattr(preprocess, fn_name)
setattr(preprocess, fn_name, self._wrap_fn(fn))
6 changes: 3 additions & 3 deletions flash/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,11 @@ def __init__(
self.stage = stage
self.on_device = on_device

extension = f"{'on_device' if self.on_device else ''}"
extension = f"{'_on_device' if self.on_device else ''}"
self._current_stage_context = CurrentRunningStageContext(stage, preprocess)
self._per_sample_transform_context = CurrentFuncContext(f"per_sample_transform_{extension}", preprocess)
self._per_sample_transform_context = CurrentFuncContext(f"per_sample_transform{extension}", preprocess)
self._collate_context = CurrentFuncContext("collate", preprocess)
self._per_batch_transform_context = CurrentFuncContext(f"per_batch_transform_{extension}", preprocess)
self._per_batch_transform_context = CurrentFuncContext(f"per_batch_transform{extension}", preprocess)

def forward(self, samples: Sequence[Any]) -> Any:
with self._current_stage_context:
Expand Down
20 changes: 19 additions & 1 deletion flash/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torch.utils.data.dataset import Subset

from flash.data.auto_dataset import AutoDataset
from flash.data.base_viz import BaseViz
from flash.data.data_pipeline import DataPipeline, Postprocess, Preprocess


Expand Down Expand Up @@ -53,7 +54,7 @@ def __init__(
test_dataset: Optional[Dataset] = None,
predict_dataset: Optional[Dataset] = None,
batch_size: int = 1,
num_workers: Optional[int] = None,
num_workers: Optional[int] = 0,
) -> None:

super().__init__()
Expand Down Expand Up @@ -83,10 +84,23 @@ def __init__(

self._preprocess = None
self._postprocess = None
self._viz = None

# this may also trigger data preloading
self.set_running_stages()

@property
def viz(self) -> BaseViz:
return self._viz or DataModule.configure_vis()

@viz.setter
def viz(self, viz: BaseViz) -> None:
self._viz = viz

@classmethod
def configure_vis(cls) -> BaseViz:
return BaseViz()

@staticmethod
def get_dataset_attribute(dataset: torch.utils.data.Dataset, attr_name: str, default: Optional[Any] = None) -> Any:
if isinstance(dataset, Subset):
Expand Down Expand Up @@ -320,6 +334,9 @@ def from_load_data_inputs(
else:
data_pipeline = cls(**kwargs).data_pipeline

viz_callback = cls.configure_vis()
viz_callback.attach_to_preprocess(data_pipeline._preprocess_pipeline)

train_dataset = cls._generate_dataset_if_possible(
train_load_data_input, running_stage=RunningStage.TRAINING, data_pipeline=data_pipeline
)
Expand All @@ -341,4 +358,5 @@ def from_load_data_inputs(
)
datamodule._preprocess = data_pipeline._preprocess_pipeline
datamodule._postprocess = data_pipeline._postprocess_pipeline
viz_callback.attach_to_datamodule(datamodule)
return datamodule
2 changes: 1 addition & 1 deletion flash/vision/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from flash.vision.classification import ImageClassificationData, ImageClassifier
from flash.vision.classification import ImageClassificationData, ImageClassifier, ImageClassificationDataViz
from flash.vision.detection import ObjectDetectionData, ObjectDetector
from flash.vision.embedding import ImageEmbedder
2 changes: 1 addition & 1 deletion flash/vision/classification/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from flash.vision.classification.data import ImageClassificationData
from flash.vision.classification.data import ImageClassificationData, ImageClassificationDataViz
from flash.vision.classification.model import ImageClassifier
52 changes: 52 additions & 0 deletions flash/vision/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def to_tensor_transform(self, sample: Any) -> Any:
def post_tensor_transform(self, sample: Any) -> Any:
return self.common_step(sample)

# todo bug (tchaton) where to place the collate. Need an indication.
tchaton marked this conversation as resolved.
Show resolved Hide resolved
def per_batch_transform(self, sample: Any) -> Any:
return self.common_step(sample)

Expand Down Expand Up @@ -468,6 +469,7 @@ def from_filepaths(
folder/cat_asd932_.png

Args:

train_filepaths: String or sequence of file paths for training dataset. Defaults to ``None``.
train_labels: Sequence of labels for training dataset. Defaults to ``None``.
val_filepaths: String or sequence of file paths for validation dataset. Defaults to ``None``.
Expand All @@ -484,6 +486,7 @@ def from_filepaths(
seed: Used for the train/val splits.

Returns:

ImageClassificationData: The constructed data module.
"""
# enable passing in a string which loads all files in that folder as a list
Expand Down Expand Up @@ -524,3 +527,52 @@ def from_filepaths(
seed=seed,
**kwargs
)


class ImageClassificationDataViz(ImageClassificationData):

def show_train_batch(self):
self.viz.enabled = True
# fetch batch and cache data
_ = next(iter(self.train_dataloader()))
self.viz.enabled = False

from typing import List

import kornia as K
tchaton marked this conversation as resolved.
Show resolved Hide resolved
import matplotlib.pyplot as plt
import numpy as np
import torchvision as tv
from PIL import Image

# plot row data
rows: int = 4 # chenge later
data_raw: List[Image] = self.viz.batches['train']['load_sample']
for num, x_data in enumerate(data_raw):
img, label = x_data
plt.subplot(rows, rows, num + 1)
plt.title(label)
plt.axis('off')
plt.imshow(np.array(img))
plt.title('load_sample')
plt.show(block=False)

mean = torch.tensor([0.485, 0.456, 0.406])
std = torch.tensor([0.229, 0.224, 0.225])
tchaton marked this conversation as resolved.
Show resolved Hide resolved

# plot pre-process and after augmentations
data1, labels1 = self.viz.batches['train']['collate'][0] # this is before random transforms
data2, labels2 = self.viz.batches['train']['per_batch_transform'][0] # this should be after random transforms

data1 = K.enhance.denormalize(data1, mean, std)
data2 = K.enhance.denormalize(data2, mean, std)

# cast and prepare data for viualisation
data1_vis = K.tensor_to_image(tv.utils.make_grid(data1))
data2_vis = K.tensor_to_image(tv.utils.make_grid(data2))

# plot using matplotlib
fig, (ax1, ax2) = plt.subplots(2)
ax1.imshow(data1_vis)
ax2.imshow(data2_vis)
plt.show()
91 changes: 91 additions & 0 deletions tests/data/test_data_viz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path

import kornia as K
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image
from pytorch_lightning import seed_everything

from flash.data.utils import _STAGES_PREFIX
from flash.vision import ImageClassificationData


def _rand_image():
return Image.fromarray(np.random.randint(0, 255, (196, 196, 3), dtype="uint8"))


class ImageClassificationDataViz(ImageClassificationData):

def show_batch(self):
# viz needs to be enabled, so it doesn't store profile transforms during training
with self.viz.enable():
_ = next(iter(self.train_dataloader()))
_ = next(iter(self.val_dataloader()))
_ = next(iter(self.test_dataloader()))
_ = next(iter(self.predict_dataloader()))


def test_base_viz(tmpdir):
seed_everything(42)
tmpdir = Path(tmpdir)

(tmpdir / "a").mkdir()
(tmpdir / "b").mkdir()
_rand_image().save(tmpdir / "a" / "a_1.png")
_rand_image().save(tmpdir / "a" / "a_2.png")

_rand_image().save(tmpdir / "b" / "a_1.png")
_rand_image().save(tmpdir / "b" / "a_2.png")

img_data = ImageClassificationDataViz.from_filepaths(
train_filepaths=[tmpdir / "a", tmpdir / "b"],
train_labels=[0, 1],
val_filepaths=[tmpdir / "a", tmpdir / "b"],
val_labels=[0, 1],
test_filepaths=[tmpdir / "a", tmpdir / "b"],
test_labels=[0, 1],
predict_filepaths=[tmpdir / "a", tmpdir / "b"],
batch_size=2,
num_workers=0,
)

img_data.show_batch()
for stage in _STAGES_PREFIX.values():
is_predict = stage == "predict"

def extract_data(data):
if not is_predict:
return data[0][0]
return data[0]

assert isinstance(extract_data(img_data.viz.batches[stage]["load_sample"]), Image.Image)
if not is_predict:
assert isinstance(img_data.viz.batches[stage]["load_sample"][0][1], int)

assert isinstance(extract_data(img_data.viz.batches[stage]["to_tensor_transform"]), torch.Tensor)
if not is_predict:
assert isinstance(img_data.viz.batches[stage]["to_tensor_transform"][0][1], int)

assert extract_data(img_data.viz.batches[stage]["collate"]).shape == torch.Size([2, 3, 196, 196])
if not is_predict:
assert img_data.viz.batches[stage]["collate"][0][1].shape == torch.Size([2])

generated = extract_data(img_data.viz.batches[stage]["per_batch_transform"]).shape
assert generated == torch.Size([2, 3, 196, 196])
if not is_predict:
assert img_data.viz.batches[stage]["per_batch_transform"][0][1].shape == torch.Size([2])