Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extending Scale to allow for multiple dimension inputs #40

Merged
merged 4 commits into from
May 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.4.2"
version = "0.4.3"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
27 changes: 18 additions & 9 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,7 @@ Create a Sparsely Connected Layer with a very specific structure (only Diagonal

## Arguments

* `dims`: number of input and output dimensions
* `dims`: size of the learnable scale and bias parameters.
* `activation`: activation function

## Keyword Arguments
Expand All @@ -702,17 +702,20 @@ Create a Sparsely Connected Layer with a very specific structure (only Diagonal

## Input

* `x` must be a Matrix of size `dims × B` or a Vector of length `dims`
* `x` must be an Array of size `(dims..., B)` or `(dims...[0], ..., dims[k])` for `k ≤ size(dims)`

## Returns

* Matrix of size `dims × B` or a Vector of length `dims`
* Array of size `(dims..., B)` or `(dims...[0], ..., dims[k])` for `k ≤ size(dims)`
* Empty `NamedTuple()`

## Parameters

* `weight`: Weight Vector of size `(dims,)`
* `bias`: Bias of size `(dims,)`
* `weight`: Weight Array of size `(dims...)`
* `bias`: Bias of size `(dims...)`

!!! compat "Lux 0.4.3"
`Scale` with multiple dimensions requires at least Lux 0.4.3.
"""
struct Scale{bias, F1, D, F2, F3} <: AbstractExplicitLayer
activation::F1
Expand All @@ -727,21 +730,27 @@ function Base.show(io::IO, d::Scale)
return print(io, ")")
end

function Scale(dims, activation=identity; init_weight=glorot_uniform,
function Scale(dims::Tuple{Vararg{Integer}}, activation=identity;
init_weight=glorot_uniform,
init_bias=zeros32, bias::Bool=true)
activation = NNlib.fast_act(activation)
return Scale{bias, typeof(activation), typeof(dims), typeof(init_weight),
typeof(init_bias)}(activation, dims, init_weight, init_bias)
end

function Scale(s1::Integer, s23::Integer...; _act=identity, kw...)
Scale(tuple(s1, s23...), _act; kw...)
end
Scale(size_act...; kw...) = Scale(size_act[1:(end - 1)]...; _act=size_act[end], kw...)

function initialparameters(rng::AbstractRNG, d::Scale{true})
return (weight=d.init_weight(rng, d.dims), bias=d.init_bias(rng, d.dims))
return (weight=d.init_weight(rng, d.dims...), bias=d.init_bias(rng, d.dims...))
end
function initialparameters(rng::AbstractRNG, d::Scale{false})
(weight=d.init_weight(rng, d.dims),)
(weight=d.init_weight(rng, d.dims...),)
end

parameterlength(d::Scale{bias}) where {bias} = (1 + bias) * d.dims
parameterlength(d::Scale{bias}) where {bias} = (1 + bias) * prod(d.dims)
statelength(d::Scale) = 0

function (d::Scale{true})(x::AbstractArray, ps::Union{ComponentArray, NamedTuple},
Expand Down
47 changes: 47 additions & 0 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,50 @@ end
end == [10 20; 10 20]
end
end

@testset "Scale" begin
@testset "constructors" begin
layer = Scale(10, 100)
ps, st = Lux.setup(rng, layer)

@test size(ps.weight) == (10, 100)
@test size(ps.bias) == (10, 100)
@test layer.activation == identity

layer = Scale(10, 100, relu; bias=false)
ps, st = Lux.setup(rng, layer)

@test !haskey(ps, :bias)
@test layer.activation == relu
end

@testset "dimensions" begin
layer = Scale(10, 5)
ps, st = Lux.setup(rng, layer)

@test size(first(Lux.apply(layer, randn(10), ps, st))) == (10, 5)
@test size(first(Lux.apply(layer, randn(10, 5, 2), ps, st))) == (10, 5, 2)
end

@testset "zeros" begin
@test begin
layer = Scale(10, 1, identity; init_weight=ones)
first(Lux.apply(layer, ones(10, 1), Lux.setup(rng, layer)...))
end == ones(10, 1)

@test begin
layer = Scale(10, 1, identity; init_weight=ones)
first(Lux.apply(layer, ones(10, 2), Lux.setup(rng, layer)...))
end == ones(10, 2)

@test begin
layer = Scale(2, identity; init_weight=ones, init_bias=ones)
first(Lux.apply(layer, [1 2; 3 4], Lux.setup(rng, layer)...))
end == [2.0 3.0; 4.0 5.0]

@test begin
layer = Scale(2, tanh; bias=false, init_weight=zeros)
first(Lux.apply(layer, [1 2; 3 4], Lux.setup(rng, layer)...))
end == zeros(2, 2)
end
end