Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support RandomChoice for random mask #112

Merged
merged 3 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions configs/_base_/datasets/dog_inpaint_multiple_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
train_pipeline = [
dict(type="torchvision/Resize", size=512, interpolation="bilinear"),
dict(type="RandomCrop", size=512),
dict(type="RandomHorizontalFlip", p=0.5),
dict(type="RandomChoice",
transforms=[
[dict(
type="LoadMask",
mask_mode="irregular",
mask_config=dict(
num_vertices=(4, 10),
max_angle=6.0,
length_range=(20, 200),
brush_width=(10, 100),
area_ratio_range=(0.15, 0.65)))],
[dict(
type="LoadMask",
mask_mode="irregular",
mask_config=dict(
num_vertices=(1, 5),
max_angle=6.0,
length_range=(40, 450),
brush_width=(20, 250),
area_ratio_range=(0.15, 0.65)))],
[dict(
type="LoadMask",
mask_mode="irregular",
mask_config=dict(
num_vertices=(4, 70),
max_angle=6.0,
length_range=(15, 100),
brush_width=(5, 20),
area_ratio_range=(0.15, 0.65)))],
[dict(
type="LoadMask",
mask_mode="bbox",
mask_config=dict(
max_bbox_shape=(150, 150),
max_bbox_delta=50,
min_margin=0))],
[dict(
type="LoadMask",
mask_mode="bbox",
mask_config=dict(
max_bbox_shape=(300, 300),
max_bbox_delta=100,
min_margin=10))],
]),
dict(type="torchvision/ToTensor"),
dict(type="MaskToTensor"),
dict(type="DumpImage", max_imgs=10, dump_dir="work_dirs/dump"),
dict(type="torchvision/Normalize", mean=[0.5], std=[0.5]),
dict(type="GetMaskedImage"),
dict(type="PackInputs",
input_keys=["img", "mask", "masked_image", "text"]),
]
train_dataloader = dict(
batch_size=4,
num_workers=4,
dataset=dict(
type="HFDreamBoothDataset",
dataset="diffusers/dog-example",
instance_prompt="a photo of sks dog",
pipeline=train_pipeline,
class_prompt=None),
sampler=dict(type="InfiniteSampler", shuffle=True),
)

val_dataloader = None
val_evaluator = None
test_dataloader = val_dataloader
test_evaluator = val_evaluator

