-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
mul/ewise rules for basic arithmetic semiring (#26)
* arithmetic groundwork * arithmetic rules for mul and elwise 1st pass * tests and a few fixes * Add mask function, fix eadd(PLUS) * correct mul rrules * test folder structure * mask and vector transpose v1 * Broken constructor rules * arithmetic groundwork * arithmetic rules for mul and elwise 1st pass * tests and a few fixes * Add mask function, fix eadd(PLUS) * correct mul rrules * test folder structure * Broken constructor rules * Move out constructor rules for now * compat * rm constructorrule includes
- Loading branch information
Will Kimmerer
authored
Jul 11, 2021
1 parent
952e7a0
commit f0dd5c9
Showing
18 changed files
with
305 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import FiniteDifferences | ||
import LinearAlgebra | ||
import ChainRulesCore: frule, rrule | ||
using ChainRulesCore | ||
const RealOrComplex = Union{Real, Complex} | ||
|
||
#Required for ChainRulesTestUtils | ||
function FiniteDifferences.to_vec(M::GBMatrix) | ||
I, J, X = findnz(M) | ||
function backtomat(xvec) | ||
return GBMatrix(I, J, xvec; nrows = size(M, 1), ncols = size(M, 2)) | ||
end | ||
return X, backtomat | ||
end | ||
|
||
function FiniteDifferences.to_vec(v::GBVector) | ||
i, x = findnz(v) | ||
function backtovec(xvec) | ||
return GBVector(i, xvec; nrows=size(v, 1)) | ||
end | ||
return x, backtovec | ||
end | ||
|
||
function FiniteDifferences.rand_tangent( | ||
rng::AbstractRNG, | ||
x::GBMatrix{T} | ||
) where {T <: Union{AbstractFloat, Complex}} | ||
n = nnz(x) | ||
v = rand(rng, -9:0.01:9, n) | ||
I, J, _ = findnz(x) | ||
return GBMatrix(I, J, v; nrows = size(x, 1), ncols = size(x, 2)) | ||
end | ||
|
||
function FiniteDifferences.rand_tangent( | ||
rng::AbstractRNG, | ||
x::GBVector{T} | ||
) where {T <: Union{AbstractFloat, Complex}} | ||
n = nnz(x) | ||
v = rand(rng, -9:0.01:9, n) | ||
I, _ = findnz(x) | ||
return GBVector(I, v; nrows = size(x, 1)) | ||
end | ||
|
||
FiniteDifferences.rand_tangent(rng::AbstractRNG, x::AbstractOp) = NoTangent() | ||
# LinearAlgebra.norm freaks over the nothings. | ||
LinearAlgebra.norm(A::GBArray, p::Real=2) = norm(nonzeros(A), p) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
#emul TIMES | ||
function frule( | ||
(_, ΔA, ΔB, _), | ||
::typeof(emul), | ||
A::GBArray, | ||
B::GBArray, | ||
::typeof(BinaryOps.TIMES) | ||
) | ||
Ω = emul(A, B, BinaryOps.TIMES) | ||
∂Ω = emul(ΔA, B, BinaryOps.TIMES) + emul(ΔB, A, BinaryOps.TIMES) | ||
return Ω, ∂Ω | ||
end | ||
function frule((_, ΔA, ΔB), ::typeof(emul), A::GBArray, B::GBArray) | ||
return frule((nothing, ΔA, ΔB, nothing), emul, A, B, BinaryOps.TIMES) | ||
end | ||
|
||
function rrule(::typeof(emul), A::GBArray, B::GBArray, ::typeof(BinaryOps.TIMES)) | ||
function timespullback(ΔΩ) | ||
∂A = emul(ΔΩ, B) | ||
∂B = emul(ΔΩ, A) | ||
return NoTangent(), ∂A, ∂B, NoTangent() | ||
end | ||
return emul(A, B, BinaryOps.TIMES), timespullback | ||
end | ||
|
||
function rrule(::typeof(emul), A::GBArray, B::GBArray) | ||
Ω, fullpb = rrule(emul, A, B, BinaryOps.TIMES) | ||
emulpb(ΔΩ) = fullpb(ΔΩ)[1:3] | ||
return Ω, emulpb | ||
end | ||
|
||
############ | ||
# eadd rules | ||
############ | ||
|
||
# PLUS | ||
###### | ||
|
||
function frule( | ||
(_, ΔA, ΔB, _), | ||
::typeof(eadd), | ||
A::GBArray, | ||
B::GBArray, | ||
::typeof(BinaryOps.PLUS) | ||
) | ||
Ω = eadd(A, B, BinaryOps.PLUS) | ||
∂Ω = eadd(ΔA, ΔB, BinaryOps.PLUS) | ||
return Ω, ∂Ω | ||
end | ||
function frule((_, ΔA, ΔB), ::typeof(eadd), A::GBArray, B::GBArray) | ||
return frule((nothing, ΔA, ΔB, nothing), eadd, A, B, BinaryOps.PLUS) | ||
end | ||
|
||
function rrule(::typeof(eadd), A::GBArray, B::GBArray, ::typeof(BinaryOps.PLUS)) | ||
function pluspullback(ΔΩ) | ||
return ( | ||
NoTangent(), | ||
mask(ΔΩ, A; structural = true), | ||
mask(ΔΩ, B; structural = true), | ||
NoTangent() | ||
) | ||
end | ||
return eadd(A, B, BinaryOps.PLUS), pluspullback | ||
end | ||
|
||
# Do I have to duplicate this? I get 4 tangents instead of 3 if I call the previous rule. | ||
function rrule(::typeof(eadd), A::GBArray, B::GBArray) | ||
Ω, fullpb = rrule(eadd, A, B, BinaryOps.PLUS) | ||
eaddpb(ΔΩ) = fullpb(ΔΩ)[1:3] | ||
return Ω, eaddpb | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Per Lyndon. Needs adaptation, and/or needs redefinition of map to use functions rather | ||
# than AbstractOp. | ||
#function rrule(map, f, xs) | ||
# # Rather than 3 maps really want 1 multimap | ||
# ys_and_pullbacks = map(x->rrule(f, x), xs) #Take this to ys = map(f, x) | ||
# ys = map(first, ys_and_pullbacks) | ||
# pullbacks = map(last, ys_and_pullbacks) | ||
# function map_pullback(dys) | ||
# _call(f, x) = f(x) | ||
# dfs_and_dxs = map(_call, pullbacks, dys) | ||
# # but in your case you know it will be NoTangent() so can skip | ||
# df = sum(first, dfs_and_dxs) | ||
# dxs = map(last, dfs_and_dxs) | ||
# return NoTangent(), df, dxs | ||
# end | ||
# return ys, map_pullback | ||
#end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# Standard arithmetic mul: | ||
function frule( | ||
(_, ΔA, ΔB), | ||
::typeof(mul), | ||
A::GBMatOrTranspose, | ||
B::GBMatOrTranspose | ||
) | ||
frule((nothing, ΔA, ΔB, nothing), mul, A, B, Semirings.PLUS_TIMES) | ||
end | ||
function frule( | ||
(_, ΔA, ΔB, _), | ||
::typeof(mul), | ||
A::GBMatOrTranspose, | ||
B::GBMatOrTranspose, | ||
::typeof(Semirings.PLUS_TIMES) | ||
) | ||
Ω = mul(A, B) | ||
∂Ω = mul(ΔA, B) + mul(A, ΔB) | ||
return Ω, ∂Ω | ||
end | ||
# Tests will not pass for this. For two reasons. | ||
# First is #25, the output inference is not type stable. | ||
# That's it's own issue. | ||
|
||
# Second, to_vec currently works by mapping materialized values back and forth, ie. it knows nothing about nothings. | ||
# This means they give different answers. FiniteDifferences is probably "incorrect", but I have no proof. | ||
|
||
function rrule( | ||
::typeof(mul), | ||
A::GBMatOrTranspose, | ||
B::GBMatOrTranspose, | ||
::typeof(Semirings.PLUS_TIMES) | ||
) | ||
function mulpullback(ΔΩ) | ||
∂A = mul(ΔΩ, B'; mask=A) | ||
∂B = mul(A', ΔΩ; mask=B) | ||
return NoTangent(), ∂A, ∂B, NoTangent() | ||
end | ||
return mul(A, B), mulpullback | ||
end | ||
|
||
|
||
function rrule( | ||
::typeof(mul), | ||
A::GBMatOrTranspose, | ||
B::GBMatOrTranspose | ||
) | ||
Ω, mulpullback = rrule(mul, A, B, Semirings.PLUS_TIMES) | ||
pullback(ΔΩ) = mulpullback(ΔΩ)[1:3] | ||
return Ω, pullback | ||
end |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -59,7 +59,6 @@ function LinearAlgebra.mul!( | |
return w | ||
end | ||
|
||
|
||
""" | ||
mul(A::GBArray, B::GBArray; kwargs...)::GBArray | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
using FiniteDifferences | ||
function test_to_vec(x::T; check_inferred=true) where {T} | ||
check_inferred && @inferred FiniteDifferences.to_vec(x) | ||
x_vec, back = FiniteDifferences.to_vec(x) | ||
@test x_vec isa Vector | ||
@test all(s -> s isa Real, x_vec) | ||
check_inferred && @inferred back(x_vec) | ||
@test x == back(x_vec) | ||
return nothing | ||
end | ||
|
||
@testset "chainrulesutils" begin | ||
y = GBMatrix(sprand(10, 10, 0.5)) | ||
test_to_vec(y) | ||
v = GBVector(sprand(10, 0.5)) | ||
test_to_vec(v) | ||
end |
Oops, something went wrong.