Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic support for Intel XPU (Arc Graphics) #409

Merged
merged 6 commits into from
Apr 7, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 36 additions & 12 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
MPS = 5

accelerate_enabled = False
xpu_available = False
vram_state = NORMAL_VRAM

total_vram = 0
Expand All @@ -21,7 +22,12 @@

try:
import torch
total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024)
import intel_extension_for_pytorch as ipex
kwaa marked this conversation as resolved.
Show resolved Hide resolved
if torch.xpu.is_available():
xpu_available = True
total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024)
else:
total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024)
forced_normal_vram = "--normalvram" in sys.argv
if not forced_normal_vram and not forced_cpu:
Expand Down Expand Up @@ -125,6 +131,7 @@ def load_model_gpu(model):
global current_loaded_model
global vram_state
global model_accelerated
global xpu_available

if model is current_loaded_model:
return
Expand All @@ -143,14 +150,17 @@ def load_model_gpu(model):
pass
elif vram_state == NORMAL_VRAM or vram_state == HIGH_VRAM:
model_accelerated = False
real_model.cuda()
if xpu_available:
real_model.to("xpu")
else:
real_model.cuda()
else:
if vram_state == NO_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
elif vram_state == LOW_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"})

accelerate.dispatch_model(real_model, device_map=device_map, main_device="cuda")
accelerate.dispatch_model(real_model, device_map=device_map, main_device="xpu" if xpu_available else "cuda")
model_accelerated = True
return current_loaded_model

Expand All @@ -176,8 +186,12 @@ def load_controlnet_gpu(models):

def load_if_low_vram(model):
global vram_state
global xpu_available
if vram_state == LOW_VRAM or vram_state == NO_VRAM:
return model.cuda()
if xpu_available:
return model.to("xpu")
else:
return model.cuda()
return model

def unload_if_low_vram(model):
Expand All @@ -187,12 +201,16 @@ def unload_if_low_vram(model):
return model

def get_torch_device():
global xpu_available
if vram_state == MPS:
return torch.device("mps")
if vram_state == CPU:
return torch.device("cpu")
else:
return torch.cuda.current_device()
if xpu_available:
return torch.device("xpu")
else:
return torch.cuda.current_device()

def get_autocast_device(dev):
if hasattr(dev, 'type'):
Expand Down Expand Up @@ -222,19 +240,24 @@ def pytorch_attention_enabled():
return ENABLE_PYTORCH_ATTENTION

def get_free_memory(dev=None, torch_free_too=False):
global xpu_available
if dev is None:
dev = get_torch_device()

if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
mem_free_total = psutil.virtual_memory().available
mem_free_torch = mem_free_total
else:
stats = torch.cuda.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(dev)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
if xpu_available:
mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev)
mem_free_torch = mem_free_total
else:
stats = torch.cuda.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(dev)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch

if torch_free_too:
return (mem_free_total, mem_free_torch)
Expand All @@ -259,7 +282,8 @@ def mps_mode():
return vram_state == MPS

def should_use_fp16():
if cpu_mode() or mps_mode():
global xpu_available
if cpu_mode() or mps_mode() or xpu_available:
return False #TODO ?

if torch.cuda.is_bf16_supported():
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ torchsde
einops
open-clip-torch
transformers>=4.25.1
safetensors
safetensors>=0.3.0
kwaa marked this conversation as resolved.
Show resolved Hide resolved
pytorch_lightning
aiohttp
accelerate
Expand Down