Skip to content

Commit

Permalink
CFGDenoiser and script_callbacks mod for SAG
Browse files Browse the repository at this point in the history
  • Loading branch information
gitadmin0608 committed Apr 21, 2023
1 parent c205251 commit 0e39aa7
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 1 deletion.
36 changes: 36 additions & 0 deletions modules/script_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,21 @@ def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, te


class CFGDenoisedParams:
def __init__(self, x, sampling_step, total_sampling_steps, inner_model):
self.x = x
"""Latent image representation in the process of being denoised"""

self.sampling_step = sampling_step
"""Current Sampling step number"""

self.total_sampling_steps = total_sampling_steps
"""Total number of sampling steps planned"""

self.inner_model = inner_model
"""Inner model reference used for denoising"""


class AfterCFGCallbackParams:
def __init__(self, x, sampling_step, total_sampling_steps):
self.x = x
"""Latent image representation in the process of being denoised"""
Expand All @@ -60,6 +75,10 @@ def __init__(self, x, sampling_step, total_sampling_steps):
self.total_sampling_steps = total_sampling_steps
"""Total number of sampling steps planned"""

self.output_altered = False
"""A flag for CFGDenoiser indicating whether the output has been altered by the callback"""



class UiTrainTabParams:
def __init__(self, txt2img_preview_params):
Expand All @@ -84,6 +103,7 @@ def __init__(self, imgs, cols, rows):
callbacks_image_saved=[],
callbacks_cfg_denoiser=[],
callbacks_cfg_denoised=[],
callbacks_cfg_after_cfg=[],
callbacks_before_component=[],
callbacks_after_component=[],
callbacks_image_grid=[],
Expand Down Expand Up @@ -174,6 +194,14 @@ def cfg_denoised_callback(params: CFGDenoisedParams):
report_exception(e, c, 'cfg_denoised_callback')


def cfg_after_cfg_callback(params: AfterCFGCallbackParams):
for c in callback_map['callbacks_cfg_after_cfg']:
try:
c.callback(params)
except Exception as e:
report_exception(e, c, 'cfg_after_cfg_callback')


def before_component_callback(component, **kwargs):
for c in callback_map['callbacks_before_component']:
try:
Expand Down Expand Up @@ -315,6 +343,14 @@ def on_cfg_denoised(callback):
add_callback(callback_map['callbacks_cfg_denoised'], callback)


def on_cfg_after_cfg(callback):
"""register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations are completed.
The callback is called with one argument:
- params: AfterCFGCallbackParams - parameters to be passed to the script for post-processing after cfg calculation.
"""
add_callback(callback_map['callbacks_cfg_after_cfg'], callback)


def on_before_component(callback):
"""register a function to be called before a component is created.
The callback is called with arguments:
Expand Down
8 changes: 7 additions & 1 deletion modules/sd_samplers_kdiffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import modules.shared as shared
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback

samplers_k_diffusion = [
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
Expand Down Expand Up @@ -145,7 +146,7 @@ def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):

x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:]))

denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps)
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
cfg_denoised_callback(denoised_params)

devices.test_for_nans(x_out, "unet")
Expand All @@ -165,6 +166,11 @@ def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised

after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
cfg_after_cfg_callback(after_cfg_callback_params)
if after_cfg_callback_params.output_altered:
denoised = after_cfg_callback_params.x

self.step += 1

return denoised
Expand Down

0 comments on commit 0e39aa7

Please sign in to comment.