Skip to content

Commit

Permalink
add chainrules-extension for basic nfft-functions
Browse files Browse the repository at this point in the history
  • Loading branch information
migrosser committed Jun 22, 2023
1 parent 18101be commit 8552b28
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 0 deletions.
6 changes: 6 additions & 0 deletions AbstractNFFTs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,11 @@ version = "0.8.2"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

[extensions]
AbstractNFFTsChainRulesCoreExt = "ChainRulesCore"

[compat]
julia = "1.6"
73 changes: 73 additions & 0 deletions AbstractNFFTs/ext/AbstractNFFTsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
module AbstractNFFTsChainRulesCoreExt

using AbstractNFFTs
import ChainRulesCore

###############
# mul-interface
###############
function ChainRulesCore.frule((_, Δx, _), ::typeof(*), plan::AbstractFTPlan, x::AbstractArray)
y = plan*x
Δy = plan*Δx
return y, Δy
end
function ChainRulesCore.rrule(::typeof(*), plan::AbstractFTPlan, x::AbstractArray)
y = plan*x
project_x = ChainRulesCore.ProjectTo(x)
function mul_pullback(ȳ)
= project_x( adjoint(plan)*ChainRulesCore.unthunk(ȳ) )
return ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), x̄
end
return y, mul_pullback
end

##################
# NFFT, NFCT, NFST
##################
for (op,trans) in zip([:nfft, :nfct, :nfst], [:adjoint, :transpose, :transpose])

func_trans = Symbol("$(op)_$(trans)")
pbfunc = Symbol("$(op)_pullback")
pbfunc_trans = Symbol("$(op)_$(trans)_pullback")
@eval begin

# direct trafo
function ChainRulesCore.frule((_, Δx, _), ::typeof($(op)), k::AbstractMatrix, x::AbstractArray)
y = $(op)(k,x)
Δy = $(op)(k,Δx)
return y, Δy
end
function ChainRulesCore.rrule(::typeof($(op)), k::AbstractMatrix, x::AbstractArray)
y = $(op)(k,x)
project_x = ChainRulesCore.ProjectTo(x)
function $(pbfunc)(ȳ)
= project_x($(func_trans)(k, size(y), ChainRulesCore.unthunk(ȳ)))
return ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), x̄
end
return y, nfft_pullback
end

# adjoint trafo
function ChainRulesCore.frule((_, Δx, _), ::typeof($(func_trans)), k::AbstractMatrix, N, x::AbstractArray)
y = $(func_trans)(k,N,x)
Δy = $(func_trans)(k,N,Δx)
return y, Δy
end
function ChainRulesCore.rrule(::typeof($(func_trans)), k::AbstractMatrix, N, x::AbstractArray)
y = $(func_trans)(k,N,x)
project_x = ChainRulesCore.ProjectTo(x)
function $(pbfunc_trans)(ȳ)
= project_x($(op)(k, ChainRulesCore.unthunk(ȳ)))
return ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), x̄
end
return y, $(pbfunc_trans)
end

end

end




end # module

0 comments on commit 8552b28

Please sign in to comment.