Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix error with path initializing StylesDatabase #8

Merged
merged 3 commits into from
Apr 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 45 additions & 60 deletions modules/styles.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations
from pathlib import Path
from modules import errors
import csv
import os
import typing
Expand Down Expand Up @@ -45,13 +44,13 @@ def extract_style_text_from_prompt(style_text, prompt):
if "{prompt}" in stripped_style_text:
left, _, right = stripped_style_text.partition("{prompt}")
if stripped_prompt.startswith(left) and stripped_prompt.endswith(right):
prompt = stripped_prompt[len(left):len(stripped_prompt)-len(right)]
prompt = stripped_prompt[len(left) : len(stripped_prompt) - len(right)]
return True, prompt
else:
if stripped_prompt.endswith(stripped_style_text):
prompt = stripped_prompt[:len(stripped_prompt)-len(stripped_style_text)]
prompt = stripped_prompt[: len(stripped_prompt) - len(stripped_style_text)]

if prompt.endswith(', '):
if prompt.endswith(", "):
prompt = prompt[:-2]

return True, prompt
Expand All @@ -68,11 +67,15 @@ def extract_original_prompts(style: PromptStyle, prompt, negative_prompt):
if not style.prompt and not style.negative_prompt:
return False, prompt, negative_prompt

match_positive, extracted_positive = extract_style_text_from_prompt(style.prompt, prompt)
match_positive, extracted_positive = extract_style_text_from_prompt(
style.prompt, prompt
)
if not match_positive:
return False, prompt, negative_prompt

match_negative, extracted_negative = extract_style_text_from_prompt(style.negative_prompt, negative_prompt)
match_negative, extracted_negative = extract_style_text_from_prompt(
style.negative_prompt, negative_prompt
)
if not match_negative:
return False, prompt, negative_prompt

Expand All @@ -89,20 +92,28 @@ def _format_divider(file: str) -> str:
return divider


def _expand_path(path: list[str | Path] | str | Path) -> list[str]:
if isinstance(path, (str, Path)):
return [str(Path(path))]

paths = []
for pattern in path:
folder, file = os.path.split(pattern)
if "*" in file or "?" in file:
matching_files = Path(folder).glob(file)
[paths.append(str(file)) for file in matching_files]
else:
paths.append(str(Path(pattern)))

return paths


class StyleDatabase:
def __init__(self, paths: list[str | Path]):
def __init__(self, path: str | Path):
self.no_style = PromptStyle("None", "", "", None)
self.styles = {}
self.path = path
self.prompt_fields = [field for field in PromptStyle._fields if field != "path"]

# The default path will be self.path with any wildcard removed. If it
# doesn't exist, the reload() method updates this to be 'styles.csv'.
self.default_file = "styles.csv"
folder, file = os.path.split(self.path)
filename, _, ext = file.partition('*')
self.default_path = os.path.join(folder, filename + ext)

self.reload()

def reload(self):
Expand All @@ -112,50 +123,18 @@ def reload(self):
"""
self.styles.clear()

# scans for all styles files
all_styles_files = []
for pattern in self.paths:
folder, file = os.path.split(pattern)
if '*' in file or '?' in file:
found_files = Path(folder).glob(file)
[all_styles_files.append(file) for file in found_files]
else:
# if os.path.exists(pattern):
all_styles_files.append(Path(pattern))

if "*" in filename:
fileglob = filename.split("*")[0] + "*.csv"
filelist = []
for file in os.listdir(path):
if fnmatch.fnmatch(file, fileglob):
filelist.append(file)
# Add a visible divider to the style list
divider = _format_divider(file)
self.styles[divider] = PromptStyle(
f"{divider}", None, None, "do_not_save"
)
# Add styles from this CSV file
self.load_from_csv(os.path.join(path, file))

# Ensure the default file is loaded, else its contents may be lost:
if os.path.split(self.default_path)[1] not in filelist:
self.default_path = os.path.join(path, self.default_file)
divider = _format_divider(self.default_file)
self.styles[divider] = PromptStyle(
f"{divider}", None, None, "do_not_save"
)
self.load_from_csv(os.path.join(path, self.default_file))

if len(filelist) == 0:
print(f"No styles found in {path} matching {fileglob}")
self.load_from_csv(self.default_path)
return
# Expand the path to a list of full paths, expanding any wildcards. The
# default path will be the first of these:
style_files = _expand_path(self.path)
self.default_path = style_files[0]

elif not os.path.exists(self.path):
print(f"Style database not found: {self.path}")
return
else:
self.load_from_csv(self.path)
for file in style_files:
_, filename = os.path.split(file)
# Add a visible divider to the style list
divider = _format_divider(filename)
self.styles[divider] = PromptStyle(f"{divider}", None, None, "do_not_save")
# Add styles from this CSV file
self.load_from_csv(file)

def load_from_csv(self, path: str):
with open(path, "r", encoding="utf-8-sig", newline="") as file:
Expand All @@ -173,7 +152,10 @@ def load_from_csv(self, path: str):
)

def get_style_paths(self) -> set:
"""Returns a set of all distinct paths of files that styles are loaded from."""
"""
Using the collection of styles in the StyleDatabase, returns a set of
all distinct files that styles are loaded from.
"""
# Update any styles without a path to the default path
for style in list(self.styles.values()):
if not style.path:
Expand Down Expand Up @@ -224,14 +206,17 @@ def save_styles(self, path: str = None) -> None:
writer = csv.DictWriter(file, fieldnames=self.prompt_fields)
writer.writeheader()
for style in (s for s in self.styles.values() if s.path == style_path):
# Skip style list dividers, e.g. "STYLES.CSV"
# Skip style list divider entries, e.g. "## STYLES.CSV ##"
if style.name.lower().strip("# ") in csv_names:
continue
# Write style fields, ignoring the path field
writer.writerow(
{k: v for k, v in style._asdict().items() if k != "path"}
)

# Reloading the styles to re-order the drop-down lists
self.reload()

def extract_styles_from_prompt(self, prompt, negative_prompt):
extracted = []

Expand Down
Loading