Skip to content

Commit

Permalink
basic tests added
Browse files Browse the repository at this point in the history
  • Loading branch information
kippfreud committed Jun 28, 2024
1 parent ebdbe67 commit 75d3a46
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 2 deletions.
4 changes: 2 additions & 2 deletions pynapple/process/signal_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def compute_welch_spectrum(sig, fs=None):
return spectogram, freqs


def morlet(M=1024, ncycles=1.5, scaling=1.0, precision=8):
def _morlet(M=1024, ncycles=1.5, scaling=1.0, precision=8):
"""
Defines the complex Morelet wavelet kernel
Expand Down Expand Up @@ -235,7 +235,7 @@ def _convolve_wavelet(
"""
if norm not in ["sss", "amp"]:
raise ValueError("Given `norm` must be `sss` or `amp`")
morlet_f = morlet(int(2**precision), ncycles=n_cycles, scaling=scaling)
morlet_f = _morlet(int(2**precision), ncycles=n_cycles, scaling=scaling)
x = np.linspace(-8, 8, int(2**precision))
int_psi = np.conj(_integrate(morlet_f, x[1] - x[0]))
scale = scaling / (freq / fs)
Expand Down
37 changes: 37 additions & 0 deletions tests/test_signal_processing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Tests of `signal_processing` for pynapple"""

import numpy as np
import pytest

import pynapple as nap


def test_compute_spectrum():
t = np.linspace(0, 1, 1024)
sig = nap.Tsd(d=np.random.random(1024), t=t)
r = nap.compute_spectrum(sig)
assert len(r[1]) == 1024
assert len(r[0]) == 1024
assert r[0].dtype == np.complex128
assert r[1].dtype == np.float64


def test_compute_welch_spectrum():
t = np.linspace(0, 1, 1024)
sig = nap.Tsd(d=np.random.random(1024), t=t)
r = nap.compute_welch_spectrum(sig)
assert r[0].dtype == np.float64
assert r[1].dtype == np.float64


def test_compute_wavelet_transform():
t = np.linspace(0, 1, 1024)
sig = nap.Tsd(d=np.random.random(1024), t=t)
freqs = np.linspace(1, 600, 10)
mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs)
assert mwt.shape == (1024, 10)

sig = nap.TsdFrame(d=np.random.random((1024, 4)), t=t)
freqs = np.linspace(1, 600, 10)
mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs)
assert mwt.shape == (1024, 10, 4)

0 comments on commit 75d3a46

Please sign in to comment.