diff --git a/README.md b/README.md
index dbc2386..b6163a6 100644
--- a/README.md
+++ b/README.md
@@ -1,33 +1,25 @@
# ComfyUI Diffusion Color Grading
This is an Extension for [ComfyUI](https://github.com/comfyanonymous/ComfyUI), which is the joint research between me and TimothyAlexisVass.
-For more information, check out the original [Extension](https://github.com/Haoming02/sd-webui-diffusion-cg) for **Automatic1111**.
+For more information, check out the original [Extension](https://github.com/Haoming02/sd-webui-diffusion-cg) for **Automatic1111** Webui.
## How to Use
-> Example workflows are included~
+> An example workflow is included~
- Attach the **Recenter** or **RecenterXL** node between `Empty Latent` and `KSampler` nodes
- Adjust the **strength** and **color** sliders as needed
-- Attach the **Normalization** or **NormalizationXL** node between `KSampler` and `VAE Decode` nodes
+- Attach the **Normalization** node between `KSampler` and `VAE Decode` nodes
### Important:
-- The **Recenter** is "global." If you want to disable it during later part of the workflow *(**eg.** during `Hires. Fix`)*, you have to add another **Recenter** node and set its `strength` to `0.0`.
+- The **Recenter** effect is "global." If you want to disable it during other parts of the workflow *(**eg.** during `Hires. Fix`)*, you have to add another **Recenter** node with its `strength` set to `0.0`.
## Examples
-SD 1.5
-
-
-
Off | On
-
-
-
-SDXL
-
-
-
Off | On
+
+![](examples/on.png)
+Off | On
## Known Issue
-- Doesn't work with some of the Samplers
+- Does not work with certain Samplers
diff --git a/__init__.py b/__init__.py
index a7cad14..03b81fe 100644
--- a/__init__.py
+++ b/__init__.py
@@ -1,17 +1,16 @@
-from .normalization import Normalization, NormalizationXL
-from .recenter import Recenter, RecenterXL, reset_str
+from .normalization import Normalization
+from .recenter import Recenter, RecenterXL, disable_recenter
+from functools import wraps
import execution
NODE_CLASS_MAPPINGS = {
"Normalization": Normalization,
- "NormalizationXL": NormalizationXL,
"Recenter": Recenter,
"Recenter XL": RecenterXL,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"Normalization": "Normalization",
- "NormalizationXL": "NormalizationXL",
"Recenter": "Recenter",
"Recenter XL": "RecenterXL",
}
@@ -20,8 +19,8 @@
def find_node(prompt: dict) -> bool:
"""Find any ReCenter Node"""
- for k, v in prompt.items():
- if v["class_type"] in ("Recenter", "Recenter XL"):
+ for node in prompt.values():
+ if node.get("class_type", None) in ("Recenter", "Recenter XL"):
return True
return False
@@ -30,10 +29,11 @@ def find_node(prompt: dict) -> bool:
original_validate = execution.validate_prompt
-def hijack_validate(prompt):
+@wraps(original_validate)
+def hijack_validate(prompt: dict):
if not find_node(prompt):
- reset_str()
+ disable_recenter()
return original_validate(prompt)
diff --git a/examples/1.5_off.jpg b/examples/1.5_off.jpg
deleted file mode 100644
index 34471bc..0000000
Binary files a/examples/1.5_off.jpg and /dev/null differ
diff --git a/examples/1.5_on.png b/examples/1.5_on.png
deleted file mode 100644
index d972c5d..0000000
Binary files a/examples/1.5_on.png and /dev/null differ
diff --git a/examples/off.jpg b/examples/off.jpg
new file mode 100644
index 0000000..1ccafc1
Binary files /dev/null and b/examples/off.jpg differ
diff --git a/examples/on.png b/examples/on.png
new file mode 100644
index 0000000..0a3ee6c
Binary files /dev/null and b/examples/on.png differ
diff --git a/examples/xl_off.jpg b/examples/xl_off.jpg
deleted file mode 100644
index 1ef39fd..0000000
Binary files a/examples/xl_off.jpg and /dev/null differ
diff --git a/examples/xl_on.png b/examples/xl_on.png
deleted file mode 100644
index 1a083a8..0000000
Binary files a/examples/xl_on.png and /dev/null differ
diff --git a/normalization.py b/normalization.py
index 63c2913..7aad18d 100644
--- a/normalization.py
+++ b/normalization.py
@@ -1,53 +1,43 @@
-DYNAMIC_RANGE = [18, 14, 14, 14]
-DYNAMIC_RANGE_XL = [20, 16, 16]
+import torch
+DYNAMIC_RANGE: float = 1.0 / 0.18215 / 0.13025
-def normalize_tensor(x, r):
+
+def normalize_tensor(x: torch.Tensor, r: float) -> torch.Tensor:
ratio = r / max(abs(float(x.min())), abs(float(x.max())))
- return x * max(ratio, 0.99)
+ return x * max(ratio, 1.0)
-def clone_latent(latent):
+def clone_latent(latent: dict) -> dict:
return {"samples": latent["samples"].detach().clone()}
class Normalization:
+
@classmethod
def INPUT_TYPES(s):
- return {"required": {"latent": ("LATENT",)}}
+ return {
+ "required": {
+ "latent": ("LATENT",),
+ "sdxl": ("BOOLEAN",),
+ }
+ }
RETURN_TYPES = ("LATENT",)
FUNCTION = "normalize"
CATEGORY = "latent"
- def normalize(self, latent):
+ @torch.inference_mode()
+ def normalize(self, latent: dict, sdxl: bool):
norm_latent = clone_latent(latent)
- batches = latent["samples"].size(0)
- for b in range(batches):
- for c in range(4):
- norm_latent["samples"][b][c] = normalize_tensor(
- norm_latent["samples"][b][c], DYNAMIC_RANGE[c]
- )
-
- return (norm_latent,)
-
-class NormalizationXL:
- @classmethod
- def INPUT_TYPES(s):
- return {"required": {"latent": ("LATENT",)}}
+ batchSize: int = latent["samples"].size(0)
+ channels: int = 3 if sdxl else 4
- RETURN_TYPES = ("LATENT",)
- FUNCTION = "normalize"
- CATEGORY = "latent"
-
- def normalize(self, latent):
- norm_latent = clone_latent(latent)
- batches = latent["samples"].size(0)
- for b in range(batches):
- for c in range(3):
+ for b in range(batchSize):
+ for c in range(channels):
norm_latent["samples"][b][c] = normalize_tensor(
- norm_latent["samples"][b][c], DYNAMIC_RANGE_XL[c]
+ norm_latent["samples"][b][c], DYNAMIC_RANGE / 2.5
)
return (norm_latent,)
diff --git a/pyproject.toml b/pyproject.toml
index c630656..e6e23aa 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,7 +1,7 @@
[project]
name = "comfyui-diffusion-cg"
description = "Color Grading for Stable Diffusion"
-version = "1.0.1"
+version = "1.1.0"
license = { text = "MIT License" }
dependencies = []
diff --git a/recenter.py b/recenter.py
index 802b969..05fc6cc 100644
--- a/recenter.py
+++ b/recenter.py
@@ -1,12 +1,17 @@
+from functools import wraps
+import datetime
import comfy
+import torch
-rc_strength = 0.0
-LUTs = None
+RECENTER: float = 0.0
+LUTS: list[float] = None
-def reset_str():
- global rc_strength
- rc_strength = 0.0
+def disable_recenter():
+ global RECENTER
+ RECENTER = 0.0
+ global LUTS
+ LUTS = None
ORIGINAL_SAMPLE = comfy.sample.sample
@@ -15,20 +20,24 @@ def reset_str():
def hijack(SAMPLE):
+ @wraps(SAMPLE)
def sample_center(*args, **kwargs):
original_callback = kwargs["callback"]
+ @torch.inference_mode()
+ @wraps(original_callback)
def hijack_callback(step, x0, x, total_steps):
- global rc_strength
- global LUTs
- if not rc_strength or not LUTs:
+ if (not RECENTER) or (not LUTS):
return original_callback(step, x0, x, total_steps)
- batchSize = x.size(0)
+ X = x.detach().clone()
+ batchSize: int = X.size(0)
+ channels: int = len(LUTS)
+
for b in range(batchSize):
- for c in range(len(LUTs)):
- x[b][c] += (LUTs[c] - x[b][c].detach().clone().mean()) * rc_strength
+ for c in range(channels):
+ x[b][c] += (LUTS[c] - X[b][c].mean()) * RECENTER
return original_callback(step, x0, x, total_steps)
@@ -43,6 +52,7 @@ def hijack_callback(step, x0, x, total_steps):
class Recenter:
+
@classmethod
def INPUT_TYPES(s):
return {
@@ -50,30 +60,23 @@ def INPUT_TYPES(s):
"latent": ("LATENT",),
"strength": (
"FLOAT",
- {
- "default": 0.0,
- "min": 0.0,
- "max": 1.0,
- "step": 0.1,
- "round": 0.1,
- "display": "slider",
- },
+ {"default": 0.00, "min": 0.00, "max": 1.00, "step": 0.05},
),
"C": (
"FLOAT",
- {"default": 0.01, "min": -1.00, "max": 1.00, "step": 0.01},
+ {"default": 0.00, "min": -1.00, "max": 1.00, "step": 0.05},
),
"M": (
"FLOAT",
- {"default": 0.50, "min": -1.00, "max": 1.00, "step": 0.01},
+ {"default": 0.00, "min": -1.00, "max": 1.00, "step": 0.05},
),
"Y": (
"FLOAT",
- {"default": -0.13, "min": -1.00, "max": 1.00, "step": 0.01},
+ {"default": 0.00, "min": -1.00, "max": 1.00, "step": 0.05},
),
"K": (
"FLOAT",
- {"default": 0.00, "min": -1.00, "max": 1.00, "step": 0.01},
+ {"default": 0.00, "min": -1.00, "max": 1.00, "step": 0.05},
),
}
}
@@ -82,14 +85,18 @@ def INPUT_TYPES(s):
FUNCTION = "hook"
CATEGORY = "latent"
- def hook(self, latent, strength, C, M, Y, K):
- global rc_strength
- rc_strength = strength
- global LUTs
- LUTs = [-K, -M, C, Y]
+ def hook(self, latent, strength: float, C: float, M: float, Y: float, K: float):
+ global RECENTER
+ RECENTER = strength
+ global LUTS
+ LUTS = [-K, -M, C, Y]
return (latent,)
+ @classmethod
+ def IS_CHANGED(*args, **kwargs):
+ return str(datetime.datetime.now())
+
class RecenterXL:
@classmethod
@@ -99,18 +106,20 @@ def INPUT_TYPES(s):
"latent": ("LATENT",),
"strength": (
"FLOAT",
- {
- "default": 0.0,
- "min": 0.0,
- "max": 1.0,
- "step": 0.1,
- "round": 0.1,
- "display": "slider",
- },
+ {"default": 0.00, "min": 0.00, "max": 1.00, "step": 0.05},
+ ),
+ "Y": (
+ "FLOAT",
+ {"default": 0.00, "min": -1.00, "max": 1.00, "step": 0.05},
+ ),
+ "Cb": (
+ "FLOAT",
+ {"default": 0.00, "min": -1.00, "max": 1.00, "step": 0.05},
+ ),
+ "Cr": (
+ "FLOAT",
+ {"default": 0.00, "min": -1.00, "max": 1.00, "step": 0.05},
),
- "L": ("FLOAT", {"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.05}),
- "a": ("FLOAT", {"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.05}),
- "b": ("FLOAT", {"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.05}),
}
}
@@ -118,10 +127,14 @@ def INPUT_TYPES(s):
FUNCTION = "hook"
CATEGORY = "latent"
- def hook(self, latent, strength, L, a, b):
- global rc_strength
- rc_strength = strength
- global LUTs
- LUTs = [L, -a, b]
+ def hook(self, latent, strength: float, Y: float, Cb: float, Cr: float):
+ global RECENTER
+ RECENTER = strength
+ global LUTS
+ LUTS = [Y, -Cr, -Cb]
return (latent,)
+
+ @classmethod
+ def IS_CHANGED(*args, **kwargs):
+ return str(datetime.datetime.now())