Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose hooks to process input args in precision plugins #13000

Closed
awaelchli opened this issue May 6, 2022 · 1 comment · Fixed by #18209
Closed

Expose hooks to process input args in precision plugins #13000

awaelchli opened this issue May 6, 2022 · 1 comment · Fixed by #18209
Milestone

Comments

@awaelchli
Copy link
Contributor

awaelchli commented May 6, 2022

Proposed refactor

Motivation

#12983 attempted to fix an issue where the evaluation step methods in the DDPStrategy would call the LightningModule.x_step methods directly instead of going through the LightningDoublePrecisionModule.x_step methods.

This introduced a typing issue where we can't guarantee that self.model implements the x_step methods.

Pitch

Treat #12983 as a temporary fix and refactor away the need to have a precision wrapper.
All the precision wrapper does is process the input arguments.
This input processing could be done directly in the strategy.

Before

    def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
        with self.precision_plugin.test_step_context():
            return self.MAYBE_WRAPPER.test_step(*args, **kwargs)

After

    def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
        with self.precision_plugin.test_step_context():
            return self.LIGHTNING_MODULE.test_step(self.precision_plugin.process_step_inputs(*args, **kwargs))

Additional context


If you enjoy Lightning, check out our other projects! ⚡

  • Metrics: Machine learning metrics for distributed, scalable PyTorch applications.

  • Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.

  • Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.

  • Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.

  • Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.

cc @justusschock @awaelchli @rohitgr7 @carmocca @akihironitta

@awaelchli awaelchli added needs triage Waiting to be triaged by maintainers precision: double Double precision strategy refactor and removed needs triage Waiting to be triaged by maintainers labels May 6, 2022
@awaelchli awaelchli added this to the 1.7 milestone May 6, 2022
@carmocca
Copy link
Contributor

carmocca commented May 6, 2022

Sounds good to me. A relevant PR that still needs to land: #10079

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants