Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor of ViT models #135

Merged
merged 17 commits into from
Mar 26, 2022
Merged

Refactor of ViT models #135

merged 17 commits into from
Mar 26, 2022

Conversation

theabhirath
Copy link
Member

Note: This PR is a logical follow-up to the many refinements made in #125 and thus must be merged only after that one is merged. (Reviewing this after #125 is merged will help too because the changes made in this PR only will be visible more clearly)

This PR refactors the ViT models to make some performance improvements and also to make it easier to write future PRs for attention-based models. First, it takes on NeuralAttentionlib as a dep because this allows for writing more general attention operations (which will be very helpful for future models where input dimensions may not be 3 or where there are more complicated hoops to jump through). This has the unexpected side effect of downgrading the Flux version to 0.12.8, however, because the library has a compat bound on NNlib v0.7 and Flux 0.12.9 is the first one to incorporate v0.8. A simple release should solve this, though. (cc @chengchingwen for this)

Overall, I found this version of attention to be significantly faster and inside the ViT model, benchmarking gave me a performance gap of roughly 100 ms for the forward pass. I haven't checked gradient ops due to memory limitations but it seems quite reasonable to assume a performance improvement there as well because NeuralAttentionlib takes care of the gradient for matmul ops, for example.

This also refines the VIT API to use the terminology originally used in the paper for scaling model size (tiny, small, base, large, huge, giant and gigantic).

@theabhirath theabhirath force-pushed the vit-refactor branch 2 times, most recently from a3d174b to 47b5678 Compare March 15, 2022 20:45
src/layers/embeddings.jl Outdated Show resolved Hide resolved
src/layers/mlp.jl Outdated Show resolved Hide resolved
src/other/mlpmixer.jl Outdated Show resolved Hide resolved
test/other.jl Outdated
Comment on lines 29 to 32
@test size(m(rand(Float32, 224, 224, 3, 2))) == (1000, 2)
@test_skip gradtest(m, rand(Float32, 224, 224, 3, 2))
GC.gc()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding OOMs, it seems there are a number of models which can accommodate smaller image sizes. I think we can get away with CIFAR-10's 32x32 for those. ImageNet size can be reserved for the smaller variants of each model type if need be, just to make sure the overall architecture is sound.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That messes up the config values for the depth sometimes, since CIFAR is a smaller dataset and sizes may not match - they will for some of the models where the sizes are held constant throughout, thought, I get that. My idea was basically to make sure the config dicts hold up the way they should - we could always shift to a smaller image size wherever we can. But gradtests are not even enabled and there are OOMs, so that's definitely a bit of a problem in case those have to be enabled in the future

