Skip to content

Commit

Permalink
apply fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Nov 13, 2024
1 parent b373a76 commit 5da2e43
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 28 deletions.
21 changes: 11 additions & 10 deletions src/lightning/pytorch/loops/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
import shutil
import sys
from collections import ChainMap, OrderedDict, defaultdict
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from typing import Any, DefaultDict, Iterable, Iterator, List, Optional, Tuple, Union
from typing import Any, Optional, Union

from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor
Expand Down Expand Up @@ -67,17 +68,17 @@ def __init__(
self.verbose = verbose
self.inference_mode = inference_mode
self.batch_progress = _BatchProgress() # across dataloaders
self._max_batches: List[Union[int, float]] = []
self._max_batches: list[Union[int, float]] = []

self._results = _ResultCollection(training=False)
self._logged_outputs: List[_OUT_DICT] = []
self._logged_outputs: list[_OUT_DICT] = []
self._has_run: bool = False
self._trainer_fn = trainer_fn
self._stage = stage
self._data_source = _DataLoaderSource(None, f"{stage.dataloader_prefix}_dataloader")
self._combined_loader: Optional[CombinedLoader] = None
self._data_fetcher: Optional[_DataFetcher] = None
self._seen_batches_per_dataloader: DefaultDict[int, int] = defaultdict(int)
self._seen_batches_per_dataloader: defaultdict[int, int] = defaultdict(int)
self._last_val_dl_reload_epoch = float("-inf")
self._module_mode = _ModuleMode()
self._restart_stage = RestartStage.NONE
Expand All @@ -90,7 +91,7 @@ def num_dataloaders(self) -> int:
return len(combined_loader.flattened)

@property
def max_batches(self) -> List[Union[int, float]]:
def max_batches(self) -> list[Union[int, float]]:
"""The max number of batches to run per dataloader."""
max_batches = self._max_batches
if not self.trainer.sanity_checking:
Expand All @@ -114,7 +115,7 @@ def _is_sequential(self) -> bool:
return self._combined_loader._mode == "sequential"

@_no_grad_context
def run(self) -> List[_OUT_DICT]:
def run(self) -> list[_OUT_DICT]:
self.setup_data()
if self.skip:
return []
Expand Down Expand Up @@ -280,7 +281,7 @@ def on_run_start(self) -> None:
self._on_evaluation_start()
self._on_evaluation_epoch_start()

def on_run_end(self) -> List[_OUT_DICT]:
def on_run_end(self) -> list[_OUT_DICT]:
"""Runs the ``_on_evaluation_epoch_end`` hook."""
# if `done` returned True before any iterations were done, this won't have been called in `on_advance_end`
self.trainer._logger_connector.epoch_end_reached()
Expand Down Expand Up @@ -508,7 +509,7 @@ def _verify_dataloader_idx_requirement(self) -> None:
)

@staticmethod
def _get_keys(data: dict) -> Iterable[Tuple[str, ...]]:
def _get_keys(data: dict) -> Iterable[tuple[str, ...]]:
for k, v in data.items():
if isinstance(v, dict):
for new_key in apply_to_collection(v, dict, _EvaluationLoop._get_keys):
Expand All @@ -527,7 +528,7 @@ def _find_value(data: dict, target: Iterable[str]) -> Optional[Any]:
return _EvaluationLoop._find_value(result, rest)

@staticmethod
def _print_results(results: List[_OUT_DICT], stage: str) -> None:
def _print_results(results: list[_OUT_DICT], stage: str) -> None:
# remove the dl idx suffix
results = [{k.split("/dataloader_idx_")[0]: v for k, v in result.items()} for result in results]
metrics_paths = {k for keys in apply_to_collection(results, dict, _EvaluationLoop._get_keys) for k in keys}
Expand All @@ -544,7 +545,7 @@ def _print_results(results: List[_OUT_DICT], stage: str) -> None:
term_size = shutil.get_terminal_size(fallback=(120, 30)).columns or 120
max_length = int(min(max(len(max(metrics_strs, key=len)), len(max(headers, key=len)), 25), term_size / 2))

rows: List[List[Any]] = [[] for _ in metrics_paths]
rows: list[list[Any]] = [[] for _ in metrics_paths]

for result in results:
for metric, row in zip(metrics_paths, rows):
Expand Down
8 changes: 4 additions & 4 deletions src/lightning/pytorch/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import logging
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
from typing import Any, Optional, Union

