From e6cb80714fa8fc15bf45bbf0372ad180d9867f79 Mon Sep 17 00:00:00 2001 From: Mirco Date: Wed, 26 Jul 2023 13:24:36 +0200 Subject: [PATCH 1/3] implement teacher_channel_mean_std with lower mem. footprint Workaround loops over dataset twice: no need to save teacher_outputs (which needs much GPU memory for large datasets) --- .../models/efficient_ad/lightning_model.py | 29 ++++++++++++++----- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/src/anomalib/models/efficient_ad/lightning_model.py b/src/anomalib/models/efficient_ad/lightning_model.py index 0ec0f908e2..d7fe1f3bf5 100644 --- a/src/anomalib/models/efficient_ad/lightning_model.py +++ b/src/anomalib/models/efficient_ad/lightning_model.py @@ -137,16 +137,31 @@ def teacher_channel_mean_std(self, dataloader: DataLoader) -> dict[str, Tensor]: means_distance = [] logger.info("Calculate teacher channel mean and std") - for batch in tqdm.tqdm(dataloader, desc="Calculate teacher channel mean", position=0, leave=True): - y = self.model.teacher(batch["image"].to(self.device)) - y_means.append(torch.mean(y, dim=[0, 2, 3])) - teacher_outputs.append(y) + try: + for batch in tqdm.tqdm(dataloader, desc="Calculate teacher channel mean", position=0, leave=True): + y = self.model.teacher(batch["image"].to(self.device)) + y_means.append(torch.mean(y, dim=[0, 2, 3])) + teacher_outputs.append(y) + except RuntimeError as e: + teacher_outputs = [] + y_means = [] + torch.cuda.empty_cache() + logger.info("Recovering from OutOfMemory Exception by using workaround with longer runtime") + for batch in tqdm.tqdm(dataloader, desc="Calculate teacher channel mean", position=0, leave=True): + y = self.model.teacher(batch["image"].to(self.device)) + y_means.append(torch.mean(y, dim=[0, 2, 3])) channel_mean = torch.mean(torch.stack(y_means), dim=0)[None, :, None, None] - for y in tqdm.tqdm(teacher_outputs, desc="Calculate teacher channel std", position=0, leave=True): - distance = (y - channel_mean) ** 2 - means_distance.append(torch.mean(distance, dim=[0, 2, 3])) + if len(teacher_outputs) > 0: + for y in tqdm.tqdm(teacher_outputs, desc="Calculate teacher channel std", position=0, leave=True): + distance = (y - channel_mean) ** 2 + means_distance.append(torch.mean(distance, dim=[0, 2, 3])) + else: + for batch in tqdm.tqdm(dataloader, desc="Calculate teacher channel std", position=0, leave=True): + y = self.model.teacher(batch["image"].to(self.device)) + distance = (y - channel_mean) ** 2 + means_distance.append(torch.mean(distance, dim=[0, 2, 3])) channel_var = torch.mean(torch.stack(means_distance), dim=0)[None, :, None, None] channel_std = torch.sqrt(channel_var) From 208ead7ca33a49ec04b4ee05644ac2317c28dafb Mon Sep 17 00:00:00 2001 From: Samet Date: Tue, 12 Sep 2023 10:21:34 +0100 Subject: [PATCH 2/3] Fix pre-commit --- src/anomalib/models/efficient_ad/lightning_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anomalib/models/efficient_ad/lightning_model.py b/src/anomalib/models/efficient_ad/lightning_model.py index 4d15bbfd86..96a2329d30 100644 --- a/src/anomalib/models/efficient_ad/lightning_model.py +++ b/src/anomalib/models/efficient_ad/lightning_model.py @@ -142,7 +142,7 @@ def teacher_channel_mean_std(self, dataloader: DataLoader) -> dict[str, Tensor]: y = self.model.teacher(batch["image"].to(self.device)) y_means.append(torch.mean(y, dim=[0, 2, 3])) teacher_outputs.append(y) - except RuntimeError as e: + except RuntimeError: teacher_outputs = [] y_means = [] torch.cuda.empty_cache() From 9289a5922006ff94e462655984c4a22f35e6c598 Mon Sep 17 00:00:00 2001 From: Mirco Date: Tue, 12 Sep 2023 12:47:19 +0200 Subject: [PATCH 3/3] implement new mean/std calculation as standard --- .../models/efficient_ad/lightning_model.py | 30 +++++-------------- 1 file changed, 7 insertions(+), 23 deletions(-) diff --git a/src/anomalib/models/efficient_ad/lightning_model.py b/src/anomalib/models/efficient_ad/lightning_model.py index 96a2329d30..1fc917f9ed 100644 --- a/src/anomalib/models/efficient_ad/lightning_model.py +++ b/src/anomalib/models/efficient_ad/lightning_model.py @@ -133,35 +133,19 @@ def teacher_channel_mean_std(self, dataloader: DataLoader) -> dict[str, Tensor]: dict[str, Tensor]: Dictionary of channel-wise mean and std """ y_means = [] - teacher_outputs = [] means_distance = [] logger.info("Calculate teacher channel mean and std") - try: - for batch in tqdm.tqdm(dataloader, desc="Calculate teacher channel mean", position=0, leave=True): - y = self.model.teacher(batch["image"].to(self.device)) - y_means.append(torch.mean(y, dim=[0, 2, 3])) - teacher_outputs.append(y) - except RuntimeError: - teacher_outputs = [] - y_means = [] - torch.cuda.empty_cache() - logger.info("Recovering from OutOfMemory Exception by using workaround with longer runtime") - for batch in tqdm.tqdm(dataloader, desc="Calculate teacher channel mean", position=0, leave=True): - y = self.model.teacher(batch["image"].to(self.device)) - y_means.append(torch.mean(y, dim=[0, 2, 3])) + for batch in tqdm.tqdm(dataloader, desc="Calculate teacher channel mean", position=0, leave=True): + y = self.model.teacher(batch["image"].to(self.device)) + y_means.append(torch.mean(y, dim=[0, 2, 3])) channel_mean = torch.mean(torch.stack(y_means), dim=0)[None, :, None, None] - if len(teacher_outputs) > 0: - for y in tqdm.tqdm(teacher_outputs, desc="Calculate teacher channel std", position=0, leave=True): - distance = (y - channel_mean) ** 2 - means_distance.append(torch.mean(distance, dim=[0, 2, 3])) - else: - for batch in tqdm.tqdm(dataloader, desc="Calculate teacher channel std", position=0, leave=True): - y = self.model.teacher(batch["image"].to(self.device)) - distance = (y - channel_mean) ** 2 - means_distance.append(torch.mean(distance, dim=[0, 2, 3])) + for batch in tqdm.tqdm(dataloader, desc="Calculate teacher channel std", position=0, leave=True): + y = self.model.teacher(batch["image"].to(self.device)) + distance = (y - channel_mean) ** 2 + means_distance.append(torch.mean(distance, dim=[0, 2, 3])) channel_var = torch.mean(torch.stack(means_distance), dim=0)[None, :, None, None] channel_std = torch.sqrt(channel_var)