From a1bea000661a7af4b8a5c9a0d09537f7b855efb3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 24 Nov 2022 16:52:48 +0100 Subject: [PATCH] Minor LightningLite clean up (#15780) --- src/lightning_lite/lite.py | 3 +-- src/lightning_lite/strategies/fsdp.py | 6 ++++++ src/lightning_lite/strategies/fsdp_native.py | 20 ------------------- .../strategies/fully_sharded_native.py | 2 +- 4 files changed, 8 insertions(+), 23 deletions(-) delete mode 100644 src/lightning_lite/strategies/fsdp_native.py diff --git a/src/lightning_lite/lite.py b/src/lightning_lite/lite.py index 5e41f15121acb..a198efd1ceab8 100644 --- a/src/lightning_lite/lite.py +++ b/src/lightning_lite/lite.py @@ -13,7 +13,6 @@ # limitations under the License. import inspect import os -from abc import ABC from contextlib import contextmanager, nullcontext from functools import partial from pathlib import Path @@ -55,7 +54,7 @@ from lightning_lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer -class LightningLite(ABC): +class LightningLite: """Lite accelerates your PyTorch training or inference code with minimal changes required. - Automatic placement of models and data onto the device. diff --git a/src/lightning_lite/strategies/fsdp.py b/src/lightning_lite/strategies/fsdp.py index 8053992d18525..46a36bf95b763 100644 --- a/src/lightning_lite/strategies/fsdp.py +++ b/src/lightning_lite/strategies/fsdp.py @@ -306,3 +306,9 @@ def no_backward_sync(self, module: Module) -> Generator: ) with module.no_sync(): yield + + +def _optimizer_has_flat_params(optimizer: Optimizer) -> bool: + from torch.distributed.fsdp import FlatParameter + + return any(isinstance(param, FlatParameter) for param in optimizer.param_groups[0]["params"]) diff --git a/src/lightning_lite/strategies/fsdp_native.py b/src/lightning_lite/strategies/fsdp_native.py deleted file mode 100644 index 9e70400e476dc..0000000000000 --- a/src/lightning_lite/strategies/fsdp_native.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from torch.optim import Optimizer - - -def _optimizer_has_flat_params(optimizer: Optimizer) -> bool: - from torch.distributed.fsdp import FlatParameter - - return any(isinstance(param, FlatParameter) for param in optimizer.param_groups[0]["params"]) diff --git a/src/pytorch_lightning/strategies/fully_sharded_native.py b/src/pytorch_lightning/strategies/fully_sharded_native.py index c628f2a653a79..69110db45507f 100644 --- a/src/pytorch_lightning/strategies/fully_sharded_native.py +++ b/src/pytorch_lightning/strategies/fully_sharded_native.py @@ -20,7 +20,7 @@ import pytorch_lightning as pl from lightning_lite.plugins import CheckpointIO, ClusterEnvironment -from lightning_lite.strategies.fsdp_native import _optimizer_has_flat_params +from lightning_lite.strategies.fsdp import _optimizer_has_flat_params from lightning_lite.utilities.distributed import ( _get_default_process_group_backend_for_device, _init_dist_connection,