diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 5f4e969e..245825f7 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -53,7 +53,7 @@ def compute_welch_spectrum(sig, fs=None): return spectogram, freqs -def morlet(M=1024, ncycles=1.5, scaling=1.0, precision=8): +def _morlet(M=1024, ncycles=1.5, scaling=1.0, precision=8): """ Defines the complex Morelet wavelet kernel @@ -235,7 +235,7 @@ def _convolve_wavelet( """ if norm not in ["sss", "amp"]: raise ValueError("Given `norm` must be `sss` or `amp`") - morlet_f = morlet(int(2**precision), ncycles=n_cycles, scaling=scaling) + morlet_f = _morlet(int(2**precision), ncycles=n_cycles, scaling=scaling) x = np.linspace(-8, 8, int(2**precision)) int_psi = np.conj(_integrate(morlet_f, x[1] - x[0])) scale = scaling / (freq / fs) diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py new file mode 100644 index 00000000..8da68921 --- /dev/null +++ b/tests/test_signal_processing.py @@ -0,0 +1,37 @@ +"""Tests of `signal_processing` for pynapple""" + +import numpy as np +import pytest + +import pynapple as nap + + +def test_compute_spectrum(): + t = np.linspace(0, 1, 1024) + sig = nap.Tsd(d=np.random.random(1024), t=t) + r = nap.compute_spectrum(sig) + assert len(r[1]) == 1024 + assert len(r[0]) == 1024 + assert r[0].dtype == np.complex128 + assert r[1].dtype == np.float64 + + +def test_compute_welch_spectrum(): + t = np.linspace(0, 1, 1024) + sig = nap.Tsd(d=np.random.random(1024), t=t) + r = nap.compute_welch_spectrum(sig) + assert r[0].dtype == np.float64 + assert r[1].dtype == np.float64 + + +def test_compute_wavelet_transform(): + t = np.linspace(0, 1, 1024) + sig = nap.Tsd(d=np.random.random(1024), t=t) + freqs = np.linspace(1, 600, 10) + mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) + assert mwt.shape == (1024, 10) + + sig = nap.TsdFrame(d=np.random.random((1024, 4)), t=t) + freqs = np.linspace(1, 600, 10) + mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) + assert mwt.shape == (1024, 10, 4)