diff --git a/Project.toml b/Project.toml index cc23f0551..1be4e0cf9 100644 --- a/Project.toml +++ b/Project.toml @@ -8,12 +8,14 @@ BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] BSON = "0.3.2" Flux = "0.12" +MLUtils = "0.1.2" Functors = "0.2" NNlib = "0.7.34" julia = "1.4" diff --git a/src/Metalhead.jl b/src/Metalhead.jl index 977b6183f..aab12aa11 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -6,11 +6,12 @@ using Functors using BSON using Artifacts, LazyArtifacts using Statistics +using MLUtils import Functors include("utilities.jl") -include("layers.jl") +include("layers.jl") # CNN models include("convnets/alexnet.jl") @@ -26,6 +27,9 @@ include("convnets/mobilenet.jl") # Other models include("other/mlpmixer.jl") +# ViT-based models +include("vit-based/vit.jl") + export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, @@ -33,11 +37,12 @@ export AlexNet, DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201, ResNeXt, MobileNetv2, MobileNetv3, - MLPMixer + MLPMixer, + ViT # use Flux._big_show to pretty print large models for T in (:AlexNet, :VGG, :ResNet, :GoogLeNet, :Inception3, :SqueezeNet, :DenseNet, :ResNeXt, - :MobileNetv2, :MobileNetv3, :MLPMixer) + :MobileNetv2, :MobileNetv3, :MLPMixer, :ViT) @eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model) end diff --git a/src/convnets/alexnet.jl b/src/convnets/alexnet.jl index faf27a831..ea3962c2a 100644 --- a/src/convnets/alexnet.jl +++ b/src/convnets/alexnet.jl @@ -17,7 +17,7 @@ function alexnet(; nclasses = 1000) Conv((3, 3), 256 => 256, relu, pad = (1, 1)), MaxPool((3, 3), stride = (2, 2)), AdaptiveMeanPool((6,6))), - Chain(flatten, + Chain(MLUtils.flatten, Dropout(0.5), Dense(256 * 6 * 6, 4096, relu), Dropout(0.5), diff --git a/src/convnets/densenet.jl b/src/convnets/densenet.jl index 3e4960f92..fa85fe548 100644 --- a/src/convnets/densenet.jl +++ b/src/convnets/densenet.jl @@ -75,7 +75,7 @@ function densenet(inplanes, growth_rates; reduction = 0.5, nclasses = 1000) return Chain(Chain(layers...), Chain(AdaptiveMeanPool((1, 1)), - flatten, + MLUtils.flatten, Dense(outplanes, nclasses))) end diff --git a/src/convnets/googlenet.jl b/src/convnets/googlenet.jl index 152199469..4de47a0ef 100644 --- a/src/convnets/googlenet.jl +++ b/src/convnets/googlenet.jl @@ -56,7 +56,7 @@ function googlenet(; nclasses = 1000) _inceptionblock(832, 256, 160, 320, 32, 128, 128), _inceptionblock(832, 384, 192, 384, 48, 128, 128)), Chain(AdaptiveMeanPool((1, 1)), - flatten, + MLUtils.flatten, Dropout(0.4), Dense(1024, nclasses))) diff --git a/src/convnets/inception.jl b/src/convnets/inception.jl index 8ab57f201..a9a33ed50 100644 --- a/src/convnets/inception.jl +++ b/src/convnets/inception.jl @@ -170,7 +170,7 @@ function inception3(; nclasses = 1000) inception_e(2048)), Chain(AdaptiveMeanPool((1, 1)), Dropout(0.2), - flatten, + MLUtils.flatten, Dense(2048, nclasses))) return layer diff --git a/src/convnets/mobilenet.jl b/src/convnets/mobilenet.jl index ee554eba1..1c9d574f9 100644 --- a/src/convnets/mobilenet.jl +++ b/src/convnets/mobilenet.jl @@ -77,7 +77,7 @@ function mobilenetv2(width_mult, configs; max_width = 1280, nclasses = 1000) return Chain(Chain(layers..., conv_bn((1, 1), inplanes, outplanes, relu6, bias = false)...), - Chain(AdaptiveMeanPool((1, 1)), flatten, Dense(outplanes, nclasses))) + Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, Dense(outplanes, nclasses))) end # Layer configurations for MobileNetv2 @@ -221,7 +221,7 @@ function mobilenetv3(width_mult, configs; max_width = 1024, nclasses = 1000) return Chain(Chain(layers..., conv_bn((1, 1), inplanes, explanes, hardswish, bias = false)...), - Chain(AdaptiveMeanPool((1, 1)), flatten, classifier...)) + Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, classifier...)) end # Configurations for small and large mode for MobileNetv3 diff --git a/src/convnets/resnet.jl b/src/convnets/resnet.jl index c96bd991d..a86f87252 100644 --- a/src/convnets/resnet.jl +++ b/src/convnets/resnet.jl @@ -79,7 +79,7 @@ function resnet(block, residuals::NTuple{2, Any}, connection = addrelu; end return Chain(Chain(layers...), - Chain(AdaptiveMeanPool((1, 1)), flatten, Dense(inplanes, nclasses))) + Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, Dense(inplanes, nclasses))) end """ diff --git a/src/convnets/resnext.jl b/src/convnets/resnext.jl index bbb431895..e3ebd21b3 100644 --- a/src/convnets/resnext.jl +++ b/src/convnets/resnext.jl @@ -63,7 +63,7 @@ function resnext(cardinality, width, widen_factor = 2, connection = (x, y) -> @. end return Chain(Chain(layers...), - Chain(AdaptiveMeanPool((1, 1)), flatten, Dense(inplanes, nclasses))) + Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, Dense(inplanes, nclasses))) end """ diff --git a/src/convnets/squeezenet.jl b/src/convnets/squeezenet.jl index 4f61765e0..169ad2e86 100644 --- a/src/convnets/squeezenet.jl +++ b/src/convnets/squeezenet.jl @@ -43,7 +43,7 @@ function squeezenet() Dropout(0.5), Conv((1, 1), 512 => 1000, relu)), AdaptiveMeanPool((1, 1)), - flatten) + MLUtils.flatten) return layers end diff --git a/src/convnets/vgg.jl b/src/convnets/vgg.jl index 55ebe928f..2849ce07a 100644 --- a/src/convnets/vgg.jl +++ b/src/convnets/vgg.jl @@ -63,7 +63,7 @@ Create VGG classifier (fully connected) layers """ function vgg_classifier_layers(imsize, nclasses, fcsize, dropout) layers = [] - push!(layers, flatten) + push!(layers, MLUtils.flatten) push!(layers, Dense(Int(prod(imsize)), fcsize, relu)) push!(layers, Dropout(dropout)) push!(layers, Dense(fcsize, fcsize, relu)) diff --git a/src/layers.jl b/src/layers.jl index d04ff212f..34d99a3e0 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -91,31 +91,95 @@ end skip_identity(inplanes, outplanes, downsample) = skip_identity(inplanes, outplanes) """ - addrelu(x, y) + mlpblock(planes, hidden_planes; dropout = 0., dense = Dense, activation = gelu) -Convenience function for `(x, y) -> @. relu(x + y)`. -Useful as the `connection` argument for [`resnet`](#). -See also [`reluadd`](#). +Feedforward block used in many vision transformer-like models. + +# Arguments +- `planes`: Number of dimensions in the input and output. +- `hidden_planes`: Number of dimensions in the intermediate layer. +- `dropout`: Dropout rate. +- `dense`: Type of dense layer to use in the feedforward block. +- `activation`: Activation function to use. """ -addrelu(x, y) = @. relu(x + y) +function mlpblock(planes, hidden_planes; dropout = 0., dense = Dense, activation = gelu) + Chain(dense(planes, hidden_planes, activation), Dropout(dropout), + dense(hidden_planes, planes, activation), Dropout(dropout)) +end """ - reluadd(x, y) + Attention(in => out) + Attention(qkvlayer) -Convenience function for `(x, y) -> @. relu(x) + relu(y)`. -Useful as the `connection` argument for [`resnet`](#). -See also [`addrelu`](#). +Self attention layer used by transformer models. Specify the `in` and `out` dimensions, +or directly provide a `qkvlayer` that maps an input the queries, keys, and values. """ -reluadd(x, y) = @. relu(x) + relu(y) +struct Attention{T} + qkv::T +end + +Attention(dims::Pair{Int, Int}) = Attention(Dense(dims.first, dims.second * 3; bias = false)) + +@functor Attention + +function (attn::Attention)(x::AbstractArray{T, 3}) where T + q, k, v = chunk(attn.qkv(x), 3, dims = 1) + scale = convert(T, sqrt(size(q, 1))) + score = softmax(batched_mul(batched_transpose(q), k) / scale) + attention = batched_mul(v, score) + + return attention +end -# Patching layer used by many vision transformer-like models -struct Patching{T <: Integer} - patch_height::T - patch_width::T +struct MHAttention{S, T} + heads::S + projection::T end -Patching(patch_size) = Patching(patch_size, patch_size) -function (p::Patching)(x) +""" + MHAttention(in, hidden, nheads; dropout = 0.0) + +Multi-head self-attention layer used in many vision transformer-like models. + +# Arguments +- `in`: Number of dimensions in the input. +- `hidden`: Number of dimensions in the intermediate layer. +- `nheads`: Number of attention heads. +- `dropout`: Dropout rate for the projection layer. +""" +function MHAttention(in, hidden, nheads; dropout = 0.) + if (nheads == 1 && hidden == in) + return Attention(in => in) + end + inheads, innerheads = chunk(1:in, nheads), chunk(1:hidden, nheads) + heads = Parallel(vcat, [Attention(length(i) => length(o)) for (i, o) in zip(inheads, innerheads)]...) + projection = Chain(Dense(hidden, in), Dropout(dropout)) + + MHAttention(heads, projection) +end + +@functor MHAttention + +function (mha::MHAttention)(x) + nheads = length(mha.heads.layers) + xhead = chunk(x, nheads, dims = 1) + return mha.projection(mha.heads(xhead...)) +end + +""" + PatchEmbedding(patch_size) + PatchEmbedding(patch_height, patch_width) + +Patch embedding layer used by many vision transformer-like models to split the input image into patches. +""" +struct PatchEmbedding + patch_height::Int + patch_width::Int +end + +PatchEmbedding(patch_size) = PatchEmbedding(patch_size, patch_size) + +function (p::PatchEmbedding)(x) h, w, c, n = size(x) hp, wp = h ÷ p.patch_height, w ÷ p.patch_width xpatch = reshape(x, hp, p.patch_height, wp, p.patch_width, c, n) @@ -124,21 +188,38 @@ function (p::Patching)(x) hp * wp, n) end -@functor Patching +@functor PatchEmbedding """ - mlpblock(planes, expansion_factor = 4, dropout = 0., dense = Dense) + ViPosEmbedding(embedsize, npatches; init = (dims) -> rand(Float32, dims)) -Feedforward block used in many vision transformer-like models. +Positional embedding layer used by many vision transformer-like models. +""" +struct ViPosEmbedding{T} + vectors::T +end + +ViPosEmbedding(embedsize, npatches; init = (dims::NTuple{2, Int}) -> rand(Float32, dims)) = + ViPosEmbedding(init((embedsize, npatches))) + +(p::ViPosEmbedding)(x) = x .+ p.vectors + +@functor ViPosEmbedding -# Arguments - `planes`: Number of dimensions in the input and output. - `hidden_planes`: Number of dimensions in the intermediate layer. - `dropout`: Dropout rate. - `dense`: Type of dense layer to use in the feedforward block. - `activation`: Activation function to use. """ -function mlpblock(planes, hidden_planes, dropout = 0., dense = Dense; activation = gelu) - Chain(dense(planes, hidden_planes, activation), Dropout(dropout), - dense(hidden_planes, planes, activation), Dropout(dropout)) + ClassTokens(dim; init = Flux.zeros32) + +Appends class tokens to an input with embedding dimension `dim` for use in many vision transformer models. +""" +struct ClassTokens{T} + token::T end + +ClassTokens(dim::Integer; init = Flux.zeros32) = ClassTokens(init(dim, 1, 1)) + +function (m::ClassTokens)(x) + tokens = repeat(m.token, 1, 1, size(x, 3)) + return hcat(tokens, x) +end + +@functor ClassTokens diff --git a/src/other/mlpmixer.jl b/src/other/mlpmixer.jl index 2f6c0da3d..2bf7309f2 100644 --- a/src/other/mlpmixer.jl +++ b/src/other/mlpmixer.jl @@ -36,14 +36,14 @@ function mlpmixer(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = 1 num_patches = (im_height ÷ patch_size) * (im_width ÷ patch_size) layers = [] - push!(layers, Patching(patch_size)) + push!(layers, PatchEmbedding(patch_size)) push!(layers, Dense((patch_size ^ 2) * inchannels, planes)) append!(layers, [Chain(_residualprenorm(planes, mlpblock(num_patches, - expansion_factor * num_patches, - dropout, token_mix)), + expansion_factor * num_patches; + dropout, dense = token_mix)), _residualprenorm(planes, mlpblock(planes, - expansion_factor * planes, dropout, - channel_mix)),) for _ in 1:depth]) + expansion_factor * planes; dropout, + dense = channel_mix)),) for _ in 1:depth]) classification_head = Chain(_seconddimmean, Dense(planes, nclasses)) diff --git a/src/utilities.jl b/src/utilities.jl index bd64dbcf1..57ea8eba9 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -1,5 +1,23 @@ # Utility function for classifier head of vision transformer-like models -_seconddimmean(x) = mean(x, dims = 2)[:, 1, :] +_seconddimmean(x) = dropdims(mean(x, dims = 2); dims = 2) + +""" + addrelu(x, y) + +Convenience function for `(x, y) -> @. relu(x + y)`. +Useful as the `connection` argument for [`resnet`](#). +See also [`reluadd`](#). +""" +addrelu(x, y) = @. relu(x + y) + +""" + reluadd(x, y) + +Convenience function for `(x, y) -> @. relu(x) + relu(y)`. +Useful as the `connection` argument for [`resnet`](#). +See also [`addrelu`](#). +""" +reluadd(x, y) = @. relu(x) + relu(y) """ weights(model) diff --git a/src/vit-based/vit.jl b/src/vit-based/vit.jl new file mode 100644 index 000000000..09960da66 --- /dev/null +++ b/src/vit-based/vit.jl @@ -0,0 +1,114 @@ +# Utility function for applying LayerNorm before a block +prenorm(planes, fn) = Chain(fn, LayerNorm(planes)) + +""" + transformer_encoder(planes, depth, heads, headplanes, mlppanes; dropout = 0.) + +Transformer as used in the base ViT architecture. +([reference](https://arxiv.org/abs/2010.11929)). + +# Arguments +- `planes`: number of input channels +- `depth`: number of attention blocks +- `heads`: number of attention heads +- `headplanes`: number of hidden channels per head +- `mlppanes`: number of hidden channels in the MLP block +- `dropout`: dropout rate +""" +function transformer_encoder(planes, depth, heads, headplanes, mlpplanes; dropout = 0.) + layers = [Chain(SkipConnection(prenorm(planes, MHAttention(planes, headplanes, heads; dropout)), +), + SkipConnection(prenorm(planes, mlpblock(planes, mlpplanes; dropout)), +)) + for _ in 1:depth] + + Chain(layers...) +end + +""" + vit(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = (16, 16), planes = 1024, + depth = 6, heads = 16, mlppanes = 2048, headplanes = 64, dropout = 0.1, emb_dropout = 0.1, + pool = :class, nclasses = 1000) + +Creates a Vision Transformer (ViT) model. +([reference](https://arxiv.org/abs/2010.11929)). + +# Arguments +- `imsize`: image size +- `inchannels`: number of input channels +- `patch_size`: size of the patches +- `planes`: the number of channels fed into the main model +- `depth`: number of blocks in the transformer +- `heads`: number of attention heads in the transformer +- `mlpplanes`: number of hidden channels in the MLP block in the transformer +- `headplanes`: number of hidden channels per head in the transformer +- `dropout`: dropout rate +- `emb_dropout`: dropout rate for the positional embedding layer +- `pool`: pooling type, either :class or :mean +- `nclasses`: number of classes in the output +""" +function vit(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = (16, 16), planes = 1024, + depth = 6, heads = 16, mlppanes = 2048, headplanes = 64, dropout = 0.1, emb_dropout = 0.1, + pool = :class, nclasses = 1000) + + im_height, im_width = imsize + patch_height, patch_width = patch_size + + @assert (im_height % patch_height == 0) && (im_width % patch_width == 0) + "Image dimensions must be divisible by the patch size." + @assert pool in [:class, :mean] + "Pool type must be either :class (class token) or :mean (mean pooling)" + + npatches = (im_height ÷ patch_height) * (im_width ÷ patch_width) + patchplanes = inchannels * patch_height * patch_width + + return Chain(Chain(PatchEmbedding(patch_height, patch_width), + Dense(patchplanes, planes), + ClassTokens(planes), + ViPosEmbedding(planes, npatches + 1), + Dropout(emb_dropout), + transformer_encoder(planes, depth, heads, headplanes, mlppanes; dropout), + (pool == :class) ? x -> x[:, 1, :] : _seconddimmean), + Chain(LayerNorm(planes), Dense(planes, nclasses))) +end + +struct ViT + layers +end + +""" + ViT(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = (16, 16), planes = 1024, + depth = 6, heads = 16, mlppanes = 2048, headplanes = 64, dropout = 0.1, emb_dropout = 0.1, + pool = :class, nclasses = 1000) + +Creates a Vision Transformer (ViT) model. +([reference](https://arxiv.org/abs/2010.11929)). + +# Arguments +- `imsize`: image size +- `inchannels`: number of input channels +- `patch_size`: size of the patches +- `planes`: the number of channels fed into the main model +- `depth`: number of blocks in the transformer +- `heads`: number of attention heads in the transformer +- `mlpplanes`: number of hidden channels in the MLP block in the transformer +- `headplanes`: number of hidden channels per head in the transformer +- `dropout`: dropout rate +- `emb_dropout`: dropout rate for the positional embedding layer +- `pool`: pooling type, either :class or :mean +- `nclasses`: number of classes in the output +""" +function ViT(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = (16, 16), planes = 1024, + depth = 6, heads = 16, mlppanes = 2048, headplanes = 64, + dropout = 0.1, emb_dropout = 0.1, pool = :class, nclasses = 1000) + + layers = vit(imsize; inchannels, patch_size, planes, depth, heads, mlppanes, headplanes, + dropout, emb_dropout, pool, nclasses) + + ViT(layers) +end + +(m::ViT)(x) = m.layers(x) + +backbone(m::ViT) = m.layers[1] +classifier(m::ViT) = m.layers[2] + +@functor ViT diff --git a/test/runtests.jl b/test/runtests.jl index 7d2c9abf0..074dfc972 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -19,3 +19,8 @@ end @testset verbose = true "Other" begin include("other.jl") end + +# ViT tests +@testset verbose = true "ViTs" begin + include("vit-based.jl") +end diff --git a/test/vit-based.jl b/test/vit-based.jl new file mode 100644 index 000000000..4e5451fb6 --- /dev/null +++ b/test/vit-based.jl @@ -0,0 +1,7 @@ +using Metalhead, Test +using Flux + +@testset "ViT" begin + @test size(ViT()(rand(Float32, 256, 256, 3, 2))) == (1000, 2) + @test_skip gradtest(ViT(), rand(Float32, 256, 256, 3, 2)) +end \ No newline at end of file