diff --git a/src/anomalib/models/image/efficient_ad/lightning_model.py b/src/anomalib/models/image/efficient_ad/lightning_model.py index 1f4026a8ac..4fcce26d1c 100644 --- a/src/anomalib/models/image/efficient_ad/lightning_model.py +++ b/src/anomalib/models/image/efficient_ad/lightning_model.py @@ -240,10 +240,25 @@ def configure_optimizers(self) -> torch.optim.Optimizer: lr=self.lr, weight_decay=self.weight_decay, ) - num_steps = min( - self.trainer.max_steps, - self.trainer.max_epochs * len(self.trainer.datamodule.train_dataloader()), - ) + + if self.trainer.max_epochs < 0 and self.trainer.max_steps < 0: + msg = "A finite number of steps or epochs must be defined" + raise ValueError(msg) + + # lightning stops training when either 'max_steps' or 'max_epochs' is reached (earliest), + # so actual training steps need to be determined here + if self.trainer.max_epochs < 0: + # max_epochs not set + num_steps = self.trainer.max_steps + elif self.trainer.max_steps < 0: + # max_steps not set -> determine steps as 'max_epochs' * 'steps in a single training epoch' + num_steps = self.trainer.max_epochs * len(self.trainer.datamodule.train_dataloader()) + else: + num_steps = min( + self.trainer.max_steps, + self.trainer.max_epochs * len(self.trainer.datamodule.train_dataloader()), + ) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.95 * num_steps), gamma=0.1) return {"optimizer": optimizer, "lr_scheduler": scheduler}