diff --git a/README.md b/README.md index dbc2386..b6163a6 100644 --- a/README.md +++ b/README.md @@ -1,33 +1,25 @@ # ComfyUI Diffusion Color Grading This is an Extension for [ComfyUI](https://github.com/comfyanonymous/ComfyUI), which is the joint research between me and TimothyAlexisVass. -For more information, check out the original [Extension](https://github.com/Haoming02/sd-webui-diffusion-cg) for **Automatic1111**. +For more information, check out the original [Extension](https://github.com/Haoming02/sd-webui-diffusion-cg) for **Automatic1111** Webui. ## How to Use -> Example workflows are included~ +> An example workflow is included~ - Attach the **Recenter** or **RecenterXL** node between `Empty Latent` and `KSampler` nodes - Adjust the **strength** and **color** sliders as needed -- Attach the **Normalization** or **NormalizationXL** node between `KSampler` and `VAE Decode` nodes +- Attach the **Normalization** node between `KSampler` and `VAE Decode` nodes ### Important: -- The **Recenter** is "global." If you want to disable it during later part of the workflow *(**eg.** during `Hires. Fix`)*, you have to add another **Recenter** node and set its `strength` to `0.0`. +- The **Recenter** effect is "global." If you want to disable it during other parts of the workflow *(**eg.** during `Hires. Fix`)*, you have to add another **Recenter** node with its `strength` set to `0.0`. ## Examples

-SD 1.5
- - -
Off | On -

- -

-SDXL
- - -
Off | On + +
+Off | On

