From ee2304c4382b96919cd954a55fb3dc6d9db44689 Mon Sep 17 00:00:00 2001 From: blaz-r Date: Thu, 1 Feb 2024 23:03:44 +0100 Subject: [PATCH 1/2] Fix scheduler num_steps for efficient ad Signed-off-by: blaz-r Signed-off-by: Blaz Rolih --- .../image/efficient_ad/lightning_model.py | 22 +++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/src/anomalib/models/image/efficient_ad/lightning_model.py b/src/anomalib/models/image/efficient_ad/lightning_model.py index 1f4026a8ac..29c7a9829c 100644 --- a/src/anomalib/models/image/efficient_ad/lightning_model.py +++ b/src/anomalib/models/image/efficient_ad/lightning_model.py @@ -240,10 +240,24 @@ def configure_optimizers(self) -> torch.optim.Optimizer: lr=self.lr, weight_decay=self.weight_decay, ) - num_steps = min( - self.trainer.max_steps, - self.trainer.max_epochs * len(self.trainer.datamodule.train_dataloader()), - ) + + if self.trainer.max_epochs < 0 and self.trainer.max_steps < 0: + raise ValueError("A finite number of steps or epochs must be defined") + + # lightning stops training when either 'max_steps' or 'max_epochs' is reached (earliest), + # so actual training steps need to be determined here + if self.trainer.max_epochs < 0: + # max_epochs not set + num_steps = self.trainer.max_steps + elif self.trainer.max_steps < 0: + # max_steps not set -> determine steps as 'max_epochs' * 'steps in a single training epoch' + num_steps = self.trainer.max_epochs * len(self.trainer.datamodule.train_dataloader()) + else: + num_steps = min( + self.trainer.max_steps, + self.trainer.max_epochs * len(self.trainer.datamodule.train_dataloader()), + ) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.95 * num_steps), gamma=0.1) return {"optimizer": optimizer, "lr_scheduler": scheduler} From e5ae59d0ba9781b81998759e87cad2b971762e5e Mon Sep 17 00:00:00 2001 From: Blaz Rolih Date: Mon, 5 Feb 2024 18:58:41 +0100 Subject: [PATCH 2/2] Address ruff issues Signed-off-by: Blaz Rolih --- src/anomalib/models/image/efficient_ad/lightning_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/anomalib/models/image/efficient_ad/lightning_model.py b/src/anomalib/models/image/efficient_ad/lightning_model.py index 29c7a9829c..4fcce26d1c 100644 --- a/src/anomalib/models/image/efficient_ad/lightning_model.py +++ b/src/anomalib/models/image/efficient_ad/lightning_model.py @@ -242,7 +242,8 @@ def configure_optimizers(self) -> torch.optim.Optimizer: ) if self.trainer.max_epochs < 0 and self.trainer.max_steps < 0: - raise ValueError("A finite number of steps or epochs must be defined") + msg = "A finite number of steps or epochs must be defined" + raise ValueError(msg) # lightning stops training when either 'max_steps' or 'max_epochs' is reached (earliest), # so actual training steps need to be determined here