-
-
Notifications
You must be signed in to change notification settings - Fork 66
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor of ViT models #135
Conversation
a3d174b
to
47b5678
Compare
test/other.jl
Outdated
@test size(m(rand(Float32, 224, 224, 3, 2))) == (1000, 2) | ||
@test_skip gradtest(m, rand(Float32, 224, 224, 3, 2)) | ||
GC.gc() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding OOMs, it seems there are a number of models which can accommodate smaller image sizes. I think we can get away with CIFAR-10's 32x32 for those. ImageNet size can be reserved for the smaller variants of each model type if need be, just to make sure the overall architecture is sound.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That messes up the config values for the depth sometimes, since CIFAR is a smaller dataset and sizes may not match - they will for some of the models where the sizes are held constant throughout, thought, I get that. My idea was basically to make sure the config dicts hold up the way they should - we could always shift to a smaller image size wherever we can. But gradtests
are not even enabled and there are OOMs, so that's definitely a bit of a problem in case those have to be enabled in the future
d70b692
to
e047d53
Compare
bbc03fd
to
04d59e2
Compare
q, k, v = chunk(reshape(m.qkv_layer(x), B ÷ m.nheads, m.nheads, C, 3 * N), 3; dims = 4) | ||
scale = convert(T, sqrt(size(q, 1) / m.nheads)) | ||
attn = m.attn_drop(softmax(NeuralAttentionlib.matmul(q, permutedims(k, (2, 1, 3, 4))) * scale)) | ||
x = m.projection(reshape(NeuralAttentionlib.matmul(attn, v), (B, C, N))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of curiosity, why not use NeuralAttentionlib.multihead_qkv_attention
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Primarily for demonstration purposes....here of course the library function works perfectly well because it's vanilla attention but for other models there are more things happening and so I wrote it this way to kinda document the internals of self attention for a ViT-like model as more of these models are added. Could always switch - is there a significant performance boost obtained?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a significant performance boost obtained?
Probably no, since they work in the same way. But I wonder where does the performance gap (~ 100ms) you mentioned coming from. The functions NeuralAttentionlib
(currently) provided is to make attention functions composable, and matmul
should be equivalent to NNlib.batch_mul
+ reshape
. I'm not sure if the performance gain really comes from NeuralAttentionlib
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it's the multiple calls to Attention
inside Parallel
? This one simply does MHA directly while the original implementation had Parallel
with multiple Attentions
zipped together. The main reason I chose this version was because it also leaves future options open for working with inputs with more than three dimensions if need be (as I understand it CollapsedDimArray
provides this functionality if necessary [cool library BTW, thanks so much for making this easy for beginners to work with 😁])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it also leaves future options open for working with inputs with more than three dimensions if need be (as I understand it CollapsedDimArray provides this functionality if necessary
Yes, and actually NeuralAttentionlib.multihead_qkv_attention
already allows inputs with more dimensions as long as the dimension order is correct.
while the original implementation had Parallel with multiple Attentions zipped together.
So this seems to be the main reason for the performance gap. Parallel
doesn't run in parallel, it sequentially execute each path and aggregate the result, so doing MHA directly would definitely be faster.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As I recall in the original PR had issues with batched_mul
across channels, and it also performed the multiplications serially. Does NeuralAttentionlib
parallelize them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/chengchingwen/NeuralAttentionlib.jl/blob/master/src/matmul/matmul.jl certainly seems to work similar to the NNlib version (but with options for more dimensions), so I think there is parallelism? @chengchingwen could elaborate better though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure which PR does "the original PR" refer to. But as mentioned above, NeuralAttentionlib.matmul
is just batched_mul
+ reshape
, they work almost the same. Under the context of GPU, they are all calling CUBLAS.gemm_strided_batched!
which is parallelized. However, the CPU path of NeuralAttentionlib.matmul
is executed serially because for some unknown reason the multi-threading version is slower than the serial one (tested on my machine with an old Julia version long ago), while batched_mul
seems to be parallelized.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant #105 for reference. In that PR, 4D batched_mul
was originally implemented as a sequential loop over the 3rd dimension (which ended up being faster via Parallel
). I thought there was already some reshaping going on, but maybe we missed doing it correctly.
Either way, looks like at least the GPU path is parallelized for sure. It's good to consolidate these primitives in NeuralAttentionlib.jl. In #105, part of the problem with 4D batched_mul
was making assumptions about what the different dimensions mean. In NeuralAttentionlib.jl, it will be much safer to do that.
Perhaps I'm getting this wrong, but testing this PR on different versions of Julia has...surprising results. I can understand the 1.7.2 result, but master should be faster because of all the compiler improvements, right? (Also holy sheesh the 1.8 compiler seems to be running on a different plane of existence than 1.7) |
Compile time regressions are not uncommon on master, but generally speaking they're addressed before an actual release is cut. |
Bump? |
src/convnets/convnext.jl
Outdated
@@ -11,11 +11,11 @@ Creates a single block of ConvNeXt. | |||
""" | |||
function convnextblock(planes, drop_path_rate = 0., λ = 1f-6) | |||
layers = SkipConnection(Chain(DepthwiseConv((7, 7), planes => planes; pad = 3), | |||
x -> permutedims(x, (3, 1, 2, 4)), | |||
permute_dims((3, 1, 2, 4)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe swapdims
or something not almost the same as Base? Also, would it make sense to have swap_channels_spatial
/swap_spatial_channels
(where the dimension indices are left unspecified)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
swapdims
makes sense, yeah. I didn't really get the swap_spatial_channels
though - do you mean having convenience functions for interchanging very specific dimensions of the array? If so, would that not make it overly restrictive keeping future models in mind (and also that input sizes may be 4D or 3D)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it anymore restrictive than what's there right now? Both assume the ordering and number of dimensions. The explicit name would be easier to know what the intent is at a glance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No actually right now swapdims(perm)
is just Base.fix2(permutedims, perm)
- so it actually doesn't assume the ordering or the number of dimensions. I could change it to make it more restrictive. Is that something that would make more sense in this particular context?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't mean the definition of swapdims
which is fine as is. I mean the particular usage in this model (and others) is restricted. But it's okay we can just stick with what we have.
Co-Authored-By: Kyle Daruwalla <[email protected]>
8a757a6
to
7c6880a
Compare
Note: This PR is a logical follow-up to the many refinements made in #125 and thus must be merged only after that one is merged. (Reviewing this after #125 is merged will help too because the changes made in this PR only will be visible more clearly)
This PR refactors the ViT models to make some performance improvements and also to make it easier to write future PRs for attention-based models. First, it takes on NeuralAttentionlib as a dep because this allows for writing more general attention operations (which will be very helpful for future models where input dimensions may not be 3 or where there are more complicated hoops to jump through). This has the unexpected side effect of downgrading the Flux version to 0.12.8, however, because the library has a compat bound on NNlib v0.7 and Flux 0.12.9 is the first one to incorporate v0.8. A simple release should solve this, though. (cc @chengchingwen for this)
Overall, I found this version of attention to be significantly faster and inside the ViT model, benchmarking gave me a performance gap of roughly 100 ms for the forward pass. I haven't checked gradient ops due to memory limitations but it seems quite reasonable to assume a performance improvement there as well because NeuralAttentionlib takes care of the gradient for matmul ops, for example.
This also refines the VIT API to use the terminology originally used in the paper for scaling model size (tiny, small, base, large, huge, giant and gigantic).