From 49d92d9e4bd701b7b997a7e7ff59e12fed36c1f0 Mon Sep 17 00:00:00 2001 From: Joe Greener Date: Thu, 27 Jun 2024 17:43:03 +0100 Subject: [PATCH] boundary gradient functions --- src/zygote.jl | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/zygote.jl b/src/zygote.jl index ac356fee..3c8d0d9e 100644 --- a/src/zygote.jl +++ b/src/zygote.jl @@ -83,6 +83,14 @@ function Zygote.accum(x::NamedTuple{(:side_lengths,), Tuple{SVector{2, T}}}, y:: RectangularBoundary(x.side_lengths .+ y; check_positive=false) end +function Zygote.accum(x::NamedTuple{(:side_lengths,), Tuple{SVector{3, T}}}, y::SizedVector{3, T, Vector{T}}) where T + CubicBoundary(x.side_lengths .+ y; check_positive=false) +end + +function Zygote.accum(x::NamedTuple{(:side_lengths,), Tuple{SVector{2, T}}}, y::SizedVector{2, T, Vector{T}}) where T + RectangularBoundary(x.side_lengths .+ y; check_positive=false) +end + function Zygote.accum(x::NamedTuple{(:side_lengths,), Tuple{SizedVector{3, T, Vector{T}}}}, y::SVector{3, T}) where T CubicBoundary(SVector{3, T}(x.side_lengths .+ y); check_positive=false) end @@ -99,6 +107,14 @@ function Base.:+(x::NamedTuple{(:side_lengths,), Tuple{SizedVector{2, T, Vector{ RectangularBoundary(SVector{2, T}(x.side_lengths .+ y.side_lengths); check_positive=false) end +function Base.:+(x::CubicBoundary{T}, y::NamedTuple{(:side_lengths,), Tuple{SVector{3, T}}}) where T + CubicBoundary(SVector{3, T}(x.side_lengths .+ y.side_lengths); check_positive=false) +end + +function Base.:+(x::RectangularBoundary{T}, y::NamedTuple{(:side_lengths,), Tuple{SVector{2, T}}}) where T + RectangularBoundary(SVector{2, T}(x.side_lengths .+ y.side_lengths); check_positive=false) +end + function Base.:+(x::SVector{3, T}, y::CubicBoundary{T}) where T CubicBoundary(x .+ y.side_lengths; check_positive=false) end