Skip to content

Commit

Permalink
linting changes
Browse files Browse the repository at this point in the history
  • Loading branch information
kippfreud committed Jun 7, 2024
1 parent be3c68f commit 4d41ca8
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 16 deletions.
2 changes: 1 addition & 1 deletion pynapple/process/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
shift_timestamps,
shuffle_ts_intervals,
)
from .signal_processing import *
from .tuning_curves import (
compute_1d_mutual_info,
compute_1d_tuning_curves,
Expand All @@ -24,4 +25,3 @@
compute_2d_tuning_curves_continuous,
compute_discrete_tuning_curves,
)
from .signal_processing import *
43 changes: 28 additions & 15 deletions pynapple/process/signal_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
"""

import numpy as np
from itertools import repeat

import numpy as np

import pynapple as nap


def compute_fft(sig, fs):
"""
"""
Performs numpy fft on sig, returns output
..todo: Make sig handle TsdFrame, TsdTensor
Expand Down Expand Up @@ -40,7 +43,7 @@ def morlet(M, ncycles=5.0, scaling=1.0):
Morelet wavelet kernel
"""
x = np.linspace(-scaling * 2 * np.pi, scaling * 2 * np.pi, M)
return np.exp(1j * ncycles * x) * (np.exp(-0.5 * (x ** 2)) * np.pi ** (-0.25))
return np.exp(1j * ncycles * x) * (np.exp(-0.5 * (x**2)) * np.pi ** (-0.25))

Check warning on line 46 in pynapple/process/signal_processing.py

View check run for this annotation

Codecov / codecov/patch

pynapple/process/signal_processing.py#L45-L46

Added lines #L45 - L46 were not covered by tests


"""
Expand Down Expand Up @@ -70,14 +73,16 @@ def _check_n_cycles(n_cycles, len_cycles=None):
"""
if isinstance(n_cycles, (int, float, np.number)):
if n_cycles <= 0:
raise ValueError('Number of cycles must be a positive number.')
raise ValueError("Number of cycles must be a positive number.")
n_cycles = repeat(n_cycles)

Check warning on line 77 in pynapple/process/signal_processing.py

View check run for this annotation

Codecov / codecov/patch

pynapple/process/signal_processing.py#L76-L77

Added lines #L76 - L77 were not covered by tests
elif isinstance(n_cycles, (tuple, list, np.ndarray)):
for cycle in n_cycles:
if cycle <= 0:
raise ValueError('Each number of cycles must be a positive number.')
raise ValueError("Each number of cycles must be a positive number.")

Check warning on line 81 in pynapple/process/signal_processing.py

View check run for this annotation

Codecov / codecov/patch

pynapple/process/signal_processing.py#L81

Added line #L81 was not covered by tests
if len_cycles and len(n_cycles) != len_cycles:
raise ValueError('The length of number of cycles does not match other inputs.')
raise ValueError(

Check warning on line 83 in pynapple/process/signal_processing.py

View check run for this annotation

Codecov / codecov/patch

pynapple/process/signal_processing.py#L83

Added line #L83 was not covered by tests
"The length of number of cycles does not match other inputs."
)
n_cycles = iter(n_cycles)
return n_cycles

Check warning on line 87 in pynapple/process/signal_processing.py

View check run for this annotation

Codecov / codecov/patch

pynapple/process/signal_processing.py#L86-L87

Added lines #L86 - L87 were not covered by tests

Expand Down Expand Up @@ -105,7 +110,7 @@ def _create_freqs(freq_start, freq_stop, freq_step=1):
return np.arange(freq_start, freq_stop + freq_step, freq_step)

Check warning on line 110 in pynapple/process/signal_processing.py

View check run for this annotation

Codecov / codecov/patch

pynapple/process/signal_processing.py#L110

Added line #L110 was not covered by tests


