Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Future 1/n: package in src/ folder #13293

Merged
merged 19 commits into from
Jun 15, 2022
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
if _FAIRSCALE_AVAILABLE:
from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler
else:
OSS = ShardedGradScaler = object
carmocca marked this conversation as resolved.
Show resolved Hide resolved


class ShardedNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin):
Expand Down
2 changes: 2 additions & 0 deletions src/pytorch_lightning/strategies/collaborative.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

if _HIVEMIND_AVAILABLE:
import hivemind
else:
hivemind = None

log = logging.getLogger(__name__)

Expand Down
2 changes: 2 additions & 0 deletions src/pytorch_lightning/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@

if _FAIRSCALE_AVAILABLE:
from fairscale.optim import OSS
else:
OSS = object
if _TORCH_GREATER_EQUAL_1_10 and torch.distributed.is_available():
from torch.distributed.algorithms.model_averaging.averagers import ModelAverager

Expand Down
2 changes: 2 additions & 0 deletions src/pytorch_lightning/strategies/sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
from fairscale.optim import OSS

from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel, unwrap_lightning_module_sharded
else:
OSS = ShardedDataParallel = object


class DDPShardedStrategy(DDPStrategy):
Expand Down
2 changes: 2 additions & 0 deletions src/pytorch_lightning/strategies/sharded_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from fairscale.optim import OSS

from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel, unwrap_lightning_module_sharded
else:
OSS = ShardedDataParallel = object


class DDPSpawnShardedStrategy(DDPSpawnStrategy):
Expand Down