Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mul/ewise rules for basic arithmetic semiring #26

Merged
merged 20 commits into from
Jul 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ version = "0.4.0"

[deps]
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
ContextVariablesX = "6add18c4-b38d-439d-96f6-d6bc489c04c5"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Expand All @@ -15,8 +18,11 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
rayegun marked this conversation as resolved.
Show resolved Hide resolved
SSGraphBLAS_jll = "5.1.2"
CEnum = "0.4"
ContextVariablesX = "0.1"
MacroTools = "0.5"
SSGraphBLAS_jll = "5.1"
julia = "1.6"
CEnum = "0.4.1"
ContextVariablesX = "0.1.1"
MacroTools = "0.5.6"
ChainRulesCore = "0.10"
ChainRulesTestUtils = "0.7"
FiniteDifferences = "0.12"
9 changes: 7 additions & 2 deletions src/SuiteSparseGraphBLAS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,14 @@ include("operations/kronecker.jl")
include("print.jl")
include("import.jl")
include("export.jl")

#EXPERIMENTAL
include("options.jl")
#EXPERIMENTAL
include("chainrules/chainruleutils.jl")
include("chainrules/mulrules.jl")
include("chainrules/ewiserules.jl")
include("chainrules/maprules.jl")
include("chainrules/reducerules.jl")
include("chainrules/selectrules.jl")
#include("random.jl")
include("misc.jl")
export libgb
Expand Down
46 changes: 46 additions & 0 deletions src/chainrules/chainruleutils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import FiniteDifferences
import LinearAlgebra
import ChainRulesCore: frule, rrule
using ChainRulesCore
const RealOrComplex = Union{Real, Complex}

#Required for ChainRulesTestUtils
function FiniteDifferences.to_vec(M::GBMatrix)
rayegun marked this conversation as resolved.
Show resolved Hide resolved
I, J, X = findnz(M)
function backtomat(xvec)
return GBMatrix(I, J, xvec; nrows = size(M, 1), ncols = size(M, 2))
end
return X, backtomat
end
rayegun marked this conversation as resolved.
Show resolved Hide resolved

function FiniteDifferences.to_vec(v::GBVector)
i, x = findnz(v)
function backtovec(xvec)
return GBVector(i, xvec; nrows=size(v, 1))
end
return x, backtovec
end

function FiniteDifferences.rand_tangent(
rng::AbstractRNG,
x::GBMatrix{T}
) where {T <: Union{AbstractFloat, Complex}}
n = nnz(x)
v = rand(rng, -9:0.01:9, n)
I, J, _ = findnz(x)
return GBMatrix(I, J, v; nrows = size(x, 1), ncols = size(x, 2))
end

function FiniteDifferences.rand_tangent(
rng::AbstractRNG,
x::GBVector{T}
) where {T <: Union{AbstractFloat, Complex}}
n = nnz(x)
v = rand(rng, -9:0.01:9, n)
I, _ = findnz(x)
return GBVector(I, v; nrows = size(x, 1))
end

FiniteDifferences.rand_tangent(rng::AbstractRNG, x::AbstractOp) = NoTangent()
# LinearAlgebra.norm freaks over the nothings.
LinearAlgebra.norm(A::GBArray, p::Real=2) = norm(nonzeros(A), p)
71 changes: 71 additions & 0 deletions src/chainrules/ewiserules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#emul TIMES
function frule(
(_, ΔA, ΔB, _),
::typeof(emul),
A::GBArray,
B::GBArray,
::typeof(BinaryOps.TIMES)
)
Ω = emul(A, B, BinaryOps.TIMES)
∂Ω = emul(ΔA, B, BinaryOps.TIMES) + emul(ΔB, A, BinaryOps.TIMES)
return Ω, ∂Ω
end
function frule((_, ΔA, ΔB), ::typeof(emul), A::GBArray, B::GBArray)
return frule((nothing, ΔA, ΔB, nothing), emul, A, B, BinaryOps.TIMES)
end

