Skip to content

Commit

Permalink
Make new LATENT_PREVIEWER type for declaring KSampler preview methods
Browse files Browse the repository at this point in the history
  • Loading branch information
space-nuko committed May 31, 2023
1 parent 32ebc08 commit 2c9799d
Showing 1 changed file with 39 additions and 13 deletions.
52 changes: 39 additions & 13 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@
import folder_paths


class LatentPreviewer:
def decode_latent_to_preview(self, device, x0):
pass


def before_node_execution():
comfy.model_management.throw_exception_if_processing_interrupted()

Expand Down Expand Up @@ -282,6 +287,27 @@ def encode(self, taesd, pixels):
samples = taesd.encoder(pixels.permute(0, 3, 1, 2).to(device)).to(device)
return ({"samples": samples}, )

class TAESDPreviewerImpl(LatentPreviewer):
def __init__(self, taesd):
self.taesd = taesd

def decode_latent_to_preview(self, device, x0):
x_sample = self.taesd.decoder(x0.to(device))[0].detach()
x_sample = self.taesd.unscale_latents(x_sample) # returns value in [-2, 2]
x_sample = x_sample * 0.5
return x_sample

class TAESDPreviewer:
@classmethod
def INPUT_TYPES(s):
return {"required": { "taesd": ("TAESD", ), }}
RETURN_TYPES = ("LATENT_PREVIEWER",)
FUNCTION = "make_previewer"

CATEGORY = "latent/previewer"

def make_previewer(self, taesd):
return (TAESDPreviewerImpl(taesd), )

class SaveLatent:
def __init__(self):
Expand Down Expand Up @@ -986,10 +1012,8 @@ def set_mask(self, samples, mask):
return (s,)


def decode_latent_to_preview_image(taesd, device, preview_format, x0):
x_sample = taesd.decoder(x0.to(device))[0].detach()
x_sample = taesd.unscale_latents(x_sample) # returns value in [-2, 2]
x_sample = x_sample * 0.5
def decode_latent_to_preview_image(previewer, device, preview_format, x0):
x_sample = previewer.decode_latent_to_preview(device, x0)

x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
Expand All @@ -1015,7 +1039,7 @@ def decode_latent_to_preview_image(taesd, device, preview_format, x0):
return preview_bytes


def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, taesd=None):
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, previewer=None):
device = comfy.model_management.get_torch_device()
latent_image = latent["samples"]

Expand All @@ -1036,8 +1060,8 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
pbar = comfy.utils.ProgressBar(steps)
def callback(step, x0, x, total_steps):
preview_bytes = None
if taesd:
preview_bytes = decode_latent_to_preview_image(taesd, device, preview_format, x0)
if previewer:
preview_bytes = decode_latent_to_preview_image(previewer, device, preview_format, x0)
pbar.update_absolute(step + 1, total_steps, preview_bytes)

samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
Expand All @@ -1063,16 +1087,16 @@ def INPUT_TYPES(s):
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
},
"optional": {
"taesd": ("TAESD",)
"previewer": ("LATENT_PREVIEWER",)
}}

RETURN_TYPES = ("LATENT",)
FUNCTION = "sample"

CATEGORY = "sampling"

def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, taesd=None):
return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, taesd=taesd)
def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, previewer=None):
return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, previewer=previewer)

class KSamplerAdvanced:
@classmethod
Expand All @@ -1093,22 +1117,22 @@ def INPUT_TYPES(s):
"return_with_leftover_noise": (["disable", "enable"], ),
},
"optional": {
"taesd": ("TAESD",)
"previewer": ("LATENT_PREVIEWER",)
}}

RETURN_TYPES = ("LATENT",)
FUNCTION = "sample"

CATEGORY = "sampling"

def sample(self, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0, taesd=None):
def sample(self, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0, previewer=None):
force_full_denoise = True
if return_with_leftover_noise == "enable":
force_full_denoise = False
disable_noise = False
if add_noise == "disable":
disable_noise = True
return common_ksampler(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise, taesd=taesd)
return common_ksampler(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise, previewer=previewer)

class SaveImage:
def __init__(self):
Expand Down Expand Up @@ -1369,6 +1393,7 @@ def expand_image(self, image, left, top, right, bottom, feathering):
"VAELoader": VAELoader,
"TAESDDecode": TAESDDecode,
"TAESDEncode": TAESDEncode,
"TAESDPreviewer": TAESDPreviewer,
"TAESDLoader": TAESDLoader,
"EmptyLatentImage": EmptyLatentImage,
"LatentUpscale": LatentUpscale,
Expand Down Expand Up @@ -1425,6 +1450,7 @@ def expand_image(self, image, left, top, right, bottom, feathering):
"CheckpointLoaderSimple": "Load Checkpoint",
"VAELoader": "Load VAE",
"TAESDLoader": "Load TAESD",
"TAESDPreviewer": "TAESD Previewer",
"LoraLoader": "Load LoRA",
"CLIPLoader": "Load CLIP",
"ControlNetLoader": "Load ControlNet Model",
Expand Down

0 comments on commit 2c9799d

Please sign in to comment.