Comment on lines +44 to +47
q, k, v = chunk(reshape(m.qkv_layer(x), B ÷ m.nheads, m.nheads, C, 3 * N), 3; dims = 4)
scale = convert(T, sqrt(size(q, 1) / m.nheads))
attn = m.attn_drop(softmax(NeuralAttentionlib.matmul(q, permutedims(k, (2, 1, 3, 4))) * scale))
x = m.projection(reshape(NeuralAttentionlib.matmul(attn, v), (B, C, N)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, why not use NeuralAttentionlib.multihead_qkv_attention?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Primarily for demonstration purposes....here of course the library function works perfectly well because it's vanilla attention but for other models there are more things happening and so I wrote it this way to kinda document the internals of self attention for a ViT-like model as more of these models are added. Could always switch - is there a significant performance boost obtained?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a significant performance boost obtained?

Probably no, since they work in the same way. But I wonder where does the performance gap (~ 100ms) you mentioned coming from. The functions NeuralAttentionlib (currently) provided is to make attention functions composable, and matmul should be equivalent to NNlib.batch_mul + reshape. I'm not sure if the performance gain really comes from NeuralAttentionlib

Copy link
Member Author

@theabhirath theabhirath Mar 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's the multiple calls to Attention inside Parallel? This one simply does MHA directly while the original implementation had Parallel with multiple Attentions zipped together. The main reason I chose this version was because it also leaves future options open for working with inputs with more than three dimensions if need be (as I understand it CollapsedDimArray provides this functionality if necessary [cool library BTW, thanks so much for making this easy for beginners to work with 😁])

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it also leaves future options open for working with inputs with more than three dimensions if need be (as I understand it CollapsedDimArray provides this functionality if necessary

Yes, and actually NeuralAttentionlib.multihead_qkv_attention already allows inputs with more dimensions as long as the dimension order is correct.

while the original implementation had Parallel with multiple Attentions zipped together.

So this seems to be the main reason for the performance gap. Parallel doesn't run in parallel, it sequentially execute each path and aggregate the result, so doing MHA directly would definitely be faster.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I recall in the original PR had issues with batched_mul across channels, and it also performed the multiplications serially. Does NeuralAttentionlib parallelize them?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/chengchingwen/NeuralAttentionlib.jl/blob/master/src/matmul/matmul.jl certainly seems to work similar to the NNlib version (but with options for more dimensions), so I think there is parallelism? @chengchingwen could elaborate better though

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure which PR does "the original PR" refer to. But as mentioned above, NeuralAttentionlib.matmul is just batched_mul + reshape, they work almost the same. Under the context of GPU, they are all calling CUBLAS.gemm_strided_batched! which is parallelized. However, the CPU path of NeuralAttentionlib.matmul is executed serially because for some unknown reason the multi-threading version is slower than the serial one (tested on my machine with an old Julia version long ago), while batched_mul seems to be parallelized.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant #105 for reference. In that PR, 4D batched_mul was originally implemented as a sequential loop over the 3rd dimension (which ended up being faster via Parallel). I thought there was already some reshaping going on, but maybe we missed doing it correctly.

Either way, looks like at least the GPU path is parallelized for sure. It's good to consolidate these primitives in NeuralAttentionlib.jl. In #105, part of the problem with 4D batched_mul was making assumptions about what the different dimensions mean. In NeuralAttentionlib.jl, it will be much safer to do that.

@theabhirath
Copy link
Member Author

theabhirath commented Mar 17, 2022

Perhaps I'm getting this wrong, but testing this PR on different versions of Julia has...surprising results.

1.8.0-beta2:
Screenshot 2022-03-17 at 8 22 00 AM

1.7.2:
Screenshot 2022-03-17 at 8 08 39 AM

1.9.0-dev (master):
Screenshot 2022-03-17 at 8 24 18 AM

I can understand the 1.7.2 result, but master should be faster because of all the compiler improvements, right? (Also holy sheesh the 1.8 compiler seems to be running on a different plane of existence than 1.7)

@ToucheSir
Copy link
Member

Compile time regressions are not uncommon on master, but generally speaking they're addressed before an actual release is cut.

@theabhirath theabhirath requested a review from darsnack March 20, 2022 17:20
@theabhirath
Copy link
Member Author

Bump?

@@ -11,11 +11,11 @@ Creates a single block of ConvNeXt.
"""
function convnextblock(planes, drop_path_rate = 0., λ = 1f-6)
layers = SkipConnection(Chain(DepthwiseConv((7, 7), planes => planes; pad = 3),
x -> permutedims(x, (3, 1, 2, 4)),
permute_dims((3, 1, 2, 4)),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe swapdims or something not almost the same as Base? Also, would it make sense to have swap_channels_spatial/swap_spatial_channels (where the dimension indices are left unspecified)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

swapdims makes sense, yeah. I didn't really get the swap_spatial_channels though - do you mean having convenience functions for interchanging very specific dimensions of the array? If so, would that not make it overly restrictive keeping future models in mind (and also that input sizes may be 4D or 3D)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it anymore restrictive than what's there right now? Both assume the ordering and number of dimensions. The explicit name would be easier to know what the intent is at a glance.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No actually right now swapdims(perm) is just Base.fix2(permutedims, perm) - so it actually doesn't assume the ordering or the number of dimensions. I could change it to make it more restrictive. Is that something that would make more sense in this particular context?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't mean the definition of swapdims which is fine as is. I mean the particular usage in this model (and others) is restricted. But it's okay we can just stick with what we have.

src/layers/attention.jl Outdated Show resolved Hide resolved
src/layers/attention.jl Outdated Show resolved Hide resolved
src/layers/embeddings.jl Outdated Show resolved Hide resolved
test/vit-based.jl Outdated Show resolved Hide resolved
Co-Authored-By: Kyle Daruwalla <[email protected]>
src/utilities.jl Outdated Show resolved Hide resolved
@darsnack darsnack merged commit c8f0a88 into FluxML:master Mar 26, 2022
@theabhirath theabhirath deleted the vit-refactor branch March 27, 2022 01:10
@theabhirath theabhirath mentioned this pull request Mar 27, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants