Skip to content

Commit

Permalink
Add rules for evalpoly (#190)
Browse files Browse the repository at this point in the history
* Add rules for evalpoly

* Rename pullback

* Apply suggestions from code review

Co-authored-by: Lyndon White <[email protected]>

* Make backx its own method for the future

* Update UniformScaling to-dos

* Add matrix poly tests

* Use correct indices

* Reorganize

* Use generated functions with fallbacks

* Reimplement rules for matrices

* Deactivate complex tests

Until complex conventions are clarified and FiniteDifferences v0.10.0 is supported.

* Add generated functions for tuple case

* Add comment

* Rename defs to exs

* Refactor and test fallbacks

* Simplify indexing

* Don't store output as an intermediate

* Support scalar x with matrix pi

* Make extensible for other ps

* Move fallback tests under rrule tests

* Reorder args and remove unnecessary product

* Eliminate unneeded mul and reorganize

* Remove unnecessary product

* Fix length of ys and wrap lines

* Place final ∂yi to in loop

This for some reason profiles much faster

* Keep other rules consistent with vector

* Unify tests

* Increment version number

* Try equality check outside of tuple

* Approximate check due to muladd

* Approximate check scalar output too

* Decrement version number

Co-authored-by: Lyndon White <[email protected]>
  • Loading branch information
sethaxen and oxinabox authored Jun 30, 2020
1 parent 1a3e34a commit 224e553
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.7.1"
version = "0.7.2"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
143 changes: 143 additions & 0 deletions src/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,146 @@ function rrule(::typeof(identity), x)
end
return (x, identity_pullback)
end

#####
##### `evalpoly`
#####

if VERSION v"1.4"
function frule((_, Δx, Δp), ::typeof(evalpoly), x, p)
N = length(p)
@inbounds y = p[N]
Δy = Δp[N]
@inbounds for i in (N - 1):-1:1
Δy = muladd(Δx, y, muladd(x, Δy, Δp[i]))
y = muladd(x, y, p[i])
end
return y, Δy
end

function rrule(::typeof(evalpoly), x, p)
y, ys = _evalpoly_intermediates(x, p)
function evalpoly_pullback(Δy)
∂x, ∂p = _evalpoly_back(x, p, ys, Δy)
return NO_FIELDS, ∂x, ∂p
end
return y, evalpoly_pullback
end

# evalpoly but storing intermediates
function _evalpoly_intermediates(x, p::Tuple)
return if @generated
N = length(p.parameters)
exs = []
vars = []
ex = :(p[$N])
for i in 1:(N - 1)
yi = Symbol("y", i)
push!(vars, yi)
push!(exs, :($yi = $ex))
ex = :(muladd(x, $yi, p[$(N - i)]))
end
push!(exs, :(y = $ex))
Expr(:block, exs..., :(y, ($(vars...),)))
else
_evalpoly_intermediates_fallback(x, p)
end
end
function _evalpoly_intermediates_fallback(x, p::Tuple)
N = length(p)
y = p[N]
ys = (y, ntuple(N - 2) do i
return y = muladd(x, y, p[N - i])
end...)
y = muladd(x, y, p[1])
return y, ys
end
function _evalpoly_intermediates(x, p)
N = length(p)
@inbounds yn = one(x) * p[N]
ys = similar(p, typeof(yn), N - 1)
@inbounds ys[1] = yn
@inbounds for i in 2:(N - 1)
ys[i] = muladd(x, ys[i - 1], p[N - i + 1])
end
@inbounds y = muladd(x, ys[N - 1], p[1])
return y, ys
end

# TODO: Handle following cases
# 1) x is a UniformScaling, pᵢ is a matrix
# 2) x is a matrix, pᵢ is a UniformScaling
@inline _evalpoly_backx(x, yi, ∂yi) = ∂yi * yi'
@inline _evalpoly_backx(x, yi, ∂x, ∂yi) = muladd(∂yi, yi', ∂x)
@inline _evalpoly_backx(x::Number, yi, ∂yi) = conj(dot(∂yi, yi))
@inline _evalpoly_backx(x::Number, yi, ∂x, ∂yi) = _evalpoly_backx(x, yi, ∂yi) + ∂x

@inline _evalpoly_backp(pi, ∂yi) = ∂yi

function _evalpoly_back(x, p::Tuple, ys, Δy)
return if @generated
exs = []
vars = []
N = length(p.parameters)
for i in 2:(N - 1)
∂pi = Symbol("∂p", i)
push!(vars, ∂pi)
push!(exs, :(∂x = _evalpoly_backx(x, ys[$(N - i)], ∂x, ∂yi)))
push!(exs, :($∂pi = _evalpoly_backp(p[$i], ∂yi)))
push!(exs, :(∂yi = x′ * ∂yi))
end
push!(vars, :(_evalpoly_backp(p[$N], ∂yi))) # ∂pN
Expr(
:block,
:(x′ = x'),
:(∂yi = Δy),
:(∂p1 = _evalpoly_backp(p[1], ∂yi)),
:(∂x = _evalpoly_backx(x, ys[$(N - 1)], ∂yi)),
:(∂yi = x′ * ∂yi),
exs...,
:(∂p = (∂p1, $(vars...))),
:(∂x, Composite{typeof(p),typeof(∂p)}(∂p)),
)
else
_evalpoly_back_fallback(x, p, ys, Δy)
end
end
function _evalpoly_back_fallback(x, p::Tuple, ys, Δy)
x′ = x'
∂yi = Δy
N = length(p)
∂p1 = _evalpoly_backp(p[1], ∂yi)
∂x = _evalpoly_backx(x, ys[N - 1], ∂yi)
∂yi = x′ * ∂yi
∂p = (
∂p1,
ntuple(N - 2) do i
∂x = _evalpoly_backx(x, ys[N-i-1], ∂x, ∂yi)
∂pi = _evalpoly_backp(p[i+1], ∂yi)
∂yi = x′ * ∂yi
return ∂pi
end...,
_evalpoly_backp(p[N], ∂yi), # ∂pN
)
return ∂x, Composite{typeof(p),typeof(∂p)}(∂p)
end
function _evalpoly_back(x, p, ys, Δy)
x′ = x'
∂yi = one(x′) * Δy
N = length(p)
@inbounds ∂p1 = _evalpoly_backp(p[1], ∂yi)
∂p = similar(p, typeof(∂p1))
@inbounds begin
∂x = _evalpoly_backx(x, ys[N - 1], ∂yi)
∂yi = x′ * ∂yi
∂p[1] = ∂p1
for i in 2:(N - 1)
∂x = _evalpoly_backx(x, ys[N - i], ∂x, ∂yi)
∂p[i] = _evalpoly_backp(p[i], ∂yi)
∂yi = x′ * ∂yi
end
∂p[N] = _evalpoly_backp(p[N], ∂yi)
end
return ∂x, ∂p
end
end
35 changes: 35 additions & 0 deletions test/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,41 @@
)
end

VERSION v"1.4" && @testset "evalpoly" begin
# test fallbacks for when code generation fails
@testset "fallbacks for $T" for T in (Float64, ComplexF64)
x, p = randn(T), Tuple(randn(T, 10))
y_fb, ys_fb = ChainRules._evalpoly_intermediates_fallback(x, p)
y, ys = ChainRules._evalpoly_intermediates(x, p)
@test y_fb y
@test collect(ys_fb) collect(ys)

Δy, ys = randn(T), Tuple(randn(T, 9))
∂x_fb, ∂p_fb = ChainRules._evalpoly_back_fallback(x, p, ys, Δy)
∂x, ∂p = ChainRules._evalpoly_back(x, p, ys, Δy)
@test ∂x_fb ∂x
@test collect(∂p_fb) collect(∂p)
end

@testset "x dim: $(nx), pi dim: $(np), type: $T" for T in (Float64, ComplexF64), nx in (tuple(), 3), np in (tuple(), 3)
# skip x::Matrix, pi::Number case, which is not supported by evalpoly
isempty(np) && !isempty(nx) && continue
m = 5
sx = (nx..., nx...)
sp = (np..., np...)
x, ẋ, x̄ = randn(T, sx...), randn(T, sx...), randn(T, sx...)
p = [randn(T, sp...) for _ in 1:m]
= [randn(T, sp...) for _ in 1:m]
= [randn(T, sp...) for _ in 1:m]
Ω = evalpoly(x, p)
Ω̄ = randn(T, size(Ω)...)
frule_test(evalpoly, (x, ẋ), (p, ṗ))
frule_test(evalpoly, (x, ẋ), (Tuple(p), Tuple(ṗ)))
rrule_test(evalpoly, Ω̄, (x, x̄), (p, p̄))
rrule_test(evalpoly, Ω̄, (x, x̄), (Tuple(p), Tuple(p̄)))
end
end

@testset "Constants" for x in (-0.1, 6.4, 1.0+0.5im, -10.0+0im, 0.0+200im)
test_scalar(one, x)
test_scalar(zero, x)
Expand Down

2 comments on commit 224e553

@sethaxen
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/17266

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.7.2 -m "<description of version>" 224e553132f26979e7f5c34fccaf769788022299
git push origin v0.7.2

Please sign in to comment.