Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Embedding merging #1526

Merged
merged 24 commits into from
Nov 28, 2022
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
ed3b26f
add whole <style token> to vocab for concept library embeddings
lstein Nov 19, 2022
65c018e
add ability to load multiple concept .bin files
lstein Nov 20, 2022
af98ae1
make --log_tokenization respect custom tokens
damian0815 Nov 20, 2022
83cf3a2
start working on concept downloading system
lstein Nov 20, 2022
e68f7ec
Merge branch 'embedding-merging' of github.com:/invoke-ai/InvokeAI in…
lstein Nov 20, 2022
f43ba5a
Merge branch 'development' into embedding-merging
lstein Nov 20, 2022
5595b87
preliminary support for dynamic loading and merging of multiple embed…
lstein Nov 21, 2022
6cb31fc
fix loading .pt embeddings; allow multi-vector embeddings; warn on dupes
damian0815 Nov 21, 2022
53d99cc
simplify replacement logic and remove cuda assumption
damian0815 Nov 21, 2022
f683892
download list of concepts from hugging face
lstein Nov 21, 2022
c54131e
remove misleading customization of '*' placeholder
damian0815 Nov 21, 2022
1b946e4
Merge branch 'embedding-merging' of github.com:invoke-ai/InvokeAI int…
damian0815 Nov 21, 2022
8033790
address all the issues raised by damian0815 in review of PR #1526
lstein Nov 22, 2022
27966fb
Merge branch 'embedding-merging' of github.com:/invoke-ai/InvokeAI in…
lstein Nov 22, 2022
16d8d7c
actually resize the token_embeddings
damian0815 Nov 22, 2022
5e448e1
Merge branch 'development' into embedding-merging
lstein Nov 26, 2022
87c6b5d
multiple improvements to the concept loader based on code reviews
lstein Nov 26, 2022
60bc394
Merge branch 'embedding-merging' of github.com:/invoke-ai/InvokeAI in…
lstein Nov 26, 2022
0ef9f67
Merge branch 'development' into embedding-merging
lstein Nov 27, 2022
0843bbc
autocomplete terms end with ">" now
lstein Nov 27, 2022
8961a73
Merge branch 'embedding-merging' of github.com:/invoke-ai/InvokeAI in…
lstein Nov 27, 2022
cedcd95
fix startup error and network unreachable
lstein Nov 27, 2022
49aa405
fix misformatted error string
lstein Nov 28, 2022
0666efe
Merge branch 'development' into embedding-merging
lstein Nov 28, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
803 changes: 803 additions & 0 deletions configs/sd-concepts.txt

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion configs/stable-diffusion/v1-inference.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ model:
placeholder_strings: ["*"]
initializer_words: ['sculpture']
per_image_tokens: false
num_vectors_per_token: 1
num_vectors_per_token: 8
progressive_words: False

unet_config:
Expand Down
2 changes: 1 addition & 1 deletion configs/stable-diffusion/v1-inpainting-inference.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ model:
placeholder_strings: ["*"]
initializer_words: ['sculpture']
per_image_tokens: false
num_vectors_per_token: 1
num_vectors_per_token: 8
progressive_words: False

unet_config:
Expand Down
7 changes: 7 additions & 0 deletions ldm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from ldm.invoke.model_cache import ModelCache
from ldm.invoke.seamless import configure_model_padding
from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale
from ldm.invoke.concepts_lib import Concepts

def fix_func(orig):
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
Expand Down Expand Up @@ -858,6 +859,12 @@ def set_model(self,model_name):
self.model_name = model_name
return self.model

def load_concepts(self,concepts:list[str]):
self.model.embedding_manager.load_concepts(concepts, self.precision=='float32' or self.precision=='autocast')

def concept_lib(self)->Concepts:
return self.model.embedding_manager.concepts_library