## Known Issue -- Doesn't work with some of the Samplers +- Does not work with certain Samplers diff --git a/__init__.py b/__init__.py index a7cad14..03b81fe 100644 --- a/__init__.py +++ b/__init__.py @@ -1,17 +1,16 @@ -from .normalization import Normalization, NormalizationXL -from .recenter import Recenter, RecenterXL, reset_str +from .normalization import Normalization +from .recenter import Recenter, RecenterXL, disable_recenter +from functools import wraps import execution NODE_CLASS_MAPPINGS = { "Normalization": Normalization, - "NormalizationXL": NormalizationXL, "Recenter": Recenter, "Recenter XL": RecenterXL, } NODE_DISPLAY_NAME_MAPPINGS = { "Normalization": "Normalization", - "NormalizationXL": "NormalizationXL", "Recenter": "Recenter", "Recenter XL": "RecenterXL", } @@ -20,8 +19,8 @@ def find_node(prompt: dict) -> bool: """Find any ReCenter Node""" - for k, v in prompt.items(): - if v["class_type"] in ("Recenter", "Recenter XL"): + for node in prompt.values(): + if node.get("class_type", None) in ("Recenter", "Recenter XL"): return True return False @@ -30,10 +29,11 @@ def find_node(prompt: dict) -> bool: original_validate = execution.validate_prompt -def hijack_validate(prompt): +@wraps(original_validate) +def hijack_validate(prompt: dict): if not find_node(prompt): - reset_str() + disable_recenter() return original_validate(prompt) diff --git a/examples/1.5_off.jpg b/examples/1.5_off.jpg deleted file mode 100644 index 34471bc..0000000 Binary files a/examples/1.5_off.jpg and /dev/null differ diff --git a/examples/1.5_on.png b/examples/1.5_on.png deleted file mode 100644 index d972c5d..0000000 Binary files a/examples/1.5_on.png and /dev/null differ diff --git a/examples/off.jpg b/examples/off.jpg new file mode 100644 index 0000000..1ccafc1 Binary files /dev/null and b/examples/off.jpg differ diff --git a/examples/on.png b/examples/on.png new file mode 100644 index 0000000..0a3ee6c Binary files /dev/null and b/examples/on.png differ diff --git a/examples/xl_off.jpg b/examples/xl_off.jpg deleted file mode 100644 index 1ef39fd..0000000 Binary files a/examples/xl_off.jpg and /dev/null differ diff --git a/examples/xl_on.png b/examples/xl_on.png deleted file mode 100644 index 1a083a8..0000000 Binary files a/examples/xl_on.png and /dev/null differ diff --git a/normalization.py b/normalization.py index 63c2913..7aad18d 100644 --- a/normalization.py +++ b/normalization.py @@ -1,53 +1,43 @@ -DYNAMIC_RANGE = [18, 14, 14, 14] -DYNAMIC_RANGE_XL = [20, 16, 16] +import torch +DYNAMIC_RANGE: float = 1.0 / 0.18215 / 0.13025 -def normalize_tensor(x, r): + +def normalize_tensor(x: torch.Tensor, r: float) -> torch.Tensor: ratio = r / max(abs(float(x.min())), abs(float(x.max()))) - return x * max(ratio, 0.99) + return x * max(ratio, 1.0) -def clone_latent(latent): +def clone_latent(latent: dict) -> dict: return {"samples": latent["samples"].detach().clone()} class Normalization: + @classmethod def INPUT_TYPES(s): - return {"required": {"latent": ("LATENT",)}} + return { + "required": { + "latent": ("LATENT",), + "sdxl": ("BOOLEAN",), + } + } RETURN_TYPES = ("LATENT",) FUNCTION = "normalize" CATEGORY = "latent" - def normalize(self, latent): + @torch.inference_mode() + def normalize(self, latent: dict, sdxl: bool): norm_latent = clone_latent(latent) - batches = latent["samples"].size(0) - for b in range(batches): - for c in range(4): - norm_latent["samples"][b][c] = normalize_tensor( - norm_latent["samples"][b][c], DYNAMIC_RANGE[c] - ) - - return (norm_latent,) - -class NormalizationXL: - @classmethod - def INPUT_TYPES(s): - return {"required": {"latent": ("LATENT",)}} + batchSize: int = latent["samples"].size(0) + channels: int = 3 if sdxl else 4 - RETURN_TYPES = ("LATENT",) - FUNCTION = "normalize" - CATEGORY = "latent" - - def normalize(self, latent): - norm_latent = clone_latent(latent) - batches = latent["samples"].size(0) - for b in range(batches): - for c in range(3): + for b in range(batchSize): + for c in range(channels): norm_latent["samples"][b][c] = normalize_tensor( - norm_latent["samples"][b][c], DYNAMIC_RANGE_XL[c] + norm_latent["samples"][b][c], DYNAMIC_RANGE / 2.5 ) return (norm_latent,) diff --git a/pyproject.toml b/pyproject.toml index c630656..e6e23aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "comfyui-diffusion-cg" description = "Color Grading for Stable Diffusion" -version = "1.0.1" +version = "1.1.0" license = { text = "MIT License" } dependencies = [] diff --git a/recenter.py b/recenter.py index 802b969..05fc6cc 100644 --- a/recenter.py +++ b/recenter.py @@ -1,12 +1,17 @@ +from functools import wraps +import datetime import comfy +import torch -rc_strength = 0.0 -LUTs = None +RECENTER: float = 0.0 +LUTS: list[float] = None -def reset_str(): - global rc_strength - rc_strength = 0.0 +def disable_recenter(): + global RECENTER + RECENTER = 0.0 + global LUTS + LUTS = None ORIGINAL_SAMPLE = comfy.sample.sample @@ -15,20 +20,24 @@ def reset_str(): def hijack(SAMPLE): + @wraps(SAMPLE) def sample_center(*args, **kwargs): original_callback = kwargs["callback"] + @torch.inference_mode() + @wraps(original_callback) def hijack_callback(step, x0, x, total_steps): - global rc_strength - global LUTs - if not rc_strength or not LUTs: + if (not RECENTER) or (not LUTS): return original_callback(step, x0, x, total_steps) - batchSize = x.size(0) + X = x.detach().clone() + batchSize: int = X.size(0) + channels: int = len(LUTS) + for b in range(batchSize): - for c in range(len(LUTs)): - x[b][c] += (LUTs[c] - x[b][c].detach().clone().mean()) * rc_strength + for c in range(channels): + x[b][c] += (LUTS[c] - X[b][c].mean()) * RECENTER return original_callback(step, x0, x, total_steps) @@ -43,6 +52,7 @@ def hijack_callback(step, x0, x, total_steps): class Recenter: + @classmethod def INPUT_TYPES(s): return { @@ -50,30 +60,23 @@ def INPUT_TYPES(s): "latent": ("LATENT",), "strength": ( "FLOAT", - { - "default": 0.0, - "min": 0.0, - "max": 1.0, - "step": 0.1, - "round": 0.1, - "display": "slider", - }, + {"default": 0.00, "min": 0.00, "max": 1.00, "step": 0.05}, ), "C": ( "FLOAT", - {"default": 0.01, "min": -1.00, "max": 1.00, "step": 0.01}, + {"default": 0.00, "min": -1.00, "max": 1.00, "step": 0.05}, ), "M": ( "FLOAT", - {"default": 0.50, "min": -1.00, "max": 1.00, "step": 0.01}, + {"default": 0.00, "min": -1.00, "max": 1.00, "step": 0.05}, ), "Y": ( "FLOAT", - {"default": -0.13, "min": -1.00, "max": 1.00, "step": 0.01}, + {"default": 0.00, "min": -1.00, "max": 1.00, "step": 0.05}, ), "K": ( "FLOAT", - {"default": 0.00, "min": -1.00, "max": 1.00, "step": 0.01}, + {"default": 0.00, "min": -1.00, "max": 1.00, "step": 0.05}, ), } } @@ -82,14 +85,18 @@ def INPUT_TYPES(s): FUNCTION = "hook" CATEGORY = "latent" - def hook(self, latent, strength, C, M, Y, K): - global rc_strength - rc_strength = strength - global LUTs - LUTs = [-K, -M, C, Y] + def hook(self, latent, strength: float, C: float, M: float, Y: float, K: float): + global RECENTER + RECENTER = strength + global LUTS + LUTS = [-K, -M, C, Y] return (latent,) + @classmethod + def IS_CHANGED(*args, **kwargs): + return str(datetime.datetime.now()) + class RecenterXL: @classmethod @@ -99,18 +106,20 @@ def INPUT_TYPES(s): "latent": ("LATENT",), "strength": ( "FLOAT", - { - "default": 0.0, - "min": 0.0, - "max": 1.0, - "step": 0.1, - "round": 0.1, - "display": "slider", - }, + {"default": 0.00, "min": 0.00, "max": 1.00, "step": 0.05}, + ), + "Y": ( + "FLOAT", + {"default": 0.00, "min": -1.00, "max": 1.00, "step": 0.05}, + ), + "Cb": ( + "FLOAT", + {"default": 0.00, "min": -1.00, "max": 1.00, "step": 0.05}, + ), + "Cr": ( + "FLOAT", + {"default": 0.00, "min": -1.00, "max": 1.00, "step": 0.05}, ), - "L": ("FLOAT", {"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.05}), - "a": ("FLOAT", {"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.05}), - "b": ("FLOAT", {"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.05}), } } @@ -118,10 +127,14 @@ def INPUT_TYPES(s): FUNCTION = "hook" CATEGORY = "latent" - def hook(self, latent, strength, L, a, b): - global rc_strength - rc_strength = strength - global LUTs - LUTs = [L, -a, b] + def hook(self, latent, strength: float, Y: float, Cb: float, Cr: float): + global RECENTER + RECENTER = strength + global LUTS + LUTS = [Y, -Cr, -Cb] return (latent,) + + @classmethod + def IS_CHANGED(*args, **kwargs): + return str(datetime.datetime.now())