From cbdd9b83d490c62365356fb8d55a4af8387c739c Mon Sep 17 00:00:00 2001 From: Ashwin Vaidya Date: Wed, 15 May 2024 13:53:22 +0200 Subject: [PATCH] Add safe track Signed-off-by: Ashwin Vaidya --- src/anomalib/engine/engine.py | 4 +- src/anomalib/loggers/mlflow.py | 4 -- .../components/sampling/k_center_greedy.py | 4 +- src/anomalib/utils/rich.py | 47 +++++++++++++++++++ 4 files changed, 51 insertions(+), 8 deletions(-) create mode 100644 src/anomalib/utils/rich.py diff --git a/src/anomalib/engine/engine.py b/src/anomalib/engine/engine.py index 8e7e679650..f2575771e0 100644 --- a/src/anomalib/engine/engine.py +++ b/src/anomalib/engine/engine.py @@ -9,7 +9,7 @@ from typing import Any import torch -from lightning.pytorch.callbacks import Callback +from lightning.pytorch.callbacks import Callback, RichModelSummary, RichProgressBar from lightning.pytorch.loggers import Logger from lightning.pytorch.trainer import Trainer from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS @@ -406,7 +406,7 @@ def _setup_transform( def _setup_anomalib_callbacks(self) -> None: """Set up callbacks for the trainer.""" - _callbacks: list[Callback] = [] + _callbacks: list[Callback] = [RichProgressBar(), RichModelSummary()] # Add ModelCheckpoint if it is not in the callbacks list. has_checkpoint_callback = any(isinstance(c, ModelCheckpoint) for c in self._cache.args["callbacks"]) diff --git a/src/anomalib/loggers/mlflow.py b/src/anomalib/loggers/mlflow.py index 2bc3835072..64504893d4 100644 --- a/src/anomalib/loggers/mlflow.py +++ b/src/anomalib/loggers/mlflow.py @@ -7,12 +7,8 @@ from lightning.pytorch.utilities import rank_zero_only from matplotlib.figure import Figure -from anomalib.utils.exceptions.imports import try_import - from .base import ImageLoggerBase -try_import("mlflow") - class AnomalibMLFlowLogger(ImageLoggerBase, MLFlowLogger): """Logger for MLFlow. diff --git a/src/anomalib/models/components/sampling/k_center_greedy.py b/src/anomalib/models/components/sampling/k_center_greedy.py index 788f2e6683..2b0f495d28 100644 --- a/src/anomalib/models/components/sampling/k_center_greedy.py +++ b/src/anomalib/models/components/sampling/k_center_greedy.py @@ -8,10 +8,10 @@ # SPDX-License-Identifier: Apache-2.0 import torch -from rich.progress import track from torch.nn import functional as F # noqa: N812 from anomalib.models.components.dimensionality_reduction import SparseRandomProjection +from anomalib.utils.rich import safe_track class KCenterGreedy: @@ -98,7 +98,7 @@ def select_coreset_idxs(self, selected_idxs: list[int] | None = None) -> list[in selected_coreset_idxs: list[int] = [] idx = int(torch.randint(high=self.n_observations, size=(1,)).item()) - for _ in track(range(self.coreset_size), description="Selecting Coreset Indices."): + for _ in safe_track(sequence=range(self.coreset_size), description="Selecting Coreset Indices."): self.update_distances(cluster_centers=[idx]) idx = self.get_new_idx() if idx in selected_idxs: diff --git a/src/anomalib/utils/rich.py b/src/anomalib/utils/rich.py new file mode 100644 index 0000000000..99842b6b8b --- /dev/null +++ b/src/anomalib/utils/rich.py @@ -0,0 +1,47 @@ +"""Custom rich methods.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from collections.abc import Generator, Iterable +from typing import TYPE_CHECKING, Any + +from rich import get_console +from rich.progress import track + +if TYPE_CHECKING: + from rich.live import Live + + +class CacheRichLiveState: + """Cache the live state of the console. + + Note: This is a bit dangerous as it accesses private attributes of the console. + Use this with caution. + """ + + def __init__(self) -> None: + self.console = get_console() + self.live: "Live" | None = None + + def __enter__(self) -> None: + """Save the live state of the console.""" + # Need to access private attribute to get the live state + with self.console._lock: # noqa: SLF001 + self.live = self.console._live # noqa: SLF001 + self.console.clear_live() + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: # noqa: ANN401 + """Restore the live state of the console.""" + if self.live: + self.console.clear_live() + self.console.set_live(self.live) + + +def safe_track(*args, **kwargs) -> Generator[Iterable, Any, Any]: + """Wraps ``rich.progress.track`` with a context manager to cache the live state. + + For parameters look at ``rich.progress.track``. + """ + with CacheRichLiveState(): + yield from track(*args, **kwargs)