diff --git a/Project.toml b/Project.toml index 4c1769440..56c3544bb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "0.4.2" +version = "0.4.3" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 1a6ca4bf1..008238ecc 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -691,7 +691,7 @@ Create a Sparsely Connected Layer with a very specific structure (only Diagonal ## Arguments -* `dims`: number of input and output dimensions +* `dims`: size of the learnable scale and bias parameters. * `activation`: activation function ## Keyword Arguments @@ -702,17 +702,20 @@ Create a Sparsely Connected Layer with a very specific structure (only Diagonal ## Input -* `x` must be a Matrix of size `dims × B` or a Vector of length `dims` +* `x` must be an Array of size `(dims..., B)` or `(dims...[0], ..., dims[k])` for `k ≤ size(dims)` ## Returns -* Matrix of size `dims × B` or a Vector of length `dims` +* Array of size `(dims..., B)` or `(dims...[0], ..., dims[k])` for `k ≤ size(dims)` * Empty `NamedTuple()` ## Parameters -* `weight`: Weight Vector of size `(dims,)` -* `bias`: Bias of size `(dims,)` +* `weight`: Weight Array of size `(dims...)` +* `bias`: Bias of size `(dims...)` + +!!! compat "Lux 0.4.3" + `Scale` with multiple dimensions requires at least Lux 0.4.3. """ struct Scale{bias, F1, D, F2, F3} <: AbstractExplicitLayer activation::F1 @@ -727,21 +730,27 @@ function Base.show(io::IO, d::Scale) return print(io, ")") end -function Scale(dims, activation=identity; init_weight=glorot_uniform, +function Scale(dims::Tuple{Vararg{Integer}}, activation=identity; + init_weight=glorot_uniform, init_bias=zeros32, bias::Bool=true) activation = NNlib.fast_act(activation) return Scale{bias, typeof(activation), typeof(dims), typeof(init_weight), typeof(init_bias)}(activation, dims, init_weight, init_bias) end +function Scale(s1::Integer, s23::Integer...; _act=identity, kw...) + Scale(tuple(s1, s23...), _act; kw...) +end +Scale(size_act...; kw...) = Scale(size_act[1:(end - 1)]...; _act=size_act[end], kw...) + function initialparameters(rng::AbstractRNG, d::Scale{true}) - return (weight=d.init_weight(rng, d.dims), bias=d.init_bias(rng, d.dims)) + return (weight=d.init_weight(rng, d.dims...), bias=d.init_bias(rng, d.dims...)) end function initialparameters(rng::AbstractRNG, d::Scale{false}) - (weight=d.init_weight(rng, d.dims),) + (weight=d.init_weight(rng, d.dims...),) end -parameterlength(d::Scale{bias}) where {bias} = (1 + bias) * d.dims +parameterlength(d::Scale{bias}) where {bias} = (1 + bias) * prod(d.dims) statelength(d::Scale) = 0 function (d::Scale{true})(x::AbstractArray, ps::Union{ComponentArray, NamedTuple}, diff --git a/test/layers/basic.jl b/test/layers/basic.jl index c7f3428ec..f49611661 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -268,3 +268,50 @@ end end == [10 20; 10 20] end end + +@testset "Scale" begin + @testset "constructors" begin + layer = Scale(10, 100) + ps, st = Lux.setup(rng, layer) + + @test size(ps.weight) == (10, 100) + @test size(ps.bias) == (10, 100) + @test layer.activation == identity + + layer = Scale(10, 100, relu; bias=false) + ps, st = Lux.setup(rng, layer) + + @test !haskey(ps, :bias) + @test layer.activation == relu + end + + @testset "dimensions" begin + layer = Scale(10, 5) + ps, st = Lux.setup(rng, layer) + + @test size(first(Lux.apply(layer, randn(10), ps, st))) == (10, 5) + @test size(first(Lux.apply(layer, randn(10, 5, 2), ps, st))) == (10, 5, 2) + end + + @testset "zeros" begin + @test begin + layer = Scale(10, 1, identity; init_weight=ones) + first(Lux.apply(layer, ones(10, 1), Lux.setup(rng, layer)...)) + end == ones(10, 1) + + @test begin + layer = Scale(10, 1, identity; init_weight=ones) + first(Lux.apply(layer, ones(10, 2), Lux.setup(rng, layer)...)) + end == ones(10, 2) + + @test begin + layer = Scale(2, identity; init_weight=ones, init_bias=ones) + first(Lux.apply(layer, [1 2; 3 4], Lux.setup(rng, layer)...)) + end == [2.0 3.0; 4.0 5.0] + + @test begin + layer = Scale(2, tanh; bias=false, init_weight=zeros) + first(Lux.apply(layer, [1 2; 3 4], Lux.setup(rng, layer)...)) + end == zeros(2, 2) + end +end