diff --git a/src/basicfuns.jl b/src/basicfuns.jl index 638d2312..2bbced2a 100644 --- a/src/basicfuns.jl +++ b/src/basicfuns.jl @@ -232,19 +232,19 @@ That is, `r` is overwritten with `exp.(x)`, normalized to sum to 1. See the [Wikipedia entry](https://en.wikipedia.org/wiki/Softmax_function) """ -function softmax!(r::AbstractArray{R}, x::AbstractArray{T}) where {R<:AbstractFloat,T<:Real} +function softmax!(r::AbstractArray{<:Real}, x::AbstractArray{<:Real}) n = length(x) length(r) == n || throw(DimensionMismatch("Inconsistent array lengths.")) u = maximum(x) - s = 0. + s = zero(eltype(r)) @inbounds for i = 1:n s += (r[i] = exp(x[i] - u)) end - invs = convert(R, inv(s)) + invs = inv(s) @inbounds for i = 1:n r[i] *= invs end - r + return r end """ @@ -261,4 +261,4 @@ $(SIGNATURES) Return the [`softmax transformation`](https://en.wikipedia.org/wiki/Softmax_function) applied to `x`. """ -softmax(x::AbstractArray{<:Real}) = softmax!(similar(x, Float64), x) +softmax(x::AbstractArray{<:Real}) = softmax!(similar(x, float(eltype(x))), x) diff --git a/test/basicfuns.jl b/test/basicfuns.jl index 431d3575..3b583824 100644 --- a/test/basicfuns.jl +++ b/test/basicfuns.jl @@ -149,7 +149,22 @@ end @testset "softmax" begin x = [1.0, 2.0, 3.0] r = exp.(x) ./ sum(exp.(x)) - @test softmax([1.0, 2.0, 3.0]) ≈ r + @test softmax(x) ≈ r softmax!(x) @test x ≈ r + + x = [1, 2, 3] + r = exp.(x) ./ sum(exp.(x)) + @test softmax(x) ≈ r + @test eltype(softmax(x)) == Float64 + + x = [1//2, 2//3, 3//4] + r = exp.(x) ./ sum(exp.(x)) + @test softmax(x) ≈ r + @test eltype(softmax(x)) == Float64 + + x = Float32[1, 2, 3] + r = exp.(x) ./ sum(exp.(x)) + @test softmax(x) ≈ r + @test eltype(softmax(x)) == Float32 end