Skip to content

Commit

Permalink
Update SDXL_DeepCache.py
Browse files Browse the repository at this point in the history
  • Loading branch information
WentianZhang-ML authored Apr 14, 2024
1 parent 4b9c6a9 commit 04ea1d0
Showing 1 changed file with 21 additions and 1 deletion.
22 changes: 21 additions & 1 deletion src/tgate/SDXL_DeepCache.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,15 +467,35 @@ def tgate(
xm.mark_step()
self.deepcache.disable()
self.deepcache.enable()

if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast

if needs_upcasting:
self.upcast_vae()
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
elif latents.dtype != self.vae.dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
self.vae = self.vae.to(latents.dtype)

# unscale/denormalize the latents
# denormalize with the mean and std if available and not None
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
if has_latents_mean and has_latents_std:
latents_mean = (
torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
)
latents_std = (
torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
)
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
else:
latents = latents / self.vae.config.scaling_factor

image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image = self.vae.decode(latents, return_dict=False)[0]

# cast back to fp16 if needed
if needs_upcasting:
Expand Down

0 comments on commit 04ea1d0

Please sign in to comment.