import torch
from typing_extensions import override
Expand Down Expand Up @@ -104,7 +104,7 @@ def __init__(

self._data_source = _DataLoaderSource(None, "train_dataloader")
self._combined_loader: Optional[CombinedLoader] = None
self._combined_loader_states_to_load: List[Dict[str, Any]] = []
self._combined_loader_states_to_load: list[dict[str, Any]] = []
self._data_fetcher: Optional[_DataFetcher] = None
self._last_train_dl_reload_epoch = float("-inf")
self._restart_stage = RestartStage.NONE
Expand Down Expand Up @@ -504,14 +504,14 @@ def teardown(self) -> None:
self.epoch_loop.teardown()

@override
def on_save_checkpoint(self) -> Dict:
def on_save_checkpoint(self) -> dict:
state_dict = super().on_save_checkpoint()
if self._combined_loader is not None and (loader_states := self._combined_loader._state_dicts()):
state_dict["combined_loader"] = loader_states
return state_dict

@override
def on_load_checkpoint(self, state_dict: Dict) -> None:
def on_load_checkpoint(self, state_dict: dict) -> None:
self._combined_loader_states_to_load = state_dict.get("combined_loader", [])
super().on_load_checkpoint(state_dict)

Expand Down
12 changes: 6 additions & 6 deletions src/lightning/pytorch/loops/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Optional
from typing import Optional

import lightning.pytorch as pl
from lightning.pytorch.loops.progress import _BaseProgress
Expand Down Expand Up @@ -41,7 +41,7 @@ def restarting(self, restarting: bool) -> None:
def reset_restart_stage(self) -> None:
pass

def on_save_checkpoint(self) -> Dict:
def on_save_checkpoint(self) -> dict:
"""Called when saving a model checkpoint, use to persist loop state.
Returns:
Expand All @@ -50,10 +50,10 @@ def on_save_checkpoint(self) -> Dict:
"""
return {}

def on_load_checkpoint(self, state_dict: Dict) -> None:
def on_load_checkpoint(self, state_dict: dict) -> None:
"""Called when loading a model checkpoint, use to reload loop state."""

def state_dict(self, destination: Optional[Dict] = None, prefix: str = "") -> Dict:
def state_dict(self, destination: Optional[dict] = None, prefix: str = "") -> dict:
"""The state dict is determined by the state and progress of this loop and all its children.
Args:
Expand All @@ -77,7 +77,7 @@ def state_dict(self, destination: Optional[Dict] = None, prefix: str = "") -> Di

def load_state_dict(
self,
state_dict: Dict,
state_dict: dict,
prefix: str = "",
) -> None:
"""Loads the state of this loop and all its children."""
Expand All @@ -88,7 +88,7 @@ def load_state_dict(
self.restarting = True
self._loaded_from_state_dict = True

def _load_from_state_dict(self, state_dict: Dict, prefix: str) -> None:
def _load_from_state_dict(self, state_dict: dict, prefix: str) -> None:
for k, v in self.__dict__.items():
key = prefix + k
if key not in state_dict:
Expand Down
6 changes: 3 additions & 3 deletions src/lightning/pytorch/loops/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import math
from collections import OrderedDict
from dataclasses import dataclass
from typing import Any, Dict, Optional, Union
from typing import Any, Optional, Union

from typing_extensions import override

Expand Down Expand Up @@ -390,13 +390,13 @@ def teardown(self) -> None:
self.val_loop.teardown()

@override
def on_save_checkpoint(self) -> Dict:
def on_save_checkpoint(self) -> dict:
state_dict = super().on_save_checkpoint()
state_dict["_batches_that_stepped"] = self._batches_that_stepped
return state_dict

@override
def on_load_checkpoint(self, state_dict: Dict) -> None:
def on_load_checkpoint(self, state_dict: dict) -> None:
self._batches_that_stepped = state_dict.get("_batches_that_stepped", 0)

def _accumulated_batches_reached(self) -> bool:
Expand Down
11 changes: 6 additions & 5 deletions tests/tests_pytorch/loops/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from collections.abc import Iterator
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Dict, Iterator
from typing import Any
from unittest.mock import ANY, Mock

import pytest
Expand Down Expand Up @@ -87,10 +88,10 @@ def advance(self) -> None:

self.outputs.append(value)

def state_dict(self) -> Dict:
def state_dict(self) -> dict:
return {"iteration_count": self.iteration_count, "outputs": self.outputs}

def load_state_dict(self, state_dict: Dict) -> None:
def load_state_dict(self, state_dict: dict) -> None:
self.iteration_count = state_dict["iteration_count"]
self.outputs = state_dict["outputs"]

Expand Down Expand Up @@ -140,10 +141,10 @@ def advance(self) -> None:
return
loop.run()

def on_save_checkpoint(self) -> Dict:
def on_save_checkpoint(self) -> dict:
return {"a": self.a}

def on_load_checkpoint(self, state_dict: Dict) -> None:
def on_load_checkpoint(self, state_dict: dict) -> None:
self.a = state_dict["a"]

trainer = Trainer()
Expand Down

0 comments on commit 5da2e43

Please sign in to comment.