function rrule(::typeof(emul), A::GBArray, B::GBArray, ::typeof(BinaryOps.TIMES))
function timespullback(ΔΩ)
∂A = emul(ΔΩ, B)
∂B = emul(ΔΩ, A)
return NoTangent(), ∂A, ∂B, NoTangent()
end
return emul(A, B, BinaryOps.TIMES), timespullback
end

function rrule(::typeof(emul), A::GBArray, B::GBArray)
Ω, fullpb = rrule(emul, A, B, BinaryOps.TIMES)
emulpb(ΔΩ) = fullpb(ΔΩ)[1:3]
return Ω, emulpb
end

############
# eadd rules
############

# PLUS
######

function frule(
(_, ΔA, ΔB, _),
::typeof(eadd),
A::GBArray,
B::GBArray,
::typeof(BinaryOps.PLUS)
)
Ω = eadd(A, B, BinaryOps.PLUS)
∂Ω = eadd(ΔA, ΔB, BinaryOps.PLUS)
return Ω, ∂Ω
end
function frule((_, ΔA, ΔB), ::typeof(eadd), A::GBArray, B::GBArray)
return frule((nothing, ΔA, ΔB, nothing), eadd, A, B, BinaryOps.PLUS)
end

function rrule(::typeof(eadd), A::GBArray, B::GBArray, ::typeof(BinaryOps.PLUS))
function pluspullback(ΔΩ)
return (
NoTangent(),
mask(ΔΩ, A; structural = true),
mask(ΔΩ, B; structural = true),
NoTangent()
)
end
return eadd(A, B, BinaryOps.PLUS), pluspullback
end

# Do I have to duplicate this? I get 4 tangents instead of 3 if I call the previous rule.
rayegun marked this conversation as resolved.
Show resolved Hide resolved
function rrule(::typeof(eadd), A::GBArray, B::GBArray)
Ω, fullpb = rrule(eadd, A, B, BinaryOps.PLUS)
eaddpb(ΔΩ) = fullpb(ΔΩ)[1:3]
return Ω, eaddpb
end
17 changes: 17 additions & 0 deletions src/chainrules/maprules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Per Lyndon. Needs adaptation, and/or needs redefinition of map to use functions rather
# than AbstractOp.
#function rrule(map, f, xs)
# # Rather than 3 maps really want 1 multimap
# ys_and_pullbacks = map(x->rrule(f, x), xs) #Take this to ys = map(f, x)
# ys = map(first, ys_and_pullbacks)
# pullbacks = map(last, ys_and_pullbacks)
# function map_pullback(dys)
# _call(f, x) = f(x)
# dfs_and_dxs = map(_call, pullbacks, dys)
# # but in your case you know it will be NoTangent() so can skip
# df = sum(first, dfs_and_dxs)
# dxs = map(last, dfs_and_dxs)
# return NoTangent(), df, dxs
# end
# return ys, map_pullback
#end
51 changes: 51 additions & 0 deletions src/chainrules/mulrules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Standard arithmetic mul:
function frule(
(_, ΔA, ΔB),
::typeof(mul),
A::GBMatOrTranspose,
B::GBMatOrTranspose
)
frule((nothing, ΔA, ΔB, nothing), mul, A, B, Semirings.PLUS_TIMES)
end
function frule(
(_, ΔA, ΔB, _),
::typeof(mul),
A::GBMatOrTranspose,
B::GBMatOrTranspose,
::typeof(Semirings.PLUS_TIMES)
)
Ω = mul(A, B)
∂Ω = mul(ΔA, B) + mul(A, ΔB)
return Ω, ∂Ω
end
# Tests will not pass for this. For two reasons.
# First is #25, the output inference is not type stable.
# That's it's own issue.

