Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

P+: Extended Textual Conditioning in Text-to-Image Generation #327

Merged
merged 3 commits into from
Mar 30, 2023

Conversation

jakaline-dev
Copy link

@jakaline-dev jakaline-dev commented Mar 25, 2023

Implemented from https://prompt-plus.github.io/
image
top: TI, bottom: XTI

This method is training TI for each cross attention layer.

Paper says the optimal training parameters for XTI is 500 steps with lr=0.005. They said it converged faster than original TI with 5000 steps, but for me it took the same time.

I hardcoded the saved XTI safetensors file to have 16 keys: ['IN01', 'IN02', 'IN04', 'IN05', 'IN07', 'IN08', 'MID', 'OUT03', 'OUT04', 'OUT05', 'OUT06', 'OUT07', 'OUT08', 'OUT09', 'OUT10', 'OUT11'].

To train, use train_textual_inversion_XTI.py. Training args are same with original TI.
To inference, use '--XTI_embeddings' just as '--textual_inversion_embeddings'.

Comparable to LoRA? We will have to see.

@TingTingin
Copy link
Contributor

how many images how long was training what gpu did you use

@jakaline-dev
Copy link
Author

jakaline-dev commented Mar 26, 2023

how many images how long was training what gpu did you use

~150 images, ~60 minutes on 4080
Memory usage is similar to plain TI training

@kohya-ss
Copy link
Owner

Thank you for the great PR! This is very interesting. It will take me a little time to check it out, so I will merge the other PRs and then get to it.

@catboxanon
Copy link
Contributor

Did you take a look at https://github.com/cloneofsimo/promptplusplus? This seems to bring a few improvements to the original paper.

@kohya-ss
Copy link
Owner

Hi, @jakaline-dev !

I've tested the PR, and the training and the image generation seem to work fine! However, the sample generation during the training raises an error.

I think the sample generation may require some modification like gen_img_diffusers.py, but it will take some time to find the way for me.

Do you have any idea?

prompt: usu frog
negative_prompt: None
height: 512
width: 512
sample_steps: 30
scale: 7.5
  0%|                                                                                           | 0/30 [00:00<?, ?it/s]
Traceback (most recent call last):                                                              | 0/30 [00:00<?, ?it/s]
  File "D:\Work\SD\dev\sd-scripts\venv\lib\site-packages\einops\einops.py", line 412, in reduce
    return _apply_recipe(recipe, tensor, reduction_type=reduction)
  File "D:\Work\SD\dev\sd-scripts\venv\lib\site-packages\einops\einops.py", line 235, in _apply_recipe
    _reconstruct_from_shape(recipe, backend.shape(tensor))
  File "D:\Work\SD\dev\sd-scripts\venv\lib\site-packages\einops\einops.py", line 165, in _reconstruct_from_shape_uncached
    raise EinopsError('Expected {} dimensions, got {}'.format(len(self.input_composite_axes), len(shape)))
