Skip to content

Commit

Permalink
Update to new API (#169)
Browse files Browse the repository at this point in the history
* Update to new API:

* Fixed bound

* Update Project.toml

Co-Authored-By: Nick Robinson <[email protected]>

* Update Project.toml

* Update Project.toml

Co-Authored-By: Nick Robinson <[email protected]>

* Update Project.toml

Co-Authored-By: Nick Robinson <[email protected]>

* Bump

* Sort out version

* Require tests pass on 1.3

* Re-disable 1.3 because #149

* Oops

Co-authored-by: Nick Robinson <[email protected]>
  • Loading branch information
willtebbutt and nickrobinson251 authored Feb 25, 2020
1 parent d3464c2 commit 71a9789
Show file tree
Hide file tree
Showing 8 changed files with 20 additions and 20 deletions.
10 changes: 5 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.3.4"
version = "0.4.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -10,12 +10,12 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRulesCore = "0.6.1"
ChainRulesTestUtils = "0.1.3"
FiniteDifferences = "^0.7"
ChainRulesCore = "0.7.0"
ChainRulesTestUtils = "0.2"
FiniteDifferences = "0.7"
Reexport = "0.2"
Requires = "0.5.2, 1"
julia = "^1.0"
julia = "1"

[extras]
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Expand Down
4 changes: 2 additions & 2 deletions src/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@

# product rule requires special care for arguments where `mul` is non-commutative

function frule(::typeof(*), x::Number, y::Number, _, Δx, Δy)
function frule((_, Δx, Δy), ::typeof(*), x::Number, y::Number)
# Optimized version of `Δx .* y .+ x .* Δy`. Also, it is potentially more
# accurate on machines with FMA instructions, since there are only two
# rounding operations, one in `muladd/fma` and the other in `*`.
Expand All @@ -122,7 +122,7 @@ function rrule(::typeof(*), x::Number, y::Number)
return x * y, times_pullback
end

function frule(::typeof(identity), x, _, ẏ)
function frule((_, ẏ), ::typeof(identity), x)
return x, ẏ
end

Expand Down
4 changes: 2 additions & 2 deletions src/rulesets/Base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ https://github.com/JuliaLang/julia/issues/22129.
function _cast_diff(f, x)
function element_rule(u)
dself = Zero()
fu, du = frule(f, u, dself, One())
fu, du = frule((dself, One()), f, u)
fu, extern(du)
end
results = broadcast(element_rule, x)
return first.(results), last.(results)
end

function frule(::typeof(broadcast), f, x, _, Δf, Δx)
function frule((_, Δf, Δx), ::typeof(broadcast), f, x)
Ω, ∂x = _cast_diff(f, x)
return Ω, Δx .* ∂x
end
Expand Down
2 changes: 1 addition & 1 deletion src/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ end
##### `sum`
#####

function frule(::typeof(sum), x, _, ẋ)
function frule((_, ẋ), ::typeof(sum), x)
return sum(x), sum(ẋ)
end

Expand Down
6 changes: 3 additions & 3 deletions src/rulesets/LinearAlgebra/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ _zeros(x) = fill!(similar(x), zero(eltype(x)))
##### `BLAS.dot`
#####

frule(::typeof(BLAS.dot), x, y, Δself, Δx, Δy) = frule(dot, x, y, Δself, Δx, Δy)
frule((Δself, Δx, Δy), ::typeof(BLAS.dot), x, y) = frule((Δself, Δx, Δy), dot, x, y)

rrule(::typeof(BLAS.dot), x, y) = rrule(dot, x, y)

Expand All @@ -35,7 +35,7 @@ end
##### `BLAS.nrm2`
#####

function frule(::typeof(BLAS.nrm2), x, _, Δx)
function frule((_, Δx), ::typeof(BLAS.nrm2), x)
Ω = BLAS.nrm2(x)
return Ω, sum(Δx .* @thunk(x * inv(Ω)))
end
Expand Down Expand Up @@ -67,7 +67,7 @@ end
##### `BLAS.asum`
#####

function frule(::typeof(BLAS.asum), x, _, Δx)
function frule((_, Δx), ::typeof(BLAS.asum), x)
return BLAS.asum(x), sum(sign.(x) .* Δx)
end

Expand Down
10 changes: 5 additions & 5 deletions src/rulesets/LinearAlgebra/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ const SquareMatrix{T} = Union{Diagonal{T},AbstractTriangular{T}}
##### `dot`
#####

function frule(::typeof(dot), x, y, _, Δx, Δy)
function frule((_, Δx, Δy), ::typeof(dot), x, y)
return dot(x, y), sum(Δx .* y) + sum(x .* Δy)
end

Expand All @@ -23,7 +23,7 @@ end
##### `inv`
#####

function frule(::typeof(inv), x::AbstractArray, _, Δx)
function frule((_, Δx), ::typeof(inv), x::AbstractArray)
Ω = inv(x)
return Ω, -Ω * Δx * Ω
end
Expand All @@ -40,7 +40,7 @@ end
##### `det`
#####

function frule(::typeof(det), x, _, ẋ)
function frule((_, ẋ), ::typeof(det), x)
Ω = det(x)
# TODO Performance optimization: probably there is an efficent
# way to compute this trace without during the full compution within
Expand All @@ -59,7 +59,7 @@ end
##### `logdet`
#####

function frule(::typeof(logdet), x, _, Δx)
function frule((_, Δx), ::typeof(logdet), x)
Ω = logdet(x)
return Ω, tr(inv(x) * Δx)
end
Expand All @@ -76,7 +76,7 @@ end
##### `trace`
#####

function frule(::typeof(tr), x, _, Δx)
function frule((_, Δx), ::typeof(tr), x)
return tr(x), tr(Δx)
end

Expand Down
2 changes: 1 addition & 1 deletion test/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@
_, x̄ = pb(10.5)
@test extern(x̄) == 0

_, ẏ = frule(sign, 0.0, Zero(), 10.5)
_, ẏ = frule((Zero(), 10.5), sign, 0.0)
@test extern(ẏ) == 0
end
end
Expand Down
2 changes: 1 addition & 1 deletion test/rulesets/Base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
end
@testset "frule" begin
x = rand(3, 3)
y, ẏ = frule(broadcast, sin, x, Zero(), Zero(), One())
y, ẏ = frule((Zero(), Zero(), One()), broadcast, sin, x)
@test y == sin.(x)
@test extern(ẏ) == cos.(x)
end
Expand Down

2 comments on commit 71a9789

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/10069

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if Julia TagBot is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.0 -m "<description of version>" 71a978948367109fa0a15c64ac521e13d7a27b8d
git push origin v0.4.0

Please sign in to comment.