-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove Trainer reference from lightning module and datamodule #7315
Comments
Hello, @edenlightning @ananthsub may I ask a question about this proposal and for your advice? Context More concretely, for algorithms that need to do inference during training (not validation, but to support training steps), I think PL doesn't have good support today. I implemented a thin wrapper (subclass) over My wish Now, here comes a question.
Could you help me understand this idea more concretely possibly with an example? I would like to understand if the idea would support our use case. A direction I was having in mind, if I were part of the core PL dev team is something like: class MyModel(LightningModule):
def on_train_dataloader(self):
the_singleton_trainer_ref = self.trainer # or via some other way if we want to remove the ref in LM.
with new_trainer_context(the_singleton_trainer_ref) as new_trainer:
# do whatever we want.
# new_trainer would have a fresh state but inherits all other data members.
new_trainer.predict(model=self, datamodule=...)
# we need to make sure the work inside the context manager is side effect free.
# 1. the original trainer state should preserved the same as before after the context is closed.
# 2. no side effect for model. for instance, teardown should not move a model from GPU to CPU if user
# doesn't want it. This behavior is not configurable today.
# 3. no side effect for the data module. The context manager could be implemented:
I am not an expert in PL but I have been reading its source code for a while. I'd love to hear your expert advice and help brainstorm. Thank you for taking the time to read this. |
My apologies for the slow reply. By Trainer Context, I do not mean a Python context manager. Rather, I mean a read-only view of the Trainer state. This would allow the LightningModule and DataModule to have access to information in a controlled manner without exposing the full Trainer object here.
Fundamentally, I think this should be solved at the loop level. Though I would like to better understand the constraints of your use case (e.g. why iterating over a dataloader inside the LightningModule is not feasible for you vs. needing to call trainer.predict). This is a huge challenge for precisely the reasons you mentioned: it is very difficult to manage storing and resuming trainer states appropriately. I am also curious if there's another way to split up your use case into a pipeline rather than doing everything at once. |
Motivation: Remove trainer reference from LightningModule and LightningDataModule.
Benefits:
Possible solution: think about introducing a TrainerContext / TrainerState object to pass state and deprecate references to the trainer. This would be mostly read-only data that the LightningModule could leverage for various settings like progress (epoch/step count), distributed setting (global rank, local rank, data parallel/ddp/deepspeed/etc) and more.
The ambition is the lightning module has a tightly controlled view of the trainer, while the trainer has full insight into the module.
cc @Borda @tchaton @justusschock @awaelchli
The text was updated successfully, but these errors were encountered: