Skip to content

Commit

Permalink
Delete higher order function rules (#172)
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox authored Apr 20, 2020
1 parent 59f06bc commit 02e7857
Show file tree
Hide file tree
Showing 11 changed files with 7 additions and 213 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.4.1"
version = "0.5.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
3 changes: 0 additions & 3 deletions src/ChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,8 @@ if VERSION < v"1.3.0-DEV.142"
import LinearAlgebra: dot
end

include("helper_functions.jl")

include("rulesets/Base/base.jl")
include("rulesets/Base/array.jl")
include("rulesets/Base/broadcast.jl")
include("rulesets/Base/mapreduce.jl")

include("rulesets/Statistics/statistics.jl")
Expand Down
24 changes: 0 additions & 24 deletions src/helper_functions.jl

This file was deleted.

28 changes: 0 additions & 28 deletions src/rulesets/Base/broadcast.jl

This file was deleted.

64 changes: 6 additions & 58 deletions src/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
@@ -1,51 +1,3 @@
#####
##### `map`
#####

function rrule(::typeof(map), f, xs...)
y = map(f, xs...)
function map_pullback(ȳ)
ntuple(length(xs)+2) do full_i
full_i == 1 && return NO_FIELDS
full_i == 2 && return DoesNotExist()
i = full_i-2
@thunk map(ȳ, xs...) do ȳi, xis...
_, pullback = _checked_rrule(f, xis...)
∂xis = pullback(ȳi)
extern(∂xis[i+1]) #+1 to skp ∂self
end
end
end
return y, map_pullback
end

#####
##### `mapreduce`, `mapfoldl`, `mapfoldr`
#####

for mf in (:mapreduce, :mapfoldl, :mapfoldr)
sig = :(rrule(::typeof($mf), f, op, x::AbstractArray{<:Real}))
call = :($mf(f, op, x))
if mf === :mapreduce
insert!(sig.args, 2, Expr(:parameters, Expr(:kw, :dims, :(:))))
insert!(call.args, 2, Expr(:parameters, Expr(:kw, :dims, :dims)))
end
pullback_name = Symbol(mf, :_pullback)
body = quote
y = $call
function $pullback_name(ȳ)
∂x = @thunk broadcast(x, ȳ) do xi, ȳi
_, pullback_f = _checked_rrule(f, xi)
_, ∂xi = pullback_f(ȳi)
extern(∂xi)
end
(NO_FIELDS, DoesNotExist(), DoesNotExist(), ∂x)
end
return y, $pullback_name
end
eval(Expr(:function, sig, body))
end

#####
##### `sum`
#####
Expand All @@ -54,18 +6,14 @@ function frule((_, ẋ), ::typeof(sum), x)
return sum(x), sum(ẋ)
end

function rrule(::typeof(sum), f, x::AbstractArray{<:Real}; dims=:)
y, mr_pullback = rrule(mapreduce, f, Base.add_sum, x; dims=dims)
function sum_pullback(ȳ)
return NO_FIELDS, DoesNotExist(), last(mr_pullback(ȳ))
end
return y, sum_pullback
end

function rrule(::typeof(sum), x::AbstractArray{<:Real}; dims=:)
y, inner_pullback = rrule(sum, identity, x; dims=dims)
y = sum(sum, x; dims=dims)
function sum_pullback(ȳ)
return NO_FIELDS, last(inner_pullback(ȳ))
# broadcasting the two works out the size no-matter `dims`
= @thunk broadcast(x, ȳ) do xi, ȳi
ȳi
end
return (NO_FIELDS, x̄)
end
return y, sum_pullback
end
Expand Down
13 changes: 0 additions & 13 deletions src/rulesets/Statistics/statistics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,3 @@ function rrule(::typeof(mean), x::AbstractArray{<:Real}; dims=:)
end
return y_sum / n, mean_pullback
end

function rrule(::typeof(mean), f, x::AbstractArray{<:Real})
y_sum, sum_pullback = rrule(sum, f, x)
n = _denom(x, :)
function mean_pullback(ȳ)
∂x = Thunk() do
_, _, ∂sum_x = sum_pullback(ȳ)
extern(∂sum_x) / n
end
return (NO_FIELDS, DoesNotExist(), ∂x)
end
return y_sum / n, mean_pullback
end
13 changes: 0 additions & 13 deletions test/helper_functions.jl

This file was deleted.

27 changes: 0 additions & 27 deletions test/rulesets/Base/broadcast.jl

This file was deleted.

39 changes: 0 additions & 39 deletions test/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
@@ -1,37 +1,4 @@
@testset "Maps and Reductions" begin
@testset "map" begin
rng = MersenneTwister(42)
n = 10
x = randn(rng, n)
vx = randn(rng, n)
= randn(rng, n)
rrule_test(map, ȳ, (sin, nothing), (x, vx))
rrule_test(map, ȳ, (+, nothing), (x, vx), (randn(rng, n), randn(rng, n)))
end
@testset "mapreduce" begin
rng = MersenneTwister(6)
n = 10
x = randn(rng, n)
vx = randn(rng, n)
= randn(rng)
rrule_test(mapreduce, ȳ, (sin, nothing), (+, nothing), (x, vx))

# With keyword arguments (not yet supported in rrule_test)
X = randn(rng, n, n)
y, pullback = rrule(mapreduce, abs2, +, X; dims=2)
= randn(rng, size(y))
(_, _, _, x̄_ad) = pullback(ȳ)
x̄_fd = only(j′vp(central_fdm(5, 1), x->mapreduce(abs2, +, x; dims=2), ȳ, X))
@test x̄_ad x̄_fd atol=1e-9 rtol=1e-9
end
@testset "$f" for f in (mapfoldl, mapfoldr)
rng = MersenneTwister(10)
n = 7
x = randn(rng, n)
vx = randn(rng, n)
= randn(rng)
rrule_test(f, ȳ, (cos, nothing), (+, nothing), (x, vx))
end
@testset "sum" begin
@testset "Vector" begin
rng, M = MersenneTwister(123456), 3
Expand All @@ -48,12 +15,6 @@
frule_test(sum, (randn(rng, M, N, P), randn(rng, M, N, P)))
rrule_test(sum, randn(rng), (randn(rng, M, N, P), randn(rng, M, N, P)))
end
@testset "function argument" begin
rng = MersenneTwister(1)
n = 8
rrule_test(sum, randn(rng), (cos, nothing), (randn(rng, n), randn(rng, n)))
rrule_test(sum, randn(rng), (abs2, nothing), (randn(rng, n), randn(rng, n)))
end
@testset "keyword arguments" begin
rng = MersenneTwister(33)
n = 4
Expand Down
4 changes: 0 additions & 4 deletions test/rulesets/Statistics/statistics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@
rrule_test(mean, randn(rng), (randn(rng, n), randn(rng, n)))
end

@testset "with function arg" begin
rrule_test(mean, randn(rng), (abs2, nothing), (randn(rng, n), randn(rng, n)))
end

@testset "with dims kwargs" begin
X = randn(rng, n, n+1)
y, mean_pullback = rrule(mean, X; dims=1)
Expand Down
3 changes: 0 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,11 @@ Random.seed!(1) # Set seed that all testsets should reset to.

println("Testing ChainRules.jl")
@testset "ChainRules" begin
include("helper_functions.jl")
@testset "rulesets" begin

@testset "Base" begin
include(joinpath("rulesets", "Base", "base.jl"))
include(joinpath("rulesets", "Base", "array.jl"))
include(joinpath("rulesets", "Base", "mapreduce.jl"))
include(joinpath("rulesets", "Base", "broadcast.jl"))
end

print(" ")
Expand Down

2 comments on commit 02e7857

@oxinabox
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/13359

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.5.0 -m "<description of version>" 02e7857e34b5c01067a288262f69cfcb9fce069b
git push origin v0.5.0

Please sign in to comment.