diff --git a/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py b/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py index 9be338695c..cbc9544383 100644 --- a/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py +++ b/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py @@ -36,7 +36,6 @@ def __init__( first_conv: bool = False, maxpool1: bool = False, enc_out_dim: int = 512, - kl_coeff: float = 0.1, latent_dim: int = 256, lr: float = 1e-4, **kwargs diff --git a/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py b/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py index ef19110a02..b9f449ac21 100644 --- a/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py +++ b/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py @@ -127,6 +127,7 @@ def step(self, batch, batch_idx): kl = log_qz - log_pz kl = kl.mean() + kl *= self.kl_coeff loss = kl + recon_loss