How to save and load LightningModule whose input containing the pretrained moduel? #10037
-
Hi, I'm applying Pytorch Lightning module to VAE and our model We first train VAE and give the best checkpoint of pretrained VAE as the initial weight of our model. # STEP 1. Train VAE
vae = VAE(...)
trainer = Trainer(...)
trainer.fit(vae)
# STEP 2.
vae = VAE.load_from_checkpoint(...)
class Model(LightningModule):
def __init__(self, encoder, decoder, learning_rate):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.save_hyperparameters("learning_rate")
...
encoder = copy.deepcopy(vae.encoder)
decoder = copy.deepcopy(vae.decoder)
model = Model(
encoder=encoder,
decoder=decoder,
...
)
trainer.fit(model) The problem is when I load the model after train ends. Since the torch modules are contained in input arguments of Model, the common approach model = Model.load_from_checkpoint(...) yields following error messages. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
since part of your model is inside arguments, you can randomly initialize your VAE and let the new checkpoint configure it's weights vae = VAE()
encoder = copy.deepcopy(vae.encoder)
decoder = copy.deepcopy(vae.decoder)
model = Model.load_from_checkpoint(..., encoder=encoder, decoder=decoder) |
Beta Was this translation helpful? Give feedback.
since part of your model is inside arguments, you can randomly initialize your VAE and let the new checkpoint configure it's weights