def correct_colors(self,
image_list,
reference_image_path,
Expand Down
37 changes: 34 additions & 3 deletions ldm/invoke/CLI.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata
from ldm.invoke.image_util import make_grid
from ldm.invoke.log import write_log
from ldm.invoke.concepts_lib import Concepts
from omegaconf import OmegaConf
from pathlib import Path
import pyparsing
Expand Down Expand Up @@ -62,6 +63,12 @@ def main():
if not os.path.isabs(opt.conf):
opt.conf = os.path.normpath(os.path.join(Globals.root,opt.conf))

if opt.embeddings:
if not os.path.isabs(opt.embedding_path):
embedding_path = os.path.normpath(os.path.join(Globals.root,opt.embedding_path))
else:
embedding_path = None

# load the infile as a list of lines
if opt.infile:
try:
Expand All @@ -81,7 +88,7 @@ def main():
conf = opt.conf,
model = opt.model,
sampler_name = opt.sampler_name,
embedding_path = opt.embedding_path,
embedding_path = embedding_path,
full_precision = opt.full_precision,
precision = opt.precision,
gfpgan=gfpgan,
Expand Down Expand Up @@ -138,6 +145,7 @@ def main_loop(gen, opt):
# changing the history file midstream when the output directory is changed.
completer = get_completer(opt, models=list(model_config.keys()))
set_default_output_dir(opt, completer)
add_embedding_terms(gen, completer)
output_cntr = completer.get_current_history_length()+1

# os.pathconf is not available on Windows
Expand Down Expand Up @@ -215,7 +223,7 @@ def main_loop(gen, opt):
set_default_output_dir(opt,completer)

# try to relativize pathnames
for attr in ('init_img','init_mask','init_color','embedding_path'):
for attr in ('init_img','init_mask','init_color'):
if getattr(opt,attr) and not os.path.exists(getattr(opt,attr)):
basename = getattr(opt,attr)
path = os.path.join(opt.outdir,basename)
Expand Down Expand Up @@ -298,6 +306,7 @@ def image_writer(image, seed, upscaled=False, first_seed=None, use_prefix=None):
if use_prefix is not None:
prefix = use_prefix
postprocessed = upscaled if upscaled else operation=='postprocess'
opt.prompt = triggers_to_concepts(gen, opt.prompt) # to avoid the problem of non-unique concept triggers
filename, formatted_dream_prompt = prepare_image_metadata(
opt,
prefix,
Expand Down Expand Up @@ -341,6 +350,8 @@ def image_writer(image, seed, upscaled=False, first_seed=None, use_prefix=None):
last_results.append([path, seed])

if operation == 'generate':
# load any <embeddings> from the SD concepts library
opt.prompt = concepts_to_triggers(gen, opt.prompt)
catch_ctrl_c = infile is None # if running interactively, we catch keyboard interrupts
opt.last_operation='generate'
try:
Expand Down Expand Up @@ -416,6 +427,7 @@ def do_command(command:str, gen, opt:Args, completer) -> tuple:
elif command.startswith('!switch'):
model_name = command.replace('!switch ','',1)
gen.set_model(model_name)
add_embedding_terms(gen, completer)
completer.add_history(command)
operation = None

Expand Down Expand Up @@ -489,6 +501,19 @@ def do_command(command:str, gen, opt:Args, completer) -> tuple:
command = '-h'
return command, operation

def concepts_to_triggers(gen, prompt:str)->str:
concepts = re.findall('<([^>]+)>',prompt)
lstein marked this conversation as resolved.
Show resolved Hide resolved
if not concepts:
return prompt
gen.load_concepts(concepts)
return gen.concept_lib().replace_concepts_with_triggers(prompt)

def triggers_to_concepts(gen,prompt:str)->str:
concepts = re.findall('<([^>]+)>',prompt)
if not concepts:
return prompt
return gen.concept_lib().replace_triggers_with_concepts(prompt)

def set_default_output_dir(opt:Args, completer:Completer):
'''
If opt.outdir is relative, we add the root directory to it
Expand Down Expand Up @@ -790,7 +815,13 @@ def invoke_ai_web_server_loop(gen, gfpgan, codeformer, esrgan):
except KeyboardInterrupt:
pass


def add_embedding_terms(gen,completer):
'''
Called after setting the model, updates the autocompleter with
any terms loaded by the embedding manager.
'''
completer.add_embedding_terms(gen.model.embedding_manager.list_terms())

def split_variations(variations_string) -> list:
# shotgun parsing, woo
parts = []
Expand Down
14 changes: 12 additions & 2 deletions ldm/invoke/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
import copy
import base64
import functools
import warnings
import ldm.invoke.pngwriter
from ldm.invoke.globals import Globals
from ldm.invoke.prompt_parser import split_weighted_subprompts
Expand All @@ -116,7 +117,7 @@

# is there a way to pick this up during git commits?
APP_ID = 'invoke-ai/InvokeAI'
APP_VERSION = 'v2.1.2'
APP_VERSION = 'v2.2.0'

class ArgFormatter(argparse.RawTextHelpFormatter):
# use defined argument order to display usage
Expand Down Expand Up @@ -546,9 +547,18 @@ def _create_arg_parser(self):
help='generate a grid'
)
render_group.add_argument(
lstein marked this conversation as resolved.
Show resolved Hide resolved
'--embedding_directory',
'--embedding_path',
dest='embedding_path',
default='embeddings',
type=str,
help='Path to a pre-trained embedding manager checkpoint - can only be set on command line',
help='Path to a directory containing .bin and/or .pt files, or a single .bin/.pt file. You may use subdirectories. (default is ROOTDIR/embeddings)'
)
render_group.add_argument(
'--embeddings',
action=argparse.BooleanOptionalAction,
default=True,
help='Enable embedding directory (default). Use --no-embeddings to disable.',
)
render_group.add_argument(
'--enable_image_debugging',
Expand Down
147 changes: 147 additions & 0 deletions ldm/invoke/concepts_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
"""
Query and install embeddings from the HuggingFace SD Concepts Library
at https://huggingface.co/sd-concepts-library.

The interface is through the Concepts() object.
"""
import os
import re
import traceback
from urllib import request
from huggingface_hub import HfFolder, hf_hub_url, ModelSearchArguments, ModelFilter, HfApi
from ldm.invoke.globals import Globals

class Concepts(object):
def __init__(self, root=None):
'''
Initialize the Concepts object. May optionally pass a root directory.
'''
self.root = root or Globals.root
self.hf_api = HfApi()
self.concept_list = None
self.concepts_loaded = dict()
self.triggers = dict() # concept name to trigger phrase
self.concept_names = dict() # trigger phrase to concept name
self.match_trigger = re.compile('(<[\w\-]+>)')
self.match_concept = re.compile('<([\w\-]+)>')

def list_concepts(self)->list:
'''
Return a list of all the concepts by name, without the 'sd-concepts-library' part.
'''
if self.concept_list is not None:
return self.concept_list
try:
models = self.hf_api.list_models(filter=ModelFilter(model_name='sd-concepts-library/'))
self.concept_list = [a.id.split('/')[1] for a in models]
except Exception as e:
print(' ** WARNING: Hugging Face textual inversion concepts libraries could not be loaded. The error was {str(e)}.')
print(' ** You may load .bin and .pt file(s) manually using the --embedding_directory argument.')
return self.concept_list

def get_concept_model_path(self, concept_name:str)->str:
'''
Returns the path to the 'learned_embeds.bin' file in
the named concept. Returns None if invalid or cannot
be downloaded.
'''
return self.get_concept_file(concept_name.lower(),'learned_embeds.bin')

def concept_to_trigger(self, concept_name:str)->str:
'''
Given a concept name returns its trigger by looking in the
"token_identifier.txt" file.
'''
if concept_name in self.triggers:
return self.triggers[concept_name]
file = self.get_concept_file(concept_name, 'token_identifier.txt', local_only=True)
if not file:
return None
with open(file,'r') as f:
trigger = f.readline()
trigger = trigger.strip()
self.triggers[concept_name] = trigger
self.concept_names[trigger] = concept_name
return trigger

def trigger_to_concept(self, trigger:str)->str:
'''
Given a trigger phrase, maps it to the concept library name.
Only works if concept_to_trigger() has previously been called
on this library. There needs to be a persistent database for
this.
'''
concept = self.concept_names.get(trigger,None)
return f'<{concept}>' if concept else f'{trigger}'

def replace_triggers_with_concepts(self, prompt:str)->str:
'''
Given a prompt string that contains <trigger> tags, replace these
tags with the concept name. The reason for this is so that the
concept names get stored in the prompt metadata. There is no
controlling of colliding triggers in the SD library, so it is
better to store the concept name (unique) than the concept trigger
(not necessarily unique!)
'''
def do_replace(match)->str:
return self.trigger_to_concept(match.group(1)) or f'<{match.group(1)}>'
return self.match_trigger.sub(do_replace, prompt)

def replace_concepts_with_triggers(self, prompt:str)->str:
'''
Given a prompt string that contains <concept_name> tags, replace
these tags with the appropriate trigger.
'''
def do_replace(match)->str:
return self.concept_to_trigger(match.group(1)) or f'<{match.group(1)}>'
return self.match_concept.sub(do_replace, prompt)

def get_concept_file(self, concept_name:str, file_name:str='learned_embeds.bin' , local_only:bool=False)->str:
if not self.concept_is_downloaded(concept_name) and not local_only:
self.download_concept(concept_name)
path = os.path.join(self._concept_path(concept_name), file_name)
return path if os.path.exists(path) else None

def concept_is_downloaded(self, concept_name)->bool:
concept_directory = self._concept_path(concept_name)
return os.path.exists(concept_directory)

def download_concept(self,concept_name)->bool:
repo_id = self._concept_id(concept_name)
dest = self._concept_path(concept_name)

access_token = HfFolder.get_token()
header = [("Authorization", f'Bearer {access_token}')] if access_token else []
opener = request.build_opener()
opener.addheaders = header
request.install_opener(opener)

os.makedirs(dest, exist_ok=True)
succeeded = True

bytes = 0
def tally_download_size(chunk, size, total):
nonlocal bytes
if chunk==0:
bytes += total

print(f'>> Downloading {repo_id}...',end='')
try:
for file in ('README.md','learned_embeds.bin','token_identifier.txt','type_of_concept.txt'):
url = hf_hub_url(repo_id, file)
request.urlretrieve(url, os.path.join(dest,file),reporthook=tally_download_size)
except Exception as e:
if e.code==404:
print(f'This concept is not known to the Hugging Face library. Generation will continue without the concept.')
else:
print(f'Failed to download {concept_name}/{file} ({str(e)}. This may be due to a network error. Generation will continue without the concept.)')
os.rmdir(dest)
return False
print('...{:.2f}Kb'.format(bytes/1024))
return succeeded

def _concept_id(self, concept_name:str)->str:
return f'sd-concepts-library/{concept_name}'

def _concept_path(self, concept_name:str)->str:
return os.path.join(self.root,'models','sd-concepts-library',concept_name)
4 changes: 2 additions & 2 deletions ldm/invoke/model_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def _load_model(self, model_name:str):
if not os.path.isabs(weights):
weights = os.path.normpath(os.path.join(Globals.root,weights))
# scan model
self._scan_model(model_name, weights)
self.scan_model(model_name, weights)

print(f'>> Loading {model_name} from {weights}')

Expand Down Expand Up @@ -288,7 +288,7 @@ def offload_model(self, model_name:str) -> None:
if self._has_cuda():
torch.cuda.empty_cache()

def _scan_model(self, model_name, checkpoint):
def scan_model(self, model_name, checkpoint):
# scan model
print(f'>> Scanning Model: {model_name}')
scan_result = scan_file_path(checkpoint)
Expand Down
4 changes: 2 additions & 2 deletions ldm/invoke/prompt_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,13 +646,13 @@ def split_weighted_subprompts(text, skip_normalize=False)->list:
# usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' '
def log_tokenization(text, model, display_label=None):
tokens = model.cond_stage_model.tokenizer._tokenize(text)
tokens = model.cond_stage_model.tokenizer.tokenize(text)
tokenized = ""
discarded = ""
usedTokens = 0
totalTokens = len(tokens)
for i in range(0, totalTokens):
token = tokens[i].replace('</w>', 'x` ')
token = tokens[i].replace('</w>', ' ')
# alternate color
s = (usedTokens % 6) + 1
if i < model.cond_stage_model.max_length:
Expand Down
Loading