Skip to content

Commit

Permalink
disable xformers/sdp if cannot be used
Browse files Browse the repository at this point in the history
  • Loading branch information
vladmandic committed Apr 21, 2023
1 parent cf80857 commit 57204b3
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 26 deletions.
6 changes: 0 additions & 6 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,3 @@ Tech that can be integrated as part of the core workflow...
- Bunch of stuff: <https://pharmapsychotic.com/tools.html>

### Pending Code Updates

- autodetect which system libs should be installed
this is a first pass of autoconfig for nvidia vs amd environments
- fix parse cmd line args from extensions
- merge tomesd token merging
- merge 23 prs pending from a1111 backlog
3 changes: 0 additions & 3 deletions modules/call_queue.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import sys
import html
import threading
import time
Expand Down Expand Up @@ -52,8 +51,6 @@ def f(*args, extra_outputs_array=extra_outputs, **kwargs):
if run_memmon:
shared.mem_mon.monitor()
t = time.perf_counter()
if sys.modules['xformers'] is not None and shared.opts.cross_attention_optimization != "xFormers":
sys.modules["xformers"] = None
try:
if shared.cmd_opts.profile:
pr = cProfile.Profile()
Expand Down
6 changes: 0 additions & 6 deletions modules/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,16 @@ def extract_device_id(args, name):

def get_cuda_device_string():
from modules import shared

if shared.cmd_opts.device_id is not None:
return f"cuda:{shared.cmd_opts.device_id}"

return "cuda"


def get_optimal_device_name():
if torch.cuda.is_available():
return get_cuda_device_string()

if has_mps():
return "mps"

return "cpu"


Expand All @@ -45,10 +41,8 @@ def get_optimal_device():

def get_device_for(task):
from modules import shared

if task in shared.cmd_opts.use_cpu:
return cpu

return get_optimal_device()


Expand Down
2 changes: 1 addition & 1 deletion modules/hashes.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def sha256(filename, title):
if shared.cmd_opts.no_hashing:
return None

print(f"Calculating sha256 for {filename}: ", end='')
print(f"Calculating sha256: {filename}", end='')
sha256_value = calculate_sha256(filename)
print(f"{sha256_value}")

Expand Down
9 changes: 9 additions & 0 deletions modules/import_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,14 @@
from modules.shared import opts

# this will break any attempt to import xformers which will prevent stability diffusion repo from trying to use it
try:
import xformers # pylint: disable=unused-import
import xformers.ops # pylint: disable=unused-import
except:
pass

if opts.cross_attention_optimization != "xFormers":
if sys.modules.get("xformers", None) is not None:
print('Unloading xFormers')
sys.modules["xformers"] = None
sys.modules["xformers.ops"] = None
10 changes: 7 additions & 3 deletions modules/sd_hijack.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,17 @@

def apply_optimizations():
undo_optimizations()

ldm.modules.diffusionmodules.model.nonlinearity = silu
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th

optimization_method = None

can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(getattr(torch.nn.functional, "scaled_dot_product_attention"))
if devices.device == torch.device("cpu"):
if opts.cross_attention_optimization == "Scaled-Dot-Product":
print("Scaled dot product cross attention is not available on CPU")
can_use_sdp = False
if opts.cross_attention_optimization == "xFormers":
print("xFormers cross attention is not available on CPU")
shared.xformers_available = False

if opts.cross_attention_optimization == "Disable cross-attention layer optimization":
print("Cross-attention optimization disabled")
Expand Down
19 changes: 12 additions & 7 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class Dot(dict): # dot notation access to dictionary attributes
args = Dot({ 'debug': False, 'upgrade': False, 'noupdate': False, 'skip-extensions': False, 'skip-requirements': False, 'reset': False })
quick_allowed = True
errors = 0
opts = {}


# setup console and file logging
Expand Down Expand Up @@ -174,7 +175,7 @@ def check_torch():
if shutil.which('nvidia-smi') is not None:
log.info('nVidia toolkit detected')
torch_command = os.environ.get('TORCH_COMMAND', 'torch torchaudio torchvision --index-url https://download.pytorch.org/whl/cu118')
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.17')
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.17' if opts.get('cross_attention_optimization', '') == 'xFormers' else 'none')
elif shutil.which('rocm-smi') is not None:
log.info('AMD toolkit detected')
torch_command = os.environ.get('TORCH_COMMAND', 'torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.4.2')
Expand Down Expand Up @@ -266,14 +267,10 @@ def run_extension_installer(folder):

# get list of all enabled extensions
def list_extensions(folder):
settings = {}
if os.path.isfile(args.ui_settings_file):
with open(args.ui_settings_file, "r", encoding="utf8") as file:
settings = json.load(file)
if settings.get('disable_all_extensions', 'none') != 'none':
if opts.get('disable_all_extensions', 'none') != 'none':
log.debug('Disabled extensions: all')
return []
disabled_extensions = set(settings.get('disabled_extensions', []))
disabled_extensions = set(opts.get('disabled_extensions', []))
if len(disabled_extensions) > 0:
log.debug(f'Disabled extensions: {disabled_extensions}')
return [x for x in os.listdir(folder) if x not in disabled_extensions and not x.startswith('.')]
Expand Down Expand Up @@ -480,9 +477,17 @@ def git_reset():
log.info('GIT reset complete')


def read_options():
global opts # pylint: disable=global-statement
if os.path.isfile(args.ui_settings_file):
with open(args.ui_settings_file, "r", encoding="utf8") as file:
opts = json.load(file)


# entry method when used as module
def run_setup():
setup_logging(args.upgrade)
read_options()
check_python()
if args.reset:
git_reset()
Expand Down

0 comments on commit 57204b3

Please sign in to comment.