diff --git a/src/anomalib/models/efficient_ad/lightning_model.py b/src/anomalib/models/efficient_ad/lightning_model.py index aebad81a7b..1fc917f9ed 100644 --- a/src/anomalib/models/efficient_ad/lightning_model.py +++ b/src/anomalib/models/efficient_ad/lightning_model.py @@ -133,18 +133,17 @@ def teacher_channel_mean_std(self, dataloader: DataLoader) -> dict[str, Tensor]: dict[str, Tensor]: Dictionary of channel-wise mean and std """ y_means = [] - teacher_outputs = [] means_distance = [] logger.info("Calculate teacher channel mean and std") for batch in tqdm.tqdm(dataloader, desc="Calculate teacher channel mean", position=0, leave=True): y = self.model.teacher(batch["image"].to(self.device)) y_means.append(torch.mean(y, dim=[0, 2, 3])) - teacher_outputs.append(y) channel_mean = torch.mean(torch.stack(y_means), dim=0)[None, :, None, None] - for y in tqdm.tqdm(teacher_outputs, desc="Calculate teacher channel std", position=0, leave=True): + for batch in tqdm.tqdm(dataloader, desc="Calculate teacher channel std", position=0, leave=True): + y = self.model.teacher(batch["image"].to(self.device)) distance = (y - channel_mean) ** 2 means_distance.append(torch.mean(distance, dim=[0, 2, 3]))