Skip to content

Commit

Permalink
mul/ewise rules for basic arithmetic semiring (#26)
Browse files Browse the repository at this point in the history
* arithmetic groundwork

* arithmetic rules for mul and elwise 1st pass

* tests and a few fixes

* Add mask function, fix eadd(PLUS)

* correct mul rrules

* test folder structure

* mask and vector transpose v1

* Broken constructor rules

* arithmetic groundwork

* arithmetic rules for mul and elwise 1st pass

* tests and a few fixes

* Add mask function, fix eadd(PLUS)

* correct mul rrules

* test folder structure

* Broken constructor rules

* Move out constructor rules for now

* compat

* rm constructorrule includes
  • Loading branch information
Will Kimmerer authored Jul 11, 2021
1 parent 952e7a0 commit f0dd5c9
Show file tree
Hide file tree
Showing 18 changed files with 305 additions and 28 deletions.
14 changes: 10 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ version = "0.4.0"

[deps]
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
ContextVariablesX = "6add18c4-b38d-439d-96f6-d6bc489c04c5"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Expand All @@ -15,8 +18,11 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
SSGraphBLAS_jll = "5.1.2"
CEnum = "0.4"
ContextVariablesX = "0.1"
MacroTools = "0.5"
SSGraphBLAS_jll = "5.1"
julia = "1.6"
CEnum = "0.4.1"
ContextVariablesX = "0.1.1"
MacroTools = "0.5.6"
ChainRulesCore = "0.10"
ChainRulesTestUtils = "0.7"
FiniteDifferences = "0.12"
9 changes: 7 additions & 2 deletions src/SuiteSparseGraphBLAS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,14 @@ include("operations/kronecker.jl")
include("print.jl")
include("import.jl")
include("export.jl")

#EXPERIMENTAL
include("options.jl")
#EXPERIMENTAL
include("chainrules/chainruleutils.jl")
include("chainrules/mulrules.jl")
include("chainrules/ewiserules.jl")
include("chainrules/maprules.jl")
include("chainrules/reducerules.jl")
include("chainrules/selectrules.jl")
#include("random.jl")
include("misc.jl")
export libgb
Expand Down
46 changes: 46 additions & 0 deletions src/chainrules/chainruleutils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import FiniteDifferences
import LinearAlgebra
import ChainRulesCore: frule, rrule
using ChainRulesCore
const RealOrComplex = Union{Real, Complex}

#Required for ChainRulesTestUtils
function FiniteDifferences.to_vec(M::GBMatrix)
I, J, X = findnz(M)
function backtomat(xvec)
return GBMatrix(I, J, xvec; nrows = size(M, 1), ncols = size(M, 2))
end
return X, backtomat
end

function FiniteDifferences.to_vec(v::GBVector)
i, x = findnz(v)
function backtovec(xvec)
return GBVector(i, xvec; nrows=size(v, 1))
end
return x, backtovec
end

function FiniteDifferences.rand_tangent(
rng::AbstractRNG,
x::GBMatrix{T}
) where {T <: Union{AbstractFloat, Complex}}
n = nnz(x)
v = rand(rng, -9:0.01:9, n)
I, J, _ = findnz(x)
return GBMatrix(I, J, v; nrows = size(x, 1), ncols = size(x, 2))
end

function FiniteDifferences.rand_tangent(
rng::AbstractRNG,
x::GBVector{T}
) where {T <: Union{AbstractFloat, Complex}}
n = nnz(x)
v = rand(rng, -9:0.01:9, n)
I, _ = findnz(x)
return GBVector(I, v; nrows = size(x, 1))
end

FiniteDifferences.rand_tangent(rng::AbstractRNG, x::AbstractOp) = NoTangent()
# LinearAlgebra.norm freaks over the nothings.
LinearAlgebra.norm(A::GBArray, p::Real=2) = norm(nonzeros(A), p)
71 changes: 71 additions & 0 deletions src/chainrules/ewiserules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#emul TIMES
function frule(
(_, ΔA, ΔB, _),
::typeof(emul),
A::GBArray,
B::GBArray,
::typeof(BinaryOps.TIMES)
)
Ω = emul(A, B, BinaryOps.TIMES)
∂Ω = emul(ΔA, B, BinaryOps.TIMES) + emul(ΔB, A, BinaryOps.TIMES)
return Ω, ∂Ω
end
function frule((_, ΔA, ΔB), ::typeof(emul), A::GBArray, B::GBArray)
return frule((nothing, ΔA, ΔB, nothing), emul, A, B, BinaryOps.TIMES)
end

function rrule(::typeof(emul), A::GBArray, B::GBArray, ::typeof(BinaryOps.TIMES))
function timespullback(ΔΩ)
∂A = emul(ΔΩ, B)
∂B = emul(ΔΩ, A)
return NoTangent(), ∂A, ∂B, NoTangent()
end
return emul(A, B, BinaryOps.TIMES), timespullback
end

function rrule(::typeof(emul), A::GBArray, B::GBArray)
Ω, fullpb = rrule(emul, A, B, BinaryOps.TIMES)
emulpb(ΔΩ) = fullpb(ΔΩ)[1:3]
return Ω, emulpb
end

############
# eadd rules
############

# PLUS
######

function frule(
(_, ΔA, ΔB, _),
::typeof(eadd),
A::GBArray,
B::GBArray,
::typeof(BinaryOps.PLUS)
)
Ω = eadd(A, B, BinaryOps.PLUS)
∂Ω = eadd(ΔA, ΔB, BinaryOps.PLUS)
return Ω, ∂Ω
end
function frule((_, ΔA, ΔB), ::typeof(eadd), A::GBArray, B::GBArray)
return frule((nothing, ΔA, ΔB, nothing), eadd, A, B, BinaryOps.PLUS)
end

function rrule(::typeof(eadd), A::GBArray, B::GBArray, ::typeof(BinaryOps.PLUS))
function pluspullback(ΔΩ)
return (
NoTangent(),
mask(ΔΩ, A; structural = true),
mask(ΔΩ, B; structural = true),
NoTangent()
)
end
return eadd(A, B, BinaryOps.PLUS), pluspullback
end

# Do I have to duplicate this? I get 4 tangents instead of 3 if I call the previous rule.
function rrule(::typeof(eadd), A::GBArray, B::GBArray)
Ω, fullpb = rrule(eadd, A, B, BinaryOps.PLUS)
eaddpb(ΔΩ) = fullpb(ΔΩ)[1:3]
return Ω, eaddpb
end
17 changes: 17 additions & 0 deletions src/chainrules/maprules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Per Lyndon. Needs adaptation, and/or needs redefinition of map to use functions rather
# than AbstractOp.
#function rrule(map, f, xs)
# # Rather than 3 maps really want 1 multimap
# ys_and_pullbacks = map(x->rrule(f, x), xs) #Take this to ys = map(f, x)
# ys = map(first, ys_and_pullbacks)
# pullbacks = map(last, ys_and_pullbacks)
# function map_pullback(dys)
# _call(f, x) = f(x)
# dfs_and_dxs = map(_call, pullbacks, dys)
# # but in your case you know it will be NoTangent() so can skip
# df = sum(first, dfs_and_dxs)
# dxs = map(last, dfs_and_dxs)
# return NoTangent(), df, dxs
# end
# return ys, map_pullback
#end
51 changes: 51 additions & 0 deletions src/chainrules/mulrules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Standard arithmetic mul:
function frule(
(_, ΔA, ΔB),
::typeof(mul),
A::GBMatOrTranspose,
B::GBMatOrTranspose
)
frule((nothing, ΔA, ΔB, nothing), mul, A, B, Semirings.PLUS_TIMES)
end
function frule(
(_, ΔA, ΔB, _),
::typeof(mul),
A::GBMatOrTranspose,
B::GBMatOrTranspose,
::typeof(Semirings.PLUS_TIMES)
)
Ω = mul(A, B)
∂Ω = mul(ΔA, B) + mul(A, ΔB)
return Ω, ∂Ω
end
# Tests will not pass for this. For two reasons.
# First is #25, the output inference is not type stable.
# That's it's own issue.

# Second, to_vec currently works by mapping materialized values back and forth, ie. it knows nothing about nothings.
# This means they give different answers. FiniteDifferences is probably "incorrect", but I have no proof.

function rrule(
::typeof(mul),
A::GBMatOrTranspose,
B::GBMatOrTranspose,
::typeof(Semirings.PLUS_TIMES)
)
function mulpullback(ΔΩ)
∂A = mul(ΔΩ, B'; mask=A)
∂B = mul(A', ΔΩ; mask=B)
return NoTangent(), ∂A, ∂B, NoTangent()
end
return mul(A, B), mulpullback
end


function rrule(
::typeof(mul),
A::GBMatOrTranspose,
B::GBMatOrTranspose
)
Ω, mulpullback = rrule(mul, A, B, Semirings.PLUS_TIMES)
pullback(ΔΩ) = mulpullback(ΔΩ)[1:3]
return Ω, pullback
end
Empty file added src/chainrules/reducerules.jl
Empty file.
Empty file added src/chainrules/selectrules.jl
Empty file.
22 changes: 11 additions & 11 deletions src/lib/LibGraphBLAS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,27 +27,27 @@ macro wraperror(code)
elseif info == GrB_NO_VALUE
return nothing
else
if info == GrB_UNINITIALIZED_OBJECT
if info == GrB_UNINITIALIZED_OBJECT
throw(UninitializedObjectError)
elseif info == GrB_INVALID_OBJECT
elseif info == GrB_INVALID_OBJECT
throw(InvalidObjectError)
elseif info == GrB_NULL_POINTER
elseif info == GrB_NULL_POINTER
throw(NullPointerError)
elseif info == GrB_INVALID_VALUE
elseif info == GrB_INVALID_VALUE
throw(InvalidValueError)
elseif info == GrB_INVALID_INDEX
elseif info == GrB_INVALID_INDEX
throw(InvalidIndexError)
elseif info == GrB_DOMAIN_MISMATCH
elseif info == GrB_DOMAIN_MISMATCH
throw(DomainError(nothing, "GraphBLAS Domain Mismatch"))
elseif info == GrB_DIMENSION_MISMATCH
throw(DimensionMismatch())
elseif info == GrB_OUTPUT_NOT_EMPTY
elseif info == GrB_OUTPUT_NOT_EMPTY
throw(OutputNotEmptyError)
elseif info == GrB_OUT_OF_MEMORY
elseif info == GrB_OUT_OF_MEMORY
throw(OutOfMemoryError())
elseif info == GrB_INSUFFICIENT_SPACE
elseif info == GrB_INSUFFICIENT_SPACE
throw(InsufficientSpaceError)
elseif info == GrB_INDEX_OUT_OF_BOUNDS
elseif info == GrB_INDEX_OUT_OF_BOUNDS
throw(BoundsError())
elseif info == GrB_PANIC
throw(PANIC)
Expand Down Expand Up @@ -843,7 +843,7 @@ for T ∈ valid_vec
nvals = GrB_Vector_nvals(v)
I = Vector{GrB_Index}(undef, nvals)
X = Vector{$type}(undef, nvals)
nvals = Ref{GrB_Index}()
nvals = Ref{GrB_Index}(nvals)
$func(I, X, nvals, v)
nvals[] == length(I) == length(X) || throw(DimensionMismatch())
return I .+ 1, X
Expand Down
9 changes: 5 additions & 4 deletions src/matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Create an nrows x ncols GBMatrix M such that M[I[k], J[k]] = X[k]. The dup funct
to `|` for booleans and `+` for nonbooleans.
"""
function GBMatrix(
I::Vector, J::Vector, X::Vector{T};
I::AbstractVector, J::AbstractVector, X::AbstractVector{T};
dup = BinaryOps.PLUS, nrows = maximum(I), ncols = maximum(J)
) where {T}
A = GBMatrix{T}(nrows, ncols)
Expand All @@ -33,14 +33,14 @@ Create an nrows x ncols GBMatrix M such that M[I[k], J[k]] = x.
The resulting matrix is "iso-valued" such that it only stores `x` once rather than once for
each index.
"""
function GBMatrix(I::Vector, J::Vector, x::T;
function GBMatrix(I::AbstractVector, J::AbstractVector, x::T;
nrows = maximum(I), ncols = maximum(J)) where {T}
A = GBMatrix{T}(nrows, ncols)
build(A, I, J, x)
return A
end

function build(A::GBMatrix{T}, I::Vector, J::Vector, x::T) where {T}
function build(A::GBMatrix{T}, I::AbstractVector, J::AbstractVector, x::T) where {T}
nnz(A) == 0 || throw(libgb.OutputNotEmptyError("Cannot build matrix with existing elements"))
length(I) == length(J) || DimensionMismatch("I, J and X must have the same length")
x = GBScalar(x)
Expand Down Expand Up @@ -158,7 +158,8 @@ function Base.show(io::IO, ::MIME"text/plain", A::GBMatrix)
gxbprint(io, A)
end

SparseArrays.nonzeros(A::GBArray) = findnz(A)[3]
SparseArrays.nonzeros(A::GBArray) = findnz(A)[end]


# Indexing functions
####################
Expand Down
8 changes: 7 additions & 1 deletion src/operations/ewise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ function emul!(
desc = nothing
)
op, mask, accum, desc = _handlectx(op, mask, accum, desc, BinaryOps.TIMES)

size(w) == size(u) == size(v) || throw(DimensionMismatch())
op = getoperator(op, optype(u, v))
accum = getoperator(accum, eltype(w))
Expand Down Expand Up @@ -275,6 +274,13 @@ function eadd(
return eadd!(C, A, B, op; mask, accum, desc)
end

function Base.:+(A::GBArray, B::GBArray)
eadd(A, B, nothing)
end

function Base.:-(A::GBArray, B::GBArray)
eadd(A, B, BinaryOps.MINUS)
end
#Elementwise Broadcasts
#######################

Expand Down
1 change: 0 additions & 1 deletion src/operations/mul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ function LinearAlgebra.mul!(
return w
end


"""
mul(A::GBArray, B::GBArray; kwargs...)::GBArray
Expand Down
1 change: 1 addition & 0 deletions src/operations/transpose.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ end
mask!(C::GBArray, A::GBArray, mask::GBArray)
Apply a mask to matrix `A`, storing the results in C.
"""
function mask!(C::GBArray, A::GBArray, mask::GBArray; structural = false, complement = false)
desc = Descriptors.T0
Expand Down
8 changes: 4 additions & 4 deletions src/vector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ GBVector{T}(dims::Dims{1}) where {T} = GBVector{T}(dims...)
Create a GBVector from a vector of indices `I` and a vector of values `X`.
"""
function GBVector(I::Vector, X::Vector{T}; dup = BinaryOps.PLUS) where {T}
x = GBVector{T}(maximum(I))
function GBVector(I::AbstractVector, X::AbstractVector{T}; dup = BinaryOps.PLUS, nrows = maximum(I)) where {T}
x = GBVector{T}(nrows)
build(x, I, X, dup = dup)
return x
end
Expand All @@ -27,14 +27,14 @@ Create an nrows length GBVector v such that M[I[k]] = x.
The resulting vector is "iso-valued" such that it only stores `x` once rather than once for
each index.
"""
function GBVector(I::Vector, x::T;
function GBVector(I::AbstractVector, x::T;
nrows = maximum(I)) where {T}
A = GBVector{T}(nrows)
build(A, I, x)
return A
end

function build(A::GBVector{T}, I::Vector, x::T) where {T}
function build(A::GBVector{T}, I::AbstractVector, x::T) where {T}
nnz(A) == 0 || throw(libgb.OutputNotEmptyError("Cannot build vector with existing elements"))
x = GBScalar(x)

Expand Down
17 changes: 17 additions & 0 deletions test/chainrules/chainrulesutils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using FiniteDifferences
function test_to_vec(x::T; check_inferred=true) where {T}
check_inferred && @inferred FiniteDifferences.to_vec(x)
x_vec, back = FiniteDifferences.to_vec(x)
@test x_vec isa Vector
@test all(s -> s isa Real, x_vec)
check_inferred && @inferred back(x_vec)
@test x == back(x_vec)
return nothing
end

@testset "chainrulesutils" begin
y = GBMatrix(sprand(10, 10, 0.5))
test_to_vec(y)
v = GBVector(sprand(10, 0.5))
test_to_vec(v)
end
Loading

0 comments on commit f0dd5c9

Please sign in to comment.