Skip to content

Commit

Permalink
Merge pull request bmaltais#831 from bmaltais/dev2
Browse files Browse the repository at this point in the history
v21.5.12
  • Loading branch information
bmaltais authored May 23, 2023
2 parents f4a9d48 + 68246d8 commit 30b054b
Show file tree
Hide file tree
Showing 5 changed files with 269 additions and 27 deletions.
31 changes: 19 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,16 @@ This repository provides a Windows-focused Gradio GUI for [Kohya's Stable Diffus

### Table of Contents

<<<<<<< HEAD
- [Tutorials](#tutorials)
=======
[English translation by darkstorm2150 is here](https://github.com/darkstorm2150/sd-scripts#links-to-usage-documentation). Thanks to darkstorm2150!

>>>>>>> 6d6df18387a72193af62c651473fe1369b6a2040
* [Training guide - common](./docs/train_README-ja.md) : data preparation, options etc...
* [Chinese version](./docs/train_README-zh.md)
* [Dataset config](./docs/config_README-ja.md)
* [DreamBooth training guide](./docs/train_db_README-ja.md)
* [Step by Step fine-tuning guide](./docs/fine_tune_README_ja.md):
* [Training LoRA](./docs/train_network_README-ja.md)
* [training Textual Inversion](./docs/train_ti_README-ja.md)
* [Image generation](./docs/gen_img_README-ja.md)
* note.com [Model conversion](https://note.com/kohya_ss/n/n374f316fe4ad)
* [Dataset config](./docs/config_README-ja.md)
* [DreamBooth training guide](./docs/train_db_README-ja.md)
* [Step by Step fine-tuning guide](./docs/fine_tune_README_ja.md):
* [Training LoRA](./docs/train_network_README-ja.md)
* [training Textual Inversion](./docs/train_ti_README-ja.md)
* [Image generation](./docs/gen_img_README-ja.md)
* [Model conversion](https://note.com/kohya_ss/n/n374f316fe4ad)
- [Required Dependencies](#required-dependencies)
- [Linux/macOS](#linux-and-macos-dependencies)
- [Installation](#installation)
Expand Down Expand Up @@ -58,6 +53,10 @@ Newer Tutorial: [Generate Studio Quality Realistic Photos By Kohya LoRA Stable D

[![Newer Tutorial: Generate Studio Quality Realistic Photos By Kohya LoRA Stable Diffusion Training](https://user-images.githubusercontent.com/19240467/235306147-85dd8126-f397-406b-83f2-368927fa0281.png)](https://www.youtube.com/watch?v=TpuDOsuKIBo)

Newer Tutorial: [How To Install And Use Kohya LoRA GUI / Web UI on RunPod IO](https://www.youtube.com/watch?v=3uzCNrQao3o):

[![How To Install And Use Kohya LoRA GUI / Web UI on RunPod IO With Stable Diffusion & Automatic1111](https://github-production-user-asset-6210df.s3.amazonaws.com/19240467/238678226-0c9c3f7d-c308-4793-b790-999fdc271372.png)](https://www.youtube.com/watch?v=3uzCNrQao3o)

## Required Dependencies

- Install [Python 3.10](https://www.python.org/ftp/python/3.10.9/python-3.10.9-amd64.exe)
Expand Down Expand Up @@ -346,6 +345,14 @@ This will store a backup file with your current locally installed pip packages a

## Change History

* 2023/07/15 (v21.5.12)
- Fixed several bugs.
- The state is saved even when the `--save_state` option is not specified in `fine_tune.py` and `train_db.py`. [PR #521](https://github.com/kohya-ss/sd-scripts/pull/521) Thanks to akshaal!
- Cannot load LoRA without `alpha`. [PR #527](https://github.com/kohya-ss/sd-scripts/pull/527) Thanks to Manjiz!
- Minor changes to console output during sample generation. [PR #515](https://github.com/kohya-ss/sd-scripts/pull/515) Thanks to yanhuifair!
- The generation script now uses xformers for VAE as well.
- Fixed an issue where an error would occur if the encoding of the prompt file was different from the default. [PR #510](https://github.com/kohya-ss/sd-scripts/pull/510) Thanks to sdbds!
- Please save the prompt file in UTF-8.
* 2023/07/15 (v21.5.11)
- Added an option `--dim_from_weights` to `train_network.py` to automatically determine the dim(rank) from the weight file. [PR #491](https://github.com/kohya-ss/sd-scripts/pull/491) Thanks to AI-Casanova!
- It is useful in combination with `resize_lora.py`. Please see the PR for details.
Expand Down
107 changes: 103 additions & 4 deletions gen_img_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ def backward(ctx, do):
return dq, dk, dv, None, None, None, None


# TODO common train_util.py
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
if mem_eff_attn:
replace_unet_cross_attn_to_memory_efficient()
Expand All @@ -319,7 +320,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio


def replace_unet_cross_attn_to_memory_efficient():
print("Replace CrossAttention.forward to use NAI style Hypernetwork and FlashAttention")
print("CrossAttention.forward has been replaced to FlashAttention (not xformers) and NAI style Hypernetwork")
flash_func = FlashAttentionFunction

def forward_flash_attn(self, x, context=None, mask=None):
Expand Down Expand Up @@ -359,7 +360,7 @@ def forward_flash_attn(self, x, context=None, mask=None):


def replace_unet_cross_attn_to_xformers():
print("Replace CrossAttention.forward to use NAI style Hypernetwork and xformers")
print("CrossAttention.forward has been replaced to enable xformers and NAI style Hypernetwork")
try:
import xformers.ops
except ImportError:
Expand Down Expand Up @@ -401,6 +402,104 @@ def forward_xformers(self, x, context=None, mask=None):
diffusers.models.attention.CrossAttention.forward = forward_xformers


def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers):
if mem_eff_attn:
replace_vae_attn_to_memory_efficient()
elif xformers:
# とりあえずDiffusersのxformersを使う。AttentionがあるのはMidBlockのみ
print("Use Diffusers xformers for VAE")
vae.set_use_memory_efficient_attention_xformers(True)

"""
# VAEがbfloat16でメモリ消費が大きい問題を解決する
upsamplers = []
for block in vae.decoder.up_blocks:
if block.upsamplers is not None:
upsamplers.extend(block.upsamplers)
def forward_upsample(_self, hidden_states, output_size=None):
assert hidden_states.shape[1] == _self.channels
if _self.use_conv_transpose:
return _self.conv(hidden_states)
dtype = hidden_states.dtype
if dtype == torch.bfloat16:
assert output_size is None
# repeat_interleaveはすごく遅いが、回数はあまり呼ばれないので許容する
hidden_states = hidden_states.repeat_interleave(2, dim=-1)
hidden_states = hidden_states.repeat_interleave(2, dim=-2)
else:
if hidden_states.shape[0] >= 64:
hidden_states = hidden_states.contiguous()
# if `output_size` is passed we force the interpolation output
# size and do not make use of `scale_factor=2`
if output_size is None:
hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
else:
hidden_states = torch.nn.functional.interpolate(hidden_states, size=output_size, mode="nearest")
if _self.use_conv:
if _self.name == "conv":
hidden_states = _self.conv(hidden_states)
else:
hidden_states = _self.Conv2d_0(hidden_states)
return hidden_states
# replace upsamplers
for upsampler in upsamplers:
# make new scope
def make_replacer(upsampler):
def forward(hidden_states, output_size=None):
return forward_upsample(upsampler, hidden_states, output_size)
return forward
upsampler.forward = make_replacer(upsampler)
"""


def replace_vae_attn_to_memory_efficient():
print("AttentionBlock.forward has been replaced to FlashAttention (not xformers)")
flash_func = FlashAttentionFunction

def forward_flash_attn(self, hidden_states):
print("forward_flash_attn")
q_bucket_size = 512
k_bucket_size = 1024

residual = hidden_states
batch, channel, height, width = hidden_states.shape

# norm
hidden_states = self.group_norm(hidden_states)

hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)

# proj to q, k, v
query_proj = self.query(hidden_states)
key_proj = self.key(hidden_states)
value_proj = self.value(hidden_states)

query_proj, key_proj, value_proj = map(
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (query_proj, key_proj, value_proj)
)

out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size)

out = rearrange(out, "b h n d -> b n (h d)")

# compute next hidden_states
hidden_states = self.proj_attn(hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)

# res connect and rescale
hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states

diffusers.models.attention.AttentionBlock.forward = forward_flash_attn


# endregion

# region 画像生成の本体:lpw_stable_diffusion.py (ASL)からコピーして修正
Expand Down Expand Up @@ -2142,6 +2241,7 @@ def main(args):
# xformers、Hypernetwork対応
if not args.diffusers_xformers:
replace_unet_modules(unet, not args.xformers, args.xformers)
replace_vae_modules(vae, not args.xformers, args.xformers)

# tokenizerを読み込む
print("loading tokenizer")
Expand Down Expand Up @@ -3175,8 +3275,7 @@ def setup_parser() -> argparse.ArgumentParser:
"--vae_slices",
type=int,
default=None,
help=
"number of slices to split image into for VAE to reduce VRAM usage, None for no splitting (default), slower if specified. 16 or 32 recommended / VAE処理時にVRAM使用量削減のため画像を分割するスライス数、Noneの場合は分割しない(デフォルト)、指定すると遅くなる。16か32程度を推奨"
help="number of slices to split image into for VAE to reduce VRAM usage, None for no splitting (default), slower if specified. 16 or 32 recommended / VAE処理時にVRAM使用量削減のため画像を分割するスライス数、Noneの場合は分割しない(デフォルト)、指定すると遅くなる。16か32程度を推奨",
)
parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数")
parser.add_argument(
Expand Down
76 changes: 66 additions & 10 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1765,14 +1765,15 @@ def backward(ctx, do):


def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
# unet is not used currently, but it is here for future use
if mem_eff_attn:
replace_unet_cross_attn_to_memory_efficient()
elif xformers:
replace_unet_cross_attn_to_xformers()


def replace_unet_cross_attn_to_memory_efficient():
print("Replace CrossAttention.forward to use FlashAttention (not xformers)")
print("CrossAttention.forward has been replaced to FlashAttention (not xformers)")
flash_func = FlashAttentionFunction

def forward_flash_attn(self, x, context=None, mask=None):
Expand Down Expand Up @@ -1812,7 +1813,7 @@ def forward_flash_attn(self, x, context=None, mask=None):


def replace_unet_cross_attn_to_xformers():
print("Replace CrossAttention.forward to use xformers")
print("CrossAttention.forward has been replaced to enable xformers.")
try:
import xformers.ops
except ImportError:
Expand Down Expand Up @@ -1854,6 +1855,60 @@ def forward_xformers(self, x, context=None, mask=None):
diffusers.models.attention.CrossAttention.forward = forward_xformers


"""
def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers):
# vae is not used currently, but it is here for future use
if mem_eff_attn:
replace_vae_attn_to_memory_efficient()
elif xformers:
# とりあえずDiffusersのxformersを使う。AttentionがあるのはMidBlockのみ
print("Use Diffusers xformers for VAE")
vae.encoder.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True)
vae.decoder.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True)
def replace_vae_attn_to_memory_efficient():
print("AttentionBlock.forward has been replaced to FlashAttention (not xformers)")
flash_func = FlashAttentionFunction
def forward_flash_attn(self, hidden_states):
print("forward_flash_attn")
q_bucket_size = 512
k_bucket_size = 1024
residual = hidden_states
batch, channel, height, width = hidden_states.shape
# norm
hidden_states = self.group_norm(hidden_states)
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
# proj to q, k, v
query_proj = self.query(hidden_states)
key_proj = self.key(hidden_states)
value_proj = self.value(hidden_states)
query_proj, key_proj, value_proj = map(
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (query_proj, key_proj, value_proj)
)
out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size)
out = rearrange(out, "b h n d -> b n (h d)")
# compute next hidden_states
hidden_states = self.proj_attn(hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
# res connect and rescale
hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states
diffusers.models.attention.AttentionBlock.forward = forward_flash_attn
"""


# endregion


Expand Down Expand Up @@ -3167,10 +3222,11 @@ def save_sd_model_on_epoch_end_or_stepwise(
print(f"removing old model: {remove_out_dir}")
shutil.rmtree(remove_out_dir)

if on_epoch_end:
save_and_remove_state_on_epoch_end(args, accelerator, epoch_no)
else:
save_and_remove_state_stepwise(args, accelerator, global_step)
if args.save_state:
if on_epoch_end:
save_and_remove_state_on_epoch_end(args, accelerator, epoch_no)
else:
save_and_remove_state_stepwise(args, accelerator, global_step)


def save_and_remove_state_on_epoch_end(args: argparse.Namespace, accelerator, epoch_no):
Expand Down Expand Up @@ -3294,7 +3350,7 @@ def sample_images(
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
return

print(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
print(f"\ngenerating sample images at step / サンプル画像生成 ステップ: {steps}")
if not os.path.isfile(args.sample_prompts):
print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
return
Expand All @@ -3308,15 +3364,15 @@ def sample_images(
# prompts = f.readlines()

if args.sample_prompts.endswith(".txt"):
with open(args.sample_prompts, "r") as f:
with open(args.sample_prompts, "r", encoding="utf-8") as f:
lines = f.readlines()
prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
elif args.sample_prompts.endswith(".toml"):
with open(args.sample_prompts, "r") as f:
with open(args.sample_prompts, "r", encoding="utf-8") as f:
data = toml.load(f)
prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]]
elif args.sample_prompts.endswith(".json"):
with open(args.sample_prompts, "r") as f:
with open(args.sample_prompts, "r", encoding="utf-8") as f:
prompts = json.load(f)

# schedulerを用意する
Expand Down
2 changes: 1 addition & 1 deletion networks/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
# support old LoRA without alpha
for key in modules_dim.keys():
if key not in modules_alpha:
modules_alpha = modules_dim[key]
modules_alpha[key] = modules_dim[key]

module_class = LoRAInfModule if for_inference else LoRAModule

Expand Down
Loading

0 comments on commit 30b054b

Please sign in to comment.