Skip to content

Commit

Permalink
🐞 Fix Rich Progress with Patchcore Training (#2062)
Browse files Browse the repository at this point in the history
Add safe track

Signed-off-by: Ashwin Vaidya <[email protected]>
  • Loading branch information
ashwinvaidya17 authored May 15, 2024
1 parent 5ff7f10 commit 849de79
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 8 deletions.
4 changes: 2 additions & 2 deletions src/anomalib/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import Any

import torch
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.callbacks import Callback, RichModelSummary, RichProgressBar
from lightning.pytorch.loggers import Logger
from lightning.pytorch.trainer import Trainer
from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS
Expand Down Expand Up @@ -406,7 +406,7 @@ def _setup_transform(

def _setup_anomalib_callbacks(self) -> None:
"""Set up callbacks for the trainer."""
_callbacks: list[Callback] = []
_callbacks: list[Callback] = [RichProgressBar(), RichModelSummary()]

# Add ModelCheckpoint if it is not in the callbacks list.
has_checkpoint_callback = any(isinstance(c, ModelCheckpoint) for c in self._cache.args["callbacks"])
Expand Down
4 changes: 0 additions & 4 deletions src/anomalib/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,8 @@
from lightning.pytorch.utilities import rank_zero_only
from matplotlib.figure import Figure

from anomalib.utils.exceptions.imports import try_import

from .base import ImageLoggerBase

try_import("mlflow")


class AnomalibMLFlowLogger(ImageLoggerBase, MLFlowLogger):
"""Logger for MLFlow.
Expand Down
4 changes: 2 additions & 2 deletions src/anomalib/models/components/sampling/k_center_greedy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
# SPDX-License-Identifier: Apache-2.0

import torch
from rich.progress import track
from torch.nn import functional as F # noqa: N812

from anomalib.models.components.dimensionality_reduction import SparseRandomProjection
from anomalib.utils.rich import safe_track


class KCenterGreedy:
Expand Down Expand Up @@ -98,7 +98,7 @@ def select_coreset_idxs(self, selected_idxs: list[int] | None = None) -> list[in

selected_coreset_idxs: list[int] = []
idx = int(torch.randint(high=self.n_observations, size=(1,)).item())
for _ in track(range(self.coreset_size), description="Selecting Coreset Indices."):
for _ in safe_track(sequence=range(self.coreset_size), description="Selecting Coreset Indices."):
self.update_distances(cluster_centers=[idx])
idx = self.get_new_idx()
if idx in selected_idxs:
Expand Down
47 changes: 47 additions & 0 deletions src/anomalib/utils/rich.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Custom rich methods."""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from collections.abc import Generator, Iterable
from typing import TYPE_CHECKING, Any

from rich import get_console
from rich.progress import track

if TYPE_CHECKING:
from rich.live import Live


class CacheRichLiveState:
"""Cache the live state of the console.
Note: This is a bit dangerous as it accesses private attributes of the console.
Use this with caution.
"""

def __init__(self) -> None:
self.console = get_console()
self.live: "Live" | None = None

def __enter__(self) -> None:
"""Save the live state of the console."""
# Need to access private attribute to get the live state
with self.console._lock: # noqa: SLF001
self.live = self.console._live # noqa: SLF001
self.console.clear_live()

def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: # noqa: ANN401
"""Restore the live state of the console."""
if self.live:
self.console.clear_live()
self.console.set_live(self.live)


def safe_track(*args, **kwargs) -> Generator[Iterable, Any, Any]:
"""Wraps ``rich.progress.track`` with a context manager to cache the live state.
For parameters look at ``rich.progress.track``.
"""
with CacheRichLiveState():
yield from track(*args, **kwargs)

0 comments on commit 849de79

Please sign in to comment.