def compute_wavelet_transform(sig, fs, freqs, n_cycles=7, scaling=0.5, norm='amp'):
def compute_wavelet_transform(sig, fs, freqs, n_cycles=7, scaling=0.5, norm="amp"):
"""
Compute the time-frequency representation of a signal using morlet wavelets.
Expand Down Expand Up @@ -152,14 +157,20 @@ def compute_wavelet_transform(sig, fs, freqs, n_cycles=7, scaling=0.5, norm='amp
mwt[ind, :] = _convolve_wavelet(sig, fs, freq, n_cycle, scaling, norm=norm)
return nap.TsdFrame(t=sig.index, d=np.transpose(mwt))

Check warning on line 158 in pynapple/process/signal_processing.py

View check run for this annotation

Codecov / codecov/patch

pynapple/process/signal_processing.py#L157-L158

Added lines #L157 - L158 were not covered by tests
else:
mwt = np.zeros([sig.values.shape[0], len(freqs), sig.values.shape[1]], dtype=complex)
mwt = np.zeros(

Check warning on line 160 in pynapple/process/signal_processing.py

View check run for this annotation

Codecov / codecov/patch

pynapple/process/signal_processing.py#L160

Added line #L160 was not covered by tests
[sig.values.shape[0], len(freqs), sig.values.shape[1]], dtype=complex
)
for channel_i in range(sig.values.shape[1]):
for ind, (freq, n_cycle) in enumerate(zip(freqs, n_cycles)):
mwt[:, ind, channel_i] = _convolve_wavelet(sig[:, channel_i], fs, freq, n_cycle, scaling, norm=norm)
mwt[:, ind, channel_i] = _convolve_wavelet(

Check warning on line 165 in pynapple/process/signal_processing.py

View check run for this annotation

Codecov / codecov/patch

pynapple/process/signal_processing.py#L165

Added line #L165 was not covered by tests
sig[:, channel_i], fs, freq, n_cycle, scaling, norm=norm
)
return nap.TsdTensor(t=sig.index, d=mwt)

Check warning on line 168 in pynapple/process/signal_processing.py

View check run for this annotation

Codecov / codecov/patch

pynapple/process/signal_processing.py#L168

Added line #L168 was not covered by tests


def _convolve_wavelet(sig, fs, freq, n_cycles=7, scaling=0.5, wavelet_len=None, norm='sss'):
def _convolve_wavelet(
sig, fs, freq, n_cycles=7, scaling=0.5, wavelet_len=None, norm="sss"
):
"""
Convolve a signal with a complex wavelet.
Expand Down Expand Up @@ -195,16 +206,18 @@ def _convolve_wavelet(sig, fs, freq, n_cycles=7, scaling=0.5, wavelet_len=None,
* Taking np.abs() of output gives the analytic amplitude.
* Taking np.angle() of output gives the analytic phase. ..todo: this this still true?
"""
if norm not in ['sss', 'amp']:
raise ValueError('Given `norm` must be `sss` or `amp`')
if norm not in ["sss", "amp"]:
raise ValueError("Given `norm` must be `sss` or `amp`")

Check warning on line 210 in pynapple/process/signal_processing.py

View check run for this annotation

Codecov / codecov/patch

pynapple/process/signal_processing.py#L210

Added line #L210 was not covered by tests
if wavelet_len is None:
wavelet_len = int(n_cycles * fs / freq)

Check warning on line 212 in pynapple/process/signal_processing.py

View check run for this annotation

Codecov / codecov/patch

pynapple/process/signal_processing.py#L212

Added line #L212 was not covered by tests
if wavelet_len > sig.shape[-1]:
raise ValueError('The length of the wavelet is greater than the signal. Can not proceed.')
raise ValueError(

Check warning on line 214 in pynapple/process/signal_processing.py

View check run for this annotation

Codecov / codecov/patch

pynapple/process/signal_processing.py#L214

Added line #L214 was not covered by tests
"The length of the wavelet is greater than the signal. Can not proceed."
)
morlet_f = morlet(wavelet_len, ncycles=n_cycles, scaling=scaling)

Check warning on line 217 in pynapple/process/signal_processing.py

View check run for this annotation

Codecov / codecov/patch

pynapple/process/signal_processing.py#L217

Added line #L217 was not covered by tests
if norm == 'sss':
if norm == "sss":
morlet_f = morlet_f / np.sqrt(np.sum(np.abs(morlet_f) ** 2))

Check warning on line 219 in pynapple/process/signal_processing.py

View check run for this annotation

Codecov / codecov/patch

pynapple/process/signal_processing.py#L219

Added line #L219 was not covered by tests
elif norm == 'amp':
elif norm == "amp":
morlet_f = morlet_f / np.sum(np.abs(morlet_f))
mwt_real = sig.convolve(np.real(morlet_f))
mwt_imag = sig.convolve(np.imag(morlet_f))
Expand Down

0 comments on commit 4d41ca8

Please sign in to comment.