custom_hooks = [
dict(
type="VisualizationHook",
prompt=["a photo of sks dog"] * 4,
image=["https://github.com/okotaku/diffengine/assets/24734142/8e02bd0e-9dcc-49b6-94b0-86ab3b40bc2b"] * 4, # noqa
mask=["https://github.com/okotaku/diffengine/assets/24734142/d0de4fb9-9183-418a-970d-582e9324f05d"] * 2 + [ # noqa
"https://github.com/okotaku/diffengine/assets/24734142/a40d1a4f-9c47-4fa0-936e-88a49c92c8d7"] * 2, # noqa
by_epoch=False,
width=512,
height=512,
interval=100),
dict(type="SDCheckpointHook"),
]
8 changes: 8 additions & 0 deletions configs/stable_diffusion_inpaint/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,11 @@ You can see more details on [`docs/source/run_guides/run_sd.md`](../../docs/sour
![mask](https://github.com/okotaku/diffengine/assets/24734142/d0de4fb9-9183-418a-970d-582e9324f05d)

![example](https://github.com/okotaku/diffengine/assets/24734142/f9ec820b-af75-4c74-8c0b-6558a0a19b95)

#### stable_diffusion_inpaint_dog_multi_mask

![input](https://github.com/okotaku/diffengine/assets/24734142/8e02bd0e-9dcc-49b6-94b0-86ab3b40bc2b)

![mask](https://github.com/okotaku/diffengine/assets/24734142/a40d1a4f-9c47-4fa0-936e-88a49c92c8d7)

![example](https://github.com/okotaku/diffengine/assets/24734142/f9766a71-0845-4dea-a037-f7dabfca200e)
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
_base_ = [
"../_base_/models/stable_diffusion_inpaint.py",
"../_base_/datasets/dog_inpaint_multiple_mask.py",
"../_base_/schedules/stable_diffusion_1k.py",
"../_base_/default_runtime.py",
]
17 changes: 16 additions & 1 deletion diffengine/datasets/hf_dreambooth_datasets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# flake8: noqa: S311,RUF012
import copy
import hashlib
import os
import random
import shutil
from collections.abc import Sequence
Expand Down Expand Up @@ -45,6 +46,9 @@ class HFDreamBoothDataset(Dataset):
class_prompt (Optional[str]): The prompt to specify images in the same
class as provided instance images. Defaults to None.
pipeline (Sequence): Processing pipeline. Defaults to an empty tuple.
csv (str, optional): Image path csv file name when loading local
folder. If None, the dataset will be loaded from image folders.
Defaults to None.
cache_dir (str, optional): The directory where the downloaded datasets
will be stored.Defaults to None.
"""
Expand All @@ -65,8 +69,12 @@ def __init__(self,
class_image_config: dict | None = None,
class_prompt: str | None = None,
pipeline: Sequence = (),
csv: str | None = None,
cache_dir: str | None = None) -> None:

self.dataset_name = dataset
self.csv = csv

if class_image_config is None:
class_image_config = {
"model": "runwayml/stable-diffusion-v1-5",
Expand All @@ -77,7 +85,12 @@ def __init__(self,
}
if Path(dataset).exists():
# load local folder
self.dataset = load_dataset(dataset, cache_dir=cache_dir)["train"]
if csv is not None:
data_file = os.path.join(dataset, csv)
self.dataset = load_dataset(
"csv", data_files=data_file, cache_dir=cache_dir)["train"]
else:
self.dataset = load_dataset(dataset, cache_dir=cache_dir)["train"]
else: # noqa
# load huggingface online
if dataset_sub_dir is not None:
Expand Down Expand Up @@ -172,6 +185,8 @@ def __getitem__(self, idx: int) -> dict:
data_info = self.dataset[idx]
image = data_info[self.image_column]
if isinstance(image, str):
if self.csv is not None:
image = os.path.join(self.dataset_name, image)
image = Image.open(image)
image = image.convert("RGB")
result = {"img": image, "text": self.instance_prompt}
Expand Down
4 changes: 4 additions & 0 deletions diffengine/datasets/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .loading import LoadMask
from .processing import (
TRANSFORMS,
AddConstantCaption,
CenterCrop,
CLIPImageProcessor,
ComputePixArtImgInfo,
Expand All @@ -17,6 +18,7 @@
SaveImageShape,
T5TextPreprocess,
)
from .wrappers import RandomChoice

__all__ = [
"BaseTransform",
Expand All @@ -36,4 +38,6 @@
"LoadMask",
"MaskToTensor",
"GetMaskedImage",
"RandomChoice",
"AddConstantCaption",
]
31 changes: 31 additions & 0 deletions diffengine/datasets/transforms/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,3 +687,34 @@ def transform(self, results: dict) -> dict | tuple[list, list] | None:
"""
results[self.key] = results["img"] * results["mask"]
return results


@TRANSFORMS.register_module()
class AddConstantCaption(BaseTransform):
"""AddConstantCaption.

Example. "a dog." * constant_caption="in szn style"
-> "a dog. in szn style"

Args:
----
keys (List[str]): `keys` to apply augmentation from results.
"""

def __init__(self, constant_caption: str, keys=None) -> None:
if keys is None:
keys = ["text"]
self.constant_caption: str = constant_caption
self.keys = keys

def transform(self,
results: dict) -> dict | tuple[list, list] | None:
"""Transform.

Args:
----
results (dict): The result dict.
"""
for k in self.keys:
results[k] = results[k] + " " + self.constant_caption
return results
69 changes: 69 additions & 0 deletions diffengine/datasets/transforms/wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from collections.abc import Callable, Iterator

import mmengine
import numpy as np
from mmengine.dataset.base_dataset import Compose

from diffengine.datasets.transforms.base import BaseTransform
from diffengine.registry import TRANSFORMS

Transform = dict | Callable[[dict], dict]


@TRANSFORMS.register_module()
class RandomChoice(BaseTransform):
"""Process data with a randomly chosen transform from given candidates.

Copied from mmcv/transforms/wrappers.py.

Args:
----
transforms (list[list]): A list of transform candidates, each is a
sequence of transforms.
prob (list[float], optional): The probabilities associated
with each pipeline. The length should be equal to the pipeline
number and the sum should be 1. If not given, a uniform
distribution will be assumed.

Examples:
--------
>>> # config
>>> pipeline = [
>>> dict(type='RandomChoice',
>>> transforms=[
>>> [dict(type='RandomHorizontalFlip')], # subpipeline 1
>>> [dict(type='RandomRotate')], # subpipeline 2
>>> ]
>>> )
>>> ]
"""

def __init__(self,
transforms: list[Transform | list[Transform]],
prob: list[float] | None = None) -> None:

super().__init__()

if prob is not None:
assert mmengine.is_seq_of(prob, float)
assert len(transforms) == len(prob),(
"``transforms`` and ``prob`` must have same lengths. "
f"Got {len(transforms)} vs {len(prob)}.")
assert sum(prob) == 1

self.prob = prob
self.transforms = [Compose(transforms) for transforms in transforms]

def __iter__(self) -> Iterator:
"""Iterate over transforms."""
return iter(self.transforms)

def random_pipeline_index(self) -> int:
"""Return a random transform index."""
indices = np.arange(len(self.transforms))
return np.random.choice(indices, p=self.prob) # noqa

def transform(self, results: dict) -> dict | None:
"""Randomly choose a transform to apply."""
idx = self.random_pipeline_index()
return self.transforms[idx](results)
2 changes: 2 additions & 0 deletions diffengine/engine/hooks/peft_save_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def before_save_checkpoint(self, runner, checkpoint: dict) -> None:
model.unet.save_pretrained(osp.join(ckpt_path, "unet"))
model_keys = ["unet"]
elif hasattr(model, "prior"):
# TODO(takuoko): Delete if bug is fixed in diffusers. # noqa
model.prior._internal_dict["_name_or_path"] = "prior" # noqa
model.prior.save_pretrained(osp.join(ckpt_path, "prior"))
model_keys = ["prior"]
elif hasattr(model, "transformer"):
Expand Down
2 changes: 1 addition & 1 deletion diffengine/models/editors/pixart_alpha/pixart_alpha.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def infer(self,
torch_dtype=torch.float32,
)
if self.finetune_text_encoder:
# todo[takuoko]: When parsing text_encoder directly, the # noqa
# TODO(takuoko): When parsing text_encoder directly, the # noqa
# results are different. So we need to parse here.
pipeline.text_encoder = self.text_encoder
pipeline.to(self.device)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def prepare_model(self) -> None:
self.unet, is_sdxl=True)

if self.gradient_checkpointing:
# todo[takuoko]: Support ControlNetXSModel for gradient # noqa
# TODO(takuoko): Support ControlNetXSModel for gradient # noqa
# checkpointing
# self.controlnet.enable_gradient_checkpointing()
self.unet.enable_gradient_checkpointing()
Expand Down
13 changes: 13 additions & 0 deletions tests/test_datasets/test_hf_dreambooth_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,16 @@ def test_dataset_from_local(self):
assert data["text"] == "a photo of sks dog"
assert isinstance(data["img"], Image.Image)
assert data["img"].width == 400

def test_dataset_from_local_with_csv(self):
dataset = HFDreamBoothDataset(
dataset="tests/testdata/dataset",
csv="metadata.csv",
image_column="file_name",
instance_prompt="a photo of sks dog")
assert len(dataset) == 1

data = dataset[0]
assert data["text"] == "a photo of sks dog"
assert isinstance(data["img"], Image.Image)
assert data["img"].width == 400
17 changes: 17 additions & 0 deletions tests/test_datasets/test_transforms/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,3 +502,20 @@ def test_transform(self):
assert data["masked_image"].shape == img.shape
assert torch.allclose(data["masked_image"][10:, 10:], img[10:, 10:])
assert data["masked_image"][:10, :10].sum() == 0


class TestAddConstantCaption(TestCase):

def test_register(self):
assert "AddConstantCaption" in TRANSFORMS

def test_transform(self):
data = {
"text": "a dog.",
}

# test transform
trans = TRANSFORMS.build(dict(type="AddConstantCaption",
constant_caption="in szn style"))
data = trans(data)
assert data["text"] == "a dog. in szn style"
Loading