einops.EinopsError: Expected 3 dimensions, got 2

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "D:\Work\SD\dev\sd-scripts\train_textual_inversion_XTI.py", line 586, in <module>
    train(args)
  File "D:\Work\SD\dev\sd-scripts\train_textual_inversion_XTI.py", line 469, in train
    train_util.sample_images(
  File "D:\Work\SD\dev\sd-scripts\library\train_util.py", line 2944, in sample_images
    image = pipeline(
  File "D:\Work\SD\dev\sd-scripts\venv\lib\site-packages\torch\autograd\grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "D:\Work\SD\dev\sd-scripts\library\lpw_stable_diffusion.py", line 855, in __call__
    noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
  File "D:\Work\SD\dev\sd-scripts\venv\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "D:\Work\SD\dev\sd-scripts\XTI_hijack.py", line 83, in unet_forward_XTI
    sample, res_samples = downsample_block(
  File "D:\Work\SD\dev\sd-scripts\venv\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "D:\Work\SD\dev\sd-scripts\XTI_hijack.py", line 152, in downblock_forward_XTI
    hidden_states = torch.utils.checkpoint.checkpoint(
  File "D:\Work\SD\dev\sd-scripts\venv\lib\site-packages\torch\utils\checkpoint.py", line 235, in checkpoint
    return CheckpointFunction.apply(function, preserve, *args)
  File "D:\Work\SD\dev\sd-scripts\venv\lib\site-packages\torch\utils\checkpoint.py", line 96, in forward
    outputs = run_function(*args)
  File "D:\Work\SD\dev\sd-scripts\XTI_hijack.py", line 145, in custom_forward
    return module(*inputs, return_dict=return_dict)
  File "D:\Work\SD\dev\sd-scripts\venv\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "D:\Work\SD\dev\sd-scripts\venv\lib\site-packages\diffusers\models\attention.py", line 216, in forward
    hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep)
  File "D:\Work\SD\dev\sd-scripts\venv\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "D:\Work\SD\dev\sd-scripts\venv\lib\site-packages\diffusers\models\attention.py", line 491, in forward
    hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states
  File "D:\Work\SD\dev\sd-scripts\venv\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "D:\Work\SD\dev\sd-scripts\library\train_util.py", line 1717, in forward_xformers
    q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
  File "D:\Work\SD\dev\sd-scripts\library\train_util.py", line 1717, in <lambda>
    q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
  File "D:\Work\SD\dev\sd-scripts\venv\lib\site-packages\einops\einops.py", line 483, in rearrange
    return reduce(cast(Tensor, tensor), pattern, reduction='rearrange', **axes_lengths)
  File "D:\Work\SD\dev\sd-scripts\venv\lib\site-packages\einops\einops.py", line 420, in reduce
    raise EinopsError(message + '\n {}'.format(e))
einops.EinopsError:  Error while processing rearrange-reduction pattern "b n (h d) -> b n h d".
 Input tensor shape: torch.Size([77, 320]). Additional info: {'h': 8}.
 Expected 3 dimensions, got 2

@jakaline-dev
Copy link
Author

Looking at the code right now, it seems that the sample generation part is using a different pipeline from gen_img_diffusers.py.
I'll make some changes to fix this (along with updating with the latest patches), but it would be appreciated if you update the codebase to the latest diffusers (0.14) sooner or later. It has some useful functions that makes the code more managable, such as pipe.unet.set_attn_processor (Look at https://github.com/cloneofsimo/promptplusplus)

@jakaline-dev
Copy link
Author

It's too complicated to fix for the sampling during training stages, so going dirty and disabled sampling for now. If it's not urgent, it can be fixed in the future.

@jakaline-dev
Copy link
Author

제목 없음

from safetensors import safe_open
from safetensors.torch import save_file

subsets = ["MID", "IN08", "OUT03", "IN07", "OUT04", "OUT05"]

for i in range(6):
    f1 = {}
    with safe_open("emu_prsk.safetensors", framework="pt", device="cpu") as f:
        for key in f.keys():
            f1[key] = f.get_tensor(key)

    f2 = {}
    with safe_open("hayate.safetensors", framework="pt", device="cpu") as f:
        for key in f.keys():
            f2[key] = f.get_tensor(key)

    for key in subsets[0:i+1]:
        f1[key] = f2[key]    
    save_file(f1, f"emu_hayate_subset_{i+1}.safetensors")

Meanwhile, here is some code for mixing layers (just as the paper did)

@kohya-ss
Copy link
Owner

It's too complicated to fix for the sampling during training stages, so going dirty and disabled sampling for now. If it's not urgent, it can be fixed in the future.

Thank you for taking a look. I think disabling sampling is ok. And sorry to bother you with the changing the file format.

I know I will have to update Diffusers at some point, but I have many other hacks in addition to TI in this repo. Diffusers updates quickly and dirty, and is not backward compatible, so updating it is a bit of a pain...

I will review and merge after work!

@kohya-ss kohya-ss changed the base branch from main to dev March 30, 2023 10:44
@kohya-ss kohya-ss merged commit 935d477 into kohya-ss:dev Mar 30, 2023
@kohya-ss
Copy link
Owner

I've added some modification after merging. Please let me know if you notice anything.

Thank you for this great work!

@kgonia
Copy link

kgonia commented Apr 4, 2023

@jakaline-dev can prompt+ be used for standard finetuning?

@jakaline-dev
Copy link
Author

@kgonia You can finetune from a pretrained p+ embedding, but usually finetuning is done with frozen embeddings. Although it kinda would be possible if you tweak the code for text encoder finetuning (haven't seen anyone doing it with just plain TI)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants