Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consideration of Flash Attention in Generative Components #7944

Closed
KumoLiu opened this issue Jul 23, 2024 · 6 comments · Fixed by #7977
Closed

Consideration of Flash Attention in Generative Components #7944

KumoLiu opened this issue Jul 23, 2024 · 6 comments · Fixed by #7977
Assignees

Comments

@KumoLiu
Copy link
Contributor

KumoLiu commented Jul 23, 2024

Regarding the comments here: #7715 (comment)

We have removed the flash attention from the generative components merged into the core. However, based on experiments conducted by @dongyang0122, there appears to be a significant difference between using and not using flash attention. We should consider adding this option back. @dongyang0122 will share more detailed comparison results from the experiments.

@KumoLiu
Copy link
Contributor Author

KumoLiu commented Jul 23, 2024

cc @ericspod @virginiafdez

@dongyang0122
Copy link
Collaborator

dongyang0122 commented Jul 25, 2024

We use the following Python script for comparison. When flash attention is enabled, we are able to train the diffusion model with batch size 1. When flash attention is disabled, the training process will cause out of memory. For the experiments, we use a A100 GPU with 80GB of memory. If flash attention is enabled, 30GB+ of memory is utilized.

verify_training.py.txt

@mingxin-zheng
Copy link
Contributor

Thanks @dongyang0122 for the verification.

It seems to me that xformers/flash_attention brought benefits and we should weight including them in the dependency list.

On the other hand, installation of the package could be challenging, because the range of torch, cuda, and os platform seems narrow: https://github.com/facebookresearch/xformers/issues

If monai leaves the installation of xformers to the user but we keep the optional_import in the code base, is this acceptable?

@ericspod
Copy link
Member

We could look at using the flash attention in Pytorch as well: https://pytorch.org/blog/pytorch2-2/

@KumoLiu
Copy link
Contributor Author

KumoLiu commented Jul 29, 2024

Thanks to @dongyang0122 for the script. I tried the script and tested it on both the original generative implementation and the PyTorch implementation. The results are shown below. We can see that PyTorch almost achieves the same results as the xformer implementation. Based on this, I believe we can use PyTorch instead.

image

@KumoLiu
Copy link
Contributor Author

KumoLiu commented Jul 29, 2024

cc @guopengf for vis. Here is the code I use: https://github.com/KumoLiu/MONAI/tree/flash-atten

@virginiafdez virginiafdez mentioned this issue Aug 1, 2024
7 tasks
@KumoLiu KumoLiu closed this as completed in 6c23fd0 Aug 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants