Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fft: satisfy Plancherel #1075

Merged
merged 12 commits into from
Jun 12, 2022
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/).

## [Unreleased]

### Fixed
- `kit.fft`: fixed bug where Fourier coefficients were off by a scalar factor.

## [3.4.4]

### Added
Expand Down
31 changes: 19 additions & 12 deletions WrightTools/kit/_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from .. import exceptions as wt_exceptions

from typing import Tuple


# --- define --------------------------------------------------------------------------------------

Expand Down Expand Up @@ -120,35 +122,40 @@ def diff(xi, yi, order=1) -> np.ndarray:
return yi


def fft(xi, yi, axis=0) -> tuple:
"""Take the 1D FFT of an N-dimensional array and return "sensible" properly shifted arrays.
def fft(xi, yi, axis=0) -> Tuple[np.ndarray, np.ndarray]:
"""Compute a discrete Fourier Transform along one axis of an N-dimensional
array and also compute the 1D frequency coordinates of the transform. The
Fourier coefficients and frequency coordinates are ordered so that the
coordinates are monotonic (i.e. uses `numpy.fft.fftshift`).

Parameters
----------
xi : numpy.ndarray
1D array over which the points to be FFT'ed are defined
yi : numpy.ndarray
ND array with values to FFT
ti : 1D numpy.ndarray
Independent variable specifying data coordinates. Must be monotonic,
linearly spaced data. `ti.size` must be equal to `yi.shape[axis]`
yi : n-dimensional numpy.ndarray
Dependent variable. ND array with values to FFT.
axis : int
axis of yi to perform FFT over

Returns
-------
xi : 1D numpy.ndarray
1D array. Conjugate to input xi. Example: if input xi is in the time
domain, output xi is in frequency domain.
yi : ND numpy.ndarray
FFT. Has the same shape as the input array (yi).
1D array. Conjugate coordinates to input xi. Example: if input `xi`
is time coordinates, output `xi` is (cyclic) frequency coordinates.
yi : complex numpy.ndarray
Transformed data. Has the same shape as the input array (yi).
"""
# xi must be 1D
if xi.ndim != 1:
raise wt_exceptions.DimensionalityError(1, xi.ndim)
# xi must be evenly spaced
spacing = np.diff(xi)
if not np.allclose(spacing, spacing.mean()):
spacing_mean = spacing.mean()
if not np.allclose(spacing, spacing_mean):
raise RuntimeError("WrightTools.kit.fft: argument xi must be evenly spaced")
# fft
yi = np.fft.fft(yi, axis=axis)
yi = np.fft.fft(yi, axis=axis) * spacing_mean
d = (xi.max() - xi.min()) / (xi.size - 1)
xi = np.fft.fftfreq(xi.size, d=d)
# shift
Expand Down
20 changes: 16 additions & 4 deletions tests/kit/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,24 @@
# --- test ----------------------------------------------------------------------------------------


def test_1_sin():
def test_analytic_fft():
a = 1 - 1j
t = np.linspace(0, 10, 10000)
z = np.heaviside(t, 0.5) * np.exp(-a * t)
wi, zi = wt.kit.fft(t, z)
zi_analytical = 1 / (a + 1j * 2 * np.pi * wi)
assert np.allclose(zi.real, zi_analytical.real, atol=1e-3)
assert np.allclose(zi.imag, zi_analytical.imag, atol=1e-3)


def test_plancherel():
t = np.linspace(-10, 10, 10000)
z = np.sin(2 * np.pi * t)
wi, zi = wt.kit.fft(t, z)
freq = np.abs(wi[np.argmax(zi)])
assert np.isclose(freq, 1, rtol=1e-3, atol=1e-3)
intensity_time = (z**2).sum() * (t[1] - t[0])
intensity_freq = (zi * zi.conjugate()).real.sum() * (wi[1] - wi[0])
rel_error = np.abs(intensity_time - intensity_freq) / (intensity_time + intensity_freq)
assert rel_error < 1e-12


def test_5_sines():
Expand All @@ -28,7 +40,7 @@ def test_5_sines():
z = np.sin(2 * np.pi * freqs[None, :] * t[:, None])
wi, zi = wt.kit.fft(t, z, axis=0)
freq = np.abs(wi[np.argmax(zi, axis=0)])
assert np.all(np.isclose(freq, freqs, rtol=1e-3, atol=1e-3))
assert np.allclose(freq, freqs, rtol=1e-3, atol=1e-3)


def test_dimensionality_error():
Expand Down