Skip to content

Commit

Permalink
simplified compute_wavelet_transform, added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kippfreud committed Jul 17, 2024
1 parent ebdb64c commit 0f17883
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 21 deletions.
28 changes: 12 additions & 16 deletions pynapple/process/signal_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,23 +191,19 @@ def compute_wavelet_transform(
output_shape = (sig.shape[0], len(freqs), *sig.shape[1:])
sig = sig.reshape((sig.shape[0], np.prod(sig.shape[1:])))

cwt = np.zeros(
[sig.values.shape[0], len(freqs), sig.values.shape[1]], dtype=complex
)

filter_bank = generate_morlet_filterbank(freqs, fs, n_cycles, scaling, precision)
for f_i, filter in enumerate(filter_bank):
convolved = sig.convolve(np.transpose(np.asarray([filter.real, filter.imag])))
convolved = convolved[:, :, 0].values + convolved[:, :, 1].values * 1j
coef = -np.diff(convolved, axis=0)
if norm == "sss":
coef *= -np.sqrt(scaling) / (freqs[f_i] / fs)
elif norm == "amp":
coef *= -scaling / (freqs[f_i] / fs)
coef = np.insert(
coef, 1, coef[0], axis=0
) # slightly hacky line, necessary to make output correct shape
cwt[:, f_i, :] = coef if len(coef.shape) == 2 else np.expand_dims(coef, axis=1)
convolved_real = sig.convolve(np.transpose(filter_bank.real))
convolved_imag = sig.convolve(np.transpose(filter_bank.imag))
convolved = convolved_real.values + convolved_imag.values * 1j
coef = -np.diff(convolved, axis=0)
if norm == "sss":
coef *= coef * (-np.sqrt(scaling) / (freqs / fs))
elif norm == "amp":
coef *= -scaling / (freqs / fs)
coef = np.insert(
coef, 1, coef[0, :], axis=0
) # slightly hacky line, necessary to make output correct shape
cwt = np.swapaxes(coef, 1, 2)

if len(output_shape) == 2:
return nap.TsdFrame(
Expand Down
21 changes: 16 additions & 5 deletions tests/test_signal_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,22 @@ def test_compute_wavelet_transform():
assert mpf == 20
assert mwt.shape == (1000, 10)

t = np.linspace(0, 1, 1000)
sig = nap.Tsd(d=np.sin(t * 20 * np.pi * 2), t=t)
freqs = np.linspace(10, 100, 10)
mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, norm="sss")
mpf = freqs[np.argmax(np.sum(np.abs(mwt), axis=0))]
assert mpf == 20
assert mwt.shape == (1000, 10)

t = np.linspace(0, 1, 1000)
sig = nap.Tsd(d=np.sin(t * 20 * np.pi * 2), t=t)
freqs = np.linspace(10, 100, 10)
mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, norm="amp")
mpf = freqs[np.argmax(np.sum(np.abs(mwt), axis=0))]
assert mpf == 20
assert mwt.shape == (1000, 10)

t = np.linspace(0, 1, 1000)
sig = nap.Tsd(d=np.sin(t * 70 * np.pi * 2), t=t)
freqs = np.linspace(10, 100, 10)
Expand Down Expand Up @@ -125,8 +141,3 @@ def test_compute_wavelet_transform():
with pytest.raises(ValueError) as e_info:
nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, n_cycles=-1.5)
assert str(e_info.value) == "Number of cycles must be a positive number."


if __name__ == "__main__":
test_compute_wavelet_transform()
# test_compute_welch_spectogram()

0 comments on commit 0f17883

Please sign in to comment.