diff --git a/configs/_base_/datasets/dog_inpaint_multiple_mask.py b/configs/_base_/datasets/dog_inpaint_multiple_mask.py new file mode 100644 index 0000000..eee5a8f --- /dev/null +++ b/configs/_base_/datasets/dog_inpaint_multiple_mask.py @@ -0,0 +1,86 @@ +train_pipeline = [ + dict(type="torchvision/Resize", size=512, interpolation="bilinear"), + dict(type="RandomCrop", size=512), + dict(type="RandomHorizontalFlip", p=0.5), + dict(type="RandomChoice", + transforms=[ + [dict( + type="LoadMask", + mask_mode="irregular", + mask_config=dict( + num_vertices=(4, 10), + max_angle=6.0, + length_range=(20, 200), + brush_width=(10, 100), + area_ratio_range=(0.15, 0.65)))], + [dict( + type="LoadMask", + mask_mode="irregular", + mask_config=dict( + num_vertices=(1, 5), + max_angle=6.0, + length_range=(40, 450), + brush_width=(20, 250), + area_ratio_range=(0.15, 0.65)))], + [dict( + type="LoadMask", + mask_mode="irregular", + mask_config=dict( + num_vertices=(4, 70), + max_angle=6.0, + length_range=(15, 100), + brush_width=(5, 20), + area_ratio_range=(0.15, 0.65)))], + [dict( + type="LoadMask", + mask_mode="bbox", + mask_config=dict( + max_bbox_shape=(150, 150), + max_bbox_delta=50, + min_margin=0))], + [dict( + type="LoadMask", + mask_mode="bbox", + mask_config=dict( + max_bbox_shape=(300, 300), + max_bbox_delta=100, + min_margin=10))], + ]), + dict(type="torchvision/ToTensor"), + dict(type="MaskToTensor"), + dict(type="DumpImage", max_imgs=10, dump_dir="work_dirs/dump"), + dict(type="torchvision/Normalize", mean=[0.5], std=[0.5]), + dict(type="GetMaskedImage"), + dict(type="PackInputs", + input_keys=["img", "mask", "masked_image", "text"]), +] +train_dataloader = dict( + batch_size=4, + num_workers=4, + dataset=dict( + type="HFDreamBoothDataset", + dataset="diffusers/dog-example", + instance_prompt="a photo of sks dog", + pipeline=train_pipeline, + class_prompt=None), + sampler=dict(type="InfiniteSampler", shuffle=True), +) + +val_dataloader = None +val_evaluator = None +test_dataloader = val_dataloader +test_evaluator = val_evaluator + +custom_hooks = [ + dict( + type="VisualizationHook", + prompt=["a photo of sks dog"] * 4, + image=["https://github.com/okotaku/diffengine/assets/24734142/8e02bd0e-9dcc-49b6-94b0-86ab3b40bc2b"] * 4, # noqa + mask=["https://github.com/okotaku/diffengine/assets/24734142/d0de4fb9-9183-418a-970d-582e9324f05d"] * 2 + [ # noqa + "https://github.com/okotaku/diffengine/assets/24734142/a40d1a4f-9c47-4fa0-936e-88a49c92c8d7"] * 2, # noqa + by_epoch=False, + width=512, + height=512, + interval=100), + dict(type="SDCheckpointHook"), +] diff --git a/configs/stable_diffusion_inpaint/README.md b/configs/stable_diffusion_inpaint/README.md index 7005ffd..28acb4a 100644 --- a/configs/stable_diffusion_inpaint/README.md +++ b/configs/stable_diffusion_inpaint/README.md @@ -87,3 +87,11 @@ You can see more details on [`docs/source/run_guides/run_sd.md`](../../docs/sour ![mask](https://github.com/okotaku/diffengine/assets/24734142/d0de4fb9-9183-418a-970d-582e9324f05d) ![example](https://github.com/okotaku/diffengine/assets/24734142/f9ec820b-af75-4c74-8c0b-6558a0a19b95) + +#### stable_diffusion_inpaint_dog_multi_mask + +![input](https://github.com/okotaku/diffengine/assets/24734142/8e02bd0e-9dcc-49b6-94b0-86ab3b40bc2b) + +![mask](https://github.com/okotaku/diffengine/assets/24734142/a40d1a4f-9c47-4fa0-936e-88a49c92c8d7) + +![example](https://github.com/okotaku/diffengine/assets/24734142/f9766a71-0845-4dea-a037-f7dabfca200e) diff --git a/configs/stable_diffusion_inpaint/stable_diffusion_inpaint_dog_multi_mask.py b/configs/stable_diffusion_inpaint/stable_diffusion_inpaint_dog_multi_mask.py new file mode 100644 index 0000000..47a392a --- /dev/null +++ b/configs/stable_diffusion_inpaint/stable_diffusion_inpaint_dog_multi_mask.py @@ -0,0 +1,6 @@ +_base_ = [ + "../_base_/models/stable_diffusion_inpaint.py", + "../_base_/datasets/dog_inpaint_multiple_mask.py", + "../_base_/schedules/stable_diffusion_1k.py", + "../_base_/default_runtime.py", +] diff --git a/diffengine/datasets/hf_dreambooth_datasets.py b/diffengine/datasets/hf_dreambooth_datasets.py index 7279cee..fda3984 100644 --- a/diffengine/datasets/hf_dreambooth_datasets.py +++ b/diffengine/datasets/hf_dreambooth_datasets.py @@ -1,6 +1,7 @@ # flake8: noqa: S311,RUF012 import copy import hashlib +import os import random import shutil from collections.abc import Sequence @@ -45,6 +46,9 @@ class HFDreamBoothDataset(Dataset): class_prompt (Optional[str]): The prompt to specify images in the same class as provided instance images. Defaults to None. pipeline (Sequence): Processing pipeline. Defaults to an empty tuple. + csv (str, optional): Image path csv file name when loading local + folder. If None, the dataset will be loaded from image folders. + Defaults to None. cache_dir (str, optional): The directory where the downloaded datasets will be stored.Defaults to None. """ @@ -65,8 +69,12 @@ def __init__(self, class_image_config: dict | None = None, class_prompt: str | None = None, pipeline: Sequence = (), + csv: str | None = None, cache_dir: str | None = None) -> None: + self.dataset_name = dataset + self.csv = csv + if class_image_config is None: class_image_config = { "model": "runwayml/stable-diffusion-v1-5", @@ -77,7 +85,12 @@ def __init__(self, } if Path(dataset).exists(): # load local folder - self.dataset = load_dataset(dataset, cache_dir=cache_dir)["train"] + if csv is not None: + data_file = os.path.join(dataset, csv) + self.dataset = load_dataset( + "csv", data_files=data_file, cache_dir=cache_dir)["train"] + else: + self.dataset = load_dataset(dataset, cache_dir=cache_dir)["train"] else: # noqa # load huggingface online if dataset_sub_dir is not None: @@ -172,6 +185,8 @@ def __getitem__(self, idx: int) -> dict: data_info = self.dataset[idx] image = data_info[self.image_column] if isinstance(image, str): + if self.csv is not None: + image = os.path.join(self.dataset_name, image) image = Image.open(image) image = image.convert("RGB") result = {"img": image, "text": self.instance_prompt} diff --git a/diffengine/datasets/transforms/__init__.py b/diffengine/datasets/transforms/__init__.py index 72ff532..2a9aab1 100644 --- a/diffengine/datasets/transforms/__init__.py +++ b/diffengine/datasets/transforms/__init__.py @@ -4,6 +4,7 @@ from .loading import LoadMask from .processing import ( TRANSFORMS, + AddConstantCaption, CenterCrop, CLIPImageProcessor, ComputePixArtImgInfo, @@ -17,6 +18,7 @@ SaveImageShape, T5TextPreprocess, ) +from .wrappers import RandomChoice __all__ = [ "BaseTransform", @@ -36,4 +38,6 @@ "LoadMask", "MaskToTensor", "GetMaskedImage", + "RandomChoice", + "AddConstantCaption", ] diff --git a/diffengine/datasets/transforms/processing.py b/diffengine/datasets/transforms/processing.py index ea5cb40..f050e88 100644 --- a/diffengine/datasets/transforms/processing.py +++ b/diffengine/datasets/transforms/processing.py @@ -687,3 +687,34 @@ def transform(self, results: dict) -> dict | tuple[list, list] | None: """ results[self.key] = results["img"] * results["mask"] return results + + +@TRANSFORMS.register_module() +class AddConstantCaption(BaseTransform): + """AddConstantCaption. + + Example. "a dog." * constant_caption="in szn style" + -> "a dog. in szn style" + + Args: + ---- + keys (List[str]): `keys` to apply augmentation from results. + """ + + def __init__(self, constant_caption: str, keys=None) -> None: + if keys is None: + keys = ["text"] + self.constant_caption: str = constant_caption + self.keys = keys + + def transform(self, + results: dict) -> dict | tuple[list, list] | None: + """Transform. + + Args: + ---- + results (dict): The result dict. + """ + for k in self.keys: + results[k] = results[k] + " " + self.constant_caption + return results diff --git a/diffengine/datasets/transforms/wrappers.py b/diffengine/datasets/transforms/wrappers.py new file mode 100644 index 0000000..8d61e8e --- /dev/null +++ b/diffengine/datasets/transforms/wrappers.py @@ -0,0 +1,69 @@ +from collections.abc import Callable, Iterator + +import mmengine +import numpy as np +from mmengine.dataset.base_dataset import Compose + +from diffengine.datasets.transforms.base import BaseTransform +from diffengine.registry import TRANSFORMS + +Transform = dict | Callable[[dict], dict] + + +@TRANSFORMS.register_module() +class RandomChoice(BaseTransform): + """Process data with a randomly chosen transform from given candidates. + + Copied from mmcv/transforms/wrappers.py. + + Args: + ---- + transforms (list[list]): A list of transform candidates, each is a + sequence of transforms. + prob (list[float], optional): The probabilities associated + with each pipeline. The length should be equal to the pipeline + number and the sum should be 1. If not given, a uniform + distribution will be assumed. + + Examples: + -------- + >>> # config + >>> pipeline = [ + >>> dict(type='RandomChoice', + >>> transforms=[ + >>> [dict(type='RandomHorizontalFlip')], # subpipeline 1 + >>> [dict(type='RandomRotate')], # subpipeline 2 + >>> ] + >>> ) + >>> ] + """ + + def __init__(self, + transforms: list[Transform | list[Transform]], + prob: list[float] | None = None) -> None: + + super().__init__() + + if prob is not None: + assert mmengine.is_seq_of(prob, float) + assert len(transforms) == len(prob),( + "``transforms`` and ``prob`` must have same lengths. " + f"Got {len(transforms)} vs {len(prob)}.") + assert sum(prob) == 1 + + self.prob = prob + self.transforms = [Compose(transforms) for transforms in transforms] + + def __iter__(self) -> Iterator: + """Iterate over transforms.""" + return iter(self.transforms) + + def random_pipeline_index(self) -> int: + """Return a random transform index.""" + indices = np.arange(len(self.transforms)) + return np.random.choice(indices, p=self.prob) # noqa + + def transform(self, results: dict) -> dict | None: + """Randomly choose a transform to apply.""" + idx = self.random_pipeline_index() + return self.transforms[idx](results) diff --git a/diffengine/engine/hooks/peft_save_hook.py b/diffengine/engine/hooks/peft_save_hook.py index 6a82ed5..5c025eb 100644 --- a/diffengine/engine/hooks/peft_save_hook.py +++ b/diffengine/engine/hooks/peft_save_hook.py @@ -37,6 +37,8 @@ def before_save_checkpoint(self, runner, checkpoint: dict) -> None: model.unet.save_pretrained(osp.join(ckpt_path, "unet")) model_keys = ["unet"] elif hasattr(model, "prior"): + # TODO(takuoko): Delete if bug is fixed in diffusers. # noqa + model.prior._internal_dict["_name_or_path"] = "prior" # noqa model.prior.save_pretrained(osp.join(ckpt_path, "prior")) model_keys = ["prior"] elif hasattr(model, "transformer"): diff --git a/diffengine/models/editors/pixart_alpha/pixart_alpha.py b/diffengine/models/editors/pixart_alpha/pixart_alpha.py index 612fe49..60ced05 100644 --- a/diffengine/models/editors/pixart_alpha/pixart_alpha.py +++ b/diffengine/models/editors/pixart_alpha/pixart_alpha.py @@ -239,7 +239,7 @@ def infer(self, torch_dtype=torch.float32, ) if self.finetune_text_encoder: - # todo[takuoko]: When parsing text_encoder directly, the # noqa + # TODO(takuoko): When parsing text_encoder directly, the # noqa # results are different. So we need to parse here. pipeline.text_encoder = self.text_encoder pipeline.to(self.device) diff --git a/diffengine/models/editors/stable_diffusion_xl_controlnet/stable_diffusion_xl_controlnetxs.py b/diffengine/models/editors/stable_diffusion_xl_controlnet/stable_diffusion_xl_controlnetxs.py index 778f020..20ad522 100644 --- a/diffengine/models/editors/stable_diffusion_xl_controlnet/stable_diffusion_xl_controlnetxs.py +++ b/diffengine/models/editors/stable_diffusion_xl_controlnet/stable_diffusion_xl_controlnetxs.py @@ -23,7 +23,7 @@ def prepare_model(self) -> None: self.unet, is_sdxl=True) if self.gradient_checkpointing: - # todo[takuoko]: Support ControlNetXSModel for gradient # noqa + # TODO(takuoko): Support ControlNetXSModel for gradient # noqa # checkpointing # self.controlnet.enable_gradient_checkpointing() self.unet.enable_gradient_checkpointing() diff --git a/tests/test_datasets/test_hf_dreambooth_datasets.py b/tests/test_datasets/test_hf_dreambooth_datasets.py index c6b0a3e..985b9a1 100644 --- a/tests/test_datasets/test_hf_dreambooth_datasets.py +++ b/tests/test_datasets/test_hf_dreambooth_datasets.py @@ -69,3 +69,16 @@ def test_dataset_from_local(self): assert data["text"] == "a photo of sks dog" assert isinstance(data["img"], Image.Image) assert data["img"].width == 400 + + def test_dataset_from_local_with_csv(self): + dataset = HFDreamBoothDataset( + dataset="tests/testdata/dataset", + csv="metadata.csv", + image_column="file_name", + instance_prompt="a photo of sks dog") + assert len(dataset) == 1 + + data = dataset[0] + assert data["text"] == "a photo of sks dog" + assert isinstance(data["img"], Image.Image) + assert data["img"].width == 400 diff --git a/tests/test_datasets/test_transforms/test_processing.py b/tests/test_datasets/test_transforms/test_processing.py index dc06c9e..117ad50 100644 --- a/tests/test_datasets/test_transforms/test_processing.py +++ b/tests/test_datasets/test_transforms/test_processing.py @@ -502,3 +502,20 @@ def test_transform(self): assert data["masked_image"].shape == img.shape assert torch.allclose(data["masked_image"][10:, 10:], img[10:, 10:]) assert data["masked_image"][:10, :10].sum() == 0 + + +class TestAddConstantCaption(TestCase): + + def test_register(self): + assert "AddConstantCaption" in TRANSFORMS + + def test_transform(self): + data = { + "text": "a dog.", + } + + # test transform + trans = TRANSFORMS.build(dict(type="AddConstantCaption", + constant_caption="in szn style")) + data = trans(data) + assert data["text"] == "a dog. in szn style" diff --git a/tests/test_datasets/test_transforms/test_wrappers.py b/tests/test_datasets/test_transforms/test_wrappers.py new file mode 100644 index 0000000..bd20c2d --- /dev/null +++ b/tests/test_datasets/test_transforms/test_wrappers.py @@ -0,0 +1,66 @@ +import warnings +from unittest import TestCase + +from diffengine.datasets.transforms.base import BaseTransform +from diffengine.registry import TRANSFORMS + + +@TRANSFORMS.register_module() +class AddToValue(BaseTransform): + """Dummy transform to add a given addend to results['value']""" + + def __init__(self, addend=0) -> None: + super().__init__() + self.addend = addend + + def add(self, results, addend): + augend = results["value"] + + if isinstance(augend, list): + warnings.warn("value is a list", UserWarning) + if isinstance(augend, dict): + warnings.warn("value is a dict", UserWarning) + + def _add_to_value(augend, addend): + if isinstance(augend, list): + return [_add_to_value(v, addend) for v in augend] + if isinstance(augend, dict): + return {k: _add_to_value(v, addend) for k, v in augend.items()} + return augend + addend + + results["value"] = _add_to_value(results["value"], addend) + return results + + def transform(self, results): + return self.add(results, self.addend) + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f"addend = {self.addend}" + return repr_str + + +class TestRandomChoice(TestCase): + + def test_register(self): + assert "RandomChoice" in TRANSFORMS + + def test_transform(self): + data = dict(value=1) + + # test transform + trans = TRANSFORMS.build(dict(type="RandomChoice", + transforms=[ + [AddToValue(addend=1.0)], + [AddToValue(addend=2.0)]], + prob=[1.0, 0.0])) + data = trans(data) + assert data["value"] == 2 + + # Case 2: default probability + trans = TRANSFORMS.build(dict(type="RandomChoice", + transforms=[ + [AddToValue(addend=1.0)], + [AddToValue(addend=2.0)]])) + + _ = trans(dict(value=1)) diff --git a/tests/test_models/test_editors/test_stable_diffusion_xl_controlnet/test_stable_diffusion_xl_controlnetxs.py b/tests/test_models/test_editors/test_stable_diffusion_xl_controlnet/test_stable_diffusion_xl_controlnetxs.py index 71593bd..66e6843 100644 --- a/tests/test_models/test_editors/test_stable_diffusion_xl_controlnet/test_stable_diffusion_xl_controlnetxs.py +++ b/tests/test_models/test_editors/test_stable_diffusion_xl_controlnet/test_stable_diffusion_xl_controlnetxs.py @@ -183,7 +183,7 @@ def prepare_model(self) -> None: ) if self.gradient_checkpointing: - # todo[takuoko]: Support ControlNetXSModel for gradient # noqa + # TODO(takuoko): Support ControlNetXSModel for gradient # noqa # checkpointing # self.controlnet.enable_gradient_checkpointing() self.unet.enable_gradient_checkpointing()