Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dispatch on StaticArray instead of SArray #7

Merged
merged 1 commit into from
Feb 12, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions src/DiffResults.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,17 @@ Return `r::DiffResult`, with output value storage provided by `value` and output
storage provided by `derivs`.

In reality, `DiffResult` is an abstract supertype of two concrete types, `MutableDiffResult`
and `ImmutableDiffResult`. If all `value`/`derivs` are all `Number`s or `SArray`s, then `r`
will be immutable (i.e. `r::ImmutableDiffResult`). Otherwise, `r` will be mutable
and `ImmutableDiffResult`. If all `value`/`derivs` are all `Number`s or `StaticArray`s,
then `r` will be immutable (i.e. `r::ImmutableDiffResult`). Otherwise, `r` will be mutable
(i.e. `r::MutableDiffResult`).

Note that `derivs` can be provide in splatted form, i.e. `DiffResult(value, derivs...)`.
"""
DiffResult

DiffResult(value::Number, derivs::Tuple{Vararg{Number}}) = ImmutableDiffResult(value, derivs)
DiffResult(value::Number, derivs::Tuple{Vararg{SArray}}) = ImmutableDiffResult(value, derivs)
DiffResult(value::SArray, derivs::Tuple{Vararg{SArray}}) = ImmutableDiffResult(value, derivs)
DiffResult(value::Number, derivs::Tuple{Vararg{StaticArray}}) = ImmutableDiffResult(value, derivs)
DiffResult(value::StaticArray, derivs::Tuple{Vararg{StaticArray}}) = ImmutableDiffResult(value, derivs)
DiffResult(value::Number, derivs::Tuple{Vararg{AbstractArray}}) = MutableDiffResult(value, derivs)
DiffResult(value::AbstractArray, derivs::Tuple{Vararg{AbstractArray}}) = MutableDiffResult(value, derivs)
DiffResult(value::Union{Number,AbstractArray}, derivs::Union{Number,AbstractArray}...) = DiffResult(value, derivs)
Expand All @@ -65,7 +65,7 @@ shape information. If you want to allocate storage yourself, use the `DiffResult
constructor instead.
"""
GradientResult(x::AbstractArray) = DiffResult(first(x), similar(x))
GradientResult(x::SArray) = DiffResult(first(x), x)
GradientResult(x::StaticArray) = DiffResult(first(x), x)

"""
JacobianResult(x::AbstractArray)
Expand All @@ -79,7 +79,7 @@ shape information. If you want to allocate storage yourself, use the `DiffResult
constructor instead.
"""
JacobianResult(x::AbstractArray) = DiffResult(similar(x), similar(x, length(x), length(x)))
JacobianResult(x::SArray{<:Any,T,<:Any,L}) where {T,L} = DiffResult(x, zeros(SMatrix{L,L,T}))
JacobianResult(x::StaticArray) = DiffResult(x, zeros(StaticArrays.similar_type(typeof(x), Size(length(x),length(x)))))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will the length(x) here cause the size parameter of the resulting array to be uninferrable, or does constant prop take care of that for us?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Constant prop seems to take care of it (as I assumed without worrying much 😄):

using DiffResults, StaticArrays, Test

v = SVector([1, 2.5, 4])
@inferred DiffResults.JacobianResult(v)
@code_warntype DiffResults.JacobianResult(v)

outputs

Body::DiffResults.ImmutableDiffResult{1,SArray{Tuple{3},Float64,1,3},Tuple{SArray{Tuple{3,3},Float64,2,9}}}
1%1 = %new(DiffResults.ImmutableDiffResult{1,SArray{Tuple{3},Float64,1,3},Tuple{SArray{Tuple{3,3},Float64,2,9}}}, x, ([0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0],))::DiffResults.ImmutableDiffResult{1,SArray{Tuple{3},Float64,1,3},Tuple{SArray{Tuple{3,3},Float64,2,9}}}
└──      return %1


"""
JacobianResult(y::AbstractArray, x::AbstractArray)
Expand All @@ -92,7 +92,7 @@ Like the single argument version, `y` and `x` are only used for type and
shape information and are not stored in the returned `DiffResult`.
"""
JacobianResult(y::AbstractArray, x::AbstractArray) = DiffResult(similar(y), similar(y, length(y), length(x)))
JacobianResult(y::SArray{<:Any,<:Any,<:Any,Y}, x::SArray{<:Any,T,<:Any,X}) where {T,Y,X} = DiffResult(y, zeros(SMatrix{Y,X,T}))
JacobianResult(y::StaticArray, x::StaticArray) = DiffResult(y, zeros(StaticArrays.similar_type(typeof(x), Size(length(y),length(x)))))

"""
HessianResult(x::AbstractArray)
Expand All @@ -105,7 +105,7 @@ shape information. If you want to allocate storage yourself, use the `DiffResult
constructor instead.
"""
HessianResult(x::AbstractArray) = DiffResult(first(x), similar(x), similar(x, length(x), length(x)))
HessianResult(x::SArray{<:Any,T,<:Any,L}) where {T,L} = DiffResult(first(x), x, zeros(SMatrix{L,L,T}))
HessianResult(x::StaticArray) = DiffResult(first(x), x, zeros(StaticArrays.similar_type(typeof(x), Size(length(x),length(x)))))

#############
# Interface #
Expand Down Expand Up @@ -203,7 +203,7 @@ function derivative!(r::MutableDiffResult, x::AbstractArray, ::Type{Val{i}} = Va
return r
end

function derivative!(r::ImmutableDiffResult, x::Union{Number,SArray}, ::Type{Val{i}} = Val{1}) where {i}
function derivative!(r::ImmutableDiffResult, x::Union{Number,StaticArray}, ::Type{Val{i}} = Val{1}) where {i}
return ImmutableDiffResult(value(r), tuple_setindex(r.derivs, x, Val{i}))
end

Expand Down Expand Up @@ -232,7 +232,7 @@ function derivative!(f, r::ImmutableDiffResult, x::Number, ::Type{Val{i}} = Val{
return derivative!(r, f(x), Val{i})
end

function derivative!(f, r::ImmutableDiffResult, x::SArray, ::Type{Val{i}} = Val{1}) where {i}
function derivative!(f, r::ImmutableDiffResult, x::StaticArray, ::Type{Val{i}} = Val{1}) where {i}
return derivative!(r, map(f, x), Val{i})
end

Expand Down