Best practice when the model and datamodule have conflicting hyperparameters? #9195
Unanswered
Erotemic
asked this question in
Lightning Trainer API: Trainer, LightningModule, LightningDataModule
Replies: 1 comment 1 reply
-
I don't think there's a solution for this, considering that Some repro code: from pytorch_lightning import Trainer
from tests.helpers import BoringDataModule, BoringModel
class BoringDataModule(BoringDataModule):
def __init__(self, x: int):
super().__init__()
self.save_hyperparameters()
class BoringModel(BoringModel):
def __init__(self, x: int):
super().__init__()
self.save_hyperparameters()
def run():
x = 5
dm = BoringDataModule(x)
model = BoringModel(dm.hparams.x)
trainer = Trainer(fast_dev_run=True)
trainer.fit(model, dm)
if __name__ == "__main__":
run() @kaushikb11 do you have any thoughts? |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I'm getting an error when trying to use my lightning module and datamodule together:
This is because they both take "channels" as an argument. The value of this argument is a string-code indicating the input modalities that will be fed to the network.
For instance:
From a design perspective this makes sense that both the datamodule and lightning model will take this parameter. The datamodule needs to know which channels from its available pool it's torch datasets will produce, and the lightning module needs to know what the number of input modalities and the bands within those those modalities.
Typically when I construct a training session I will create my datamodule and then I will pass the
datamodule.channels
attribute as a constructor kwarg to the module.However, at inference time, when I load a the model, I will construct a datamodule based on
model.channels
, so I ensure that it can produce what the model expects.I suppose in the meantime, I'll just change the name of one of the parameters so they don't conflict, but in general what is the best practice to avoid this issue? Is this a shortcoming of lightning? Or could I be doing something better?
Beta Was this translation helpful? Give feedback.
All reactions