# Second, to_vec currently works by mapping materialized values back and forth, ie. it knows nothing about nothings.
# This means they give different answers. FiniteDifferences is probably "incorrect", but I have no proof.

function rrule(
::typeof(mul),
A::GBMatOrTranspose,
B::GBMatOrTranspose,
::typeof(Semirings.PLUS_TIMES)
)
function mulpullback(ΔΩ)
∂A = mul(ΔΩ, B'; mask=A)
∂B = mul(A', ΔΩ; mask=B)
return NoTangent(), ∂A, ∂B, NoTangent()
end
return mul(A, B), mulpullback
end


function rrule(
::typeof(mul),
A::GBMatOrTranspose,
B::GBMatOrTranspose
)
Ω, mulpullback = rrule(mul, A, B, Semirings.PLUS_TIMES)
pullback(ΔΩ) = mulpullback(ΔΩ)[1:3]
return Ω, pullback
end
Empty file added src/chainrules/reducerules.jl
Empty file.
Empty file added src/chainrules/selectrules.jl
Empty file.
22 changes: 11 additions & 11 deletions src/lib/LibGraphBLAS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,27 +27,27 @@ macro wraperror(code)
elseif info == GrB_NO_VALUE
return nothing
else
if info == GrB_UNINITIALIZED_OBJECT
if info == GrB_UNINITIALIZED_OBJECT
throw(UninitializedObjectError)
elseif info == GrB_INVALID_OBJECT
elseif info == GrB_INVALID_OBJECT
throw(InvalidObjectError)
elseif info == GrB_NULL_POINTER
elseif info == GrB_NULL_POINTER
throw(NullPointerError)
elseif info == GrB_INVALID_VALUE
elseif info == GrB_INVALID_VALUE
throw(InvalidValueError)
elseif info == GrB_INVALID_INDEX
elseif info == GrB_INVALID_INDEX
throw(InvalidIndexError)
elseif info == GrB_DOMAIN_MISMATCH
elseif info == GrB_DOMAIN_MISMATCH
throw(DomainError(nothing, "GraphBLAS Domain Mismatch"))
elseif info == GrB_DIMENSION_MISMATCH
throw(DimensionMismatch())
elseif info == GrB_OUTPUT_NOT_EMPTY
elseif info == GrB_OUTPUT_NOT_EMPTY
throw(OutputNotEmptyError)
elseif info == GrB_OUT_OF_MEMORY
elseif info == GrB_OUT_OF_MEMORY
throw(OutOfMemoryError())
elseif info == GrB_INSUFFICIENT_SPACE
elseif info == GrB_INSUFFICIENT_SPACE
throw(InsufficientSpaceError)
elseif info == GrB_INDEX_OUT_OF_BOUNDS
elseif info == GrB_INDEX_OUT_OF_BOUNDS
throw(BoundsError())
elseif info == GrB_PANIC
throw(PANIC)
Expand Down Expand Up @@ -843,7 +843,7 @@ for T ∈ valid_vec
nvals = GrB_Vector_nvals(v)
I = Vector{GrB_Index}(undef, nvals)
X = Vector{$type}(undef, nvals)
nvals = Ref{GrB_Index}()
nvals = Ref{GrB_Index}(nvals)
$func(I, X, nvals, v)
nvals[] == length(I) == length(X) || throw(DimensionMismatch())
return I .+ 1, X
Expand Down
9 changes: 5 additions & 4 deletions src/matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Create an nrows x ncols GBMatrix M such that M[I[k], J[k]] = X[k]. The dup funct
to `|` for booleans and `+` for nonbooleans.
"""
function GBMatrix(
I::Vector, J::Vector, X::Vector{T};
I::AbstractVector, J::AbstractVector, X::AbstractVector{T};
dup = BinaryOps.PLUS, nrows = maximum(I), ncols = maximum(J)
) where {T}
A = GBMatrix{T}(nrows, ncols)
Expand All @@ -33,14 +33,14 @@ Create an nrows x ncols GBMatrix M such that M[I[k], J[k]] = x.
The resulting matrix is "iso-valued" such that it only stores `x` once rather than once for
each index.
"""
function GBMatrix(I::Vector, J::Vector, x::T;
function GBMatrix(I::AbstractVector, J::AbstractVector, x::T;
nrows = maximum(I), ncols = maximum(J)) where {T}
A = GBMatrix{T}(nrows, ncols)
build(A, I, J, x)
return A
end

function build(A::GBMatrix{T}, I::Vector, J::Vector, x::T) where {T}
function build(A::GBMatrix{T}, I::AbstractVector, J::AbstractVector, x::T) where {T}
nnz(A) == 0 || throw(libgb.OutputNotEmptyError("Cannot build matrix with existing elements"))
length(I) == length(J) || DimensionMismatch("I, J and X must have the same length")
x = GBScalar(x)
Expand Down Expand Up @@ -158,7 +158,8 @@ function Base.show(io::IO, ::MIME"text/plain", A::GBMatrix)
gxbprint(io, A)
end

SparseArrays.nonzeros(A::GBArray) = findnz(A)[3]
SparseArrays.nonzeros(A::GBArray) = findnz(A)[end]


# Indexing functions
####################
Expand Down
8 changes: 7 additions & 1 deletion src/operations/ewise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ function emul!(
desc = nothing
)
op, mask, accum, desc = _handlectx(op, mask, accum, desc, BinaryOps.TIMES)

size(w) == size(u) == size(v) || throw(DimensionMismatch())
op = getoperator(op, optype(u, v))
accum = getoperator(accum, eltype(w))
Expand Down Expand Up @@ -275,6 +274,13 @@ function eadd(
return eadd!(C, A, B, op; mask, accum, desc)
end

function Base.:+(A::GBArray, B::GBArray)
eadd(A, B, nothing)
end

function Base.:-(A::GBArray, B::GBArray)
eadd(A, B, BinaryOps.MINUS)
end
#Elementwise Broadcasts
#######################

Expand Down
1 change: 0 additions & 1 deletion src/operations/mul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ function LinearAlgebra.mul!(
return w
end


"""
mul(A::GBArray, B::GBArray; kwargs...)::GBArray
Expand Down
34 changes: 34 additions & 0 deletions src/operations/transpose.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,47 @@ function Base.copy!(
return gbtranspose!(C, A.parent; mask, accum, desc)
end

"""
mask!(C::GBArray, A::GBArray, mask::GBArray)

Apply a mask to matrix `A`, storing the results in C.

"""
function mask!(C::GBArray, A::GBArray, mask::GBArray; structural = false, complement = false)
desc = Descriptors.T0
structural && (desc = desc + Descriptors.S)
complement && (desc = desc + Descriptors.C)
gbtranspose!(C, A; mask, desc)
return C
end

"""
mask(A::GBArray, mask::GBArray)

Apply a mask to matrix `A`.
"""
function mask(A::GBArray, mask::GBArray; structural = false, complement = false)
return mask!(similar(A), A, mask; structural, complement)
end

function Base.copy(
A::LinearAlgebra.Transpose{<:Any, <:GBMatrix};
mask = C_NULL, accum = C_NULL, desc::Descriptor = Descriptors.NULL
)
return gbtranspose(A.parent; mask, accum, desc)
end

function Base.copy(v::LinearAlgebra.Transpose{<:Any, <:GBVector})
A = GBMatrix{eltype(v)}(size(v, 1), size(v, 2))
nz = findnz(v.parent)
for i ∈ 1:length(nz[1])
println(i)
println(nz[1][i], ": ", nz[2][i])
A[1, nz[1][i]] = nz[2][i]
end
return A
end

function _handletranspose(
A::GBArray,
desc::Union{Descriptor, Nothing} = nothing,
Expand Down
Loading