From 564309905882229e5d374fb67cd1ec8e749ad644 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 23 Jan 2025 16:09:16 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tests_pytorch/callbacks/test_weight_averaging.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/tests_pytorch/callbacks/test_weight_averaging.py b/tests/tests_pytorch/callbacks/test_weight_averaging.py index b59359933eb82..bf49ff68cfb2e 100644 --- a/tests/tests_pytorch/callbacks/test_weight_averaging.py +++ b/tests/tests_pytorch/callbacks/test_weight_averaging.py @@ -18,8 +18,8 @@ import pytest import torch from torch import Tensor, nn -from torch.utils.data import DataLoader from torch.optim.swa_utils import get_swa_avg_fn +from torch.utils.data import DataLoader from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.callbacks import WeightAveraging @@ -170,9 +170,8 @@ def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None: def test_weight_averaging_deepcopy(tmp_path): - """Ensure that WeightAveraging callback doesn't deepcopy the data loaders or the data module and consume memory more - than necessary. - """ + """Ensure that WeightAveraging callback doesn't deepcopy the data loaders or the data module and consume memory + more than necessary.""" class TestCallback(WeightAveraging): def __init__(self, *args, **kwargs):