Skip to content

Commit

Permalink
add an option to unload models during hypernetwork training to save VRAM
Browse files Browse the repository at this point in the history
  • Loading branch information
AUTOMATIC1111 committed Oct 11, 2022
1 parent 6d09b8d commit d4ea5f4
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 18 deletions.
25 changes: 18 additions & 7 deletions modules/hypernetworks/hypernetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')

log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
unload = shared.opts.unload_models_when_training

if save_hypernetwork_every > 0:
hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
Expand All @@ -188,11 +189,13 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
else:
images_dir = None

cond_model = shared.sd_model.cond_stage_model

shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file)
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True)

if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)

hypernetwork = shared.loaded_hypernetwork
weights = hypernetwork.weights()
Expand All @@ -211,7 +214,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
return hypernetwork, filename

pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, (x, text) in pbar:
for i, (x, text, cond) in pbar:
hypernetwork.step = i + ititial_step

if hypernetwork.step > steps:
Expand All @@ -221,11 +224,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
break

with torch.autocast("cuda"):
c = cond_model([text])

cond = cond.to(devices.device)
x = x.to(devices.device)
loss = shared.sd_model(x.unsqueeze(0), c)[0]
loss = shared.sd_model(x.unsqueeze(0), cond)[0]
del x
del cond

losses[hypernetwork.step % losses.shape[0]] = loss.item()

Expand All @@ -244,6 +247,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,

preview_text = text if preview_image_prompt == "" else preview_image_prompt

optimizer.zero_grad()
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)

p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
prompt=preview_text,
Expand All @@ -255,6 +262,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
processed = processing.process_images(p)
image = processed.images[0]

if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)

shared.state.current_image = image
image.save(last_saved_image)

Expand Down
4 changes: 3 additions & 1 deletion modules/hypernetworks/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import modules.textual_inversion.textual_inversion
import modules.textual_inversion.preprocess
from modules import sd_hijack, shared
from modules import sd_hijack, shared, devices
from modules.hypernetworks import hypernetwork


Expand Down Expand Up @@ -41,5 +41,7 @@ def train_hypernetwork(*args):
raise
finally:
shared.loaded_hypernetwork = initial_hypernetwork
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
sd_hijack.apply_optimizations()

4 changes: 4 additions & 0 deletions modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,10 @@ def options_section(section_identifier, options_dict):
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
}))

options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP form VRAM when training"),
}))

options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, show_on_main_page=True),
"sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}),
Expand Down
29 changes: 20 additions & 9 deletions modules/textual_inversion/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@

import random
import tqdm
from modules import devices
from modules import devices, shared
import re

re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")


class PersonalizedBase(Dataset):
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None):
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False):

self.placeholder_token = placeholder_token

Expand All @@ -32,6 +32,8 @@ def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_to

assert data_root, 'dataset directory not specified'

cond_model = shared.sd_model.cond_stage_model

self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths):
Expand All @@ -53,7 +55,13 @@ def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_to
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
init_latent = init_latent.to(devices.cpu)

self.dataset.append((init_latent, filename_tokens))
if include_cond:
text = self.create_text(filename_tokens)
cond = cond_model([text]).to(devices.cpu)
else:
cond = None

self.dataset.append((init_latent, filename_tokens, cond))

self.length = len(self.dataset) * repeats

Expand All @@ -64,6 +72,12 @@ def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_to
def shuffle(self):
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]

def create_text(self, filename_tokens):
text = random.choice(self.lines)
text = text.replace("[name]", self.placeholder_token)
text = text.replace("[filewords]", ' '.join(filename_tokens))
return text

def __len__(self):
return self.length

Expand All @@ -72,10 +86,7 @@ def __getitem__(self, i):
self.shuffle()

index = self.indexes[i % len(self.indexes)]
x, filename_tokens = self.dataset[index]

text = random.choice(self.lines)
text = text.replace("[name]", self.placeholder_token)
text = text.replace("[filewords]", ' '.join(filename_tokens))
x, filename_tokens, cond = self.dataset[index]

return x, text
text = self.create_text(filename_tokens)
return x, text, cond
2 changes: 1 addition & 1 deletion modules/textual_inversion/textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
return embedding, filename

pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
for i, (x, text) in pbar:
for i, (x, text, _) in pbar:
embedding.step = i + ititial_step

if embedding.step > steps:
Expand Down

0 comments on commit d4ea5f4

Please sign in to comment.