-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathrecenter.py
132 lines (104 loc) · 3.42 KB
/
recenter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from functools import wraps
import datetime
import comfy
import torch
RECENTER: float = 0.0
LUTS: list[float] = None
def disable_recenter():
global RECENTER
RECENTER = 0.0
global LUTS
LUTS = None
ORIGINAL_SAMPLE = comfy.sample.sample
ORIGINAL_SAMPLE_CUSTOM = comfy.sample.sample_custom
def hijack(SAMPLE):
@wraps(SAMPLE)
def sample_center(*args, **kwargs):
original_callback = kwargs["callback"]
@torch.inference_mode()
@wraps(original_callback)
def hijack_callback(step, x0, x, total_steps):
if (not RECENTER) or (not LUTS):
return original_callback(step, x0, x, total_steps)
X = x.detach().clone()
batchSize: int = X.size(0)
channels: int = len(LUTS)
for b in range(batchSize):
for c in range(channels):
x[b][c] += (LUTS[c] - X[b][c].mean()) * RECENTER
return original_callback(step, x0, x, total_steps)
kwargs["callback"] = hijack_callback
return SAMPLE(*args, **kwargs)
return sample_center
comfy.sample.sample = hijack(ORIGINAL_SAMPLE)
comfy.sample.sample_custom = hijack(ORIGINAL_SAMPLE_CUSTOM)
class Recenter:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"latent": ("LATENT",),
"strength": (
"FLOAT",
{"default": 0.00, "min": 0.00, "max": 1.00, "step": 0.05},
),
"C": (
"FLOAT",
{"default": 0.00, "min": -1.00, "max": 1.00, "step": 0.05},
),
"M": (
"FLOAT",
{"default": 0.00, "min": -1.00, "max": 1.00, "step": 0.05},
),
"Y": (
"FLOAT",
{"default": 0.00, "min": -1.00, "max": 1.00, "step": 0.05},
),
"K": (
"FLOAT",
{"default": 0.00, "min": -1.00, "max": 1.00, "step": 0.05},
),
}
}
RETURN_TYPES = ("LATENT",)
FUNCTION = "hook"
CATEGORY = "latent"
def hook(self, latent, strength: float, C: float, M: float, Y: float, K: float):
global RECENTER
RECENTER = strength
global LUTS
LUTS = [-K, -M, C, Y]
return (latent,)
class RecenterXL:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"latent": ("LATENT",),
"strength": (
"FLOAT",
{"default": 0.00, "min": 0.00, "max": 1.00, "step": 0.05},
),
"Y": (
"FLOAT",
{"default": 0.00, "min": -1.00, "max": 1.00, "step": 0.05},
),
"Cb": (
"FLOAT",
{"default": 0.00, "min": -1.00, "max": 1.00, "step": 0.05},
),
"Cr": (
"FLOAT",
{"default": 0.00, "min": -1.00, "max": 1.00, "step": 0.05},
),
}
}
RETURN_TYPES = ("LATENT",)
FUNCTION = "hook"
CATEGORY = "latent"
def hook(self, latent, strength: float, Y: float, Cb: float, Cr: float):
global RECENTER
RECENTER = strength
global LUTS
LUTS = [Y, -Cr, -Cb]
return (latent,)