Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Task]: EfficientAD OutOfMemory exception (fix included) #1301

Closed
MG109 opened this issue Aug 25, 2023 · 4 comments · Fixed by #1340
Closed

[Task]: EfficientAD OutOfMemory exception (fix included) #1301

MG109 opened this issue Aug 25, 2023 · 4 comments · Fixed by #1340

Comments

@MG109
Copy link
Contributor

MG109 commented Aug 25, 2023

What is the motivation for this task?

I ran into an OutOfMemory exception when calculating the teacher channel mean and standard deviation while training the EfficientAD model.

Describe the solution you'd like

I fixed this by not saving the teacher outputs to a list but instead iterating the dataloader twice.
This increases the runtime of the teacher_channel_mean_std() function, but only if the teacher outputs don't fit into memory.

I could create a pull request from my branch, but I didn't manage to get the pre-commit hooks running.

    @torch.no_grad()
    def teacher_channel_mean_std(self, dataloader: DataLoader) -> dict[str, Tensor]:
        """Calculate the mean and std of the teacher models activations.

        Args:
            dataloader (DataLoader): Dataloader of the respective dataset.

        Returns:
            dict[str, Tensor]: Dictionary of channel-wise mean and std
        """
        y_means = []
        teacher_outputs = []
        means_distance = []

        logger.info("Calculate teacher channel mean and std")
        try:
            for batch in tqdm.tqdm(dataloader, desc="Calculate teacher channel mean", position=0, leave=True):
                y = self.model.teacher(batch["image"].to(self.device))
                y_means.append(torch.mean(y, dim=[0, 2, 3]))
                teacher_outputs.append(y)
        except RuntimeError as e:
            teacher_outputs = []
            y_means = []
            torch.cuda.empty_cache()
            logger.info("Recovering from OutOfMemory Exception by using workaround with longer runtime")
            for batch in tqdm.tqdm(dataloader, desc="Calculate teacher channel mean", position=0, leave=True):
                y = self.model.teacher(batch["image"].to(self.device))
                y_means.append(torch.mean(y, dim=[0, 2, 3]))

        channel_mean = torch.mean(torch.stack(y_means), dim=0)[None, :, None, None]

        if len(teacher_outputs) > 0:
            for y in tqdm.tqdm(teacher_outputs, desc="Calculate teacher channel std", position=0, leave=True):
                distance = (y - channel_mean) ** 2
                means_distance.append(torch.mean(distance, dim=[0, 2, 3]))
        else:
            for batch in tqdm.tqdm(dataloader, desc="Calculate teacher channel std", position=0, leave=True):
                y = self.model.teacher(batch["image"].to(self.device))
                distance = (y - channel_mean) ** 2
                means_distance.append(torch.mean(distance, dim=[0, 2, 3]))

        channel_var = torch.mean(torch.stack(means_distance), dim=0)[None, :, None, None]
        channel_std = torch.sqrt(channel_var)
        return {"mean": channel_mean, "std": channel_std}

Additional context

No response

@MG109 MG109 added the Task label Aug 25, 2023
@alexriedel1
Copy link
Contributor

Thanks for the fixing idea. It works, you still have to store the values before processing them. The best would probabaly be an "online" mean and variance calculation. The mean is trivial and standard deviation is calculated as follows https://math.stackexchange.com/questions/198336/how-to-calculate-standard-deviation-with-streaming-inputs

@blaz-r
Copy link
Contributor

blaz-r commented Aug 26, 2023

I agree with @alexriedel1, it would be worth exploring the idea of using such statistical algorithms. I'm not that familiar with the topic, but I see that there can be some problems with precision and stability, so it'd be worth to explore this, and maybe offer this as a config option if it turns out to be unstable in some cases (tho I'm not sure how to verify).

@ZHIZIHUABU
Copy link

thanks ,it works, but will this affect training accuracy?

@samet-akcay
Copy link
Contributor

I could create a pull request from my branch, but I didn't manage to get the pre-commit hooks running.

@MG109, that's alright. If you create a PR, we could sort out the pre-commit stuff for you. You are the one proposing this, so it would be good that you become a contributor.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants