Skip to content

Commit

Permalink
optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
Haoming02 committed Sep 2, 2024
1 parent 8c9fcc9 commit 4ce54de
Show file tree
Hide file tree
Showing 11 changed files with 94 additions and 99 deletions.
24 changes: 8 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,33 +1,25 @@
# ComfyUI Diffusion Color Grading
This is an Extension for [ComfyUI](https://github.com/comfyanonymous/ComfyUI), which is the joint research between me and <ins>TimothyAlexisVass</ins>.

For more information, check out the original [Extension](https://github.com/Haoming02/sd-webui-diffusion-cg) for **Automatic1111**.
For more information, check out the original [Extension](https://github.com/Haoming02/sd-webui-diffusion-cg) for **Automatic1111** Webui.

## How to Use
> Example workflows are included~
> An example workflow is included~
- Attach the **Recenter** or **RecenterXL** node between `Empty Latent` and `KSampler` nodes
- Adjust the **strength** and **color** sliders as needed
- Attach the **Normalization** or **NormalizationXL** node between `KSampler` and `VAE Decode` nodes
- Attach the **Normalization** node between `KSampler` and `VAE Decode` nodes

### Important:
- The **Recenter** is "global." If you want to disable it during later part of the workflow *(**eg.** during `Hires. Fix`)*, you have to add another **Recenter** node and set its `strength` to `0.0`.
- The **Recenter** effect is "global." If you want to disable it during other parts of the workflow *(**eg.** during `Hires. Fix`)*, you have to add another **Recenter** node with its `strength` set to `0.0`.

## Examples

<p align="center">
<b>SD 1.5</b><br>
<img src="examples/1.5_off.jpg" width=384>
<img src="examples/1.5_on.png" width=384>
<br><code>Off | On</code>
</p>

<p align="center">
<b>SDXL</b><br>
<img src="examples/xl_off.jpg" width=384>
<img src="examples/xl_on.png" width=384>
<br><code>Off | On</code>
<img src="examples/off.jpg" width=384>
<img src="examples/on.png" width=384><br>
<code>Off | On</code>
</p>

## Known Issue
- Doesn't work with some of the Samplers
- Does not work with certain Samplers
16 changes: 8 additions & 8 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
from .normalization import Normalization, NormalizationXL
from .recenter import Recenter, RecenterXL, reset_str
from .normalization import Normalization
from .recenter import Recenter, RecenterXL, disable_recenter
from functools import wraps
import execution

NODE_CLASS_MAPPINGS = {
"Normalization": Normalization,
"NormalizationXL": NormalizationXL,
"Recenter": Recenter,
"Recenter XL": RecenterXL,
}

NODE_DISPLAY_NAME_MAPPINGS = {
"Normalization": "Normalization",
"NormalizationXL": "NormalizationXL",
"Recenter": "Recenter",
"Recenter XL": "RecenterXL",
}
Expand All @@ -20,8 +19,8 @@
def find_node(prompt: dict) -> bool:
"""Find any ReCenter Node"""

for k, v in prompt.items():
if v["class_type"] in ("Recenter", "Recenter XL"):
for node in prompt.values():
if node.get("class_type", None) in ("Recenter", "Recenter XL"):
return True

return False
Expand All @@ -30,10 +29,11 @@ def find_node(prompt: dict) -> bool:
original_validate = execution.validate_prompt


def hijack_validate(prompt):
@wraps(original_validate)
def hijack_validate(prompt: dict):

if not find_node(prompt):
reset_str()
disable_recenter()

return original_validate(prompt)

Expand Down
Binary file removed examples/1.5_off.jpg
Binary file not shown.
Binary file removed examples/1.5_on.png
Binary file not shown.
Binary file added examples/off.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/on.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed examples/xl_off.jpg
Binary file not shown.
Binary file removed examples/xl_on.png
Binary file not shown.
50 changes: 20 additions & 30 deletions normalization.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,43 @@
DYNAMIC_RANGE = [18, 14, 14, 14]
DYNAMIC_RANGE_XL = [20, 16, 16]
import torch

DYNAMIC_RANGE: float = 1.0 / 0.18215 / 0.13025

def normalize_tensor(x, r):

def normalize_tensor(x: torch.Tensor, r: float) -> torch.Tensor:
ratio = r / max(abs(float(x.min())), abs(float(x.max())))
return x * max(ratio, 0.99)
return x * max(ratio, 1.0)


def clone_latent(latent):
def clone_latent(latent: dict) -> dict:
return {"samples": latent["samples"].detach().clone()}


class Normalization:

@classmethod
def INPUT_TYPES(s):
return {"required": {"latent": ("LATENT",)}}
return {
"required": {
"latent": ("LATENT",),
"sdxl": ("BOOLEAN",),
}
}

RETURN_TYPES = ("LATENT",)
FUNCTION = "normalize"
CATEGORY = "latent"

def normalize(self, latent):
@torch.inference_mode()
def normalize(self, latent: dict, sdxl: bool):
norm_latent = clone_latent(latent)
batches = latent["samples"].size(0)
for b in range(batches):
for c in range(4):
norm_latent["samples"][b][c] = normalize_tensor(
norm_latent["samples"][b][c], DYNAMIC_RANGE[c]
)

return (norm_latent,)


class NormalizationXL:
@classmethod
def INPUT_TYPES(s):
return {"required": {"latent": ("LATENT",)}}
batchSize: int = latent["samples"].size(0)
channels: int = 3 if sdxl else 4

RETURN_TYPES = ("LATENT",)
FUNCTION = "normalize"
CATEGORY = "latent"

def normalize(self, latent):
norm_latent = clone_latent(latent)
batches = latent["samples"].size(0)
for b in range(batches):
for c in range(3):
for b in range(batchSize):
for c in range(channels):
norm_latent["samples"][b][c] = normalize_tensor(
norm_latent["samples"][b][c], DYNAMIC_RANGE_XL[c]
norm_latent["samples"][b][c], DYNAMIC_RANGE / 2.5
)

return (norm_latent,)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "comfyui-diffusion-cg"
description = "Color Grading for Stable Diffusion"
version = "1.0.1"
version = "1.1.0"
license = { text = "MIT License" }
dependencies = []

Expand Down
101 changes: 57 additions & 44 deletions recenter.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from functools import wraps
import datetime
import comfy
import torch

rc_strength = 0.0
LUTs = None
RECENTER: float = 0.0
LUTS: list[float] = None


def reset_str():
global rc_strength
rc_strength = 0.0
def disable_recenter():
global RECENTER
RECENTER = 0.0
global LUTS
LUTS = None


ORIGINAL_SAMPLE = comfy.sample.sample
Expand All @@ -15,20 +20,24 @@ def reset_str():

def hijack(SAMPLE):

@wraps(SAMPLE)
def sample_center(*args, **kwargs):
original_callback = kwargs["callback"]

@torch.inference_mode()
@wraps(original_callback)
def hijack_callback(step, x0, x, total_steps):
global rc_strength
global LUTs

if not rc_strength or not LUTs:
if (not RECENTER) or (not LUTS):
return original_callback(step, x0, x, total_steps)

batchSize = x.size(0)
X = x.detach().clone()
batchSize: int = X.size(0)
channels: int = len(LUTS)

for b in range(batchSize):
for c in range(len(LUTs)):
x[b][c] += (LUTs[c] - x[b][c].detach().clone().mean()) * rc_strength
for c in range(channels):
x[b][c] += (LUTS[c] - X[b][c].mean()) * RECENTER

return original_callback(step, x0, x, total_steps)

Expand All @@ -43,37 +52,31 @@ def hijack_callback(step, x0, x, total_steps):


class Recenter:

@classmethod
def INPUT_TYPES(s):
return {
"required": {
"latent": ("LATENT",),
"strength": (
"FLOAT",
{
"default": 0.0,
"min": 0.0,
"max": 1.0,
"step": 0.1,
"round": 0.1,
"display": "slider",
},
{"default": 0.00, "min": 0.00, "max": 1.00, "step": 0.05},
),
"C": (
"FLOAT",
{"default": 0.01, "min": -1.00, "max": 1.00, "step": 0.01},
{"default": 0.00, "min": -1.00, "max": 1.00, "step": 0.05},
),
"M": (
"FLOAT",
{"default": 0.50, "min": -1.00, "max": 1.00, "step": 0.01},
{"default": 0.00, "min": -1.00, "max": 1.00, "step": 0.05},
),
"Y": (
"FLOAT",
{"default": -0.13, "min": -1.00, "max": 1.00, "step": 0.01},
{"default": 0.00, "min": -1.00, "max": 1.00, "step": 0.05},
),
"K": (
"FLOAT",
{"default": 0.00, "min": -1.00, "max": 1.00, "step": 0.01},
{"default": 0.00, "min": -1.00, "max": 1.00, "step": 0.05},
),
}
}
Expand All @@ -82,14 +85,18 @@ def INPUT_TYPES(s):
FUNCTION = "hook"
CATEGORY = "latent"

def hook(self, latent, strength, C, M, Y, K):
global rc_strength
rc_strength = strength
global LUTs
LUTs = [-K, -M, C, Y]
def hook(self, latent, strength: float, C: float, M: float, Y: float, K: float):
global RECENTER
RECENTER = strength
global LUTS
LUTS = [-K, -M, C, Y]

return (latent,)

@classmethod
def IS_CHANGED(*args, **kwargs):
return str(datetime.datetime.now())


class RecenterXL:
@classmethod
Expand All @@ -99,29 +106,35 @@ def INPUT_TYPES(s):
"latent": ("LATENT",),
"strength": (
"FLOAT",
{
"default": 0.0,
"min": 0.0,
"max": 1.0,
"step": 0.1,
"round": 0.1,
"display": "slider",
},
{"default": 0.00, "min": 0.00, "max": 1.00, "step": 0.05},
),
"Y": (
"FLOAT",
{"default": 0.00, "min": -1.00, "max": 1.00, "step": 0.05},
),
"Cb": (
"FLOAT",
{"default": 0.00, "min": -1.00, "max": 1.00, "step": 0.05},
),
"Cr": (
"FLOAT",
{"default": 0.00, "min": -1.00, "max": 1.00, "step": 0.05},
),
"L": ("FLOAT", {"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.05}),
"a": ("FLOAT", {"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.05}),
"b": ("FLOAT", {"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.05}),
}
}

RETURN_TYPES = ("LATENT",)
FUNCTION = "hook"
CATEGORY = "latent"

def hook(self, latent, strength, L, a, b):
global rc_strength
rc_strength = strength
global LUTs
LUTs = [L, -a, b]
def hook(self, latent, strength: float, Y: float, Cb: float, Cr: float):
global RECENTER
RECENTER = strength
global LUTS
LUTS = [Y, -Cr, -Cb]

return (latent,)

@classmethod
def IS_CHANGED(*args, **kwargs):
return str(datetime.datetime.now())

0 comments on commit 4ce54de

Please sign in to comment.