-
Notifications
You must be signed in to change notification settings - Fork 210
Fix DDP support #1182
Fix DDP support #1182
Changes from 14 commits
c6cb7ef
d5e775e
c4cbb3b
06ab125
7189de5
8112f96
30b4003
08aaae2
6634759
2fe708c
dc9fce2
c734949
5c8db6e
373ee6f
5f7fafb
ba71adc
be3da37
d598948
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import functools | ||
|
||
import pytorch_lightning as pl | ||
from pytorch_lightning import Callback | ||
|
||
from flash.core.data.io.output import Output | ||
from flash.core.data.io.output_transform import OutputTransform | ||
|
||
|
||
class TransformPredictions(Callback): | ||
"""``TransformPredictions`` is a :class:`~pytorch_lightning.callbacks.base.Callback` which can be used to apply an | ||
:class:`~flash.core.data.io.output_transform.OutputTransform` and an :class:`~flash.core.data.io.output.Output` to | ||
model predictions. | ||
|
||
Args: | ||
output_transform: The :class:`~flash.core.data.io.output_transform.OutputTransform` to apply. | ||
output: The :class:`~flash.core.data.io.output.Output` to apply. | ||
""" | ||
|
||
def __init__(self, output_transform: OutputTransform, output: Output): | ||
super().__init__() | ||
|
||
self.output_transform = output_transform | ||
self.output = output | ||
|
||
def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | ||
predict_step = pl_module.predict_step | ||
|
||
@functools.wraps(predict_step) | ||
def wrapper(*args, **kwargs): | ||
predictions = predict_step(*args, **kwargs) | ||
if predictions is not None: | ||
predictions = self.output_transform(predictions) | ||
predictions = [self.output(prediction) for prediction in predictions] | ||
return predictions | ||
Comment on lines
+44
to
+48
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just curious, when do you think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think there are some cases where it can be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll also see if there is a possibility that |
||
|
||
pl_module.predict_step = wrapper | ||
|
||
def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | ||
pl_module.predict_step = pl_module.predict_step.__wrapped__ |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import os | ||
|
||
import pytest | ||
|
||
|
||
def forked(callable): | ||
# PyTest forked not available in Windows | ||
if os.name == "nt": | ||
return callable | ||
return pytest.mark.forked(callable) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@krshrimali This is the main fix. We used to have a bug where the data was accidentally included in the checkpoint. We patched that by adding this overrides. But then DDP spawn needs to pickle the data to send it to each process so this causes problems. We refactored away the bit that got this included in the checkpoint so now can be safely removed 😃
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, thank you so much for the explanation, @ethanwharris!