-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathFFTOp.jl
96 lines (83 loc) · 2.74 KB
/
FFTOp.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
export FFTOp
import Base.copy
mutable struct FFTOp{T} <: AbstractLinearOperator{T}
nrow :: Int
ncol :: Int
symmetric :: Bool
hermitian :: Bool
prod! :: Function
tprod! :: Nothing
ctprod! :: Function
nprod :: Int
ntprod :: Int
nctprod :: Int
args5 :: Bool
use_prod5! :: Bool
allocated5 :: Bool
Mv5 :: Vector{T}
Mtu5 :: Vector{T}
plan
iplan
shift::Bool
unitary::Bool
end
LinearOperators.storage_type(op::FFTOp) = typeof(op.Mv5)
"""
FFTOp(T::Type, shape::Tuple, shift=true, unitary=true)
returns an operator which performs an FFT on Arrays of type T
# Arguments:
* `T::Type` - type of the array to transform
* `shape::Tuple` - size of the array to transform
* (`shift=true`) - if true, fftshifts are performed
* (`unitary=true`) - if true, FFT is normalized such that it is unitary
"""
function FFTOp(T::Type, shape::NTuple{D,Int64}, shift::Bool=true; unitary::Bool=true, cuda::Bool=false) where D
#tmpVec = cuda ? CuArray{T}(undef,shape) : Array{Complex{real(T)}}(undef, shape)
tmpVec = Array{Complex{real(T)}}(undef, shape)
plan = plan_fft!(tmpVec; flags=FFTW.MEASURE)
iplan = plan_bfft!(tmpVec; flags=FFTW.MEASURE)
if unitary
facF = T(1.0/sqrt(prod(shape)))
facB = T(1.0/sqrt(prod(shape)))
else
facF = T(1.0)
facB = T(1.0)
end
let shape_=shape, plan_=plan, iplan_=iplan, tmpVec_=tmpVec, facF_=facF, facB_=facB
if shift
return FFTOp{T}(prod(shape), prod(shape), false, false
, (res, x) -> fft_multiply_shift!(res, plan_, x, shape_, facF_, tmpVec_)
, nothing
, (res, x) -> fft_multiply_shift!(res, iplan_, x, shape_, facB_, tmpVec_)
, 0, 0, 0, true, false, true, T[], T[]
, plan
, iplan
, shift
, unitary)
else
return FFTOp{T}(prod(shape), prod(shape), false, false
, (res, x) -> fft_multiply!(res, plan_, x, facF_, tmpVec_)
, nothing
, (res, x) -> fft_multiply!(res, iplan_, x, facB_, tmpVec_)
, 0, 0, 0, true, false, true, T[], T[]
, plan
, iplan
, shift
, unitary)
end
end
end
function fft_multiply!(res::AbstractVector{T}, plan::P, x::AbstractVector{Tr}, factor::T, tmpVec::Array{T,D}) where {T, Tr, P<:AbstractFFTs.Plan, D}
tmpVec[:] .= x
plan * tmpVec
res .= factor .* vec(tmpVec)
end
function fft_multiply_shift!(res::AbstractVector{T}, plan::P, x::AbstractVector{Tr}, shape::NTuple{D}, factor::T, tmpVec::Array{T,D}) where {T, Tr, P<:AbstractFFTs.Plan, D}
ifftshift!(tmpVec, reshape(x,shape))
plan * tmpVec
fftshift!(reshape(res,shape), tmpVec)
res .*= factor
end
function Base.copy(S::FFTOp)
return FFTOp(eltype(S), size(S.plan), S.shift, unitary=S.unitary)
end