From 26c2b25e5b451886bda2cdea2a0e1ca383f25609 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Thu, 6 Jun 2024 17:02:16 +0100 Subject: [PATCH 01/71] inital commit for signal proc - complex morelet and fft v0 --- pynapple/process/__init__.py | 1 + pynapple/process/signal_processing.py | 201 ++++++++++++++++++++++++++ 2 files changed, 202 insertions(+) create mode 100644 pynapple/process/signal_processing.py diff --git a/pynapple/process/__init__.py b/pynapple/process/__init__.py index 2e1af412..db2581d5 100644 --- a/pynapple/process/__init__.py +++ b/pynapple/process/__init__.py @@ -24,3 +24,4 @@ compute_2d_tuning_curves_continuous, compute_discrete_tuning_curves, ) +from .signal_processing import * diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py new file mode 100644 index 00000000..6afda102 --- /dev/null +++ b/pynapple/process/signal_processing.py @@ -0,0 +1,201 @@ +import numpy as np +from itertools import repeat +import pynapple as nap +from tqdm import tqdm +import matplotlib.pyplot as plt + + +# -------------------------------------------------------------------------------- + +def compute_fft(sig, fs): + """ + Performs numpy fft on sig, returns output + ..todo: Make sig handle TsdFrame, TsdTensor + + :param sig: :param fs: :return: + """ + if not isinstance(sig, nap.Tsd): + raise TypeError("Currently compute_fft is only implemented for Tsd") + fft_result = np.fft.fft(sig.values) + fft_freq = np.fft.fftfreq(len(sig.values), 1 / fs) + return fft_result, fft_freq + + +def morlet(M, ncycles=5.0, scaling=1.0): + """ + Defines the complex Morelet wavelet + :param M: Length of the wavelet. :param ncycles: number of wavelet cycles to use. Default is 5 :param scaling: Scaling factor. Default is 1. :return: (M,) ndarray Morelet wavelet + """ + x = np.linspace(-scaling * 2 * np.pi, scaling * 2 * np.pi, M) + return np.exp(1j * ncycles * x) * (np.exp(-0.5 * (x ** 2)) * np.pi ** (-0.25)) + + +""" +The following code has been adapted from functions in the neurodsp package: +https://github.com/neurodsp-tools/neurodsp + +..todo: reference licence in LICENCE directory +""" + + +def check_n_cycles(n_cycles, len_cycles=None): + """Check an input as a number of cycles definition, and make it iterable. + + Parameters ---------- n_cycles : float or list Definition of number of cycles. If a single value, the same number of cycles is used for each frequency value. If a list or list_like, then should be a n_cycles corresponding to each frequency. len_cycles : int, optional What the length of `n_cycles` should, if it's a list. + Returns ------- n_cycles : iterable An iterable version of the number of cycles. """ + if isinstance(n_cycles, (int, float, np.number)): + + if n_cycles <= 0: + raise ValueError('Number of cycles must be a positive number.') + + n_cycles = repeat(n_cycles) + + elif isinstance(n_cycles, (tuple, list, np.ndarray)): + + for cycle in n_cycles: + if cycle <= 0: + raise ValueError('Each number of cycles must be a positive number.') + + if len_cycles and len(n_cycles) != len_cycles: + raise ValueError('The length of number of cycles does not match other inputs.') + + n_cycles = iter(n_cycles) + + return n_cycles + + +def create_freqs(freq_start, freq_stop, freq_step=1): + """Create an array of frequencies. + + Parameters ---------- freq_start : float Starting value for the frequency definition. freq_stop : float Stopping value for the frequency definition, inclusive. freq_step : float, optional, default: 1 Step value, for linearly spaced values between start and stop. + Returns ------- freqs : 1d array Frequency indices. """ + return np.arange(freq_start, freq_stop + freq_step, freq_step) + + +def compute_wavelet_transform(sig, fs, freqs, n_cycles=7, scaling=0.5, norm='amp'): + """Compute the time-frequency representation of a signal using morlet wavelets. + + Parameters + ---------- + sig : 1d array + Time series. + fs : float + Sampling rate, in Hz. + freqs : 1d array or list of float + If array, frequency values to estimate with morlet wavelets. + If list, define the frequency range, as [freq_start, freq_stop, freq_step]. + The `freq_step` is optional, and defaults to 1. Range is inclusive of `freq_stop` value. + n_cycles : float or 1d array + Length of the filter, as the number of cycles for each frequency. + If 1d array, this defines n_cycles for each frequency. + scaling : float + Scaling factor. + norm : {'sss', 'amp'}, optional + Normalization method: + + * 'sss' - divide by the square root of the sum of squares + * 'amp' - divide by the sum of amplitudes + + Returns + ------- + mwt : 2d array + Time frequency representation of the input signal. + + Notes + ----- + This computes the continuous wavelet transform at specified frequencies across time. + + Examples + -------- + Compute a Morlet wavelet time-frequency representation of a signal: + + >>> from neurodsp.sim import sim_combined + >>> sig = sim_combined(n_seconds=10, fs=500, + ... components={'sim_powerlaw': {}, 'sim_oscillation' : {'freq': 10}}) + >>> mwt = compute_wavelet_transform(sig, fs=500, freqs=[1, 30]) + """ + if not isinstance(sig, nap.Tsd) and not isinstance(sig, nap.TsdFrame): + raise TypeError("`sig` must be instance of Tsd or TsdFrame") + + if isinstance(freqs, (tuple, list)): + freqs = create_freqs(*freqs) + n_cycles = check_n_cycles(n_cycles, len(freqs)) + if isinstance(sig, nap.Tsd): + mwt = np.zeros([len(freqs), len(sig)], dtype=complex) + for ind, (freq, n_cycle) in enumerate(zip(freqs, n_cycles)): + wav = convolve_wavelet(sig, fs, freq, n_cycle, scaling, norm=norm) + mwt[ind, :] = wav + return nap.TsdFrame(t=sig.index, d=np.transpose(mwt)) + else: + mwt = np.zeros([sig.values.shape[0], len(freqs), sig.values.shape[1]], dtype=complex) + for channel_i in tqdm(range(sig.values.shape[1])): + for ind, (freq, n_cycle) in enumerate(zip(freqs, n_cycles)): + wav = convolve_wavelet(sig[:, channel_i], fs, freq, n_cycle, scaling, norm=norm) + mwt[:, ind, channel_i] = wav + return nap.TsdTensor(t=sig.index, d=mwt) + + +def convolve_wavelet(sig, fs, freq, n_cycles=7, scaling=0.5, wavelet_len=None, norm='sss'): + """Convolve a signal with a complex wavelet. + + Parameters + ---------- + sig : 1d array + Time series to filter. + fs : float + Sampling rate, in Hz. + freq : float + Center frequency of bandpass filter. + n_cycles : float, optional, default: 7 + Length of the filter, as the number of cycles of the oscillation with specified frequency. + scaling : float, optional, default: 0.5 + Scaling factor for the morlet wavelet. + wavelet_len : int, optional + Length of the wavelet. If defined, this overrides the freq and n_cycles inputs. + norm : {'sss', 'amp'}, optional + Normalization method: + + * 'sss' - divide by the square root of the sum of squares + * 'amp' - divide by the sum of amplitudes + + Returns + ------- + array + Complex time series. + + Notes + ----- + + * The real part of the returned array is the filtered signal. + * Taking np.abs() of output gives the analytic amplitude. + * Taking np.angle() of output gives the analytic phase. + + Examples + -------- + Convolve a complex wavelet with a simulated signal: + + >>> from neurodsp.sim import sim_combined + >>> sig = sim_combined(n_seconds=10, fs=500, + ... components={'sim_powerlaw': {}, 'sim_oscillation' : {'freq': 10}}) + >>> cts = convolve_wavelet(sig, fs=500, freq=10) + """ + if norm not in ['sss', 'amp']: + raise ValueError('Given `norm` must be `sss` or `amp`') + + if wavelet_len is None: + wavelet_len = int(n_cycles * fs / freq) + + if wavelet_len > sig.shape[-1]: + raise ValueError('The length of the wavelet is greater than the signal. Can not proceed.') + + morlet_f = morlet(wavelet_len, ncycles=n_cycles, scaling=scaling) + + if norm == 'sss': + morlet_f = morlet_f / np.sqrt(np.sum(np.abs(morlet_f) ** 2)) + elif norm == 'amp': + morlet_f = morlet_f / np.sum(np.abs(morlet_f)) + + mwt_real = sig.convolve(np.real(morlet_f)) + mwt_imag = sig.convolve(np.imag(morlet_f)) + + return mwt_real.values + 1j * mwt_imag.values \ No newline at end of file From f080d4f38114fa2540b38a9e87beb8625d517262 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Thu, 27 Jun 2024 00:33:58 +0100 Subject: [PATCH 02/71] basic pywavelets functionality matched --- pynapple/process/__init__.py | 6 +- pynapple/process/signal_processing.py | 297 +++++++++++++------------- 2 files changed, 149 insertions(+), 154 deletions(-) diff --git a/pynapple/process/__init__.py b/pynapple/process/__init__.py index db2581d5..08d58648 100644 --- a/pynapple/process/__init__.py +++ b/pynapple/process/__init__.py @@ -24,4 +24,8 @@ compute_2d_tuning_curves_continuous, compute_discrete_tuning_curves, ) -from .signal_processing import * +from .signal_processing import ( + compute_wavelet_transform, + compute_spectrum, + compute_welch_spectrum +) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 6afda102..c49e729d 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -1,100 +1,72 @@ +""" +Signal processing tools for Pynapple. + +Contains functionality for signal processing pynapple object; fourier transforms and wavelet decomposition. +""" + import numpy as np -from itertools import repeat import pynapple as nap -from tqdm import tqdm -import matplotlib.pyplot as plt +from math import ceil, floor +import json +from scipy.signal import welch +with open('wavelets.json') as f: + WAVELET_DICT = json.load(f) -# -------------------------------------------------------------------------------- -def compute_fft(sig, fs): +def compute_spectrum(sig, fs=None): """ Performs numpy fft on sig, returns output ..todo: Make sig handle TsdFrame, TsdTensor - :param sig: :param fs: :return: + ---------- + sig : pynapple.Tsd or pynapple.TsdFrame + Time series. + fs : float, optional + Sampling rate, in Hz. If None, will be calculated from the given signal """ if not isinstance(sig, nap.Tsd): raise TypeError("Currently compute_fft is only implemented for Tsd") + if fs is None: + fs = sig.index.shape[0]/(sig.index.max() - sig.index.min()) fft_result = np.fft.fft(sig.values) fft_freq = np.fft.fftfreq(len(sig.values), 1 / fs) return fft_result, fft_freq -def morlet(M, ncycles=5.0, scaling=1.0): - """ - Defines the complex Morelet wavelet - :param M: Length of the wavelet. :param ncycles: number of wavelet cycles to use. Default is 5 :param scaling: Scaling factor. Default is 1. :return: (M,) ndarray Morelet wavelet +def compute_welch_spectrum(sig, fs=None): """ - x = np.linspace(-scaling * 2 * np.pi, scaling * 2 * np.pi, M) - return np.exp(1j * ncycles * x) * (np.exp(-0.5 * (x ** 2)) * np.pi ** (-0.25)) - - -""" -The following code has been adapted from functions in the neurodsp package: -https://github.com/neurodsp-tools/neurodsp - -..todo: reference licence in LICENCE directory -""" - + Performs scipy Welch's decomposition on sig, returns output + ..todo: Make sig handle TsdFrame, TsdTensor -def check_n_cycles(n_cycles, len_cycles=None): - """Check an input as a number of cycles definition, and make it iterable. - - Parameters ---------- n_cycles : float or list Definition of number of cycles. If a single value, the same number of cycles is used for each frequency value. If a list or list_like, then should be a n_cycles corresponding to each frequency. len_cycles : int, optional What the length of `n_cycles` should, if it's a list. - Returns ------- n_cycles : iterable An iterable version of the number of cycles. """ - if isinstance(n_cycles, (int, float, np.number)): - - if n_cycles <= 0: - raise ValueError('Number of cycles must be a positive number.') - - n_cycles = repeat(n_cycles) - - elif isinstance(n_cycles, (tuple, list, np.ndarray)): - - for cycle in n_cycles: - if cycle <= 0: - raise ValueError('Each number of cycles must be a positive number.') - - if len_cycles and len(n_cycles) != len_cycles: - raise ValueError('The length of number of cycles does not match other inputs.') - - n_cycles = iter(n_cycles) - - return n_cycles - - -def create_freqs(freq_start, freq_stop, freq_step=1): - """Create an array of frequencies. - - Parameters ---------- freq_start : float Starting value for the frequency definition. freq_stop : float Stopping value for the frequency definition, inclusive. freq_step : float, optional, default: 1 Step value, for linearly spaced values between start and stop. - Returns ------- freqs : 1d array Frequency indices. """ - return np.arange(freq_start, freq_stop + freq_step, freq_step) + ---------- + sig : pynapple.Tsd or pynapple.TsdFrame + Time series. + fs : float, optional + Sampling rate, in Hz. If None, will be calculated from the given signal + """ + if not isinstance(sig, nap.Tsd): + raise TypeError("Currently compute_fft is only implemented for Tsd") + if fs is None: + fs = sig.index.shape[0]/(sig.index.max() - sig.index.min()) + freqs, spectogram = welch(sig.values, fs=fs) + return spectogram, freqs -def compute_wavelet_transform(sig, fs, freqs, n_cycles=7, scaling=0.5, norm='amp'): - """Compute the time-frequency representation of a signal using morlet wavelets. +def compute_wavelet_transform(sig, freqs, fs=None): + """ + Compute the time-frequency representation of a signal using morlet wavelets. Parameters ---------- - sig : 1d array + sig : pynapple.Tsd or pynapple.TsdFrame Time series. - fs : float - Sampling rate, in Hz. freqs : 1d array or list of float If array, frequency values to estimate with morlet wavelets. If list, define the frequency range, as [freq_start, freq_stop, freq_step]. The `freq_step` is optional, and defaults to 1. Range is inclusive of `freq_stop` value. - n_cycles : float or 1d array - Length of the filter, as the number of cycles for each frequency. - If 1d array, this defines n_cycles for each frequency. - scaling : float - Scaling factor. - norm : {'sss', 'amp'}, optional - Normalization method: - - * 'sss' - divide by the square root of the sum of squares - * 'amp' - divide by the sum of amplitudes + fs : float, optional + Sampling rate, in Hz. If None, will be calculated from the given signal Returns ------- @@ -104,98 +76,117 @@ def compute_wavelet_transform(sig, fs, freqs, n_cycles=7, scaling=0.5, norm='amp Notes ----- This computes the continuous wavelet transform at specified frequencies across time. - - Examples - -------- - Compute a Morlet wavelet time-frequency representation of a signal: - - >>> from neurodsp.sim import sim_combined - >>> sig = sim_combined(n_seconds=10, fs=500, - ... components={'sim_powerlaw': {}, 'sim_oscillation' : {'freq': 10}}) - >>> mwt = compute_wavelet_transform(sig, fs=500, freqs=[1, 30]) """ if not isinstance(sig, nap.Tsd) and not isinstance(sig, nap.TsdFrame): - raise TypeError("`sig` must be instance of Tsd or TsdFrame") - - if isinstance(freqs, (tuple, list)): - freqs = create_freqs(*freqs) - n_cycles = check_n_cycles(n_cycles, len(freqs)) + raise TypeError("`sig` must be instance of Tsd, TsdFrame, or TsdTensor") + if fs is None: + fs = sig.index.shape[0]/(sig.index.max() - sig.index.min()) + assert fs/2 > np.max(freqs), "`freqs` contains values over the Nyquist frequency." if isinstance(sig, nap.Tsd): - mwt = np.zeros([len(freqs), len(sig)], dtype=complex) - for ind, (freq, n_cycle) in enumerate(zip(freqs, n_cycles)): - wav = convolve_wavelet(sig, fs, freq, n_cycle, scaling, norm=norm) - mwt[ind, :] = wav - return nap.TsdFrame(t=sig.index, d=np.transpose(mwt)) - else: - mwt = np.zeros([sig.values.shape[0], len(freqs), sig.values.shape[1]], dtype=complex) - for channel_i in tqdm(range(sig.values.shape[1])): - for ind, (freq, n_cycle) in enumerate(zip(freqs, n_cycles)): - wav = convolve_wavelet(sig[:, channel_i], fs, freq, n_cycle, scaling, norm=norm) - mwt[:, ind, channel_i] = wav - return nap.TsdTensor(t=sig.index, d=mwt) - - -def convolve_wavelet(sig, fs, freq, n_cycles=7, scaling=0.5, wavelet_len=None, norm='sss'): - """Convolve a signal with a complex wavelet. + mwt, f = _cwt(sig, + freqs=freqs, + wavelet="cmor1.5-1.0", + sampling_period=1/fs) + return nap.TsdFrame(t=sig.index, d=np.transpose(mwt), time_support=sig.time_support) + elif isinstance(sig, nap.TsdFrame): + mwt = np.zeros( + [sig.values.shape[0], len(freqs), sig.values.shape[1]], dtype=complex + ) + for channel_i in range(sig.values.shape[1]): + mwt[:, :, channel_i] = np.transpose(_cwt(sig[:, channel_i], + freqs=freqs, + wavelet="cmor1.5-1.0", + sampling_period=1/fs)[0]) + return nap.TsdTensor(t=sig.index, d=mwt, time_support=sig.time_support) + elif isinstance(sig, nap.TsdTensor): + raise NotImplemented("cwt for TsdTensor is not yet implemented") + + +def _cwt(data, freqs, wavelet, sampling_period, axis=-1): + """ + cwt(data, scales, wavelet) + + One dimensional Continuous Wavelet Transform. Parameters ---------- - sig : 1d array - Time series to filter. - fs : float - Sampling rate, in Hz. - freq : float - Center frequency of bandpass filter. - n_cycles : float, optional, default: 7 - Length of the filter, as the number of cycles of the oscillation with specified frequency. - scaling : float, optional, default: 0.5 - Scaling factor for the morlet wavelet. - wavelet_len : int, optional - Length of the wavelet. If defined, this overrides the freq and n_cycles inputs. - norm : {'sss', 'amp'}, optional - Normalization method: - - * 'sss' - divide by the square root of the sum of squares - * 'amp' - divide by the sum of amplitudes + data : pynapple.Tsd or pynapple.TsdFrame + Input time series signal. + freqs : 1d array + Frequency values to estimate with morlet wavelets. + wavelet : Wavelet object or name + Wavelet to use, only implemented for "cmor1.5-1.0". + sampling_period : float + Sampling period for the frequencies output. + The values computed for ``coefs`` are independent of the choice of + ``sampling_period`` (i.e. ``scales`` is not scaled by the sampling + period). + axis: int, optional + Axis over which to compute the CWT. If not given, the last axis is + used. Returns ------- - array - Complex time series. - - Notes - ----- - - * The real part of the returned array is the filtered signal. - * Taking np.abs() of output gives the analytic amplitude. - * Taking np.angle() of output gives the analytic phase. - - Examples - -------- - Convolve a complex wavelet with a simulated signal: - - >>> from neurodsp.sim import sim_combined - >>> sig = sim_combined(n_seconds=10, fs=500, - ... components={'sim_powerlaw': {}, 'sim_oscillation' : {'freq': 10}}) - >>> cts = convolve_wavelet(sig, fs=500, freq=10) + coefs : array_like + Continuous wavelet transform of the input signal for the given scales + and wavelet. The first axis of ``coefs`` corresponds to the scales. + The remaining axes match the shape of ``data``. + frequencies : array_like + If the unit of sampling period are seconds and given, then frequencies + are in hertz. Otherwise, a sampling period of 1 is assumed. + + ..todo:: This should use pynapple convolve but currently that cannot handle imaginary numbers as it uses scipy convolve """ - if norm not in ['sss', 'amp']: - raise ValueError('Given `norm` must be `sss` or `amp`') - - if wavelet_len is None: - wavelet_len = int(n_cycles * fs / freq) - - if wavelet_len > sig.shape[-1]: - raise ValueError('The length of the wavelet is greater than the signal. Can not proceed.') - - morlet_f = morlet(wavelet_len, ncycles=n_cycles, scaling=scaling) - - if norm == 'sss': - morlet_f = morlet_f / np.sqrt(np.sum(np.abs(morlet_f) ** 2)) - elif norm == 'amp': - morlet_f = morlet_f / np.sum(np.abs(morlet_f)) - - mwt_real = sig.convolve(np.real(morlet_f)) - mwt_imag = sig.convolve(np.imag(morlet_f)) - - return mwt_real.values + 1j * mwt_imag.values \ No newline at end of file + int_psi = np.array(WAVELET_DICT[wavelet]['int_psi_real'])*1j + np.array(WAVELET_DICT[wavelet]['int_psi_imag']) + x = np.array(WAVELET_DICT[wavelet]["x"]) + central_freq = WAVELET_DICT[wavelet]["central_freq"] + scales = central_freq/(freqs*sampling_period) + out = np.empty((np.size(scales),) + data.shape, dtype=np.complex128) + + if data.ndim > 1: + # move axis to be transformed last (so it is contiguous) + data = data.swapaxes(-1, axis) + # reshape to (n_batch, data.shape[-1]) + data_shape_pre = data.shape + data = data.reshape((-1, data.shape[-1])) + + for i, scale in enumerate(scales): + step = x[1] - x[0] + j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * step) + j = j.astype(int) # floor + if j[-1] >= int_psi.size: + j = np.extract(j < int_psi.size, j) + int_psi_scale = int_psi[j][::-1] + + if data.ndim == 1: + conv = np.convolve(data, int_psi_scale) + else: + # batch convolution via loop + conv_shape = list(data.shape) + conv_shape[-1] += int_psi_scale.size - 1 + conv_shape = tuple(conv_shape) + conv = np.empty(conv_shape, dtype=np.complex128) + for n in range(data.shape[0]): + conv[n, :] = np.convolve(data[n], int_psi_scale) + + coef = - np.sqrt(scale) * np.diff(conv, axis=-1) + if out.dtype.kind != 'c': + coef = coef.real + # transform axis is always -1 due to the data reshape above + d = (coef.shape[-1] - data.shape[-1]) / 2. + if d > 0: + coef = coef[..., floor(d):-ceil(d)] + elif d < 0: + raise ValueError( + f"Selected scale of {scale} too small.") + if data.ndim > 1: + # restore original data shape and axis position + coef = coef.reshape(data_shape_pre) + coef = coef.swapaxes(axis, -1) + out[i, ...] = coef + + frequencies = central_freq/scales + if np.isscalar(frequencies): + frequencies = np.array([frequencies]) + frequencies /= sampling_period + return out, frequencies From 9aa05ac3755b97363fe54d353c882c328c7593ca Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Thu, 27 Jun 2024 19:50:15 +0100 Subject: [PATCH 03/71] different wavelet definition --- pynapple/process/__init__.py | 1 + pynapple/process/signal_processing.py | 212 ++++++++++++++++++++++++++ 2 files changed, 213 insertions(+) diff --git a/pynapple/process/__init__.py b/pynapple/process/__init__.py index 08d58648..120a3363 100644 --- a/pynapple/process/__init__.py +++ b/pynapple/process/__init__.py @@ -26,6 +26,7 @@ ) from .signal_processing import ( compute_wavelet_transform, + compute_wavelet_transform_og, compute_spectrum, compute_welch_spectrum ) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index c49e729d..e2b7bf9f 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -9,6 +9,7 @@ from math import ceil, floor import json from scipy.signal import welch +from itertools import repeat with open('wavelets.json') as f: WAVELET_DICT = json.load(f) @@ -190,3 +191,214 @@ def _cwt(data, freqs, wavelet, sampling_period, axis=-1): frequencies = np.array([frequencies]) frequencies /= sampling_period return out, frequencies + + + + + + +# ------------------------------------------------------------------------------- + +def morlet(M=1024, ncycles=1.5, scaling=1.0, precision=8): + """ + Defines the complex Morelet wavelet kernel + + Parameters + ---------- + M : int + Length of the wavelet + ncycles : float + number of wavelet cycles to use. Default is 5 + scaling: float + Scaling factor. Default is 1.5 + precision: int + Precision of wavelet to use + + Returns + ------- + np.ndarray + Morelet wavelet kernel + """ + x = np.linspace(-precision, precision, M) + return ((np.pi*ncycles) ** (-0.25)) * np.exp(-x**2 / ncycles) * np.exp(1j * 2*np.pi * scaling * x) + +""" +The following code has been adapted from functions in the neurodsp package: +https://github.com/neurodsp-tools/neurodsp + +..todo: reference licence in LICENCE directory +""" + +def _check_n_cycles(n_cycles, len_cycles=None): + """ + Check an input as a number of cycles, and make it iterable. + + Parameters + ---------- + n_cycles : float or list + Definition of number of cycles to check. If a single value, the same number of cycles is used for each + frequency value. If a list or list_like, then should be a n_cycles corresponding to each frequency. + len_cycles: int, optional + What the length of `n_cycles` should be, if it's a list. + + Returns + ------- + iter + An iterable version of the number of cycles. + """ + if isinstance(n_cycles, (int, float, np.number)): + if n_cycles <= 0: + raise ValueError("Number of cycles must be a positive number.") + n_cycles = repeat(n_cycles) + elif isinstance(n_cycles, (tuple, list, np.ndarray)): + for cycle in n_cycles: + if cycle <= 0: + raise ValueError("Each number of cycles must be a positive number.") + if len_cycles and len(n_cycles) != len_cycles: + raise ValueError( + "The length of number of cycles does not match other inputs." + ) + n_cycles = iter(n_cycles) + return n_cycles + + +def _create_freqs(freq_start, freq_stop, freq_step=1): + """ + Creates an array of frequencies. + + ..todo:: Implement log scaling + + Parameters + ---------- + freq_start : float + Starting value for the frequency definition. + freq_stop: float + Stopping value for the frequency definition, inclusive. + freq_step: float, optional + Step value, for linearly spaced values between start and stop. + + Returns + ------- + freqs: 1d array + Frequency indices. + """ + return np.arange(freq_start, freq_stop + freq_step, freq_step) + + +def compute_wavelet_transform_og(sig, fs, freqs, n_cycles=7, scaling=0.5, norm="amp"): + """ + Compute the time-frequency representation of a signal using morlet wavelets. + + ..todo:: better normalization between channels + + Parameters + ---------- + sig : pynapple.Tsd or pynapple.TsdFrame + Time series. + fs : float + Sampling rate, in Hz. + freqs : 1d array or list of float + If array, frequency values to estimate with morlet wavelets. + If list, define the frequency range, as [freq_start, freq_stop, freq_step]. + The `freq_step` is optional, and defaults to 1. Range is inclusive of `freq_stop` value. + n_cycles : float or 1d array + Length of the filter, as the number of cycles for each frequency. + If 1d array, this defines n_cycles for each frequency. + scaling : float + Scaling factor. + norm : {'sss', 'amp'}, optional + Normalization method: + + * 'sss' - divide by the square root of the sum of squares + * 'amp' - divide by the sum of amplitudes + + Returns + ------- + mwt : 2d array + Time frequency representation of the input signal. + + Notes + ----- + This computes the continuous wavelet transform at specified frequencies across time. + """ + if not isinstance(sig, nap.Tsd) and not isinstance(sig, nap.TsdFrame): + raise TypeError("`sig` must be instance of Tsd or TsdFrame") + if isinstance(freqs, (tuple, list)): + freqs = _create_freqs(*freqs) + n_cycles = _check_n_cycles(n_cycles, len(freqs)) + if isinstance(sig, nap.Tsd): + mwt = np.zeros([len(freqs), len(sig)], dtype=complex) + for ind, (freq, n_cycle) in enumerate(zip(freqs, n_cycles)): + mwt[ind, :] = _convolve_wavelet(sig, + fs, + freq, + n_cycle, + scaling, + norm=norm) + return nap.TsdFrame(t=sig.index, d=np.transpose(mwt), time_support=sig.time_support) + else: + mwt = np.zeros( + [sig.values.shape[0], len(freqs), sig.values.shape[1]], dtype=complex + ) + for channel_i in range(sig.values.shape[1]): + for ind, (freq, n_cycle) in enumerate(zip(freqs, n_cycles)): + mwt[:, ind, channel_i] = _convolve_wavelet( + sig[:, channel_i], fs, freq, n_cycle, scaling, norm=norm + ) + return nap.TsdTensor(t=sig.index, d=mwt, time_support=sig.time_support) + + +def _convolve_wavelet( + sig, fs, freq, n_cycles=7, scaling=0.5, wavelet_len=None, norm="sss" +): + """ + Convolve a signal with a complex wavelet. + + Parameters + ---------- + sig : pynapple.Tsd + Time series to filter. + fs : float + Sampling rate, in Hz. + freq : float + Center frequency of bandpass filter. + n_cycles : float, optional, default: 7 + Length of the filter, as the number of cycles of the oscillation with specified frequency. + scaling : float, optional, default: 0.5 + Scaling factor for the morlet wavelet. + wavelet_len : int, optional + Length of the wavelet. If defined, this overrides the freq and n_cycles inputs. + norm : {'sss', 'amp'}, optional + Normalization method: + + * 'sss' - divide by the square root of the sum of squares + * 'amp' - divide by the sum of amplitudes + + Returns + ------- + array + Complex- valued time series. + + Notes + ----- + + * The real part of the returned array is the filtered signal. + * Taking np.abs() of output gives the analytic amplitude. + * Taking np.angle() of output gives the analytic phase. ..todo: this this still true? + """ + if norm not in ["sss", "amp"]: + raise ValueError("Given `norm` must be `sss` or `amp`") + if wavelet_len is None: + wavelet_len = int(n_cycles * fs / freq) + if wavelet_len > sig.shape[-1]: + raise ValueError( + "The length of the wavelet is greater than the signal. Can not proceed." + ) + morlet_f = morlet(wavelet_len, ncycles=n_cycles, scaling=scaling) + if norm == "sss": + morlet_f = morlet_f / np.sqrt(np.sum(np.abs(morlet_f) ** 2)) + elif norm == "amp": + morlet_f = morlet_f / np.sum(np.abs(morlet_f)) + mwt_real = sig.convolve(np.real(morlet_f)) + mwt_imag = sig.convolve(np.imag(morlet_f)) + return mwt_real.values + 1j * mwt_imag.values \ No newline at end of file From 4af8d90dc7013df6be7753a02cbbcbe4adf5f053 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Thu, 27 Jun 2024 22:48:45 +0100 Subject: [PATCH 04/71] wavlet workaround fixed, tutorial added --- docs/examples/tutorial_signal_processing.py | 235 ++++++++++++++++++++ pynapple/process/__init__.py | 1 - pynapple/process/signal_processing.py | 200 +++-------------- 3 files changed, 263 insertions(+), 173 deletions(-) create mode 100644 docs/examples/tutorial_signal_processing.py diff --git a/docs/examples/tutorial_signal_processing.py b/docs/examples/tutorial_signal_processing.py new file mode 100644 index 00000000..c0d585ba --- /dev/null +++ b/docs/examples/tutorial_signal_processing.py @@ -0,0 +1,235 @@ +# -*- coding: utf-8 -*- +""" +Signal Processing Local Field Potentials +============ + +Working with Local Field Potential data. + +See the [documentation](https://pynapple-org.github.io/pynapple/) of Pynapple for instructions on installing the package. + +This tutorial was made by Kipp Freud. + +""" +# %% +# !!! warning +# This tutorial uses matplotlib for displaying the figure +# +# You can install all with `pip install matplotlib requests` +# +# mkdocs_gallery_thumbnail_number = 1 +# +# Now, import the necessary libraries: + +import numpy as np +import pynapple as nap +import pandas as pd +import os +#import requests +from zipfile import ZipFile +import matplotlib.pyplot as plt +import matplotlib +matplotlib.use("TkAgg") + +# %% +# *** +# Downloading the data +# ------------------ +# First things first: Let's download and extract the data +# path = "data/A2929-200711" +# extract_to = "data" +# if extract_to not in os.listdir("."): +# os.mkdir(extract_to) +# if path not in os.listdir("."): +# # Download the file +# response = requests.get("https://www.dropbox.com/s/su4oaje57g3kit9/A2929-200711.zip?dl=1") +# zip_path = os.path.join(extract_to, '/downloaded_file.zip') +# # Write the zip file to disk +# with open(zip_path, 'wb') as f: +# f.write(response.content) +# # Unzip the file +# with ZipFile(zip_path, 'r') as zip_ref: +# zip_ref.extractall(extract_to) + + +# %% +# *** +# Parsing the data +# ------------------ +# Now that we have the data, we must append the 2kHz LFP recording to the .nwb file +# eeg_path = "data/A2929-200711/A2929-200711.dat" +# frequency = 20000 # Hz +# n_channels = 16 +# f = open(eeg_path, 'rb') +# startoffile = f.seek(0, 0) +# endoffile = f.seek(0, 2) +# f.close() +# bytes_size = 2 +# n_samples = int((endoffile-startoffile)/n_channels/bytes_size) +# duration = n_samples/frequency +# interval = 1/frequency +# fp = np.memmap(eeg_path, np.int16, 'r', shape = (n_samples, n_channels)) +# timestep = np.arange(0, n_samples)/frequency +# eeg = nap.TsdFrame(t=timestep, d=fp) +# nap.append_NWB_LFP("data/A2929-200711/", +# eeg) + + +# %% +# Let's save the RoiResponseSeries as a variable called 'transients' and print it +FS = 1250 +# data = nap.load_file("data/A2929-200711/pynapplenwb/A2929-200711.nwb") +data = nap.load_file("data/stable.nwb") +print(data["ElectricalSeries"]) +# normed_electrical_series = data["ElectricalSeries"].values +# normed_electrical_series = normed_electrical_series[:, :] +# normed_electrical_series[:, :10] = normed_electrical_series[:, :10] - np.expand_dims(np.mean(normed_electrical_series[:, :10], axis=1), axis=1) +# normed_electrical_series[:, 10:] = normed_electrical_series[:, 10:] - np.expand_dims(np.mean(normed_electrical_series[:, 10:], axis=1), axis=1) +NES = nap.TsdFrame(t=data["ElectricalSeries"].index.values, d=data["ElectricalSeries"].values, time_support=data["ElectricalSeries"].time_support) + +# %% +# *** +# Selecting slices +# ----------------------------------- +# Let's consider a two 1-second slices of data, one from the sleep epoch and one from wake +wake_minute = NES.value_from(NES, nap.IntervalSet(900,960)) +sleep_minute = NES.value_from(NES, nap.IntervalSet(0,60)) + +# %% +# *** +# Plotting the LFP activity of one slices +# ----------------------------------- +# Let's plot +fig, ax = plt.subplots(2) +for channel in range(sleep_minute.shape[1]): + ax[0].plot(sleep_minute.index.values, sleep_minute[:, channel], alpha=0.5, label="Sleep Data") +ax[0].set_title("Sleep ephys") +for channel in range(wake_minute.shape[1]): + ax[1].plot(wake_minute.index.values, wake_minute[:, channel], alpha=0.5, label="Wake Data") +ax[1].set_title("Wake ephys") +plt.show() + + +# %% +# There is much shared information between channels, and wake and sleep don't seem visibly different. +# Let's take the Fourier transforms of one channel for both and see if differences are present +channel = 1 +fig, ax = plt.subplots(1) +fft_sig, fft_freqs = nap.compute_spectrum(sleep_minute[:, channel], + fs=int(FS)) +ax.plot(fft_freqs, np.abs(fft_sig), alpha=0.5, label="Sleep Data") +ax.set_xlim((0, FS/2 )) +fft_sig, fft_freqs = nap.compute_spectrum(wake_minute[:, channel], + fs=int(FS)) +ax.plot(fft_freqs, np.abs(fft_sig), alpha=0.5, label="Wake Data") +ax.set_title(f"Fourier Decomposition for channel {channel}") +ax.legend() +plt.show() + + +# %% +# There seems to be some differences presenting themselves - a bump in higher frequencies for waking data? +# Let's explore further with a wavelet decomposition + +def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kwargs): + if np.iscomplexobj(powers): + powers = abs(powers) + ax.imshow(powers, aspect='auto', **kwargs) + ax.invert_yaxis() + ax.set_xlabel('Time (s)') + ax.set_ylabel('Frequency (Hz)') + if isinstance(x_ticks, int): + x_tick_pos = np.linspace(0, times.size, x_ticks) + x_ticks = np.round(np.linspace(times[0], times[-1], x_ticks), 2) + else: + x_tick_pos = [np.argmin(np.abs(times - val)) for val in x_ticks] + ax.set(xticks=x_tick_pos, xticklabels=x_ticks) + if isinstance(y_ticks, int): + y_ticks_pos = np.linspace(0, freqs.size, y_ticks) + y_ticks = np.round(np.linspace(freqs[0], freqs[-1], y_ticks), 2) + else: + y_ticks_pos = [np.argmin(np.abs(freqs - val)) for val in y_ticks] + ax.set(yticks=y_ticks_pos, yticklabels=y_ticks) + +fig, ax = plt.subplots(2) +freqs = np.array([2.59, 3.66, 5.18, 8.0, 10.36, 14.65, 20.72, 29.3, 41.44, 58.59, 82.88, 117.19, + 165.75, 234.38, 331.5, 468.75, 624., ]) +mwt_sleep = nap.compute_wavelet_transform( + sleep_minute[:, channel], + fs=None, + freqs=freqs + ) +plot_timefrequency(sleep_minute.index.values[:], freqs[:], np.transpose(mwt_sleep[:,:].values), ax=ax[0]) +ax[0].set_title(f"Sleep Data Wavelet Decomposition: Channel {channel}") +mwt_wake = nap.compute_wavelet_transform( + wake_minute[:, channel], + fs=None, + freqs=freqs + ) +plot_timefrequency(wake_minute.index.values[:], freqs[:], np.transpose(mwt_wake[:,:].values), ax=ax[1]) +ax[1].set_title(f"Wake Data Wavelet Decomposition: Channel {channel}") +plt.margins(0) +plt.show() + +# %% +# Let's focus on the waking data. Let's see if we can isolate the theta oscillations from the data +freq = 3 +interval = (937, 939) +wake_second = wake_minute.value_from(wake_minute, nap.IntervalSet(interval[0],interval[1])) +fig, ax = plt.subplots(1) +ax.plot(wake_second.index.values, wake_second[:, channel], alpha=0.5, label="Wake Data") +ax.plot(wake_second.index.values, + mwt_wake.value_from(mwt_wake, nap.IntervalSet(interval[0],interval[1]))[:, freq].values.real, label="Theta oscillations") +ax.set_title(f"{freqs[freq]}Hz oscillation power.") +plt.show() + + +# %% +# Let's focus on the sleeping data. Let's see if we can isolate the slow wave oscillations from the data +freq = 0 +# interval = (10, 15) +interval = (20, 25) +sleep_second = sleep_minute.value_from(sleep_minute, nap.IntervalSet(interval[0],interval[1])) +_, ax = plt.subplots(1) +ax.plot(sleep_second.index.values, sleep_second[:, channel], alpha=0.5, label="Wake Data") +ax.plot(sleep_second.index.values, + mwt_sleep.value_from(mwt_sleep, nap.IntervalSet(interval[0],interval[1]))[:, freq].values.real, label="Slow Wave Oscillations") +ax.set_title(f"{freqs[freq]}Hz oscillation power") +plt.show() + +# %% +# Let's plot spike phase, time scatter plots to see if spikes display phase characteristics during slow wave sleep + +_, ax = plt.subplots(20, figsize=(10, 50)) +mwt_sleep = np.transpose(mwt_sleep.value_from(mwt_sleep, nap.IntervalSet(interval[0],interval[1]))) +ax[0].plot(sleep_second.index, sleep_second.values[:, 0]) +plot_timefrequency(sleep_second.index, freqs, np.abs(mwt_sleep[:, :]), ax=ax[1]) + +ax[2].plot(sleep_second.index, sleep_second.values[:, 0]) +ax[2].plot(sleep_second.index, mwt_sleep[freq, :].real) +ax[2].set_title(f"{freqs[freq]}Hz") + +ax[3].plot(sleep_second.index, np.abs(mwt_sleep[freq, :])) +# ax[3].plot(lfp.index, lfp.values[:,0]) +ax[4].plot(sleep_second.index, np.angle(mwt_sleep[freq, :])) + +spikes = {} +for i in data["units"].index: + spikes[i] = data["units"][i].times()[ + (data["units"][i].times() > interval[0]) & (data["units"][i].times() < interval[1])] + +phase = {} +for i in spikes.keys(): + phase_i = [] + for spike in spikes[i]: + phase_i.append(np.angle(mwt_sleep[freq, np.argmin(np.abs(sleep_second.index.values - spike))])) + phase[i] = np.array(phase_i) + +for i in range(15): + ax[5 + i].scatter(spikes[i], phase[i]) + ax[5 + i].set_xlim(interval[0], interval[1]) + ax[5 + i].set_ylim(-np.pi, np.pi) + ax[5 + i].set_xlabel("time (s)") + ax[5 + i].set_ylabel("phase") + +plt.tight_layout() +plt.show() \ No newline at end of file diff --git a/pynapple/process/__init__.py b/pynapple/process/__init__.py index 120a3363..08d58648 100644 --- a/pynapple/process/__init__.py +++ b/pynapple/process/__init__.py @@ -26,7 +26,6 @@ ) from .signal_processing import ( compute_wavelet_transform, - compute_wavelet_transform_og, compute_spectrum, compute_welch_spectrum ) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index e2b7bf9f..cc9ab1bc 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -54,151 +54,6 @@ def compute_welch_spectrum(sig, fs=None): return spectogram, freqs -def compute_wavelet_transform(sig, freqs, fs=None): - """ - Compute the time-frequency representation of a signal using morlet wavelets. - - Parameters - ---------- - sig : pynapple.Tsd or pynapple.TsdFrame - Time series. - freqs : 1d array or list of float - If array, frequency values to estimate with morlet wavelets. - If list, define the frequency range, as [freq_start, freq_stop, freq_step]. - The `freq_step` is optional, and defaults to 1. Range is inclusive of `freq_stop` value. - fs : float, optional - Sampling rate, in Hz. If None, will be calculated from the given signal - - Returns - ------- - mwt : 2d array - Time frequency representation of the input signal. - - Notes - ----- - This computes the continuous wavelet transform at specified frequencies across time. - """ - if not isinstance(sig, nap.Tsd) and not isinstance(sig, nap.TsdFrame): - raise TypeError("`sig` must be instance of Tsd, TsdFrame, or TsdTensor") - if fs is None: - fs = sig.index.shape[0]/(sig.index.max() - sig.index.min()) - assert fs/2 > np.max(freqs), "`freqs` contains values over the Nyquist frequency." - if isinstance(sig, nap.Tsd): - mwt, f = _cwt(sig, - freqs=freqs, - wavelet="cmor1.5-1.0", - sampling_period=1/fs) - return nap.TsdFrame(t=sig.index, d=np.transpose(mwt), time_support=sig.time_support) - elif isinstance(sig, nap.TsdFrame): - mwt = np.zeros( - [sig.values.shape[0], len(freqs), sig.values.shape[1]], dtype=complex - ) - for channel_i in range(sig.values.shape[1]): - mwt[:, :, channel_i] = np.transpose(_cwt(sig[:, channel_i], - freqs=freqs, - wavelet="cmor1.5-1.0", - sampling_period=1/fs)[0]) - return nap.TsdTensor(t=sig.index, d=mwt, time_support=sig.time_support) - elif isinstance(sig, nap.TsdTensor): - raise NotImplemented("cwt for TsdTensor is not yet implemented") - - -def _cwt(data, freqs, wavelet, sampling_period, axis=-1): - """ - cwt(data, scales, wavelet) - - One dimensional Continuous Wavelet Transform. - - Parameters - ---------- - data : pynapple.Tsd or pynapple.TsdFrame - Input time series signal. - freqs : 1d array - Frequency values to estimate with morlet wavelets. - wavelet : Wavelet object or name - Wavelet to use, only implemented for "cmor1.5-1.0". - sampling_period : float - Sampling period for the frequencies output. - The values computed for ``coefs`` are independent of the choice of - ``sampling_period`` (i.e. ``scales`` is not scaled by the sampling - period). - axis: int, optional - Axis over which to compute the CWT. If not given, the last axis is - used. - - Returns - ------- - coefs : array_like - Continuous wavelet transform of the input signal for the given scales - and wavelet. The first axis of ``coefs`` corresponds to the scales. - The remaining axes match the shape of ``data``. - frequencies : array_like - If the unit of sampling period are seconds and given, then frequencies - are in hertz. Otherwise, a sampling period of 1 is assumed. - - ..todo:: This should use pynapple convolve but currently that cannot handle imaginary numbers as it uses scipy convolve - """ - int_psi = np.array(WAVELET_DICT[wavelet]['int_psi_real'])*1j + np.array(WAVELET_DICT[wavelet]['int_psi_imag']) - x = np.array(WAVELET_DICT[wavelet]["x"]) - central_freq = WAVELET_DICT[wavelet]["central_freq"] - scales = central_freq/(freqs*sampling_period) - out = np.empty((np.size(scales),) + data.shape, dtype=np.complex128) - - if data.ndim > 1: - # move axis to be transformed last (so it is contiguous) - data = data.swapaxes(-1, axis) - # reshape to (n_batch, data.shape[-1]) - data_shape_pre = data.shape - data = data.reshape((-1, data.shape[-1])) - - for i, scale in enumerate(scales): - step = x[1] - x[0] - j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * step) - j = j.astype(int) # floor - if j[-1] >= int_psi.size: - j = np.extract(j < int_psi.size, j) - int_psi_scale = int_psi[j][::-1] - - if data.ndim == 1: - conv = np.convolve(data, int_psi_scale) - else: - # batch convolution via loop - conv_shape = list(data.shape) - conv_shape[-1] += int_psi_scale.size - 1 - conv_shape = tuple(conv_shape) - conv = np.empty(conv_shape, dtype=np.complex128) - for n in range(data.shape[0]): - conv[n, :] = np.convolve(data[n], int_psi_scale) - - coef = - np.sqrt(scale) * np.diff(conv, axis=-1) - if out.dtype.kind != 'c': - coef = coef.real - # transform axis is always -1 due to the data reshape above - d = (coef.shape[-1] - data.shape[-1]) / 2. - if d > 0: - coef = coef[..., floor(d):-ceil(d)] - elif d < 0: - raise ValueError( - f"Selected scale of {scale} too small.") - if data.ndim > 1: - # restore original data shape and axis position - coef = coef.reshape(data_shape_pre) - coef = coef.swapaxes(axis, -1) - out[i, ...] = coef - - frequencies = central_freq/scales - if np.isscalar(frequencies): - frequencies = np.array([frequencies]) - frequencies /= sampling_period - return out, frequencies - - - - - - -# ------------------------------------------------------------------------------- - def morlet(M=1024, ncycles=1.5, scaling=1.0, precision=8): """ Defines the complex Morelet wavelet kernel @@ -222,13 +77,6 @@ def morlet(M=1024, ncycles=1.5, scaling=1.0, precision=8): x = np.linspace(-precision, precision, M) return ((np.pi*ncycles) ** (-0.25)) * np.exp(-x**2 / ncycles) * np.exp(1j * 2*np.pi * scaling * x) -""" -The following code has been adapted from functions in the neurodsp package: -https://github.com/neurodsp-tools/neurodsp - -..todo: reference licence in LICENCE directory -""" - def _check_n_cycles(n_cycles, len_cycles=None): """ Check an input as a number of cycles, and make it iterable. @@ -285,12 +133,10 @@ def _create_freqs(freq_start, freq_stop, freq_step=1): return np.arange(freq_start, freq_stop + freq_step, freq_step) -def compute_wavelet_transform_og(sig, fs, freqs, n_cycles=7, scaling=0.5, norm="amp"): +def compute_wavelet_transform(sig, fs, freqs, n_cycles=1.5, scaling=1.0, norm="amp"): """ Compute the time-frequency representation of a signal using morlet wavelets. - ..todo:: better normalization between channels - Parameters ---------- sig : pynapple.Tsd or pynapple.TsdFrame @@ -325,6 +171,8 @@ def compute_wavelet_transform_og(sig, fs, freqs, n_cycles=7, scaling=0.5, norm=" raise TypeError("`sig` must be instance of Tsd or TsdFrame") if isinstance(freqs, (tuple, list)): freqs = _create_freqs(*freqs) + if fs is None: + fs = sig.index.shape[0]/(sig.index.max() - sig.index.min()) n_cycles = _check_n_cycles(n_cycles, len(freqs)) if isinstance(sig, nap.Tsd): mwt = np.zeros([len(freqs), len(sig)], dtype=complex) @@ -349,7 +197,7 @@ def compute_wavelet_transform_og(sig, fs, freqs, n_cycles=7, scaling=0.5, norm=" def _convolve_wavelet( - sig, fs, freq, n_cycles=7, scaling=0.5, wavelet_len=None, norm="sss" + sig, fs, freq, n_cycles=1.5, scaling=1.0, precision=10, norm="sss" ): """ Convolve a signal with a complex wavelet. @@ -366,8 +214,6 @@ def _convolve_wavelet( Length of the filter, as the number of cycles of the oscillation with specified frequency. scaling : float, optional, default: 0.5 Scaling factor for the morlet wavelet. - wavelet_len : int, optional - Length of the wavelet. If defined, this overrides the freq and n_cycles inputs. norm : {'sss', 'amp'}, optional Normalization method: @@ -384,21 +230,31 @@ def _convolve_wavelet( * The real part of the returned array is the filtered signal. * Taking np.abs() of output gives the analytic amplitude. - * Taking np.angle() of output gives the analytic phase. ..todo: this this still true? + * Taking np.angle() of output gives the analytic phase. """ if norm not in ["sss", "amp"]: raise ValueError("Given `norm` must be `sss` or `amp`") - if wavelet_len is None: - wavelet_len = int(n_cycles * fs / freq) - if wavelet_len > sig.shape[-1]: + morlet_f = morlet(int(2**precision), ncycles=n_cycles, scaling=scaling) + x = np.linspace(-8, 8, int(2**precision)) + int_psi = np.conj(_integrate(morlet_f, x[1] - x[0])) + scale = scaling / (freq/fs) + j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * (x[1] - x[0])) + j = j.astype(int) # floor + if j[-1] >= int_psi.size: + j = np.extract(j < int_psi.size, j) + int_psi_scale = int_psi[j][::-1] + conv = np.convolve(sig, int_psi_scale) + coef = - np.sqrt(scale) * np.diff(conv, axis=-1) + # transform axis is always -1 due to the data reshape above + d = (coef.shape[-1] - sig.shape[-1]) / 2. + if d > 0: + coef = coef[..., floor(d):-ceil(d)] + elif d < 0: raise ValueError( - "The length of the wavelet is greater than the signal. Can not proceed." - ) - morlet_f = morlet(wavelet_len, ncycles=n_cycles, scaling=scaling) - if norm == "sss": - morlet_f = morlet_f / np.sqrt(np.sum(np.abs(morlet_f) ** 2)) - elif norm == "amp": - morlet_f = morlet_f / np.sum(np.abs(morlet_f)) - mwt_real = sig.convolve(np.real(morlet_f)) - mwt_imag = sig.convolve(np.imag(morlet_f)) - return mwt_real.values + 1j * mwt_imag.values \ No newline at end of file + f"Selected scale of {scale} too small.") + return coef + +def _integrate(arr, step): + integral = np.cumsum(arr) + integral *= step + return integral From 01c5435d0fd5be04f5c83ee812e242a0193a9d09 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Fri, 28 Jun 2024 16:36:51 +0100 Subject: [PATCH 05/71] tutorial cleaning --- docs/examples/tutorial_signal_processing.py | 69 ++++++++++----------- 1 file changed, 32 insertions(+), 37 deletions(-) diff --git a/docs/examples/tutorial_signal_processing.py b/docs/examples/tutorial_signal_processing.py index c0d585ba..6955df5c 100644 --- a/docs/examples/tutorial_signal_processing.py +++ b/docs/examples/tutorial_signal_processing.py @@ -22,9 +22,8 @@ import numpy as np import pynapple as nap -import pandas as pd import os -#import requests +import requests from zipfile import ZipFile import matplotlib.pyplot as plt import matplotlib @@ -35,20 +34,20 @@ # Downloading the data # ------------------ # First things first: Let's download and extract the data -# path = "data/A2929-200711" -# extract_to = "data" -# if extract_to not in os.listdir("."): -# os.mkdir(extract_to) -# if path not in os.listdir("."): -# # Download the file -# response = requests.get("https://www.dropbox.com/s/su4oaje57g3kit9/A2929-200711.zip?dl=1") -# zip_path = os.path.join(extract_to, '/downloaded_file.zip') -# # Write the zip file to disk -# with open(zip_path, 'wb') as f: -# f.write(response.content) -# # Unzip the file -# with ZipFile(zip_path, 'r') as zip_ref: -# zip_ref.extractall(extract_to) +path = "data/A2929-200711" +extract_to = "data" +if extract_to not in os.listdir("."): + os.mkdir(extract_to) +if path not in os.listdir("."): +# Download the file + response = requests.get("https://www.dropbox.com/s/su4oaje57g3kit9/A2929-200711.zip?dl=1") + zip_path = os.path.join(extract_to, '/downloaded_file.zip') + # Write the zip file to disk + with open(zip_path, 'wb') as f: + f.write(response.content) + # Unzip the file + with ZipFile(zip_path, 'r') as zip_ref: + zip_ref.extractall(extract_to) # %% @@ -56,22 +55,22 @@ # Parsing the data # ------------------ # Now that we have the data, we must append the 2kHz LFP recording to the .nwb file -# eeg_path = "data/A2929-200711/A2929-200711.dat" -# frequency = 20000 # Hz -# n_channels = 16 -# f = open(eeg_path, 'rb') -# startoffile = f.seek(0, 0) -# endoffile = f.seek(0, 2) -# f.close() -# bytes_size = 2 -# n_samples = int((endoffile-startoffile)/n_channels/bytes_size) -# duration = n_samples/frequency -# interval = 1/frequency -# fp = np.memmap(eeg_path, np.int16, 'r', shape = (n_samples, n_channels)) -# timestep = np.arange(0, n_samples)/frequency -# eeg = nap.TsdFrame(t=timestep, d=fp) -# nap.append_NWB_LFP("data/A2929-200711/", -# eeg) +eeg_path = "data/A2929-200711/A2929-200711.dat" +frequency = 20000 # Hz +n_channels = 16 +f = open(eeg_path, 'rb') +startoffile = f.seek(0, 0) +endoffile = f.seek(0, 2) +f.close() +bytes_size = 2 +n_samples = int((endoffile-startoffile)/n_channels/bytes_size) +duration = n_samples/frequency +interval = 1/frequency +fp = np.memmap(eeg_path, np.int16, 'r', shape = (n_samples, n_channels)) +timestep = np.arange(0, n_samples)/frequency +eeg = nap.TsdFrame(t=timestep, d=fp) +nap.append_NWB_LFP("data/A2929-200711/", + eeg) # %% @@ -80,17 +79,13 @@ # data = nap.load_file("data/A2929-200711/pynapplenwb/A2929-200711.nwb") data = nap.load_file("data/stable.nwb") print(data["ElectricalSeries"]) -# normed_electrical_series = data["ElectricalSeries"].values -# normed_electrical_series = normed_electrical_series[:, :] -# normed_electrical_series[:, :10] = normed_electrical_series[:, :10] - np.expand_dims(np.mean(normed_electrical_series[:, :10], axis=1), axis=1) -# normed_electrical_series[:, 10:] = normed_electrical_series[:, 10:] - np.expand_dims(np.mean(normed_electrical_series[:, 10:], axis=1), axis=1) -NES = nap.TsdFrame(t=data["ElectricalSeries"].index.values, d=data["ElectricalSeries"].values, time_support=data["ElectricalSeries"].time_support) # %% # *** # Selecting slices # ----------------------------------- # Let's consider a two 1-second slices of data, one from the sleep epoch and one from wake +NES = nap.TsdFrame(t=data["ElectricalSeries"].index.values, d=data["ElectricalSeries"].values, time_support=data["ElectricalSeries"].time_support) wake_minute = NES.value_from(NES, nap.IntervalSet(900,960)) sleep_minute = NES.value_from(NES, nap.IntervalSet(0,60)) From 027878e71e6c6bb290c614a7e26b4e0f370451c9 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Fri, 28 Jun 2024 16:41:44 +0100 Subject: [PATCH 06/71] linting --- docs/examples/tutorial_signal_processing.py | 168 +++++++++++++------- pynapple/process/signal_processing.py | 54 ++++--- 2 files changed, 141 insertions(+), 81 deletions(-) diff --git a/docs/examples/tutorial_signal_processing.py b/docs/examples/tutorial_signal_processing.py index 6955df5c..b3e2f9af 100644 --- a/docs/examples/tutorial_signal_processing.py +++ b/docs/examples/tutorial_signal_processing.py @@ -20,14 +20,14 @@ # # Now, import the necessary libraries: -import numpy as np -import pynapple as nap import os -import requests from zipfile import ZipFile + import matplotlib.pyplot as plt -import matplotlib -matplotlib.use("TkAgg") +import numpy as np +import requests + +import pynapple as nap # %% # *** @@ -39,14 +39,16 @@ if extract_to not in os.listdir("."): os.mkdir(extract_to) if path not in os.listdir("."): -# Download the file - response = requests.get("https://www.dropbox.com/s/su4oaje57g3kit9/A2929-200711.zip?dl=1") - zip_path = os.path.join(extract_to, '/downloaded_file.zip') + # Download the file + response = requests.get( + "https://www.dropbox.com/s/su4oaje57g3kit9/A2929-200711.zip?dl=1" + ) + zip_path = os.path.join(extract_to, "/downloaded_file.zip") # Write the zip file to disk - with open(zip_path, 'wb') as f: + with open(zip_path, "wb") as f: f.write(response.content) # Unzip the file - with ZipFile(zip_path, 'r') as zip_ref: + with ZipFile(zip_path, "r") as zip_ref: zip_ref.extractall(extract_to) @@ -56,21 +58,20 @@ # ------------------ # Now that we have the data, we must append the 2kHz LFP recording to the .nwb file eeg_path = "data/A2929-200711/A2929-200711.dat" -frequency = 20000 # Hz +frequency = 20000 # Hz n_channels = 16 -f = open(eeg_path, 'rb') +f = open(eeg_path, "rb") startoffile = f.seek(0, 0) endoffile = f.seek(0, 2) f.close() bytes_size = 2 -n_samples = int((endoffile-startoffile)/n_channels/bytes_size) -duration = n_samples/frequency -interval = 1/frequency -fp = np.memmap(eeg_path, np.int16, 'r', shape = (n_samples, n_channels)) -timestep = np.arange(0, n_samples)/frequency +n_samples = int((endoffile - startoffile) / n_channels / bytes_size) +duration = n_samples / frequency +interval = 1 / frequency +fp = np.memmap(eeg_path, np.int16, "r", shape=(n_samples, n_channels)) +timestep = np.arange(0, n_samples) / frequency eeg = nap.TsdFrame(t=timestep, d=fp) -nap.append_NWB_LFP("data/A2929-200711/", - eeg) +nap.append_NWB_LFP("data/A2929-200711/", eeg) # %% @@ -85,9 +86,13 @@ # Selecting slices # ----------------------------------- # Let's consider a two 1-second slices of data, one from the sleep epoch and one from wake -NES = nap.TsdFrame(t=data["ElectricalSeries"].index.values, d=data["ElectricalSeries"].values, time_support=data["ElectricalSeries"].time_support) -wake_minute = NES.value_from(NES, nap.IntervalSet(900,960)) -sleep_minute = NES.value_from(NES, nap.IntervalSet(0,60)) +NES = nap.TsdFrame( + t=data["ElectricalSeries"].index.values, + d=data["ElectricalSeries"].values, + time_support=data["ElectricalSeries"].time_support, +) +wake_minute = NES.value_from(NES, nap.IntervalSet(900, 960)) +sleep_minute = NES.value_from(NES, nap.IntervalSet(0, 60)) # %% # *** @@ -96,10 +101,17 @@ # Let's plot fig, ax = plt.subplots(2) for channel in range(sleep_minute.shape[1]): - ax[0].plot(sleep_minute.index.values, sleep_minute[:, channel], alpha=0.5, label="Sleep Data") + ax[0].plot( + sleep_minute.index.values, + sleep_minute[:, channel], + alpha=0.5, + label="Sleep Data", + ) ax[0].set_title("Sleep ephys") for channel in range(wake_minute.shape[1]): - ax[1].plot(wake_minute.index.values, wake_minute[:, channel], alpha=0.5, label="Wake Data") + ax[1].plot( + wake_minute.index.values, wake_minute[:, channel], alpha=0.5, label="Wake Data" + ) ax[1].set_title("Wake ephys") plt.show() @@ -109,12 +121,10 @@ # Let's take the Fourier transforms of one channel for both and see if differences are present channel = 1 fig, ax = plt.subplots(1) -fft_sig, fft_freqs = nap.compute_spectrum(sleep_minute[:, channel], - fs=int(FS)) +fft_sig, fft_freqs = nap.compute_spectrum(sleep_minute[:, channel], fs=int(FS)) ax.plot(fft_freqs, np.abs(fft_sig), alpha=0.5, label="Sleep Data") -ax.set_xlim((0, FS/2 )) -fft_sig, fft_freqs = nap.compute_spectrum(wake_minute[:, channel], - fs=int(FS)) +ax.set_xlim((0, FS / 2)) +fft_sig, fft_freqs = nap.compute_spectrum(wake_minute[:, channel], fs=int(FS)) ax.plot(fft_freqs, np.abs(fft_sig), alpha=0.5, label="Wake Data") ax.set_title(f"Fourier Decomposition for channel {channel}") ax.legend() @@ -125,13 +135,14 @@ # There seems to be some differences presenting themselves - a bump in higher frequencies for waking data? # Let's explore further with a wavelet decomposition + def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kwargs): if np.iscomplexobj(powers): powers = abs(powers) - ax.imshow(powers, aspect='auto', **kwargs) + ax.imshow(powers, aspect="auto", **kwargs) ax.invert_yaxis() - ax.set_xlabel('Time (s)') - ax.set_ylabel('Frequency (Hz)') + ax.set_xlabel("Time (s)") + ax.set_ylabel("Frequency (Hz)") if isinstance(x_ticks, int): x_tick_pos = np.linspace(0, times.size, x_ticks) x_ticks = np.round(np.linspace(times[0], times[-1], x_ticks), 2) @@ -145,22 +156,43 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kw y_ticks_pos = [np.argmin(np.abs(freqs - val)) for val in y_ticks] ax.set(yticks=y_ticks_pos, yticklabels=y_ticks) + fig, ax = plt.subplots(2) -freqs = np.array([2.59, 3.66, 5.18, 8.0, 10.36, 14.65, 20.72, 29.3, 41.44, 58.59, 82.88, 117.19, - 165.75, 234.38, 331.5, 468.75, 624., ]) +freqs = np.array( + [ + 2.59, + 3.66, + 5.18, + 8.0, + 10.36, + 14.65, + 20.72, + 29.3, + 41.44, + 58.59, + 82.88, + 117.19, + 165.75, + 234.38, + 331.5, + 468.75, + 624.0, + ] +) mwt_sleep = nap.compute_wavelet_transform( - sleep_minute[:, channel], - fs=None, - freqs=freqs - ) -plot_timefrequency(sleep_minute.index.values[:], freqs[:], np.transpose(mwt_sleep[:,:].values), ax=ax[0]) + sleep_minute[:, channel], fs=None, freqs=freqs +) +plot_timefrequency( + sleep_minute.index.values[:], + freqs[:], + np.transpose(mwt_sleep[:, :].values), + ax=ax[0], +) ax[0].set_title(f"Sleep Data Wavelet Decomposition: Channel {channel}") -mwt_wake = nap.compute_wavelet_transform( - wake_minute[:, channel], - fs=None, - freqs=freqs - ) -plot_timefrequency(wake_minute.index.values[:], freqs[:], np.transpose(mwt_wake[:,:].values), ax=ax[1]) +mwt_wake = nap.compute_wavelet_transform(wake_minute[:, channel], fs=None, freqs=freqs) +plot_timefrequency( + wake_minute.index.values[:], freqs[:], np.transpose(mwt_wake[:, :].values), ax=ax[1] +) ax[1].set_title(f"Wake Data Wavelet Decomposition: Channel {channel}") plt.margins(0) plt.show() @@ -169,11 +201,18 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kw # Let's focus on the waking data. Let's see if we can isolate the theta oscillations from the data freq = 3 interval = (937, 939) -wake_second = wake_minute.value_from(wake_minute, nap.IntervalSet(interval[0],interval[1])) +wake_second = wake_minute.value_from( + wake_minute, nap.IntervalSet(interval[0], interval[1]) +) fig, ax = plt.subplots(1) ax.plot(wake_second.index.values, wake_second[:, channel], alpha=0.5, label="Wake Data") -ax.plot(wake_second.index.values, - mwt_wake.value_from(mwt_wake, nap.IntervalSet(interval[0],interval[1]))[:, freq].values.real, label="Theta oscillations") +ax.plot( + wake_second.index.values, + mwt_wake.value_from(mwt_wake, nap.IntervalSet(interval[0], interval[1]))[ + :, freq + ].values.real, + label="Theta oscillations", +) ax.set_title(f"{freqs[freq]}Hz oscillation power.") plt.show() @@ -183,11 +222,20 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kw freq = 0 # interval = (10, 15) interval = (20, 25) -sleep_second = sleep_minute.value_from(sleep_minute, nap.IntervalSet(interval[0],interval[1])) +sleep_second = sleep_minute.value_from( + sleep_minute, nap.IntervalSet(interval[0], interval[1]) +) _, ax = plt.subplots(1) -ax.plot(sleep_second.index.values, sleep_second[:, channel], alpha=0.5, label="Wake Data") -ax.plot(sleep_second.index.values, - mwt_sleep.value_from(mwt_sleep, nap.IntervalSet(interval[0],interval[1]))[:, freq].values.real, label="Slow Wave Oscillations") +ax.plot( + sleep_second.index.values, sleep_second[:, channel], alpha=0.5, label="Wake Data" +) +ax.plot( + sleep_second.index.values, + mwt_sleep.value_from(mwt_sleep, nap.IntervalSet(interval[0], interval[1]))[ + :, freq + ].values.real, + label="Slow Wave Oscillations", +) ax.set_title(f"{freqs[freq]}Hz oscillation power") plt.show() @@ -195,7 +243,9 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kw # Let's plot spike phase, time scatter plots to see if spikes display phase characteristics during slow wave sleep _, ax = plt.subplots(20, figsize=(10, 50)) -mwt_sleep = np.transpose(mwt_sleep.value_from(mwt_sleep, nap.IntervalSet(interval[0],interval[1]))) +mwt_sleep = np.transpose( + mwt_sleep.value_from(mwt_sleep, nap.IntervalSet(interval[0], interval[1])) +) ax[0].plot(sleep_second.index, sleep_second.values[:, 0]) plot_timefrequency(sleep_second.index, freqs, np.abs(mwt_sleep[:, :]), ax=ax[1]) @@ -210,13 +260,19 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kw spikes = {} for i in data["units"].index: spikes[i] = data["units"][i].times()[ - (data["units"][i].times() > interval[0]) & (data["units"][i].times() < interval[1])] + (data["units"][i].times() > interval[0]) + & (data["units"][i].times() < interval[1]) + ] phase = {} for i in spikes.keys(): phase_i = [] for spike in spikes[i]: - phase_i.append(np.angle(mwt_sleep[freq, np.argmin(np.abs(sleep_second.index.values - spike))])) + phase_i.append( + np.angle( + mwt_sleep[freq, np.argmin(np.abs(sleep_second.index.values - spike))] + ) + ) phase[i] = np.array(phase_i) for i in range(15): @@ -227,4 +283,4 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kw ax[5 + i].set_ylabel("phase") plt.tight_layout() -plt.show() \ No newline at end of file +plt.show() diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index cc9ab1bc..21ab311c 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -4,21 +4,23 @@ Contains functionality for signal processing pynapple object; fourier transforms and wavelet decomposition. """ -import numpy as np -import pynapple as nap -from math import ceil, floor import json -from scipy.signal import welch from itertools import repeat +from math import ceil, floor + +import numpy as np +from scipy.signal import welch -with open('wavelets.json') as f: +import pynapple as nap + +with open("wavelets.json") as f: WAVELET_DICT = json.load(f) def compute_spectrum(sig, fs=None): - """ - Performs numpy fft on sig, returns output - ..todo: Make sig handle TsdFrame, TsdTensor + """ + Performs numpy fft on sig, returns output + ..todo: Make sig handle TsdFrame, TsdTensor ---------- sig : pynapple.Tsd or pynapple.TsdFrame @@ -29,7 +31,7 @@ def compute_spectrum(sig, fs=None): if not isinstance(sig, nap.Tsd): raise TypeError("Currently compute_fft is only implemented for Tsd") if fs is None: - fs = sig.index.shape[0]/(sig.index.max() - sig.index.min()) + fs = sig.index.shape[0] / (sig.index.max() - sig.index.min()) fft_result = np.fft.fft(sig.values) fft_freq = np.fft.fftfreq(len(sig.values), 1 / fs) return fft_result, fft_freq @@ -49,7 +51,7 @@ def compute_welch_spectrum(sig, fs=None): if not isinstance(sig, nap.Tsd): raise TypeError("Currently compute_fft is only implemented for Tsd") if fs is None: - fs = sig.index.shape[0]/(sig.index.max() - sig.index.min()) + fs = sig.index.shape[0] / (sig.index.max() - sig.index.min()) freqs, spectogram = welch(sig.values, fs=fs) return spectogram, freqs @@ -75,7 +77,12 @@ def morlet(M=1024, ncycles=1.5, scaling=1.0, precision=8): Morelet wavelet kernel """ x = np.linspace(-precision, precision, M) - return ((np.pi*ncycles) ** (-0.25)) * np.exp(-x**2 / ncycles) * np.exp(1j * 2*np.pi * scaling * x) + return ( + ((np.pi * ncycles) ** (-0.25)) + * np.exp(-(x**2) / ncycles) + * np.exp(1j * 2 * np.pi * scaling * x) + ) + def _check_n_cycles(n_cycles, len_cycles=None): """ @@ -172,18 +179,15 @@ def compute_wavelet_transform(sig, fs, freqs, n_cycles=1.5, scaling=1.0, norm="a if isinstance(freqs, (tuple, list)): freqs = _create_freqs(*freqs) if fs is None: - fs = sig.index.shape[0]/(sig.index.max() - sig.index.min()) + fs = sig.index.shape[0] / (sig.index.max() - sig.index.min()) n_cycles = _check_n_cycles(n_cycles, len(freqs)) if isinstance(sig, nap.Tsd): mwt = np.zeros([len(freqs), len(sig)], dtype=complex) for ind, (freq, n_cycle) in enumerate(zip(freqs, n_cycles)): - mwt[ind, :] = _convolve_wavelet(sig, - fs, - freq, - n_cycle, - scaling, - norm=norm) - return nap.TsdFrame(t=sig.index, d=np.transpose(mwt), time_support=sig.time_support) + mwt[ind, :] = _convolve_wavelet(sig, fs, freq, n_cycle, scaling, norm=norm) + return nap.TsdFrame( + t=sig.index, d=np.transpose(mwt), time_support=sig.time_support + ) else: mwt = np.zeros( [sig.values.shape[0], len(freqs), sig.values.shape[1]], dtype=complex @@ -237,23 +241,23 @@ def _convolve_wavelet( morlet_f = morlet(int(2**precision), ncycles=n_cycles, scaling=scaling) x = np.linspace(-8, 8, int(2**precision)) int_psi = np.conj(_integrate(morlet_f, x[1] - x[0])) - scale = scaling / (freq/fs) + scale = scaling / (freq / fs) j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * (x[1] - x[0])) j = j.astype(int) # floor if j[-1] >= int_psi.size: j = np.extract(j < int_psi.size, j) int_psi_scale = int_psi[j][::-1] conv = np.convolve(sig, int_psi_scale) - coef = - np.sqrt(scale) * np.diff(conv, axis=-1) + coef = -np.sqrt(scale) * np.diff(conv, axis=-1) # transform axis is always -1 due to the data reshape above - d = (coef.shape[-1] - sig.shape[-1]) / 2. + d = (coef.shape[-1] - sig.shape[-1]) / 2.0 if d > 0: - coef = coef[..., floor(d):-ceil(d)] + coef = coef[..., floor(d) : -ceil(d)] elif d < 0: - raise ValueError( - f"Selected scale of {scale} too small.") + raise ValueError(f"Selected scale of {scale} too small.") return coef + def _integrate(arr, step): integral = np.cumsum(arr) integral *= step From 18b3bc1da0561c7e43ba3ddd5a1b5fa3031b2a0a Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Fri, 28 Jun 2024 16:44:59 +0100 Subject: [PATCH 07/71] more linting --- pynapple/process/__init__.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pynapple/process/__init__.py b/pynapple/process/__init__.py index 08d58648..1cb2f735 100644 --- a/pynapple/process/__init__.py +++ b/pynapple/process/__init__.py @@ -15,6 +15,11 @@ shift_timestamps, shuffle_ts_intervals, ) +from .signal_processing import ( + compute_spectrum, + compute_wavelet_transform, + compute_welch_spectrum, +) from .tuning_curves import ( compute_1d_mutual_info, compute_1d_tuning_curves, @@ -24,8 +29,3 @@ compute_2d_tuning_curves_continuous, compute_discrete_tuning_curves, ) -from .signal_processing import ( - compute_wavelet_transform, - compute_spectrum, - compute_welch_spectrum -) From ebdbe67320d23b5b23f6b52590c59033926e7463 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Fri, 28 Jun 2024 16:50:05 +0100 Subject: [PATCH 08/71] json removal --- pynapple/process/signal_processing.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 21ab311c..5f4e969e 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -13,9 +13,6 @@ import pynapple as nap -with open("wavelets.json") as f: - WAVELET_DICT = json.load(f) - def compute_spectrum(sig, fs=None): """ From 75d3a460525683896de25e3b67e41a2d65b12377 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Fri, 28 Jun 2024 17:52:16 +0100 Subject: [PATCH 09/71] basic tests added --- pynapple/process/signal_processing.py | 4 +-- tests/test_signal_processing.py | 37 +++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 2 deletions(-) create mode 100644 tests/test_signal_processing.py diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 5f4e969e..245825f7 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -53,7 +53,7 @@ def compute_welch_spectrum(sig, fs=None): return spectogram, freqs -def morlet(M=1024, ncycles=1.5, scaling=1.0, precision=8): +def _morlet(M=1024, ncycles=1.5, scaling=1.0, precision=8): """ Defines the complex Morelet wavelet kernel @@ -235,7 +235,7 @@ def _convolve_wavelet( """ if norm not in ["sss", "amp"]: raise ValueError("Given `norm` must be `sss` or `amp`") - morlet_f = morlet(int(2**precision), ncycles=n_cycles, scaling=scaling) + morlet_f = _morlet(int(2**precision), ncycles=n_cycles, scaling=scaling) x = np.linspace(-8, 8, int(2**precision)) int_psi = np.conj(_integrate(morlet_f, x[1] - x[0])) scale = scaling / (freq / fs) diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py new file mode 100644 index 00000000..8da68921 --- /dev/null +++ b/tests/test_signal_processing.py @@ -0,0 +1,37 @@ +"""Tests of `signal_processing` for pynapple""" + +import numpy as np +import pytest + +import pynapple as nap + + +def test_compute_spectrum(): + t = np.linspace(0, 1, 1024) + sig = nap.Tsd(d=np.random.random(1024), t=t) + r = nap.compute_spectrum(sig) + assert len(r[1]) == 1024 + assert len(r[0]) == 1024 + assert r[0].dtype == np.complex128 + assert r[1].dtype == np.float64 + + +def test_compute_welch_spectrum(): + t = np.linspace(0, 1, 1024) + sig = nap.Tsd(d=np.random.random(1024), t=t) + r = nap.compute_welch_spectrum(sig) + assert r[0].dtype == np.float64 + assert r[1].dtype == np.float64 + + +def test_compute_wavelet_transform(): + t = np.linspace(0, 1, 1024) + sig = nap.Tsd(d=np.random.random(1024), t=t) + freqs = np.linspace(1, 600, 10) + mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) + assert mwt.shape == (1024, 10) + + sig = nap.TsdFrame(d=np.random.random((1024, 4)), t=t) + freqs = np.linspace(1, 600, 10) + mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) + assert mwt.shape == (1024, 10, 4) From 334e785fceaf4ba022718dd7899a4dea64da953d Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Mon, 1 Jul 2024 15:08:10 +0100 Subject: [PATCH 10/71] remove unused import --- pynapple/process/signal_processing.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 245825f7..15da4906 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -4,7 +4,6 @@ Contains functionality for signal processing pynapple object; fourier transforms and wavelet decomposition. """ -import json from itertools import repeat from math import ceil, floor From a3ab81cffe06cd3d7e3ae086f155e7f6a35925d3 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Tue, 2 Jul 2024 19:43:30 +0100 Subject: [PATCH 11/71] minor notebook changes --- docs/examples/tutorial_signal_processing.py | 59 ++++++++++----------- 1 file changed, 29 insertions(+), 30 deletions(-) diff --git a/docs/examples/tutorial_signal_processing.py b/docs/examples/tutorial_signal_processing.py index b3e2f9af..c7da10e9 100644 --- a/docs/examples/tutorial_signal_processing.py +++ b/docs/examples/tutorial_signal_processing.py @@ -57,21 +57,21 @@ # Parsing the data # ------------------ # Now that we have the data, we must append the 2kHz LFP recording to the .nwb file -eeg_path = "data/A2929-200711/A2929-200711.dat" -frequency = 20000 # Hz -n_channels = 16 -f = open(eeg_path, "rb") -startoffile = f.seek(0, 0) -endoffile = f.seek(0, 2) -f.close() -bytes_size = 2 -n_samples = int((endoffile - startoffile) / n_channels / bytes_size) -duration = n_samples / frequency -interval = 1 / frequency -fp = np.memmap(eeg_path, np.int16, "r", shape=(n_samples, n_channels)) -timestep = np.arange(0, n_samples) / frequency -eeg = nap.TsdFrame(t=timestep, d=fp) -nap.append_NWB_LFP("data/A2929-200711/", eeg) +# eeg_path = "data/A2929-200711/A2929-200711.dat" +# frequency = 20000 # Hz +# n_channels = 16 +# f = open(eeg_path, "rb") +# startoffile = f.seek(0, 0) +# endoffile = f.seek(0, 2) +# f.close() +# bytes_size = 2 +# n_samples = int((endoffile - startoffile) / n_channels / bytes_size) +# duration = n_samples / frequency +# interval = 1 / frequency +# fp = np.memmap(eeg_path, np.int16, "r", shape=(n_samples, n_channels)) +# timestep = np.arange(0, n_samples) / frequency +# eeg = nap.TsdFrame(t=timestep, d=fp) +# nap.append_NWB_LFP("data/A2929-200711/", eeg) # %% @@ -91,8 +91,9 @@ d=data["ElectricalSeries"].values, time_support=data["ElectricalSeries"].time_support, ) -wake_minute = NES.value_from(NES, nap.IntervalSet(900, 960)) -sleep_minute = NES.value_from(NES, nap.IntervalSet(0, 60)) +wake_minute = NES.restrict(nap.IntervalSet(900, 960)) +sleep_minute = NES.restrict(nap.IntervalSet(0, 60)) +channel = 1 # %% # *** @@ -102,7 +103,6 @@ fig, ax = plt.subplots(2) for channel in range(sleep_minute.shape[1]): ax[0].plot( - sleep_minute.index.values, sleep_minute[:, channel], alpha=0.5, label="Sleep Data", @@ -110,7 +110,9 @@ ax[0].set_title("Sleep ephys") for channel in range(wake_minute.shape[1]): ax[1].plot( - wake_minute.index.values, wake_minute[:, channel], alpha=0.5, label="Wake Data" + wake_minute[:, channel], + alpha=0.5, + label="Wake Data" ) ax[1].set_title("Wake ephys") plt.show() @@ -119,7 +121,6 @@ # %% # There is much shared information between channels, and wake and sleep don't seem visibly different. # Let's take the Fourier transforms of one channel for both and see if differences are present -channel = 1 fig, ax = plt.subplots(1) fft_sig, fft_freqs = nap.compute_spectrum(sleep_minute[:, channel], fs=int(FS)) ax.plot(fft_freqs, np.abs(fft_sig), alpha=0.5, label="Sleep Data") @@ -201,14 +202,13 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kw # Let's focus on the waking data. Let's see if we can isolate the theta oscillations from the data freq = 3 interval = (937, 939) -wake_second = wake_minute.value_from( - wake_minute, nap.IntervalSet(interval[0], interval[1]) -) +wake_second = wake_minute.restrict(nap.IntervalSet(interval[0], interval[1])) +mwt_wake_second = mwt_wake.restrict(nap.IntervalSet(interval[0], interval[1])) fig, ax = plt.subplots(1) ax.plot(wake_second.index.values, wake_second[:, channel], alpha=0.5, label="Wake Data") ax.plot( wake_second.index.values, - mwt_wake.value_from(mwt_wake, nap.IntervalSet(interval[0], interval[1]))[ + mwt_wake_second[ :, freq ].values.real, label="Theta oscillations", @@ -222,16 +222,15 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kw freq = 0 # interval = (10, 15) interval = (20, 25) -sleep_second = sleep_minute.value_from( - sleep_minute, nap.IntervalSet(interval[0], interval[1]) -) +sleep_second = sleep_minute.restrict(nap.IntervalSet(interval[0], interval[1])) +mwt_sleep_second = mwt_sleep.restrict(nap.IntervalSet(interval[0], interval[1])) _, ax = plt.subplots(1) ax.plot( - sleep_second.index.values, sleep_second[:, channel], alpha=0.5, label="Wake Data" + sleep_second[:, channel], alpha=0.5, label="Wake Data" ) ax.plot( sleep_second.index.values, - mwt_sleep.value_from(mwt_sleep, nap.IntervalSet(interval[0], interval[1]))[ + mwt_sleep_second[ :, freq ].values.real, label="Slow Wave Oscillations", @@ -244,7 +243,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kw _, ax = plt.subplots(20, figsize=(10, 50)) mwt_sleep = np.transpose( - mwt_sleep.value_from(mwt_sleep, nap.IntervalSet(interval[0], interval[1])) + mwt_sleep_second ) ax[0].plot(sleep_second.index, sleep_second.values[:, 0]) plot_timefrequency(sleep_second.index, freqs, np.abs(mwt_sleep[:, :]), ax=ax[1]) From 80c12d76a6c0ae3a7168eda74e14fe9dd7c7d558 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Tue, 2 Jul 2024 21:04:15 +0100 Subject: [PATCH 12/71] spectogram now takes tdsframe --- pynapple/process/signal_processing.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 15da4906..ad1a9b8e 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -8,6 +8,7 @@ from math import ceil, floor import numpy as np +import pandas as pd from scipy.signal import welch import pynapple as nap @@ -16,7 +17,7 @@ def compute_spectrum(sig, fs=None): """ Performs numpy fft on sig, returns output - ..todo: Make sig handle TsdFrame, TsdTensor + ..todo: Make sig handle TsdTensor ---------- sig : pynapple.Tsd or pynapple.TsdFrame @@ -24,13 +25,13 @@ def compute_spectrum(sig, fs=None): fs : float, optional Sampling rate, in Hz. If None, will be calculated from the given signal """ - if not isinstance(sig, nap.Tsd): - raise TypeError("Currently compute_fft is only implemented for Tsd") + if not isinstance(sig, (nap.Tsd, nap.TsdFrame)): + raise TypeError("Currently compute_fft is only implemented for Tsd or TsdFrame") if fs is None: fs = sig.index.shape[0] / (sig.index.max() - sig.index.min()) - fft_result = np.fft.fft(sig.values) + fft_result = np.fft.fft(sig.values, axis=0) fft_freq = np.fft.fftfreq(len(sig.values), 1 / fs) - return fft_result, fft_freq + return pd.DataFrame(fft_result, fft_freq) def compute_welch_spectrum(sig, fs=None): From c1a5a26eb31e058f277375fc81e3cebe4abe4084 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Wed, 3 Jul 2024 16:49:26 +0100 Subject: [PATCH 13/71] review changes --- docs/examples/tutorial_signal_processing.py | 144 ++++++++++---------- pynapple/process/__init__.py | 4 +- pynapple/process/signal_processing.py | 12 +- 3 files changed, 81 insertions(+), 79 deletions(-) diff --git a/docs/examples/tutorial_signal_processing.py b/docs/examples/tutorial_signal_processing.py index c7da10e9..a4d58520 100644 --- a/docs/examples/tutorial_signal_processing.py +++ b/docs/examples/tutorial_signal_processing.py @@ -20,12 +20,8 @@ # # Now, import the necessary libraries: -import os -from zipfile import ZipFile - import matplotlib.pyplot as plt import numpy as np -import requests import pynapple as nap @@ -33,51 +29,28 @@ # *** # Downloading the data # ------------------ -# First things first: Let's download and extract the data -path = "data/A2929-200711" -extract_to = "data" -if extract_to not in os.listdir("."): - os.mkdir(extract_to) -if path not in os.listdir("."): - # Download the file - response = requests.get( - "https://www.dropbox.com/s/su4oaje57g3kit9/A2929-200711.zip?dl=1" - ) - zip_path = os.path.join(extract_to, "/downloaded_file.zip") - # Write the zip file to disk - with open(zip_path, "wb") as f: - f.write(response.content) - # Unzip the file - with ZipFile(zip_path, "r") as zip_ref: - zip_ref.extractall(extract_to) - - -# %% -# *** -# Parsing the data -# ------------------ -# Now that we have the data, we must append the 2kHz LFP recording to the .nwb file -# eeg_path = "data/A2929-200711/A2929-200711.dat" -# frequency = 20000 # Hz -# n_channels = 16 -# f = open(eeg_path, "rb") -# startoffile = f.seek(0, 0) -# endoffile = f.seek(0, 2) -# f.close() -# bytes_size = 2 -# n_samples = int((endoffile - startoffile) / n_channels / bytes_size) -# duration = n_samples / frequency -# interval = 1 / frequency -# fp = np.memmap(eeg_path, np.int16, "r", shape=(n_samples, n_channels)) -# timestep = np.arange(0, n_samples) / frequency -# eeg = nap.TsdFrame(t=timestep, d=fp) -# nap.append_NWB_LFP("data/A2929-200711/", eeg) +# First things first: Let's download and extract the data - currently commented as correct NWB is not online +# path = "data/A2929-200711" +# extract_to = "data" +# if extract_to not in os.listdir("."): +# os.mkdir(extract_to) +# if path not in os.listdir("."): +# # Download the file +# response = requests.get( +# "https://www.dropbox.com/s/su4oaje57g3kit9/A2929-200711.zip?dl=1" +# ) +# zip_path = os.path.join(extract_to, "/downloaded_file.zip") +# # Write the zip file to disk +# with open(zip_path, "wb") as f: +# f.write(response.content) +# # Unzip the file +# with ZipFile(zip_path, "r") as zip_ref: +# zip_ref.extractall(extract_to) # %% # Let's save the RoiResponseSeries as a variable called 'transients' and print it FS = 1250 -# data = nap.load_file("data/A2929-200711/pynapplenwb/A2929-200711.nwb") data = nap.load_file("data/stable.nwb") print(data["ElectricalSeries"]) @@ -109,29 +82,66 @@ ) ax[0].set_title("Sleep ephys") for channel in range(wake_minute.shape[1]): - ax[1].plot( - wake_minute[:, channel], - alpha=0.5, - label="Wake Data" - ) + ax[1].plot(wake_minute[:, channel], alpha=0.5, label="Wake Data") ax[1].set_title("Wake ephys") plt.show() # %% -# There is much shared information between channels, and wake and sleep don't seem visibly different. -# Let's take the Fourier transforms of one channel for both and see if differences are present -fig, ax = plt.subplots(1) -fft_sig, fft_freqs = nap.compute_spectrum(sleep_minute[:, channel], fs=int(FS)) -ax.plot(fft_freqs, np.abs(fft_sig), alpha=0.5, label="Sleep Data") -ax.set_xlim((0, FS / 2)) -fft_sig, fft_freqs = nap.compute_spectrum(wake_minute[:, channel], fs=int(FS)) -ax.plot(fft_freqs, np.abs(fft_sig), alpha=0.5, label="Wake Data") -ax.set_title(f"Fourier Decomposition for channel {channel}") -ax.legend() +# Let's take the Fourier transforms of one channel for both waking and sleeping and see if differences are present +channel = 1 +fig, ax = plt.subplots(2) +fft = nap.compute_spectogram(sleep_minute, fs=int(FS)) +ax[0].plot( + fft.index, np.abs(fft.iloc[:, channel]), alpha=0.5, label="Sleep Data", c="blue" +) +ax[0].set_xlim((0, FS / 2)) +ax[0].set_xlabel("Freq (Hz)") +ax[0].set_ylabel("Frequency Power") + +ax[0].set_title("Sleep LFP Decomposition") +fft = nap.compute_spectogram(wake_minute, fs=int(FS)) +ax[1].plot( + fft.index, np.abs(fft.iloc[:, channel]), alpha=0.5, label="Wake Data", c="orange" +) +ax[1].set_xlim((0, FS / 2)) +fig.suptitle(f"Fourier Decomposition for channel {channel}") +ax[1].set_title("Sleep LFP Decomposition") +ax[1].set_xlabel("Freq (Hz)") +ax[1].set_ylabel("Frequency Power") + +# ax.legend() plt.show() +# %% +# Let's now consider the Welch spectograms of waking and sleeping data... + +fig, ax = plt.subplots(2) +fft = nap.compute_welch_spectogram(sleep_minute, fs=int(FS)) +ax[0].plot( + fft.index, np.abs(fft.iloc[:, channel]), alpha=0.5, label="Sleep Data", color="blue" +) +ax[0].set_xlim((0, FS / 2)) +ax[0].set_title("Sleep LFP Decomposition") +ax[0].set_xlabel("Freq (Hz)") +ax[0].set_ylabel("Frequency Power") +welch = nap.compute_welch_spectogram(wake_minute, fs=int(FS)) +ax[1].plot( + welch.index, + np.abs(welch.iloc[:, channel]), + alpha=0.5, + label="Wake Data", + color="orange", +) +ax[1].set_xlim((0, FS / 2)) +fig.suptitle(f"Welch Decomposition for channel {channel}") +ax[1].set_title("Sleep LFP Decomposition") +ax[1].set_xlabel("Freq (Hz)") +ax[1].set_ylabel("Frequency Power") +# ax.legend() +plt.show() + # %% # There seems to be some differences presenting themselves - a bump in higher frequencies for waking data? # Let's explore further with a wavelet decomposition @@ -208,9 +218,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kw ax.plot(wake_second.index.values, wake_second[:, channel], alpha=0.5, label="Wake Data") ax.plot( wake_second.index.values, - mwt_wake_second[ - :, freq - ].values.real, + mwt_wake_second[:, freq].values.real, label="Theta oscillations", ) ax.set_title(f"{freqs[freq]}Hz oscillation power.") @@ -225,14 +233,10 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kw sleep_second = sleep_minute.restrict(nap.IntervalSet(interval[0], interval[1])) mwt_sleep_second = mwt_sleep.restrict(nap.IntervalSet(interval[0], interval[1])) _, ax = plt.subplots(1) -ax.plot( - sleep_second[:, channel], alpha=0.5, label="Wake Data" -) +ax.plot(sleep_second[:, channel], alpha=0.5, label="Wake Data") ax.plot( sleep_second.index.values, - mwt_sleep_second[ - :, freq - ].values.real, + mwt_sleep_second[:, freq].values.real, label="Slow Wave Oscillations", ) ax.set_title(f"{freqs[freq]}Hz oscillation power") @@ -242,9 +246,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kw # Let's plot spike phase, time scatter plots to see if spikes display phase characteristics during slow wave sleep _, ax = plt.subplots(20, figsize=(10, 50)) -mwt_sleep = np.transpose( - mwt_sleep_second -) +mwt_sleep = np.transpose(mwt_sleep_second) ax[0].plot(sleep_second.index, sleep_second.values[:, 0]) plot_timefrequency(sleep_second.index, freqs, np.abs(mwt_sleep[:, :]), ax=ax[1]) diff --git a/pynapple/process/__init__.py b/pynapple/process/__init__.py index 1cb2f735..fb7e22b9 100644 --- a/pynapple/process/__init__.py +++ b/pynapple/process/__init__.py @@ -16,9 +16,9 @@ shuffle_ts_intervals, ) from .signal_processing import ( - compute_spectrum, + compute_spectogram, compute_wavelet_transform, - compute_welch_spectrum, + compute_welch_spectogram, ) from .tuning_curves import ( compute_1d_mutual_info, diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index ad1a9b8e..3bb4be8e 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -14,7 +14,7 @@ import pynapple as nap -def compute_spectrum(sig, fs=None): +def compute_spectogram(sig, fs=None): """ Performs numpy fft on sig, returns output ..todo: Make sig handle TsdTensor @@ -34,7 +34,7 @@ def compute_spectrum(sig, fs=None): return pd.DataFrame(fft_result, fft_freq) -def compute_welch_spectrum(sig, fs=None): +def compute_welch_spectogram(sig, fs=None): """ Performs scipy Welch's decomposition on sig, returns output ..todo: Make sig handle TsdFrame, TsdTensor @@ -45,12 +45,12 @@ def compute_welch_spectrum(sig, fs=None): fs : float, optional Sampling rate, in Hz. If None, will be calculated from the given signal """ - if not isinstance(sig, nap.Tsd): + if not isinstance(sig, (nap.Tsd, nap.TsdFrame)): raise TypeError("Currently compute_fft is only implemented for Tsd") if fs is None: fs = sig.index.shape[0] / (sig.index.max() - sig.index.min()) - freqs, spectogram = welch(sig.values, fs=fs) - return spectogram, freqs + freqs, spectogram = welch(sig.values, fs=fs, axis=0) + return pd.DataFrame(spectogram, freqs) def _morlet(M=1024, ncycles=1.5, scaling=1.0, precision=8): @@ -145,7 +145,7 @@ def compute_wavelet_transform(sig, fs, freqs, n_cycles=1.5, scaling=1.0, norm="a ---------- sig : pynapple.Tsd or pynapple.TsdFrame Time series. - fs : float + fs : float or None Sampling rate, in Hz. freqs : 1d array or list of float If array, frequency values to estimate with morlet wavelets. From dfe6e411e8f2fed704a770f7da339d85ac754e38 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Wed, 3 Jul 2024 16:55:25 +0100 Subject: [PATCH 14/71] updated function names in test --- tests/test_signal_processing.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index 8da68921..b2d3c8b9 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -6,20 +6,20 @@ import pynapple as nap -def test_compute_spectrum(): +def test_compute_spectogram(): t = np.linspace(0, 1, 1024) sig = nap.Tsd(d=np.random.random(1024), t=t) - r = nap.compute_spectrum(sig) + r = nap.compute_spectogram(sig) assert len(r[1]) == 1024 assert len(r[0]) == 1024 assert r[0].dtype == np.complex128 assert r[1].dtype == np.float64 -def test_compute_welch_spectrum(): +def test_ccompute_welch_spectogram(): t = np.linspace(0, 1, 1024) sig = nap.Tsd(d=np.random.random(1024), t=t) - r = nap.compute_welch_spectrum(sig) + r = nap.compute_welch_spectogram(sig) assert r[0].dtype == np.float64 assert r[1].dtype == np.float64 From 3a9173a9149ac7f760c3fb05642221ae4c942acc Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Wed, 3 Jul 2024 17:59:09 +0100 Subject: [PATCH 15/71] updated tests --- tests/test_signal_processing.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index b2d3c8b9..85ed3b8e 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -1,6 +1,7 @@ """Tests of `signal_processing` for pynapple""" import numpy as np +import pandas as pd import pytest import pynapple as nap @@ -10,18 +11,15 @@ def test_compute_spectogram(): t = np.linspace(0, 1, 1024) sig = nap.Tsd(d=np.random.random(1024), t=t) r = nap.compute_spectogram(sig) - assert len(r[1]) == 1024 - assert len(r[0]) == 1024 - assert r[0].dtype == np.complex128 - assert r[1].dtype == np.float64 + assert isinstance(r, pd.DataFrame) + assert r.shape[0] == 1024 -def test_ccompute_welch_spectogram(): +def test_compute_welch_spectogram(): t = np.linspace(0, 1, 1024) sig = nap.Tsd(d=np.random.random(1024), t=t) r = nap.compute_welch_spectogram(sig) - assert r[0].dtype == np.float64 - assert r[1].dtype == np.float64 + assert isinstance(r, pd.DataFrame) def test_compute_wavelet_transform(): From cfb606630b37cb97d828a338d83d2032280987ae Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Wed, 3 Jul 2024 19:01:32 +0100 Subject: [PATCH 16/71] expanded test coverage --- pynapple/process/signal_processing.py | 8 +++- tests/test_signal_processing.py | 56 +++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 2 deletions(-) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 3bb4be8e..a38b56ab 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -26,7 +26,9 @@ def compute_spectogram(sig, fs=None): Sampling rate, in Hz. If None, will be calculated from the given signal """ if not isinstance(sig, (nap.Tsd, nap.TsdFrame)): - raise TypeError("Currently compute_fft is only implemented for Tsd or TsdFrame") + raise TypeError( + "Currently compute_spectogram is only implemented for Tsd or TsdFrame" + ) if fs is None: fs = sig.index.shape[0] / (sig.index.max() - sig.index.min()) fft_result = np.fft.fft(sig.values, axis=0) @@ -46,7 +48,9 @@ def compute_welch_spectogram(sig, fs=None): Sampling rate, in Hz. If None, will be calculated from the given signal """ if not isinstance(sig, (nap.Tsd, nap.TsdFrame)): - raise TypeError("Currently compute_fft is only implemented for Tsd") + raise TypeError( + "Currently compute_welch_spectogram is only implemented for Tsd or TsdFrame" + ) if fs is None: fs = sig.index.shape[0] / (sig.index.max() - sig.index.min()) freqs, spectogram = welch(sig.values, fs=fs, axis=0) diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index 85ed3b8e..fd728efe 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -14,6 +14,18 @@ def test_compute_spectogram(): assert isinstance(r, pd.DataFrame) assert r.shape[0] == 1024 + sig = nap.TsdFrame(d=np.random.random((1024, 4)), t=t) + r = nap.compute_spectogram(sig) + assert isinstance(r, pd.DataFrame) + assert r.shape == (1024, 4) + + with pytest.raises(TypeError) as e_info: + nap.compute_spectogram("a_string") + assert ( + str(e_info.value) + == "Currently compute_spectogram is only implemented for Tsd or TsdFrame" + ) + def test_compute_welch_spectogram(): t = np.linspace(0, 1, 1024) @@ -21,6 +33,18 @@ def test_compute_welch_spectogram(): r = nap.compute_welch_spectogram(sig) assert isinstance(r, pd.DataFrame) + sig = nap.TsdFrame(d=np.random.random((1024, 4)), t=t) + r = nap.compute_welch_spectogram(sig) + assert isinstance(r, pd.DataFrame) + assert r.shape[1] == 4 + + with pytest.raises(TypeError) as e_info: + nap.compute_welch_spectogram("a_string") + assert ( + str(e_info.value) + == "Currently compute_welch_spectogram is only implemented for Tsd or TsdFrame" + ) + def test_compute_wavelet_transform(): t = np.linspace(0, 1, 1024) @@ -29,7 +53,39 @@ def test_compute_wavelet_transform(): mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) assert mwt.shape == (1024, 10) + t = np.linspace(0, 1, 1024) + sig = nap.Tsd(d=np.random.random(1024), t=t) + freqs = (1, 51, 10) + mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) + assert mwt.shape == (1024, 6) + + sig = nap.Tsd(d=np.random.random(1024), t=t) + freqs = np.linspace(1, 600, 10) + mwt = nap.compute_wavelet_transform( + sig, fs=None, freqs=freqs, n_cycles=tuple(np.repeat((1.5, 2.5), 5)) + ) + assert mwt.shape == (1024, 10) + sig = nap.TsdFrame(d=np.random.random((1024, 4)), t=t) freqs = np.linspace(1, 600, 10) mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) assert mwt.shape == (1024, 10, 4) + + with pytest.raises(ValueError) as e_info: + nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, n_cycles=-1.5) + assert str(e_info.value) == "Number of cycles must be a positive number." + + with pytest.raises(ValueError) as e_info: + nap.compute_wavelet_transform( + sig, fs=None, freqs=freqs, n_cycles=tuple(np.repeat((1.5, -2.5), 5)) + ) + assert str(e_info.value) == "Each number of cycles must be a positive number." + + with pytest.raises(ValueError) as e_info: + nap.compute_wavelet_transform( + sig, fs=None, freqs=freqs, n_cycles=tuple(np.repeat((1.5, 2.5), 2)) + ) + assert ( + str(e_info.value) + == "The length of number of cycles does not match other inputs." + ) From 4148d6d83ded7a3a382e742bdf67dd139302c441 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Mon, 8 Jul 2024 17:55:58 +0100 Subject: [PATCH 17/71] notebook various changes --- docs/examples/tutorial_signal_processing.py | 76 ++++++++++++--------- pynapple/process/signal_processing.py | 2 - 2 files changed, 43 insertions(+), 35 deletions(-) diff --git a/docs/examples/tutorial_signal_processing.py b/docs/examples/tutorial_signal_processing.py index a4d58520..8684ed02 100644 --- a/docs/examples/tutorial_signal_processing.py +++ b/docs/examples/tutorial_signal_processing.py @@ -19,53 +19,54 @@ # mkdocs_gallery_thumbnail_number = 1 # # Now, import the necessary libraries: - +import matplotlib +matplotlib.use("TkAgg") import matplotlib.pyplot as plt import numpy as np import pynapple as nap +from examples_utils import data, plotting # %% # *** # Downloading the data # ------------------ -# First things first: Let's download and extract the data - currently commented as correct NWB is not online -# path = "data/A2929-200711" -# extract_to = "data" -# if extract_to not in os.listdir("."): -# os.mkdir(extract_to) -# if path not in os.listdir("."): -# # Download the file -# response = requests.get( -# "https://www.dropbox.com/s/su4oaje57g3kit9/A2929-200711.zip?dl=1" -# ) -# zip_path = os.path.join(extract_to, "/downloaded_file.zip") -# # Write the zip file to disk -# with open(zip_path, "wb") as f: -# f.write(response.content) -# # Unzip the file -# with ZipFile(zip_path, "r") as zip_ref: -# zip_ref.extractall(extract_to) +# First things first: Let's download and extract the data - download section currently commented as correct NWB +# is not online +# path = data.download_data( +# "Achilles_10252013.nwb", "https://osf.io/hu5ma/download", "../data" +# ) +# data = nap.load_file(path) -# %% -# Let's save the RoiResponseSeries as a variable called 'transients' and print it -FS = 1250 -data = nap.load_file("data/stable.nwb") -print(data["ElectricalSeries"]) +data = nap.load_file("../data/Achillies_ephys.nwb") +FS = len(data["LFP"].index[:]) / (data["LFP"].index[-1] - data["LFP"].index[0]) +print(data) # %% # *** # Selecting slices # ----------------------------------- -# Let's consider a two 1-second slices of data, one from the sleep epoch and one from wake -NES = nap.TsdFrame( - t=data["ElectricalSeries"].index.values, - d=data["ElectricalSeries"].values, - time_support=data["ElectricalSeries"].time_support, +# Let's consider two 60-second slices of data, one from the sleep epoch and one from wake + +wake_minute_interval = nap.IntervalSet( + data["epochs"]["MazeEpoch"]["start"] + 60., + data["epochs"]["MazeEpoch"]["start"] + 120., +) +sleep_minute_interval = nap.IntervalSet( + data["epochs"]["POSTEpoch"]["start"] + 60., + data["epochs"]["POSTEpoch"]["start"] + 120., +) +wake_minute = nap.TsdFrame( + t=data["LFP"].restrict(wake_minute_interval).index.values, + d=data["LFP"].restrict(wake_minute_interval).values, + time_support=data["LFP"].restrict(wake_minute_interval).time_support, +) +sleep_minute = nap.TsdFrame( + t=data["LFP"].restrict(sleep_minute_interval).index.values, + d=data["LFP"].restrict(sleep_minute_interval).values, + time_support=data["LFP"].restrict(sleep_minute_interval).time_support, ) -wake_minute = NES.restrict(nap.IntervalSet(900, 960)) -sleep_minute = NES.restrict(nap.IntervalSet(0, 60)) channel = 1 # %% @@ -211,7 +212,10 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kw # %% # Let's focus on the waking data. Let's see if we can isolate the theta oscillations from the data freq = 3 -interval = (937, 939) +interval = ( + wake_minute_interval["start"], + wake_minute_interval["start"]+2 +) wake_second = wake_minute.restrict(nap.IntervalSet(interval[0], interval[1])) mwt_wake_second = mwt_wake.restrict(nap.IntervalSet(interval[0], interval[1])) fig, ax = plt.subplots(1) @@ -229,7 +233,10 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kw # Let's focus on the sleeping data. Let's see if we can isolate the slow wave oscillations from the data freq = 0 # interval = (10, 15) -interval = (20, 25) +interval = ( + sleep_minute_interval["start"]+30, + sleep_minute_interval["start"]+35 +) sleep_second = sleep_minute.restrict(nap.IntervalSet(interval[0], interval[1])) mwt_sleep_second = mwt_sleep.restrict(nap.IntervalSet(interval[0], interval[1])) _, ax = plt.subplots(1) @@ -276,8 +283,11 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kw ) phase[i] = np.array(phase_i) +spikes = {k: v for k,v in spikes.items() if len(v) > 0} +phase = {k: v for k,v in phase.items() if len(v) > 0} + for i in range(15): - ax[5 + i].scatter(spikes[i], phase[i]) + ax[5 + i].scatter(spikes[list(spikes.keys())[i]], phase[list(phase.keys())[i]]) ax[5 + i].set_xlim(interval[0], interval[1]) ax[5 + i].set_ylim(-np.pi, np.pi) ax[5 + i].set_xlabel("time (s)") diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index a38b56ab..72fc650f 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -17,7 +17,6 @@ def compute_spectogram(sig, fs=None): """ Performs numpy fft on sig, returns output - ..todo: Make sig handle TsdTensor ---------- sig : pynapple.Tsd or pynapple.TsdFrame @@ -39,7 +38,6 @@ def compute_spectogram(sig, fs=None): def compute_welch_spectogram(sig, fs=None): """ Performs scipy Welch's decomposition on sig, returns output - ..todo: Make sig handle TsdFrame, TsdTensor ---------- sig : pynapple.Tsd or pynapple.TsdFrame From 19c91835cd29171f7e67a66421b1e63b022dcc13 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Mon, 8 Jul 2024 21:00:01 +0100 Subject: [PATCH 18/71] compute_wavelet_transform can now handle TsdTensor --- docs/examples/tutorial_signal_processing.py | 23 ++++++-------- pynapple/process/signal_processing.py | 34 ++++++++++++--------- tests/test_signal_processing.py | 14 +++++++++ 3 files changed, 42 insertions(+), 29 deletions(-) diff --git a/docs/examples/tutorial_signal_processing.py b/docs/examples/tutorial_signal_processing.py index 8684ed02..0bb17fcb 100644 --- a/docs/examples/tutorial_signal_processing.py +++ b/docs/examples/tutorial_signal_processing.py @@ -20,6 +20,7 @@ # # Now, import the necessary libraries: import matplotlib + matplotlib.use("TkAgg") import matplotlib.pyplot as plt import numpy as np @@ -50,12 +51,12 @@ # Let's consider two 60-second slices of data, one from the sleep epoch and one from wake wake_minute_interval = nap.IntervalSet( - data["epochs"]["MazeEpoch"]["start"] + 60., - data["epochs"]["MazeEpoch"]["start"] + 120., + data["epochs"]["MazeEpoch"]["start"] + 60.0, + data["epochs"]["MazeEpoch"]["start"] + 120.0, ) sleep_minute_interval = nap.IntervalSet( - data["epochs"]["POSTEpoch"]["start"] + 60., - data["epochs"]["POSTEpoch"]["start"] + 120., + data["epochs"]["POSTEpoch"]["start"] + 60.0, + data["epochs"]["POSTEpoch"]["start"] + 120.0, ) wake_minute = nap.TsdFrame( t=data["LFP"].restrict(wake_minute_interval).index.values, @@ -212,10 +213,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kw # %% # Let's focus on the waking data. Let's see if we can isolate the theta oscillations from the data freq = 3 -interval = ( - wake_minute_interval["start"], - wake_minute_interval["start"]+2 -) +interval = (wake_minute_interval["start"], wake_minute_interval["start"] + 2) wake_second = wake_minute.restrict(nap.IntervalSet(interval[0], interval[1])) mwt_wake_second = mwt_wake.restrict(nap.IntervalSet(interval[0], interval[1])) fig, ax = plt.subplots(1) @@ -233,10 +231,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kw # Let's focus on the sleeping data. Let's see if we can isolate the slow wave oscillations from the data freq = 0 # interval = (10, 15) -interval = ( - sleep_minute_interval["start"]+30, - sleep_minute_interval["start"]+35 -) +interval = (sleep_minute_interval["start"] + 30, sleep_minute_interval["start"] + 35) sleep_second = sleep_minute.restrict(nap.IntervalSet(interval[0], interval[1])) mwt_sleep_second = mwt_sleep.restrict(nap.IntervalSet(interval[0], interval[1])) _, ax = plt.subplots(1) @@ -283,8 +278,8 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kw ) phase[i] = np.array(phase_i) -spikes = {k: v for k,v in spikes.items() if len(v) > 0} -phase = {k: v for k,v in phase.items() if len(v) > 0} +spikes = {k: v for k, v in spikes.items() if len(v) > 0} +phase = {k: v for k, v in phase.items() if len(v) > 0} for i in range(15): ax[5 + i].scatter(spikes[list(spikes.keys())[i]], phase[list(phase.keys())[i]]) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 72fc650f..ca61f2fe 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -173,30 +173,34 @@ def compute_wavelet_transform(sig, fs, freqs, n_cycles=1.5, scaling=1.0, norm="a ----- This computes the continuous wavelet transform at specified frequencies across time. """ - if not isinstance(sig, nap.Tsd) and not isinstance(sig, nap.TsdFrame): - raise TypeError("`sig` must be instance of Tsd or TsdFrame") + if not isinstance(sig, (nap.Tsd, nap.TsdFrame, nap.TsdTensor)): + raise TypeError("`sig` must be instance of Tsd, TsdFrame, or TsdTensor") if isinstance(freqs, (tuple, list)): freqs = _create_freqs(*freqs) if fs is None: fs = sig.index.shape[0] / (sig.index.max() - sig.index.min()) n_cycles = _check_n_cycles(n_cycles, len(freqs)) if isinstance(sig, nap.Tsd): - mwt = np.zeros([len(freqs), len(sig)], dtype=complex) + sig = sig.reshape((sig.shape[0], 1)) + output_shape = (sig.shape[0], len(freqs)) + else: + output_shape = (sig.shape[0], len(freqs), *sig.shape[1:]) + sig = sig.reshape((sig.shape[0], np.prod(sig.shape[1:]))) + mwt = np.zeros( + [sig.values.shape[0], len(freqs), sig.values.shape[1]], dtype=complex + ) + for channel_i in range(sig.values.shape[1]): for ind, (freq, n_cycle) in enumerate(zip(freqs, n_cycles)): - mwt[ind, :] = _convolve_wavelet(sig, fs, freq, n_cycle, scaling, norm=norm) + mwt[:, ind, channel_i] = _convolve_wavelet( + sig[:, channel_i], fs, freq, n_cycle, scaling, norm=norm + ) + if len(output_shape) == 2: return nap.TsdFrame( - t=sig.index, d=np.transpose(mwt), time_support=sig.time_support + t=sig.index, d=mwt.reshape(output_shape), time_support=sig.time_support ) - else: - mwt = np.zeros( - [sig.values.shape[0], len(freqs), sig.values.shape[1]], dtype=complex - ) - for channel_i in range(sig.values.shape[1]): - for ind, (freq, n_cycle) in enumerate(zip(freqs, n_cycles)): - mwt[:, ind, channel_i] = _convolve_wavelet( - sig[:, channel_i], fs, freq, n_cycle, scaling, norm=norm - ) - return nap.TsdTensor(t=sig.index, d=mwt, time_support=sig.time_support) + return nap.TsdTensor( + t=sig.index, d=mwt.reshape(output_shape), time_support=sig.time_support + ) def _convolve_wavelet( diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index fd728efe..c50860ba 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -47,6 +47,14 @@ def test_compute_welch_spectogram(): def test_compute_wavelet_transform(): + + ##..todo put there + t = np.linspace(0, 1, 1024) # can remove this when we move it + sig = nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=t) + freqs = np.linspace(1, 600, 10) + mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) + assert mwt.shape == (1024, 10, 4, 2) + t = np.linspace(0, 1, 1024) sig = nap.Tsd(d=np.random.random(1024), t=t) freqs = np.linspace(1, 600, 10) @@ -71,6 +79,8 @@ def test_compute_wavelet_transform(): mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) assert mwt.shape == (1024, 10, 4) + #..todo: here + with pytest.raises(ValueError) as e_info: nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, n_cycles=-1.5) assert str(e_info.value) == "Number of cycles must be a positive number." @@ -89,3 +99,7 @@ def test_compute_wavelet_transform(): str(e_info.value) == "The length of number of cycles does not match other inputs." ) + + +if __name__ == "__main__": + test_compute_wavelet_transform() \ No newline at end of file From 63f52b2e32f540525d73fa50455f5c0687bfa893 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Fri, 12 Jul 2024 20:36:07 +0100 Subject: [PATCH 19/71] PR comment changes --- docs/examples/tutorial_signal_processing.py | 471 ++++++++++++++++---- pynapple/process/signal_processing.py | 91 +++- tests/test_signal_processing.py | 40 +- 3 files changed, 472 insertions(+), 130 deletions(-) diff --git a/docs/examples/tutorial_signal_processing.py b/docs/examples/tutorial_signal_processing.py index 0bb17fcb..dece1036 100644 --- a/docs/examples/tutorial_signal_processing.py +++ b/docs/examples/tutorial_signal_processing.py @@ -14,86 +14,138 @@ # !!! warning # This tutorial uses matplotlib for displaying the figure # -# You can install all with `pip install matplotlib requests` +# You can install all with `pip install matplotlib requests tqdm` # # mkdocs_gallery_thumbnail_number = 1 # # Now, import the necessary libraries: import matplotlib - matplotlib.use("TkAgg") import matplotlib.pyplot as plt import numpy as np - +import os +import requests +import tqdm +import math import pynapple as nap -from examples_utils import data, plotting +import scipy # %% # *** # Downloading the data # ------------------ -# First things first: Let's download and extract the data - download section currently commented as correct NWB -# is not online +# First things first: Let's download the data and save it locally + +path = "Achilles_10252013_EEG.nwb" +if path not in os.listdir("."): + r = requests.get(f"https://osf.io/2dfvp/download", stream=True) + block_size = 1024*1024 + with open(path, 'wb') as f: + for data in tqdm.tqdm(r.iter_content(block_size), unit='MB', unit_scale=True, + total=math.ceil(int(r.headers.get('content-length', 0))//block_size)): + f.write(data) -# path = data.download_data( -# "Achilles_10252013.nwb", "https://osf.io/hu5ma/download", "../data" -# ) -# data = nap.load_file(path) -data = nap.load_file("../data/Achillies_ephys.nwb") -FS = len(data["LFP"].index[:]) / (data["LFP"].index[-1] - data["LFP"].index[0]) +# %% +# *** +# Loading the data +# ------------------ +# Loading the data, calculating the sampling frequency + +data = nap.load_file(path) +FS = len(data["eeg"].index[:]) / (data["eeg"].index[-1] - data["eeg"].index[0]) print(data) + # %% # *** # Selecting slices # ----------------------------------- # Let's consider two 60-second slices of data, one from the sleep epoch and one from wake -wake_minute_interval = nap.IntervalSet( - data["epochs"]["MazeEpoch"]["start"] + 60.0, - data["epochs"]["MazeEpoch"]["start"] + 120.0, +REM_minute_interval = nap.IntervalSet( + data["rem"]["start"][0] + 60.0, + data["rem"]["start"][0] + 120.0, ) -sleep_minute_interval = nap.IntervalSet( - data["epochs"]["POSTEpoch"]["start"] + 60.0, - data["epochs"]["POSTEpoch"]["start"] + 120.0, + +SWS_minute_interval = nap.IntervalSet( + data["nrem"]["start"][0] + 10.0, + data["nrem"]["start"][0] + 70.0, ) -wake_minute = nap.TsdFrame( - t=data["LFP"].restrict(wake_minute_interval).index.values, - d=data["LFP"].restrict(wake_minute_interval).values, - time_support=data["LFP"].restrict(wake_minute_interval).time_support, + +RUN_minute_interval = nap.IntervalSet( + data["forward_ep"]["start"][-18] + 0., + data["forward_ep"]["start"][-18] + 60., ) -sleep_minute = nap.TsdFrame( - t=data["LFP"].restrict(sleep_minute_interval).index.values, - d=data["LFP"].restrict(sleep_minute_interval).values, - time_support=data["LFP"].restrict(sleep_minute_interval).time_support, + +REM_minute = nap.TsdFrame( + t=data["eeg"].restrict(REM_minute_interval).index.values, + d=data["eeg"].restrict(REM_minute_interval).values, + time_support=data["eeg"].restrict(REM_minute_interval).time_support, ) -channel = 1 + +SWS_minute = nap.TsdFrame( + t=data["eeg"].restrict(SWS_minute_interval).index.values, + d=data["eeg"].restrict(SWS_minute_interval).values, + time_support=data["eeg"].restrict(SWS_minute_interval).time_support, +) + +RUN_minute = nap.TsdFrame( + t=data["eeg"].restrict(RUN_minute_interval).index.values, + d=data["eeg"].restrict(RUN_minute_interval).values, + time_support=data["eeg"].restrict(RUN_minute_interval).time_support, +) +# RUN_position = nap.TsdFrame( +# t=data["position"].restrict(RUN_minute_interval).index.values[1:], +# d=np.diff(data['position'].restrict(RUN_minute_interval)), +# time_support=data["position"].restrict(RUN_minute_interval).time_support, +# ) +RUN_position = nap.TsdFrame( + t=data["position"].restrict(RUN_minute_interval).index.values[:], + d=data['position'].restrict(RUN_minute_interval), + time_support=data["position"].restrict(RUN_minute_interval).time_support, +) + +channel = 0 # %% # *** # Plotting the LFP activity of one slices # ----------------------------------- # Let's plot -fig, ax = plt.subplots(2) -for channel in range(sleep_minute.shape[1]): + +fig, ax = plt.subplots(3) + +for channel in range(SWS_minute.shape[1]): ax[0].plot( - sleep_minute[:, channel], + SWS_minute[:, channel], alpha=0.5, label="Sleep Data", ) -ax[0].set_title("Sleep ephys") -for channel in range(wake_minute.shape[1]): - ax[1].plot(wake_minute[:, channel], alpha=0.5, label="Wake Data") -ax[1].set_title("Wake ephys") +ax[0].set_title("non-REM ephys") +ax[0].set_ylabel("LFP (v)") +ax[0].set_xlabel("time (s)") +ax[0].margins(0) +for channel in range(REM_minute.shape[1]): + ax[1].plot(REM_minute[:, channel], alpha=0.5, label="Wake Data", color="orange") +ax[1].set_ylabel("LFP (v)") +ax[1].set_xlabel("time (s)") +ax[1].set_title("REM ephys") +ax[1].margins(0) +for channel in range(RUN_minute.shape[1]): + ax[2].plot(RUN_minute[:, channel], alpha=0.5, label="Wake Data", color="green") +ax[2].set_ylabel("LFP (v)") +ax[2].set_xlabel("time (s)") +ax[2].set_title("Running ephys") +ax[2].margins(0) plt.show() # %% # Let's take the Fourier transforms of one channel for both waking and sleeping and see if differences are present -channel = 1 -fig, ax = plt.subplots(2) -fft = nap.compute_spectogram(sleep_minute, fs=int(FS)) +channel = 0 +fig, ax = plt.subplots(3) +fft = nap.compute_spectogram(SWS_minute, fs=int(FS)) ax[0].plot( fft.index, np.abs(fft.iloc[:, channel]), alpha=0.5, label="Sleep Data", c="blue" ) @@ -101,17 +153,27 @@ ax[0].set_xlabel("Freq (Hz)") ax[0].set_ylabel("Frequency Power") -ax[0].set_title("Sleep LFP Decomposition") -fft = nap.compute_spectogram(wake_minute, fs=int(FS)) +ax[0].set_title("non-REM LFP Decomposition") +fft = nap.compute_spectogram(REM_minute, fs=int(FS)) ax[1].plot( fft.index, np.abs(fft.iloc[:, channel]), alpha=0.5, label="Wake Data", c="orange" ) ax[1].set_xlim((0, FS / 2)) fig.suptitle(f"Fourier Decomposition for channel {channel}") -ax[1].set_title("Sleep LFP Decomposition") +ax[1].set_title("REM LFP Decomposition") ax[1].set_xlabel("Freq (Hz)") ax[1].set_ylabel("Frequency Power") +fft = nap.compute_spectogram(RUN_minute, fs=int(FS)) +ax[2].plot( + fft.index, np.abs(fft.iloc[:, channel]), alpha=0.5, label="Running Data", c="green" +) +ax[2].set_xlim((0, FS / 2)) +fig.suptitle(f"Fourier Decomposition for channel {channel}") +ax[2].set_title("Running LFP Decomposition") +ax[2].set_xlabel("Freq (Hz)") +ax[2].set_ylabel("Frequency Power") + # ax.legend() plt.show() @@ -119,28 +181,46 @@ # %% # Let's now consider the Welch spectograms of waking and sleeping data... -fig, ax = plt.subplots(2) -fft = nap.compute_welch_spectogram(sleep_minute, fs=int(FS)) +fig, ax = plt.subplots(3) +welch = nap.compute_welch_spectogram(SWS_minute, fs=int(FS)) ax[0].plot( - fft.index, np.abs(fft.iloc[:, channel]), alpha=0.5, label="Sleep Data", color="blue" + welch.index, + np.abs(welch.iloc[:, channel]), + alpha=0.5, + label="non-REM Data", + color="blue" ) ax[0].set_xlim((0, FS / 2)) -ax[0].set_title("Sleep LFP Decomposition") +ax[0].set_title("non-REM LFP Decomposition") ax[0].set_xlabel("Freq (Hz)") ax[0].set_ylabel("Frequency Power") -welch = nap.compute_welch_spectogram(wake_minute, fs=int(FS)) +welch = nap.compute_welch_spectogram(REM_minute, fs=int(FS)) ax[1].plot( welch.index, np.abs(welch.iloc[:, channel]), alpha=0.5, - label="Wake Data", + label="REM Data", color="orange", ) ax[1].set_xlim((0, FS / 2)) fig.suptitle(f"Welch Decomposition for channel {channel}") -ax[1].set_title("Sleep LFP Decomposition") +ax[1].set_title("REM LFP Decomposition") ax[1].set_xlabel("Freq (Hz)") ax[1].set_ylabel("Frequency Power") + +welch = nap.compute_welch_spectogram(RUN_minute, fs=int(FS)) +ax[2].plot( + welch.index, + np.abs(welch.iloc[:, channel]), + alpha=0.5, + label="Running Data", + color="green", +) +ax[2].set_xlim((0, FS / 2)) +fig.suptitle(f"Welch Decomposition for channel {channel}") +ax[2].set_title("Running LFP Decomposition") +ax[2].set_xlabel("Freq (Hz)") +ax[2].set_ylabel("Frequency Power") # ax.legend() plt.show() @@ -149,7 +229,7 @@ # Let's explore further with a wavelet decomposition -def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kwargs): +def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=None, ax=None, **kwargs): if np.iscomplexobj(powers): powers = abs(powers) ax.imshow(powers, aspect="auto", **kwargs) @@ -166,11 +246,24 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kw y_ticks_pos = np.linspace(0, freqs.size, y_ticks) y_ticks = np.round(np.linspace(freqs[0], freqs[-1], y_ticks), 2) else: + y_ticks = freqs y_ticks_pos = [np.argmin(np.abs(freqs - val)) for val in y_ticks] ax.set(yticks=y_ticks_pos, yticklabels=y_ticks) - -fig, ax = plt.subplots(2) +fig = plt.figure(constrained_layout=True, figsize=(10, 50)) +num_cells = 10 +axd = fig.subplot_mosaic( + [ + ["wd_sws"], + ["lfp_sws"], + ["wd_rem"], + ["lfp_rem"], + ["wd_run"], + ["lfp_run"], + ["pos_run"] + ], + height_ratios=[1, .2, 1, .2, 1, .2, .2] +) freqs = np.array( [ 2.59, @@ -185,80 +278,192 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kw 58.59, 82.88, 117.19, - 165.75, + 150.00, + 190.00, 234.38, + 270.00, 331.5, - 468.75, - 624.0, + 390.00, + # 468.75, + # 520.00, + # 570.00, + # 624.0, ] ) -mwt_sleep = nap.compute_wavelet_transform( - sleep_minute[:, channel], fs=None, freqs=freqs +mwt_SWS = nap.compute_wavelet_transform( + SWS_minute[:, channel], fs=None, freqs=freqs ) plot_timefrequency( - sleep_minute.index.values[:], + SWS_minute.index.values[:], freqs[:], - np.transpose(mwt_sleep[:, :].values), - ax=ax[0], + np.transpose(mwt_SWS[:, :].values), + ax=axd["wd_sws"], ) -ax[0].set_title(f"Sleep Data Wavelet Decomposition: Channel {channel}") -mwt_wake = nap.compute_wavelet_transform(wake_minute[:, channel], fs=None, freqs=freqs) +axd["wd_sws"].set_title(f"non-REM Data Wavelet Decomposition: Channel {channel}") + +mwt_REM = nap.compute_wavelet_transform(REM_minute[:, channel], fs=None, freqs=freqs) plot_timefrequency( - wake_minute.index.values[:], freqs[:], np.transpose(mwt_wake[:, :].values), ax=ax[1] + REM_minute.index.values[:], freqs[:], np.transpose(mwt_REM[:, :].values), ax=axd["wd_rem"] ) -ax[1].set_title(f"Wake Data Wavelet Decomposition: Channel {channel}") -plt.margins(0) +axd["wd_rem"].set_title(f"REM Data Wavelet Decomposition: Channel {channel}") + +mwt_RUN = nap.compute_wavelet_transform(RUN_minute[:, channel], fs=None, freqs=freqs) +plot_timefrequency( + RUN_minute.index.values[:], freqs[:], np.transpose(mwt_RUN[:, :].values), ax=axd["wd_run"] +) +axd["wd_run"].set_title(f"Running Data Wavelet Decomposition: Channel {channel}") + +axd["lfp_sws"].plot(SWS_minute) +axd["lfp_rem"].plot(REM_minute) +axd["lfp_run"].plot(RUN_minute) +axd["pos_run"].plot(RUN_position) +axd["pos_run"].margins(0) +for k in ["lfp_sws", "lfp_rem", "lfp_run"]: + axd[k].margins(0) + axd[k].set_ylabel("LFP (v)") + axd[k].get_xaxis().set_visible(False) + axd[k].spines['top'].set_visible(False) + axd[k].spines['right'].set_visible(False) + axd[k].spines['bottom'].set_visible(False) + axd[k].spines['left'].set_visible(False) plt.show() -# %% -# Let's focus on the waking data. Let's see if we can isolate the theta oscillations from the data +# %%g freq = 3 -interval = (wake_minute_interval["start"], wake_minute_interval["start"] + 2) -wake_second = wake_minute.restrict(nap.IntervalSet(interval[0], interval[1])) -mwt_wake_second = mwt_wake.restrict(nap.IntervalSet(interval[0], interval[1])) +interval = (REM_minute_interval["start"] + 0, REM_minute_interval["start"] + 5) +REM_second = REM_minute.restrict(nap.IntervalSet(interval[0], interval[1])) +mwt_REM_second = mwt_REM.restrict(nap.IntervalSet(interval[0], interval[1])) fig, ax = plt.subplots(1) -ax.plot(wake_second.index.values, wake_second[:, channel], alpha=0.5, label="Wake Data") +ax.plot(REM_second.index.values, REM_second[:, channel], alpha=0.5, label="Wake Data") ax.plot( - wake_second.index.values, - mwt_wake_second[:, freq].values.real, + REM_second.index.values, + mwt_REM_second[:, freq].values.real, label="Theta oscillations", ) ax.set_title(f"{freqs[freq]}Hz oscillation power.") plt.show() +# %% +# Let's plot spike phase, time scatter plots to see if spikes display phase characteristics during wakeful theta oscillation + +fig = plt.figure(constrained_layout=True, figsize=(10, 50)) +num_cells = 10 +axd = fig.subplot_mosaic( + [ + ["raw_lfp"]*2, + ["wavelet"]*2, + ["fit_wavelet"]*2, + ["wavelet_power"]*2, + ["wavelet_phase"]*2 + ] + [[f"spikes_phasetime_{i}", f"spikephase_hist_{i}"] for i in range(num_cells)], +) + + +# _, ax = plt.subplots(25, figsize=(10, 50)) +mwt_REM = np.transpose(mwt_REM_second) +axd["raw_lfp"].plot(REM_second.index, REM_second.values[:, 0]) +axd["raw_lfp"].margins(0) +plot_timefrequency(REM_second.index, freqs, np.abs(mwt_REM[:, :]), ax=axd["wavelet"]) + +axd["fit_wavelet"].plot(REM_second.index, REM_second.values[:, 0]) +axd["fit_wavelet"].plot(REM_second.index, mwt_REM[freq, :].real) +axd["fit_wavelet"].set_title(f"{freqs[freq]}Hz") +axd["fit_wavelet"].margins(0) + +axd["wavelet_power"].plot(REM_second.index, np.abs(mwt_REM[freq, :])) +axd["wavelet_power"].margins(0) +# ax[3].plot(lfp.index, lfp.values[:,0]) +axd["wavelet_phase"].plot(REM_second.index, np.angle(mwt_REM[freq, :])) +axd["wavelet_phase"].margins(0) + +spikes = {} +for i in data["units"].index: + spikes[i] = data["units"][i].times()[ + (data["units"][i].times() > interval[0]) + & (data["units"][i].times() < interval[1]) + ] + +phase = {} +for i in spikes.keys(): + phase_i = [] + for spike in spikes[i]: + phase_i.append( + np.angle( + mwt_REM[freq, np.argmin(np.abs(REM_second.index.values - spike))] + ) + ) + phase[i] = np.array(phase_i) + +spikes = {k: v for k, v in spikes.items() if len(v) > 20} +phase = {k: v for k, v in phase.items() if len(v) > 20} + +variances = {key: scipy.stats.circvar(value, low=-np.pi, high=np.pi) for key, value in phase.items()} +spikes = dict(sorted(spikes.items(), key=lambda item: variances[item[0]])) +phase = dict(sorted(phase.items(), key=lambda item: variances[item[0]])) + +for i in range(num_cells): + axd[f"spikes_phasetime_{i}"].scatter(spikes[list(spikes.keys())[i]], phase[list(phase.keys())[i]]) + axd[f"spikes_phasetime_{i}"].set_xlim(interval[0], interval[1]) + axd[f"spikes_phasetime_{i}"].set_ylim(-np.pi, np.pi) + axd[f"spikes_phasetime_{i}"].set_xlabel("time (s)") + axd[f"spikes_phasetime_{i}"].set_ylabel("phase") + + axd[f"spikephase_hist_{i}"].hist(phase[list(phase.keys())[i]], bins=np.linspace(-np.pi, np.pi, 10)) + axd[f"spikephase_hist_{i}"].set_xlim(-np.pi, np.pi) + +plt.tight_layout() +plt.show() + # %% # Let's focus on the sleeping data. Let's see if we can isolate the slow wave oscillations from the data freq = 0 # interval = (10, 15) -interval = (sleep_minute_interval["start"] + 30, sleep_minute_interval["start"] + 35) -sleep_second = sleep_minute.restrict(nap.IntervalSet(interval[0], interval[1])) -mwt_sleep_second = mwt_sleep.restrict(nap.IntervalSet(interval[0], interval[1])) +interval = (SWS_minute_interval["start"] + 30, SWS_minute_interval["start"] + 50) +SWS_second = SWS_minute.restrict(nap.IntervalSet(interval[0], interval[1])) +mwt_SWS_second = mwt_SWS.restrict(nap.IntervalSet(interval[0], interval[1])) _, ax = plt.subplots(1) -ax.plot(sleep_second[:, channel], alpha=0.5, label="Wake Data") +ax.plot(SWS_second[:, channel], alpha=0.5, label="Wake Data") ax.plot( - sleep_second.index.values, - mwt_sleep_second[:, freq].values.real, + SWS_second.index.values, + mwt_SWS_second[:, freq].values.real, label="Slow Wave Oscillations", ) ax.set_title(f"{freqs[freq]}Hz oscillation power") plt.show() # %% -# Let's plot spike phase, time scatter plots to see if spikes display phase characteristics during slow wave sleep +# Let's plot spike phase, time scatter plots to see if spikes display phase characteristics during wakeful theta oscillation -_, ax = plt.subplots(20, figsize=(10, 50)) -mwt_sleep = np.transpose(mwt_sleep_second) -ax[0].plot(sleep_second.index, sleep_second.values[:, 0]) -plot_timefrequency(sleep_second.index, freqs, np.abs(mwt_sleep[:, :]), ax=ax[1]) +fig = plt.figure(constrained_layout=True, figsize=(10, 50)) +num_cells = 5 +axd = fig.subplot_mosaic( + [ + ["raw_lfp"]*2, + ["wavelet"]*2, + ["fit_wavelet"]*2, + ["wavelet_power"]*2, + ["wavelet_phase"]*2 + ] + [[f"spikes_phasetime_{i}", f"spikephase_hist_{i}"] for i in range(num_cells)], +) -ax[2].plot(sleep_second.index, sleep_second.values[:, 0]) -ax[2].plot(sleep_second.index, mwt_sleep[freq, :].real) -ax[2].set_title(f"{freqs[freq]}Hz") -ax[3].plot(sleep_second.index, np.abs(mwt_sleep[freq, :])) -# ax[3].plot(lfp.index, lfp.values[:,0]) -ax[4].plot(sleep_second.index, np.angle(mwt_sleep[freq, :])) +# _, ax = plt.subplots(25, figsize=(10, 50)) +mwt_SWS = np.transpose(mwt_SWS_second) +axd["raw_lfp"].plot(SWS_second.index, SWS_second.values[:, 0]) +axd["raw_lfp"].margins(0) + +plot_timefrequency(SWS_second.index, freqs, np.abs(mwt_SWS[:, :]), ax=axd["wavelet"]) + +axd["fit_wavelet"].plot(SWS_second.index, SWS_second.values[:, 0]) +axd["fit_wavelet"].plot(SWS_second.index, mwt_SWS[freq, :].real) +axd["fit_wavelet"].set_title(f"{freqs[freq]}Hz") +axd["fit_wavelet"].margins(0) + +axd["wavelet_power"].plot(SWS_second.index, np.abs(mwt_SWS[freq, :])) +axd["wavelet_power"].margins(0) +axd["wavelet_phase"].plot(SWS_second.index, np.angle(mwt_SWS[freq, :])) +axd["wavelet_phase"].margins(0) spikes = {} for i in data["units"].index: @@ -273,7 +478,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kw for spike in spikes[i]: phase_i.append( np.angle( - mwt_sleep[freq, np.argmin(np.abs(sleep_second.index.values - spike))] + mwt_SWS[freq, np.argmin(np.abs(SWS_second.index.values - spike))] ) ) phase[i] = np.array(phase_i) @@ -281,12 +486,88 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kw spikes = {k: v for k, v in spikes.items() if len(v) > 0} phase = {k: v for k, v in phase.items() if len(v) > 0} -for i in range(15): - ax[5 + i].scatter(spikes[list(spikes.keys())[i]], phase[list(phase.keys())[i]]) - ax[5 + i].set_xlim(interval[0], interval[1]) - ax[5 + i].set_ylim(-np.pi, np.pi) - ax[5 + i].set_xlabel("time (s)") - ax[5 + i].set_ylabel("phase") +for i in range(num_cells): + axd[f"spikes_phasetime_{i}"].scatter(spikes[list(spikes.keys())[i]], phase[list(phase.keys())[i]]) + axd[f"spikes_phasetime_{i}"].set_xlim(interval[0], interval[1]) + axd[f"spikes_phasetime_{i}"].set_ylim(-np.pi, np.pi) + axd[f"spikes_phasetime_{i}"].set_xlabel("time (s)") + axd[f"spikes_phasetime_{i}"].set_ylabel("phase") + + axd[f"spikephase_hist_{i}"].hist(phase[list(phase.keys())[i]], bins=np.linspace(-np.pi, np.pi, 10)) + axd[f"spikephase_hist_{i}"].set_xlim(-np.pi, np.pi) plt.tight_layout() plt.show() + +# %% +# Let's focus on the sleeping data. Let's see if we can isolate the slow wave oscillations from the data +# interval = (10, 15) + +# for run in [-16, -15, -13, -20]: +# interval = ( +# data["forward_ep"]["start"][run], +# data["forward_ep"]["end"][run]+3., +# ) +# print(interval) +# RUN_second_r = RUN_minute.restrict(nap.IntervalSet(interval[0], interval[1])) +# RUN_position_r = RUN_position.restrict(nap.IntervalSet(interval[0], interval[1])) +# mwt_RUN_second_r = mwt_RUN.restrict(nap.IntervalSet(interval[0], interval[1])) +# _, ax = plt.subplots(3) +# plot_timefrequency( +# RUN_second_r.index.values[:], freqs[:], np.transpose(mwt_RUN_second_r[:, :].values), ax=ax[0] +# ) +# ax[1].plot(RUN_second_r[:, channel], alpha=0.5, label="Wake Data") +# ax[1].margins(0) +# +# ax[2].plot(RUN_position, alpha=0.5, label="Wake Data") +# ax[2].set_xlim(RUN_second_r[:, channel].index.min(), RUN_second_r[:, channel].index.max()) +# ax[2].margins(0) +# plt.show() + + +RUN_minute_interval = nap.IntervalSet( + data["forward_ep"]["start"][0], + data["forward_ep"]["end"][-1] +) + +RUN_minute = nap.TsdFrame( + t=data["eeg"].restrict(RUN_minute_interval).index.values, + d=data["eeg"].restrict(RUN_minute_interval).values, + time_support=data["eeg"].restrict(RUN_minute_interval).time_support, +) + +RUN_position = nap.TsdFrame( + t=data["position"].restrict(RUN_minute_interval).index.values[:], + d=data['position'].restrict(RUN_minute_interval), + time_support=data["position"].restrict(RUN_minute_interval).time_support, +) + +mwt_RUN = nap.compute_wavelet_transform(RUN_minute[:, channel], + freqs=freqs, + fs=None, + norm=None, + n_cycles=3.5, + scaling=1) + +for run in range(len(data["forward_ep"]["start"])): + interval = ( + data["forward_ep"]["start"][run], + data["forward_ep"]["end"][run]+5., + ) + if interval[1] - interval[0] < 6: + continue + print(interval) + RUN_second_r = RUN_minute.restrict(nap.IntervalSet(interval[0], interval[1])) + RUN_position_r = RUN_position.restrict(nap.IntervalSet(interval[0], interval[1])) + mwt_RUN_second_r = mwt_RUN.restrict(nap.IntervalSet(interval[0], interval[1])) + _, ax = plt.subplots(3) + plot_timefrequency( + RUN_second_r.index.values[:], freqs[:], np.transpose(mwt_RUN_second_r[:, :].values), ax=ax[0] + ) + ax[1].plot(RUN_second_r[:, channel], alpha=0.5, label="Wake Data") + ax[1].margins(0) + + ax[2].plot(RUN_position, alpha=0.5, label="Wake Data") + ax[2].set_xlim(RUN_second_r[:, channel].index.min(), RUN_second_r[:, channel].index.max()) + ax[2].margins(0) + plt.show() diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index ca61f2fe..f7b1d51e 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -14,7 +14,7 @@ import pynapple as nap -def compute_spectogram(sig, fs=None): +def compute_spectogram(sig, fs=None, ep=None, full_range=False): """ Performs numpy fft on sig, returns output @@ -23,22 +23,41 @@ def compute_spectogram(sig, fs=None): Time series. fs : float, optional Sampling rate, in Hz. If None, will be calculated from the given signal + ep : pynapple.IntervalSet or None, optional + The epoch to calculate the fft on. Must be length 1. + full_range : bool, optional + If true, will return full fft frequency range, otherwise will return only positive values """ if not isinstance(sig, (nap.Tsd, nap.TsdFrame)): raise TypeError( "Currently compute_spectogram is only implemented for Tsd or TsdFrame" ) + if not (ep is None or isinstance(ep, nap.IntervalSet)): + raise TypeError("ep param must be a pynapple IntervalSet object, or None") + if ep is None: + ep = sig.time_support + if len(ep) != 1: + raise ValueError("Given epoch (or signal time_support) must have length 1") if fs is None: fs = sig.index.shape[0] / (sig.index.max() - sig.index.min()) - fft_result = np.fft.fft(sig.values, axis=0) - fft_freq = np.fft.fftfreq(len(sig.values), 1 / fs) - return pd.DataFrame(fft_result, fft_freq) - + fft_result = np.fft.fft(sig.restrict(ep).values, axis=0) + fft_freq = np.fft.fftfreq(len(sig.restrict(ep).values), 1 / fs) + ret = pd.DataFrame(fft_result, fft_freq) + ret.sort_index(inplace=True) + if not full_range: + return ret.loc[ret.index >= 0] + return ret def compute_welch_spectogram(sig, fs=None): """ - Performs scipy Welch's decomposition on sig, returns output + Performs scipy Welch's decomposition on sig, returns output. + Estimates the power spectral density of a signal by segmenting it into overlapping sections, applying a + window function to each segment, computing their FFTs, and averaging the resulting periodograms to reduce noise. + + ..todo: remove this or add binsize parameter + ..todo: be careful of border artifacts + Parameters ---------- sig : pynapple.Tsd or pynapple.TsdFrame Time series. @@ -57,16 +76,16 @@ def compute_welch_spectogram(sig, fs=None): def _morlet(M=1024, ncycles=1.5, scaling=1.0, precision=8): """ - Defines the complex Morelet wavelet kernel + Defines the complex Morlet wavelet kernel Parameters ---------- M : int Length of the wavelet ncycles : float - number of wavelet cycles to use. Default is 5 + number of wavelet cycles to use. Default is 1.5 scaling: float - Scaling factor. Default is 1.5 + Scaling factor. Default is 1.0 precision: int Precision of wavelet to use @@ -139,7 +158,7 @@ def _create_freqs(freq_start, freq_stop, freq_step=1): return np.arange(freq_start, freq_stop + freq_step, freq_step) -def compute_wavelet_transform(sig, fs, freqs, n_cycles=1.5, scaling=1.0, norm="amp"): +def compute_wavelet_transform(sig, freqs, fs=None, n_cycles=1.5, scaling=1.0, precision=10, norm=None): """ Compute the time-frequency representation of a signal using morlet wavelets. @@ -147,20 +166,21 @@ def compute_wavelet_transform(sig, fs, freqs, n_cycles=1.5, scaling=1.0, norm="a ---------- sig : pynapple.Tsd or pynapple.TsdFrame Time series. - fs : float or None - Sampling rate, in Hz. freqs : 1d array or list of float If array, frequency values to estimate with morlet wavelets. If list, define the frequency range, as [freq_start, freq_stop, freq_step]. The `freq_step` is optional, and defaults to 1. Range is inclusive of `freq_stop` value. + fs : float or None + Sampling rate, in Hz. Defaults to sig.rate if None is given n_cycles : float or 1d array Length of the filter, as the number of cycles for each frequency. If 1d array, this defines n_cycles for each frequency. scaling : float Scaling factor. - norm : {'sss', 'amp'}, optional + norm : {None, 'sss', 'amp'}, optional Normalization method: + * None - no normalization * 'sss' - divide by the square root of the sum of squares * 'amp' - divide by the sum of amplitudes @@ -178,7 +198,7 @@ def compute_wavelet_transform(sig, fs, freqs, n_cycles=1.5, scaling=1.0, norm="a if isinstance(freqs, (tuple, list)): freqs = _create_freqs(*freqs) if fs is None: - fs = sig.index.shape[0] / (sig.index.max() - sig.index.min()) + fs = sig.rate n_cycles = _check_n_cycles(n_cycles, len(freqs)) if isinstance(sig, nap.Tsd): sig = sig.reshape((sig.shape[0], 1)) @@ -192,7 +212,7 @@ def compute_wavelet_transform(sig, fs, freqs, n_cycles=1.5, scaling=1.0, norm="a for channel_i in range(sig.values.shape[1]): for ind, (freq, n_cycle) in enumerate(zip(freqs, n_cycles)): mwt[:, ind, channel_i] = _convolve_wavelet( - sig[:, channel_i], fs, freq, n_cycle, scaling, norm=norm + sig[:, channel_i], fs, freq, n_cycle, scaling, precision=precision, norm=norm ) if len(output_shape) == 2: return nap.TsdFrame( @@ -204,7 +224,7 @@ def compute_wavelet_transform(sig, fs, freqs, n_cycles=1.5, scaling=1.0, norm="a def _convolve_wavelet( - sig, fs, freq, n_cycles=1.5, scaling=1.0, precision=10, norm="sss" + sig, fs, freq, n_cycles=1.5, scaling=1.0, precision=10, norm=None ): """ Convolve a signal with a complex wavelet. @@ -221,17 +241,22 @@ def _convolve_wavelet( Length of the filter, as the number of cycles of the oscillation with specified frequency. scaling : float, optional, default: 0.5 Scaling factor for the morlet wavelet. - norm : {'sss', 'amp'}, optional + precision: int, optional, defaul: 10 + Precision of wavelet - higher number will lead to higher resolution wavelet (i.e. a longer filter bank + to be convolved with the signal) + norm : {'sss', 'amp', None}, optional Normalization method: * 'sss' - divide by the square root of the sum of squares * 'amp' - divide by the sum of amplitudes + * None - no normalization Returns ------- array - Complex- valued time series. + Complex-valued time series. + ..todo: fix scaling Notes ----- @@ -239,19 +264,27 @@ def _convolve_wavelet( * Taking np.abs() of output gives the analytic amplitude. * Taking np.angle() of output gives the analytic phase. """ - if norm not in ["sss", "amp"]: - raise ValueError("Given `norm` must be `sss` or `amp`") + if norm not in ["sss", "amp", None]: + raise ValueError("Given `norm` must be None, `sss` or `amp`") morlet_f = _morlet(int(2**precision), ncycles=n_cycles, scaling=scaling) x = np.linspace(-8, 8, int(2**precision)) int_psi = np.conj(_integrate(morlet_f, x[1] - x[0])) + scale = scaling / (freq / fs) j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * (x[1] - x[0])) j = j.astype(int) # floor if j[-1] >= int_psi.size: j = np.extract(j < int_psi.size, j) int_psi_scale = int_psi[j][::-1] + conv = np.convolve(sig, int_psi_scale) - coef = -np.sqrt(scale) * np.diff(conv, axis=-1) + if norm == "sss": + coef = -np.sqrt(scale) * np.diff(conv, axis=-1) + elif norm == "amp": + coef = -scale * np.diff(conv, axis=-1) + else: + coef = np.diff(conv, axis=-1) #No normalization seems to be most effective... take others out? Why scale? ..todo + # transform axis is always -1 due to the data reshape above d = (coef.shape[-1] - sig.shape[-1]) / 2.0 if d > 0: @@ -262,6 +295,22 @@ def _convolve_wavelet( def _integrate(arr, step): + """ + Integrates an array with respect to some step param. Used for integrating complex wavelets. + + Parameters + ---------- + arr : np.ndarray + wave function to be integrated + step : float + Step size of vgiven wave function array + + Returns + ------- + array + Complex-valued integrated wavelet + + """ integral = np.cumsum(arr) integral *= step return integral diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index c50860ba..e302c446 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -8,16 +8,31 @@ def test_compute_spectogram(): - t = np.linspace(0, 1, 1024) - sig = nap.Tsd(d=np.random.random(1024), t=t) + with pytest.raises(ValueError) as e_info: + t = np.linspace(0, 1, 1000) + sig = nap.Tsd(d=np.random.random(1000), t=t, + time_support=nap.IntervalSet(start=[0.1, 0.6], end=[0.2, 0.81])) + r = nap.compute_spectogram(sig) + assert ( + str(e_info.value) + == "Given epoch (or signal time_support) must have length 1" + ) + + t = np.linspace(0, 1, 1000) + sig = nap.Tsd(d=np.random.random(1000), t=t) r = nap.compute_spectogram(sig) assert isinstance(r, pd.DataFrame) - assert r.shape[0] == 1024 + assert r.shape[0] == 500 - sig = nap.TsdFrame(d=np.random.random((1024, 4)), t=t) + sig = nap.TsdFrame(d=np.random.random((1000, 4)), t=t) r = nap.compute_spectogram(sig) assert isinstance(r, pd.DataFrame) - assert r.shape == (1024, 4) + assert r.shape == (500, 4) + + sig = nap.TsdFrame(d=np.random.random((1000, 4)), t=t) + r = nap.compute_spectogram(sig, full_range=True) + assert isinstance(r, pd.DataFrame) + assert r.shape == (1000, 4) with pytest.raises(TypeError) as e_info: nap.compute_spectogram("a_string") @@ -48,13 +63,6 @@ def test_compute_welch_spectogram(): def test_compute_wavelet_transform(): - ##..todo put there - t = np.linspace(0, 1, 1024) # can remove this when we move it - sig = nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=t) - freqs = np.linspace(1, 600, 10) - mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) - assert mwt.shape == (1024, 10, 4, 2) - t = np.linspace(0, 1, 1024) sig = nap.Tsd(d=np.random.random(1024), t=t) freqs = np.linspace(1, 600, 10) @@ -79,7 +87,11 @@ def test_compute_wavelet_transform(): mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) assert mwt.shape == (1024, 10, 4) - #..todo: here + t = np.linspace(0, 1, 1024) # can remove this when we move it + sig = nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=t) + freqs = np.linspace(1, 600, 10) + mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) + assert mwt.shape == (1024, 10, 4, 2) with pytest.raises(ValueError) as e_info: nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, n_cycles=-1.5) @@ -102,4 +114,4 @@ def test_compute_wavelet_transform(): if __name__ == "__main__": - test_compute_wavelet_transform() \ No newline at end of file + test_compute_spectogram() \ No newline at end of file From dda072f2630f52f805563e0bff54e6b20c1d60ad Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Fri, 12 Jul 2024 21:20:25 +0100 Subject: [PATCH 20/71] filterbank changes --- pynapple/process/signal_processing.py | 40 +++++++++++++++++++++++++-- tests/test_signal_processing.py | 2 +- 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index f7b1d51e..629f0bd1 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -199,7 +199,7 @@ def compute_wavelet_transform(sig, freqs, fs=None, n_cycles=1.5, scaling=1.0, pr freqs = _create_freqs(*freqs) if fs is None: fs = sig.rate - n_cycles = _check_n_cycles(n_cycles, len(freqs)) + # n_cycles = _check_n_cycles(n_cycles, len(freqs)) if isinstance(sig, nap.Tsd): sig = sig.reshape((sig.shape[0], 1)) output_shape = (sig.shape[0], len(freqs)) @@ -209,6 +209,18 @@ def compute_wavelet_transform(sig, freqs, fs=None, n_cycles=1.5, scaling=1.0, pr mwt = np.zeros( [sig.values.shape[0], len(freqs), sig.values.shape[1]], dtype=complex ) + + filter_bank = _generate_morelet_filterbank(freqs, fs, n_cycles, scaling, precision) + # + import matplotlib + matplotlib.use("TkAgg") + import matplotlib.pyplot as plt + plt.clf() + for f in filter_bank: + plt.plot(f) + plt.show() + conv = np.convolve(sig, filter_bank) + for channel_i in range(sig.values.shape[1]): for ind, (freq, n_cycle) in enumerate(zip(freqs, n_cycles)): mwt[:, ind, channel_i] = _convolve_wavelet( @@ -222,6 +234,30 @@ def compute_wavelet_transform(sig, freqs, fs=None, n_cycles=1.5, scaling=1.0, pr t=sig.index, d=mwt.reshape(output_shape), time_support=sig.time_support ) +def _generate_morelet_filterbank(freqs, fs, n_cycles, scaling, precision): + """ + Make docsting #..todo: + + :param freqs: + :param n_cycles: + :param scaling: + :param precision: + :return: + """ + filter_bank = [] + morlet_f = _morlet(int(2 ** precision), ncycles=n_cycles, scaling=scaling) + x = np.linspace(-8, 8, int(2 ** precision)) + int_psi = np.conj(_integrate(morlet_f, x[1] - x[0])) + for freq in freqs: + scale = scaling / (freq / fs) + j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * (x[1] - x[0])) + j = j.astype(int) # floor + if j[-1] >= int_psi.size: + j = np.extract(j < int_psi.size, j) + int_psi_scale = int_psi[j][::-1] + filter_bank.append(int_psi_scale) + return filter_bank + def _convolve_wavelet( sig, fs, freq, n_cycles=1.5, scaling=1.0, precision=10, norm=None @@ -256,7 +292,6 @@ def _convolve_wavelet( array Complex-valued time series. - ..todo: fix scaling Notes ----- @@ -276,6 +311,7 @@ def _convolve_wavelet( if j[-1] >= int_psi.size: j = np.extract(j < int_psi.size, j) int_psi_scale = int_psi[j][::-1] + print(len(int_psi_scale)) conv = np.convolve(sig, int_psi_scale) if norm == "sss": diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index e302c446..7be3eef9 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -114,4 +114,4 @@ def test_compute_wavelet_transform(): if __name__ == "__main__": - test_compute_spectogram() \ No newline at end of file + test_compute_wavelet_transform() \ No newline at end of file From d0c0ddd349810d111dfff9847699e16649ce90e5 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Sat, 13 Jul 2024 00:37:07 +0100 Subject: [PATCH 21/71] fixed test --- pynapple/process/signal_processing.py | 42 ++++----------------------- tests/test_signal_processing.py | 32 +++++--------------- 2 files changed, 13 insertions(+), 61 deletions(-) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index acc711ce..0ec96d10 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -102,39 +102,6 @@ def _morlet(M=1024, ncycles=1.5, scaling=1.0, precision=8): ) -def _check_n_cycles(n_cycles, len_cycles=None): - """ - Check an input as a number of cycles, and make it iterable. - - Parameters - ---------- - n_cycles : float or list - Definition of number of cycles to check. If a single value, the same number of cycles is used for each - frequency value. If a list or list_like, then should be a n_cycles corresponding to each frequency. - len_cycles: int, optional - What the length of `n_cycles` should be, if it's a list. - - Returns - ------- - iter - An iterable version of the number of cycles. - """ - if isinstance(n_cycles, (int, float, np.number)): - if n_cycles <= 0: - raise ValueError("Number of cycles must be a positive number.") - n_cycles = repeat(n_cycles) - elif isinstance(n_cycles, (tuple, list, np.ndarray)): - for cycle in n_cycles: - if cycle <= 0: - raise ValueError("Each number of cycles must be a positive number.") - if len_cycles and len(n_cycles) != len_cycles: - raise ValueError( - "The length of number of cycles does not match other inputs." - ) - n_cycles = iter(n_cycles) - return n_cycles - - def _create_freqs(freq_start, freq_stop, freq_step=1): """ Creates an array of frequencies. @@ -199,6 +166,9 @@ def compute_wavelet_transform( if not isinstance(sig, (nap.Tsd, nap.TsdFrame, nap.TsdTensor)): raise TypeError("`sig` must be instance of Tsd, TsdFrame, or TsdTensor") + if isinstance(n_cycles, (int, float, np.number)): + if n_cycles <= 0: + raise ValueError("Number of cycles must be a positive number.") if isinstance(freqs, (tuple, list)): freqs = _create_freqs(*freqs) @@ -206,8 +176,6 @@ def compute_wavelet_transform( if fs is None: fs = sig.rate - # n_cycles = _check_n_cycles(n_cycles, len(freqs)) - if isinstance(sig, nap.Tsd): sig = sig.reshape((sig.shape[0], 1)) output_shape = (sig.shape[0], len(freqs)) @@ -229,9 +197,9 @@ def compute_wavelet_transform( elif norm == "amp": coef *= -scaling / (freqs[f_i] / fs) coef = np.insert( - coef, 1, coef[0] + coef, 1, coef[0], axis=0 ) # slightly hacky line, necessary to make output correct shape - mwt[:, f_i, :] = np.expand_dims(coef, axis=1) + mwt[:, f_i, :] = coef if len(coef.shape) == 2 else np.expand_dims(coef, axis=1) if len(output_shape) == 2: return nap.TsdFrame( diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index d9159416..68347c3a 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -10,12 +10,14 @@ def test_compute_spectogram(): with pytest.raises(ValueError) as e_info: t = np.linspace(0, 1, 1000) - sig = nap.Tsd(d=np.random.random(1000), t=t, - time_support=nap.IntervalSet(start=[0.1, 0.6], end=[0.2, 0.81])) + sig = nap.Tsd( + d=np.random.random(1000), + t=t, + time_support=nap.IntervalSet(start=[0.1, 0.6], end=[0.2, 0.81]), + ) r = nap.compute_spectogram(sig) assert ( - str(e_info.value) - == "Given epoch (or signal time_support) must have length 1" + str(e_info.value) == "Given epoch (or signal time_support) must have length 1" ) t = np.linspace(0, 1, 1000) @@ -75,13 +77,6 @@ def test_compute_wavelet_transform(): mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) assert mwt.shape == (1024, 6) - sig = nap.Tsd(d=np.random.random(1024), t=t) - freqs = np.linspace(1, 600, 10) - mwt = nap.compute_wavelet_transform( - sig, fs=None, freqs=freqs, n_cycles=tuple(np.repeat((1.5, 2.5), 5)) - ) - assert mwt.shape == (1024, 10) - sig = nap.TsdFrame(d=np.random.random((1024, 4)), t=t) freqs = np.linspace(1, 600, 10) mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) @@ -97,17 +92,6 @@ def test_compute_wavelet_transform(): nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, n_cycles=-1.5) assert str(e_info.value) == "Number of cycles must be a positive number." - with pytest.raises(ValueError) as e_info: - nap.compute_wavelet_transform( - sig, fs=None, freqs=freqs, n_cycles=tuple(np.repeat((1.5, -2.5), 5)) - ) - assert str(e_info.value) == "Each number of cycles must be a positive number." - with pytest.raises(ValueError) as e_info: - nap.compute_wavelet_transform( - sig, fs=None, freqs=freqs, n_cycles=tuple(np.repeat((1.5, 2.5), 2)) - ) - assert ( - str(e_info.value) - == "The length of number of cycles does not match other inputs." - ) +if __name__ == "__main__": + test_compute_wavelet_transform() From 9c53af536b791b52b94044234b6f0e78688aef49 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Sat, 13 Jul 2024 00:49:17 +0100 Subject: [PATCH 22/71] unused import removed --- pynapple/process/signal_processing.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 0ec96d10..e82c1901 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -4,8 +4,6 @@ Contains functionality for signal processing pynapple object; fourier transforms and wavelet decomposition. """ -from itertools import repeat - import numpy as np import pandas as pd from scipy.signal import welch From 2b4bc86ab1466299acc6e610193b17071073341e Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Mon, 15 Jul 2024 17:38:18 +0100 Subject: [PATCH 23/71] logspacing --- pynapple/process/signal_processing.py | 15 ++++++++++----- tests/test_signal_processing.py | 12 +++++++++++- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index e82c1901..8c0a0fc7 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -48,7 +48,7 @@ def compute_spectogram(sig, fs=None, ep=None, full_range=False): def compute_welch_spectogram(sig, fs=None): """ - Performs scipy Welch's decomposition on sig, returns output. + Performs Welch's decomposition on sig, returns output. Estimates the power spectral density of a signal by segmenting it into overlapping sections, applying a window function to each segment, computing their FFTs, and averaging the resulting periodograms to reduce noise. @@ -100,27 +100,32 @@ def _morlet(M=1024, ncycles=1.5, scaling=1.0, precision=8): ) -def _create_freqs(freq_start, freq_stop, freq_step=1): +def _create_freqs(freq_start, freq_stop, log_scaling=False, freq_step=1, log_base=np.e): """ Creates an array of frequencies. - ..todo:: Implement log scaling - Parameters ---------- freq_start : float Starting value for the frequency definition. freq_stop: float Stopping value for the frequency definition, inclusive. + log_scaling: Bool + If True, will use log spacing with base log_base for frequency spacing. Default False. freq_step: float, optional Step value, for linearly spaced values between start and stop. + log_base: float + If log_scaling==True, this defines the base of the log to use. Returns ------- freqs: 1d array Frequency indices. """ - return np.arange(freq_start, freq_stop + freq_step, freq_step) + if not log_scaling: + return np.arange(freq_start, freq_stop + freq_step, freq_step) + else: + return np.logspace(freq_start, freq_stop, base=log_base) def compute_wavelet_transform( diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index 68347c3a..5055cdbf 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -45,6 +45,16 @@ def test_compute_spectogram(): def test_compute_welch_spectogram(): + t = np.linspace(0, 1, 10000) + sig = nap.TsdFrame( + d=np.random.random((10000, 4)), + t=t, + time_support=nap.IntervalSet(start=[0.1, 0.4], end=[0.2, 0.525]), + ) + r = nap.compute_welch_spectogram(sig) + assert isinstance(r, pd.DataFrame) + assert r.shape[1] == 4 + t = np.linspace(0, 1, 1024) sig = nap.Tsd(d=np.random.random(1024), t=t) r = nap.compute_welch_spectogram(sig) @@ -94,4 +104,4 @@ def test_compute_wavelet_transform(): if __name__ == "__main__": - test_compute_wavelet_transform() + test_compute_welch_spectogram() From f110d1e0e3250b0e30e8ca6b856cb39fc3e796a5 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Wed, 17 Jul 2024 16:44:18 +0100 Subject: [PATCH 24/71] better tests, finished notebook 1 --- docs/examples/tutorial_signal_processing.py | 705 ++++++++------------ pynapple/process/signal_processing.py | 32 +- tests/test_signal_processing.py | 27 +- 3 files changed, 310 insertions(+), 454 deletions(-) diff --git a/docs/examples/tutorial_signal_processing.py b/docs/examples/tutorial_signal_processing.py index 80296e65..7f2383a9 100644 --- a/docs/examples/tutorial_signal_processing.py +++ b/docs/examples/tutorial_signal_processing.py @@ -1,17 +1,16 @@ # -*- coding: utf-8 -*- """ -Signal Processing Local Field Potentials +Grosmark & Buzsáki (2016) Tutorial ============ +This tutorial demonstrates how we can use the signal processing tools within Pynapple to aid with data analysis. +We will examine the dataset from [Grosmark & Buzsáki (2016)](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4919122/). -Working with Local Field Potential data. - -See the [documentation](https://pynapple-org.github.io/pynapple/) of Pynapple for instructions on installing the package. +Specifically, we will examine Local Field Potential data from a period of active traversal of a linear track. This tutorial was made by Kipp Freud. """ -import math -import os + # %% # !!! warning @@ -19,13 +18,14 @@ # # You can install all with `pip install matplotlib requests tqdm` # -# mkdocs_gallery_thumbnail_number = 1 -# -# Now, import the necessary libraries: +# First, import the necessary libraries: + +import math +import os + import matplotlib.pyplot as plt import numpy as np import requests -import scipy import tqdm import pynapple as nap @@ -34,7 +34,7 @@ # *** # Downloading the data # ------------------ -# First things first: Let's download the data and save it locally +# Let's download the data and save it locally path = "Achilles_10252013_EEG.nwb" if path not in os.listdir("."): @@ -54,7 +54,7 @@ # *** # Loading the data # ------------------ -# Loading the data, calculating the sampling frequency +# Let's load and print the full dataset. data = nap.load_file(path) FS = len(data["eeg"].index[:]) / (data["eeg"].index[-1] - data["eeg"].index[0]) @@ -65,172 +65,125 @@ # *** # Selecting slices # ----------------------------------- -# Let's consider two 60-second slices of data, one from the sleep epoch and one from wake - -REM_minute_interval = nap.IntervalSet( - data["rem"]["start"][0] + 60.0, - data["rem"]["start"][0] + 120.0, +# We will consider a single run of the experiment - where the rodent completes a full traversal of the linear track, +# followed by 4 seconds of post-traversal activity. + +# Define the run to use for this Analysis +run_index = 7 +# Define the IntervalSet for this run and instantiate both LFP and +# Position TsdFrame objects +RUN_interval = nap.IntervalSet( + data["forward_ep"]["start"][run_index], + data["forward_ep"]["end"][run_index] + 4.0, ) - -SWS_minute_interval = nap.IntervalSet( - data["nrem"]["start"][0] + 10.0, - data["nrem"]["start"][0] + 70.0, +RUN_Tsd = nap.TsdFrame( + t=data["eeg"].restrict(RUN_interval).index.values - + data["forward_ep"]["start"][run_index], + d=data["eeg"].restrict(RUN_interval).values, ) - -RUN_minute_interval = nap.IntervalSet( - data["forward_ep"]["start"][-18] + 0.0, - data["forward_ep"]["start"][-18] + 60.0, +RUN_pos = nap.TsdFrame( + t=data["position"].restrict(RUN_interval).index.values - + data["forward_ep"]["start"][run_index], + d=data["position"].restrict(RUN_interval).asarray(), ) +# The given dataset has only one channel, so we set channel = 0 here +channel = 0 -REM_minute = nap.TsdFrame( - t=data["eeg"].restrict(REM_minute_interval).index.values, - d=data["eeg"].restrict(REM_minute_interval).values, - time_support=data["eeg"].restrict(REM_minute_interval).time_support, -) +# %% +# *** +# Plotting the LFP and Behavioural Activity +# ----------------------------------- -SWS_minute = nap.TsdFrame( - t=data["eeg"].restrict(SWS_minute_interval).index.values, - d=data["eeg"].restrict(SWS_minute_interval).values, - time_support=data["eeg"].restrict(SWS_minute_interval).time_support, +fig = plt.figure(constrained_layout=True, figsize=(10, 6)) +axd = fig.subplot_mosaic( + [ + ["ephys"], + ["pos"] + ], + height_ratios=[1, 0.2], ) -RUN_minute = nap.TsdFrame( - t=data["eeg"].restrict(RUN_minute_interval).index.values, - d=data["eeg"].restrict(RUN_minute_interval).values, - time_support=data["eeg"].restrict(RUN_minute_interval).time_support, +axd["ephys"].plot( + RUN_Tsd[:, channel].restrict( + nap.IntervalSet( + 0.0, + data["forward_ep"]["end"][run_index] - + data["forward_ep"]["start"][run_index] + ) + ), + label="Traversal LFP Data", + color="green" ) - -RUN_position = nap.TsdFrame( - t=data["position"].restrict(RUN_minute_interval).index.values[:], - d=data["position"].restrict(RUN_minute_interval), - time_support=data["position"].restrict(RUN_minute_interval).time_support, +axd["ephys"].plot( + RUN_Tsd[:, channel].restrict( + nap.IntervalSet( + data["forward_ep"]["end"][run_index] - + data["forward_ep"]["start"][run_index], + data["forward_ep"]["end"][run_index] - + data["forward_ep"]["start"][run_index] + 5.0, + ) + ), + label="Post Traversal LFP Data", + color="blue" ) +axd["ephys"].set_title("Traversal & Post Traversal LFP") +axd["ephys"].set_ylabel("LFP (v)") +axd["ephys"].set_xlabel("time (s)") +axd["ephys"].margins(0) +axd["ephys"].legend() +axd["pos"].plot(RUN_pos, color="black") +axd["pos"].margins(0) +axd["pos"].set_xlabel("time (s)") +axd["pos"].set_ylabel("Linearized Position") +axd["pos"].set_xlim(RUN_Tsd.index.min(), RUN_Tsd.index.max()) -channel = 0 # %% # *** -# Plotting the LFP activity of one slices +# Getting the LFP Spectogram # ----------------------------------- -# Let's plot - -fig, ax = plt.subplots(3) - -for channel in range(SWS_minute.shape[1]): - ax[0].plot( - SWS_minute[:, channel], - alpha=0.5, - label="Sleep Data", - ) -ax[0].set_title("non-REM ephys") -ax[0].set_ylabel("LFP (v)") -ax[0].set_xlabel("time (s)") -ax[0].margins(0) -for channel in range(REM_minute.shape[1]): - ax[1].plot(REM_minute[:, channel], alpha=0.5, label="Wake Data", color="orange") -ax[1].set_ylabel("LFP (v)") -ax[1].set_xlabel("time (s)") -ax[1].set_title("REM ephys") -ax[1].margins(0) -for channel in range(RUN_minute.shape[1]): - ax[2].plot(RUN_minute[:, channel], alpha=0.5, label="Wake Data", color="green") -ax[2].set_ylabel("LFP (v)") -ax[2].set_xlabel("time (s)") -ax[2].set_title("Running ephys") -ax[2].margins(0) -plt.show() - - -# %% # Let's take the Fourier transforms of one channel for both waking and sleeping and see if differences are present -channel = 0 -fig, ax = plt.subplots(3) -fft = nap.compute_spectogram(SWS_minute, fs=int(FS)) -ax[0].plot( - fft.index, np.abs(fft.iloc[:, channel]), alpha=0.5, label="Sleep Data", c="blue" -) -ax[0].set_xlim((0, FS / 2)) -ax[0].set_xlabel("Freq (Hz)") -ax[0].set_ylabel("Frequency Power") - -ax[0].set_title("non-REM LFP Decomposition") -fft = nap.compute_spectogram(REM_minute, fs=int(FS)) -ax[1].plot( - fft.index, np.abs(fft.iloc[:, channel]), alpha=0.5, label="Wake Data", c="orange" -) -ax[1].set_xlim((0, FS / 2)) -fig.suptitle(f"Fourier Decomposition for channel {channel}") -ax[1].set_title("REM LFP Decomposition") -ax[1].set_xlabel("Freq (Hz)") -ax[1].set_ylabel("Frequency Power") - -fft = nap.compute_spectogram(RUN_minute, fs=int(FS)) -ax[2].plot( - fft.index, np.abs(fft.iloc[:, channel]), alpha=0.5, label="Running Data", c="green" -) -ax[2].set_xlim((0, FS / 2)) -fig.suptitle(f"Fourier Decomposition for channel {channel}") -ax[2].set_title("Running LFP Decomposition") -ax[2].set_xlabel("Freq (Hz)") -ax[2].set_ylabel("Frequency Power") -# ax.legend() -plt.show() +fft = nap.compute_spectogram(RUN_Tsd, fs=int(FS)) - -# %% -# Let's now consider the Welch spectograms of waking and sleeping data... - -fig, ax = plt.subplots(3) -welch = nap.compute_welch_spectogram(SWS_minute, fs=int(FS)) -ax[0].plot( - welch.index, - np.abs(welch.iloc[:, channel]), - alpha=0.5, - label="non-REM Data", - color="blue", -) -ax[0].set_xlim((0, FS / 2)) -ax[0].set_title("non-REM LFP Decomposition") -ax[0].set_xlabel("Freq (Hz)") -ax[0].set_ylabel("Frequency Power") -welch = nap.compute_welch_spectogram(REM_minute, fs=int(FS)) -ax[1].plot( - welch.index, - np.abs(welch.iloc[:, channel]), - alpha=0.5, - label="REM Data", - color="orange", -) -ax[1].set_xlim((0, FS / 2)) -fig.suptitle(f"Welch Decomposition for channel {channel}") -ax[1].set_title("REM LFP Decomposition") -ax[1].set_xlabel("Freq (Hz)") -ax[1].set_ylabel("Frequency Power") - -welch = nap.compute_welch_spectogram(RUN_minute, fs=int(FS)) -ax[2].plot( - welch.index, - np.abs(welch.iloc[:, channel]), - alpha=0.5, - label="Running Data", - color="green", +# Now we will plot it +fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) +ax.plot( + fft.index, np.abs(fft.iloc[:, channel]), alpha=0.5, + label="LFP Frequency Power", c="blue", linewidth=2 ) -ax[2].set_xlim((0, FS / 2)) -fig.suptitle(f"Welch Decomposition for channel {channel}") -ax[2].set_title("Running LFP Decomposition") -ax[2].set_xlabel("Freq (Hz)") -ax[2].set_ylabel("Frequency Power") -# ax.legend() -plt.show() +ax.set_xlabel("Freq (Hz)") +ax.set_ylabel("Frequency Power") +ax.set_title("LFP Fourier Decomposition") +ax.set_xlim(1, 30) +ax.axvline(9.36, c="orange", label="9.36Hz", alpha=0.5) +ax.axvline(18.74, c="green", label="18.74Hz", alpha=0.5) +ax.legend() +# ax.set_yscale('log') +# ax.set_xscale('log') # %% -# There seems to be some differences presenting themselves - a bump in higher frequencies for waking data? -# Let's explore further with a wavelet decomposition +# *** +# Getting the Wavelet Decomposition +# ----------------------------------- +# It looks like the prominent frequencies in the data may vary over time. For example, it looks like the +# LFP characteristics may be different while the animal is running along the track, and when it is finished. +# Let's generate a wavelet decomposition to look more closely at the changing frequency powers over time. +# We must define the frequency set that we'd like to use for our decomposition; these +# have been manually selected based on the frequencies used in Frey et. al (2021), but +# could also be defined as a linspace or logspace +freqs = np.array( + [ + 2.59, 3.66, 5.18, 8.0, 10.36, 20.72, 29.3, 41.44, 58.59, 82.88, + 117.19, 152.35, 192.19, 200., 234.38, 270.00, 331.5, 390.00, + ] +) +mwt_RUN = nap.compute_wavelet_transform(RUN_Tsd[:, channel], fs=None, freqs=freqs) +#Define wavelet decomposition plotting function def plot_timefrequency( - times, freqs, powers, x_ticks=5, y_ticks=None, ax=None, **kwargs + times, freqs, powers, x_ticks=5, ax=None, **kwargs ): if np.iscomplexobj(powers): powers = abs(powers) @@ -244,325 +197,197 @@ def plot_timefrequency( else: x_tick_pos = [np.argmin(np.abs(times - val)) for val in x_ticks] ax.set(xticks=x_tick_pos, xticklabels=x_ticks) - if isinstance(y_ticks, int): - y_ticks_pos = np.linspace(0, freqs.size, y_ticks) - y_ticks = np.round(np.linspace(freqs[0], freqs[-1], y_ticks), 2) - else: - y_ticks = freqs - y_ticks_pos = [np.argmin(np.abs(freqs - val)) for val in y_ticks] + y_ticks = freqs + y_ticks_pos = [np.argmin(np.abs(freqs - val)) for val in y_ticks] ax.set(yticks=y_ticks_pos, yticklabels=y_ticks) - - -fig = plt.figure(constrained_layout=True, figsize=(10, 50)) -num_cells = 10 + +# And plot +fig = plt.figure(constrained_layout=True, figsize=(10, 6)) axd = fig.subplot_mosaic( [ - ["wd_sws"], - ["lfp_sws"], - ["wd_rem"], - ["lfp_rem"], ["wd_run"], ["lfp_run"], ["pos_run"], ], - height_ratios=[1, 0.2, 1, 0.2, 1, 0.2, 0.2], -) -freqs = np.array( - [ - 2.59, - 3.66, - 5.18, - 8.0, - 10.36, - 14.65, - 20.72, - 29.3, - 41.44, - 58.59, - 82.88, - 117.19, - 150.00, - 190.00, - 234.38, - 270.00, - 331.5, - 390.00, - # 468.75, - # 520.00, - # 570.00, - # 624.0, - ] -) -mwt_SWS = nap.compute_wavelet_transform(SWS_minute[:, channel], fs=None, freqs=freqs) -plot_timefrequency( - SWS_minute.index.values[:], - freqs[:], - np.transpose(mwt_SWS[:, :].values), - ax=axd["wd_sws"], -) -axd["wd_sws"].set_title(f"non-REM Data Wavelet Decomposition: Channel {channel}") - -mwt_REM = nap.compute_wavelet_transform(REM_minute[:, channel], fs=None, freqs=freqs) -plot_timefrequency( - REM_minute.index.values[:], - freqs[:], - np.transpose(mwt_REM[:, :].values), - ax=axd["wd_rem"], + height_ratios=[1, 0.2, 0.4], ) -axd["wd_rem"].set_title(f"REM Data Wavelet Decomposition: Channel {channel}") - -mwt_RUN = nap.compute_wavelet_transform(RUN_minute[:, channel], fs=None, freqs=freqs) plot_timefrequency( - RUN_minute.index.values[:], + RUN_Tsd.index.values[:], freqs[:], np.transpose(mwt_RUN[:, :].values), ax=axd["wd_run"], ) -axd["wd_run"].set_title(f"Running Data Wavelet Decomposition: Channel {channel}") - -axd["lfp_sws"].plot(SWS_minute) -axd["lfp_rem"].plot(REM_minute) -axd["lfp_run"].plot(RUN_minute) -axd["pos_run"].plot(RUN_position) -axd["pos_run"].margins(0) -for k in ["lfp_sws", "lfp_rem", "lfp_run"]: +axd["wd_run"].set_title(f"Wavelet Decomposition") +axd["lfp_run"].plot(RUN_Tsd) +axd["pos_run"].plot(RUN_pos) +axd["pos_run"].set_xlim(RUN_Tsd.index.min(), RUN_Tsd.index.max()) +axd["pos_run"].set_ylabel("Lin. Position (cm)") +for k in ["lfp_run", "pos_run"]: axd[k].margins(0) - axd[k].set_ylabel("LFP (v)") + if k != "pos_run": + axd[k].set_ylabel("LFP (v)") axd[k].get_xaxis().set_visible(False) - axd[k].spines["top"].set_visible(False) - axd[k].spines["right"].set_visible(False) - axd[k].spines["bottom"].set_visible(False) - axd[k].spines["left"].set_visible(False) -plt.show() - -# %%g -freq = 3 -interval = (REM_minute_interval["start"] + 0, REM_minute_interval["start"] + 5) -REM_second = REM_minute.restrict(nap.IntervalSet(interval[0], interval[1])) -mwt_REM_second = mwt_REM.restrict(nap.IntervalSet(interval[0], interval[1])) -fig, ax = plt.subplots(1) -ax.plot(REM_second.index.values, REM_second[:, channel], alpha=0.5, label="Wake Data") -ax.plot( - REM_second.index.values, - mwt_REM_second[:, freq].values.real, - label="Theta oscillations", -) -ax.set_title(f"{freqs[freq]}Hz oscillation power.") -plt.show() + for spine in ["top", "right", "bottom", "left"]: + axd[k].spines[spine].set_visible(False) # %% -# Let's plot spike phase, time scatter plots to see if spikes display phase characteristics during wakeful theta oscillation - -fig = plt.figure(constrained_layout=True, figsize=(10, 50)) -num_cells = 10 +# *** +# Visualizing Theta Band Power +# ----------------------------------- +# There seems to be a strong theta frequency present in the data during the maze traversal. +# Let's plot the estimated 8Hz component of the wavelet decomposition on top of our data, and see how well +# they match up +theta_freq_index = 3 +theta_band_reconstruction = mwt_RUN[:, theta_freq_index].values.real +theta_band_power_envelope = np.abs(mwt_RUN[:, theta_freq_index].values) + +fig = plt.figure(constrained_layout=True, figsize=(10, 6)) axd = fig.subplot_mosaic( [ - ["raw_lfp"] * 2, - ["wavelet"] * 2, - ["fit_wavelet"] * 2, - ["wavelet_power"] * 2, - ["wavelet_phase"] * 2, - ] - + [[f"spikes_phasetime_{i}", f"spikephase_hist_{i}"] for i in range(num_cells)], + ["lfp_run"], + ["pos_run"], + ], + height_ratios=[1, 0.3], ) +axd["lfp_run"].plot(RUN_Tsd.index.values, RUN_Tsd[:, channel], + alpha=0.5, + label="LFP Data") +axd["lfp_run"].plot( + RUN_Tsd.index.values, + theta_band_reconstruction, + label=f"{freqs[theta_freq_index]}Hz oscillations", +) +axd["lfp_run"].plot( + RUN_Tsd.index.values, + theta_band_power_envelope, + label=f"{freqs[theta_freq_index]}Hz power envelope", +) -# _, ax = plt.subplots(25, figsize=(10, 50)) -mwt_REM = np.transpose(mwt_REM_second) -axd["raw_lfp"].plot(REM_second.index, REM_second.values[:, 0]) -axd["raw_lfp"].margins(0) -plot_timefrequency(REM_second.index, freqs, np.abs(mwt_REM[:, :]), ax=axd["wavelet"]) - -axd["fit_wavelet"].plot(REM_second.index, REM_second.values[:, 0]) -axd["fit_wavelet"].plot(REM_second.index, mwt_REM[freq, :].real) -axd["fit_wavelet"].set_title(f"{freqs[freq]}Hz") -axd["fit_wavelet"].margins(0) - -axd["wavelet_power"].plot(REM_second.index, np.abs(mwt_REM[freq, :])) -axd["wavelet_power"].margins(0) -# ax[3].plot(lfp.index, lfp.values[:,0]) -axd["wavelet_phase"].plot(REM_second.index, np.angle(mwt_REM[freq, :])) -axd["wavelet_phase"].margins(0) - -spikes = {} -for i in data["units"].index: - spikes[i] = data["units"][i].times()[ - (data["units"][i].times() > interval[0]) - & (data["units"][i].times() < interval[1]) - ] - -phase = {} -for i in spikes.keys(): - phase_i = [] - for spike in spikes[i]: - phase_i.append( - np.angle(mwt_REM[freq, np.argmin(np.abs(REM_second.index.values - spike))]) - ) - phase[i] = np.array(phase_i) - -spikes = {k: v for k, v in spikes.items() if len(v) > 20} -phase = {k: v for k, v in phase.items() if len(v) > 20} - -variances = { - key: scipy.stats.circvar(value, low=-np.pi, high=np.pi) - for key, value in phase.items() -} -spikes = dict(sorted(spikes.items(), key=lambda item: variances[item[0]])) -phase = dict(sorted(phase.items(), key=lambda item: variances[item[0]])) - -for i in range(num_cells): - axd[f"spikes_phasetime_{i}"].scatter( - spikes[list(spikes.keys())[i]], phase[list(phase.keys())[i]] - ) - axd[f"spikes_phasetime_{i}"].set_xlim(interval[0], interval[1]) - axd[f"spikes_phasetime_{i}"].set_ylim(-np.pi, np.pi) - axd[f"spikes_phasetime_{i}"].set_xlabel("time (s)") - axd[f"spikes_phasetime_{i}"].set_ylabel("phase") - - axd[f"spikephase_hist_{i}"].hist( - phase[list(phase.keys())[i]], bins=np.linspace(-np.pi, np.pi, 10) - ) - axd[f"spikephase_hist_{i}"].set_xlim(-np.pi, np.pi) +axd["lfp_run"].set_ylabel("LFP (v)") +axd["lfp_run"].set_xlabel("Time (s)") +axd["lfp_run"].set_title(f"{freqs[theta_freq_index]}Hz oscillation power.") +axd["pos_run"].plot(RUN_pos) +[axd[k].margins(0) for k in ["lfp_run", "pos_run"]] +[axd["pos_run"].spines[sp].set_visible(False) for sp in ["top", "right", "bottom", "left"]] +axd["pos_run"].get_xaxis().set_visible(False) +axd["pos_run"].set_xlim(RUN_Tsd.index.min(), RUN_Tsd.index.max()) +axd["pos_run"].set_ylabel("Lin. Position (cm)") +axd["lfp_run"].legend() -plt.tight_layout() -plt.show() +# %% +# *** +# Visualizing Sharp Wave Ripple Power +# ----------------------------------- +# There also seem to be peaks in the 200Hz frequency power after traversal of thew maze is complete. +# Let's plot the LFP, along with the 200Hz frequency power, to see if we can isolate these peaks and +# see what's going on. +ripple_freq_idx = 13 +ripple_power = np.abs(mwt_RUN[:, ripple_freq_idx].values) -# %% -# Let's focus on the sleeping data. Let's see if we can isolate the slow wave oscillations from the data -freq = 12 -# interval = (10, 15) -interval = (SWS_minute_interval["start"] + 30, SWS_minute_interval["start"] + 50) -SWS_second = SWS_minute.restrict(nap.IntervalSet(interval[0], interval[1])) -mwt_SWS_second = mwt_SWS.restrict(nap.IntervalSet(interval[0], interval[1])) -_, ax = plt.subplots(1) -ax.plot(SWS_second[:, channel], alpha=0.5, label="Wake Data") -ax.plot( - SWS_second.index.values, - mwt_SWS_second[:, freq].values.real, - label="Slow Wave Oscillations", +fig = plt.figure(constrained_layout=True, figsize=(10, 5)) +axd = fig.subplot_mosaic( + [ + ["lfp_run"], + ["rip_pow"], + ], + height_ratios=[1, 0.4], ) -ax.set_title(f"{freqs[freq]}Hz oscillation power") -plt.show() +axd["lfp_run"].plot(RUN_Tsd.index.values, RUN_Tsd[:, channel], label="LFP Data") +axd["lfp_run"].set_ylabel("LFP (v)") +axd["lfp_run"].set_xlabel("Time (s)") +axd["lfp_run"].margins(0) +axd["lfp_run"].set_title(f"Traversal & Post-Traversal LFP") +axd["rip_pow"].plot(RUN_Tsd.index.values, + ripple_power + ) +axd["rip_pow"].margins(0) +axd["rip_pow"].get_xaxis().set_visible(False) +axd["rip_pow"].spines["top"].set_visible(False) +axd["rip_pow"].spines["right"].set_visible(False) +axd["rip_pow"].spines["bottom"].set_visible(False) +axd["rip_pow"].spines["left"].set_visible(False) +axd["rip_pow"].set_xlim(RUN_Tsd.index.min(), RUN_Tsd.index.max()) +axd["rip_pow"].set_ylabel(f"{freqs[ripple_freq_idx]}Hz Power") # %% -# Let's plot spike phase, time scatter plots to see if spikes display phase characteristics during wakeful theta oscillation - -fig = plt.figure(constrained_layout=True, figsize=(10, 50)) -num_cells = 5 +# *** +# Isolating Ripple Times +# ----------------------------------- +# We can see one significant peak in the 200Hz frequency power. Let's smooth the power curve and threshold +# to try to isolate this event. + +# define our threshold +threshold = 100 +# smooth our wavelet power +window_size = 51 +window = np.ones(window_size) / window_size +smoother_swr_power = np.convolve(np.abs(mwt_RUN[:, ripple_freq_idx].values), window, mode='same') +# isolate our ripple periods +is_ripple = smoother_swr_power > threshold +start_idx = None +ripple_periods = [] +for i in range(len(RUN_Tsd.index.values)): + if is_ripple[i] and start_idx is None: + start_idx = i + elif not is_ripple[i] and start_idx is not None: + axd["rip_pow"].plot(RUN_Tsd.index.values[start_idx:i], smoother_swr_power[start_idx:i], color='red', linewidth=2) + ripple_periods.append( (start_idx, i) ) + start_idx = None + +# plot of captured ripple periods +fig = plt.figure(constrained_layout=True, figsize=(10, 5)) axd = fig.subplot_mosaic( [ - ["raw_lfp"] * 2, - ["wavelet"] * 2, - ["fit_wavelet"] * 2, - ["wavelet_power"] * 2, - ["wavelet_phase"] * 2, - ] - + [[f"spikes_phasetime_{i}", f"spikephase_hist_{i}"] for i in range(num_cells)], + ["lfp_run"], + ["rip_pow"], + ], + height_ratios=[1, 0.4], ) - - -# _, ax = plt.subplots(25, figsize=(10, 50)) -mwt_SWS = np.transpose(mwt_SWS_second) -axd["raw_lfp"].plot(SWS_second.index, SWS_second.values[:, 0]) -axd["raw_lfp"].margins(0) - -plot_timefrequency(SWS_second.index, freqs, np.abs(mwt_SWS[:, :]), ax=axd["wavelet"]) - -axd["fit_wavelet"].plot(SWS_second.index, SWS_second.values[:, 0]) -axd["fit_wavelet"].plot(SWS_second.index, mwt_SWS[freq, :].real) -axd["fit_wavelet"].set_title(f"{freqs[freq]}Hz") -axd["fit_wavelet"].margins(0) - -axd["wavelet_power"].plot(SWS_second.index, np.abs(mwt_SWS[freq, :])) -axd["wavelet_power"].margins(0) -axd["wavelet_phase"].plot(SWS_second.index, np.angle(mwt_SWS[freq, :])) -axd["wavelet_phase"].margins(0) - -spikes = {} -for i in data["units"].index: - spikes[i] = data["units"][i].times()[ - (data["units"][i].times() > interval[0]) - & (data["units"][i].times() < interval[1]) - ] - -phase = {} -for i in spikes.keys(): - phase_i = [] - for spike in spikes[i]: - phase_i.append( - np.angle(mwt_SWS[freq, np.argmin(np.abs(SWS_second.index.values - spike))]) - ) - phase[i] = np.array(phase_i) - -spikes = {k: v for k, v in spikes.items() if len(v) > 0} -phase = {k: v for k, v in phase.items() if len(v) > 0} - -for i in range(num_cells): - axd[f"spikes_phasetime_{i}"].scatter( - spikes[list(spikes.keys())[i]], phase[list(phase.keys())[i]] - ) - axd[f"spikes_phasetime_{i}"].set_xlim(interval[0], interval[1]) - axd[f"spikes_phasetime_{i}"].set_ylim(-np.pi, np.pi) - axd[f"spikes_phasetime_{i}"].set_xlabel("time (s)") - axd[f"spikes_phasetime_{i}"].set_ylabel("phase") - - axd[f"spikephase_hist_{i}"].hist( - phase[list(phase.keys())[i]], bins=np.linspace(-np.pi, np.pi, 10) - ) - axd[f"spikephase_hist_{i}"].set_xlim(-np.pi, np.pi) - -plt.tight_layout() -plt.show() +axd["lfp_run"].plot(RUN_Tsd.index.values, RUN_Tsd[:, channel], label="LFP Data") +axd["rip_pow"].plot(RUN_Tsd.index.values, + smoother_swr_power + ) +for r in ripple_periods: + axd["rip_pow"].plot(RUN_Tsd.index.values[r[0]:r[1]], smoother_swr_power[r[0]:r[1]], color='red', linewidth=2) +axd["lfp_run"].set_ylabel("LFP (v)") +axd["lfp_run"].set_xlabel("Time (s)") +axd["lfp_run"].set_title(f"{freqs[ripple_freq_idx]}Hz oscillation power.") +axd["rip_pow"].axhline(threshold) +[axd[k].margins(0) for k in ["lfp_run", "rip_pow"]] +[axd["rip_pow"].spines[sp].set_visible(False) for sp in ["top", "left", "right"]] +axd["rip_pow"].get_xaxis().set_visible(False) +axd["rip_pow"].set_xlim(RUN_Tsd.index.min(), RUN_Tsd.index.max()) +axd["rip_pow"].set_ylabel(f"{freqs[ripple_freq_idx]}Hz Power") # %% -# Let's focus on the sleeping data. Let's see if we can isolate the slow wave oscillations from the data -# interval = (10, 15) - -RUN_minute_interval = nap.IntervalSet( - data["forward_ep"]["start"][0], data["forward_ep"]["end"][-1] -) +# *** +# Plotting a Sharp Wave Ripple +# ----------------------------------- +# Let's zoom in on out detected ripples and have a closer look! -RUN_minute = nap.TsdFrame( - t=data["eeg"].restrict(RUN_minute_interval).index.values, - d=data["eeg"].restrict(RUN_minute_interval).values, - time_support=data["eeg"].restrict(RUN_minute_interval).time_support, -) +# Filter out ripples which do not last long enough +ripple_periods = [r for r in ripple_periods if r[1]-r[0] > 20] -RUN_position = nap.TsdFrame( - t=data["position"].restrict(RUN_minute_interval).index.values[:], - d=data["position"].restrict(RUN_minute_interval), - time_support=data["position"].restrict(RUN_minute_interval).time_support, +# And plot! +fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) +buffer = 200 +ax.plot( + RUN_Tsd.index.values[r[0]-buffer:r[1]+buffer], + RUN_Tsd[r[0]-buffer:r[1]+buffer], + color="blue", + label="Non-SWR LFP" ) - -mwt_RUN = nap.compute_wavelet_transform( - RUN_minute[:, channel], freqs=freqs, fs=None, norm=None, n_cycles=3.5, scaling=1 +ax.plot( + RUN_Tsd.index.values[r[0]:r[1]], + RUN_Tsd[r[0]:r[1]], + color="red", + label="SWR", + linewidth=2 ) - -for run in range(len(data["forward_ep"]["start"])): - interval = ( - data["forward_ep"]["start"][run], - data["forward_ep"]["end"][run] + 5.0, - ) - if interval[1] - interval[0] < 6: - continue - print(interval) - RUN_second_r = RUN_minute.restrict(nap.IntervalSet(interval[0], interval[1])) - RUN_position_r = RUN_position.restrict(nap.IntervalSet(interval[0], interval[1])) - mwt_RUN_second_r = mwt_RUN.restrict(nap.IntervalSet(interval[0], interval[1])) - _, ax = plt.subplots(3) - plot_timefrequency( - RUN_second_r.index.values[:], - freqs[:], - np.transpose(mwt_RUN_second_r[:, :].values), - ax=ax[0], - ) - ax[1].plot(RUN_second_r[:, channel], alpha=0.5, label="Wake Data") - ax[1].margins(0) - - ax[2].plot(RUN_position, alpha=0.5, label="Wake Data") - ax[2].set_xlim( - RUN_second_r[:, channel].index.min(), RUN_second_r[:, channel].index.max() - ) - ax[2].margins(0) - plt.show() +ax.margins(0) +ax.set_xlabel("Time (s)") +ax.set_ylabel("LFP (v)") +ax.legend() +ax.set_title("Sharp Wave Ripple Visualization") diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 8c0a0fc7..3dcd06db 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -13,7 +13,7 @@ def compute_spectogram(sig, fs=None, ep=None, full_range=False): """ - Performs numpy fft on sig, returns output + Performs numpy fft on sig, returns output. Pynapple assumes a constant sampling rate for sig. ---------- sig : pynapple.Tsd or pynapple.TsdFrame @@ -24,6 +24,11 @@ def compute_spectogram(sig, fs=None, ep=None, full_range=False): The epoch to calculate the fft on. Must be length 1. full_range : bool, optional If true, will return full fft frequency range, otherwise will return only positive values + + Notes + ----- + compute_spectogram computes fft on only a single epoch of data. This epoch be given with the ep + parameter otherwise will be sig.time_support, but it must only be a single epoch. """ if not isinstance(sig, (nap.Tsd, nap.TsdFrame)): raise TypeError( @@ -100,7 +105,7 @@ def _morlet(M=1024, ncycles=1.5, scaling=1.0, precision=8): ) -def _create_freqs(freq_start, freq_stop, log_scaling=False, freq_step=1, log_base=np.e): +def _create_freqs(freq_start, freq_stop, freq_step=1, log_scaling=False, log_base=np.e): """ Creates an array of frequencies. @@ -110,10 +115,10 @@ def _create_freqs(freq_start, freq_stop, log_scaling=False, freq_step=1, log_bas Starting value for the frequency definition. freq_stop: float Stopping value for the frequency definition, inclusive. - log_scaling: Bool - If True, will use log spacing with base log_base for frequency spacing. Default False. freq_step: float, optional Step value, for linearly spaced values between start and stop. + log_scaling: Bool + If True, will use log spacing with base log_base for frequency spacing. Default False. log_base: float If log_scaling==True, this defines the base of the log to use. @@ -136,7 +141,7 @@ def compute_wavelet_transform( Parameters ---------- - sig : pynapple.Tsd or pynapple.TsdFrame + sig : pynapple.Tsd, pynapple.TsdFrame or pynapple.TsdTensor Time series. freqs : 1d array or list of float If array, frequency values to estimate with morlet wavelets. @@ -159,7 +164,7 @@ def compute_wavelet_transform( Returns ------- - mwt : 2d array + pynapple.TsdFrame or pynapple.TsdTensor : 2d array Time frequency representation of the input signal. Notes @@ -186,11 +191,11 @@ def compute_wavelet_transform( output_shape = (sig.shape[0], len(freqs), *sig.shape[1:]) sig = sig.reshape((sig.shape[0], np.prod(sig.shape[1:]))) - mwt = np.zeros( + cwt = np.zeros( [sig.values.shape[0], len(freqs), sig.values.shape[1]], dtype=complex ) - filter_bank = _generate_morlet_filterbank(freqs, fs, n_cycles, scaling, precision) + filter_bank = generate_morlet_filterbank(freqs, fs, n_cycles, scaling, precision) for f_i, filter in enumerate(filter_bank): convolved = sig.convolve(np.transpose(np.asarray([filter.real, filter.imag]))) convolved = convolved[:, :, 0].values + convolved[:, :, 1].values * 1j @@ -202,19 +207,19 @@ def compute_wavelet_transform( coef = np.insert( coef, 1, coef[0], axis=0 ) # slightly hacky line, necessary to make output correct shape - mwt[:, f_i, :] = coef if len(coef.shape) == 2 else np.expand_dims(coef, axis=1) + cwt[:, f_i, :] = coef if len(coef.shape) == 2 else np.expand_dims(coef, axis=1) if len(output_shape) == 2: return nap.TsdFrame( - t=sig.index, d=mwt.reshape(output_shape), time_support=sig.time_support + t=sig.index, d=cwt.reshape(output_shape), time_support=sig.time_support ) return nap.TsdTensor( - t=sig.index, d=mwt.reshape(output_shape), time_support=sig.time_support + t=sig.index, d=cwt.reshape(output_shape), time_support=sig.time_support ) -def _generate_morlet_filterbank(freqs, fs, n_cycles, scaling, precision): +def generate_morlet_filterbank(freqs, fs, n_cycles=1.5, scaling=1.0, precision=10): """ Parameters @@ -253,7 +258,8 @@ def _generate_morlet_filterbank(freqs, fs, n_cycles, scaling, precision): if len(int_psi_scale) > max_len: max_len = len(int_psi_scale) filter_bank.append(int_psi_scale) - return filter_bank + filter_bank = [np.pad(arr, ((max_len - len(arr)) // 2, (max_len - len(arr) + 1) // 2), constant_values=0.0) for arr in filter_bank] + return np.array(filter_bank) def _integrate(arr, step): diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index 5055cdbf..87686d18 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -75,6 +75,30 @@ def test_compute_welch_spectogram(): def test_compute_wavelet_transform(): + t = np.linspace(0, 1, 1000) + sig = nap.Tsd(d=np.sin(t*50*np.pi*2), t=t) + freqs = np.linspace(10, 100, 10) + mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) + mpf = freqs[np.argmax(np.sum(np.abs(mwt), axis=0))] + assert mpf == 50 + assert mwt.shape == (1000, 10) + + t = np.linspace(0, 1, 1000) + sig = nap.Tsd(d=np.sin(t * 20 * np.pi * 2), t=t) + freqs = np.linspace(10, 100, 10) + mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) + mpf = freqs[np.argmax(np.sum(np.abs(mwt), axis=0))] + assert mpf == 20 + assert mwt.shape == (1000, 10) + + t = np.linspace(0, 1, 1000) + sig = nap.Tsd(d=np.sin(t * 70 * np.pi * 2), t=t) + freqs = np.linspace(10, 100, 10) + mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) + mpf = freqs[np.argmax(np.sum(np.abs(mwt), axis=0))] + assert mpf == 70 + assert mwt.shape == (1000, 10) + t = np.linspace(0, 1, 1024) sig = nap.Tsd(d=np.random.random(1024), t=t) freqs = np.linspace(1, 600, 10) @@ -104,4 +128,5 @@ def test_compute_wavelet_transform(): if __name__ == "__main__": - test_compute_welch_spectogram() + test_compute_wavelet_transform() + # test_compute_welch_spectogram() From b00bf23a22de6321e680a51a8ff0b838c6ac31da Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Wed, 17 Jul 2024 16:45:43 +0100 Subject: [PATCH 25/71] linting --- docs/examples/tutorial_signal_processing.py | 121 ++++++++++++-------- pynapple/process/signal_processing.py | 9 +- tests/test_signal_processing.py | 2 +- 3 files changed, 84 insertions(+), 48 deletions(-) diff --git a/docs/examples/tutorial_signal_processing.py b/docs/examples/tutorial_signal_processing.py index 7f2383a9..6379c9be 100644 --- a/docs/examples/tutorial_signal_processing.py +++ b/docs/examples/tutorial_signal_processing.py @@ -77,13 +77,13 @@ data["forward_ep"]["end"][run_index] + 4.0, ) RUN_Tsd = nap.TsdFrame( - t=data["eeg"].restrict(RUN_interval).index.values - - data["forward_ep"]["start"][run_index], + t=data["eeg"].restrict(RUN_interval).index.values + - data["forward_ep"]["start"][run_index], d=data["eeg"].restrict(RUN_interval).values, ) RUN_pos = nap.TsdFrame( - t=data["position"].restrict(RUN_interval).index.values - - data["forward_ep"]["start"][run_index], + t=data["position"].restrict(RUN_interval).index.values + - data["forward_ep"]["start"][run_index], d=data["position"].restrict(RUN_interval).asarray(), ) # The given dataset has only one channel, so we set channel = 0 here @@ -96,10 +96,7 @@ fig = plt.figure(constrained_layout=True, figsize=(10, 6)) axd = fig.subplot_mosaic( - [ - ["ephys"], - ["pos"] - ], + [["ephys"], ["pos"]], height_ratios=[1, 0.2], ) @@ -107,24 +104,25 @@ RUN_Tsd[:, channel].restrict( nap.IntervalSet( 0.0, - data["forward_ep"]["end"][run_index] - - data["forward_ep"]["start"][run_index] + data["forward_ep"]["end"][run_index] + - data["forward_ep"]["start"][run_index], ) ), label="Traversal LFP Data", - color="green" + color="green", ) axd["ephys"].plot( RUN_Tsd[:, channel].restrict( nap.IntervalSet( - data["forward_ep"]["end"][run_index] - - data["forward_ep"]["start"][run_index], - data["forward_ep"]["end"][run_index] - - data["forward_ep"]["start"][run_index] + 5.0, + data["forward_ep"]["end"][run_index] + - data["forward_ep"]["start"][run_index], + data["forward_ep"]["end"][run_index] + - data["forward_ep"]["start"][run_index] + + 5.0, ) ), label="Post Traversal LFP Data", - color="blue" + color="blue", ) axd["ephys"].set_title("Traversal & Post Traversal LFP") axd["ephys"].set_ylabel("LFP (v)") @@ -149,8 +147,12 @@ # Now we will plot it fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) ax.plot( - fft.index, np.abs(fft.iloc[:, channel]), alpha=0.5, - label="LFP Frequency Power", c="blue", linewidth=2 + fft.index, + np.abs(fft.iloc[:, channel]), + alpha=0.5, + label="LFP Frequency Power", + c="blue", + linewidth=2, ) ax.set_xlabel("Freq (Hz)") ax.set_ylabel("Frequency Power") @@ -175,16 +177,31 @@ # could also be defined as a linspace or logspace freqs = np.array( [ - 2.59, 3.66, 5.18, 8.0, 10.36, 20.72, 29.3, 41.44, 58.59, 82.88, - 117.19, 152.35, 192.19, 200., 234.38, 270.00, 331.5, 390.00, + 2.59, + 3.66, + 5.18, + 8.0, + 10.36, + 20.72, + 29.3, + 41.44, + 58.59, + 82.88, + 117.19, + 152.35, + 192.19, + 200.0, + 234.38, + 270.00, + 331.5, + 390.00, ] ) mwt_RUN = nap.compute_wavelet_transform(RUN_Tsd[:, channel], fs=None, freqs=freqs) -#Define wavelet decomposition plotting function -def plot_timefrequency( - times, freqs, powers, x_ticks=5, ax=None, **kwargs -): + +# Define wavelet decomposition plotting function +def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): if np.iscomplexobj(powers): powers = abs(powers) ax.imshow(powers, aspect="auto", **kwargs) @@ -200,7 +217,8 @@ def plot_timefrequency( y_ticks = freqs y_ticks_pos = [np.argmin(np.abs(freqs - val)) for val in y_ticks] ax.set(yticks=y_ticks_pos, yticklabels=y_ticks) - + + # And plot fig = plt.figure(constrained_layout=True, figsize=(10, 6)) axd = fig.subplot_mosaic( @@ -250,9 +268,9 @@ def plot_timefrequency( height_ratios=[1, 0.3], ) -axd["lfp_run"].plot(RUN_Tsd.index.values, RUN_Tsd[:, channel], - alpha=0.5, - label="LFP Data") +axd["lfp_run"].plot( + RUN_Tsd.index.values, RUN_Tsd[:, channel], alpha=0.5, label="LFP Data" +) axd["lfp_run"].plot( RUN_Tsd.index.values, theta_band_reconstruction, @@ -269,7 +287,10 @@ def plot_timefrequency( axd["lfp_run"].set_title(f"{freqs[theta_freq_index]}Hz oscillation power.") axd["pos_run"].plot(RUN_pos) [axd[k].margins(0) for k in ["lfp_run", "pos_run"]] -[axd["pos_run"].spines[sp].set_visible(False) for sp in ["top", "right", "bottom", "left"]] +[ + axd["pos_run"].spines[sp].set_visible(False) + for sp in ["top", "right", "bottom", "left"] +] axd["pos_run"].get_xaxis().set_visible(False) axd["pos_run"].set_xlim(RUN_Tsd.index.min(), RUN_Tsd.index.max()) axd["pos_run"].set_ylabel("Lin. Position (cm)") @@ -299,9 +320,7 @@ def plot_timefrequency( axd["lfp_run"].set_xlabel("Time (s)") axd["lfp_run"].margins(0) axd["lfp_run"].set_title(f"Traversal & Post-Traversal LFP") -axd["rip_pow"].plot(RUN_Tsd.index.values, - ripple_power - ) +axd["rip_pow"].plot(RUN_Tsd.index.values, ripple_power) axd["rip_pow"].margins(0) axd["rip_pow"].get_xaxis().set_visible(False) axd["rip_pow"].spines["top"].set_visible(False) @@ -323,7 +342,9 @@ def plot_timefrequency( # smooth our wavelet power window_size = 51 window = np.ones(window_size) / window_size -smoother_swr_power = np.convolve(np.abs(mwt_RUN[:, ripple_freq_idx].values), window, mode='same') +smoother_swr_power = np.convolve( + np.abs(mwt_RUN[:, ripple_freq_idx].values), window, mode="same" +) # isolate our ripple periods is_ripple = smoother_swr_power > threshold start_idx = None @@ -332,8 +353,13 @@ def plot_timefrequency( if is_ripple[i] and start_idx is None: start_idx = i elif not is_ripple[i] and start_idx is not None: - axd["rip_pow"].plot(RUN_Tsd.index.values[start_idx:i], smoother_swr_power[start_idx:i], color='red', linewidth=2) - ripple_periods.append( (start_idx, i) ) + axd["rip_pow"].plot( + RUN_Tsd.index.values[start_idx:i], + smoother_swr_power[start_idx:i], + color="red", + linewidth=2, + ) + ripple_periods.append((start_idx, i)) start_idx = None # plot of captured ripple periods @@ -346,11 +372,14 @@ def plot_timefrequency( height_ratios=[1, 0.4], ) axd["lfp_run"].plot(RUN_Tsd.index.values, RUN_Tsd[:, channel], label="LFP Data") -axd["rip_pow"].plot(RUN_Tsd.index.values, - smoother_swr_power - ) +axd["rip_pow"].plot(RUN_Tsd.index.values, smoother_swr_power) for r in ripple_periods: - axd["rip_pow"].plot(RUN_Tsd.index.values[r[0]:r[1]], smoother_swr_power[r[0]:r[1]], color='red', linewidth=2) + axd["rip_pow"].plot( + RUN_Tsd.index.values[r[0] : r[1]], + smoother_swr_power[r[0] : r[1]], + color="red", + linewidth=2, + ) axd["lfp_run"].set_ylabel("LFP (v)") axd["lfp_run"].set_xlabel("Time (s)") axd["lfp_run"].set_title(f"{freqs[ripple_freq_idx]}Hz oscillation power.") @@ -368,23 +397,23 @@ def plot_timefrequency( # Let's zoom in on out detected ripples and have a closer look! # Filter out ripples which do not last long enough -ripple_periods = [r for r in ripple_periods if r[1]-r[0] > 20] +ripple_periods = [r for r in ripple_periods if r[1] - r[0] > 20] # And plot! fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) buffer = 200 ax.plot( - RUN_Tsd.index.values[r[0]-buffer:r[1]+buffer], - RUN_Tsd[r[0]-buffer:r[1]+buffer], + RUN_Tsd.index.values[r[0] - buffer : r[1] + buffer], + RUN_Tsd[r[0] - buffer : r[1] + buffer], color="blue", - label="Non-SWR LFP" + label="Non-SWR LFP", ) ax.plot( - RUN_Tsd.index.values[r[0]:r[1]], - RUN_Tsd[r[0]:r[1]], + RUN_Tsd.index.values[r[0] : r[1]], + RUN_Tsd[r[0] : r[1]], color="red", label="SWR", - linewidth=2 + linewidth=2, ) ax.margins(0) ax.set_xlabel("Time (s)") diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 3dcd06db..fe000c24 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -258,7 +258,14 @@ def generate_morlet_filterbank(freqs, fs, n_cycles=1.5, scaling=1.0, precision=1 if len(int_psi_scale) > max_len: max_len = len(int_psi_scale) filter_bank.append(int_psi_scale) - filter_bank = [np.pad(arr, ((max_len - len(arr)) // 2, (max_len - len(arr) + 1) // 2), constant_values=0.0) for arr in filter_bank] + filter_bank = [ + np.pad( + arr, + ((max_len - len(arr)) // 2, (max_len - len(arr) + 1) // 2), + constant_values=0.0, + ) + for arr in filter_bank + ] return np.array(filter_bank) diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index 87686d18..81c6b024 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -76,7 +76,7 @@ def test_compute_welch_spectogram(): def test_compute_wavelet_transform(): t = np.linspace(0, 1, 1000) - sig = nap.Tsd(d=np.sin(t*50*np.pi*2), t=t) + sig = nap.Tsd(d=np.sin(t * 50 * np.pi * 2), t=t) freqs = np.linspace(10, 100, 10) mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) mpf = freqs[np.argmax(np.sum(np.abs(mwt), axis=0))] From 81a076e28c5606897f1c9add3191308a2c8ad668 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Wed, 17 Jul 2024 19:23:38 +0100 Subject: [PATCH 26/71] phase preference notebook added --- docs/examples/tutorial_phase_preferences.py | 359 ++++++++++++++++++++ docs/examples/tutorial_signal_processing.py | 2 +- 2 files changed, 360 insertions(+), 1 deletion(-) create mode 100644 docs/examples/tutorial_phase_preferences.py diff --git a/docs/examples/tutorial_phase_preferences.py b/docs/examples/tutorial_phase_preferences.py new file mode 100644 index 00000000..ce3f35ba --- /dev/null +++ b/docs/examples/tutorial_phase_preferences.py @@ -0,0 +1,359 @@ +# -*- coding: utf-8 -*- +""" +Grosmark & Buzsáki (2016) Tutorial 2 +============ + +In the previous [Grosmark & Buzsáki (2016)](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4919122/) tutorial, +we learned how to use Pynapple's signal processing tools with Local Field Potential data. Specifically, we +used wavelet decompositions to isolate Theta band activity during active traversal of a linear track, +as well as to find Sharp Wave Ripples which occurred after traversal. + +In this tutorial we will learn how to isolate phase information from our wavelet decomposition and combine it +with spiking data, to find phase preferences of spiking units. + +Specifically, we will examine LFP and spiking data from a period of REM sleep, after traversal of a linear track. + +This tutorial was made by Kipp Freud. +""" + +# %% +# !!! warning +# This tutorial uses matplotlib for displaying the figure +# +# You can install all with `pip install matplotlib requests tqdm` +# +# First, import the necessary libraries: + +import math +import os + +# ..todo: remove +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import requests +import scipy +import tqdm + +import pynapple as nap + +matplotlib.use("TkAgg") + +# %% +# *** +# Downloading the data +# ------------------ +# Let's download the data and save it locally + +path = "Achilles_10252013_EEG.nwb" +if path not in os.listdir("."): + r = requests.get(f"https://osf.io/2dfvp/download", stream=True) + block_size = 1024 * 1024 + with open(path, "wb") as f: + for data in tqdm.tqdm( + r.iter_content(block_size), + unit="MB", + unit_scale=True, + total=math.ceil(int(r.headers.get("content-length", 0)) // block_size), + ): + f.write(data) + + +# %% +# *** +# Loading the data +# ------------------ +# Let's load and print the full dataset. + +data = nap.load_file(path) +FS = len(data["eeg"].index[:]) / (data["eeg"].index[-1] - data["eeg"].index[0]) +print(data) + + +# %% +# *** +# Selecting slices +# ----------------------------------- +# Let's consider a 10-second slice of data taken during REM sleep + +# Define the IntervalSet for this run and instantiate both LFP and +# Position TsdFrame objects +REM_minute_interval = nap.IntervalSet( + data["rem"]["start"][0] + 90.0, + data["rem"]["start"][0] + 100.0, +) +REM_Tsd = nap.TsdFrame( + t=data["eeg"].restrict(REM_minute_interval).index.values + - data["eeg"].restrict(REM_minute_interval).index.values.min(), + d=data["eeg"].restrict(REM_minute_interval).values, +) + +# We will also extract spike times from all units in our dataset +# which occur during our specified interval +spikes = {} +for i in data["units"].index: + spikes[i] = ( + data["units"][i].times()[ + (data["units"][i].times() > REM_minute_interval["start"][0]) + & (data["units"][i].times() < REM_minute_interval["end"][0]) + ] + - data["eeg"].restrict(REM_minute_interval).index.values.min() + ) + +# The given dataset has only one channel, so we set channel = 0 here +channel = 0 + +# %% +# *** +# Plotting the LFP Activity +# ----------------------------------- +# We should first plot our REM Local Field Potential data. + +fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) + +ax.plot( + REM_Tsd[:, channel], + label="REM LFP Data", + color="green", +) +ax.set_title("REM Local Field Potential") +ax.set_ylabel("LFP (v)") +ax.set_xlabel("time (s)") +ax.margins(0) +ax.legend() +plt.show() + +# %% +# *** +# Getting the Wavelet Decomposition +# ----------------------------------- +# As we would expect, it looks like we have a very strong theta oscillation within our data +# - this is a common feature of REM sleep. Let's perform a wavelet decomposition, +# as we did in the last tutorial, to see get a more informative breakdown of the +# frequencies present in the data. + +# We must define the frequency set that we'd like to use for our decomposition; +# these have been manually selected based on the frequencies used in +# Frey et. al (2021), but could also be defined as a linspace or logspace +freqs = np.array( + [ + 2.59, + 3.66, + 5.18, + 8.0, + 10.36, + 20.72, + 29.3, + 41.44, + 58.59, + 82.88, + 117.19, + 152.35, + 192.19, + 200.0, + 234.38, + 270.00, + 331.5, + 390.00, + ] +) +mwt_REM = nap.compute_wavelet_transform(REM_Tsd[:, channel], fs=None, freqs=freqs) + + +# Define wavelet decomposition plotting function +def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): + if np.iscomplexobj(powers): + powers = abs(powers) + ax.imshow(powers, aspect="auto", **kwargs) + ax.invert_yaxis() + ax.set_xlabel("Time (s)") + ax.set_ylabel("Frequency (Hz)") + if isinstance(x_ticks, int): + x_tick_pos = np.linspace(0, times.size, x_ticks) + x_ticks = np.round(np.linspace(times[0], times[-1], x_ticks), 2) + else: + x_tick_pos = [np.argmin(np.abs(times - val)) for val in x_ticks] + ax.set(xticks=x_tick_pos, xticklabels=x_ticks) + y_ticks = freqs + y_ticks_pos = [np.argmin(np.abs(freqs - val)) for val in y_ticks] + ax.set(yticks=y_ticks_pos, yticklabels=y_ticks) + + +# And plot it +fig = plt.figure(constrained_layout=True, figsize=(10, 6)) +axd = fig.subplot_mosaic( + [ + ["wd_rem"], + ["lfp_rem"], + ], + height_ratios=[1, 0.2], +) +plot_timefrequency( + REM_Tsd.index.values[:], + freqs[:], + np.transpose(mwt_REM[:, :].values), + ax=axd["wd_rem"], +) +axd["wd_rem"].set_title(f"Wavelet Decomposition") +axd["lfp_rem"].plot(REM_Tsd) +axd["lfp_rem"].margins(0) +axd["lfp_rem"].set_ylabel("LFP (v)") +axd["lfp_rem"].get_xaxis().set_visible(False) +for spine in ["top", "right", "bottom", "left"]: + axd["lfp_rem"].spines[spine].set_visible(False) +plt.show() + +# %% +# *** +# Visualizing Theta Band Power and Phase +# ----------------------------------- +# There seems to be a strong theta frequency present in the data during the maze traversal. +# Let's plot the estimated 8Hz component of the wavelet decomposition on top of our data, and see how well +# they match up. We will also extract and plot the phase of the 8Hz wavelet from the decomposition. +theta_freq_index = 3 +theta_band_reconstruction = mwt_REM[:, theta_freq_index].values.real +# calculating phase here +theta_band_phase = np.angle(mwt_REM[:, theta_freq_index].values) + +fig = plt.figure(constrained_layout=True, figsize=(10, 5)) +axd = fig.subplot_mosaic( + [ + ["theta_pow"], + ["phase"], + ], + height_ratios=[0.4, 0.2], +) + +axd["theta_pow"].plot( + REM_Tsd.index.values, REM_Tsd[:, channel], alpha=0.5, label="LFP Data - REM" +) +axd["theta_pow"].plot( + REM_Tsd.index.values, + theta_band_reconstruction, + label=f"{freqs[theta_freq_index]}Hz oscillations", +) +axd["theta_pow"].set_ylabel("LFP (v)") +axd["theta_pow"].set_xlabel("Time (s)") +axd["theta_pow"].set_title(f"{freqs[theta_freq_index]}Hz oscillation power.") # +axd["theta_pow"].legend() +axd["phase"].plot(theta_band_phase) +[axd[k].margins(0) for k in ["theta_pow", "phase"]] +axd["phase"].set_ylabel("Phase") +plt.show() + + +# %% +# *** +# Finding Phase of Spikes +# ----------------------------------- +# Now that we have the phase of our theta wavelet, and our spike times, we can find the theta phase at which every +# spike occurs + +# We will start by throwing away cells which do not have enough +# spikes during our interval +spikes = {k: v for k, v in spikes.items() if len(v) > 20} +# Get phase of each spike +phase = {} +for i in spikes.keys(): + phase_i = [] + for spike in spikes[i]: + phase_i.append( + np.angle( + mwt_REM[ + np.argmin(np.abs(REM_Tsd.index.values - spike)), theta_freq_index + ] + ) + ) + phase[i] = np.array(phase_i) + +# Let's plot phase histograms for the first six units to see if there's +# any obvious preferences +fig, ax = plt.subplots(2, 3, constrained_layout=True, figsize=(10, 6)) +for ri in range(2): + for ci in range(3): + ax[ri, ci].hist( + phase[list(phase.keys())[ri * 3 + ci]], + bins=np.linspace(-np.pi, np.pi, 10), + density=True, + ) + ax[ri, ci].set_xlabel("Phase (rad)") + ax[ri, ci].set_ylabel("Density") + ax[ri, ci].set_title(f"Unit {list(phase.keys())[ri*3 + ci]}") +fig.suptitle("Phase Preference Histograms of First 6 Units") +plt.show() + +# %% +# *** +# Isolating Strong Phase Preferences +# ----------------------------------- +# It looks like there could be some phase preferences happening here, but there's a lot of cells to go through. +# Now that we have our phases of firing for each unit, we can sort the units by the circular variance of the phase +# of their spikes, to isolate the cells with the strongest phase preferences without manual inspection. + +variances = { + key: scipy.stats.circvar(value, low=-np.pi, high=np.pi) + for key, value in phase.items() +} +spikes = dict(sorted(spikes.items(), key=lambda item: variances[item[0]])) +phase = dict(sorted(phase.items(), key=lambda item: variances[item[0]])) + +# Now let's plot phase histograms for the six units with the least +# varied phase of spikes. +fig, ax = plt.subplots(2, 3, constrained_layout=True, figsize=(10, 6)) +for ri in range(2): + for ci in range(3): + ax[ri, ci].hist( + phase[list(phase.keys())[ri * 3 + ci]], + bins=np.linspace(-np.pi, np.pi, 10), + density=True, + ) + ax[ri, ci].set_xlabel("Phase (rad)") + ax[ri, ci].set_ylabel("Density") + ax[ri, ci].set_title(f"Unit {list(phase.keys())[ri*3 + ci]}") +fig.suptitle( + "Phase Preference Histograms of 6 Units with " + "Highest Phase Preference" +) +plt.show() + +# %% +# *** +# Visualizing Phase Preferences +# ----------------------------------- +# There is definitely some strong phase preferences happening here. Let's visualize the firing preferences +# of the 6 cells we've isolated to get an impression of just how striking these preferences are. + +fig = plt.figure(constrained_layout=True, figsize=(10, 12)) +axd = fig.subplot_mosaic( + [ + ["lfp_run"], + ["phase_0"], + ["phase_1"], + ["phase_2"], + ["phase_3"], + ["phase_4"], + ["phase_5"], + ], + height_ratios=[0.4, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2], +) +[axd[k].margins(0) for k in ["lfp_run"] + [f"phase_{i}" for i in range(6)]] +axd["lfp_run"].plot( + REM_Tsd.index.values, REM_Tsd[:, channel], alpha=0.5, label="LFP Data - REM" +) +axd["lfp_run"].plot( + REM_Tsd.index.values, + theta_band_reconstruction, + label=f"{freqs[theta_freq_index]}Hz oscillations", +) +axd["lfp_run"].set_ylabel("LFP (v)") +axd["lfp_run"].set_xlabel("Time (s)") +axd["lfp_run"].set_title(f"{freqs[theta_freq_index]}Hz oscillation power.") +axd["lfp_run"].legend() +for i in range(6): + axd[f"phase_{i}"].plot(REM_Tsd.index.values, theta_band_phase, alpha=0.2) + axd[f"phase_{i}"].scatter( + spikes[list(spikes.keys())[i]], phase[list(spikes.keys())[i]] + ) + axd[f"phase_{i}"].set_ylabel("Phase") + axd[f"phase_{i}"].set_title(f"Unit {list(spikes.keys())[i]}") +fig.suptitle("Phase Preference Visualizations") +plt.show() diff --git a/docs/examples/tutorial_signal_processing.py b/docs/examples/tutorial_signal_processing.py index 6379c9be..34941449 100644 --- a/docs/examples/tutorial_signal_processing.py +++ b/docs/examples/tutorial_signal_processing.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- """ -Grosmark & Buzsáki (2016) Tutorial +Grosmark & Buzsáki (2016) Tutorial 1 ============ This tutorial demonstrates how we can use the signal processing tools within Pynapple to aid with data analysis. We will examine the dataset from [Grosmark & Buzsáki (2016)](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4919122/). From ebdb64ca26f5f447a79923738156357a06b4d848 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Wed, 17 Jul 2024 19:27:28 +0100 Subject: [PATCH 27/71] remove unused import --- docs/examples/tutorial_phase_preferences.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/docs/examples/tutorial_phase_preferences.py b/docs/examples/tutorial_phase_preferences.py index ce3f35ba..59ba685b 100644 --- a/docs/examples/tutorial_phase_preferences.py +++ b/docs/examples/tutorial_phase_preferences.py @@ -27,8 +27,6 @@ import math import os -# ..todo: remove -import matplotlib import matplotlib.pyplot as plt import numpy as np import requests @@ -37,8 +35,6 @@ import pynapple as nap -matplotlib.use("TkAgg") - # %% # *** # Downloading the data From 0f17883b52325e716371363e17a885bf6dab676a Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Wed, 17 Jul 2024 19:50:25 +0100 Subject: [PATCH 28/71] simplified compute_wavelet_transform, added tests --- pynapple/process/signal_processing.py | 28 ++++++++++++--------------- tests/test_signal_processing.py | 21 +++++++++++++++----- 2 files changed, 28 insertions(+), 21 deletions(-) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index fe000c24..936ae6a3 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -191,23 +191,19 @@ def compute_wavelet_transform( output_shape = (sig.shape[0], len(freqs), *sig.shape[1:]) sig = sig.reshape((sig.shape[0], np.prod(sig.shape[1:]))) - cwt = np.zeros( - [sig.values.shape[0], len(freqs), sig.values.shape[1]], dtype=complex - ) - filter_bank = generate_morlet_filterbank(freqs, fs, n_cycles, scaling, precision) - for f_i, filter in enumerate(filter_bank): - convolved = sig.convolve(np.transpose(np.asarray([filter.real, filter.imag]))) - convolved = convolved[:, :, 0].values + convolved[:, :, 1].values * 1j - coef = -np.diff(convolved, axis=0) - if norm == "sss": - coef *= -np.sqrt(scaling) / (freqs[f_i] / fs) - elif norm == "amp": - coef *= -scaling / (freqs[f_i] / fs) - coef = np.insert( - coef, 1, coef[0], axis=0 - ) # slightly hacky line, necessary to make output correct shape - cwt[:, f_i, :] = coef if len(coef.shape) == 2 else np.expand_dims(coef, axis=1) + convolved_real = sig.convolve(np.transpose(filter_bank.real)) + convolved_imag = sig.convolve(np.transpose(filter_bank.imag)) + convolved = convolved_real.values + convolved_imag.values * 1j + coef = -np.diff(convolved, axis=0) + if norm == "sss": + coef *= coef * (-np.sqrt(scaling) / (freqs / fs)) + elif norm == "amp": + coef *= -scaling / (freqs / fs) + coef = np.insert( + coef, 1, coef[0, :], axis=0 + ) # slightly hacky line, necessary to make output correct shape + cwt = np.swapaxes(coef, 1, 2) if len(output_shape) == 2: return nap.TsdFrame( diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index 81c6b024..8edc4bd7 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -91,6 +91,22 @@ def test_compute_wavelet_transform(): assert mpf == 20 assert mwt.shape == (1000, 10) + t = np.linspace(0, 1, 1000) + sig = nap.Tsd(d=np.sin(t * 20 * np.pi * 2), t=t) + freqs = np.linspace(10, 100, 10) + mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, norm="sss") + mpf = freqs[np.argmax(np.sum(np.abs(mwt), axis=0))] + assert mpf == 20 + assert mwt.shape == (1000, 10) + + t = np.linspace(0, 1, 1000) + sig = nap.Tsd(d=np.sin(t * 20 * np.pi * 2), t=t) + freqs = np.linspace(10, 100, 10) + mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, norm="amp") + mpf = freqs[np.argmax(np.sum(np.abs(mwt), axis=0))] + assert mpf == 20 + assert mwt.shape == (1000, 10) + t = np.linspace(0, 1, 1000) sig = nap.Tsd(d=np.sin(t * 70 * np.pi * 2), t=t) freqs = np.linspace(10, 100, 10) @@ -125,8 +141,3 @@ def test_compute_wavelet_transform(): with pytest.raises(ValueError) as e_info: nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, n_cycles=-1.5) assert str(e_info.value) == "Number of cycles must be a positive number." - - -if __name__ == "__main__": - test_compute_wavelet_transform() - # test_compute_welch_spectogram() From 5af389de2aadeb40f43dbdc2c64a14d50238c16b Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Wed, 17 Jul 2024 19:56:02 +0100 Subject: [PATCH 29/71] removed time zeroing for doc examples --- docs/examples/tutorial_phase_preferences.py | 17 +++++---------- docs/examples/tutorial_signal_processing.py | 24 ++++++--------------- 2 files changed, 11 insertions(+), 30 deletions(-) diff --git a/docs/examples/tutorial_phase_preferences.py b/docs/examples/tutorial_phase_preferences.py index 59ba685b..b4f1f6ec 100644 --- a/docs/examples/tutorial_phase_preferences.py +++ b/docs/examples/tutorial_phase_preferences.py @@ -78,23 +78,16 @@ data["rem"]["start"][0] + 90.0, data["rem"]["start"][0] + 100.0, ) -REM_Tsd = nap.TsdFrame( - t=data["eeg"].restrict(REM_minute_interval).index.values - - data["eeg"].restrict(REM_minute_interval).index.values.min(), - d=data["eeg"].restrict(REM_minute_interval).values, -) +REM_Tsd = data["eeg"].restrict(REM_minute_interval) # We will also extract spike times from all units in our dataset # which occur during our specified interval spikes = {} for i in data["units"].index: - spikes[i] = ( - data["units"][i].times()[ - (data["units"][i].times() > REM_minute_interval["start"][0]) - & (data["units"][i].times() < REM_minute_interval["end"][0]) - ] - - data["eeg"].restrict(REM_minute_interval).index.values.min() - ) + spikes[i] = data["units"][i].times()[ + (data["units"][i].times() > REM_minute_interval["start"][0]) + & (data["units"][i].times() < REM_minute_interval["end"][0]) + ] # The given dataset has only one channel, so we set channel = 0 here channel = 0 diff --git a/docs/examples/tutorial_signal_processing.py b/docs/examples/tutorial_signal_processing.py index 34941449..b5d786fe 100644 --- a/docs/examples/tutorial_signal_processing.py +++ b/docs/examples/tutorial_signal_processing.py @@ -76,16 +76,9 @@ data["forward_ep"]["start"][run_index], data["forward_ep"]["end"][run_index] + 4.0, ) -RUN_Tsd = nap.TsdFrame( - t=data["eeg"].restrict(RUN_interval).index.values - - data["forward_ep"]["start"][run_index], - d=data["eeg"].restrict(RUN_interval).values, -) -RUN_pos = nap.TsdFrame( - t=data["position"].restrict(RUN_interval).index.values - - data["forward_ep"]["start"][run_index], - d=data["position"].restrict(RUN_interval).asarray(), -) +RUN_Tsd = data["eeg"].restrict(RUN_interval) +RUN_pos = data["position"].restrict(RUN_interval) + # The given dataset has only one channel, so we set channel = 0 here channel = 0 @@ -103,9 +96,7 @@ axd["ephys"].plot( RUN_Tsd[:, channel].restrict( nap.IntervalSet( - 0.0, - data["forward_ep"]["end"][run_index] - - data["forward_ep"]["start"][run_index], + data["forward_ep"]["start"][run_index], data["forward_ep"]["end"][run_index] ) ), label="Traversal LFP Data", @@ -114,11 +105,8 @@ axd["ephys"].plot( RUN_Tsd[:, channel].restrict( nap.IntervalSet( - data["forward_ep"]["end"][run_index] - - data["forward_ep"]["start"][run_index], - data["forward_ep"]["end"][run_index] - - data["forward_ep"]["start"][run_index] - + 5.0, + data["forward_ep"]["end"][run_index], + data["forward_ep"]["end"][run_index] + 5.0, ) ), label="Post Traversal LFP Data", From 072fbe396473ec5ca99a0a4405890b24247c52e7 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Wed, 17 Jul 2024 20:57:17 +0100 Subject: [PATCH 30/71] minor changes, wavelet API v0 --- docs/examples/tutorial_wavelet_api.py | 178 ++++++++++++++++++++++++++ pynapple/process/__init__.py | 1 + pynapple/process/signal_processing.py | 6 +- 3 files changed, 182 insertions(+), 3 deletions(-) create mode 100644 docs/examples/tutorial_wavelet_api.py diff --git a/docs/examples/tutorial_wavelet_api.py b/docs/examples/tutorial_wavelet_api.py new file mode 100644 index 00000000..d5552249 --- /dev/null +++ b/docs/examples/tutorial_wavelet_api.py @@ -0,0 +1,178 @@ +# -*- coding: utf-8 -*- +""" +Wavelet API tutorial +============ + +Working with Wavelets. + +See the [documentation](https://pynapple-org.github.io/pynapple/) of Pynapple for instructions on installing the package. + +This tutorial was made by Kipp Freud. + +""" + +# %% +# !!! warning +# This tutorial uses matplotlib for displaying the figure +# +# You can install all with `pip install matplotlib requests tqdm` +# +# Now, import the necessary libraries: + +import matplotlib +matplotlib.use("TkAgg") +import matplotlib.pyplot as plt +import numpy as np + +import pynapple as nap + +# %% +# *** +# Generating a dummy signal +# ------------------ +# Let's generate a dummy signal to analyse with wavelets! + +# Our dummy dataset will contain two components, a low frequency 3Hz sinusoid combined +# with a weaker 25Hz sinusoid. +t = np.linspace(0,10, 10000) +sig = nap.Tsd(d=np.sin(t * (5+t) * np.pi * 2) + np.sin(t * 3 * np.pi * 2), t=t) +# Plot it +fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 3)) +ax.plot(sig) +ax.margins(0) +plt.show() + +# %% +# *** +# Getting our Morlet wavelet filter bank +# ------------------ + +freqs = np.linspace(1,25, num=25) +filter_bank = nap.generate_morlet_filterbank(freqs, 1000, n_cycles=1.5, scaling=1.0, precision=10) +fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) +offset = 0.2 +for f_i in range(filter_bank.shape[0]): + ax.plot(np.linspace(-8, 8, filter_bank.shape[1]), + filter_bank[f_i,:]+offset*f_i) + ax.text(-2.2, offset*f_i, f"{np.round(freqs[f_i], 2)}Hz", va='center', ha='left') +ax.margins(0) +ax.yaxis.set_visible(False) +ax.spines['left'].set_visible(False) +ax.spines['right'].set_visible(False) +ax.spines['top'].set_visible(False) +ax.set_xlim(-2, 2) +ax.set_xlabel("Time (s)") +ax.set_title("Morlet Wavelet Filter Bank") +plt.show() + +# %% +# *** +# Effect of n_cycles +# ------------------ + +freqs = np.linspace(1,25, num=25) +filter_bank = nap.generate_morlet_filterbank(freqs, 1000, n_cycles=7.5, scaling=1.0, precision=10) +fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) +offset = 0.2 +for f_i in range(filter_bank.shape[0]): + ax.plot(np.linspace(-8, 8, filter_bank.shape[1]), + filter_bank[f_i,:]+offset*f_i) + ax.text(-2.2, offset*f_i, f"{np.round(freqs[f_i], 2)}Hz", va='center', ha='left') +ax.margins(0) +ax.yaxis.set_visible(False) +ax.spines['left'].set_visible(False) +ax.spines['right'].set_visible(False) +ax.spines['top'].set_visible(False) +ax.set_xlim(-2, 2) +ax.set_xlabel("Time (s)") +ax.set_title("Morlet Wavelet Filter Bank") +plt.show() + +# %% +# *** +# Effect of scaling +# ------------------ + +freqs = np.linspace(1,25, num=25) +filter_bank = nap.generate_morlet_filterbank(freqs, 1000, n_cycles=7.5, scaling=2.0, precision=10) +fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) +offset = 0.2 +for f_i in range(filter_bank.shape[0]): + ax.plot(np.linspace(-8, 8, filter_bank.shape[1]), + filter_bank[f_i,:]+offset*f_i) + ax.text(-2.2, offset*f_i, f"{np.round(freqs[f_i], 2)}Hz", va='center', ha='left') +ax.margins(0) +ax.yaxis.set_visible(False) +ax.spines['left'].set_visible(False) +ax.spines['right'].set_visible(False) +ax.spines['top'].set_visible(False) +ax.set_xlim(-2, 2) +ax.set_xlabel("Time (s)") +ax.set_title("Morlet Wavelet Filter Bank") +plt.show() + + +# %% +# *** +# Decomposing the dummy signal +# ------------------ + +mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, n_cycles=1.5, scaling=1.0, precision=15) + +def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): + if np.iscomplexobj(powers): + powers = abs(powers) + ax.imshow(powers, aspect="auto", **kwargs) + ax.invert_yaxis() + ax.set_xlabel("Time (s)") + ax.set_ylabel("Frequency (Hz)") + if isinstance(x_ticks, int): + x_tick_pos = np.linspace(0, times.size, x_ticks) + x_ticks = np.round(np.linspace(times[0], times[-1], x_ticks), 2) + else: + x_tick_pos = [np.argmin(np.abs(times - val)) for val in x_ticks] + ax.set(xticks=x_tick_pos, xticklabels=x_ticks) + y_ticks = freqs + y_ticks_pos = [np.argmin(np.abs(freqs - val)) for val in y_ticks] + ax.set(yticks=y_ticks_pos, yticklabels=y_ticks) + +fig, ax = plt.subplots(1) +plot_timefrequency( + mwt.index.values[:], + freqs[:], + np.transpose(mwt[:, :].values), + ax=ax, +) +plt.show() + + + +# %% +# *** +# Increasing n_cycles increases resolution of decomposition +# ------------------ + +mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, n_cycles=7.5, scaling=1.0, precision=10) +fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) +plot_timefrequency( + mwt.index.values[:], + freqs[:], + np.transpose(mwt[:, :].values), + ax=ax, +) +plt.show() + +# %% +# *** +# Increasing n_cycles increases resolution of decomposition +# ------------------ + +mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, n_cycles=7.5, scaling=2.0, precision=10) +fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) +plot_timefrequency( + mwt.index.values[:], + freqs[:], + np.transpose(mwt[:, :].values), + ax=ax, +) +plt.show() diff --git a/pynapple/process/__init__.py b/pynapple/process/__init__.py index fb7e22b9..3db9771d 100644 --- a/pynapple/process/__init__.py +++ b/pynapple/process/__init__.py @@ -19,6 +19,7 @@ compute_spectogram, compute_wavelet_transform, compute_welch_spectogram, + generate_morlet_filterbank ) from .tuning_curves import ( compute_1d_mutual_info, diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 936ae6a3..28329092 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -224,7 +224,7 @@ def generate_morlet_filterbank(freqs, fs, n_cycles=1.5, scaling=1.0, precision=1 If array, frequency values to estimate with morlet wavelets. If list, define the frequency range, as [freq_start, freq_stop, freq_step]. The `freq_step` is optional, and defaults to 1. Range is inclusive of `freq_stop` value. - fs : float or None + fs : float Sampling rate, in Hz. n_cycles : float or 1d array Length of the filter, as the number of cycles for each frequency. @@ -236,8 +236,8 @@ def generate_morlet_filterbank(freqs, fs, n_cycles=1.5, scaling=1.0, precision=1 Returns ------- - filter_bank : list[np.ndarray] - list of morlet wavelet filters of the frequencies given + filter_bank : np.ndarray + list of Morlet wavelet filters of the frequencies given """ filter_bank = [] morlet_f = _morlet(int(2**precision), ncycles=n_cycles, scaling=scaling) From ffdc3d972e24850d42442ecfcc47fcc51cec5619 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Wed, 17 Jul 2024 20:58:59 +0100 Subject: [PATCH 31/71] linting --- docs/examples/tutorial_wavelet_api.py | 77 ++++++++++++++++----------- pynapple/process/__init__.py | 2 +- 2 files changed, 48 insertions(+), 31 deletions(-) diff --git a/docs/examples/tutorial_wavelet_api.py b/docs/examples/tutorial_wavelet_api.py index d5552249..55221265 100644 --- a/docs/examples/tutorial_wavelet_api.py +++ b/docs/examples/tutorial_wavelet_api.py @@ -20,6 +20,7 @@ # Now, import the necessary libraries: import matplotlib + matplotlib.use("TkAgg") import matplotlib.pyplot as plt import numpy as np @@ -34,8 +35,8 @@ # Our dummy dataset will contain two components, a low frequency 3Hz sinusoid combined # with a weaker 25Hz sinusoid. -t = np.linspace(0,10, 10000) -sig = nap.Tsd(d=np.sin(t * (5+t) * np.pi * 2) + np.sin(t * 3 * np.pi * 2), t=t) +t = np.linspace(0, 10, 10000) +sig = nap.Tsd(d=np.sin(t * (5 + t) * np.pi * 2) + np.sin(t * 3 * np.pi * 2), t=t) # Plot it fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 3)) ax.plot(sig) @@ -47,19 +48,22 @@ # Getting our Morlet wavelet filter bank # ------------------ -freqs = np.linspace(1,25, num=25) -filter_bank = nap.generate_morlet_filterbank(freqs, 1000, n_cycles=1.5, scaling=1.0, precision=10) +freqs = np.linspace(1, 25, num=25) +filter_bank = nap.generate_morlet_filterbank( + freqs, 1000, n_cycles=1.5, scaling=1.0, precision=10 +) fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) offset = 0.2 for f_i in range(filter_bank.shape[0]): - ax.plot(np.linspace(-8, 8, filter_bank.shape[1]), - filter_bank[f_i,:]+offset*f_i) - ax.text(-2.2, offset*f_i, f"{np.round(freqs[f_i], 2)}Hz", va='center', ha='left') + ax.plot( + np.linspace(-8, 8, filter_bank.shape[1]), filter_bank[f_i, :] + offset * f_i + ) + ax.text(-2.2, offset * f_i, f"{np.round(freqs[f_i], 2)}Hz", va="center", ha="left") ax.margins(0) ax.yaxis.set_visible(False) -ax.spines['left'].set_visible(False) -ax.spines['right'].set_visible(False) -ax.spines['top'].set_visible(False) +ax.spines["left"].set_visible(False) +ax.spines["right"].set_visible(False) +ax.spines["top"].set_visible(False) ax.set_xlim(-2, 2) ax.set_xlabel("Time (s)") ax.set_title("Morlet Wavelet Filter Bank") @@ -70,19 +74,22 @@ # Effect of n_cycles # ------------------ -freqs = np.linspace(1,25, num=25) -filter_bank = nap.generate_morlet_filterbank(freqs, 1000, n_cycles=7.5, scaling=1.0, precision=10) +freqs = np.linspace(1, 25, num=25) +filter_bank = nap.generate_morlet_filterbank( + freqs, 1000, n_cycles=7.5, scaling=1.0, precision=10 +) fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) offset = 0.2 for f_i in range(filter_bank.shape[0]): - ax.plot(np.linspace(-8, 8, filter_bank.shape[1]), - filter_bank[f_i,:]+offset*f_i) - ax.text(-2.2, offset*f_i, f"{np.round(freqs[f_i], 2)}Hz", va='center', ha='left') + ax.plot( + np.linspace(-8, 8, filter_bank.shape[1]), filter_bank[f_i, :] + offset * f_i + ) + ax.text(-2.2, offset * f_i, f"{np.round(freqs[f_i], 2)}Hz", va="center", ha="left") ax.margins(0) ax.yaxis.set_visible(False) -ax.spines['left'].set_visible(False) -ax.spines['right'].set_visible(False) -ax.spines['top'].set_visible(False) +ax.spines["left"].set_visible(False) +ax.spines["right"].set_visible(False) +ax.spines["top"].set_visible(False) ax.set_xlim(-2, 2) ax.set_xlabel("Time (s)") ax.set_title("Morlet Wavelet Filter Bank") @@ -93,19 +100,22 @@ # Effect of scaling # ------------------ -freqs = np.linspace(1,25, num=25) -filter_bank = nap.generate_morlet_filterbank(freqs, 1000, n_cycles=7.5, scaling=2.0, precision=10) +freqs = np.linspace(1, 25, num=25) +filter_bank = nap.generate_morlet_filterbank( + freqs, 1000, n_cycles=7.5, scaling=2.0, precision=10 +) fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) offset = 0.2 for f_i in range(filter_bank.shape[0]): - ax.plot(np.linspace(-8, 8, filter_bank.shape[1]), - filter_bank[f_i,:]+offset*f_i) - ax.text(-2.2, offset*f_i, f"{np.round(freqs[f_i], 2)}Hz", va='center', ha='left') + ax.plot( + np.linspace(-8, 8, filter_bank.shape[1]), filter_bank[f_i, :] + offset * f_i + ) + ax.text(-2.2, offset * f_i, f"{np.round(freqs[f_i], 2)}Hz", va="center", ha="left") ax.margins(0) ax.yaxis.set_visible(False) -ax.spines['left'].set_visible(False) -ax.spines['right'].set_visible(False) -ax.spines['top'].set_visible(False) +ax.spines["left"].set_visible(False) +ax.spines["right"].set_visible(False) +ax.spines["top"].set_visible(False) ax.set_xlim(-2, 2) ax.set_xlabel("Time (s)") ax.set_title("Morlet Wavelet Filter Bank") @@ -117,7 +127,10 @@ # Decomposing the dummy signal # ------------------ -mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, n_cycles=1.5, scaling=1.0, precision=15) +mwt = nap.compute_wavelet_transform( + sig, fs=None, freqs=freqs, n_cycles=1.5, scaling=1.0, precision=15 +) + def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): if np.iscomplexobj(powers): @@ -136,6 +149,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): y_ticks_pos = [np.argmin(np.abs(freqs - val)) for val in y_ticks] ax.set(yticks=y_ticks_pos, yticklabels=y_ticks) + fig, ax = plt.subplots(1) plot_timefrequency( mwt.index.values[:], @@ -146,13 +160,14 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): plt.show() - # %% # *** # Increasing n_cycles increases resolution of decomposition # ------------------ -mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, n_cycles=7.5, scaling=1.0, precision=10) +mwt = nap.compute_wavelet_transform( + sig, fs=None, freqs=freqs, n_cycles=7.5, scaling=1.0, precision=10 +) fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) plot_timefrequency( mwt.index.values[:], @@ -167,7 +182,9 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # Increasing n_cycles increases resolution of decomposition # ------------------ -mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, n_cycles=7.5, scaling=2.0, precision=10) +mwt = nap.compute_wavelet_transform( + sig, fs=None, freqs=freqs, n_cycles=7.5, scaling=2.0, precision=10 +) fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) plot_timefrequency( mwt.index.values[:], diff --git a/pynapple/process/__init__.py b/pynapple/process/__init__.py index 3db9771d..0986cc6d 100644 --- a/pynapple/process/__init__.py +++ b/pynapple/process/__init__.py @@ -19,7 +19,7 @@ compute_spectogram, compute_wavelet_transform, compute_welch_spectogram, - generate_morlet_filterbank + generate_morlet_filterbank, ) from .tuning_curves import ( compute_1d_mutual_info, From 54d8cb6ebaf2b0828974aa44edc84c70276e0d4d Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Thu, 18 Jul 2024 20:20:56 +0100 Subject: [PATCH 32/71] better wavelet API notebook --- docs/examples/tutorial_wavelet_api.py | 306 +++++++++++++++++++------- pynapple/process/signal_processing.py | 2 +- 2 files changed, 223 insertions(+), 85 deletions(-) diff --git a/docs/examples/tutorial_wavelet_api.py b/docs/examples/tutorial_wavelet_api.py index 55221265..e37e6c40 100644 --- a/docs/examples/tutorial_wavelet_api.py +++ b/docs/examples/tutorial_wavelet_api.py @@ -29,109 +29,92 @@ # %% # *** -# Generating a dummy signal +# Generating a Dummy Signal # ------------------ # Let's generate a dummy signal to analyse with wavelets! +# +# Our dummy dataset will contain two components, a low frequency 2Hz sinusoid combined +# with a sinusoid which increases frequency from 5 to 15 Hz throughout the signal. -# Our dummy dataset will contain two components, a low frequency 3Hz sinusoid combined -# with a weaker 25Hz sinusoid. -t = np.linspace(0, 10, 10000) -sig = nap.Tsd(d=np.sin(t * (5 + t) * np.pi * 2) + np.sin(t * 3 * np.pi * 2), t=t) -# Plot it -fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 3)) -ax.plot(sig) -ax.margins(0) -plt.show() +Fs = 2000 +t = np.linspace(0, 5, Fs*5) +two_hz_phase = t * 2 * np.pi * 2 +two_hz_component = np.sin(two_hz_phase) +increasing_freq_component = np.sin(t * (5+t) * np.pi * 2) +sig = nap.Tsd(d=two_hz_component + increasing_freq_component + + np.random.normal(0,0.1,10000), t=t) # %% -# *** -# Getting our Morlet wavelet filter bank -# ------------------ - -freqs = np.linspace(1, 25, num=25) -filter_bank = nap.generate_morlet_filterbank( - freqs, 1000, n_cycles=1.5, scaling=1.0, precision=10 -) -fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) -offset = 0.2 -for f_i in range(filter_bank.shape[0]): - ax.plot( - np.linspace(-8, 8, filter_bank.shape[1]), filter_bank[f_i, :] + offset * f_i - ) - ax.text(-2.2, offset * f_i, f"{np.round(freqs[f_i], 2)}Hz", va="center", ha="left") -ax.margins(0) -ax.yaxis.set_visible(False) -ax.spines["left"].set_visible(False) -ax.spines["right"].set_visible(False) -ax.spines["top"].set_visible(False) -ax.set_xlim(-2, 2) -ax.set_xlabel("Time (s)") -ax.set_title("Morlet Wavelet Filter Bank") +# Lets plot it. +fig, ax = plt.subplots(3, constrained_layout=True, figsize=(10, 5)) +ax[0].plot(t, two_hz_component) +ax[1].plot(t, increasing_freq_component) +ax[2].plot(sig) +ax[0].set_title("2Hz Component") +ax[1].set_title("Increasing Frequency Component") +ax[2].set_title("Dummy Signal") +[ax[i].margins(0) for i in range(3)] +[ax[i].set_ylim(-2.5,2.5) for i in range(3)] +[ax[i].spines[sp].set_visible(False) for sp in ["right", "top"] for i in range(3)] +[ax[i].set_xlabel("Time (s)") for i in range(3)] +[ax[i].set_ylabel("Signal") for i in range(3)] +[ax[i].set_ylim(-2.5,2.5) for i in range(3)] plt.show() + # %% # *** -# Effect of n_cycles +# Getting our Morlet Wavelet Filter Bank # ------------------ +# We will be decomposing our dummy signal using wavelets of different frequencies. These wavelets +# can be examined using the `generate_morlet_filterbank' function. Here we will use the default parameters +# to define a Morlet filter bank with which we will later use to deconstruct the signal. +# Define the frequency of the wavelets in our filter bank freqs = np.linspace(1, 25, num=25) +# Get the filter bank filter_bank = nap.generate_morlet_filterbank( - freqs, 1000, n_cycles=7.5, scaling=1.0, precision=10 + freqs, Fs, n_cycles=1.5, scaling=1.0, precision=20 ) -fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) -offset = 0.2 -for f_i in range(filter_bank.shape[0]): - ax.plot( - np.linspace(-8, 8, filter_bank.shape[1]), filter_bank[f_i, :] + offset * f_i - ) - ax.text(-2.2, offset * f_i, f"{np.round(freqs[f_i], 2)}Hz", va="center", ha="left") -ax.margins(0) -ax.yaxis.set_visible(False) -ax.spines["left"].set_visible(False) -ax.spines["right"].set_visible(False) -ax.spines["top"].set_visible(False) -ax.set_xlim(-2, 2) -ax.set_xlabel("Time (s)") -ax.set_title("Morlet Wavelet Filter Bank") -plt.show() # %% -# *** -# Effect of scaling -# ------------------ - -freqs = np.linspace(1, 25, num=25) -filter_bank = nap.generate_morlet_filterbank( - freqs, 1000, n_cycles=7.5, scaling=2.0, precision=10 -) -fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) -offset = 0.2 -for f_i in range(filter_bank.shape[0]): - ax.plot( - np.linspace(-8, 8, filter_bank.shape[1]), filter_bank[f_i, :] + offset * f_i - ) - ax.text(-2.2, offset * f_i, f"{np.round(freqs[f_i], 2)}Hz", va="center", ha="left") -ax.margins(0) -ax.yaxis.set_visible(False) -ax.spines["left"].set_visible(False) -ax.spines["right"].set_visible(False) -ax.spines["top"].set_visible(False) -ax.set_xlim(-2, 2) -ax.set_xlabel("Time (s)") -ax.set_title("Morlet Wavelet Filter Bank") -plt.show() +# Lets plot it. +def plot_filterbank(filter_bank, freqs, title): + fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) + offset = 0.2 + for f_i in range(filter_bank.shape[0]): + ax.plot( + np.linspace(-8, 8, filter_bank.shape[1]), + filter_bank[f_i, :].real + offset * f_i + ) + ax.text(-2.3, offset * f_i, f"{np.round(freqs[f_i], 2)}Hz", va="center", ha="left") + ax.margins(0) + ax.yaxis.set_visible(False) + [ax.spines[sp].set_visible(False) for sp in ["left", "right", "top"]] + ax.set_xlim(-2, 2) + ax.set_xlabel("Time (s)") + ax.set_title(title) + plt.show() +title = "Morlet Wavelet Filter Bank (Real Components): n_cycles=1.5, scaling=1.0" +plot_filterbank(filter_bank, freqs, title) # %% # *** -# Decomposing the dummy signal +# Decomposing the Dummy Signal # ------------------ +# Here we will use the `compute_wavelet_transform' function to decompose our signal using the filter bank shown +# above. Wavelet decomposition breaks down a signal into its constituent wavelets, capturing both time and +# frequency information for analysis. We will calculate this decomposition and plot it's corresponding +# scalogram. +# Compute the wavelet transform using the parameters above mwt = nap.compute_wavelet_transform( - sig, fs=None, freqs=freqs, n_cycles=1.5, scaling=1.0, precision=15 + sig, fs=Fs, freqs=freqs, n_cycles=1.5, scaling=1.0, precision=20 ) - +# %% +# Lets plot it. def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): if np.iscomplexobj(powers): powers = abs(powers) @@ -150,24 +133,131 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): ax.set(yticks=y_ticks_pos, yticklabels=y_ticks) -fig, ax = plt.subplots(1) +fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) plot_timefrequency( mwt.index.values[:], freqs[:], np.transpose(mwt[:, :].values), ax=ax, ) +ax.set_title("Wavelet Decomposition Scalogram") +plt.show() + +# %% +# *** +# Reconstructing the Slow Oscillation and Phase +# ------------------ +# We can see that the decomposition has picked up on the 2Hz component of the signal, as well as the component with +# increasing frequency. In this section, we will extract just the 2Hz component from the wavelet decomposition, +# and see how it compares to the original section. + +# Get the index of the 2Hz frequency +two_hz_freq_idx = np.where(freqs == 2.)[0] +# The 2Hz component is the real component of the wavelet decomposition at this index +slow_oscillation = mwt[:, two_hz_freq_idx].values.real +# The 2Hz wavelet phase is the angle of the wavelet decomposition at this index +slow_oscillation_phase = np.angle(mwt[:, two_hz_freq_idx].values) + +# %% +# Lets plot it. +fig = plt.figure(constrained_layout=True, figsize=(10, 6)) +axd = fig.subplot_mosaic( + [["signal"], ["phase"]], + height_ratios=[1, 0.4], +) +axd["signal"].plot(sig, label="Raw Signal", alpha=0.5) +axd["signal"].plot(t, slow_oscillation, label="2Hz Reconstruction") +axd["signal"].legend() +axd["phase"].plot(t, slow_oscillation_phase, alpha=0.5) +axd["phase"].set_ylabel("Phase (rad)") +axd["signal"].set_ylabel("Signal") +axd["phase"].set_xlabel("Time (s)") +[axd[f].spines[sp].set_visible(False) for sp in ["right", "top"] for f in ["phase", "signal"]] +axd["signal"].get_xaxis().set_visible(False) +axd["signal"].spines["bottom"].set_visible(False) +[axd[k].margins(0) for k in ["signal", "phase"]] +axd["signal"].set_ylim(-2.5,2.5) +axd["phase"].set_ylim(-np.pi, np.pi) plt.show() +# %% +# *** +# Adding in the 15Hz Oscillation +# ------------------ +# Let's see what happens if we also add the 15 Hz component of the wavelet decomposition to the reconstruction. We +# will extract the 15 Hz components, and also the 15Hz wavelet power over time. The wavelet power tells us to what +# extent the 15 Hz frequency is present in our signal at different times. +# +# Finally, we will add this 15 Hz reconstruction to the one shown above, to see if it improves out reconstructed +# signal. + +# Get the index of the 15 Hz frequency +fifteen_hz_freq_idx = np.where(freqs == 15.)[0] +# The 15 Hz component is the real component of the wavelet decomposition at this index +fifteenHz_oscillation = mwt[:, fifteen_hz_freq_idx].values.real +# The 15 Hz poser is the absolute value of the wavelet decomposition at this index +fifteenHz_oscillation_power = np.abs(mwt[:, fifteen_hz_freq_idx].values) + +# %% +# Lets plot it. +fig, ax = plt.subplots(2, constrained_layout=True, figsize=(10, 6)) +ax[0].plot(t, fifteenHz_oscillation, label="15Hz Reconstruction") +ax[0].plot(t, fifteenHz_oscillation_power, label="15Hz Power") +ax[1].plot(sig, label="Raw Signal", alpha=0.5) +ax[1].plot(t, slow_oscillation+fifteenHz_oscillation, label="2Hz + 15Hz Reconstruction") +[ax[i].set_ylim(-2.5,2.5) for i in range(2)] +[ax[i].margins(0) for i in range(2)] +[ax[i].legend() for i in range(2)] +[ax[i].spines[sp].set_visible(False) for sp in ["right", "top"] for i in range(2)] +ax[0].get_xaxis().set_visible(False) +ax[0].spines["bottom"].set_visible(False) +ax[1].set_xlabel("Time (s)") +[ax[i].set_ylabel("Signal") for i in range(2)] +plt.show() # %% # *** -# Increasing n_cycles increases resolution of decomposition +# Adding ALL the Oscillations! # ------------------ +# Let's now add together the real components of all frequency bands to recreate a version of the original signal. +combined_oscillations = mwt.values.real.sum(axis=1) + +# %% +# Lets plot it. +fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) +ax.plot(sig, alpha=0.5, label="Signal") +ax.plot(t, combined_oscillations, label="Wavelet Reconstruction") +[ax.spines[sp].set_visible(False) for sp in ["right", "top"]] +ax.set_xlabel("Time (s)") +ax.set_ylabel("Signal") +ax.set_title("Wavelet Reconstruction of Signal") +ax.margins(0) +ax.legend() +plt.show() + + + +# %% +# *** +# Effect of n_cycles +# ------------------ + +freqs = np.linspace(1, 25, num=25) +filter_bank = nap.generate_morlet_filterbank( + freqs, 1000, n_cycles=7.5, scaling=1.0, precision=20 +) + +plot_filterbank(filter_bank, freqs, + "Morlet Wavelet Filter Bank (Real Components): n_cycles=7.5, scaling=1.0") + +# %% +# *** +# Let's see what effect this has on the Wavelet Scalogram which is generated... mwt = nap.compute_wavelet_transform( - sig, fs=None, freqs=freqs, n_cycles=7.5, scaling=1.0, precision=10 + sig, fs=Fs, freqs=freqs, n_cycles=7.5, scaling=1.0, precision=20 ) + fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) plot_timefrequency( mwt.index.values[:], @@ -175,21 +265,69 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): np.transpose(mwt[:, :].values), ax=ax, ) + +# %% +# *** +# And let's see if that has an effect on the reconstructed version of the signal + +combined_oscillations = mwt.values.real.sum(axis=1) + +# %% +# Lets plot it. +fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) +ax.plot(sig, alpha=0.5, label="Signal") +ax.plot(t, combined_oscillations, label="Wavelet Reconstruction") +[ax.spines[sp].set_visible(False) for sp in ["right", "top"]] +ax.set_xlabel("Time (s)") +ax.set_ylabel("Signal") +ax.set_title("Wavelet Reconstruction of Signal") +ax.margins(0) +ax.legend() plt.show() + # %% # *** -# Increasing n_cycles increases resolution of decomposition +# Effect of scaling # ------------------ -mwt = nap.compute_wavelet_transform( - sig, fs=None, freqs=freqs, n_cycles=7.5, scaling=2.0, precision=10 +freqs = np.linspace(1, 25, num=25) +filter_bank = nap.generate_morlet_filterbank( + freqs, 1000, n_cycles=1.5, scaling=2.0, precision=20 ) + +plot_filterbank(filter_bank, freqs, + "Morlet Wavelet Filter Bank (Real Components): n_cycles=1.5, scaling=2.0") + +# %% +# *** +# Let's see what effect this has on the Wavelet Scalogram which is generated... fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) +mwt = nap.compute_wavelet_transform( + sig, fs=Fs, freqs=freqs, n_cycles=1.5, scaling=2.0, precision=20 +) plot_timefrequency( mwt.index.values[:], freqs[:], np.transpose(mwt[:, :].values), ax=ax, ) + +# %% +# *** +# And let's see if that has an effect on the reconstructed version of the signal + +combined_oscillations = mwt.values.real.sum(axis=1) + +# %% +# Lets plot it. +fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) +ax.plot(sig, alpha=0.5, label="Signal") +ax.plot(t, combined_oscillations, label="Wavelet Reconstruction") +[ax.spines[sp].set_visible(False) for sp in ["right", "top"]] +ax.set_xlabel("Time (s)") +ax.set_ylabel("Signal") +ax.set_title("Wavelet Reconstruction of Signal") +ax.margins(0) +ax.legend() plt.show() diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 28329092..7cbacc05 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -197,7 +197,7 @@ def compute_wavelet_transform( convolved = convolved_real.values + convolved_imag.values * 1j coef = -np.diff(convolved, axis=0) if norm == "sss": - coef *= coef * (-np.sqrt(scaling) / (freqs / fs)) + coef *= (-np.sqrt(scaling) / (freqs / fs)) elif norm == "amp": coef *= -scaling / (freqs / fs) coef = np.insert( From 3d1ab70b24c6e9ab8dac9945d2d74bcc06e4504a Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Thu, 18 Jul 2024 20:21:22 +0100 Subject: [PATCH 33/71] removed tkagg --- docs/examples/tutorial_wavelet_api.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/docs/examples/tutorial_wavelet_api.py b/docs/examples/tutorial_wavelet_api.py index e37e6c40..8367a33d 100644 --- a/docs/examples/tutorial_wavelet_api.py +++ b/docs/examples/tutorial_wavelet_api.py @@ -19,9 +19,6 @@ # # Now, import the necessary libraries: -import matplotlib - -matplotlib.use("TkAgg") import matplotlib.pyplot as plt import numpy as np From 5df9ff045a9492daef2632c8115f480521722e81 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Thu, 18 Jul 2024 20:34:10 +0100 Subject: [PATCH 34/71] linting --- docs/examples/tutorial_wavelet_api.py | 56 ++++++++++++++++++--------- 1 file changed, 37 insertions(+), 19 deletions(-) diff --git a/docs/examples/tutorial_wavelet_api.py b/docs/examples/tutorial_wavelet_api.py index 8367a33d..b42cb609 100644 --- a/docs/examples/tutorial_wavelet_api.py +++ b/docs/examples/tutorial_wavelet_api.py @@ -34,12 +34,14 @@ # with a sinusoid which increases frequency from 5 to 15 Hz throughout the signal. Fs = 2000 -t = np.linspace(0, 5, Fs*5) +t = np.linspace(0, 5, Fs * 5) two_hz_phase = t * 2 * np.pi * 2 two_hz_component = np.sin(two_hz_phase) -increasing_freq_component = np.sin(t * (5+t) * np.pi * 2) -sig = nap.Tsd(d=two_hz_component + increasing_freq_component + - np.random.normal(0,0.1,10000), t=t) +increasing_freq_component = np.sin(t * (5 + t) * np.pi * 2) +sig = nap.Tsd( + d=two_hz_component + increasing_freq_component + np.random.normal(0, 0.1, 10000), + t=t, +) # %% # Lets plot it. @@ -51,11 +53,11 @@ ax[1].set_title("Increasing Frequency Component") ax[2].set_title("Dummy Signal") [ax[i].margins(0) for i in range(3)] -[ax[i].set_ylim(-2.5,2.5) for i in range(3)] +[ax[i].set_ylim(-2.5, 2.5) for i in range(3)] [ax[i].spines[sp].set_visible(False) for sp in ["right", "top"] for i in range(3)] [ax[i].set_xlabel("Time (s)") for i in range(3)] [ax[i].set_ylabel("Signal") for i in range(3)] -[ax[i].set_ylim(-2.5,2.5) for i in range(3)] +[ax[i].set_ylim(-2.5, 2.5) for i in range(3)] plt.show() @@ -74,6 +76,7 @@ freqs, Fs, n_cycles=1.5, scaling=1.0, precision=20 ) + # %% # Lets plot it. def plot_filterbank(filter_bank, freqs, title): @@ -82,9 +85,11 @@ def plot_filterbank(filter_bank, freqs, title): for f_i in range(filter_bank.shape[0]): ax.plot( np.linspace(-8, 8, filter_bank.shape[1]), - filter_bank[f_i, :].real + offset * f_i + filter_bank[f_i, :].real + offset * f_i, + ) + ax.text( + -2.3, offset * f_i, f"{np.round(freqs[f_i], 2)}Hz", va="center", ha="left" ) - ax.text(-2.3, offset * f_i, f"{np.round(freqs[f_i], 2)}Hz", va="center", ha="left") ax.margins(0) ax.yaxis.set_visible(False) [ax.spines[sp].set_visible(False) for sp in ["left", "right", "top"]] @@ -93,6 +98,7 @@ def plot_filterbank(filter_bank, freqs, title): ax.set_title(title) plt.show() + title = "Morlet Wavelet Filter Bank (Real Components): n_cycles=1.5, scaling=1.0" plot_filterbank(filter_bank, freqs, title) @@ -110,6 +116,7 @@ def plot_filterbank(filter_bank, freqs, title): sig, fs=Fs, freqs=freqs, n_cycles=1.5, scaling=1.0, precision=20 ) + # %% # Lets plot it. def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): @@ -149,7 +156,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # and see how it compares to the original section. # Get the index of the 2Hz frequency -two_hz_freq_idx = np.where(freqs == 2.)[0] +two_hz_freq_idx = np.where(freqs == 2.0)[0] # The 2Hz component is the real component of the wavelet decomposition at this index slow_oscillation = mwt[:, two_hz_freq_idx].values.real # The 2Hz wavelet phase is the angle of the wavelet decomposition at this index @@ -169,11 +176,15 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): axd["phase"].set_ylabel("Phase (rad)") axd["signal"].set_ylabel("Signal") axd["phase"].set_xlabel("Time (s)") -[axd[f].spines[sp].set_visible(False) for sp in ["right", "top"] for f in ["phase", "signal"]] +[ + axd[f].spines[sp].set_visible(False) + for sp in ["right", "top"] + for f in ["phase", "signal"] +] axd["signal"].get_xaxis().set_visible(False) axd["signal"].spines["bottom"].set_visible(False) [axd[k].margins(0) for k in ["signal", "phase"]] -axd["signal"].set_ylim(-2.5,2.5) +axd["signal"].set_ylim(-2.5, 2.5) axd["phase"].set_ylim(-np.pi, np.pi) plt.show() @@ -189,7 +200,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # signal. # Get the index of the 15 Hz frequency -fifteen_hz_freq_idx = np.where(freqs == 15.)[0] +fifteen_hz_freq_idx = np.where(freqs == 15.0)[0] # The 15 Hz component is the real component of the wavelet decomposition at this index fifteenHz_oscillation = mwt[:, fifteen_hz_freq_idx].values.real # The 15 Hz poser is the absolute value of the wavelet decomposition at this index @@ -201,8 +212,10 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): ax[0].plot(t, fifteenHz_oscillation, label="15Hz Reconstruction") ax[0].plot(t, fifteenHz_oscillation_power, label="15Hz Power") ax[1].plot(sig, label="Raw Signal", alpha=0.5) -ax[1].plot(t, slow_oscillation+fifteenHz_oscillation, label="2Hz + 15Hz Reconstruction") -[ax[i].set_ylim(-2.5,2.5) for i in range(2)] +ax[1].plot( + t, slow_oscillation + fifteenHz_oscillation, label="2Hz + 15Hz Reconstruction" +) +[ax[i].set_ylim(-2.5, 2.5) for i in range(2)] [ax[i].margins(0) for i in range(2)] [ax[i].legend() for i in range(2)] [ax[i].spines[sp].set_visible(False) for sp in ["right", "top"] for i in range(2)] @@ -234,7 +247,6 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): plt.show() - # %% # *** # Effect of n_cycles @@ -245,8 +257,11 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): freqs, 1000, n_cycles=7.5, scaling=1.0, precision=20 ) -plot_filterbank(filter_bank, freqs, - "Morlet Wavelet Filter Bank (Real Components): n_cycles=7.5, scaling=1.0") +plot_filterbank( + filter_bank, + freqs, + "Morlet Wavelet Filter Bank (Real Components): n_cycles=7.5, scaling=1.0", +) # %% # *** @@ -293,8 +308,11 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): freqs, 1000, n_cycles=1.5, scaling=2.0, precision=20 ) -plot_filterbank(filter_bank, freqs, - "Morlet Wavelet Filter Bank (Real Components): n_cycles=1.5, scaling=2.0") +plot_filterbank( + filter_bank, + freqs, + "Morlet Wavelet Filter Bank (Real Components): n_cycles=1.5, scaling=2.0", +) # %% # *** From c942a792f79c224443878699ef6ae76fdf27834a Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Thu, 18 Jul 2024 20:39:14 +0100 Subject: [PATCH 35/71] linting --- pynapple/process/signal_processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 7cbacc05..d9b79fd4 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -197,7 +197,7 @@ def compute_wavelet_transform( convolved = convolved_real.values + convolved_imag.values * 1j coef = -np.diff(convolved, axis=0) if norm == "sss": - coef *= (-np.sqrt(scaling) / (freqs / fs)) + coef *= -np.sqrt(scaling) / (freqs / fs) elif norm == "amp": coef *= -scaling / (freqs / fs) coef = np.insert( From 04e9d8a99eed5337e7749f4026882465b3b68ba8 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Fri, 19 Jul 2024 16:55:25 +0100 Subject: [PATCH 36/71] wavelet api tutorial improved, generate_filterbank returns TdsFrame --- docs/examples/tutorial_wavelet_api.py | 103 +++++++++++++++++++------- pynapple/process/signal_processing.py | 50 ++++--------- 2 files changed, 91 insertions(+), 62 deletions(-) diff --git a/docs/examples/tutorial_wavelet_api.py b/docs/examples/tutorial_wavelet_api.py index b42cb609..fdbd162d 100644 --- a/docs/examples/tutorial_wavelet_api.py +++ b/docs/examples/tutorial_wavelet_api.py @@ -3,7 +3,7 @@ Wavelet API tutorial ============ -Working with Wavelets. +Working with Wavelets! See the [documentation](https://pynapple-org.github.io/pynapple/) of Pynapple for instructions on installing the package. @@ -72,9 +72,7 @@ # Define the frequency of the wavelets in our filter bank freqs = np.linspace(1, 25, num=25) # Get the filter bank -filter_bank = nap.generate_morlet_filterbank( - freqs, Fs, n_cycles=1.5, scaling=1.0, precision=20 -) +filter_bank = nap.generate_morlet_filterbank(freqs, Fs, n_cycles=1.5, scaling=1.0) # %% @@ -82,11 +80,8 @@ def plot_filterbank(filter_bank, freqs, title): fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) offset = 0.2 - for f_i in range(filter_bank.shape[0]): - ax.plot( - np.linspace(-8, 8, filter_bank.shape[1]), - filter_bank[f_i, :].real + offset * f_i, - ) + for f_i in range(filter_bank.shape[1]): + ax.plot(filter_bank[:, f_i] + offset * f_i) ax.text( -2.3, offset * f_i, f"{np.round(freqs[f_i], 2)}Hz", va="center", ha="left" ) @@ -112,9 +107,7 @@ def plot_filterbank(filter_bank, freqs, title): # scalogram. # Compute the wavelet transform using the parameters above -mwt = nap.compute_wavelet_transform( - sig, fs=Fs, freqs=freqs, n_cycles=1.5, scaling=1.0, precision=20 -) +mwt = nap.compute_wavelet_transform(sig, fs=Fs, freqs=freqs, n_cycles=1.5, scaling=1.0) # %% @@ -237,25 +230,77 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # Lets plot it. fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) ax.plot(sig, alpha=0.5, label="Signal") -ax.plot(t, combined_oscillations, label="Wavelet Reconstruction") +ax.plot(t, combined_oscillations, label="Wavelet Reconstruction", alpha=0.5) [ax.spines[sp].set_visible(False) for sp in ["right", "top"]] ax.set_xlabel("Time (s)") ax.set_ylabel("Signal") ax.set_title("Wavelet Reconstruction of Signal") +ax.set_ylim(-6, 6) ax.margins(0) ax.legend() plt.show() +# %% +# *** +# Parametrization +# ------------------ +# Our reconstruction seems to get the amplitude modulations of our signal correct, but the amplitude is overestimated, +# in particular towards the end of the time period. Often, this is due to a suboptimal choice of parameters, which +# can lead to a low spatial or temporal resolution. Let's explore what changing our parameters does to the +# underlying wavelets. + +freqs = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) +scales = [1.0, 2.0, 3.0] +cycles = [1.0, 2.0, 3.0] + +fig, ax = plt.subplots( + len(scales), len(cycles), constrained_layout=True, figsize=(10, 5) +) +for row_i, sc in enumerate(scales): + for col_i, cyc in enumerate(cycles): + filter_bank = nap.generate_morlet_filterbank( + freqs, 1000, n_cycles=cyc, scaling=sc + ) + ax[row_i, col_i].plot(filter_bank[:, 0]) + ax[row_i, col_i].set_xlim(-15, 15) + ax[row_i, col_i].set_xlabel("Time (s)") + ax[row_i, col_i].set_ylabel("Signal") + [ + ax[row_i, col_i].spines[sp].set_visible(False) + for sp in ["top", "right", "left"] + ] + ax[row_i, col_i].get_yaxis().set_visible(False) + fig.text( + 0.01, + 0.6 / len(scales) + row_i / len(scales), + f"scaling={sc}", + ha="center", + va="center", + rotation="vertical", + fontsize=8, + ) +for col_i, cyc in enumerate(cycles): + ax[0, col_i].set_title(f"n_cycles={cyc}", fontsize=8) +fig.suptitle("Parametrization Visualization") +plt.show() + +# %% +# Increasing n_cycles increases the number of wavelet cycles present in the oscillations (cycles) within the +# Gaussian window of the Morlet wavelet. It essentially controls the trade-off between time resolution +# and frequency resolution. +# +# The scale parameter determines the dilation or compression of the wavelet. It controls the size of the wavelet in +# time, affecting the overall shape of the wavelet. + # %% # *** # Effect of n_cycles # ------------------ +# Let's increase n_cycles to 7.5 and see the effect on the resultant filter bank. freqs = np.linspace(1, 25, num=25) -filter_bank = nap.generate_morlet_filterbank( - freqs, 1000, n_cycles=7.5, scaling=1.0, precision=20 -) +filter_bank = nap.generate_morlet_filterbank(freqs, 1000, n_cycles=7.5, scaling=1.0) plot_filterbank( filter_bank, @@ -266,9 +311,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # %% # *** # Let's see what effect this has on the Wavelet Scalogram which is generated... -mwt = nap.compute_wavelet_transform( - sig, fs=Fs, freqs=freqs, n_cycles=7.5, scaling=1.0, precision=20 -) +mwt = nap.compute_wavelet_transform(sig, fs=Fs, freqs=freqs, n_cycles=7.5, scaling=1.0) fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) plot_timefrequency( @@ -277,6 +320,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): np.transpose(mwt[:, :].values), ax=ax, ) +ax.set_title("Wavelet Decomposition Scalogram") # %% # *** @@ -288,45 +332,47 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # Lets plot it. fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) ax.plot(sig, alpha=0.5, label="Signal") -ax.plot(t, combined_oscillations, label="Wavelet Reconstruction") +ax.plot(t, combined_oscillations, label="Wavelet Reconstruction", alpha=0.5) [ax.spines[sp].set_visible(False) for sp in ["right", "top"]] ax.set_xlabel("Time (s)") ax.set_ylabel("Signal") ax.set_title("Wavelet Reconstruction of Signal") +ax.set_ylim(-6, 6) ax.margins(0) ax.legend() plt.show() +# %% +# There's a small improvement, but perhaps we can do better. + # %% # *** # Effect of scaling # ------------------ +# Let's increase scaling to 2.0 and see the effect on the resultant filter bank. freqs = np.linspace(1, 25, num=25) -filter_bank = nap.generate_morlet_filterbank( - freqs, 1000, n_cycles=1.5, scaling=2.0, precision=20 -) +filter_bank = nap.generate_morlet_filterbank(freqs, 1000, n_cycles=7.5, scaling=2.0) plot_filterbank( filter_bank, freqs, - "Morlet Wavelet Filter Bank (Real Components): n_cycles=1.5, scaling=2.0", + "Morlet Wavelet Filter Bank (Real Components): n_cycles=7.5, scaling=2.0", ) # %% # *** # Let's see what effect this has on the Wavelet Scalogram which is generated... fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) -mwt = nap.compute_wavelet_transform( - sig, fs=Fs, freqs=freqs, n_cycles=1.5, scaling=2.0, precision=20 -) +mwt = nap.compute_wavelet_transform(sig, fs=Fs, freqs=freqs, n_cycles=7.5, scaling=2.0) plot_timefrequency( mwt.index.values[:], freqs[:], np.transpose(mwt[:, :].values), ax=ax, ) +ax.set_title("Wavelet Decomposition Scalogram") # %% # *** @@ -338,11 +384,12 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # Lets plot it. fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) ax.plot(sig, alpha=0.5, label="Signal") -ax.plot(t, combined_oscillations, label="Wavelet Reconstruction") +ax.plot(t, combined_oscillations, label="Wavelet Reconstruction", alpha=0.5) [ax.spines[sp].set_visible(False) for sp in ["right", "top"]] ax.set_xlabel("Time (s)") ax.set_ylabel("Signal") ax.set_title("Wavelet Reconstruction of Signal") ax.margins(0) +ax.set_ylim(-6, 6) ax.legend() plt.show() diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index d9b79fd4..0c3f6e13 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -51,35 +51,9 @@ def compute_spectogram(sig, fs=None, ep=None, full_range=False): return ret -def compute_welch_spectogram(sig, fs=None): - """ - Performs Welch's decomposition on sig, returns output. - Estimates the power spectral density of a signal by segmenting it into overlapping sections, applying a - window function to each segment, computing their FFTs, and averaging the resulting periodograms to reduce noise. - - ..todo: remove this or add binsize parameter - ..todo: be careful of border artifacts - - Parameters - ---------- - sig : pynapple.Tsd or pynapple.TsdFrame - Time series. - fs : float, optional - Sampling rate, in Hz. If None, will be calculated from the given signal - """ - if not isinstance(sig, (nap.Tsd, nap.TsdFrame)): - raise TypeError( - "Currently compute_welch_spectogram is only implemented for Tsd or TsdFrame" - ) - if fs is None: - fs = sig.index.shape[0] / (sig.index.max() - sig.index.min()) - freqs, spectogram = welch(sig.values, fs=fs, axis=0) - return pd.DataFrame(spectogram, freqs) - - def _morlet(M=1024, ncycles=1.5, scaling=1.0, precision=8): """ - Defines the complex Morlet wavelet kernel + Defines the complex Morlet wavelet kernel. Parameters ---------- @@ -137,7 +111,7 @@ def compute_wavelet_transform( sig, freqs, fs=None, n_cycles=1.5, scaling=1.0, precision=10, norm=None ): """ - Compute the time-frequency representation of a signal using morlet wavelets. + Compute the time-frequency representation of a signal using Morlet wavelets. Parameters ---------- @@ -192,8 +166,8 @@ def compute_wavelet_transform( sig = sig.reshape((sig.shape[0], np.prod(sig.shape[1:]))) filter_bank = generate_morlet_filterbank(freqs, fs, n_cycles, scaling, precision) - convolved_real = sig.convolve(np.transpose(filter_bank.real)) - convolved_imag = sig.convolve(np.transpose(filter_bank.imag)) + convolved_real = sig.convolve(filter_bank.real().values) + convolved_imag = sig.convolve(filter_bank.imag().values) convolved = convolved_real.values + convolved_imag.values * 1j coef = -np.diff(convolved, axis=0) if norm == "sss": @@ -217,6 +191,8 @@ def compute_wavelet_transform( def generate_morlet_filterbank(freqs, fs, n_cycles=1.5, scaling=1.0, precision=10): """ + Generates a Morlet filterbank using the given frequencies and parameters. Can be used purely for visualization, + or to convolve with a pynapple Tsd, TsdFrame, or TsdTensor as part of a wavelet decomposition process. Parameters ---------- @@ -236,14 +212,17 @@ def generate_morlet_filterbank(freqs, fs, n_cycles=1.5, scaling=1.0, precision=1 Returns ------- - filter_bank : np.ndarray + filter_bank : pynapple.TsdFrame list of Morlet wavelet filters of the frequencies given """ + if len(freqs) == 0: + raise ValueError("Given list of freqs cannot be empty.") filter_bank = [] + time_cutoff = 8 morlet_f = _morlet(int(2**precision), ncycles=n_cycles, scaling=scaling) - x = np.linspace(-8, 8, int(2**precision)) + x = np.linspace(-time_cutoff, time_cutoff, int(2**precision)) int_psi = np.conj(_integrate(morlet_f, x[1] - x[0])) - max_len = 0 + max_len = -1 for freq in freqs: scale = scaling / (freq / fs) j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * (x[1] - x[0])) @@ -253,6 +232,9 @@ def generate_morlet_filterbank(freqs, fs, n_cycles=1.5, scaling=1.0, precision=1 int_psi_scale = int_psi[j][::-1] if len(int_psi_scale) > max_len: max_len = len(int_psi_scale) + time = np.linspace( + -time_cutoff * scaling / freq, time_cutoff * scaling / freq, max_len + ) filter_bank.append(int_psi_scale) filter_bank = [ np.pad( @@ -262,7 +244,7 @@ def generate_morlet_filterbank(freqs, fs, n_cycles=1.5, scaling=1.0, precision=1 ) for arr in filter_bank ] - return np.array(filter_bank) + return nap.TsdFrame(d=np.array(filter_bank).transpose(), t=time) def _integrate(arr, step): From c49d767451c18afebc573c08ef85f18fd6dad25f Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Fri, 19 Jul 2024 16:57:28 +0100 Subject: [PATCH 37/71] welch removed --- pynapple/process/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pynapple/process/__init__.py b/pynapple/process/__init__.py index 0986cc6d..a73dea00 100644 --- a/pynapple/process/__init__.py +++ b/pynapple/process/__init__.py @@ -18,7 +18,6 @@ from .signal_processing import ( compute_spectogram, compute_wavelet_transform, - compute_welch_spectogram, generate_morlet_filterbank, ) from .tuning_curves import ( From 917f932d20a93d72a3ad5429830921acfec3b6b6 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Fri, 19 Jul 2024 16:59:22 +0100 Subject: [PATCH 38/71] welch import removed --- pynapple/process/signal_processing.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 0c3f6e13..1b83ea34 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -6,7 +6,6 @@ import numpy as np import pandas as pd -from scipy.signal import welch import pynapple as nap From 10e47fb0a414fe7be888d89d42473367c38c4446 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Fri, 19 Jul 2024 23:22:12 +0100 Subject: [PATCH 39/71] review comments addressed --- .../tutorial_pynapple_wavelets.py} | 47 ++-- docs/examples/tutorial_phase_preferences.py | 240 +++++++++--------- docs/examples/tutorial_signal_processing.py | 181 ++++++------- 3 files changed, 226 insertions(+), 242 deletions(-) rename docs/{examples/tutorial_wavelet_api.py => api_guide/tutorial_pynapple_wavelets.py} (93%) diff --git a/docs/examples/tutorial_wavelet_api.py b/docs/api_guide/tutorial_pynapple_wavelets.py similarity index 93% rename from docs/examples/tutorial_wavelet_api.py rename to docs/api_guide/tutorial_pynapple_wavelets.py index fdbd162d..8cdeb9e1 100644 --- a/docs/examples/tutorial_wavelet_api.py +++ b/docs/api_guide/tutorial_pynapple_wavelets.py @@ -15,12 +15,15 @@ # !!! warning # This tutorial uses matplotlib for displaying the figure # -# You can install all with `pip install matplotlib requests tqdm` +# You can install all with `pip install matplotlib requests tqdm seaborn` # # Now, import the necessary libraries: import matplotlib.pyplot as plt import numpy as np +import seaborn + +seaborn.set_theme() import pynapple as nap @@ -58,7 +61,6 @@ [ax[i].set_xlabel("Time (s)") for i in range(3)] [ax[i].set_ylabel("Signal") for i in range(3)] [ax[i].set_ylim(-2.5, 2.5) for i in range(3)] -plt.show() # %% @@ -66,7 +68,7 @@ # Getting our Morlet Wavelet Filter Bank # ------------------ # We will be decomposing our dummy signal using wavelets of different frequencies. These wavelets -# can be examined using the `generate_morlet_filterbank' function. Here we will use the default parameters +# can be examined using the `generate_morlet_filterbank` function. Here we will use the default parameters # to define a Morlet filter bank with which we will later use to deconstruct the signal. # Define the frequency of the wavelets in our filter bank @@ -78,7 +80,7 @@ # %% # Lets plot it. def plot_filterbank(filter_bank, freqs, title): - fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) + fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 7)) offset = 0.2 for f_i in range(filter_bank.shape[1]): ax.plot(filter_bank[:, f_i] + offset * f_i) @@ -91,7 +93,6 @@ def plot_filterbank(filter_bank, freqs, title): ax.set_xlim(-2, 2) ax.set_xlabel("Time (s)") ax.set_title(title) - plt.show() title = "Morlet Wavelet Filter Bank (Real Components): n_cycles=1.5, scaling=1.0" @@ -101,7 +102,7 @@ def plot_filterbank(filter_bank, freqs, title): # *** # Decomposing the Dummy Signal # ------------------ -# Here we will use the `compute_wavelet_transform' function to decompose our signal using the filter bank shown +# Here we will use the `compute_wavelet_transform` function to decompose our signal using the filter bank shown # above. Wavelet decomposition breaks down a signal into its constituent wavelets, capturing both time and # frequency information for analysis. We will calculate this decomposition and plot it's corresponding # scalogram. @@ -128,9 +129,10 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): y_ticks = freqs y_ticks_pos = [np.argmin(np.abs(freqs - val)) for val in y_ticks] ax.set(yticks=y_ticks_pos, yticklabels=y_ticks) + ax.grid(False) -fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) +fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) plot_timefrequency( mwt.index.values[:], freqs[:], @@ -138,7 +140,6 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): ax=ax, ) ax.set_title("Wavelet Decomposition Scalogram") -plt.show() # %% # *** @@ -179,7 +180,6 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): [axd[k].margins(0) for k in ["signal", "phase"]] axd["signal"].set_ylim(-2.5, 2.5) axd["phase"].set_ylim(-np.pi, np.pi) -plt.show() # %% # *** @@ -216,7 +216,6 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): ax[0].spines["bottom"].set_visible(False) ax[1].set_xlabel("Time (s)") [ax[i].set_ylabel("Signal") for i in range(2)] -plt.show() # %% # *** @@ -224,7 +223,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # ------------------ # Let's now add together the real components of all frequency bands to recreate a version of the original signal. -combined_oscillations = mwt.values.real.sum(axis=1) +combined_oscillations = mwt.sum(axis=1) # %% # Lets plot it. @@ -238,7 +237,6 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): ax.set_ylim(-6, 6) ax.margins(0) ax.legend() -plt.show() # %% @@ -265,25 +263,18 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): ax[row_i, col_i].plot(filter_bank[:, 0]) ax[row_i, col_i].set_xlim(-15, 15) ax[row_i, col_i].set_xlabel("Time (s)") - ax[row_i, col_i].set_ylabel("Signal") + ax[row_i, col_i].set_yticks([]) [ ax[row_i, col_i].spines[sp].set_visible(False) for sp in ["top", "right", "left"] ] - ax[row_i, col_i].get_yaxis().set_visible(False) - fig.text( - 0.01, - 0.6 / len(scales) + row_i / len(scales), - f"scaling={sc}", - ha="center", - va="center", - rotation="vertical", - fontsize=8, - ) + if col_i != 0: + ax[row_i, col_i].get_yaxis().set_visible(False) for col_i, cyc in enumerate(cycles): - ax[0, col_i].set_title(f"n_cycles={cyc}", fontsize=8) + ax[0, col_i].set_title(f"n_cycles={cyc}", fontsize=10) +for row_i, scl in enumerate(scales): + ax[row_i, 0].set_ylabel(f"scaling={scl}", fontsize=10) fig.suptitle("Parametrization Visualization") -plt.show() # %% # Increasing n_cycles increases the number of wavelet cycles present in the oscillations (cycles) within the @@ -326,7 +317,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # *** # And let's see if that has an effect on the reconstructed version of the signal -combined_oscillations = mwt.values.real.sum(axis=1) +combined_oscillations = mwt.sum(axis=1) # %% # Lets plot it. @@ -340,7 +331,6 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): ax.set_ylim(-6, 6) ax.margins(0) ax.legend() -plt.show() # %% # There's a small improvement, but perhaps we can do better. @@ -378,7 +368,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # *** # And let's see if that has an effect on the reconstructed version of the signal -combined_oscillations = mwt.values.real.sum(axis=1) +combined_oscillations = mwt.sum(axis=1) # %% # Lets plot it. @@ -392,4 +382,3 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): ax.margins(0) ax.set_ylim(-6, 6) ax.legend() -plt.show() diff --git a/docs/examples/tutorial_phase_preferences.py b/docs/examples/tutorial_phase_preferences.py index b4f1f6ec..f5af3b66 100644 --- a/docs/examples/tutorial_phase_preferences.py +++ b/docs/examples/tutorial_phase_preferences.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- """ -Grosmark & Buzsáki (2016) Tutorial 2 +Computing Phase Preferences ============ In the previous [Grosmark & Buzsáki (2016)](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4919122/) tutorial, @@ -20,7 +20,7 @@ # !!! warning # This tutorial uses matplotlib for displaying the figure # -# You can install all with `pip install matplotlib requests tqdm` +# You can install all with `pip install matplotlib requests tqdm seaborn` # # First, import the necessary libraries: @@ -29,10 +29,14 @@ import matplotlib.pyplot as plt import numpy as np +import pandas as pd import requests import scipy +import seaborn import tqdm +seaborn.set_theme() + import pynapple as nap # %% @@ -62,7 +66,7 @@ # Let's load and print the full dataset. data = nap.load_file(path) -FS = len(data["eeg"].index[:]) / (data["eeg"].index[-1] - data["eeg"].index[0]) +FS = 1250 # We know from the methods of the paper print(data) @@ -75,7 +79,7 @@ # Define the IntervalSet for this run and instantiate both LFP and # Position TsdFrame objects REM_minute_interval = nap.IntervalSet( - data["rem"]["start"][0] + 90.0, + data["rem"]["start"][0] + 95.0, data["rem"]["start"][0] + 100.0, ) REM_Tsd = data["eeg"].restrict(REM_minute_interval) @@ -88,9 +92,7 @@ (data["units"][i].times() > REM_minute_interval["start"][0]) & (data["units"][i].times() < REM_minute_interval["end"][0]) ] - -# The given dataset has only one channel, so we set channel = 0 here -channel = 0 +spikes_tsdg = data["units"].restrict(REM_minute_interval) # %% # *** @@ -98,19 +100,17 @@ # ----------------------------------- # We should first plot our REM Local Field Potential data. -fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) - +fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 3)) ax.plot( - REM_Tsd[:, channel], + REM_Tsd, label="REM LFP Data", - color="green", + color="blue", ) ax.set_title("REM Local Field Potential") ax.set_ylabel("LFP (v)") ax.set_xlabel("time (s)") ax.margins(0) ax.legend() -plt.show() # %% # *** @@ -121,32 +121,15 @@ # as we did in the last tutorial, to see get a more informative breakdown of the # frequencies present in the data. -# We must define the frequency set that we'd like to use for our decomposition; -# these have been manually selected based on the frequencies used in -# Frey et. al (2021), but could also be defined as a linspace or logspace -freqs = np.array( - [ - 2.59, - 3.66, - 5.18, - 8.0, - 10.36, - 20.72, - 29.3, - 41.44, - 58.59, - 82.88, - 117.19, - 152.35, - 192.19, - 200.0, - 234.38, - 270.00, - 331.5, - 390.00, - ] -) -mwt_REM = nap.compute_wavelet_transform(REM_Tsd[:, channel], fs=None, freqs=freqs) + +# We must define the frequency set that we'd like to use for our decomposition +freqs = np.geomspace(5, 200, 25) +# Compute the wavelet transform on our LFP data +mwt_REM = nap.compute_wavelet_transform(REM_Tsd[:, 0], fs=FS, freqs=freqs) + +# %% +# *** +# Now let's plot the calculated wavelet scalogram. # Define wavelet decomposition plotting function @@ -163,9 +146,10 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): else: x_tick_pos = [np.argmin(np.abs(times - val)) for val in x_ticks] ax.set(xticks=x_tick_pos, xticklabels=x_ticks) - y_ticks = freqs + y_ticks = [np.round(f, 2) for f in freqs] y_ticks_pos = [np.argmin(np.abs(freqs - val)) for val in y_ticks] ax.set(yticks=y_ticks_pos, yticklabels=y_ticks) + ax.grid(False) # And plot it @@ -190,19 +174,24 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): axd["lfp_rem"].get_xaxis().set_visible(False) for spine in ["top", "right", "bottom", "left"]: axd["lfp_rem"].spines[spine].set_visible(False) -plt.show() # %% # *** # Visualizing Theta Band Power and Phase # ----------------------------------- # There seems to be a strong theta frequency present in the data during the maze traversal. -# Let's plot the estimated 8Hz component of the wavelet decomposition on top of our data, and see how well -# they match up. We will also extract and plot the phase of the 8Hz wavelet from the decomposition. -theta_freq_index = 3 +# Let's plot the estimated 7Hz component of the wavelet decomposition on top of our data, and see how well +# they match up. We will also extract and plot the phase of the 7Hz wavelet from the decomposition. +theta_freq_index = np.argmin(np.abs(7 - freqs)) theta_band_reconstruction = mwt_REM[:, theta_freq_index].values.real # calculating phase here -theta_band_phase = np.angle(mwt_REM[:, theta_freq_index].values) +theta_band_phase = nap.Tsd( + t=mwt_REM.index, d=np.angle(mwt_REM[:, theta_freq_index].values) +) + +# %% +# *** +# Now let's plot the theta power and phase, along with the LFP. fig = plt.figure(constrained_layout=True, figsize=(10, 5)) axd = fig.subplot_mosaic( @@ -213,63 +202,78 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): height_ratios=[0.4, 0.2], ) -axd["theta_pow"].plot( - REM_Tsd.index.values, REM_Tsd[:, channel], alpha=0.5, label="LFP Data - REM" -) +axd["theta_pow"].plot(REM_Tsd, alpha=0.5, label="LFP Data - REM") axd["theta_pow"].plot( REM_Tsd.index.values, theta_band_reconstruction, - label=f"{freqs[theta_freq_index]}Hz oscillations", + label=f"{np.round(freqs[theta_freq_index], 2)}Hz oscillations", ) axd["theta_pow"].set_ylabel("LFP (v)") axd["theta_pow"].set_xlabel("Time (s)") -axd["theta_pow"].set_title(f"{freqs[theta_freq_index]}Hz oscillation power.") # +axd["theta_pow"].set_title( + f"{np.round(freqs[theta_freq_index],2)}Hz oscillation power." +) # axd["theta_pow"].legend() -axd["phase"].plot(theta_band_phase) +axd["phase"].plot(REM_Tsd.index.values, theta_band_phase, alpha=0.5) [axd[k].margins(0) for k in ["theta_pow", "phase"]] axd["phase"].set_ylabel("Phase") -plt.show() +axd["phase"].get_xaxis().set_visible(False) # %% # *** # Finding Phase of Spikes # ----------------------------------- -# Now that we have the phase of our theta wavelet, and our spike times, we can find the theta phase at which every -# spike occurs +# Now that we have the phase of our theta wavelet, and our spike times, we can find the phase firing preferences +# of each of the units using the compute_1d_tuning_curves function. +# +# We will start by throwing away cells which do not have a high enough firing rate during our interval. + +# Filter units based on firing rate +spikes_tsdg = spikes_tsdg[spikes_tsdg.rate > 5.0] +# Calculate theta phase firing preferences +tuning_curves = nap.compute_1d_tuning_curves( + group=spikes_tsdg, feature=theta_band_phase, nb_bins=61, minmax=(-np.pi, np.pi) +) + +# %% +# *** +# Now we will plot these preferences as smoothed angular histograms. We will select the first 6 units +# to plot. + + +def smoothAngularTuningCurves(tuning_curves, sigma=2): + tmp = np.concatenate( + (tuning_curves.values, tuning_curves.values, tuning_curves.values) + ) + tmp = scipy.ndimage.gaussian_filter1d(tmp, sigma=sigma, axis=0) + return pd.DataFrame( + index=tuning_curves.index, + data=tmp[tuning_curves.shape[0] : tuning_curves.shape[0] * 2], + columns=tuning_curves.columns, + ) -# We will start by throwing away cells which do not have enough -# spikes during our interval -spikes = {k: v for k, v in spikes.items() if len(v) > 20} -# Get phase of each spike -phase = {} -for i in spikes.keys(): - phase_i = [] - for spike in spikes[i]: - phase_i.append( - np.angle( - mwt_REM[ - np.argmin(np.abs(REM_Tsd.index.values - spike)), theta_freq_index - ] - ) - ) - phase[i] = np.array(phase_i) -# Let's plot phase histograms for the first six units to see if there's -# any obvious preferences -fig, ax = plt.subplots(2, 3, constrained_layout=True, figsize=(10, 6)) -for ri in range(2): - for ci in range(3): - ax[ri, ci].hist( - phase[list(phase.keys())[ri * 3 + ci]], - bins=np.linspace(-np.pi, np.pi, 10), - density=True, - ) - ax[ri, ci].set_xlabel("Phase (rad)") - ax[ri, ci].set_ylabel("Density") - ax[ri, ci].set_title(f"Unit {list(phase.keys())[ri*3 + ci]}") +smoothcurves = smoothAngularTuningCurves(tuning_curves, sigma=2) +fig, axd = plt.subplot_mosaic( + [["phase_0", "phase_1", "phase_2"], ["phase_3", "phase_4", "phase_5"]], + constrained_layout=True, + figsize=(10, 6), + subplot_kw={"projection": "polar"}, +) +for pl_i, sc_i in enumerate(list(smoothcurves)[:6]): + axd[f"phase_{pl_i}"].plot( + list(smoothcurves[sc_i].index) + list([smoothcurves[sc_i].index[0]]), + list(smoothcurves[sc_i].values) + list([smoothcurves[sc_i].values[0]]), + ) + axd[f"phase_{pl_i}"].set_xlabel("Phase (rad)") # Angle in radian, on the X-axis + axd[f"phase_{pl_i}"].set_ylabel( + "Firing Rate (Hz)" + ) # Firing rate in Hz, on the Y-axis + axd[f"phase_{pl_i}"].set_xticks([]) + axd[f"phase_{pl_i}"].set_title(f"Unit {sc_i}") fig.suptitle("Phase Preference Histograms of First 6 Units") -plt.show() + # %% # *** @@ -279,30 +283,38 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # Now that we have our phases of firing for each unit, we can sort the units by the circular variance of the phase # of their spikes, to isolate the cells with the strongest phase preferences without manual inspection. -variances = { +# Get phase of each spike +phase = {} +for i in spikes_tsdg: + phase_i = [ + theta_band_phase[np.argmin(np.abs(REM_Tsd.index.values - s.index))] + for s in spikes_tsdg[i] + ] + phase[i] = np.array(phase_i) +phase_var = { key: scipy.stats.circvar(value, low=-np.pi, high=np.pi) for key, value in phase.items() } -spikes = dict(sorted(spikes.items(), key=lambda item: variances[item[0]])) -phase = dict(sorted(phase.items(), key=lambda item: variances[item[0]])) - -# Now let's plot phase histograms for the six units with the least -# varied phase of spikes. -fig, ax = plt.subplots(2, 3, constrained_layout=True, figsize=(10, 6)) -for ri in range(2): - for ci in range(3): - ax[ri, ci].hist( - phase[list(phase.keys())[ri * 3 + ci]], - bins=np.linspace(-np.pi, np.pi, 10), - density=True, - ) - ax[ri, ci].set_xlabel("Phase (rad)") - ax[ri, ci].set_ylabel("Density") - ax[ri, ci].set_title(f"Unit {list(phase.keys())[ri*3 + ci]}") -fig.suptitle( - "Phase Preference Histograms of 6 Units with " + "Highest Phase Preference" +phase_var = dict(sorted(phase_var.items(), key=lambda item: item[1])) + +fig, axd = plt.subplot_mosaic( + [["phase_0", "phase_1", "phase_2"], ["phase_3", "phase_4", "phase_5"]], + constrained_layout=True, + figsize=(10, 6), + subplot_kw={"projection": "polar"}, ) -plt.show() +for pl_i, sc_i in enumerate(list(phase_var.keys())[:6]): + axd[f"phase_{pl_i}"].plot( + list(smoothcurves[sc_i].index) + list([smoothcurves[sc_i].index[0]]), + list(smoothcurves[sc_i].values) + list([smoothcurves[sc_i].values[0]]), + ) + axd[f"phase_{pl_i}"].set_xlabel("Phase (rad)") # Angle in radian, on the X-axis + axd[f"phase_{pl_i}"].set_ylabel( + "Firing Rate (Hz)" + ) # Firing rate in Hz, on the Y-axis + axd[f"phase_{pl_i}"].set_xticks([]) + axd[f"phase_{pl_i}"].set_title(f"Unit {sc_i}") +fig.suptitle("Phase Preference Histograms of 6 Units with Highest Phase Preference ") # %% # *** @@ -311,38 +323,34 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # There is definitely some strong phase preferences happening here. Let's visualize the firing preferences # of the 6 cells we've isolated to get an impression of just how striking these preferences are. -fig = plt.figure(constrained_layout=True, figsize=(10, 12)) +fig = plt.figure(constrained_layout=True, figsize=(10, 8)) axd = fig.subplot_mosaic( [ ["lfp_run"], ["phase_0"], ["phase_1"], ["phase_2"], - ["phase_3"], - ["phase_4"], - ["phase_5"], ], - height_ratios=[0.4, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2], + height_ratios=[0.4, 0.2, 0.2, 0.2], ) -[axd[k].margins(0) for k in ["lfp_run"] + [f"phase_{i}" for i in range(6)]] +[axd[k].margins(0) for k in ["lfp_run"] + [f"phase_{i}" for i in range(3)]] axd["lfp_run"].plot( - REM_Tsd.index.values, REM_Tsd[:, channel], alpha=0.5, label="LFP Data - REM" + REM_Tsd.index.values, REM_Tsd[:, 0], alpha=0.5, label="LFP Data - REM" ) axd["lfp_run"].plot( REM_Tsd.index.values, theta_band_reconstruction, - label=f"{freqs[theta_freq_index]}Hz oscillations", + label=f"{np.round(freqs[theta_freq_index],2)}Hz oscillations", ) axd["lfp_run"].set_ylabel("LFP (v)") axd["lfp_run"].set_xlabel("Time (s)") -axd["lfp_run"].set_title(f"{freqs[theta_freq_index]}Hz oscillation power.") +axd["lfp_run"].set_title(f"{np.round(freqs[theta_freq_index],2)}Hz oscillation power.") axd["lfp_run"].legend() -for i in range(6): +for i in range(3): axd[f"phase_{i}"].plot(REM_Tsd.index.values, theta_band_phase, alpha=0.2) axd[f"phase_{i}"].scatter( - spikes[list(spikes.keys())[i]], phase[list(spikes.keys())[i]] + spikes[list(phase_var.keys())[i]], phase[list(phase_var.keys())[i]] ) axd[f"phase_{i}"].set_ylabel("Phase") - axd[f"phase_{i}"].set_title(f"Unit {list(spikes.keys())[i]}") + axd[f"phase_{i}"].set_title(f"Unit {list(phase_var.keys())[i]}") fig.suptitle("Phase Preference Visualizations") -plt.show() diff --git a/docs/examples/tutorial_signal_processing.py b/docs/examples/tutorial_signal_processing.py index b5d786fe..d9b66be9 100644 --- a/docs/examples/tutorial_signal_processing.py +++ b/docs/examples/tutorial_signal_processing.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- """ -Grosmark & Buzsáki (2016) Tutorial 1 +Computing Wavelet Transform ============ This tutorial demonstrates how we can use the signal processing tools within Pynapple to aid with data analysis. We will examine the dataset from [Grosmark & Buzsáki (2016)](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4919122/). @@ -16,7 +16,7 @@ # !!! warning # This tutorial uses matplotlib for displaying the figure # -# You can install all with `pip install matplotlib requests tqdm` +# You can install all with `pip install matplotlib requests tqdm seaborn` # # First, import the necessary libraries: @@ -26,8 +26,11 @@ import matplotlib.pyplot as plt import numpy as np import requests +import seaborn import tqdm +seaborn.set_theme() + import pynapple as nap # %% @@ -57,7 +60,7 @@ # Let's load and print the full dataset. data = nap.load_file(path) -FS = len(data["eeg"].index[:]) / (data["eeg"].index[-1] - data["eeg"].index[0]) +FS = 1250 print(data) @@ -78,9 +81,7 @@ ) RUN_Tsd = data["eeg"].restrict(RUN_interval) RUN_pos = data["position"].restrict(RUN_interval) - -# The given dataset has only one channel, so we set channel = 0 here -channel = 0 +print(RUN_Tsd) # %% # *** @@ -90,11 +91,11 @@ fig = plt.figure(constrained_layout=True, figsize=(10, 6)) axd = fig.subplot_mosaic( [["ephys"], ["pos"]], - height_ratios=[1, 0.2], + height_ratios=[1, 0.4], ) axd["ephys"].plot( - RUN_Tsd[:, channel].restrict( + RUN_Tsd[:, 0].restrict( nap.IntervalSet( data["forward_ep"]["start"][run_index], data["forward_ep"]["end"][run_index] ) @@ -103,7 +104,7 @@ color="green", ) axd["ephys"].plot( - RUN_Tsd[:, channel].restrict( + RUN_Tsd[:, 0].restrict( nap.IntervalSet( data["forward_ep"]["end"][run_index], data["forward_ep"]["end"][run_index] + 5.0, @@ -126,17 +127,23 @@ # %% # *** -# Getting the LFP Spectogram +# Getting the LFP Spectrogram # ----------------------------------- -# Let's take the Fourier transforms of one channel for both waking and sleeping and see if differences are present +# Let's take the Fourier transform of our data to get an initial insight into the dominant frequencies. fft = nap.compute_spectogram(RUN_Tsd, fs=int(FS)) +print(fft) + +# %% +# *** +# The returned object is a pandas dataframe which uses frequencies as indexes and spectral power as values. +# +# Now let's plot it -# Now we will plot it fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) ax.plot( fft.index, - np.abs(fft.iloc[:, channel]), + np.abs(fft.iloc[:, 0]), alpha=0.5, label="LFP Frequency Power", c="blue", @@ -160,32 +167,14 @@ # LFP characteristics may be different while the animal is running along the track, and when it is finished. # Let's generate a wavelet decomposition to look more closely at the changing frequency powers over time. -# We must define the frequency set that we'd like to use for our decomposition; these -# have been manually selected based on the frequencies used in Frey et. al (2021), but -# could also be defined as a linspace or logspace -freqs = np.array( - [ - 2.59, - 3.66, - 5.18, - 8.0, - 10.36, - 20.72, - 29.3, - 41.44, - 58.59, - 82.88, - 117.19, - 152.35, - 192.19, - 200.0, - 234.38, - 270.00, - 331.5, - 390.00, - ] -) -mwt_RUN = nap.compute_wavelet_transform(RUN_Tsd[:, channel], fs=None, freqs=freqs) +# We must define the frequency set that we'd like to use for our decomposition +freqs = np.geomspace(5, 250, 25) +# Compute and print the wavelet transform on our LFP data +mwt_RUN = nap.compute_wavelet_transform(RUN_Tsd[:, 0], fs=FS, freqs=freqs) + +# %% +# *** +# Now let's plot it. # Define wavelet decomposition plotting function @@ -202,20 +191,21 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): else: x_tick_pos = [np.argmin(np.abs(times - val)) for val in x_ticks] ax.set(xticks=x_tick_pos, xticklabels=x_ticks) - y_ticks = freqs + y_ticks = [np.round(f, 2) for f in freqs] y_ticks_pos = [np.argmin(np.abs(freqs - val)) for val in y_ticks] ax.set(yticks=y_ticks_pos, yticklabels=y_ticks) + ax.grid(False) # And plot -fig = plt.figure(constrained_layout=True, figsize=(10, 6)) +fig = plt.figure(constrained_layout=True, figsize=(10, 8)) axd = fig.subplot_mosaic( [ ["wd_run"], ["lfp_run"], ["pos_run"], ], - height_ratios=[1, 0.2, 0.4], + height_ratios=[1.2, 0.2, 0.6], ) plot_timefrequency( RUN_Tsd.index.values[:], @@ -243,36 +233,42 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # There seems to be a strong theta frequency present in the data during the maze traversal. # Let's plot the estimated 8Hz component of the wavelet decomposition on top of our data, and see how well # they match up -theta_freq_index = 3 + +# Find the index of the frequency closest to theta band +theta_freq_index = np.argmin(np.abs(10 - freqs)) +# Extract its real component, as well as its power envelope theta_band_reconstruction = mwt_RUN[:, theta_freq_index].values.real theta_band_power_envelope = np.abs(mwt_RUN[:, theta_freq_index].values) + +# %% +# *** +# Now let's visualise the theta band component of the signal over time. + fig = plt.figure(constrained_layout=True, figsize=(10, 6)) axd = fig.subplot_mosaic( [ ["lfp_run"], ["pos_run"], ], - height_ratios=[1, 0.3], + height_ratios=[1, 0.4], ) -axd["lfp_run"].plot( - RUN_Tsd.index.values, RUN_Tsd[:, channel], alpha=0.5, label="LFP Data" -) +axd["lfp_run"].plot(RUN_Tsd.index.values, RUN_Tsd[:, 0], alpha=0.5, label="LFP Data") axd["lfp_run"].plot( RUN_Tsd.index.values, theta_band_reconstruction, - label=f"{freqs[theta_freq_index]}Hz oscillations", + label=f"{np.round(freqs[theta_freq_index], 2)}Hz oscillations", ) axd["lfp_run"].plot( RUN_Tsd.index.values, theta_band_power_envelope, - label=f"{freqs[theta_freq_index]}Hz power envelope", + label=f"{np.round(freqs[theta_freq_index], 2)}Hz power envelope", ) axd["lfp_run"].set_ylabel("LFP (v)") axd["lfp_run"].set_xlabel("Time (s)") -axd["lfp_run"].set_title(f"{freqs[theta_freq_index]}Hz oscillation power.") +axd["lfp_run"].set_title(f"{np.round(freqs[theta_freq_index], 2)}Hz oscillation power.") axd["pos_run"].plot(RUN_pos) [axd[k].margins(0) for k in ["lfp_run", "pos_run"]] [ @@ -292,9 +288,16 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # Let's plot the LFP, along with the 200Hz frequency power, to see if we can isolate these peaks and # see what's going on. -ripple_freq_idx = 13 +# Find the index of the frequency closest to sharp wave ripple oscillations +ripple_freq_idx = np.argmin(np.abs(200 - freqs)) +# Extract its power envelope ripple_power = np.abs(mwt_RUN[:, ripple_freq_idx].values) + +# %% +# *** +# Now let's visualise the 200Hz component of the signal over time. + fig = plt.figure(constrained_layout=True, figsize=(10, 5)) axd = fig.subplot_mosaic( [ @@ -303,7 +306,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): ], height_ratios=[1, 0.4], ) -axd["lfp_run"].plot(RUN_Tsd.index.values, RUN_Tsd[:, channel], label="LFP Data") +axd["lfp_run"].plot(RUN_Tsd.index.values, RUN_Tsd[:, 0], label="LFP Data") axd["lfp_run"].set_ylabel("LFP (v)") axd["lfp_run"].set_xlabel("Time (s)") axd["lfp_run"].margins(0) @@ -316,7 +319,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): axd["rip_pow"].spines["bottom"].set_visible(False) axd["rip_pow"].spines["left"].set_visible(False) axd["rip_pow"].set_xlim(RUN_Tsd.index.min(), RUN_Tsd.index.max()) -axd["rip_pow"].set_ylabel(f"{freqs[ripple_freq_idx]}Hz Power") +axd["rip_pow"].set_ylabel(f"{np.round(freqs[ripple_freq_idx], 2)}Hz Power") # %% # *** @@ -325,32 +328,22 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # We can see one significant peak in the 200Hz frequency power. Let's smooth the power curve and threshold # to try to isolate this event. -# define our threshold -threshold = 100 -# smooth our wavelet power -window_size = 51 -window = np.ones(window_size) / window_size -smoother_swr_power = np.convolve( - np.abs(mwt_RUN[:, ripple_freq_idx].values), window, mode="same" +# Define threshold +threshold = 6000 +# Smooth wavelet power TsdFrame at the SWR frequency +smoother_swr_power = ( + mwt_RUN[:, ripple_freq_idx] + .abs() + .smooth(std=0.025, windowsize=0.2, time_units="s", norm=False) ) -# isolate our ripple periods -is_ripple = smoother_swr_power > threshold -start_idx = None -ripple_periods = [] -for i in range(len(RUN_Tsd.index.values)): - if is_ripple[i] and start_idx is None: - start_idx = i - elif not is_ripple[i] and start_idx is not None: - axd["rip_pow"].plot( - RUN_Tsd.index.values[start_idx:i], - smoother_swr_power[start_idx:i], - color="red", - linewidth=2, - ) - ripple_periods.append((start_idx, i)) - start_idx = None +# Threshold our TsdFrame +is_ripple = smoother_swr_power.threshold(threshold) + + +# %% +# *** +# Now let's plot the threshold ripple power over time. -# plot of captured ripple periods fig = plt.figure(constrained_layout=True, figsize=(10, 5)) axd = fig.subplot_mosaic( [ @@ -359,24 +352,18 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): ], height_ratios=[1, 0.4], ) -axd["lfp_run"].plot(RUN_Tsd.index.values, RUN_Tsd[:, channel], label="LFP Data") -axd["rip_pow"].plot(RUN_Tsd.index.values, smoother_swr_power) -for r in ripple_periods: - axd["rip_pow"].plot( - RUN_Tsd.index.values[r[0] : r[1]], - smoother_swr_power[r[0] : r[1]], - color="red", - linewidth=2, - ) +axd["lfp_run"].plot(RUN_Tsd[:, 0], label="LFP Data") +axd["rip_pow"].plot(smoother_swr_power) +axd["rip_pow"].plot(is_ripple, color="red", linewidth=2) axd["lfp_run"].set_ylabel("LFP (v)") axd["lfp_run"].set_xlabel("Time (s)") -axd["lfp_run"].set_title(f"{freqs[ripple_freq_idx]}Hz oscillation power.") +axd["lfp_run"].set_title(f"{np.round(freqs[ripple_freq_idx], 2)}Hz oscillation power.") axd["rip_pow"].axhline(threshold) [axd[k].margins(0) for k in ["lfp_run", "rip_pow"]] [axd["rip_pow"].spines[sp].set_visible(False) for sp in ["top", "left", "right"]] axd["rip_pow"].get_xaxis().set_visible(False) axd["rip_pow"].set_xlim(RUN_Tsd.index.min(), RUN_Tsd.index.max()) -axd["rip_pow"].set_ylabel(f"{freqs[ripple_freq_idx]}Hz Power") +axd["rip_pow"].set_ylabel(f"{np.round(freqs[ripple_freq_idx], 2)}Hz Power") # %% # *** @@ -384,21 +371,21 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # ----------------------------------- # Let's zoom in on out detected ripples and have a closer look! -# Filter out ripples which do not last long enough -ripple_periods = [r for r in ripple_periods if r[1] - r[0] > 20] - -# And plot! fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) -buffer = 200 +buffer = 0.1 ax.plot( - RUN_Tsd.index.values[r[0] - buffer : r[1] + buffer], - RUN_Tsd[r[0] - buffer : r[1] + buffer], + RUN_Tsd.restrict( + nap.IntervalSet( + start=is_ripple.index.min() - buffer, end=is_ripple.index.max() + buffer + ) + ), color="blue", label="Non-SWR LFP", ) ax.plot( - RUN_Tsd.index.values[r[0] : r[1]], - RUN_Tsd[r[0] : r[1]], + RUN_Tsd.restrict( + nap.IntervalSet(start=is_ripple.index.min(), end=is_ripple.index.max()) + ), color="red", label="SWR", linewidth=2, From fa7952efa2dfe64b1cac4b2fcc7d71b845c16b22 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Fri, 19 Jul 2024 23:26:35 +0100 Subject: [PATCH 40/71] removing welch tests --- tests/test_signal_processing.py | 29 ----------------------------- 1 file changed, 29 deletions(-) diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index 8edc4bd7..510d29da 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -44,35 +44,6 @@ def test_compute_spectogram(): ) -def test_compute_welch_spectogram(): - t = np.linspace(0, 1, 10000) - sig = nap.TsdFrame( - d=np.random.random((10000, 4)), - t=t, - time_support=nap.IntervalSet(start=[0.1, 0.4], end=[0.2, 0.525]), - ) - r = nap.compute_welch_spectogram(sig) - assert isinstance(r, pd.DataFrame) - assert r.shape[1] == 4 - - t = np.linspace(0, 1, 1024) - sig = nap.Tsd(d=np.random.random(1024), t=t) - r = nap.compute_welch_spectogram(sig) - assert isinstance(r, pd.DataFrame) - - sig = nap.TsdFrame(d=np.random.random((1024, 4)), t=t) - r = nap.compute_welch_spectogram(sig) - assert isinstance(r, pd.DataFrame) - assert r.shape[1] == 4 - - with pytest.raises(TypeError) as e_info: - nap.compute_welch_spectogram("a_string") - assert ( - str(e_info.value) - == "Currently compute_welch_spectogram is only implemented for Tsd or TsdFrame" - ) - - def test_compute_wavelet_transform(): t = np.linspace(0, 1, 1000) From cac44ff9a8dcd35fbf6647479971f862f0705317 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Fri, 19 Jul 2024 23:35:21 +0100 Subject: [PATCH 41/71] fixed broked phase notebook --- docs/examples/tutorial_phase_preferences.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/docs/examples/tutorial_phase_preferences.py b/docs/examples/tutorial_phase_preferences.py index f5af3b66..3d42c663 100644 --- a/docs/examples/tutorial_phase_preferences.py +++ b/docs/examples/tutorial_phase_preferences.py @@ -86,13 +86,7 @@ # We will also extract spike times from all units in our dataset # which occur during our specified interval -spikes = {} -for i in data["units"].index: - spikes[i] = data["units"][i].times()[ - (data["units"][i].times() > REM_minute_interval["start"][0]) - & (data["units"][i].times() < REM_minute_interval["end"][0]) - ] -spikes_tsdg = data["units"].restrict(REM_minute_interval) +spikes = data["units"].restrict(REM_minute_interval) # %% # *** @@ -230,10 +224,10 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # We will start by throwing away cells which do not have a high enough firing rate during our interval. # Filter units based on firing rate -spikes_tsdg = spikes_tsdg[spikes_tsdg.rate > 5.0] +spikes = spikes[spikes.rate > 5.0] # Calculate theta phase firing preferences tuning_curves = nap.compute_1d_tuning_curves( - group=spikes_tsdg, feature=theta_band_phase, nb_bins=61, minmax=(-np.pi, np.pi) + group=spikes, feature=theta_band_phase, nb_bins=61, minmax=(-np.pi, np.pi) ) # %% @@ -285,10 +279,10 @@ def smoothAngularTuningCurves(tuning_curves, sigma=2): # Get phase of each spike phase = {} -for i in spikes_tsdg: +for i in spikes: phase_i = [ theta_band_phase[np.argmin(np.abs(REM_Tsd.index.values - s.index))] - for s in spikes_tsdg[i] + for s in spikes[i] ] phase[i] = np.array(phase_i) phase_var = { @@ -349,7 +343,7 @@ def smoothAngularTuningCurves(tuning_curves, sigma=2): for i in range(3): axd[f"phase_{i}"].plot(REM_Tsd.index.values, theta_band_phase, alpha=0.2) axd[f"phase_{i}"].scatter( - spikes[list(phase_var.keys())[i]], phase[list(phase_var.keys())[i]] + spikes[list(phase_var.keys())[i]].index, phase[list(phase_var.keys())[i]] ) axd[f"phase_{i}"].set_ylabel("Phase") axd[f"phase_{i}"].set_title(f"Unit {list(phase_var.keys())[i]}") From b0bb20f88982683d32a38e58652fd54249e98d5d Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Fri, 19 Jul 2024 23:37:25 +0100 Subject: [PATCH 42/71] better comments on phase notebook --- docs/examples/tutorial_phase_preferences.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/examples/tutorial_phase_preferences.py b/docs/examples/tutorial_phase_preferences.py index 3d42c663..57fc036a 100644 --- a/docs/examples/tutorial_phase_preferences.py +++ b/docs/examples/tutorial_phase_preferences.py @@ -291,6 +291,11 @@ def smoothAngularTuningCurves(tuning_curves, sigma=2): } phase_var = dict(sorted(phase_var.items(), key=lambda item: item[1])) +# %% +# *** +# And now we plot the phase preference histograms of the 6 units with the least variance in the phase of their +# spiking behaviour. + fig, axd = plt.subplot_mosaic( [["phase_0", "phase_1", "phase_2"], ["phase_3", "phase_4", "phase_5"]], constrained_layout=True, From 496fbe359f12255fb3814605b22961653fab3df7 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Mon, 29 Jul 2024 15:54:15 +0100 Subject: [PATCH 43/71] PR comments addressed, tests added --- docs/examples/tutorial_phase_preferences.py | 5 ---- pynapple/process/signal_processing.py | 32 +++++++++++++++------ tests/test_signal_processing.py | 30 +++++++++++++++++++ 3 files changed, 54 insertions(+), 13 deletions(-) diff --git a/docs/examples/tutorial_phase_preferences.py b/docs/examples/tutorial_phase_preferences.py index 57fc036a..18e311c9 100644 --- a/docs/examples/tutorial_phase_preferences.py +++ b/docs/examples/tutorial_phase_preferences.py @@ -3,11 +3,6 @@ Computing Phase Preferences ============ -In the previous [Grosmark & Buzsáki (2016)](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4919122/) tutorial, -we learned how to use Pynapple's signal processing tools with Local Field Potential data. Specifically, we -used wavelet decompositions to isolate Theta band activity during active traversal of a linear track, -as well as to find Sharp Wave Ripples which occurred after traversal. - In this tutorial we will learn how to isolate phase information from our wavelet decomposition and combine it with spiking data, to find phase preferences of spiking units. diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 1b83ea34..914a368b 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -14,6 +14,7 @@ def compute_spectogram(sig, fs=None, ep=None, full_range=False): """ Performs numpy fft on sig, returns output. Pynapple assumes a constant sampling rate for sig. + Parameters ---------- sig : pynapple.Tsd or pynapple.TsdFrame Time series. @@ -24,6 +25,12 @@ def compute_spectogram(sig, fs=None, ep=None, full_range=False): full_range : bool, optional If true, will return full fft frequency range, otherwise will return only positive values + Returns + ------- + pandas.DataFrame + Time frequency representation of the input signal, indexes are frequencies, values + are powers. + Notes ----- compute_spectogram computes fft on only a single epoch of data. This epoch be given with the ep @@ -40,7 +47,7 @@ def compute_spectogram(sig, fs=None, ep=None, full_range=False): if len(ep) != 1: raise ValueError("Given epoch (or signal time_support) must have length 1") if fs is None: - fs = sig.index.shape[0] / (sig.index.max() - sig.index.min()) + fs = sig.rate fft_result = np.fft.fft(sig.restrict(ep).values, axis=0) fft_freq = np.fft.fftfreq(len(sig.restrict(ep).values), 1 / fs) ret = pd.DataFrame(fft_result, fft_freq) @@ -107,7 +114,7 @@ def _create_freqs(freq_start, freq_stop, freq_step=1, log_scaling=False, log_bas def compute_wavelet_transform( - sig, freqs, fs=None, n_cycles=1.5, scaling=1.0, precision=10, norm=None + sig, freqs, fs=None, n_cycles=1.5, scaling=1.0, precision=16, norm=None ): """ Compute the time-frequency representation of a signal using Morlet wavelets. @@ -128,7 +135,8 @@ def compute_wavelet_transform( scaling : float Scaling factor. precision: int. - Precision of wavelet to use. Default is 8 + Precision of wavelet to use. . Defines the number of timepoints to evaluate the Morlet wavelet at. + Default is 16 norm : {None, 'sss', 'amp'}, optional Normalization method: * None - no normalization @@ -137,9 +145,18 @@ def compute_wavelet_transform( Returns ------- - pynapple.TsdFrame or pynapple.TsdTensor : 2d array + pynapple.TsdFrame or pynapple.TsdTensor Time frequency representation of the input signal. + Examples + -------- + >>> import numpy as np + >>> import pynapple as nap + >>> t = np.linspace(0, 1, 1000) + >>> signal = nap.Tsd(d=np.sin(t * 50 * np.pi * 2), t=t) + >>> freqs = np.linspace(10, 100, 10) + >>> mwt = nap.compute_wavelet_transform(signal, fs=None, freqs=freqs) + Notes ----- This computes the continuous wavelet transform at specified frequencies across time. @@ -158,7 +175,6 @@ def compute_wavelet_transform( fs = sig.rate if isinstance(sig, nap.Tsd): - sig = sig.reshape((sig.shape[0], 1)) output_shape = (sig.shape[0], len(freqs)) else: output_shape = (sig.shape[0], len(freqs), *sig.shape[1:]) @@ -176,7 +192,7 @@ def compute_wavelet_transform( coef = np.insert( coef, 1, coef[0, :], axis=0 ) # slightly hacky line, necessary to make output correct shape - cwt = np.swapaxes(coef, 1, 2) + cwt = np.expand_dims(coef, -1) if len(coef.shape) == 2 else coef if len(output_shape) == 2: return nap.TsdFrame( @@ -188,7 +204,7 @@ def compute_wavelet_transform( ) -def generate_morlet_filterbank(freqs, fs, n_cycles=1.5, scaling=1.0, precision=10): +def generate_morlet_filterbank(freqs, fs, n_cycles=1.5, scaling=1.0, precision=16): """ Generates a Morlet filterbank using the given frequencies and parameters. Can be used purely for visualization, or to convolve with a pynapple Tsd, TsdFrame, or TsdTensor as part of a wavelet decomposition process. @@ -207,7 +223,7 @@ def generate_morlet_filterbank(freqs, fs, n_cycles=1.5, scaling=1.0, precision=1 scaling : float Scaling factor. precision: int. - Precision of wavelet to use. + Precision of wavelet to use. Defines the number of timepoints to evaluate the Morlet wavelet at. Returns ------- diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index 510d29da..bfcd1eec 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -3,6 +3,7 @@ import numpy as np import pandas as pd import pytest +import pywt import pynapple as nap @@ -45,6 +46,35 @@ def test_compute_spectogram(): def test_compute_wavelet_transform(): + t = np.linspace(0, 1, 1000) + sig = nap.Tsd( + d=np.sin(t * 50 * np.pi * 2) + * np.interp(np.linspace(0, 1, 1000), [0, 0.5, 1], [0, 1, 0]), + t=t, + ) + freqs = np.linspace(10, 100, 10) + mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) + mpf = freqs[np.argmax(np.sum(np.abs(mwt), axis=0))] + assert mpf == 50 + assert ( + np.unravel_index(np.abs(mwt.values).argmax(), np.abs(mwt.values).shape)[0] + == 500 + ) + + t = np.linspace(0, 1, 1000) + sig = nap.Tsd( + d=np.sin(t * 10 * np.pi * 2) + * np.interp(np.linspace(0, 1, 1000), [0, 0.5, 1], [0, 1, 0]), + t=t, + ) + freqs = np.linspace(10, 100, 10) + mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) + mpf = freqs[np.argmax(np.sum(np.abs(mwt), axis=0))] + assert mpf == 10 + assert ( + np.unravel_index(np.abs(mwt.values).argmax(), np.abs(mwt.values).shape)[0] + == 500 + ) t = np.linspace(0, 1, 1000) sig = nap.Tsd(d=np.sin(t * 50 * np.pi * 2), t=t) From 647bbfc10d216f6865ff24bcae195469e122a72d Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Mon, 29 Jul 2024 15:59:22 +0100 Subject: [PATCH 44/71] unused import removed --- tests/test_signal_processing.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index bfcd1eec..8f853591 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -3,7 +3,6 @@ import numpy as np import pandas as pd import pytest -import pywt import pynapple as nap From e63ffe8d2153c1bea709f7b081da1bfdf128f204 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Tue, 30 Jul 2024 22:41:21 +0100 Subject: [PATCH 45/71] removing integrate->conv->diff pipeline --- docs/api_guide/tutorial_pynapple_wavelets.py | 71 +++++++++------- pynapple/process/signal_processing.py | 88 ++++++++++---------- tests/test_signal_processing.py | 28 +++++-- 3 files changed, 108 insertions(+), 79 deletions(-) diff --git a/docs/api_guide/tutorial_pynapple_wavelets.py b/docs/api_guide/tutorial_pynapple_wavelets.py index 8cdeb9e1..d34de4f2 100644 --- a/docs/api_guide/tutorial_pynapple_wavelets.py +++ b/docs/api_guide/tutorial_pynapple_wavelets.py @@ -74,16 +74,18 @@ # Define the frequency of the wavelets in our filter bank freqs = np.linspace(1, 25, num=25) # Get the filter bank -filter_bank = nap.generate_morlet_filterbank(freqs, Fs, n_cycles=1.5, scaling=1.0) +filter_bank = nap.generate_morlet_filterbank( + freqs, Fs, gaussian_width=1.5, window_length=1.0 +) # %% # Lets plot it. def plot_filterbank(filter_bank, freqs, title): fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 7)) - offset = 0.2 + offset = 1.0 for f_i in range(filter_bank.shape[1]): - ax.plot(filter_bank[:, f_i] + offset * f_i) + ax.plot(filter_bank[:, f_i].real() + offset * f_i) ax.text( -2.3, offset * f_i, f"{np.round(freqs[f_i], 2)}Hz", va="center", ha="left" ) @@ -95,7 +97,7 @@ def plot_filterbank(filter_bank, freqs, title): ax.set_title(title) -title = "Morlet Wavelet Filter Bank (Real Components): n_cycles=1.5, scaling=1.0" +title = "Morlet Wavelet Filter Bank (Real Components): gaussian_width=1.5, window_length=1.0" plot_filterbank(filter_bank, freqs, title) # %% @@ -108,7 +110,9 @@ def plot_filterbank(filter_bank, freqs, title): # scalogram. # Compute the wavelet transform using the parameters above -mwt = nap.compute_wavelet_transform(sig, fs=Fs, freqs=freqs, n_cycles=1.5, scaling=1.0) +mwt = nap.compute_wavelet_transform( + sig, fs=Fs, freqs=freqs, gaussian_width=1.5, window_length=1.0 +) # %% @@ -217,13 +221,14 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): ax[1].set_xlabel("Time (s)") [ax[i].set_ylabel("Signal") for i in range(2)] + # %% # *** # Adding ALL the Oscillations! # ------------------ # Let's now add together the real components of all frequency bands to recreate a version of the original signal. -combined_oscillations = mwt.sum(axis=1) +combined_oscillations = mwt.sum(axis=1).real() # %% # Lets plot it. @@ -249,19 +254,18 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # underlying wavelets. freqs = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) -scales = [1.0, 2.0, 3.0] -cycles = [1.0, 2.0, 3.0] +window_lengths = [1.0, 2.0, 3.0] +gaussian_width = [1.0, 2.0, 3.0] fig, ax = plt.subplots( - len(scales), len(cycles), constrained_layout=True, figsize=(10, 5) + len(window_lengths), len(gaussian_width), constrained_layout=True, figsize=(10, 8) ) -for row_i, sc in enumerate(scales): - for col_i, cyc in enumerate(cycles): +for row_i, wl in enumerate(window_lengths): + for col_i, gw in enumerate(gaussian_width): filter_bank = nap.generate_morlet_filterbank( - freqs, 1000, n_cycles=cyc, scaling=sc + freqs, 1000, gaussian_width=gw, window_length=wl, precision=12 ) - ax[row_i, col_i].plot(filter_bank[:, 0]) - ax[row_i, col_i].set_xlim(-15, 15) + ax[row_i, col_i].plot(filter_bank[:, 0].real()) ax[row_i, col_i].set_xlabel("Time (s)") ax[row_i, col_i].set_yticks([]) [ @@ -270,14 +274,15 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): ] if col_i != 0: ax[row_i, col_i].get_yaxis().set_visible(False) -for col_i, cyc in enumerate(cycles): - ax[0, col_i].set_title(f"n_cycles={cyc}", fontsize=10) -for row_i, scl in enumerate(scales): - ax[row_i, 0].set_ylabel(f"scaling={scl}", fontsize=10) +for col_i, gw in enumerate(gaussian_width): + ax[0, col_i].set_title(f"gaussian_width={gw}", fontsize=10) +for row_i, wl in enumerate(window_lengths): + ax[row_i, 0].set_ylabel(f"window_length={wl}", fontsize=10) fig.suptitle("Parametrization Visualization") + # %% -# Increasing n_cycles increases the number of wavelet cycles present in the oscillations (cycles) within the +# Increasing time_decay increases the number of wavelet cycles present in the oscillations (cycles) within the # Gaussian window of the Morlet wavelet. It essentially controls the trade-off between time resolution # and frequency resolution. # @@ -286,23 +291,27 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # %% # *** -# Effect of n_cycles +# Effect of time_decay # ------------------ -# Let's increase n_cycles to 7.5 and see the effect on the resultant filter bank. +# Let's increase time_decay to 7.5 and see the effect on the resultant filter bank. freqs = np.linspace(1, 25, num=25) -filter_bank = nap.generate_morlet_filterbank(freqs, 1000, n_cycles=7.5, scaling=1.0) +filter_bank = nap.generate_morlet_filterbank( + freqs, 1000, gaussian_width=7.5, window_length=1.0 +) plot_filterbank( filter_bank, freqs, - "Morlet Wavelet Filter Bank (Real Components): n_cycles=7.5, scaling=1.0", + "Morlet Wavelet Filter Bank (Real Components): gaussian_width=7.5, center_frequency=1.0", ) # %% # *** # Let's see what effect this has on the Wavelet Scalogram which is generated... -mwt = nap.compute_wavelet_transform(sig, fs=Fs, freqs=freqs, n_cycles=7.5, scaling=1.0) +mwt = nap.compute_wavelet_transform( + sig, fs=Fs, freqs=freqs, gaussian_width=7.5, window_length=1.0 +) fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) plot_timefrequency( @@ -317,7 +326,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # *** # And let's see if that has an effect on the reconstructed version of the signal -combined_oscillations = mwt.sum(axis=1) +combined_oscillations = mwt.sum(axis=1).real() # %% # Lets plot it. @@ -343,19 +352,23 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # Let's increase scaling to 2.0 and see the effect on the resultant filter bank. freqs = np.linspace(1, 25, num=25) -filter_bank = nap.generate_morlet_filterbank(freqs, 1000, n_cycles=7.5, scaling=2.0) +filter_bank = nap.generate_morlet_filterbank( + freqs, 1000, gaussian_width=7.5, window_length=2.0 +) plot_filterbank( filter_bank, freqs, - "Morlet Wavelet Filter Bank (Real Components): n_cycles=7.5, scaling=2.0", + "Morlet Wavelet Filter Bank (Real Components): gaussian_width=7.5, center_frequency=2.0", ) # %% # *** # Let's see what effect this has on the Wavelet Scalogram which is generated... fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) -mwt = nap.compute_wavelet_transform(sig, fs=Fs, freqs=freqs, n_cycles=7.5, scaling=2.0) +mwt = nap.compute_wavelet_transform( + sig, fs=Fs, freqs=freqs, gaussian_width=7.5, window_length=2.0 +) plot_timefrequency( mwt.index.values[:], freqs[:], @@ -368,7 +381,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # *** # And let's see if that has an effect on the reconstructed version of the signal -combined_oscillations = mwt.sum(axis=1) +combined_oscillations = mwt.sum(axis=1).real() # %% # Lets plot it. diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 914a368b..58ad9fbc 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -57,7 +57,7 @@ def compute_spectogram(sig, fs=None, ep=None, full_range=False): return ret -def _morlet(M=1024, ncycles=1.5, scaling=1.0, precision=8): +def _morlet(M=1024, gaussian_width=1.5, window_length=1.0, precision=8): """ Defines the complex Morlet wavelet kernel. @@ -65,10 +65,10 @@ def _morlet(M=1024, ncycles=1.5, scaling=1.0, precision=8): ---------- M : int Length of the wavelet - ncycles : float - number of wavelet cycles to use. Default is 1.5 - scaling: float - Scaling factor. Default is 1.0 + gaussian_width : float + Defines width of Gaussian to be used in wavelet creation. + window_length : float + The length of window to be used for wavelet creation. precision: int. Precision of wavelet to use. Default is 8 @@ -79,9 +79,9 @@ def _morlet(M=1024, ncycles=1.5, scaling=1.0, precision=8): """ x = np.linspace(-precision, precision, M) return ( - ((np.pi * ncycles) ** (-0.25)) - * np.exp(-(x**2) / ncycles) - * np.exp(1j * 2 * np.pi * scaling * x) + ((np.pi * gaussian_width) ** (-0.25)) + * np.exp(-(x**2) / gaussian_width) + * np.exp(1j * 2 * np.pi * window_length * x) ) @@ -114,7 +114,7 @@ def _create_freqs(freq_start, freq_stop, freq_step=1, log_scaling=False, log_bas def compute_wavelet_transform( - sig, freqs, fs=None, n_cycles=1.5, scaling=1.0, precision=16, norm=None + sig, freqs, fs=None, gaussian_width=1.5, window_length=1.0, precision=16, norm="l1" ): """ Compute the time-frequency representation of a signal using Morlet wavelets. @@ -129,19 +129,18 @@ def compute_wavelet_transform( The `freq_step` is optional, and defaults to 1. Range is inclusive of `freq_stop` value. fs : float or None Sampling rate, in Hz. Defaults to sig.rate if None is given. - n_cycles : float or 1d array - Length of the filter, as the number of cycles for each frequency. - If 1d array, this defines n_cycles for each frequency. - scaling : float - Scaling factor. + gaussian_width : float + Defines width of Gaussian to be used in wavelet creation. + window_length : float + The length of window to be used for wavelet creation. precision: int. Precision of wavelet to use. . Defines the number of timepoints to evaluate the Morlet wavelet at. Default is 16 - norm : {None, 'sss', 'amp'}, optional + norm : {None, 'l1', 'l2'}, optional Normalization method: * None - no normalization - * 'sss' - divide by the square root of the sum of squares - * 'amp' - divide by the sum of amplitudes + * 'l1' - divide by the sum of amplitudes + * 'l2' - divide by the square root of the sum of squares Returns ------- @@ -164,9 +163,11 @@ def compute_wavelet_transform( if not isinstance(sig, (nap.Tsd, nap.TsdFrame, nap.TsdTensor)): raise TypeError("`sig` must be instance of Tsd, TsdFrame, or TsdTensor") - if isinstance(n_cycles, (int, float, np.number)): - if n_cycles <= 0: - raise ValueError("Number of cycles must be a positive number.") + if isinstance(gaussian_width, (int, float, np.number)): + if gaussian_width <= 0: + raise ValueError("gaussian_width must be a positive number.") + if norm is not None and norm not in ["l1", "l2"]: + raise ValueError("norm parameter must be 'l1', 'l2', or None.") if isinstance(freqs, (tuple, list)): freqs = _create_freqs(*freqs) @@ -180,18 +181,18 @@ def compute_wavelet_transform( output_shape = (sig.shape[0], len(freqs), *sig.shape[1:]) sig = sig.reshape((sig.shape[0], np.prod(sig.shape[1:]))) - filter_bank = generate_morlet_filterbank(freqs, fs, n_cycles, scaling, precision) + filter_bank = generate_morlet_filterbank( + freqs, fs, gaussian_width, window_length, precision + ) convolved_real = sig.convolve(filter_bank.real().values) convolved_imag = sig.convolve(filter_bank.imag().values) convolved = convolved_real.values + convolved_imag.values * 1j - coef = -np.diff(convolved, axis=0) - if norm == "sss": - coef *= -np.sqrt(scaling) / (freqs / fs) - elif norm == "amp": - coef *= -scaling / (freqs / fs) - coef = np.insert( - coef, 1, coef[0, :], axis=0 - ) # slightly hacky line, necessary to make output correct shape + if norm == "l1": + coef = convolved / (fs / freqs) + elif norm == "l2": + coef = convolved / (fs / np.sqrt(freqs)) + else: + coef = convolved cwt = np.expand_dims(coef, -1) if len(coef.shape) == 2 else coef if len(output_shape) == 2: @@ -204,7 +205,9 @@ def compute_wavelet_transform( ) -def generate_morlet_filterbank(freqs, fs, n_cycles=1.5, scaling=1.0, precision=16): +def generate_morlet_filterbank( + freqs, fs, gaussian_width=1.5, window_length=1.0, precision=16 +): """ Generates a Morlet filterbank using the given frequencies and parameters. Can be used purely for visualization, or to convolve with a pynapple Tsd, TsdFrame, or TsdTensor as part of a wavelet decomposition process. @@ -217,11 +220,10 @@ def generate_morlet_filterbank(freqs, fs, n_cycles=1.5, scaling=1.0, precision=1 The `freq_step` is optional, and defaults to 1. Range is inclusive of `freq_stop` value. fs : float Sampling rate, in Hz. - n_cycles : float or 1d array - Length of the filter, as the number of cycles for each frequency. - If 1d array, this defines n_cycles for each frequency. - scaling : float - Scaling factor. + gaussian_width : float + Defines width of Gaussian to be used in wavelet creation. + window_length : float + The length of window to be used for wavelet creation. precision: int. Precision of wavelet to use. Defines the number of timepoints to evaluate the Morlet wavelet at. @@ -233,13 +235,15 @@ def generate_morlet_filterbank(freqs, fs, n_cycles=1.5, scaling=1.0, precision=1 if len(freqs) == 0: raise ValueError("Given list of freqs cannot be empty.") filter_bank = [] - time_cutoff = 8 - morlet_f = _morlet(int(2**precision), ncycles=n_cycles, scaling=scaling) - x = np.linspace(-time_cutoff, time_cutoff, int(2**precision)) - int_psi = np.conj(_integrate(morlet_f, x[1] - x[0])) + cutoff = 8 + morlet_f = _morlet( + int(2**precision), gaussian_width=gaussian_width, window_length=window_length + ) + x = np.linspace(-cutoff, cutoff, int(2**precision)) + int_psi = np.conj(morlet_f) max_len = -1 for freq in freqs: - scale = scaling / (freq / fs) + scale = window_length / (freq / fs) j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * (x[1] - x[0])) j = j.astype(int) # floor if j[-1] >= int_psi.size: @@ -248,7 +252,7 @@ def generate_morlet_filterbank(freqs, fs, n_cycles=1.5, scaling=1.0, precision=1 if len(int_psi_scale) > max_len: max_len = len(int_psi_scale) time = np.linspace( - -time_cutoff * scaling / freq, time_cutoff * scaling / freq, max_len + -cutoff * window_length / freq, cutoff * window_length / freq, max_len ) filter_bank.append(int_psi_scale) filter_bank = [ @@ -271,7 +275,7 @@ def _integrate(arr, step): arr : np.ndarray wave function to be integrated step : float - Step size of vgiven wave function array + Step size of given wave function array Returns ------- diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index 8f853591..9df76ae4 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -45,10 +45,10 @@ def test_compute_spectogram(): def test_compute_wavelet_transform(): - t = np.linspace(0, 1, 1000) + t = np.linspace(0, 1, 1001) sig = nap.Tsd( d=np.sin(t * 50 * np.pi * 2) - * np.interp(np.linspace(0, 1, 1000), [0, 0.5, 1], [0, 1, 0]), + * np.interp(np.linspace(0, 1, 1001), [0, 0.5, 1], [0, 1, 0]), t=t, ) freqs = np.linspace(10, 100, 10) @@ -60,10 +60,10 @@ def test_compute_wavelet_transform(): == 500 ) - t = np.linspace(0, 1, 1000) + t = np.linspace(0, 1, 1001) sig = nap.Tsd( d=np.sin(t * 10 * np.pi * 2) - * np.interp(np.linspace(0, 1, 1000), [0, 0.5, 1], [0, 1, 0]), + * np.interp(np.linspace(0, 1, 1001), [0, 0.5, 1], [0, 1, 0]), t=t, ) freqs = np.linspace(10, 100, 10) @@ -94,7 +94,7 @@ def test_compute_wavelet_transform(): t = np.linspace(0, 1, 1000) sig = nap.Tsd(d=np.sin(t * 20 * np.pi * 2), t=t) freqs = np.linspace(10, 100, 10) - mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, norm="sss") + mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, norm="l1") mpf = freqs[np.argmax(np.sum(np.abs(mwt), axis=0))] assert mpf == 20 assert mwt.shape == (1000, 10) @@ -102,7 +102,15 @@ def test_compute_wavelet_transform(): t = np.linspace(0, 1, 1000) sig = nap.Tsd(d=np.sin(t * 20 * np.pi * 2), t=t) freqs = np.linspace(10, 100, 10) - mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, norm="amp") + mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, norm="l2") + mpf = freqs[np.argmax(np.sum(np.abs(mwt), axis=0))] + assert mpf == 20 + assert mwt.shape == (1000, 10) + + t = np.linspace(0, 1, 1000) + sig = nap.Tsd(d=np.sin(t * 20 * np.pi * 2), t=t) + freqs = np.linspace(10, 100, 10) + mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, norm=None) mpf = freqs[np.argmax(np.sum(np.abs(mwt), axis=0))] assert mpf == 20 assert mwt.shape == (1000, 10) @@ -139,5 +147,9 @@ def test_compute_wavelet_transform(): assert mwt.shape == (1024, 10, 4, 2) with pytest.raises(ValueError) as e_info: - nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, n_cycles=-1.5) - assert str(e_info.value) == "Number of cycles must be a positive number." + nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, gaussian_width=-1.5) + assert str(e_info.value) == "gaussian_width must be a positive number." + + +if __name__ == "__main__": + test_compute_wavelet_transform() From 059d2c41907527f79a93eccbd36a3d24841a855f Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Wed, 31 Jul 2024 15:29:12 -0400 Subject: [PATCH 46/71] Adding new notebook for psd --- docs/api_guide/tutorial_pynapple_spectrum.py | 113 +++++++++++++++++++ pynapple/process/__init__.py | 3 +- pynapple/process/signal_processing.py | 13 ++- 3 files changed, 127 insertions(+), 2 deletions(-) create mode 100644 docs/api_guide/tutorial_pynapple_spectrum.py diff --git a/docs/api_guide/tutorial_pynapple_spectrum.py b/docs/api_guide/tutorial_pynapple_spectrum.py new file mode 100644 index 00000000..cb6ac699 --- /dev/null +++ b/docs/api_guide/tutorial_pynapple_spectrum.py @@ -0,0 +1,113 @@ +# -*- coding: utf-8 -*- +""" +Power spectral density +====================== + +Working with Wavelets! + +See the [documentation](https://pynapple-org.github.io/pynapple/) of Pynapple for instructions on installing the package. + +This tutorial was made by Kipp Freud. + +""" + +# %% +# !!! warning +# This tutorial uses matplotlib for displaying the figure +# +# You can install all with `pip install matplotlib requests tqdm seaborn` +# +# Now, import the necessary libraries: + +import matplotlib.pyplot as plt +import numpy as np +import seaborn + +seaborn.set_theme() + +import pynapple as nap + +# %% +# *** +# Generating a Dummy Signal +# ------------------ +# Let's generate a dummy signal to analyse with wavelets! +# +# Our dummy dataset will contain two components, a low frequency 2Hz sinusoid combined +# with a sinusoid which increases frequency from 5 to 15 Hz throughout the signal. + +F = [2, 10] + +Fs = 2000 +t = np.arange(0, 100, 1/Fs) +sig = nap.Tsd( + t=t, + d=np.cos(t*2*np.pi*F[0])+np.cos(t*2*np.pi*F[1])+np.random.normal(0, 2, len(t)), + time_support = nap.IntervalSet(0, 10) + ) + +# %% +# Let's plot it +plt.figure() +plt.plot(sig.get(0, 1)) +plt.title("Signal") +plt.show() + + +# %% +# Computing power spectral density (PSD) +# -------------------------------------- +# +# To compute a PSD of a signal, you can use the function `nap.compute_power_spectral_density` + +psd = nap.compute_power_spectral_density(sig) + +# %% +# Pynapple returns a pandas DataFrame. + +print(psd) + +# %% +# It is then easy to plot it. + +plt.figure() +plt.plot(psd) +plt.xlabel("Frequency (Hz)") +plt.ylabel("Amplitude") +plt.show() + +# %% +# Note that the output of the FFT is truncated to positive frequencies. To get positive and negative frequencies, you can set `full_range=True`. +# By default, the function returns the frequencies up to the Nyquist frequency. +# Let's zoom on the first 20 Hz. + +plt.figure() +plt.plot(psd) +plt.xlabel("Frequency (Hz)") +plt.ylabel("Amplitude") +plt.xlim(0, 20) +plt.show() + +# %% +# We find the two frequencies 2 and 10 Hz. +# +# By default, pynapple assumes a constant sampling rate and a single epoch. For example, computing the FFT over more than 1 epoch will raise an error. +double_ep = nap.IntervalSet([0, 50], [20, 100]) + +try: + nap.compute_power_spectral_density(sig, ep=double_ep) +except ValueError as e: + print(e) + + +# %% +# Computing mean PSD +# ------------------ +# +# It is possible to compute an average PSD over multiple epochs with the function `nap.compute_mean_power_spectral_density`. +# +# In this case, the argument `interval_size` determines the duration of each epochs upon which the fft is computed. +# If not epochs is passed, the function will split the `time_support`. + + + diff --git a/pynapple/process/__init__.py b/pynapple/process/__init__.py index a73dea00..53a80375 100644 --- a/pynapple/process/__init__.py +++ b/pynapple/process/__init__.py @@ -16,7 +16,8 @@ shuffle_ts_intervals, ) from .signal_processing import ( - compute_spectogram, + compute_power_spectral_density, + compute_mean_power_spectral_density, compute_wavelet_transform, generate_morlet_filterbank, ) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 58ad9fbc..44e1d558 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -10,7 +10,7 @@ import pynapple as nap -def compute_spectogram(sig, fs=None, ep=None, full_range=False): +def compute_power_spectral_density(sig, fs=None, ep=None, full_range=False): """ Performs numpy fft on sig, returns output. Pynapple assumes a constant sampling rate for sig. @@ -56,6 +56,17 @@ def compute_spectogram(sig, fs=None, ep=None, full_range=False): return ret.loc[ret.index >= 0] return ret +def compute_mean_power_spectral_density(sig, interval_size, fs=None, ep=None, full_range=False): + + if not (ep is None or isinstance(ep, nap.IntervalSet)): + raise TypeError("ep param must be a pynapple IntervalSet object, or None") + if ep is None: + ep = sig.time_support + + # split_ep = ep.split(interval_size) + + + def _morlet(M=1024, gaussian_width=1.5, window_length=1.0, precision=8): """ From eec4740bbe019d3dbfcd4ef284e4341d2d7e221c Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Wed, 31 Jul 2024 16:52:50 -0400 Subject: [PATCH 47/71] Adding mean psd notebook --- docs/api_guide/tutorial_pynapple_spectrum.py | 42 +++++++----- docs/examples/tutorial_signal_processing.py | 2 +- pynapple/process/signal_processing.py | 69 +++++++++++++++++++- 3 files changed, 95 insertions(+), 18 deletions(-) diff --git a/docs/api_guide/tutorial_pynapple_spectrum.py b/docs/api_guide/tutorial_pynapple_spectrum.py index cb6ac699..6fe6d727 100644 --- a/docs/api_guide/tutorial_pynapple_spectrum.py +++ b/docs/api_guide/tutorial_pynapple_spectrum.py @@ -3,12 +3,8 @@ Power spectral density ====================== -Working with Wavelets! - See the [documentation](https://pynapple-org.github.io/pynapple/) of Pynapple for instructions on installing the package. -This tutorial was made by Kipp Freud. - """ # %% @@ -29,12 +25,10 @@ # %% # *** -# Generating a Dummy Signal +# Generating a signal # ------------------ -# Let's generate a dummy signal to analyse with wavelets! +# Let's generate a dummy signal with 2Hz and 10Hz sinusoide with white noise. # -# Our dummy dataset will contain two components, a low frequency 2Hz sinusoid combined -# with a sinusoid which increases frequency from 5 to 15 Hz throughout the signal. F = [2, 10] @@ -42,16 +36,17 @@ t = np.arange(0, 100, 1/Fs) sig = nap.Tsd( t=t, - d=np.cos(t*2*np.pi*F[0])+np.cos(t*2*np.pi*F[1])+np.random.normal(0, 2, len(t)), - time_support = nap.IntervalSet(0, 10) + d=np.cos(t*2*np.pi*F[0])+np.cos(t*2*np.pi*F[1])+np.random.normal(0, 3, len(t)), + time_support = nap.IntervalSet(0, 100) ) # %% # Let's plot it plt.figure() -plt.plot(sig.get(0, 1)) +plt.plot(sig.get(0, 0.4)) plt.title("Signal") -plt.show() +plt.xlabel("Time (s)") + # %% @@ -74,7 +69,7 @@ plt.plot(psd) plt.xlabel("Frequency (Hz)") plt.ylabel("Amplitude") -plt.show() + # %% # Note that the output of the FFT is truncated to positive frequencies. To get positive and negative frequencies, you can set `full_range=True`. @@ -86,7 +81,7 @@ plt.xlabel("Frequency (Hz)") plt.ylabel("Amplitude") plt.xlim(0, 20) -plt.show() + # %% # We find the two frequencies 2 and 10 Hz. @@ -106,8 +101,25 @@ # # It is possible to compute an average PSD over multiple epochs with the function `nap.compute_mean_power_spectral_density`. # -# In this case, the argument `interval_size` determines the duration of each epochs upon which the fft is computed. +# In this case, the argument `interval_size` determines the duration of each epochs upon which the FFT is computed. # If not epochs is passed, the function will split the `time_support`. +# +# In this case, the FFT will be computed over epochs of 10 seconds. + +mean_psd = nap.compute_mean_power_spectral_density(sig, interval_size=10.0) +# %% +# Let's compare `mean_psd` to `psd`. + +plt.figure() +plt.plot(psd) +plt.plot(mean_psd) +plt.xlabel("Frequency (Hz)") +plt.ylabel("Amplitude") +plt.xlim(0, 20) + +# %% +# As we can see, `nap.compute_mean_power_spectral_density` was able to smooth out the noise. + diff --git a/docs/examples/tutorial_signal_processing.py b/docs/examples/tutorial_signal_processing.py index d9b66be9..2453e0f2 100644 --- a/docs/examples/tutorial_signal_processing.py +++ b/docs/examples/tutorial_signal_processing.py @@ -131,7 +131,7 @@ # ----------------------------------- # Let's take the Fourier transform of our data to get an initial insight into the dominant frequencies. -fft = nap.compute_spectogram(RUN_Tsd, fs=int(FS)) +fft = nap.compute_power_spectral_density(RUN_Tsd, fs=int(FS)) print(fft) # %% diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 44e1d558..2a994f04 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -56,16 +56,81 @@ def compute_power_spectral_density(sig, fs=None, ep=None, full_range=False): return ret.loc[ret.index >= 0] return ret -def compute_mean_power_spectral_density(sig, interval_size, fs=None, ep=None, full_range=False): +def compute_mean_power_spectral_density(sig, interval_size, fs=None, ep=None, full_range=False, time_units="s"): + """Compute mean power spectral density by averaging FFT over epochs of same size. + The parameter `interval_size` controls the duration of the epochs. + + Note that this function assumes a constant sampling rate for sig. + Parameters + ---------- + sig : TYPE + Description + interval_size : TYPE + Description + fs : None, optional + Description + ep : None, optional + Description + full_range : bool, optional + Description + time_units : str, optional + Description + + Returns + ------- + TYPE + Description + + Raises + ------ + RuntimeError + Description + TypeError + Description + """ if not (ep is None or isinstance(ep, nap.IntervalSet)): raise TypeError("ep param must be a pynapple IntervalSet object, or None") if ep is None: ep = sig.time_support + if fs is None: + fs = sig.rate + + # Split the ep + split_ep = ep.split(interval_size) + + if len(split_ep) == 0: + raise RuntimeError(f"Splitting epochs with interval_size={interval_size} generated an empty IntervalSet. Try decreasing interval_size") + + # Get the slices of each ep + slices = np.zeros((len(split_ep),2), dtype=int) - # split_ep = ep.split(interval_size) + for i in range(len(split_ep)): + sl = sig.get_slice(split_ep[i,0], split_ep[i,1]) + slices[i,0] = sl.start + slices[i,1] = sl.stop + + # Check what is the signal length + N = np.min(np.diff(slices, 1)) + if N == 0: + raise RuntimeError(f"One epoch doesn't have any signal. Check the parameter ep or the time support if no epoch is passed.") + # Get the freqs + fft_freq = np.fft.fftfreq(N, 1 / fs) + + # Compute the fft + fft_result = np.zeros((N, *sig.shape[1:]), dtype=complex) + + for i in range(len(slices)): + fft_result += np.fft.fft(sig[slices[i,0]:slices[i,1]].values[0:N], axis=0) + + ret = pd.DataFrame(fft_result, fft_freq) + ret.sort_index(inplace=True) + if not full_range: + return ret.loc[ret.index >= 0] + return ret + def _morlet(M=1024, gaussian_width=1.5, window_length=1.0, precision=8): From 82dfc43adce03e6fcd4074f5454b51e4577c1d9a Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Thu, 1 Aug 2024 13:00:26 -0400 Subject: [PATCH 48/71] pushing some failing tests --- pynapple/process/signal_processing.py | 2 +- tests/test_power_spectral_density.py | 60 +++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 1 deletion(-) create mode 100644 tests/test_power_spectral_density.py diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 2a994f04..59af1130 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -106,7 +106,7 @@ def compute_mean_power_spectral_density(sig, interval_size, fs=None, ep=None, fu slices = np.zeros((len(split_ep),2), dtype=int) for i in range(len(split_ep)): - sl = sig.get_slice(split_ep[i,0], split_ep[i,1]) + sl = sig.get_slice(split_ep[i,0], split_ep[i,1], mode='backward') slices[i,0] = sl.start slices[i,1] = sl.stop diff --git a/tests/test_power_spectral_density.py b/tests/test_power_spectral_density.py new file mode 100644 index 00000000..a604315e --- /dev/null +++ b/tests/test_power_spectral_density.py @@ -0,0 +1,60 @@ +import numpy as np +import pandas as pd +import pytest + +import pynapple as nap + + +def get_signal_and_output(f=2, fs=1000,duration=100,interval_size=10): + t=np.arange(0, duration, 1/fs) + d=np.cos(2*np.pi*f*t) + sig = nap.Tsd(t=t,d=d, time_support=nap.IntervalSet(0,100)) + tmp = d.reshape((int(duration/interval_size),int(fs*interval_size))).T + out = np.sum(np.fft.fft(tmp, axis=0), 1) + freq = np.fft.fftfreq(out.shape[0], 1 / fs) + order = np.argsort(freq) + out = out[order] + freq = freq[order] + return (sig, out, freq) + + +def test_basic(): + sig, out, freq = get_signal_and_output() + + psd = nap.compute_mean_power_spectral_density(sig, 10) + + assert isinstance(psd, pd.DataFrame) + assert psd.shape[0] > 0 # Check that the psd DataFrame is not empty + np.testing.assert_array_almost_equal(psd.values.flatten(), out[freq>=0]) + np.testing.assert_array_almost_equal(psd.index.values, freq[freq>=0]) + + + + + +@pytest.mark.parametrize("interval_size, expected_exception", [ + (10, None), # Regular case + (200, RuntimeError), # Interval size too large + (1, RuntimeError) # Epoch too small +]) +@setup_signal_and_params +def test_compute_mean_power_spectral_density(sig, interval_size, expected_exception): + if expected_exception: + with pytest.raises(expected_exception): + compute_mean_power_spectral_density(sig, interval_size) + else: + psd = compute_mean_power_spectral_density(sig, interval_size) + assert isinstance(psd, pd.DataFrame) + assert psd.shape[0] > 0 # Check that the psd DataFrame is not empty + +@pytest.mark.parametrize("full_range", [True, False]) +@setup_signal_and_params +def test_full_range_option(sig, full_range): + interval_size = 10 # Choose a valid interval size for this test + + psd = compute_mean_power_spectral_density(sig, interval_size, full_range=full_range) + + if full_range: + assert (psd.index >= 0).all() + else: + assert (psd.index >= 0).any() # Partial range should exclude negative frequencies From 6c3d6e532577d83cb1643d2e8dccd4cb52dc7bc6 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Thu, 1 Aug 2024 19:44:52 -0400 Subject: [PATCH 49/71] Adding tests for mean PSD --- pynapple/process/signal_processing.py | 45 +++++---- tests/test_power_spectral_density.py | 130 +++++++++++++++++--------- tests/test_signal_processing.py | 36 ------- 3 files changed, 114 insertions(+), 97 deletions(-) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 59af1130..b7a5f8cd 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -6,8 +6,8 @@ import numpy as np import pandas as pd - -import pynapple as nap +from numbers import Number +from .. import core as nap def compute_power_spectral_density(sig, fs=None, ep=None, full_range=False): @@ -56,7 +56,7 @@ def compute_power_spectral_density(sig, fs=None, ep=None, full_range=False): return ret.loc[ret.index >= 0] return ret -def compute_mean_power_spectral_density(sig, interval_size, fs=None, ep=None, full_range=False, time_units="s"): +def compute_mean_power_spectral_density(sig, interval_size, fs=None, ep=None, full_range=False, time_unit="s"): """Compute mean power spectral density by averaging FFT over epochs of same size. The parameter `interval_size` controls the duration of the epochs. @@ -64,46 +64,53 @@ def compute_mean_power_spectral_density(sig, interval_size, fs=None, ep=None, fu Parameters ---------- - sig : TYPE - Description - interval_size : TYPE - Description + sig : Tsd or TsdFrame + Signal with equispaced samples + interval_size : Number + Epochs size to compute to average the FFT across fs : None, optional - Description + Sampling frequency of `sig`. If `None`, `fs` is equal to `sig.rate` ep : None, optional - Description + The `IntervalSet` to calculate the fft on. Can be any length. full_range : bool, optional - Description - time_units : str, optional - Description + If true, will return full fft frequency range, otherwise will return only positive values + time_unit : str, optional + Time units for parameter `interval_size`. Can be ('s'[default], 'ms', 'us') Returns ------- - TYPE - Description + pandas.DataFrame + Power spectral density. Raises ------ RuntimeError - Description + If splitting the epoch with `interval_size` results in an empty set. TypeError - Description + If `ep` or `sig` are not respectively pynapple time series or interval set. """ if not (ep is None or isinstance(ep, nap.IntervalSet)): raise TypeError("ep param must be a pynapple IntervalSet object, or None") if ep is None: ep = sig.time_support + + if not (fs is None or isinstance(fs, Number)): + raise TypeError("fs must be of type float or int") if fs is None: - fs = sig.rate + fs = sig.rate + + if not isinstance(full_range, bool): + raise TypeError("full_range must be of type bool or None") # Split the ep + interval_size = nap.TsIndex.format_timestamps(np.array([interval_size]), time_unit)[0] split_ep = ep.split(interval_size) if len(split_ep) == 0: raise RuntimeError(f"Splitting epochs with interval_size={interval_size} generated an empty IntervalSet. Try decreasing interval_size") # Get the slices of each ep - slices = np.zeros((len(split_ep),2), dtype=int) + slices = np.zeros((len(split_ep),2), dtype=int) for i in range(len(split_ep)): sl = sig.get_slice(split_ep[i,0], split_ep[i,1], mode='backward') @@ -114,7 +121,7 @@ def compute_mean_power_spectral_density(sig, interval_size, fs=None, ep=None, fu N = np.min(np.diff(slices, 1)) if N == 0: - raise RuntimeError(f"One epoch doesn't have any signal. Check the parameter ep or the time support if no epoch is passed.") + raise RuntimeError("One interval doesn't have any signal associated. Check the parameter ep or the time support if no epoch is passed.") # Get the freqs fft_freq = np.fft.fftfreq(N, 1 / fs) diff --git a/tests/test_power_spectral_density.py b/tests/test_power_spectral_density.py index a604315e..0eb39cdd 100644 --- a/tests/test_power_spectral_density.py +++ b/tests/test_power_spectral_density.py @@ -1,60 +1,106 @@ import numpy as np import pandas as pd import pytest - +from contextlib import nullcontext as does_not_raise import pynapple as nap +############################################################ +# Test for mean_power_spectral_density +############################################################ -def get_signal_and_output(f=2, fs=1000,duration=100,interval_size=10): - t=np.arange(0, duration, 1/fs) - d=np.cos(2*np.pi*f*t) - sig = nap.Tsd(t=t,d=d, time_support=nap.IntervalSet(0,100)) - tmp = d.reshape((int(duration/interval_size),int(fs*interval_size))).T - out = np.sum(np.fft.fft(tmp, axis=0), 1) - freq = np.fft.fftfreq(out.shape[0], 1 / fs) - order = np.argsort(freq) - out = out[order] - freq = freq[order] - return (sig, out, freq) +def test_compute_power_spectral_density(): + with pytest.raises(ValueError) as e_info: + t = np.linspace(0, 1, 1000) + sig = nap.Tsd( + d=np.random.random(1000), + t=t, + time_support=nap.IntervalSet(start=[0.1, 0.6], end=[0.2, 0.81]), + ) + r = nap.compute_power_spectral_density(sig) + assert ( + str(e_info.value) == "Given epoch (or signal time_support) must have length 1" + ) + + t = np.linspace(0, 1, 1000) + sig = nap.Tsd(d=np.random.random(1000), t=t) + r = nap.compute_power_spectral_density(sig) + assert isinstance(r, pd.DataFrame) + assert r.shape[0] == 500 + + sig = nap.TsdFrame(d=np.random.random((1000, 4)), t=t) + r = nap.compute_power_spectral_density(sig) + assert isinstance(r, pd.DataFrame) + assert r.shape == (500, 4) + sig = nap.TsdFrame(d=np.random.random((1000, 4)), t=t) + r = nap.compute_power_spectral_density(sig, full_range=True) + assert isinstance(r, pd.DataFrame) + assert r.shape == (1000, 4) -def test_basic(): - sig, out, freq = get_signal_and_output() + with pytest.raises(TypeError) as e_info: + nap.compute_power_spectral_density("a_string") + assert ( + str(e_info.value) + == "Currently compute_spectogram is only implemented for Tsd or TsdFrame" + ) + +############################################################ +# Test for mean_power_spectral_density +############################################################ + +def get_signal_and_output(f=2, fs=1000,duration=100,interval_size=10): + t=np.arange(0, duration, 1/fs) + d=np.cos(2*np.pi*f*t) + sig = nap.Tsd(t=t,d=d, time_support=nap.IntervalSet(0,100)) + tmp = d.reshape((int(duration/interval_size),int(fs*interval_size))).T + tmp = tmp[0:-1] + out = np.sum(np.fft.fft(tmp, axis=0), 1) + freq = np.fft.fftfreq(out.shape[0], 1 / fs) + order = np.argsort(freq) + out = out[order] + freq = freq[order] + return (sig, out, freq) + +def test_compute_mean_power_spectral_density(): + sig, out, freq = get_signal_and_output() psd = nap.compute_mean_power_spectral_density(sig, 10) - assert isinstance(psd, pd.DataFrame) assert psd.shape[0] > 0 # Check that the psd DataFrame is not empty np.testing.assert_array_almost_equal(psd.values.flatten(), out[freq>=0]) np.testing.assert_array_almost_equal(psd.index.values, freq[freq>=0]) + # Full range + psd = nap.compute_mean_power_spectral_density(sig, 10, full_range=True) + assert isinstance(psd, pd.DataFrame) + assert psd.shape[0] > 0 # Check that the psd DataFrame is not empty + np.testing.assert_array_almost_equal(psd.values.flatten(), out) + np.testing.assert_array_almost_equal(psd.index.values, freq) + # TsdFrame + sig2 = nap.TsdFrame(t=sig.t, d=np.repeat(sig.values[:,None], 2, 1), time_support = sig.time_support) + psd = nap.compute_mean_power_spectral_density(sig2, 10, full_range=True) + assert isinstance(psd, pd.DataFrame) + assert psd.shape[0] > 0 # Check that the psd DataFrame is not empty + np.testing.assert_array_almost_equal(psd.values, np.repeat(out[:,None],2,1)) + np.testing.assert_array_almost_equal(psd.index.values, freq) - -@pytest.mark.parametrize("interval_size, expected_exception", [ - (10, None), # Regular case - (200, RuntimeError), # Interval size too large - (1, RuntimeError) # Epoch too small -]) -@setup_signal_and_params -def test_compute_mean_power_spectral_density(sig, interval_size, expected_exception): - if expected_exception: - with pytest.raises(expected_exception): - compute_mean_power_spectral_density(sig, interval_size) - else: - psd = compute_mean_power_spectral_density(sig, interval_size) - assert isinstance(psd, pd.DataFrame) - assert psd.shape[0] > 0 # Check that the psd DataFrame is not empty - -@pytest.mark.parametrize("full_range", [True, False]) -@setup_signal_and_params -def test_full_range_option(sig, full_range): - interval_size = 10 # Choose a valid interval size for this test - - psd = compute_mean_power_spectral_density(sig, interval_size, full_range=full_range) - - if full_range: - assert (psd.index >= 0).all() - else: - assert (psd.index >= 0).any() # Partial range should exclude negative frequencies +@pytest.mark.parametrize( + "sig, out, freq, interval_size, fs, ep, full_range, time_units, expectation", + [ + (*get_signal_and_output(), 10, None, None, False, "s", does_not_raise()), + (*get_signal_and_output(), 10, "a", None, False, "s", pytest.raises(TypeError, match="fs must be of type float or int")), + (*get_signal_and_output(), 10, None, "a", False, "s", pytest.raises(TypeError, match="ep param must be a pynapple IntervalSet object, or None")), + (*get_signal_and_output(), 10, None, None, "a", "s", pytest.raises(TypeError, match="full_range must be of type bool or None")), + (*get_signal_and_output(), 10*1e3, None, None, False, "ms", does_not_raise()), + (*get_signal_and_output(), 10*1e6, None, None, False, "us", does_not_raise()), + (*get_signal_and_output(), 200, None, None, False, "s", pytest.raises(RuntimeError, match="Splitting epochs with interval_size=200 generated an empty IntervalSet. Try decreasing interval_size")), + (*get_signal_and_output(), 10, None, nap.IntervalSet([0, 200], [100,300]), False, "s", pytest.raises(RuntimeError, match="One interval doesn't have any signal associated. Check the parameter ep or the time support if no epoch is passed.")), + ] +) +def test_compute_mean_power_spectral_density_raise_errors( + sig, out, freq, interval_size, fs, ep, full_range, time_units, expectation + ): + with expectation: + psd = nap.compute_mean_power_spectral_density(sig, interval_size, fs, ep, full_range, time_units) diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index 9df76ae4..3bacbb50 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -7,42 +7,6 @@ import pynapple as nap -def test_compute_spectogram(): - with pytest.raises(ValueError) as e_info: - t = np.linspace(0, 1, 1000) - sig = nap.Tsd( - d=np.random.random(1000), - t=t, - time_support=nap.IntervalSet(start=[0.1, 0.6], end=[0.2, 0.81]), - ) - r = nap.compute_spectogram(sig) - assert ( - str(e_info.value) == "Given epoch (or signal time_support) must have length 1" - ) - - t = np.linspace(0, 1, 1000) - sig = nap.Tsd(d=np.random.random(1000), t=t) - r = nap.compute_spectogram(sig) - assert isinstance(r, pd.DataFrame) - assert r.shape[0] == 500 - - sig = nap.TsdFrame(d=np.random.random((1000, 4)), t=t) - r = nap.compute_spectogram(sig) - assert isinstance(r, pd.DataFrame) - assert r.shape == (500, 4) - - sig = nap.TsdFrame(d=np.random.random((1000, 4)), t=t) - r = nap.compute_spectogram(sig, full_range=True) - assert isinstance(r, pd.DataFrame) - assert r.shape == (1000, 4) - - with pytest.raises(TypeError) as e_info: - nap.compute_spectogram("a_string") - assert ( - str(e_info.value) - == "Currently compute_spectogram is only implemented for Tsd or TsdFrame" - ) - def test_compute_wavelet_transform(): t = np.linspace(0, 1, 1001) From c8a5fc47115f79e817f40cf0e4fa839a564be255 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Thu, 1 Aug 2024 19:50:26 -0400 Subject: [PATCH 50/71] linting --- pynapple/process/__init__.py | 2 +- pynapple/process/signal_processing.py | 46 ++++++++++++++++----------- 2 files changed, 29 insertions(+), 19 deletions(-) diff --git a/pynapple/process/__init__.py b/pynapple/process/__init__.py index 53a80375..221ce039 100644 --- a/pynapple/process/__init__.py +++ b/pynapple/process/__init__.py @@ -16,8 +16,8 @@ shuffle_ts_intervals, ) from .signal_processing import ( - compute_power_spectral_density, compute_mean_power_spectral_density, + compute_power_spectral_density, compute_wavelet_transform, generate_morlet_filterbank, ) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index b7a5f8cd..4adf0567 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -4,9 +4,11 @@ Contains functionality for signal processing pynapple object; fourier transforms and wavelet decomposition. """ +from numbers import Number + import numpy as np import pandas as pd -from numbers import Number + from .. import core as nap @@ -56,10 +58,13 @@ def compute_power_spectral_density(sig, fs=None, ep=None, full_range=False): return ret.loc[ret.index >= 0] return ret -def compute_mean_power_spectral_density(sig, interval_size, fs=None, ep=None, full_range=False, time_unit="s"): - """Compute mean power spectral density by averaging FFT over epochs of same size. + +def compute_mean_power_spectral_density( + sig, interval_size, fs=None, ep=None, full_range=False, time_unit="s" +): + """Compute mean power spectral density by averaging FFT over epochs of same size. The parameter `interval_size` controls the duration of the epochs. - + Note that this function assumes a constant sampling rate for sig. Parameters @@ -76,12 +81,12 @@ def compute_mean_power_spectral_density(sig, interval_size, fs=None, ep=None, fu If true, will return full fft frequency range, otherwise will return only positive values time_unit : str, optional Time units for parameter `interval_size`. Can be ('s'[default], 'ms', 'us') - + Returns ------- pandas.DataFrame Power spectral density. - + Raises ------ RuntimeError @@ -93,7 +98,7 @@ def compute_mean_power_spectral_density(sig, interval_size, fs=None, ep=None, fu raise TypeError("ep param must be a pynapple IntervalSet object, or None") if ep is None: ep = sig.time_support - + if not (fs is None or isinstance(fs, Number)): raise TypeError("fs must be of type float or int") if fs is None: @@ -103,25 +108,31 @@ def compute_mean_power_spectral_density(sig, interval_size, fs=None, ep=None, fu raise TypeError("full_range must be of type bool or None") # Split the ep - interval_size = nap.TsIndex.format_timestamps(np.array([interval_size]), time_unit)[0] + interval_size = nap.TsIndex.format_timestamps(np.array([interval_size]), time_unit)[ + 0 + ] split_ep = ep.split(interval_size) if len(split_ep) == 0: - raise RuntimeError(f"Splitting epochs with interval_size={interval_size} generated an empty IntervalSet. Try decreasing interval_size") - + raise RuntimeError( + f"Splitting epochs with interval_size={interval_size} generated an empty IntervalSet. Try decreasing interval_size" + ) + # Get the slices of each ep - slices = np.zeros((len(split_ep),2), dtype=int) + slices = np.zeros((len(split_ep), 2), dtype=int) for i in range(len(split_ep)): - sl = sig.get_slice(split_ep[i,0], split_ep[i,1], mode='backward') - slices[i,0] = sl.start - slices[i,1] = sl.stop - + sl = sig.get_slice(split_ep[i, 0], split_ep[i, 1], mode="backward") + slices[i, 0] = sl.start + slices[i, 1] = sl.stop + # Check what is the signal length N = np.min(np.diff(slices, 1)) if N == 0: - raise RuntimeError("One interval doesn't have any signal associated. Check the parameter ep or the time support if no epoch is passed.") + raise RuntimeError( + "One interval doesn't have any signal associated. Check the parameter ep or the time support if no epoch is passed." + ) # Get the freqs fft_freq = np.fft.fftfreq(N, 1 / fs) @@ -130,14 +141,13 @@ def compute_mean_power_spectral_density(sig, interval_size, fs=None, ep=None, fu fft_result = np.zeros((N, *sig.shape[1:]), dtype=complex) for i in range(len(slices)): - fft_result += np.fft.fft(sig[slices[i,0]:slices[i,1]].values[0:N], axis=0) + fft_result += np.fft.fft(sig[slices[i, 0] : slices[i, 1]].values[0:N], axis=0) ret = pd.DataFrame(fft_result, fft_freq) ret.sort_index(inplace=True) if not full_range: return ret.loc[ret.index >= 0] return ret - def _morlet(M=1024, gaussian_width=1.5, window_length=1.0, precision=8): From 65a1a1a4a155b4066626300ddf91ef383aef0fa5 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Fri, 2 Aug 2024 16:26:48 +0100 Subject: [PATCH 51/71] param name changes --- docs/api_guide/tutorial_pynapple_wavelets.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/api_guide/tutorial_pynapple_wavelets.py b/docs/api_guide/tutorial_pynapple_wavelets.py index d34de4f2..c0f13862 100644 --- a/docs/api_guide/tutorial_pynapple_wavelets.py +++ b/docs/api_guide/tutorial_pynapple_wavelets.py @@ -291,7 +291,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # %% # *** -# Effect of time_decay +# Effect of gaussian_width # ------------------ # Let's increase time_decay to 7.5 and see the effect on the resultant filter bank. @@ -347,9 +347,9 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # %% # *** -# Effect of scaling +# Effect of window_length # ------------------ -# Let's increase scaling to 2.0 and see the effect on the resultant filter bank. +# Let's increase window_length to 2.0 and see the effect on the resultant filter bank. freqs = np.linspace(1, 25, num=25) filter_bank = nap.generate_morlet_filterbank( From b16c7c80a6db1eca021e637c9aa77be4be2c92fd Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Fri, 2 Aug 2024 11:54:24 -0400 Subject: [PATCH 52/71] fixed notebooks --- docs/api_guide/tutorial_pynapple_spectrum.py | 2 + docs/examples/tutorial_human_dataset.py | 39 +++++++------------- pynapple/io/folder.py | 6 +-- pynapple/io/interface_npz.py | 11 ++++-- 4 files changed, 26 insertions(+), 32 deletions(-) diff --git a/docs/api_guide/tutorial_pynapple_spectrum.py b/docs/api_guide/tutorial_pynapple_spectrum.py index 6fe6d727..bcb64c50 100644 --- a/docs/api_guide/tutorial_pynapple_spectrum.py +++ b/docs/api_guide/tutorial_pynapple_spectrum.py @@ -14,6 +14,8 @@ # You can install all with `pip install matplotlib requests tqdm seaborn` # # Now, import the necessary libraries: +# +# mkdocs_gallery_thumbnail_number = 3 import matplotlib.pyplot as plt import numpy as np diff --git a/docs/examples/tutorial_human_dataset.py b/docs/examples/tutorial_human_dataset.py index f84cbef6..caeb5f0e 100644 --- a/docs/examples/tutorial_human_dataset.py +++ b/docs/examples/tutorial_human_dataset.py @@ -189,36 +189,23 @@ # ------------------ # # Now that we have the PETH of spiking, we can go one step further. We will plot the mean firing rate of this cell aligned to the boundary for each trial type. Doing this in Pynapple is very simple! - -bin_size = 0.2 # 200ms bin size -step_size = 0.01 # 10ms step size, to make overlapping bins -winsize = int(bin_size / step_size) # Window size - -# %% +# # Use Pynapple to compute binned spike counts - -counts_NB = NB_peth.count(step_size) # Spike counts binned in 10ms steps, for NB trials -counts_HB = HB_peth.count(step_size) # Spike counts binned in 10ms steps, for HB trials +bin_size = 0.01 +counts_NB = NB_peth.count(bin_size) # Spike counts binned in 10ms steps, for NB trials +counts_HB = HB_peth.count(bin_size) # Spike counts binned in 10ms steps, for HB trials # %% -# Smooth the binned spike counts using a window of size 20, for both trial types +# Compute firing rate for both trial types -counts_NB = ( - counts_NB.as_dataframe() - .rolling(winsize, win_type="gaussian", min_periods=1, center=True, axis=0) - .mean(std=0.2 * winsize) -) -counts_HB = ( - counts_HB.as_dataframe() - .rolling(winsize, win_type="gaussian", min_periods=1, center=True, axis=0) - .mean(std=0.2 * winsize) -) +fr_NB = counts_NB / bin_size +fr_HB = counts_HB / bin_size # %% -# Compute firing rate for both trial types +# Smooth the firing rate with a gaussian window with std=4*bin_size +counts_NB = counts_NB.smooth(bin_size*4) +counts_HB = counts_HB.smooth(bin_size*4) -fr_NB = counts_NB * winsize -fr_HB = counts_HB * winsize # %% # Compute the mean firing rate for both trial types @@ -228,9 +215,9 @@ # %% # Compute standard error of mean (SEM) of the firing rate for both trial types - -error_NB = fr_NB.sem(axis=1) -error_HB = fr_HB.sem(axis=1) +from scipy.stats import sem +error_NB = sem(fr_NB, axis=1) +error_HB = sem(fr_HB, axis=1) # %% # Plot the mean +/- SEM of firing rate for both trial types diff --git a/pynapple/io/folder.py b/pynapple/io/folder.py index de1b9ef5..8f7d2f1a 100644 --- a/pynapple/io/folder.py +++ b/pynapple/io/folder.py @@ -4,7 +4,7 @@ # @Author: Guillaume Viejo # @Date: 2023-05-15 15:32:24 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2023-08-06 17:37:23 +# @Last Modified time: 2024-08-02 11:35:10 """ The Folder class helps to navigate a hierarchical data tree. @@ -302,12 +302,12 @@ def metadata(self, name): with open(json_filename, "r") as ff: metadata = json.load(ff) text = "\n".join([" : ".join(it) for it in metadata.items()]) - panel = Panel.fit(text, border_style="green", title=title) + panel = Panel.fit(text, border_style="green", title=str(title)) else: panel = Panel.fit( "No metadata", border_style="red", - title=title, + title=str(title), ) with Console() as console: console.print(panel) diff --git a/pynapple/io/interface_npz.py b/pynapple/io/interface_npz.py index 22da63bd..cedb779b 100644 --- a/pynapple/io/interface_npz.py +++ b/pynapple/io/interface_npz.py @@ -4,7 +4,7 @@ # @Author: Guillaume Viejo # @Date: 2023-07-05 16:03:25 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2024-04-02 14:32:25 +# @Last Modified time: 2024-08-02 11:16:07 from pathlib import Path @@ -26,10 +26,15 @@ def _find_class_from_variables(file_variables, data_ndims=None): if data_ndims is not None: - # either TsdTensor or Tsd: + assert EXPECTED_ENTRIES["Tsd"].issubset(file_variables) - return "Tsd" if data_ndims == 1 else "TsdTensor" + if data_ndims == 1: + return "Tsd" + elif data_ndims == 2: + return "TsdFrame" + else: + return "TsdTensor" for possible_type, expected_variables in EXPECTED_ENTRIES.items(): if expected_variables.issubset(file_variables): From 74d9061aa1aacf535e4cdd7e22db13bbfa67c46a Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Mon, 5 Aug 2024 15:14:09 +0100 Subject: [PATCH 53/71] better tests --- pynapple/process/signal_processing.py | 4 +- tests/test_signal_processing.py | 203 +++++++++++++++++++++++++- 2 files changed, 199 insertions(+), 8 deletions(-) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 4adf0567..25ef8448 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -233,7 +233,7 @@ def compute_wavelet_transform( Normalization method: * None - no normalization * 'l1' - divide by the sum of amplitudes - * 'l2' - divide by the square root of the sum of squares + * 'l2' - divide by the square root of the sum of amplitudes Returns ------- @@ -327,6 +327,8 @@ def generate_morlet_filterbank( """ if len(freqs) == 0: raise ValueError("Given list of freqs cannot be empty.") + if np.min(freqs) <= 0: + raise ValueError("All frequencies in freqs must be strictly positive") filter_bank = [] cutoff = 8 morlet_f = _morlet( diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index 3bacbb50..b3bae4b9 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -1,12 +1,98 @@ """Tests of `signal_processing` for pynapple""" import numpy as np -import pandas as pd import pytest import pynapple as nap +def test_generate_morlet_filterbank(): + fs = 1000 + freqs = np.linspace(10, 100, 10) + fb = nap.generate_morlet_filterbank( + freqs, fs, gaussian_width=1.5, window_length=1.0, precision=16 + ) + power = np.abs(nap.compute_power_spectral_density(fb)) + for i, f in enumerate(freqs): + assert power.iloc[:, i].argmax() == np.abs(power.index - f).argmin() + + fs = 10000 + freqs = np.linspace(100, 1000, 10) + fb = nap.generate_morlet_filterbank( + freqs, fs, gaussian_width=1.5, window_length=1.0, precision=16 + ) + power = np.abs(nap.compute_power_spectral_density(fb)) + for i, f in enumerate(freqs): + assert power.iloc[:, i].argmax() == np.abs(power.index - f).argmin() + + fs = 1000 + freqs = np.linspace(10, 100, 10) + fb = nap.generate_morlet_filterbank( + freqs, fs, gaussian_width=3.5, window_length=1.0, precision=16 + ) + power = np.abs(nap.compute_power_spectral_density(fb)) + for i, f in enumerate(freqs): + assert power.iloc[:, i].argmax() == np.abs(power.index - f).argmin() + + fs = 10000 + freqs = np.linspace(100, 1000, 10) + fb = nap.generate_morlet_filterbank( + freqs, fs, gaussian_width=3.5, window_length=1.0, precision=16 + ) + power = np.abs(nap.compute_power_spectral_density(fb)) + for i, f in enumerate(freqs): + assert power.iloc[:, i].argmax() == np.abs(power.index - f).argmin() + + fs = 1000 + freqs = np.linspace(10, 100, 10) + fb = nap.generate_morlet_filterbank( + freqs, fs, gaussian_width=1.5, window_length=3.0, precision=16 + ) + power = np.abs(nap.compute_power_spectral_density(fb)) + for i, f in enumerate(freqs): + assert power.iloc[:, i].argmax() == np.abs(power.index - f).argmin() + + fs = 10000 + freqs = np.linspace(100, 1000, 10) + fb = nap.generate_morlet_filterbank( + freqs, fs, gaussian_width=1.5, window_length=3.0, precision=16 + ) + power = np.abs(nap.compute_power_spectral_density(fb)) + for i, f in enumerate(freqs): + assert power.iloc[:, i].argmax() == np.abs(power.index - f).argmin() + + +@pytest.mark.parametrize( + "freqs, fs, gaussian_width, window_length, precision, expectation", + [ + ( + np.linspace(0, 100, 11), + 1000, + 1.5, + 1.0, + 16, + pytest.raises( + ValueError, match="All frequencies in freqs must be strictly positive" + ), + ), + ( + [], + 1000, + 1.5, + 1.0, + 16, + pytest.raises(ValueError, match="Given list of freqs cannot be empty."), + ), + ], +) +def test_generate_morlet_filterbank_raise_errors( + freqs, fs, gaussian_width, window_length, precision, expectation +): + with expectation: + _ = nap.generate_morlet_filterbank( + freqs, fs, gaussian_width, window_length, precision + ) + def test_compute_wavelet_transform(): t = np.linspace(0, 1, 1001) @@ -99,21 +185,124 @@ def test_compute_wavelet_transform(): mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) assert mwt.shape == (1024, 6) + t = np.linspace(0, 1, 1024) sig = nap.TsdFrame(d=np.random.random((1024, 4)), t=t) freqs = np.linspace(1, 600, 10) mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) assert mwt.shape == (1024, 10, 4) - t = np.linspace(0, 1, 1024) # can remove this when we move it + t = np.linspace(0, 1, 1024) sig = nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=t) freqs = np.linspace(1, 600, 10) mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) assert mwt.shape == (1024, 10, 4, 2) - with pytest.raises(ValueError) as e_info: - nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, gaussian_width=-1.5) - assert str(e_info.value) == "gaussian_width must be a positive number." + # Testing against manual convolution for l1 norm + t = np.linspace(0, 1, 1024) + sig = nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=t) + freqs = np.linspace(1, 600, 10) + mwt = nap.compute_wavelet_transform( + sig, fs=None, freqs=freqs, gaussian_width=1.5, window_length=1.0, norm="l1" + ) + output_shape = (sig.shape[0], len(freqs), *sig.shape[1:]) + sig = sig.reshape((sig.shape[0], np.prod(sig.shape[1:]))) + filter_bank = nap.generate_morlet_filterbank(freqs, 1024, 1.5, 1.0, precision=16) + convolved_real = sig.convolve(filter_bank.real().values) + convolved_imag = sig.convolve(filter_bank.imag().values) + convolved = convolved_real.values + convolved_imag.values * 1j + coef = convolved / (1024 / freqs) + cwt = np.expand_dims(coef, -1) if len(coef.shape) == 2 else coef + mwt2 = nap.TsdTensor( + t=sig.index, d=cwt.reshape(output_shape), time_support=sig.time_support + ) + assert np.array_equal(mwt, mwt2) + + # Testing against manual convolution for l2 norm + t = np.linspace(0, 1, 1024) + sig = nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=t) + freqs = np.linspace(1, 600, 10) + mwt = nap.compute_wavelet_transform( + sig, fs=None, freqs=freqs, gaussian_width=1.5, window_length=1.0, norm="l2" + ) + output_shape = (sig.shape[0], len(freqs), *sig.shape[1:]) + sig = sig.reshape((sig.shape[0], np.prod(sig.shape[1:]))) + filter_bank = nap.generate_morlet_filterbank(freqs, 1024, 1.5, 1.0, precision=16) + convolved_real = sig.convolve(filter_bank.real().values) + convolved_imag = sig.convolve(filter_bank.imag().values) + convolved = convolved_real.values + convolved_imag.values * 1j + coef = convolved / (1024 / np.sqrt(freqs)) + cwt = np.expand_dims(coef, -1) if len(coef.shape) == 2 else coef + mwt2 = nap.TsdTensor( + t=sig.index, d=cwt.reshape(output_shape), time_support=sig.time_support + ) + assert np.array_equal(mwt, mwt2) + + # Testing against manual convolution for no normalization + t = np.linspace(0, 1, 1024) + sig = nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=t) + freqs = np.linspace(1, 600, 10) + mwt = nap.compute_wavelet_transform( + sig, fs=None, freqs=freqs, gaussian_width=1.5, window_length=1.0, norm=None + ) + output_shape = (sig.shape[0], len(freqs), *sig.shape[1:]) + sig = sig.reshape((sig.shape[0], np.prod(sig.shape[1:]))) + filter_bank = nap.generate_morlet_filterbank(freqs, 1024, 1.5, 1.0, precision=16) + convolved_real = sig.convolve(filter_bank.real().values) + convolved_imag = sig.convolve(filter_bank.imag().values) + convolved = convolved_real.values + convolved_imag.values * 1j + coef = convolved + cwt = np.expand_dims(coef, -1) if len(coef.shape) == 2 else coef + mwt2 = nap.TsdTensor( + t=sig.index, d=cwt.reshape(output_shape), time_support=sig.time_support + ) + assert np.array_equal(mwt, mwt2) -if __name__ == "__main__": - test_compute_wavelet_transform() +@pytest.mark.parametrize( + "sig, fs, freqs, gaussian_width, window_length, precision, norm, expectation", + [ + ( + nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=np.linspace(0, 1, 1024)), + None, + np.linspace(0, 600, 10), + 1.5, + 1.0, + 16, + "l1", + pytest.raises( + ValueError, match="All frequencies in freqs must be strictly positive" + ), + ), + ( + nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=np.linspace(0, 1, 1024)), + None, + np.linspace(1, 600, 10), + -1.5, + 1.0, + 16, + "l1", + pytest.raises( + ValueError, match="gaussian_width must be a positive number." + ), + ), + ( + nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=np.linspace(0, 1, 1024)), + None, + np.linspace(1, 600, 10), + 1.5, + 1.0, + 16, + "l3", + pytest.raises( + ValueError, match="norm parameter must be 'l1', 'l2', or None." + ), + ), + ], +) +def test_compute_wavelet_transform_raise_errors( + sig, freqs, fs, gaussian_width, window_length, precision, norm, expectation +): + with expectation: + _ = nap.compute_wavelet_transform( + sig, freqs, fs, gaussian_width, window_length, precision, norm + ) From 73ee4deac38ebab1f7160f40c11225dd78ef1004 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Mon, 5 Aug 2024 15:26:32 +0100 Subject: [PATCH 54/71] one added case for tests --- pynapple/process/signal_processing.py | 15 +++++++-------- tests/test_signal_processing.py | 23 +++++++++++++++++++++++ 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 25ef8448..5b8eaa3d 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -216,9 +216,9 @@ def compute_wavelet_transform( ---------- sig : pynapple.Tsd, pynapple.TsdFrame or pynapple.TsdTensor Time series. - freqs : 1d array or list of float + freqs : 1d array or tuple of float If array, frequency values to estimate with morlet wavelets. - If list, define the frequency range, as [freq_start, freq_stop, freq_step]. + If tuple, define the frequency range, as [freq_start, freq_stop, freq_step]. The `freq_step` is optional, and defaults to 1. Range is inclusive of `freq_stop` value. fs : float or None Sampling rate, in Hz. Defaults to sig.rate if None is given. @@ -261,8 +261,9 @@ def compute_wavelet_transform( raise ValueError("gaussian_width must be a positive number.") if norm is not None and norm not in ["l1", "l2"]: raise ValueError("norm parameter must be 'l1', 'l2', or None.") - - if isinstance(freqs, (tuple, list)): + if not isinstance(freqs, (np.ndarray, tuple)): + raise TypeError("`freqs` must be a ndarray or tuple instance.") + if isinstance(freqs, tuple): freqs = _create_freqs(*freqs) if fs is None: @@ -307,10 +308,8 @@ def generate_morlet_filterbank( Parameters ---------- - freqs : 1d array or list of float - If array, frequency values to estimate with morlet wavelets. - If list, define the frequency range, as [freq_start, freq_stop, freq_step]. - The `freq_step` is optional, and defaults to 1. Range is inclusive of `freq_stop` value. + freqs : 1d array + Frequency values to estimate with morlet wavelets. fs : float Sampling rate, in Hz. gaussian_width : float diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index b3bae4b9..130c1b7a 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -110,6 +110,17 @@ def test_compute_wavelet_transform(): == 500 ) + t = np.linspace(0, 1, 1001) + sig = nap.Tsd( + d=np.sin(t * 50 * np.pi * 2) + * np.interp(np.linspace(0, 1, 1001), [0, 0.5, 1], [0, 1, 0]), + t=t, + ) + freqs = (10, 100, 10) + mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) + mwt2 = nap.compute_wavelet_transform(sig, fs=None, freqs=np.linspace(10, 100, 10)) + assert np.array_equal(mwt, mwt2) + t = np.linspace(0, 1, 1001) sig = nap.Tsd( d=np.sin(t * 10 * np.pi * 2) @@ -297,6 +308,18 @@ def test_compute_wavelet_transform(): ValueError, match="norm parameter must be 'l1', 'l2', or None." ), ), + ( + nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=np.linspace(0, 1, 1024)), + None, + None, + 1.5, + 1.0, + 16, + "l1", + pytest.raises( + TypeError, match="`freqs` must be a ndarray or tuple instance." + ), + ), ], ) def test_compute_wavelet_transform_raise_errors( From a5b4f3bdf01b2ec5ade81b3817bb780f06512303 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Mon, 5 Aug 2024 15:48:04 -0400 Subject: [PATCH 55/71] Updating tutorial_signal_processing notebool --- docs/examples/tutorial_signal_processing.py | 189 ++++++++------------ 1 file changed, 74 insertions(+), 115 deletions(-) diff --git a/docs/examples/tutorial_signal_processing.py b/docs/examples/tutorial_signal_processing.py index 2453e0f2..dd0ac745 100644 --- a/docs/examples/tutorial_signal_processing.py +++ b/docs/examples/tutorial_signal_processing.py @@ -60,28 +60,33 @@ # Let's load and print the full dataset. data = nap.load_file(path) -FS = 1250 + print(data) +# %% +# First we can extract the data from the NWB. The local field potential has been downsampled to 1250Hz. We will call it `eeg`. +# +# The `time_support` of the object `data['position']` contains the interval for which the rat was running along the linear track. We will call it `wake_ep`. +# + +FS = 1250 + +eeg = data['eeg'] + +wake_ep = data['position'].time_support # %% # *** -# Selecting slices +# Selecting example # ----------------------------------- # We will consider a single run of the experiment - where the rodent completes a full traversal of the linear track, # followed by 4 seconds of post-traversal activity. -# Define the run to use for this Analysis -run_index = 7 -# Define the IntervalSet for this run and instantiate both LFP and -# Position TsdFrame objects -RUN_interval = nap.IntervalSet( - data["forward_ep"]["start"][run_index], - data["forward_ep"]["end"][run_index] + 4.0, -) -RUN_Tsd = data["eeg"].restrict(RUN_interval) -RUN_pos = data["position"].restrict(RUN_interval) -print(RUN_Tsd) +forward_ep = data['forward_ep'] +RUN_interval = nap.IntervalSet(forward_ep.start[7], forward_ep.end[7] + 4.0) + +eeg_example = eeg.restrict(RUN_interval)[:,0] +pos_example = data['position'].restrict(RUN_interval) # %% # *** @@ -93,71 +98,47 @@ [["ephys"], ["pos"]], height_ratios=[1, 0.4], ) - -axd["ephys"].plot( - RUN_Tsd[:, 0].restrict( - nap.IntervalSet( - data["forward_ep"]["start"][run_index], data["forward_ep"]["end"][run_index] - ) - ), - label="Traversal LFP Data", - color="green", -) -axd["ephys"].plot( - RUN_Tsd[:, 0].restrict( - nap.IntervalSet( - data["forward_ep"]["end"][run_index], - data["forward_ep"]["end"][run_index] + 5.0, - ) - ), - label="Post Traversal LFP Data", - color="blue", -) -axd["ephys"].set_title("Traversal & Post Traversal LFP") +axd["ephys"].plot(eeg_example, label="CA1") +axd["ephys"].set_title("EEG (1250 Hz)") axd["ephys"].set_ylabel("LFP (v)") axd["ephys"].set_xlabel("time (s)") axd["ephys"].margins(0) axd["ephys"].legend() -axd["pos"].plot(RUN_pos, color="black") +axd["pos"].plot(pos_example, color="black") axd["pos"].margins(0) axd["pos"].set_xlabel("time (s)") axd["pos"].set_ylabel("Linearized Position") -axd["pos"].set_xlim(RUN_Tsd.index.min(), RUN_Tsd.index.max()) +axd["pos"].set_xlim(RUN_interval[0,0], RUN_interval[0,1]) + # %% # *** # Getting the LFP Spectrogram # ----------------------------------- -# Let's take the Fourier transform of our data to get an initial insight into the dominant frequencies. +# Let's take the Fourier transform of our data to get an initial insight into the dominant frequencies during exploration (`wake_ep`). + + +power = nap.compute_power_spectral_density(eeg, fs=FS, ep=wake_ep) +print(power) + -fft = nap.compute_power_spectral_density(RUN_Tsd, fs=int(FS)) -print(fft) # %% # *** # The returned object is a pandas dataframe which uses frequencies as indexes and spectral power as values. # -# Now let's plot it +# Let's plot the power between 1 and 100 Hz. +# +# The red area outlines the theta rhythm (6-12 Hz) which is proeminent in hippocampal LFP. fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) -ax.plot( - fft.index, - np.abs(fft.iloc[:, 0]), - alpha=0.5, - label="LFP Frequency Power", - c="blue", - linewidth=2, -) +ax.semilogy(np.abs(power[(power.index>1.0) & (power.index<100)]),alpha=0.5,label="LFP Frequency Power") +ax.axvspan(6, 12, color = 'red', alpha=0.1) ax.set_xlabel("Freq (Hz)") ax.set_ylabel("Frequency Power") ax.set_title("LFP Fourier Decomposition") -ax.set_xlim(1, 30) -ax.axvline(9.36, c="orange", label="9.36Hz", alpha=0.5) -ax.axvline(18.74, c="green", label="18.74Hz", alpha=0.5) ax.legend() -# ax.set_yscale('log') -# ax.set_xscale('log') # %% # *** @@ -168,63 +149,41 @@ # Let's generate a wavelet decomposition to look more closely at the changing frequency powers over time. # We must define the frequency set that we'd like to use for our decomposition -freqs = np.geomspace(5, 250, 25) +freqs = np.geomspace(3, 250, 100) # Compute and print the wavelet transform on our LFP data -mwt_RUN = nap.compute_wavelet_transform(RUN_Tsd[:, 0], fs=FS, freqs=freqs) +mwt_RUN = nap.compute_wavelet_transform(eeg_example, fs=FS, freqs=freqs) + + +# %% +# `mwt_RUN` is a TsdFrame with each column being the convolution with one wavelet at a particular frequency. +print(mwt_RUN) # %% # *** # Now let's plot it. +fig = plt.figure(constrained_layout=True, figsize=(10, 6)) +gs = plt.GridSpec(3, 1, figure=fig, height_ratios=[1.0, 0.5, 0.1]) + +ax0 = plt.subplot(gs[0,0]) +pcmesh = ax0.pcolormesh(mwt_RUN.t, freqs, np.transpose(np.abs(mwt_RUN))) +ax0.grid(False) +ax0.set_yscale("log") +ax0.set_title("Wavelet decomposition") +cbar = plt.colorbar(pcmesh, ax=ax0, orientation='vertical') +ax0.set_label("Amplitude") + +ax1 = plt.subplot(gs[1,0], sharex = ax0) +ax1.plot(eeg_example) +ax1.set_ylabel("LFP (v)") + +ax1 = plt.subplot(gs[2,0], sharex = ax0) +ax1.plot(pos_example) +ax1.set_xlabel("Time (s)") +ax1.set_ylabel("Pos.") + +plt.show() -# Define wavelet decomposition plotting function -def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): - if np.iscomplexobj(powers): - powers = abs(powers) - ax.imshow(powers, aspect="auto", **kwargs) - ax.invert_yaxis() - ax.set_xlabel("Time (s)") - ax.set_ylabel("Frequency (Hz)") - if isinstance(x_ticks, int): - x_tick_pos = np.linspace(0, times.size, x_ticks) - x_ticks = np.round(np.linspace(times[0], times[-1], x_ticks), 2) - else: - x_tick_pos = [np.argmin(np.abs(times - val)) for val in x_ticks] - ax.set(xticks=x_tick_pos, xticklabels=x_ticks) - y_ticks = [np.round(f, 2) for f in freqs] - y_ticks_pos = [np.argmin(np.abs(freqs - val)) for val in y_ticks] - ax.set(yticks=y_ticks_pos, yticklabels=y_ticks) - ax.grid(False) - - -# And plot -fig = plt.figure(constrained_layout=True, figsize=(10, 8)) -axd = fig.subplot_mosaic( - [ - ["wd_run"], - ["lfp_run"], - ["pos_run"], - ], - height_ratios=[1.2, 0.2, 0.6], -) -plot_timefrequency( - RUN_Tsd.index.values[:], - freqs[:], - np.transpose(mwt_RUN[:, :].values), - ax=axd["wd_run"], -) -axd["wd_run"].set_title(f"Wavelet Decomposition") -axd["lfp_run"].plot(RUN_Tsd) -axd["pos_run"].plot(RUN_pos) -axd["pos_run"].set_xlim(RUN_Tsd.index.min(), RUN_Tsd.index.max()) -axd["pos_run"].set_ylabel("Lin. Position (cm)") -for k in ["lfp_run", "pos_run"]: - axd[k].margins(0) - if k != "pos_run": - axd[k].set_ylabel("LFP (v)") - axd[k].get_xaxis().set_visible(False) - for spine in ["top", "right", "bottom", "left"]: - axd[k].spines[spine].set_visible(False) # %% # *** @@ -254,14 +213,14 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): height_ratios=[1, 0.4], ) -axd["lfp_run"].plot(RUN_Tsd.index.values, RUN_Tsd[:, 0], alpha=0.5, label="LFP Data") +axd["lfp_run"].plot(eeg_example, alpha=0.5, label="LFP Data") axd["lfp_run"].plot( - RUN_Tsd.index.values, + eeg_example.index.values, theta_band_reconstruction, label=f"{np.round(freqs[theta_freq_index], 2)}Hz oscillations", ) axd["lfp_run"].plot( - RUN_Tsd.index.values, + eeg_example.index.values, theta_band_power_envelope, label=f"{np.round(freqs[theta_freq_index], 2)}Hz power envelope", ) @@ -269,14 +228,14 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): axd["lfp_run"].set_ylabel("LFP (v)") axd["lfp_run"].set_xlabel("Time (s)") axd["lfp_run"].set_title(f"{np.round(freqs[theta_freq_index], 2)}Hz oscillation power.") -axd["pos_run"].plot(RUN_pos) +axd["pos_run"].plot(pos_example) [axd[k].margins(0) for k in ["lfp_run", "pos_run"]] [ axd["pos_run"].spines[sp].set_visible(False) for sp in ["top", "right", "bottom", "left"] ] axd["pos_run"].get_xaxis().set_visible(False) -axd["pos_run"].set_xlim(RUN_Tsd.index.min(), RUN_Tsd.index.max()) +axd["pos_run"].set_xlim(eeg_example.index.min(), eeg_example.index.max()) axd["pos_run"].set_ylabel("Lin. Position (cm)") axd["lfp_run"].legend() @@ -306,19 +265,19 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): ], height_ratios=[1, 0.4], ) -axd["lfp_run"].plot(RUN_Tsd.index.values, RUN_Tsd[:, 0], label="LFP Data") +axd["lfp_run"].plot(eeg_example, label="LFP Data") axd["lfp_run"].set_ylabel("LFP (v)") axd["lfp_run"].set_xlabel("Time (s)") axd["lfp_run"].margins(0) axd["lfp_run"].set_title(f"Traversal & Post-Traversal LFP") -axd["rip_pow"].plot(RUN_Tsd.index.values, ripple_power) +axd["rip_pow"].plot(eeg_example.index.values, ripple_power) axd["rip_pow"].margins(0) axd["rip_pow"].get_xaxis().set_visible(False) axd["rip_pow"].spines["top"].set_visible(False) axd["rip_pow"].spines["right"].set_visible(False) axd["rip_pow"].spines["bottom"].set_visible(False) axd["rip_pow"].spines["left"].set_visible(False) -axd["rip_pow"].set_xlim(RUN_Tsd.index.min(), RUN_Tsd.index.max()) +axd["rip_pow"].set_xlim(eeg_example.index.min(), eeg_example.index.max()) axd["rip_pow"].set_ylabel(f"{np.round(freqs[ripple_freq_idx], 2)}Hz Power") # %% @@ -352,7 +311,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): ], height_ratios=[1, 0.4], ) -axd["lfp_run"].plot(RUN_Tsd[:, 0], label="LFP Data") +axd["lfp_run"].plot(eeg_example, label="LFP Data") axd["rip_pow"].plot(smoother_swr_power) axd["rip_pow"].plot(is_ripple, color="red", linewidth=2) axd["lfp_run"].set_ylabel("LFP (v)") @@ -362,7 +321,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): [axd[k].margins(0) for k in ["lfp_run", "rip_pow"]] [axd["rip_pow"].spines[sp].set_visible(False) for sp in ["top", "left", "right"]] axd["rip_pow"].get_xaxis().set_visible(False) -axd["rip_pow"].set_xlim(RUN_Tsd.index.min(), RUN_Tsd.index.max()) +axd["rip_pow"].set_xlim(eeg_example.index.min(), eeg_example.index.max()) axd["rip_pow"].set_ylabel(f"{np.round(freqs[ripple_freq_idx], 2)}Hz Power") # %% @@ -374,7 +333,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) buffer = 0.1 ax.plot( - RUN_Tsd.restrict( + eeg_example.restrict( nap.IntervalSet( start=is_ripple.index.min() - buffer, end=is_ripple.index.max() + buffer ) @@ -383,7 +342,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): label="Non-SWR LFP", ) ax.plot( - RUN_Tsd.restrict( + eeg_example.restrict( nap.IntervalSet(start=is_ripple.index.min(), end=is_ripple.index.max()) ), color="red", From b1540ea5d8418ccff69085c61e0f2133bb614902 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Mon, 5 Aug 2024 21:47:34 +0100 Subject: [PATCH 56/71] more concise plotting code in docs --- docs/examples/tutorial_signal_processing.py | 107 +++++++++----------- 1 file changed, 48 insertions(+), 59 deletions(-) diff --git a/docs/examples/tutorial_signal_processing.py b/docs/examples/tutorial_signal_processing.py index dd0ac745..ac3dc3c9 100644 --- a/docs/examples/tutorial_signal_processing.py +++ b/docs/examples/tutorial_signal_processing.py @@ -64,16 +64,16 @@ print(data) # %% -# First we can extract the data from the NWB. The local field potential has been downsampled to 1250Hz. We will call it `eeg`. +# First we can extract the data from the NWB. The local field potential has been downsampled to 1250Hz. We will call it `eeg`. # # The `time_support` of the object `data['position']` contains the interval for which the rat was running along the linear track. We will call it `wake_ep`. # FS = 1250 -eeg = data['eeg'] +eeg = data["eeg"] -wake_ep = data['position'].time_support +wake_ep = data["position"].time_support # %% # *** @@ -82,11 +82,11 @@ # We will consider a single run of the experiment - where the rodent completes a full traversal of the linear track, # followed by 4 seconds of post-traversal activity. -forward_ep = data['forward_ep'] +forward_ep = data["forward_ep"] RUN_interval = nap.IntervalSet(forward_ep.start[7], forward_ep.end[7] + 4.0) -eeg_example = eeg.restrict(RUN_interval)[:,0] -pos_example = data['position'].restrict(RUN_interval) +eeg_example = eeg.restrict(RUN_interval)[:, 0] +pos_example = data["position"].restrict(RUN_interval) # %% # *** @@ -100,7 +100,7 @@ ) axd["ephys"].plot(eeg_example, label="CA1") axd["ephys"].set_title("EEG (1250 Hz)") -axd["ephys"].set_ylabel("LFP (v)") +axd["ephys"].set_ylabel("LFP (a.u.)") axd["ephys"].set_xlabel("time (s)") axd["ephys"].margins(0) axd["ephys"].legend() @@ -108,8 +108,7 @@ axd["pos"].margins(0) axd["pos"].set_xlabel("time (s)") axd["pos"].set_ylabel("Linearized Position") -axd["pos"].set_xlim(RUN_interval[0,0], RUN_interval[0,1]) - +axd["pos"].set_xlim(RUN_interval[0, 0], RUN_interval[0, 1]) # %% @@ -123,18 +122,21 @@ print(power) - # %% # *** # The returned object is a pandas dataframe which uses frequencies as indexes and spectral power as values. # -# Let's plot the power between 1 and 100 Hz. +# Let's plot the power between 1 and 100 Hz. # # The red area outlines the theta rhythm (6-12 Hz) which is proeminent in hippocampal LFP. fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) -ax.semilogy(np.abs(power[(power.index>1.0) & (power.index<100)]),alpha=0.5,label="LFP Frequency Power") -ax.axvspan(6, 12, color = 'red', alpha=0.1) +ax.semilogy( + np.abs(power[(power.index > 1.0) & (power.index < 100)]), + alpha=0.5, + label="LFP Frequency Power", +) +ax.axvspan(6, 12, color="red", alpha=0.1) ax.set_xlabel("Freq (Hz)") ax.set_ylabel("Frequency Power") ax.set_title("LFP Fourier Decomposition") @@ -165,20 +167,20 @@ fig = plt.figure(constrained_layout=True, figsize=(10, 6)) gs = plt.GridSpec(3, 1, figure=fig, height_ratios=[1.0, 0.5, 0.1]) -ax0 = plt.subplot(gs[0,0]) +ax0 = plt.subplot(gs[0, 0]) pcmesh = ax0.pcolormesh(mwt_RUN.t, freqs, np.transpose(np.abs(mwt_RUN))) ax0.grid(False) ax0.set_yscale("log") -ax0.set_title("Wavelet decomposition") -cbar = plt.colorbar(pcmesh, ax=ax0, orientation='vertical') +ax0.set_title("Wavelet Decomposition") +cbar = plt.colorbar(pcmesh, ax=ax0, orientation="vertical") ax0.set_label("Amplitude") -ax1 = plt.subplot(gs[1,0], sharex = ax0) +ax1 = plt.subplot(gs[1, 0], sharex=ax0) ax1.plot(eeg_example) ax1.set_ylabel("LFP (v)") -ax1 = plt.subplot(gs[2,0], sharex = ax0) -ax1.plot(pos_example) +ax1 = plt.subplot(gs[2, 0], sharex=ax0) +ax1.plot(pos_example, color="black") ax1.set_xlabel("Time (s)") ax1.set_ylabel("Pos.") @@ -194,7 +196,7 @@ # they match up # Find the index of the frequency closest to theta band -theta_freq_index = np.argmin(np.abs(10 - freqs)) +theta_freq_index = np.argmin(np.abs(8 - freqs)) # Extract its real component, as well as its power envelope theta_band_reconstruction = mwt_RUN[:, theta_freq_index].values.real theta_band_power_envelope = np.abs(mwt_RUN[:, theta_freq_index].values) @@ -206,38 +208,31 @@ fig = plt.figure(constrained_layout=True, figsize=(10, 6)) axd = fig.subplot_mosaic( - [ - ["lfp_run"], - ["pos_run"], - ], + [["ephys"], ["pos"]], height_ratios=[1, 0.4], ) - -axd["lfp_run"].plot(eeg_example, alpha=0.5, label="LFP Data") -axd["lfp_run"].plot( +axd["ephys"].plot(eeg_example, label="CA1") +axd["ephys"].plot( eeg_example.index.values, theta_band_reconstruction, label=f"{np.round(freqs[theta_freq_index], 2)}Hz oscillations", ) -axd["lfp_run"].plot( +axd["ephys"].plot( eeg_example.index.values, theta_band_power_envelope, label=f"{np.round(freqs[theta_freq_index], 2)}Hz power envelope", ) +axd["ephys"].set_title("EEG (1250 Hz)") +axd["ephys"].set_ylabel("LFP (a.u.)") +axd["ephys"].set_xlabel("time (s)") +axd["ephys"].margins(0) +axd["ephys"].legend() +axd["pos"].plot(pos_example, color="black") +axd["pos"].margins(0) +axd["pos"].set_xlabel("time (s)") +axd["pos"].set_ylabel("Linearized Position") +axd["pos"].set_xlim(RUN_interval[0, 0], RUN_interval[0, 1]) -axd["lfp_run"].set_ylabel("LFP (v)") -axd["lfp_run"].set_xlabel("Time (s)") -axd["lfp_run"].set_title(f"{np.round(freqs[theta_freq_index], 2)}Hz oscillation power.") -axd["pos_run"].plot(pos_example) -[axd[k].margins(0) for k in ["lfp_run", "pos_run"]] -[ - axd["pos_run"].spines[sp].set_visible(False) - for sp in ["top", "right", "bottom", "left"] -] -axd["pos_run"].get_xaxis().set_visible(False) -axd["pos_run"].set_xlim(eeg_example.index.min(), eeg_example.index.max()) -axd["pos_run"].set_ylabel("Lin. Position (cm)") -axd["lfp_run"].legend() # %% # *** @@ -266,17 +261,12 @@ height_ratios=[1, 0.4], ) axd["lfp_run"].plot(eeg_example, label="LFP Data") +axd["rip_pow"].plot(eeg_example.index.values, ripple_power) axd["lfp_run"].set_ylabel("LFP (v)") axd["lfp_run"].set_xlabel("Time (s)") axd["lfp_run"].margins(0) -axd["lfp_run"].set_title(f"Traversal & Post-Traversal LFP") -axd["rip_pow"].plot(eeg_example.index.values, ripple_power) +axd["lfp_run"].set_title(f"EEG (1250 Hz)") axd["rip_pow"].margins(0) -axd["rip_pow"].get_xaxis().set_visible(False) -axd["rip_pow"].spines["top"].set_visible(False) -axd["rip_pow"].spines["right"].set_visible(False) -axd["rip_pow"].spines["bottom"].set_visible(False) -axd["rip_pow"].spines["left"].set_visible(False) axd["rip_pow"].set_xlim(eeg_example.index.min(), eeg_example.index.max()) axd["rip_pow"].set_ylabel(f"{np.round(freqs[ripple_freq_idx], 2)}Hz Power") @@ -313,14 +303,14 @@ ) axd["lfp_run"].plot(eeg_example, label="LFP Data") axd["rip_pow"].plot(smoother_swr_power) -axd["rip_pow"].plot(is_ripple, color="red", linewidth=2) +axd["rip_pow"].axvspan( + is_ripple.index.min(), is_ripple.index.max(), color="red", alpha=0.3 +) axd["lfp_run"].set_ylabel("LFP (v)") axd["lfp_run"].set_xlabel("Time (s)") -axd["lfp_run"].set_title(f"{np.round(freqs[ripple_freq_idx], 2)}Hz oscillation power.") -axd["rip_pow"].axhline(threshold) +axd["lfp_run"].set_title(f"EEG (1250 Hz)") +axd["rip_pow"].axhline(threshold, linestyle="--", color="black", alpha=0.4) [axd[k].margins(0) for k in ["lfp_run", "rip_pow"]] -[axd["rip_pow"].spines[sp].set_visible(False) for sp in ["top", "left", "right"]] -axd["rip_pow"].get_xaxis().set_visible(False) axd["rip_pow"].set_xlim(eeg_example.index.min(), eeg_example.index.max()) axd["rip_pow"].set_ylabel(f"{np.round(freqs[ripple_freq_idx], 2)}Hz Power") @@ -341,13 +331,12 @@ color="blue", label="Non-SWR LFP", ) -ax.plot( - eeg_example.restrict( - nap.IntervalSet(start=is_ripple.index.min(), end=is_ripple.index.max()) - ), +ax.axvspan( + is_ripple.index.min(), + is_ripple.index.max(), color="red", - label="SWR", - linewidth=2, + alpha=0.3, + label="SWR LFP", ) ax.margins(0) ax.set_xlabel("Time (s)") From 05a59961a46453ac74958929521f3414bbc6fc2f Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Mon, 5 Aug 2024 23:32:15 +0100 Subject: [PATCH 57/71] signal processing tests to 100% coverage --- pynapple/process/signal_processing.py | 34 +----- tests/test_power_spectral_density.py | 169 ++++++++++++++++++++------ tests/test_signal_processing.py | 25 +++- 3 files changed, 158 insertions(+), 70 deletions(-) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 5b8eaa3d..e0ee1e32 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -178,7 +178,7 @@ def _morlet(M=1024, gaussian_width=1.5, window_length=1.0, precision=8): ) -def _create_freqs(freq_start, freq_stop, freq_step=1, log_scaling=False, log_base=np.e): +def _create_freqs(freq_start, freq_stop, num_freqs=10, log_scaling=False): """ Creates an array of frequencies. @@ -188,12 +188,10 @@ def _create_freqs(freq_start, freq_stop, freq_step=1, log_scaling=False, log_bas Starting value for the frequency definition. freq_stop: float Stopping value for the frequency definition, inclusive. - freq_step: float, optional - Step value, for linearly spaced values between start and stop. + num_freqs: int, optional + Number of freqs to create. Default 10 log_scaling: Bool If True, will use log spacing with base log_base for frequency spacing. Default False. - log_base: float - If log_scaling==True, this defines the base of the log to use. Returns ------- @@ -201,9 +199,9 @@ def _create_freqs(freq_start, freq_stop, freq_step=1, log_scaling=False, log_bas Frequency indices. """ if not log_scaling: - return np.arange(freq_start, freq_stop + freq_step, freq_step) + return np.linspace(freq_start, freq_stop, num_freqs) else: - return np.logspace(freq_start, freq_stop, base=log_base) + return np.geomspace(freq_start, freq_stop, num_freqs) def compute_wavelet_transform( @@ -358,25 +356,3 @@ def generate_morlet_filterbank( for arr in filter_bank ] return nap.TsdFrame(d=np.array(filter_bank).transpose(), t=time) - - -def _integrate(arr, step): - """ - Integrates an array with respect to some step param. Used for integrating complex wavelets. - - Parameters - ---------- - arr : np.ndarray - wave function to be integrated - step : float - Step size of given wave function array - - Returns - ------- - array - Complex-valued integrated wavelet - - """ - integral = np.cumsum(arr) - integral *= step - return integral diff --git a/tests/test_power_spectral_density.py b/tests/test_power_spectral_density.py index 0eb39cdd..18503294 100644 --- a/tests/test_power_spectral_density.py +++ b/tests/test_power_spectral_density.py @@ -1,25 +1,18 @@ +import re +from contextlib import nullcontext as does_not_raise + import numpy as np import pandas as pd import pytest -from contextlib import nullcontext as does_not_raise + import pynapple as nap ############################################################ # Test for mean_power_spectral_density ############################################################ + def test_compute_power_spectral_density(): - with pytest.raises(ValueError) as e_info: - t = np.linspace(0, 1, 1000) - sig = nap.Tsd( - d=np.random.random(1000), - t=t, - time_support=nap.IntervalSet(start=[0.1, 0.6], end=[0.2, 0.81]), - ) - r = nap.compute_power_spectral_density(sig) - assert ( - str(e_info.value) == "Given epoch (or signal time_support) must have length 1" - ) t = np.linspace(0, 1, 1000) sig = nap.Tsd(d=np.random.random(1000), t=t) @@ -37,23 +30,65 @@ def test_compute_power_spectral_density(): assert isinstance(r, pd.DataFrame) assert r.shape == (1000, 4) - with pytest.raises(TypeError) as e_info: - nap.compute_power_spectral_density("a_string") - assert ( - str(e_info.value) - == "Currently compute_spectogram is only implemented for Tsd or TsdFrame" - ) + +@pytest.mark.parametrize( + "sig, fs, ep, full_range, expectation", + [ + ( + nap.Tsd( + d=np.random.random(1000), + t=np.linspace(0, 1, 1000), + time_support=nap.IntervalSet(start=[0.1, 0.6], end=[0.2, 0.81]), + ), + 1000, + None, + False, + pytest.raises( + ValueError, + match=re.escape( + "Given epoch (or signal time_support) must have length 1" + ), + ), + ), + ( + nap.Tsd(d=np.random.random(1000), t=np.linspace(0, 1, 1000)), + 1000, + "not_ep", + False, + pytest.raises( + TypeError, + match="ep param must be a pynapple IntervalSet object, or None", + ), + ), + ( + "not_a_tsd", + 1000, + None, + False, + pytest.raises( + TypeError, + match="Currently compute_spectogram is only implemented for Tsd or TsdFrame", + ), + ), + ], +) +def test_compute_power_spectral_density_raise_errors( + sig, fs, ep, full_range, expectation +): + with expectation: + psd = nap.compute_power_spectral_density(sig, fs, ep, full_range) ############################################################ # Test for mean_power_spectral_density ############################################################ -def get_signal_and_output(f=2, fs=1000,duration=100,interval_size=10): - t=np.arange(0, duration, 1/fs) - d=np.cos(2*np.pi*f*t) - sig = nap.Tsd(t=t,d=d, time_support=nap.IntervalSet(0,100)) - tmp = d.reshape((int(duration/interval_size),int(fs*interval_size))).T + +def get_signal_and_output(f=2, fs=1000, duration=100, interval_size=10): + t = np.arange(0, duration, 1 / fs) + d = np.cos(2 * np.pi * f * t) + sig = nap.Tsd(t=t, d=d, time_support=nap.IntervalSet(0, 100)) + tmp = d.reshape((int(duration / interval_size), int(fs * interval_size))).T tmp = tmp[0:-1] out = np.sum(np.fft.fft(tmp, axis=0), 1) freq = np.fft.fftfreq(out.shape[0], 1 / fs) @@ -62,15 +97,16 @@ def get_signal_and_output(f=2, fs=1000,duration=100,interval_size=10): freq = freq[order] return (sig, out, freq) + def test_compute_mean_power_spectral_density(): - sig, out, freq = get_signal_and_output() + sig, out, freq = get_signal_and_output() psd = nap.compute_mean_power_spectral_density(sig, 10) assert isinstance(psd, pd.DataFrame) assert psd.shape[0] > 0 # Check that the psd DataFrame is not empty - np.testing.assert_array_almost_equal(psd.values.flatten(), out[freq>=0]) - np.testing.assert_array_almost_equal(psd.index.values, freq[freq>=0]) + np.testing.assert_array_almost_equal(psd.values.flatten(), out[freq >= 0]) + np.testing.assert_array_almost_equal(psd.index.values, freq[freq >= 0]) - # Full range + # Full range psd = nap.compute_mean_power_spectral_density(sig, 10, full_range=True) assert isinstance(psd, pd.DataFrame) assert psd.shape[0] > 0 # Check that the psd DataFrame is not empty @@ -78,29 +114,82 @@ def test_compute_mean_power_spectral_density(): np.testing.assert_array_almost_equal(psd.index.values, freq) # TsdFrame - sig2 = nap.TsdFrame(t=sig.t, d=np.repeat(sig.values[:,None], 2, 1), time_support = sig.time_support) + sig2 = nap.TsdFrame( + t=sig.t, d=np.repeat(sig.values[:, None], 2, 1), time_support=sig.time_support + ) psd = nap.compute_mean_power_spectral_density(sig2, 10, full_range=True) assert isinstance(psd, pd.DataFrame) assert psd.shape[0] > 0 # Check that the psd DataFrame is not empty - np.testing.assert_array_almost_equal(psd.values, np.repeat(out[:,None],2,1)) + np.testing.assert_array_almost_equal(psd.values, np.repeat(out[:, None], 2, 1)) np.testing.assert_array_almost_equal(psd.index.values, freq) @pytest.mark.parametrize( "sig, out, freq, interval_size, fs, ep, full_range, time_units, expectation", [ - (*get_signal_and_output(), 10, None, None, False, "s", does_not_raise()), - (*get_signal_and_output(), 10, "a", None, False, "s", pytest.raises(TypeError, match="fs must be of type float or int")), - (*get_signal_and_output(), 10, None, "a", False, "s", pytest.raises(TypeError, match="ep param must be a pynapple IntervalSet object, or None")), - (*get_signal_and_output(), 10, None, None, "a", "s", pytest.raises(TypeError, match="full_range must be of type bool or None")), - (*get_signal_and_output(), 10*1e3, None, None, False, "ms", does_not_raise()), - (*get_signal_and_output(), 10*1e6, None, None, False, "us", does_not_raise()), - (*get_signal_and_output(), 200, None, None, False, "s", pytest.raises(RuntimeError, match="Splitting epochs with interval_size=200 generated an empty IntervalSet. Try decreasing interval_size")), - (*get_signal_and_output(), 10, None, nap.IntervalSet([0, 200], [100,300]), False, "s", pytest.raises(RuntimeError, match="One interval doesn't have any signal associated. Check the parameter ep or the time support if no epoch is passed.")), - ] + (*get_signal_and_output(), 10, None, None, False, "s", does_not_raise()), + ( + *get_signal_and_output(), + 10, + "a", + None, + False, + "s", + pytest.raises(TypeError, match="fs must be of type float or int"), + ), + ( + *get_signal_and_output(), + 10, + None, + "a", + False, + "s", + pytest.raises( + TypeError, + match="ep param must be a pynapple IntervalSet object, or None", + ), + ), + ( + *get_signal_and_output(), + 10, + None, + None, + "a", + "s", + pytest.raises(TypeError, match="full_range must be of type bool or None"), + ), + (*get_signal_and_output(), 10 * 1e3, None, None, False, "ms", does_not_raise()), + (*get_signal_and_output(), 10 * 1e6, None, None, False, "us", does_not_raise()), + ( + *get_signal_and_output(), + 200, + None, + None, + False, + "s", + pytest.raises( + RuntimeError, + match="Splitting epochs with interval_size=200 generated an empty IntervalSet. Try decreasing interval_size", + ), + ), + ( + *get_signal_and_output(), + 10, + None, + nap.IntervalSet([0, 200], [100, 300]), + False, + "s", + pytest.raises( + RuntimeError, + match="One interval doesn't have any signal associated. Check the parameter ep or the time support if no epoch is passed.", + ), + ), + ], ) def test_compute_mean_power_spectral_density_raise_errors( sig, out, freq, interval_size, fs, ep, full_range, time_units, expectation - ): +): with expectation: - psd = nap.compute_mean_power_spectral_density(sig, interval_size, fs, ep, full_range, time_units) + psd = nap.compute_mean_power_spectral_density( + sig, interval_size, fs, ep, full_range, time_units + ) diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index 130c1b7a..6310e0af 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -121,6 +121,17 @@ def test_compute_wavelet_transform(): mwt2 = nap.compute_wavelet_transform(sig, fs=None, freqs=np.linspace(10, 100, 10)) assert np.array_equal(mwt, mwt2) + t = np.linspace(0, 1, 1001) + sig = nap.Tsd( + d=np.sin(t * 50 * np.pi * 2) + * np.interp(np.linspace(0, 1, 1001), [0, 0.5, 1], [0, 1, 0]), + t=t, + ) + freqs = (10, 100, 10, True) + mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) + mwt2 = nap.compute_wavelet_transform(sig, fs=None, freqs=np.geomspace(10, 100, 10)) + assert np.array_equal(mwt, mwt2) + t = np.linspace(0, 1, 1001) sig = nap.Tsd( d=np.sin(t * 10 * np.pi * 2) @@ -192,7 +203,7 @@ def test_compute_wavelet_transform(): t = np.linspace(0, 1, 1024) sig = nap.Tsd(d=np.random.random(1024), t=t) - freqs = (1, 51, 10) + freqs = (1, 51, 6) mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) assert mwt.shape == (1024, 6) @@ -320,6 +331,18 @@ def test_compute_wavelet_transform(): TypeError, match="`freqs` must be a ndarray or tuple instance." ), ), + ( + "not_a_signal", + None, + np.linspace(10, 100, 10), + 1.5, + 1.0, + 16, + "l1", + pytest.raises( + TypeError, match="`sig` must be instance of Tsd, TsdFrame, or TsdTensor" + ), + ), ], ) def test_compute_wavelet_transform_raise_errors( From 05e29b65451f158ec8dc8742a007a4c3f0daf789 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Tue, 6 Aug 2024 14:55:22 +0100 Subject: [PATCH 58/71] coverage actually to 100% --- pynapple/process/signal_processing.py | 7 +++++ tests/test_power_spectral_density.py | 22 +++++++++++++ tests/test_signal_processing.py | 45 +++++++++++++++++++++++++++ 3 files changed, 74 insertions(+) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index e0ee1e32..7f921402 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -257,6 +257,13 @@ def compute_wavelet_transform( if isinstance(gaussian_width, (int, float, np.number)): if gaussian_width <= 0: raise ValueError("gaussian_width must be a positive number.") + else: + raise TypeError("gaussian_width must be a float or int instance.") + if isinstance(window_length, (int, float, np.number)): + if window_length <= 0: + raise ValueError("window_length must be a positive number.") + else: + raise TypeError("window_length must be a float or int instance.") if norm is not None and norm not in ["l1", "l2"]: raise ValueError("norm parameter must be 'l1', 'l2', or None.") if not isinstance(freqs, (np.ndarray, tuple)): diff --git a/tests/test_power_spectral_density.py b/tests/test_power_spectral_density.py index 18503294..7d0ec64f 100644 --- a/tests/test_power_spectral_density.py +++ b/tests/test_power_spectral_density.py @@ -30,6 +30,18 @@ def test_compute_power_spectral_density(): assert isinstance(r, pd.DataFrame) assert r.shape == (1000, 4) + t = np.linspace(0, 1, 1000) + sig = nap.Tsd(d=np.random.random(1000), t=t) + r = nap.compute_power_spectral_density(sig, ep=sig.time_support) + assert isinstance(r, pd.DataFrame) + assert r.shape[0] == 500 + + t = np.linspace(0, 1, 1000) + sig = nap.Tsd(d=np.random.random(1000), t=t) + r = nap.compute_power_spectral_density(sig, fs=1000) + assert isinstance(r, pd.DataFrame) + assert r.shape[0] == 500 + @pytest.mark.parametrize( "sig, fs, ep, full_range, expectation", @@ -123,6 +135,16 @@ def test_compute_mean_power_spectral_density(): np.testing.assert_array_almost_equal(psd.values, np.repeat(out[:, None], 2, 1)) np.testing.assert_array_almost_equal(psd.index.values, freq) + # TsdFrame + sig2 = nap.TsdFrame( + t=sig.t, d=np.repeat(sig.values[:, None], 2, 1), time_support=sig.time_support + ) + psd = nap.compute_mean_power_spectral_density(sig2, 10, full_range=True, fs=1000) + assert isinstance(psd, pd.DataFrame) + assert psd.shape[0] > 0 # Check that the psd DataFrame is not empty + np.testing.assert_array_almost_equal(psd.values, np.repeat(out[:, None], 2, 1)) + np.testing.assert_array_almost_equal(psd.index.values, freq) + @pytest.mark.parametrize( "sig, out, freq, interval_size, fs, ep, full_range, time_units, expectation", diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index 6310e0af..34b0e79f 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -121,6 +121,17 @@ def test_compute_wavelet_transform(): mwt2 = nap.compute_wavelet_transform(sig, fs=None, freqs=np.linspace(10, 100, 10)) assert np.array_equal(mwt, mwt2) + t = np.linspace(0, 1, 1001) + sig = nap.Tsd( + d=np.sin(t * 50 * np.pi * 2) + * np.interp(np.linspace(0, 1, 1001), [0, 0.5, 1], [0, 1, 0]), + t=t, + ) + freqs = (10, 100, 10) + mwt = nap.compute_wavelet_transform(sig, fs=1001, freqs=freqs) + mwt2 = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) + assert np.array_equal(mwt, mwt2) + t = np.linspace(0, 1, 1001) sig = nap.Tsd( d=np.sin(t * 50 * np.pi * 2) @@ -307,6 +318,40 @@ def test_compute_wavelet_transform(): ValueError, match="gaussian_width must be a positive number." ), ), + ( + nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=np.linspace(0, 1, 1024)), + None, + np.linspace(1, 600, 10), + 1.5, + -1.0, + 16, + "l1", + pytest.raises(ValueError, match="window_length must be a positive number."), + ), + ( + nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=np.linspace(0, 1, 1024)), + None, + np.linspace(1, 600, 10), + "not_number", + 1.0, + 16, + "l1", + pytest.raises( + TypeError, match="gaussian_width must be a float or int instance." + ), + ), + ( + nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=np.linspace(0, 1, 1024)), + None, + np.linspace(1, 600, 10), + 1.5, + "not_number", + 16, + "l1", + pytest.raises( + TypeError, match="window_length must be a float or int instance." + ), + ), ( nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=np.linspace(0, 1, 1024)), None, From d723b8a660e9d02b4b664fd7fefd93e4ab1f109b Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Tue, 6 Aug 2024 20:30:18 +0100 Subject: [PATCH 59/71] changes to notebooks --- docs/api_guide/tutorial_pynapple_wavelets.py | 298 ++++++++++++------- docs/examples/tutorial_signal_processing.py | 2 - 2 files changed, 192 insertions(+), 108 deletions(-) diff --git a/docs/api_guide/tutorial_pynapple_wavelets.py b/docs/api_guide/tutorial_pynapple_wavelets.py index c0f13862..f4a20b6d 100644 --- a/docs/api_guide/tutorial_pynapple_wavelets.py +++ b/docs/api_guide/tutorial_pynapple_wavelets.py @@ -50,17 +50,15 @@ # Lets plot it. fig, ax = plt.subplots(3, constrained_layout=True, figsize=(10, 5)) ax[0].plot(t, two_hz_component) -ax[1].plot(t, increasing_freq_component) -ax[2].plot(sig) ax[0].set_title("2Hz Component") +ax[1].plot(t, increasing_freq_component) ax[1].set_title("Increasing Frequency Component") +ax[2].plot(sig) ax[2].set_title("Dummy Signal") [ax[i].margins(0) for i in range(3)] [ax[i].set_ylim(-2.5, 2.5) for i in range(3)] -[ax[i].spines[sp].set_visible(False) for sp in ["right", "top"] for i in range(3)] [ax[i].set_xlabel("Time (s)") for i in range(3)] [ax[i].set_ylabel("Signal") for i in range(3)] -[ax[i].set_ylim(-2.5, 2.5) for i in range(3)] # %% @@ -83,15 +81,11 @@ # Lets plot it. def plot_filterbank(filter_bank, freqs, title): fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 7)) - offset = 1.0 for f_i in range(filter_bank.shape[1]): - ax.plot(filter_bank[:, f_i].real() + offset * f_i) - ax.text( - -2.3, offset * f_i, f"{np.round(freqs[f_i], 2)}Hz", va="center", ha="left" - ) + ax.plot(filter_bank[:, f_i].real() + 1.5 * f_i) + ax.text(-2.3, 1.5 * f_i, f"{np.round(freqs[f_i], 2)}Hz", va="center", ha="left") ax.margins(0) ax.yaxis.set_visible(False) - [ax.spines[sp].set_visible(False) for sp in ["left", "right", "top"]] ax.set_xlim(-2, 2) ax.set_xlabel("Time (s)") ax.set_title(title) @@ -117,33 +111,30 @@ def plot_filterbank(filter_bank, freqs, title): # %% # Lets plot it. -def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): - if np.iscomplexobj(powers): - powers = abs(powers) - ax.imshow(powers, aspect="auto", **kwargs) +def plot_timefrequency(freqs, powers, ax=None): + im = ax.imshow(abs(powers), aspect="auto") ax.invert_yaxis() ax.set_xlabel("Time (s)") ax.set_ylabel("Frequency (Hz)") - if isinstance(x_ticks, int): - x_tick_pos = np.linspace(0, times.size, x_ticks) - x_ticks = np.round(np.linspace(times[0], times[-1], x_ticks), 2) - else: - x_tick_pos = [np.argmin(np.abs(times - val)) for val in x_ticks] - ax.set(xticks=x_tick_pos, xticklabels=x_ticks) - y_ticks = freqs - y_ticks_pos = [np.argmin(np.abs(freqs - val)) for val in y_ticks] - ax.set(yticks=y_ticks_pos, yticklabels=y_ticks) + ax.get_xaxis().set_visible(False) + ax.set(yticks=[np.argmin(np.abs(freqs - val)) for val in freqs], yticklabels=freqs) ax.grid(False) + return im -fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) -plot_timefrequency( - mwt.index.values[:], - freqs[:], - np.transpose(mwt[:, :].values), - ax=ax, -) -ax.set_title("Wavelet Decomposition Scalogram") +fig = plt.figure(constrained_layout=True, figsize=(10, 6)) +fig.suptitle("Wavelet Decomposition") +gs = plt.GridSpec(2, 1, figure=fig, height_ratios=[1.0, 0.3]) + +ax0 = plt.subplot(gs[0, 0]) +im = plot_timefrequency(freqs[:], np.transpose(mwt[:, :].values), ax=ax0) +cbar = fig.colorbar(im, ax=ax0, orientation="vertical") + +ax1 = plt.subplot(gs[1, 0]) +ax1.plot(sig) +ax1.set_ylabel("LFP (a.u.)") +ax1.set_xlabel("Time (s)") +ax1.margins(0) # %% # *** @@ -162,7 +153,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # %% # Lets plot it. -fig = plt.figure(constrained_layout=True, figsize=(10, 6)) +fig = plt.figure(constrained_layout=True, figsize=(10, 4)) axd = fig.subplot_mosaic( [["signal"], ["phase"]], height_ratios=[1, 0.4], @@ -170,20 +161,12 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): axd["signal"].plot(sig, label="Raw Signal", alpha=0.5) axd["signal"].plot(t, slow_oscillation, label="2Hz Reconstruction") axd["signal"].legend() +axd["signal"].set_ylabel("Signal") + axd["phase"].plot(t, slow_oscillation_phase, alpha=0.5) axd["phase"].set_ylabel("Phase (rad)") -axd["signal"].set_ylabel("Signal") axd["phase"].set_xlabel("Time (s)") -[ - axd[f].spines[sp].set_visible(False) - for sp in ["right", "top"] - for f in ["phase", "signal"] -] -axd["signal"].get_xaxis().set_visible(False) -axd["signal"].spines["bottom"].set_visible(False) [axd[k].margins(0) for k in ["signal", "phase"]] -axd["signal"].set_ylim(-2.5, 2.5) -axd["phase"].set_ylim(-np.pi, np.pi) # %% # *** @@ -205,21 +188,24 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # %% # Lets plot it. -fig, ax = plt.subplots(2, constrained_layout=True, figsize=(10, 6)) -ax[0].plot(t, fifteenHz_oscillation, label="15Hz Reconstruction") -ax[0].plot(t, fifteenHz_oscillation_power, label="15Hz Power") -ax[1].plot(sig, label="Raw Signal", alpha=0.5) -ax[1].plot( - t, slow_oscillation + fifteenHz_oscillation, label="2Hz + 15Hz Reconstruction" -) -[ax[i].set_ylim(-2.5, 2.5) for i in range(2)] -[ax[i].margins(0) for i in range(2)] -[ax[i].legend() for i in range(2)] -[ax[i].spines[sp].set_visible(False) for sp in ["right", "top"] for i in range(2)] -ax[0].get_xaxis().set_visible(False) -ax[0].spines["bottom"].set_visible(False) -ax[1].set_xlabel("Time (s)") -[ax[i].set_ylabel("Signal") for i in range(2)] + +fig = plt.figure(constrained_layout=True, figsize=(10, 4)) +gs = plt.GridSpec(2, 1, figure=fig, height_ratios=[1.0, 1.0]) + +ax0 = plt.subplot(gs[0, 0]) +ax0.plot(t, fifteenHz_oscillation, label="15Hz Reconstruction") +ax0.plot(t, fifteenHz_oscillation_power, label="15Hz Power") +ax0.set_xticklabels([]) + +ax1 = plt.subplot(gs[1, 0]) +ax1.plot(sig, label="Raw Signal", alpha=0.5) +ax1.plot(t, slow_oscillation + fifteenHz_oscillation, label="2Hz + 15Hz Reconstruction") +ax1.set_xlabel("Time (s)") + +[ + (a.margins(0), a.legend(), a.set_ylim(-2.5, 2.5), a.set_ylabel("Signal")) + for a in [ax0, ax1] +] # %% @@ -235,7 +221,6 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) ax.plot(sig, alpha=0.5, label="Signal") ax.plot(t, combined_oscillations, label="Wavelet Reconstruction", alpha=0.5) -[ax.spines[sp].set_visible(False) for sp in ["right", "top"]] ax.set_xlabel("Time (s)") ax.set_ylabel("Signal") ax.set_title("Wavelet Reconstruction of Signal") @@ -250,50 +235,59 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # ------------------ # Our reconstruction seems to get the amplitude modulations of our signal correct, but the amplitude is overestimated, # in particular towards the end of the time period. Often, this is due to a suboptimal choice of parameters, which -# can lead to a low spatial or temporal resolution. Let's explore what changing our parameters does to the +# can lead to a low spatial or temporal resolution. Let's visualize what changing our parameters does to the # underlying wavelets. -freqs = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) -window_lengths = [1.0, 2.0, 3.0] -gaussian_width = [1.0, 2.0, 3.0] - +window_lengths = [1.0, 3.0] +gaussian_widths = [1.0, 3.0] +colors = np.array([["r", "g"], ["b", "y"]]) fig, ax = plt.subplots( - len(window_lengths), len(gaussian_width), constrained_layout=True, figsize=(10, 8) + len(window_lengths) + 1, + len(gaussian_widths) + 1, + constrained_layout=True, + figsize=(10, 8), ) for row_i, wl in enumerate(window_lengths): - for col_i, gw in enumerate(gaussian_width): - filter_bank = nap.generate_morlet_filterbank( - freqs, 1000, gaussian_width=gw, window_length=wl, precision=12 - ) - ax[row_i, col_i].plot(filter_bank[:, 0].real()) - ax[row_i, col_i].set_xlabel("Time (s)") - ax[row_i, col_i].set_yticks([]) - [ - ax[row_i, col_i].spines[sp].set_visible(False) - for sp in ["top", "right", "left"] - ] - if col_i != 0: - ax[row_i, col_i].get_yaxis().set_visible(False) -for col_i, gw in enumerate(gaussian_width): - ax[0, col_i].set_title(f"gaussian_width={gw}", fontsize=10) -for row_i, wl in enumerate(window_lengths): - ax[row_i, 0].set_ylabel(f"window_length={wl}", fontsize=10) -fig.suptitle("Parametrization Visualization") - + for col_i, gw in enumerate(gaussian_widths): + wavelet = nap.generate_morlet_filterbank( + np.array([1.0]), 1000, gaussian_width=gw, window_length=wl, precision=12 + )[:, 0].real() + ax[row_i, col_i].plot(wavelet, c=colors[row_i, col_i]) + fft = nap.compute_power_spectral_density(wavelet) + for i, j in [(row_i, -1), (-1, col_i)]: + ax[i, j].plot(fft.abs(), c=colors[row_i, col_i]) +for i in range(len(window_lengths)): + for j in range(len(gaussian_widths)): + ax[i, j].set(xlabel="Time (s)", yticks=[]) +for ci, gw in enumerate(gaussian_widths): + ax[0, ci].set_title(f"gaussian_width={gw}", fontsize=10) +for ri, wl in enumerate(window_lengths): + ax[ri, 0].set_ylabel(f"window_length={wl}", fontsize=10) +fig.suptitle("Parametrization Visualization (1 Hz Wavelet)") +ax[-1, -1].set_visible(False) +for i in range(len(window_lengths)): + ax[-1, i].set( + xlim=(0, 2), yticks=[], ylabel="Frequency Response", xlabel="Frequency (Hz)" + ) +for i in range(len(gaussian_widths)): + ax[i, -1].set( + xlim=(0, 2), yticks=[], ylabel="Frequency Response", xlabel="Frequency (Hz)" + ) # %% -# Increasing time_decay increases the number of wavelet cycles present in the oscillations (cycles) within the -# Gaussian window of the Morlet wavelet. It essentially controls the trade-off between time resolution -# and frequency resolution. +# Increasing window_length increases the number of wavelet cycles present in the oscillations (cycles), and +# correspondingly increases the time window that the wavelet covers. # -# The scale parameter determines the dilation or compression of the wavelet. It controls the size of the wavelet in -# time, affecting the overall shape of the wavelet. +# The gaussian_width parameter determines the shape of the gaussian window being convolved with the sinusoidal +# component of the wavelet +# +# Both of these parameters can be tweaked to control for the trade-off between time resolution and frequency resolution. # %% # *** # Effect of gaussian_width # ------------------ -# Let's increase time_decay to 7.5 and see the effect on the resultant filter bank. +# Let's increase gaussian_width to 7.5 and see the effect on the resultant filter bank. freqs = np.linspace(1, 25, num=25) filter_bank = nap.generate_morlet_filterbank( @@ -313,14 +307,19 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): sig, fs=Fs, freqs=freqs, gaussian_width=7.5, window_length=1.0 ) -fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) -plot_timefrequency( - mwt.index.values[:], - freqs[:], - np.transpose(mwt[:, :].values), - ax=ax, -) -ax.set_title("Wavelet Decomposition Scalogram") +fig = plt.figure(constrained_layout=True, figsize=(10, 6)) +fig.suptitle("Wavelet Decomposition") +gs = plt.GridSpec(2, 1, figure=fig, height_ratios=[1.0, 0.3]) + +ax0 = plt.subplot(gs[0, 0]) +im = plot_timefrequency(freqs[:], np.transpose(mwt[:, :].values), ax=ax0) +cbar = fig.colorbar(im, ax=ax0, orientation="vertical") + +ax1 = plt.subplot(gs[1, 0]) +ax1.plot(sig) +ax1.set_ylabel("LFP (a.u.)") +ax1.set_xlabel("Time (s)") +ax1.margins(0) # %% # *** @@ -365,17 +364,23 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # %% # *** # Let's see what effect this has on the Wavelet Scalogram which is generated... -fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) mwt = nap.compute_wavelet_transform( sig, fs=Fs, freqs=freqs, gaussian_width=7.5, window_length=2.0 ) -plot_timefrequency( - mwt.index.values[:], - freqs[:], - np.transpose(mwt[:, :].values), - ax=ax, -) -ax.set_title("Wavelet Decomposition Scalogram") + +fig = plt.figure(constrained_layout=True, figsize=(10, 6)) +fig.suptitle("Wavelet Decomposition") +gs = plt.GridSpec(2, 1, figure=fig, height_ratios=[1.0, 0.3]) + +ax0 = plt.subplot(gs[0, 0]) +im = plot_timefrequency(freqs[:], np.transpose(mwt[:, :].values), ax=ax0) +cbar = fig.colorbar(im, ax=ax0, orientation="vertical") + +ax1 = plt.subplot(gs[1, 0]) +ax1.plot(sig) +ax1.set_ylabel("LFP (a.u.)") +ax1.set_xlabel("Time (s)") +ax1.margins(0) # %% # *** @@ -388,10 +393,91 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) ax.plot(sig, alpha=0.5, label="Signal") ax.plot(t, combined_oscillations, label="Wavelet Reconstruction", alpha=0.5) -[ax.spines[sp].set_visible(False) for sp in ["right", "top"]] ax.set_xlabel("Time (s)") ax.set_ylabel("Signal") ax.set_title("Wavelet Reconstruction of Signal") ax.margins(0) ax.set_ylim(-6, 6) ax.legend() + + +# %% +# *** +# Effect of L1 vs L2 normalization +# ------------------ +# compute_wavelet_transform contains two options for normalization; L1, and L2. L1 normalization. +# By default, L1 is used as it creates cleaner looking decomposition images. +# +# L1 normalization often increases the contrast between significant and insignificant coefficients. +# This can result in a sharper and more defined visual representation, making patterns and structures within +# the signal more evident. +# +# L2 normalization is directly related to the energy of the signal. By normalizing using the +# L2 norm, you ensure that the transformed coefficients preserve the energy distribution of the original signal. +# +# Let's compare two wavelet decomposition images, each generated with a different normalization strategy + +mwt_l1 = nap.compute_wavelet_transform( + sig, fs=Fs, freqs=freqs, gaussian_width=7.5, window_length=2.0, norm="l1" +) +mwt_l2 = nap.compute_wavelet_transform( + sig, fs=Fs, freqs=freqs, gaussian_width=7.5, window_length=2.0, norm="l2" +) + +# %% +# Let's plot both the scalograms and see the difference. + +fig = plt.figure(constrained_layout=True, figsize=(10, 6)) +fig.suptitle("Wavelet Decomposition - L1 Normalization") +gs = plt.GridSpec(2, 1, figure=fig, height_ratios=[1.0, 0.3]) +ax0 = plt.subplot(gs[0, 0]) +im = plot_timefrequency(freqs[:], np.transpose(mwt_l1[:, :].values), ax=ax0) +cbar = fig.colorbar(im, ax=ax0, orientation="vertical") +ax1 = plt.subplot(gs[1, 0]) +ax1.plot(sig) +ax1.set_ylabel("LFP (a.u.)") +ax1.set_xlabel("Time (s)") +ax1.margins(0) + +fig = plt.figure(constrained_layout=True, figsize=(10, 6)) +fig.suptitle("Wavelet Decomposition - L2 Normalization") +gs = plt.GridSpec(2, 1, figure=fig, height_ratios=[1.0, 0.3]) +ax0 = plt.subplot(gs[0, 0]) +im = plot_timefrequency(freqs[:], np.transpose(mwt_l2[:, :].values), ax=ax0) +cbar = fig.colorbar(im, ax=ax0, orientation="vertical") +ax1 = plt.subplot(gs[1, 0]) +ax1.plot(sig) +ax1.set_ylabel("LFP (a.u.)") +ax1.set_xlabel("Time (s)") +ax1.margins(0) + +# %% +# We see that the l1 normalized image contains a visually clearer image; the 5-15 Hz component of the signal is +# as powerful as the 2 Hz component, so it makes sense that they should be shown with the same power in the scalogram. +# Let's reconstruct the signal using both decompositions and see the resulting reconstruction... + +# %% + +combined_oscillations_l1 = mwt_l1.sum(axis=1).real() +combined_oscillations_l2 = mwt_l2.sum(axis=1).real() + +# %% +# Lets plot it. +fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) +ax.plot(sig, label="Signal", linewidth=3, alpha=0.6, c="b") +ax.plot( + t, combined_oscillations_l1, label="Wavelet Reconstruction (L1)", c="g", alpha=0.6 +) +ax.plot( + t, combined_oscillations_l2, label="Wavelet Reconstruction (L2)", c="r", alpha=0.6 +) +ax.set_xlabel("Time (s)") +ax.set_ylabel("Signal") +ax.set_title("Wavelet Reconstruction of Signal") +ax.margins(0) +ax.set_ylim(-6, 6) +ax.legend() + +# %% +# We see that the reconstruction from the L2 normalized decomposition matched the original signal much more closely, +# this is due to the fact that L2 normalization preserved the energy of the original signal in its reconstruction. diff --git a/docs/examples/tutorial_signal_processing.py b/docs/examples/tutorial_signal_processing.py index ac3dc3c9..0e874085 100644 --- a/docs/examples/tutorial_signal_processing.py +++ b/docs/examples/tutorial_signal_processing.py @@ -184,8 +184,6 @@ ax1.set_xlabel("Time (s)") ax1.set_ylabel("Pos.") -plt.show() - # %% # *** From 8d78bb52aa4581e666356a020639e91622fbc132 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Tue, 6 Aug 2024 22:00:29 +0100 Subject: [PATCH 60/71] doc plot neatening --- docs/api_guide/tutorial_pynapple_wavelets.py | 2 + docs/examples/tutorial_phase_preferences.py | 186 +++++++++---------- docs/examples/tutorial_signal_processing.py | 2 + 3 files changed, 92 insertions(+), 98 deletions(-) diff --git a/docs/api_guide/tutorial_pynapple_wavelets.py b/docs/api_guide/tutorial_pynapple_wavelets.py index f4a20b6d..94a74e97 100644 --- a/docs/api_guide/tutorial_pynapple_wavelets.py +++ b/docs/api_guide/tutorial_pynapple_wavelets.py @@ -17,6 +17,8 @@ # # You can install all with `pip install matplotlib requests tqdm seaborn` # +# mkdocs_gallery_thumbnail_number = 9 +# # Now, import the necessary libraries: import matplotlib.pyplot as plt diff --git a/docs/examples/tutorial_phase_preferences.py b/docs/examples/tutorial_phase_preferences.py index 18e311c9..15d44dda 100644 --- a/docs/examples/tutorial_phase_preferences.py +++ b/docs/examples/tutorial_phase_preferences.py @@ -17,6 +17,8 @@ # # You can install all with `pip install matplotlib requests tqdm seaborn` # +# mkdocs_gallery_thumbnail_number = 6 +# # First, import the necessary libraries: import math @@ -93,10 +95,9 @@ ax.plot( REM_Tsd, label="REM LFP Data", - color="blue", ) ax.set_title("REM Local Field Potential") -ax.set_ylabel("LFP (v)") +ax.set_ylabel("LFP (a.u.)") ax.set_xlabel("time (s)") ax.margins(0) ax.legend() @@ -122,47 +123,34 @@ # Define wavelet decomposition plotting function -def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): - if np.iscomplexobj(powers): - powers = abs(powers) - ax.imshow(powers, aspect="auto", **kwargs) +def plot_timefrequency(freqs, powers, ax=None): + im = ax.imshow(abs(powers), aspect="auto") ax.invert_yaxis() ax.set_xlabel("Time (s)") ax.set_ylabel("Frequency (Hz)") - if isinstance(x_ticks, int): - x_tick_pos = np.linspace(0, times.size, x_ticks) - x_ticks = np.round(np.linspace(times[0], times[-1], x_ticks), 2) - else: - x_tick_pos = [np.argmin(np.abs(times - val)) for val in x_ticks] - ax.set(xticks=x_tick_pos, xticklabels=x_ticks) - y_ticks = [np.round(f, 2) for f in freqs] - y_ticks_pos = [np.argmin(np.abs(freqs - val)) for val in y_ticks] - ax.set(yticks=y_ticks_pos, yticklabels=y_ticks) + ax.get_xaxis().set_visible(False) + ax.set( + yticks=[np.argmin(np.abs(freqs - val)) for val in freqs], + yticklabels=np.round(freqs, 2), + ) ax.grid(False) + return im -# And plot it fig = plt.figure(constrained_layout=True, figsize=(10, 6)) -axd = fig.subplot_mosaic( - [ - ["wd_rem"], - ["lfp_rem"], - ], - height_ratios=[1, 0.2], -) -plot_timefrequency( - REM_Tsd.index.values[:], - freqs[:], - np.transpose(mwt_REM[:, :].values), - ax=axd["wd_rem"], -) -axd["wd_rem"].set_title(f"Wavelet Decomposition") -axd["lfp_rem"].plot(REM_Tsd) -axd["lfp_rem"].margins(0) -axd["lfp_rem"].set_ylabel("LFP (v)") -axd["lfp_rem"].get_xaxis().set_visible(False) -for spine in ["top", "right", "bottom", "left"]: - axd["lfp_rem"].spines[spine].set_visible(False) +fig.suptitle("Wavelet Decomposition") +gs = plt.GridSpec(2, 1, figure=fig, height_ratios=[1.0, 0.3]) + +ax0 = plt.subplot(gs[0, 0]) +im = plot_timefrequency(freqs[:], np.transpose(mwt_REM[:, :].values), ax=ax0) +cbar = fig.colorbar(im, ax=ax0, orientation="vertical") + +ax1 = plt.subplot(gs[1, 0]) +ax1.plot(REM_Tsd) +ax1.set_ylabel("LFP (a.u.)") +ax1.set_xlabel("Time (s)") +ax1.margins(0) + # %% # *** @@ -171,7 +159,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # There seems to be a strong theta frequency present in the data during the maze traversal. # Let's plot the estimated 7Hz component of the wavelet decomposition on top of our data, and see how well # they match up. We will also extract and plot the phase of the 7Hz wavelet from the decomposition. -theta_freq_index = np.argmin(np.abs(7 - freqs)) +theta_freq_index = np.argmin(np.abs(8 - freqs)) theta_band_reconstruction = mwt_REM[:, theta_freq_index].values.real # calculating phase here theta_band_phase = nap.Tsd( @@ -182,31 +170,29 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # *** # Now let's plot the theta power and phase, along with the LFP. -fig = plt.figure(constrained_layout=True, figsize=(10, 5)) -axd = fig.subplot_mosaic( - [ - ["theta_pow"], - ["phase"], - ], - height_ratios=[0.4, 0.2], +fig, (ax1, ax2) = plt.subplots( + 2, 1, constrained_layout=True, figsize=(10, 5), height_ratios=[0.4, 0.2] ) -axd["theta_pow"].plot(REM_Tsd, alpha=0.5, label="LFP Data - REM") -axd["theta_pow"].plot( +ax1.plot(REM_Tsd, alpha=0.5, label="LFP Data - REM") +ax1.plot( REM_Tsd.index.values, theta_band_reconstruction, label=f"{np.round(freqs[theta_freq_index], 2)}Hz oscillations", ) -axd["theta_pow"].set_ylabel("LFP (v)") -axd["theta_pow"].set_xlabel("Time (s)") -axd["theta_pow"].set_title( - f"{np.round(freqs[theta_freq_index],2)}Hz oscillation power." -) # -axd["theta_pow"].legend() -axd["phase"].plot(REM_Tsd.index.values, theta_band_phase, alpha=0.5) -[axd[k].margins(0) for k in ["theta_pow", "phase"]] -axd["phase"].set_ylabel("Phase") -axd["phase"].get_xaxis().set_visible(False) +ax1.set( + ylabel="LFP (v)", + title=f"{np.round(freqs[theta_freq_index], 2)}Hz oscillation power.", +) +ax1.get_xaxis().set_visible(False) +ax1.legend() +ax1.margins(0) + +ax2.plot(REM_Tsd.index.values, theta_band_phase, alpha=0.5) +ax2.set(ylabel="Phase", xlabel="Time (s)") +ax2.margins(0) + +plt.show() # %% @@ -232,13 +218,11 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): def smoothAngularTuningCurves(tuning_curves, sigma=2): - tmp = np.concatenate( - (tuning_curves.values, tuning_curves.values, tuning_curves.values) - ) + tmp = np.concatenate([tuning_curves.values] * 3) tmp = scipy.ndimage.gaussian_filter1d(tmp, sigma=sigma, axis=0) return pd.DataFrame( + tmp[tuning_curves.shape[0] : 2 * tuning_curves.shape[0]], index=tuning_curves.index, - data=tmp[tuning_curves.shape[0] : tuning_curves.shape[0] * 2], columns=tuning_curves.columns, ) @@ -250,18 +234,18 @@ def smoothAngularTuningCurves(tuning_curves, sigma=2): figsize=(10, 6), subplot_kw={"projection": "polar"}, ) -for pl_i, sc_i in enumerate(list(smoothcurves)[:6]): - axd[f"phase_{pl_i}"].plot( - list(smoothcurves[sc_i].index) + list([smoothcurves[sc_i].index[0]]), - list(smoothcurves[sc_i].values) + list([smoothcurves[sc_i].values[0]]), + +for i, unit in enumerate(list(smoothcurves)[:6]): + ax = axd[f"phase_{i}"] + ax.plot( + list(smoothcurves[unit].index) + [smoothcurves[unit].index[0]], + list(smoothcurves[unit].values) + [smoothcurves[unit].values[0]], ) - axd[f"phase_{pl_i}"].set_xlabel("Phase (rad)") # Angle in radian, on the X-axis - axd[f"phase_{pl_i}"].set_ylabel( - "Firing Rate (Hz)" - ) # Firing rate in Hz, on the Y-axis - axd[f"phase_{pl_i}"].set_xticks([]) - axd[f"phase_{pl_i}"].set_title(f"Unit {sc_i}") + ax.set(xlabel="Phase (rad)", ylabel="Firing Rate (Hz)", title=f"Unit {unit}") + ax.set_xticks([]) + fig.suptitle("Phase Preference Histograms of First 6 Units") +plt.show() # %% @@ -297,18 +281,18 @@ def smoothAngularTuningCurves(tuning_curves, sigma=2): figsize=(10, 6), subplot_kw={"projection": "polar"}, ) -for pl_i, sc_i in enumerate(list(phase_var.keys())[:6]): - axd[f"phase_{pl_i}"].plot( - list(smoothcurves[sc_i].index) + list([smoothcurves[sc_i].index[0]]), - list(smoothcurves[sc_i].values) + list([smoothcurves[sc_i].values[0]]), + +for i, unit in enumerate(list(phase_var.keys())[:6]): + ax = axd[f"phase_{i}"] + ax.plot( + list(smoothcurves[unit].index) + [smoothcurves[unit].index[0]], + list(smoothcurves[unit].values) + [smoothcurves[unit].values[0]], ) - axd[f"phase_{pl_i}"].set_xlabel("Phase (rad)") # Angle in radian, on the X-axis - axd[f"phase_{pl_i}"].set_ylabel( - "Firing Rate (Hz)" - ) # Firing rate in Hz, on the Y-axis - axd[f"phase_{pl_i}"].set_xticks([]) - axd[f"phase_{pl_i}"].set_title(f"Unit {sc_i}") -fig.suptitle("Phase Preference Histograms of 6 Units with Highest Phase Preference ") + ax.set(xlabel="Phase (rad)", ylabel="Firing Rate (Hz)", title=f"Unit {unit}") + ax.set_xticks([]) + +fig.suptitle("Phase Preference Histograms of 6 Units with Highest Phase Preference") +plt.show() # %% # *** @@ -317,34 +301,40 @@ def smoothAngularTuningCurves(tuning_curves, sigma=2): # There is definitely some strong phase preferences happening here. Let's visualize the firing preferences # of the 6 cells we've isolated to get an impression of just how striking these preferences are. -fig = plt.figure(constrained_layout=True, figsize=(10, 8)) -axd = fig.subplot_mosaic( +fig, axd = plt.subplot_mosaic( [ ["lfp_run"], ["phase_0"], ["phase_1"], ["phase_2"], ], + constrained_layout=True, + figsize=(10, 8), height_ratios=[0.4, 0.2, 0.2, 0.2], ) -[axd[k].margins(0) for k in ["lfp_run"] + [f"phase_{i}" for i in range(3)]] -axd["lfp_run"].plot( - REM_Tsd.index.values, REM_Tsd[:, 0], alpha=0.5, label="LFP Data - REM" -) + +REM_index = REM_Tsd.index.values +axd["lfp_run"].plot(REM_index, REM_Tsd[:, 0], alpha=0.5, label="LFP Data - REM") axd["lfp_run"].plot( - REM_Tsd.index.values, + REM_index, theta_band_reconstruction, - label=f"{np.round(freqs[theta_freq_index],2)}Hz oscillations", + label=f"{np.round(freqs[theta_freq_index], 2)}Hz oscillations", +) +axd["lfp_run"].set( + ylabel="LFP (v)", + xlabel="Time (s)", + title=f"{np.round(freqs[theta_freq_index], 2)}Hz oscillation power.", ) -axd["lfp_run"].set_ylabel("LFP (v)") -axd["lfp_run"].set_xlabel("Time (s)") -axd["lfp_run"].set_title(f"{np.round(freqs[theta_freq_index],2)}Hz oscillation power.") axd["lfp_run"].legend() +axd["lfp_run"].margins(0) + for i in range(3): - axd[f"phase_{i}"].plot(REM_Tsd.index.values, theta_band_phase, alpha=0.2) - axd[f"phase_{i}"].scatter( - spikes[list(phase_var.keys())[i]].index, phase[list(phase_var.keys())[i]] - ) - axd[f"phase_{i}"].set_ylabel("Phase") - axd[f"phase_{i}"].set_title(f"Unit {list(phase_var.keys())[i]}") + unit_key = list(phase_var.keys())[i] + ax = axd[f"phase_{i}"] + ax.plot(REM_index, theta_band_phase, alpha=0.2) + ax.scatter(spikes[unit_key].index, phase[unit_key]) + ax.set(ylabel="Phase", title=f"Unit {unit_key}") + ax.margins(0) + fig.suptitle("Phase Preference Visualizations") +plt.show() diff --git a/docs/examples/tutorial_signal_processing.py b/docs/examples/tutorial_signal_processing.py index 0e874085..0fbc8942 100644 --- a/docs/examples/tutorial_signal_processing.py +++ b/docs/examples/tutorial_signal_processing.py @@ -18,6 +18,8 @@ # # You can install all with `pip install matplotlib requests tqdm seaborn` # +# mkdocs_gallery_thumbnail_number = 7 +# # First, import the necessary libraries: import math From 7fee0d610fa40afabf53b3d55b3a9cea7bec9d7e Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Tue, 6 Aug 2024 17:50:45 -0400 Subject: [PATCH 61/71] Update tutorial_signal_processing.py and tests --- docs/examples/tutorial_signal_processing.py | 287 +++++++++++--------- pynapple/process/signal_processing.py | 52 +++- tests/test_power_spectral_density.py | 116 +++++++- 3 files changed, 302 insertions(+), 153 deletions(-) diff --git a/docs/examples/tutorial_signal_processing.py b/docs/examples/tutorial_signal_processing.py index ac3dc3c9..71086630 100644 --- a/docs/examples/tutorial_signal_processing.py +++ b/docs/examples/tutorial_signal_processing.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- """ -Computing Wavelet Transform +Wavelet Transform ============ This tutorial demonstrates how we can use the signal processing tools within Pynapple to aid with data analysis. We will examine the dataset from [Grosmark & Buzsáki (2016)](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4919122/). @@ -18,7 +18,10 @@ # # You can install all with `pip install matplotlib requests tqdm seaborn` # +# # First, import the necessary libraries: +# +# mkdocs_gallery_thumbnail_number = 6 import math import os @@ -118,41 +121,83 @@ # Let's take the Fourier transform of our data to get an initial insight into the dominant frequencies during exploration (`wake_ep`). -power = nap.compute_power_spectral_density(eeg, fs=FS, ep=wake_ep) -print(power) +power = nap.compute_power_spectral_density(eeg, fs=FS, ep=wake_ep, norm=True) +print(power) # %% # *** # The returned object is a pandas dataframe which uses frequencies as indexes and spectral power as values. # # Let's plot the power between 1 and 100 Hz. -# -# The red area outlines the theta rhythm (6-12 Hz) which is proeminent in hippocampal LFP. fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) -ax.semilogy( - np.abs(power[(power.index > 1.0) & (power.index < 100)]), +ax.plot( + np.abs(power[(power.index >= 1.0) & (power.index <= 100)]), alpha=0.5, label="LFP Frequency Power", ) -ax.axvspan(6, 12, color="red", alpha=0.1) +ax.axvspan(6, 10, color="red", alpha=0.1) ax.set_xlabel("Freq (Hz)") ax.set_ylabel("Frequency Power") ax.set_title("LFP Fourier Decomposition") ax.legend() + +# %% +# The red area outlines the theta rhythm (6-10 Hz) which is proeminent in hippocampal LFP. +# Hippocampal theta rhythm appears mostly when the animal is running. +# (See Hasselmo, M. E., & Stern, C. E. (2014). Theta rhythm and the encoding and retrieval of space and time. Neuroimage, 85, 656-666.) +# We can check it here by separating `wake_ep` into `run_ep` and `rest_ep`. +run_ep = data['position'].dropna().find_support(1) +rest_ep = wake_ep.set_diff(run_ep) + +# %% +# `run_ep` and `rest_ep` are IntervalSet with discontinuous epoch. +# The function `nap.compute_power_spectral_density` takes signal with a single epoch to avoid artefacts between epochs jumps. +# To compare `run_ep` with `rest_ep`, we can use the function `nap.compute_mean_power_spectral_density` which avearge the FFT over multiple epochs of same duration. The parameter `interval_size` controls the duration of those epochs. +# +# In this case, `interval_size` is equal to 1.5 seconds. + +power_run = nap.compute_mean_power_spectral_density(eeg, 1.5, fs=FS, ep=run_ep, norm=True) +power_rest = nap.compute_mean_power_spectral_density(eeg, 1.5, fs=FS, ep=rest_ep, norm=True) + +# %% +# `power_run` and `power_rest` are the power spectral density when the animal is respectively running and resting. + +fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) +ax.plot( + np.abs(power_run[(power_run.index >= 3.0) & (power_run.index <= 30)]), + alpha=1, + label="Run", + linewidth=2 +) +ax.plot( + np.abs(power_rest[(power_rest.index >= 3.0) & (power_rest.index <= 30)]), + alpha=1, + label="Rest", + linewidth=2 +) +ax.axvspan(6, 10, color="red", alpha=0.1) +ax.set_xlabel("Freq (Hz)") +ax.set_ylabel("Frequency Power") +ax.set_title("LFP Fourier Decomposition") +ax.legend() + + # %% # *** # Getting the Wavelet Decomposition # ----------------------------------- -# It looks like the prominent frequencies in the data may vary over time. For example, it looks like the -# LFP characteristics may be different while the animal is running along the track, and when it is finished. +# Overall, the prominent frequencies in the data vary over time. The LFP characteristics may be different when the animal is running along the track, and when it is finished. # Let's generate a wavelet decomposition to look more closely at the changing frequency powers over time. # We must define the frequency set that we'd like to use for our decomposition freqs = np.geomspace(3, 250, 100) + +# %% # Compute and print the wavelet transform on our LFP data + mwt_RUN = nap.compute_wavelet_transform(eeg_example, fs=FS, freqs=freqs) @@ -172,8 +217,9 @@ ax0.grid(False) ax0.set_yscale("log") ax0.set_title("Wavelet Decomposition") +ax0.set_ylabel("Frequency (Hz)") cbar = plt.colorbar(pcmesh, ax=ax0, orientation="vertical") -ax0.set_label("Amplitude") +ax0.set_ylabel("Amplitude") ax1 = plt.subplot(gs[1, 0], sharex=ax0) ax1.plot(eeg_example) @@ -184,7 +230,6 @@ ax1.set_xlabel("Time (s)") ax1.set_ylabel("Pos.") -plt.show() # %% @@ -192,14 +237,14 @@ # Visualizing Theta Band Power # ----------------------------------- # There seems to be a strong theta frequency present in the data during the maze traversal. -# Let's plot the estimated 8Hz component of the wavelet decomposition on top of our data, and see how well -# they match up +# Let's plot the estimated 6-10Hz component of the wavelet decomposition on top of our data, and see how well they match up. + +theta_freq_index = np.logical_and(freqs>6, freqs<10) + -# Find the index of the frequency closest to theta band -theta_freq_index = np.argmin(np.abs(8 - freqs)) # Extract its real component, as well as its power envelope -theta_band_reconstruction = mwt_RUN[:, theta_freq_index].values.real -theta_band_power_envelope = np.abs(mwt_RUN[:, theta_freq_index].values) +theta_band_reconstruction = np.mean(mwt_RUN[:,theta_freq_index], 1) +theta_band_power_envelope = np.abs(theta_band_reconstruction) # %% @@ -207,139 +252,111 @@ # Now let's visualise the theta band component of the signal over time. fig = plt.figure(constrained_layout=True, figsize=(10, 6)) -axd = fig.subplot_mosaic( - [["ephys"], ["pos"]], - height_ratios=[1, 0.4], -) -axd["ephys"].plot(eeg_example, label="CA1") -axd["ephys"].plot( - eeg_example.index.values, - theta_band_reconstruction, - label=f"{np.round(freqs[theta_freq_index], 2)}Hz oscillations", -) -axd["ephys"].plot( - eeg_example.index.values, - theta_band_power_envelope, - label=f"{np.round(freqs[theta_freq_index], 2)}Hz power envelope", -) -axd["ephys"].set_title("EEG (1250 Hz)") -axd["ephys"].set_ylabel("LFP (a.u.)") -axd["ephys"].set_xlabel("time (s)") -axd["ephys"].margins(0) -axd["ephys"].legend() -axd["pos"].plot(pos_example, color="black") -axd["pos"].margins(0) -axd["pos"].set_xlabel("time (s)") -axd["pos"].set_ylabel("Linearized Position") -axd["pos"].set_xlim(RUN_interval[0, 0], RUN_interval[0, 1]) - +gs = plt.GridSpec(2, 1, figure=fig, height_ratios=[1.0, 0.9]) +ax0 = plt.subplot(gs[0,0]) +ax0.plot(eeg_example, label="CA1") +ax0.set_title("EEG (1250 Hz)") +ax0.set_ylabel("LFP (a.u.)") +ax0.set_xlabel("time (s)") +ax0.legend() +ax1 = plt.subplot(gs[1,0]) +ax1.plot(np.real(theta_band_reconstruction), label="6-10 Hz oscillations") +ax1.plot(theta_band_power_envelope, label="6-10 Hz power envelope") +ax1.set_xlabel("time (s)") +ax1.set_ylabel("Wavelet transform") +ax1.legend() # %% # *** -# Visualizing Sharp Wave Ripple Power +# Visualizing high frequency oscillation # ----------------------------------- -# There also seem to be peaks in the 200Hz frequency power after traversal of thew maze is complete. -# Let's plot the LFP, along with the 200Hz frequency power, to see if we can isolate these peaks and -# see what's going on. +# There also seem to be peaks in the 200Hz frequency power after traversal of thew maze is complete. Here we use the interval (18356, 18357.5) seconds to zoom in. -# Find the index of the frequency closest to sharp wave ripple oscillations -ripple_freq_idx = np.argmin(np.abs(200 - freqs)) -# Extract its power envelope -ripple_power = np.abs(mwt_RUN[:, ripple_freq_idx].values) +zoom_ep = nap.IntervalSet(18356.0, 18357.5) +mwt_zoom = mwt_RUN.restrict(zoom_ep) + +fig = plt.figure(constrained_layout=True, figsize=(10, 6)) +gs = plt.GridSpec(2, 1, figure=fig, height_ratios=[1.0, 0.5]) +ax0 = plt.subplot(gs[0, 0]) +pcmesh = ax0.pcolormesh(mwt_zoom.t, freqs, np.transpose(np.abs(mwt_zoom))) +ax0.grid(False) +ax0.set_yscale("log") +ax0.set_title("Wavelet Decomposition") +ax0.set_ylabel("Frequency (Hz)") +cbar = plt.colorbar(pcmesh, ax=ax0, orientation="vertical") +ax0.set_label("Amplitude") + +ax1 = plt.subplot(gs[1, 0], sharex=ax0) +ax1.plot(eeg_example.restrict(zoom_ep)) +ax1.set_ylabel("LFP (a.u.)") +ax1.set_xlabel("Time (s)") # %% -# *** -# Now let's visualise the 200Hz component of the signal over time. +# Those events are called Sharp-waves ripples (See : Buzsáki, G. (2015). Hippocampal sharp wave‐ripple: A cognitive biomarker for episodic memory and planning. Hippocampus, 25(10), 1073-1188.) +# +# Among other methods, we can use the Wavelet decomposition to isolate them. In this case, we will look at the power of the wavelets for frequencies between 150 to 250 Hz. -fig = plt.figure(constrained_layout=True, figsize=(10, 5)) -axd = fig.subplot_mosaic( - [ - ["lfp_run"], - ["rip_pow"], - ], - height_ratios=[1, 0.4], -) -axd["lfp_run"].plot(eeg_example, label="LFP Data") -axd["rip_pow"].plot(eeg_example.index.values, ripple_power) -axd["lfp_run"].set_ylabel("LFP (v)") -axd["lfp_run"].set_xlabel("Time (s)") -axd["lfp_run"].margins(0) -axd["lfp_run"].set_title(f"EEG (1250 Hz)") -axd["rip_pow"].margins(0) -axd["rip_pow"].set_xlim(eeg_example.index.min(), eeg_example.index.max()) -axd["rip_pow"].set_ylabel(f"{np.round(freqs[ripple_freq_idx], 2)}Hz Power") +ripple_freq_index = np.logical_and(freqs>150, freqs<250) # %% -# *** -# Isolating Ripple Times -# ----------------------------------- -# We can see one significant peak in the 200Hz frequency power. Let's smooth the power curve and threshold -# to try to isolate this event. - -# Define threshold -threshold = 6000 -# Smooth wavelet power TsdFrame at the SWR frequency -smoother_swr_power = ( - mwt_RUN[:, ripple_freq_idx] - .abs() - .smooth(std=0.025, windowsize=0.2, time_units="s", norm=False) -) -# Threshold our TsdFrame -is_ripple = smoother_swr_power.threshold(threshold) +# We can compute the mean power for this frequency band. + +ripple_power = np.mean(np.abs(mwt_RUN[:, ripple_freq_index]), 1) # %% -# *** -# Now let's plot the threshold ripple power over time. +# Now let's visualise the 150-250 Hz mean amplitude of the wavelet decomposition over time fig = plt.figure(constrained_layout=True, figsize=(10, 5)) -axd = fig.subplot_mosaic( - [ - ["lfp_run"], - ["rip_pow"], - ], - height_ratios=[1, 0.4], -) -axd["lfp_run"].plot(eeg_example, label="LFP Data") -axd["rip_pow"].plot(smoother_swr_power) -axd["rip_pow"].axvspan( - is_ripple.index.min(), is_ripple.index.max(), color="red", alpha=0.3 -) -axd["lfp_run"].set_ylabel("LFP (v)") -axd["lfp_run"].set_xlabel("Time (s)") -axd["lfp_run"].set_title(f"EEG (1250 Hz)") -axd["rip_pow"].axhline(threshold, linestyle="--", color="black", alpha=0.4) -[axd[k].margins(0) for k in ["lfp_run", "rip_pow"]] -axd["rip_pow"].set_xlim(eeg_example.index.min(), eeg_example.index.max()) -axd["rip_pow"].set_ylabel(f"{np.round(freqs[ripple_freq_idx], 2)}Hz Power") +gs = plt.GridSpec(2, 1, figure=fig, height_ratios=[1.0, 0.5]) +ax0 = plt.subplot(gs[0,0]) +ax0.plot(eeg_example.restrict(zoom_ep), label="CA1") +ax0.set_ylabel("LFP (v)") +ax0.set_title(f"EEG (1250 Hz)") +ax1 = plt.subplot(gs[1,0]) +ax1.legend() +ax1.plot(ripple_power.restrict(zoom_ep), label="150-250 Hz") +ax1.legend() +ax1.set_ylabel("Mean Amplitude") +ax1.set_xlabel("Time (s)") + # %% -# *** -# Plotting a Sharp Wave Ripple -# ----------------------------------- -# Let's zoom in on out detected ripples and have a closer look! +# It is then easy to isolate ripple times by using the pynapple functions `smooth` and `threshold`. In the following lines, `ripples` is smoothed with a gaussian kernel of size 0.005 second and thesholded with a value of 100. +# + +smoothed_ripple_power = ripple_power.smooth(0.005) + +threshold_ripple_power = smoothed_ripple_power.threshold(100) + +# %% +# `threshold_ripple_power` contains all the time points above 100. The ripple epochs are contained in the `time_support` of the threshold time series. Here we call it `rip_ep`. + +rip_ep = threshold_ripple_power.time_support + + +# %% +# Now let's plot the ripples epoch as well as the smoothed ripple power. +# +# We can also plot `rip_ep` as vertical boxes to see if the detection is accurate + +fig = plt.figure(constrained_layout=True, figsize=(10, 5)) +gs = plt.GridSpec(2, 1, figure=fig, height_ratios=[1.0, 0.5]) +ax0 = plt.subplot(gs[0,0]) +ax0.plot(eeg_example.restrict(zoom_ep), label="CA1") +for s,e in rip_ep.intersect(zoom_ep).values: + ax0.axvspan(s, e, color='red', alpha=0.1, ec=None) +ax0.set_ylabel("LFP (v)") +ax0.set_title(f"EEG (1250 Hz)") +ax1 = plt.subplot(gs[1,0]) +ax1.legend() +ax1.plot(ripple_power.restrict(zoom_ep), label="150-250 Hz") +ax1.plot(smoothed_ripple_power.restrict(zoom_ep)) +for s,e in rip_ep.intersect(zoom_ep).values: + ax1.axvspan(s, e, color='red', alpha=0.1, ec=None) +ax1.legend() +ax1.set_ylabel("Mean Amplitude") +ax1.set_xlabel("Time (s)") + -fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) -buffer = 0.1 -ax.plot( - eeg_example.restrict( - nap.IntervalSet( - start=is_ripple.index.min() - buffer, end=is_ripple.index.max() + buffer - ) - ), - color="blue", - label="Non-SWR LFP", -) -ax.axvspan( - is_ripple.index.min(), - is_ripple.index.max(), - color="red", - alpha=0.3, - label="SWR LFP", -) -ax.margins(0) -ax.set_xlabel("Time (s)") -ax.set_ylabel("LFP (v)") -ax.legend() -ax.set_title("Sharp Wave Ripple Visualization") diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 7f921402..722abafa 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -8,11 +8,12 @@ import numpy as np import pandas as pd +from scipy import signal from .. import core as nap -def compute_power_spectral_density(sig, fs=None, ep=None, full_range=False): +def compute_power_spectral_density(sig, fs=None, ep=None, full_range=False, norm=False): """ Performs numpy fft on sig, returns output. Pynapple assumes a constant sampling rate for sig. @@ -26,6 +27,8 @@ def compute_power_spectral_density(sig, fs=None, ep=None, full_range=False): The epoch to calculate the fft on. Must be length 1. full_range : bool, optional If true, will return full fft frequency range, otherwise will return only positive values + norm: bool, optional + Whether the FFT result is divided by the length of the signal to normalize the amplitude Returns ------- @@ -39,9 +42,9 @@ def compute_power_spectral_density(sig, fs=None, ep=None, full_range=False): parameter otherwise will be sig.time_support, but it must only be a single epoch. """ if not isinstance(sig, (nap.Tsd, nap.TsdFrame)): - raise TypeError( - "Currently compute_spectogram is only implemented for Tsd or TsdFrame" - ) + raise TypeError("sig must be either a Tsd or a TsdFrame object.") + if not (fs is None or isinstance(fs, Number)): + raise TypeError("fs must be of type float or int") if not (ep is None or isinstance(ep, nap.IntervalSet)): raise TypeError("ep param must be a pynapple IntervalSet object, or None") if ep is None: @@ -50,22 +53,40 @@ def compute_power_spectral_density(sig, fs=None, ep=None, full_range=False): raise ValueError("Given epoch (or signal time_support) must have length 1") if fs is None: fs = sig.rate + if not isinstance(full_range, bool): + raise TypeError("full_range must be of type bool or None") + if not isinstance(norm, bool): + raise TypeError("norm must be of type bool") + fft_result = np.fft.fft(sig.restrict(ep).values, axis=0) fft_freq = np.fft.fftfreq(len(sig.restrict(ep).values), 1 / fs) + + if norm: + fft_result = fft_result / fft_result.shape[0] + ret = pd.DataFrame(fft_result, fft_freq) ret.sort_index(inplace=True) + if not full_range: return ret.loc[ret.index >= 0] return ret def compute_mean_power_spectral_density( - sig, interval_size, fs=None, ep=None, full_range=False, time_unit="s" + sig, + interval_size, + fs=None, + ep=None, + full_range=False, + norm=False, + time_unit="s", ): """Compute mean power spectral density by averaging FFT over epochs of same size. The parameter `interval_size` controls the duration of the epochs. - Note that this function assumes a constant sampling rate for sig. + To imporve frequency resolution, the signal is multiplied by a Hamming window. + + Note that this function assumes a constant sampling rate for `sig`. Parameters ---------- @@ -79,6 +100,8 @@ def compute_mean_power_spectral_density( The `IntervalSet` to calculate the fft on. Can be any length. full_range : bool, optional If true, will return full fft frequency range, otherwise will return only positive values + norm: bool, optional + Whether the FFT result is divided by the length of the signal to normalize the amplitude time_unit : str, optional Time units for parameter `interval_size`. Can be ('s'[default], 'ms', 'us') @@ -94,6 +117,9 @@ def compute_mean_power_spectral_density( TypeError If `ep` or `sig` are not respectively pynapple time series or interval set. """ + if not isinstance(sig, (nap.Tsd, nap.TsdFrame)): + raise TypeError("sig must be either a Tsd or a TsdFrame object.") + if not (ep is None or isinstance(ep, nap.IntervalSet)): raise TypeError("ep param must be a pynapple IntervalSet object, or None") if ep is None: @@ -107,6 +133,9 @@ def compute_mean_power_spectral_density( if not isinstance(full_range, bool): raise TypeError("full_range must be of type bool or None") + if not isinstance(norm, bool): + raise TypeError("norm must be of type bool") + # Split the ep interval_size = nap.TsIndex.format_timestamps(np.array([interval_size]), time_unit)[ 0 @@ -137,11 +166,20 @@ def compute_mean_power_spectral_density( # Get the freqs fft_freq = np.fft.fftfreq(N, 1 / fs) + # Get the Hamming window + window = signal.windows.hamming(N) + if sig.ndim == 2: + window = window[:, np.newaxis] + # Compute the fft fft_result = np.zeros((N, *sig.shape[1:]), dtype=complex) for i in range(len(slices)): - fft_result += np.fft.fft(sig[slices[i, 0] : slices[i, 1]].values[0:N], axis=0) + tmp = sig[slices[i, 0] : slices[i, 1]].values[0:N] * window + fft_result += np.fft.fft(tmp, axis=0) + + if norm: + fft_result = fft_result / (float(N) * float(len(slices))) ret = pd.DataFrame(fft_result, fft_freq) ret.sort_index(inplace=True) diff --git a/tests/test_power_spectral_density.py b/tests/test_power_spectral_density.py index 7d0ec64f..dbda75cf 100644 --- a/tests/test_power_spectral_density.py +++ b/tests/test_power_spectral_density.py @@ -4,13 +4,22 @@ import numpy as np import pandas as pd import pytest +from scipy import signal import pynapple as nap + ############################################################ -# Test for mean_power_spectral_density +# Test for power_spectral_density ############################################################ +def get_sorted_fft(data,fs): + fft = np.fft.fft(data, axis=0) + fft_freq = np.fft.fftfreq(len(data), 1 / fs) + order = np.argsort(fft_freq) + if fft.ndim==1: + fft = fft[:,np.newaxis] + return fft_freq[order], fft[order] def test_compute_power_spectral_density(): @@ -20,16 +29,31 @@ def test_compute_power_spectral_density(): assert isinstance(r, pd.DataFrame) assert r.shape[0] == 500 + a, b = get_sorted_fft(sig.values, sig.rate) + np.testing.assert_array_almost_equal(r.index.values, a[a>=0]) + np.testing.assert_array_almost_equal(r.values, b[a>=0]) + + r = nap.compute_power_spectral_density(sig, norm=True) + np.testing.assert_array_almost_equal(r.values, b[a>=0]/len(sig)) + sig = nap.TsdFrame(d=np.random.random((1000, 4)), t=t) r = nap.compute_power_spectral_density(sig) assert isinstance(r, pd.DataFrame) assert r.shape == (500, 4) + a, b = get_sorted_fft(sig.values, sig.rate) + np.testing.assert_array_almost_equal(r.index.values, a[a>=0]) + np.testing.assert_array_almost_equal(r.values, b[a>=0]) + sig = nap.TsdFrame(d=np.random.random((1000, 4)), t=t) r = nap.compute_power_spectral_density(sig, full_range=True) assert isinstance(r, pd.DataFrame) assert r.shape == (1000, 4) + a, b = get_sorted_fft(sig.values, sig.rate) + np.testing.assert_array_almost_equal(r.index.values, a) + np.testing.assert_array_almost_equal(r.values, b) + t = np.linspace(0, 1, 1000) sig = nap.Tsd(d=np.random.random(1000), t=t) r = nap.compute_power_spectral_density(sig, ep=sig.time_support) @@ -44,7 +68,7 @@ def test_compute_power_spectral_density(): @pytest.mark.parametrize( - "sig, fs, ep, full_range, expectation", + "sig, fs, ep, full_range, norm, expectation", [ ( nap.Tsd( @@ -55,6 +79,7 @@ def test_compute_power_spectral_density(): 1000, None, False, + False, pytest.raises( ValueError, match=re.escape( @@ -67,28 +92,63 @@ def test_compute_power_spectral_density(): 1000, "not_ep", False, + False, pytest.raises( TypeError, match="ep param must be a pynapple IntervalSet object, or None", ), ), + ( + nap.Tsd(d=np.random.random(1000), t=np.linspace(0, 1, 1000)), + "a", + None, + False, + False, + pytest.raises( + TypeError, + match="fs must be of type float or int", + ), + ), ( "not_a_tsd", 1000, None, False, + False, + pytest.raises( + TypeError, + match="sig must be either a Tsd or a TsdFrame object.", + ), + ), + ( + nap.Tsd(d=np.random.random(1000), t=np.linspace(0, 1, 1000)), + 1000, + None, + "a", + False, pytest.raises( TypeError, - match="Currently compute_spectogram is only implemented for Tsd or TsdFrame", + match="full_range must be of type bool or None", ), ), + ( + nap.Tsd(d=np.random.random(1000), t=np.linspace(0, 1, 1000)), + 1000, + None, + False, + "a", + pytest.raises( + TypeError, + match="norm must be of type bool", + ), + ), ], ) def test_compute_power_spectral_density_raise_errors( - sig, fs, ep, full_range, expectation + sig, fs, ep, full_range, norm, expectation ): with expectation: - psd = nap.compute_power_spectral_density(sig, fs, ep, full_range) + psd = nap.compute_power_spectral_density(sig, fs, ep, full_range, norm) ############################################################ @@ -102,6 +162,7 @@ def get_signal_and_output(f=2, fs=1000, duration=100, interval_size=10): sig = nap.Tsd(t=t, d=d, time_support=nap.IntervalSet(0, 100)) tmp = d.reshape((int(duration / interval_size), int(fs * interval_size))).T tmp = tmp[0:-1] + tmp = tmp*signal.windows.hamming(tmp.shape[0])[:,np.newaxis] out = np.sum(np.fft.fft(tmp, axis=0), 1) freq = np.fft.fftfreq(out.shape[0], 1 / fs) order = np.argsort(freq) @@ -125,6 +186,14 @@ def test_compute_mean_power_spectral_density(): np.testing.assert_array_almost_equal(psd.values.flatten(), out) np.testing.assert_array_almost_equal(psd.index.values, freq) + # Norm + psd = nap.compute_mean_power_spectral_density(sig, 10, norm=True) + assert isinstance(psd, pd.DataFrame) + assert psd.shape[0] > 0 # Check that the psd DataFrame is not empty + np.testing.assert_array_almost_equal(psd.values.flatten(), out[freq >= 0]/(9999.0*10.0)) + np.testing.assert_array_almost_equal(psd.index.values, freq[freq >= 0]) + + # TsdFrame sig2 = nap.TsdFrame( t=sig.t, d=np.repeat(sig.values[:, None], 2, 1), time_support=sig.time_support @@ -147,15 +216,26 @@ def test_compute_mean_power_spectral_density(): @pytest.mark.parametrize( - "sig, out, freq, interval_size, fs, ep, full_range, time_units, expectation", + "sig, out, freq, interval_size, fs, ep, full_range, norm, time_units, expectation", [ - (*get_signal_and_output(), 10, None, None, False, "s", does_not_raise()), + (*get_signal_and_output(), 10, None, None, False, False, "s", does_not_raise()), + ( + "a", *get_signal_and_output()[1:], + 10, + None, + None, + False, + False, + "s", + pytest.raises(TypeError, match="sig must be either a Tsd or a TsdFrame object."), + ), ( *get_signal_and_output(), 10, "a", None, False, + False, "s", pytest.raises(TypeError, match="fs must be of type float or int"), ), @@ -165,6 +245,7 @@ def test_compute_mean_power_spectral_density(): None, "a", False, + False, "s", pytest.raises( TypeError, @@ -177,17 +258,29 @@ def test_compute_mean_power_spectral_density(): None, None, "a", + False, "s", pytest.raises(TypeError, match="full_range must be of type bool or None"), ), - (*get_signal_and_output(), 10 * 1e3, None, None, False, "ms", does_not_raise()), - (*get_signal_and_output(), 10 * 1e6, None, None, False, "us", does_not_raise()), + ( + *get_signal_and_output(), + 10, + None, + None, + None, + "a", + "s", + pytest.raises(TypeError, match="full_range must be of type bool or None"), + ), + (*get_signal_and_output(), 10 * 1e3, None, None, False, False, "ms", does_not_raise()), + (*get_signal_and_output(), 10 * 1e6, None, None, False, False, "us", does_not_raise()), ( *get_signal_and_output(), 200, None, None, False, + False, "s", pytest.raises( RuntimeError, @@ -200,6 +293,7 @@ def test_compute_mean_power_spectral_density(): None, nap.IntervalSet([0, 200], [100, 300]), False, + False, "s", pytest.raises( RuntimeError, @@ -209,9 +303,9 @@ def test_compute_mean_power_spectral_density(): ], ) def test_compute_mean_power_spectral_density_raise_errors( - sig, out, freq, interval_size, fs, ep, full_range, time_units, expectation + sig, out, freq, interval_size, fs, ep, full_range, norm, time_units, expectation ): with expectation: psd = nap.compute_mean_power_spectral_density( - sig, interval_size, fs, ep, full_range, time_units + sig, interval_size, fs, ep, full_range, norm, time_units ) From 265be7f78a780f523fcbc7a0460ba80f1bd85f13 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Wed, 7 Aug 2024 10:09:05 -0400 Subject: [PATCH 62/71] Missing test for sig processing --- tests/test_power_spectral_density.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/tests/test_power_spectral_density.py b/tests/test_power_spectral_density.py index dbda75cf..626b0832 100644 --- a/tests/test_power_spectral_density.py +++ b/tests/test_power_spectral_density.py @@ -265,12 +265,22 @@ def test_compute_mean_power_spectral_density(): ( *get_signal_and_output(), 10, - None, - None, - None, - "a", - "s", + None, # FS + None, # Ep + "a", # full_range + False, # Norm + "s", # Time units pytest.raises(TypeError, match="full_range must be of type bool or None"), + ), + ( + *get_signal_and_output(), + 10, + None, # FS + None, # Ep + False, # full_range + "a", # Norm + "s", # Time units + pytest.raises(TypeError, match="norm must be of type bool"), ), (*get_signal_and_output(), 10 * 1e3, None, None, False, False, "ms", does_not_raise()), (*get_signal_and_output(), 10 * 1e6, None, None, False, False, "us", does_not_raise()), From 88857147a1ecc0adc0bc45e7406d6904d59eefe0 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Wed, 7 Aug 2024 12:28:23 -0400 Subject: [PATCH 63/71] More update on wavelets --- docs/api_guide/tutorial_pynapple_spectrum.py | 23 ++--- docs/api_guide/tutorial_pynapple_wavelets.py | 98 ++++++++++---------- 2 files changed, 63 insertions(+), 58 deletions(-) diff --git a/docs/api_guide/tutorial_pynapple_spectrum.py b/docs/api_guide/tutorial_pynapple_spectrum.py index bcb64c50..39db2432 100644 --- a/docs/api_guide/tutorial_pynapple_spectrum.py +++ b/docs/api_guide/tutorial_pynapple_spectrum.py @@ -35,11 +35,11 @@ F = [2, 10] Fs = 2000 -t = np.arange(0, 100, 1/Fs) +t = np.arange(0, 200, 1/Fs) sig = nap.Tsd( t=t, d=np.cos(t*2*np.pi*F[0])+np.cos(t*2*np.pi*F[1])+np.random.normal(0, 3, len(t)), - time_support = nap.IntervalSet(0, 100) + time_support = nap.IntervalSet(0, 200) ) # %% @@ -55,9 +55,9 @@ # Computing power spectral density (PSD) # -------------------------------------- # -# To compute a PSD of a signal, you can use the function `nap.compute_power_spectral_density` +# To compute a PSD of a signal, you can use the function `nap.compute_power_spectral_density`. With `norm=True`, the output of the FFT is divided by the length of the signal. -psd = nap.compute_power_spectral_density(sig) +psd = nap.compute_power_spectral_density(sig, norm=True) # %% # Pynapple returns a pandas DataFrame. @@ -68,7 +68,7 @@ # It is then easy to plot it. plt.figure() -plt.plot(psd) +plt.plot(np.abs(psd)) plt.xlabel("Frequency (Hz)") plt.ylabel("Amplitude") @@ -79,7 +79,7 @@ # Let's zoom on the first 20 Hz. plt.figure() -plt.plot(psd) +plt.plot(np.abs(psd)) plt.xlabel("Frequency (Hz)") plt.ylabel("Amplitude") plt.xlim(0, 20) @@ -108,18 +108,19 @@ # # In this case, the FFT will be computed over epochs of 10 seconds. -mean_psd = nap.compute_mean_power_spectral_density(sig, interval_size=10.0) +mean_psd = nap.compute_mean_power_spectral_density(sig, interval_size=20.0, norm=True) # %% -# Let's compare `mean_psd` to `psd`. +# Let's compare `mean_psd` to `psd`. In both cases, the ouput is normalized. plt.figure() -plt.plot(psd) -plt.plot(mean_psd) +plt.plot(np.abs(psd), label='PSD') +plt.plot(np.abs(mean_psd), label='Mean PSD (10s)') plt.xlabel("Frequency (Hz)") plt.ylabel("Amplitude") -plt.xlim(0, 20) +plt.legend() +plt.xlim(0, 15) # %% # As we can see, `nap.compute_mean_power_spectral_density` was able to smooth out the noise. diff --git a/docs/api_guide/tutorial_pynapple_wavelets.py b/docs/api_guide/tutorial_pynapple_wavelets.py index 94a74e97..016d7cc0 100644 --- a/docs/api_guide/tutorial_pynapple_wavelets.py +++ b/docs/api_guide/tutorial_pynapple_wavelets.py @@ -1,9 +1,11 @@ # -*- coding: utf-8 -*- """ -Wavelet API tutorial -============ +Wavelet Transform +================= -Working with Wavelets! +This tutorial covers the use of `nap.compute_wavelet_transform` to do continuous wavelet transform. By default, pynapple uses Morlet wavelets. + +The function `nap.generate_morlet_filterbank` can help parametrize and visualize the Morlet wavelets. See the [documentation](https://pynapple-org.github.io/pynapple/) of Pynapple for instructions on installing the package. @@ -80,15 +82,16 @@ # %% -# Lets plot it. +# Lets plot it some of the wavelets. + def plot_filterbank(filter_bank, freqs, title): fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 7)) for f_i in range(filter_bank.shape[1]): - ax.plot(filter_bank[:, f_i].real() + 1.5 * f_i) - ax.text(-2.3, 1.5 * f_i, f"{np.round(freqs[f_i], 2)}Hz", va="center", ha="left") - ax.margins(0) - ax.yaxis.set_visible(False) - ax.set_xlim(-2, 2) + ax.plot(filter_bank[:, f_i].real() + f_i*1.5) + ax.text(-5.5, 1.5 * f_i, f"{np.round(freqs[f_i], 2)}Hz", va="center", ha="left") + + ax.set_yticks([]) + ax.set_xlim(-5, 5) ax.set_xlabel("Time (s)") ax.set_title(title) @@ -98,23 +101,28 @@ def plot_filterbank(filter_bank, freqs, title): # %% # *** -# Decomposing the Dummy Signal -# ------------------ +# Continuous wavelet transform +# ---------------------------- # Here we will use the `compute_wavelet_transform` function to decompose our signal using the filter bank shown # above. Wavelet decomposition breaks down a signal into its constituent wavelets, capturing both time and # frequency information for analysis. We will calculate this decomposition and plot it's corresponding -# scalogram. +# scalogram (which is another name for time frequency decomposition using wavelets). # Compute the wavelet transform using the parameters above mwt = nap.compute_wavelet_transform( sig, fs=Fs, freqs=freqs, gaussian_width=1.5, window_length=1.0 ) +# %% +# `mwt` for Morlet wavelet transform is a `TsdFrame`. Each column is the result of the convolution of the signal with one wavelet. + +print(mwt) # %% # Lets plot it. + def plot_timefrequency(freqs, powers, ax=None): - im = ax.imshow(abs(powers), aspect="auto") + im = ax.imshow(np.abs(powers), aspect="auto") ax.invert_yaxis() ax.set_xlabel("Time (s)") ax.set_ylabel("Frequency (Hz)") @@ -134,7 +142,7 @@ def plot_timefrequency(freqs, powers, ax=None): ax1 = plt.subplot(gs[1, 0]) ax1.plot(sig) -ax1.set_ylabel("LFP (a.u.)") +ax1.set_ylabel("Signal") ax1.set_xlabel("Time (s)") ax1.margins(0) @@ -149,9 +157,9 @@ def plot_timefrequency(freqs, powers, ax=None): # Get the index of the 2Hz frequency two_hz_freq_idx = np.where(freqs == 2.0)[0] # The 2Hz component is the real component of the wavelet decomposition at this index -slow_oscillation = mwt[:, two_hz_freq_idx].values.real +slow_oscillation = np.real(mwt[:, two_hz_freq_idx]) # The 2Hz wavelet phase is the angle of the wavelet decomposition at this index -slow_oscillation_phase = np.angle(mwt[:, two_hz_freq_idx].values) +slow_oscillation_phase = np.angle(mwt[:, two_hz_freq_idx]) # %% # Lets plot it. @@ -161,11 +169,11 @@ def plot_timefrequency(freqs, powers, ax=None): height_ratios=[1, 0.4], ) axd["signal"].plot(sig, label="Raw Signal", alpha=0.5) -axd["signal"].plot(t, slow_oscillation, label="2Hz Reconstruction") +axd["signal"].plot(slow_oscillation, label="2Hz Reconstruction") axd["signal"].legend() axd["signal"].set_ylabel("Signal") -axd["phase"].plot(t, slow_oscillation_phase, alpha=0.5) +axd["phase"].plot(slow_oscillation_phase, alpha=0.5) axd["phase"].set_ylabel("Phase (rad)") axd["phase"].set_xlabel("Time (s)") [axd[k].margins(0) for k in ["signal", "phase"]] @@ -184,9 +192,9 @@ def plot_timefrequency(freqs, powers, ax=None): # Get the index of the 15 Hz frequency fifteen_hz_freq_idx = np.where(freqs == 15.0)[0] # The 15 Hz component is the real component of the wavelet decomposition at this index -fifteenHz_oscillation = mwt[:, fifteen_hz_freq_idx].values.real +fifteenHz_oscillation = np.real(mwt[:, fifteen_hz_freq_idx]) # The 15 Hz poser is the absolute value of the wavelet decomposition at this index -fifteenHz_oscillation_power = np.abs(mwt[:, fifteen_hz_freq_idx].values) +fifteenHz_oscillation_power = np.abs(mwt[:, fifteen_hz_freq_idx]) # %% # Lets plot it. @@ -195,13 +203,13 @@ def plot_timefrequency(freqs, powers, ax=None): gs = plt.GridSpec(2, 1, figure=fig, height_ratios=[1.0, 1.0]) ax0 = plt.subplot(gs[0, 0]) -ax0.plot(t, fifteenHz_oscillation, label="15Hz Reconstruction") -ax0.plot(t, fifteenHz_oscillation_power, label="15Hz Power") +ax0.plot(fifteenHz_oscillation, label="15Hz Reconstruction") +ax0.plot(fifteenHz_oscillation_power, label="15Hz Power") ax0.set_xticklabels([]) ax1 = plt.subplot(gs[1, 0]) ax1.plot(sig, label="Raw Signal", alpha=0.5) -ax1.plot(t, slow_oscillation + fifteenHz_oscillation, label="2Hz + 15Hz Reconstruction") +ax1.plot(slow_oscillation + fifteenHz_oscillation.values, label="2Hz + 15Hz Reconstruction") ax1.set_xlabel("Time (s)") [ @@ -216,13 +224,13 @@ def plot_timefrequency(freqs, powers, ax=None): # ------------------ # Let's now add together the real components of all frequency bands to recreate a version of the original signal. -combined_oscillations = mwt.sum(axis=1).real() +combined_oscillations = np.real(np.sum(mwt, axis=1)) # %% # Lets plot it. fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) ax.plot(sig, alpha=0.5, label="Signal") -ax.plot(t, combined_oscillations, label="Wavelet Reconstruction", alpha=0.5) +ax.plot(combined_oscillations, label="Wavelet Reconstruction", alpha=0.5) ax.set_xlabel("Time (s)") ax.set_ylabel("Signal") ax.set_title("Wavelet Reconstruction of Signal") @@ -277,19 +285,19 @@ def plot_timefrequency(freqs, powers, ax=None): ) # %% -# Increasing window_length increases the number of wavelet cycles present in the oscillations (cycles), and +# Increasing `window_length` increases the number of wavelet cycles present in the oscillations (cycles), and # correspondingly increases the time window that the wavelet covers. # -# The gaussian_width parameter determines the shape of the gaussian window being convolved with the sinusoidal +# The `gaussian_width` parameter determines the shape of the gaussian window being convolved with the sinusoidal # component of the wavelet # # Both of these parameters can be tweaked to control for the trade-off between time resolution and frequency resolution. # %% # *** -# Effect of gaussian_width +# Effect of `gaussian_width` # ------------------ -# Let's increase gaussian_width to 7.5 and see the effect on the resultant filter bank. +# Let's increase `gaussian_width` to 7.5 and see the effect on the resultant filter bank. freqs = np.linspace(1, 25, num=25) filter_bank = nap.generate_morlet_filterbank( @@ -319,7 +327,7 @@ def plot_timefrequency(freqs, powers, ax=None): ax1 = plt.subplot(gs[1, 0]) ax1.plot(sig) -ax1.set_ylabel("LFP (a.u.)") +ax1.set_ylabel("Signal") ax1.set_xlabel("Time (s)") ax1.margins(0) @@ -348,9 +356,9 @@ def plot_timefrequency(freqs, powers, ax=None): # %% # *** -# Effect of window_length +# Effect of `window_length` # ------------------ -# Let's increase window_length to 2.0 and see the effect on the resultant filter bank. +# Let's increase `window_length` to 2.0 and see the effect on the resultant filter bank. freqs = np.linspace(1, 25, num=25) filter_bank = nap.generate_morlet_filterbank( @@ -380,7 +388,7 @@ def plot_timefrequency(freqs, powers, ax=None): ax1 = plt.subplot(gs[1, 0]) ax1.plot(sig) -ax1.set_ylabel("LFP (a.u.)") +ax1.set_ylabel("Signal") ax1.set_xlabel("Time (s)") ax1.margins(0) @@ -388,13 +396,13 @@ def plot_timefrequency(freqs, powers, ax=None): # *** # And let's see if that has an effect on the reconstructed version of the signal -combined_oscillations = mwt.sum(axis=1).real() +combined_oscillations = np.real(np.sum(mwt, axis=1)) # %% # Lets plot it. fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) ax.plot(sig, alpha=0.5, label="Signal") -ax.plot(t, combined_oscillations, label="Wavelet Reconstruction", alpha=0.5) +ax.plot(combined_oscillations, label="Wavelet Reconstruction", alpha=0.5) ax.set_xlabel("Time (s)") ax.set_ylabel("Signal") ax.set_title("Wavelet Reconstruction of Signal") @@ -407,7 +415,7 @@ def plot_timefrequency(freqs, powers, ax=None): # *** # Effect of L1 vs L2 normalization # ------------------ -# compute_wavelet_transform contains two options for normalization; L1, and L2. L1 normalization. +# `compute_wavelet_transform` contains two options for normalization; L1, and L2. # By default, L1 is used as it creates cleaner looking decomposition images. # # L1 normalization often increases the contrast between significant and insignificant coefficients. @@ -417,7 +425,7 @@ def plot_timefrequency(freqs, powers, ax=None): # L2 normalization is directly related to the energy of the signal. By normalizing using the # L2 norm, you ensure that the transformed coefficients preserve the energy distribution of the original signal. # -# Let's compare two wavelet decomposition images, each generated with a different normalization strategy +# Let's compare two wavelet decomposition, each generated with a different normalization strategy mwt_l1 = nap.compute_wavelet_transform( sig, fs=Fs, freqs=freqs, gaussian_width=7.5, window_length=2.0, norm="l1" @@ -437,7 +445,7 @@ def plot_timefrequency(freqs, powers, ax=None): cbar = fig.colorbar(im, ax=ax0, orientation="vertical") ax1 = plt.subplot(gs[1, 0]) ax1.plot(sig) -ax1.set_ylabel("LFP (a.u.)") +ax1.set_ylabel("Signal") ax1.set_xlabel("Time (s)") ax1.margins(0) @@ -449,7 +457,7 @@ def plot_timefrequency(freqs, powers, ax=None): cbar = fig.colorbar(im, ax=ax0, orientation="vertical") ax1 = plt.subplot(gs[1, 0]) ax1.plot(sig) -ax1.set_ylabel("LFP (a.u.)") +ax1.set_ylabel("Signal") ax1.set_xlabel("Time (s)") ax1.margins(0) @@ -460,19 +468,15 @@ def plot_timefrequency(freqs, powers, ax=None): # %% -combined_oscillations_l1 = mwt_l1.sum(axis=1).real() -combined_oscillations_l2 = mwt_l2.sum(axis=1).real() +combined_oscillations_l1 = np.real(np.sum(mwt_l1, axis=1)) +combined_oscillations_l2 = np.real(np.sum(mwt_l2, axis=1)) # %% # Lets plot it. fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) ax.plot(sig, label="Signal", linewidth=3, alpha=0.6, c="b") -ax.plot( - t, combined_oscillations_l1, label="Wavelet Reconstruction (L1)", c="g", alpha=0.6 -) -ax.plot( - t, combined_oscillations_l2, label="Wavelet Reconstruction (L2)", c="r", alpha=0.6 -) +ax.plot(combined_oscillations_l1, label="Wavelet Reconstruction (L1)", c="g", alpha=0.6) +ax.plot(combined_oscillations_l2, label="Wavelet Reconstruction (L2)", c="r", alpha=0.6) ax.set_xlabel("Time (s)") ax.set_ylabel("Signal") ax.set_title("Wavelet Reconstruction of Signal") From 2793e1cd837d327cdf57c36c3186974214f43a7b Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Wed, 7 Aug 2024 12:32:24 -0400 Subject: [PATCH 64/71] change title --- docs/examples/tutorial_phase_preferences.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/examples/tutorial_phase_preferences.py b/docs/examples/tutorial_phase_preferences.py index 15d44dda..0abb753b 100644 --- a/docs/examples/tutorial_phase_preferences.py +++ b/docs/examples/tutorial_phase_preferences.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- """ -Computing Phase Preferences -============ +Spikes-phase coupling +===================== In this tutorial we will learn how to isolate phase information from our wavelet decomposition and combine it with spiking data, to find phase preferences of spiking units. From 5d3a340199addeaf7a448a44e640d7340081cf32 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Wed, 7 Aug 2024 22:18:46 -0400 Subject: [PATCH 65/71] Updating tests --- pynapple/process/signal_processing.py | 126 ++++--- tests/test_signal_processing.py | 499 ++++++++++++++------------ 2 files changed, 328 insertions(+), 297 deletions(-) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 722abafa..4d814549 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -1,7 +1,11 @@ """ -Signal processing tools for Pynapple. +# Signal processing tools + +- `nap.compute_power_spectral_density` +- `nap.compute_mean_power_spectral_density` +- `nap.compute_wavelet_transform` +- `nap.generate_morlet_filterbank` -Contains functionality for signal processing pynapple object; fourier transforms and wavelet decomposition. """ from numbers import Number @@ -23,7 +27,7 @@ def compute_power_spectral_density(sig, fs=None, ep=None, full_range=False, norm Time series. fs : float, optional Sampling rate, in Hz. If None, will be calculated from the given signal - ep : pynapple.IntervalSet or None, optional + ep : None or pynapple.IntervalSet, optional The epoch to calculate the fft on. Must be length 1. full_range : bool, optional If true, will return full fft frequency range, otherwise will return only positive values @@ -90,13 +94,13 @@ def compute_mean_power_spectral_density( Parameters ---------- - sig : Tsd or TsdFrame + sig : pynapple.Tsd or pynapple.TsdFrame Signal with equispaced samples interval_size : Number Epochs size to compute to average the FFT across fs : None, optional Sampling frequency of `sig`. If `None`, `fs` is equal to `sig.rate` - ep : None, optional + ep : None or pynapple.IntervalSet, optional The `IntervalSet` to calculate the fft on. Can be any length. full_range : bool, optional If true, will return full fft frequency range, otherwise will return only positive values @@ -216,32 +220,6 @@ def _morlet(M=1024, gaussian_width=1.5, window_length=1.0, precision=8): ) -def _create_freqs(freq_start, freq_stop, num_freqs=10, log_scaling=False): - """ - Creates an array of frequencies. - - Parameters - ---------- - freq_start : float - Starting value for the frequency definition. - freq_stop: float - Stopping value for the frequency definition, inclusive. - num_freqs: int, optional - Number of freqs to create. Default 10 - log_scaling: Bool - If True, will use log spacing with base log_base for frequency spacing. Default False. - - Returns - ------- - freqs: 1d array - Frequency indices. - """ - if not log_scaling: - return np.linspace(freq_start, freq_stop, num_freqs) - else: - return np.geomspace(freq_start, freq_stop, num_freqs) - - def compute_wavelet_transform( sig, freqs, fs=None, gaussian_width=1.5, window_length=1.0, precision=16, norm="l1" ): @@ -250,26 +228,24 @@ def compute_wavelet_transform( Parameters ---------- - sig : pynapple.Tsd, pynapple.TsdFrame or pynapple.TsdTensor + sig : pynapple.Tsd or pynapple.TsdFrame or pynapple.TsdTensor Time series. - freqs : 1d array or tuple of float - If array, frequency values to estimate with morlet wavelets. - If tuple, define the frequency range, as [freq_start, freq_stop, freq_step]. - The `freq_step` is optional, and defaults to 1. Range is inclusive of `freq_stop` value. + freqs : 1d array + Frequency values to estimate with Morlet wavelets. fs : float or None - Sampling rate, in Hz. Defaults to sig.rate if None is given. + Sampling rate, in Hz. Defaults to `sig.rate` if None is given. gaussian_width : float - Defines width of Gaussian to be used in wavelet creation. + Defines width of Gaussian to be used in wavelet creation. Default is 1.5. window_length : float - The length of window to be used for wavelet creation. + The length of window to be used for wavelet creation. Default is 1.0. precision: int. - Precision of wavelet to use. . Defines the number of timepoints to evaluate the Morlet wavelet at. - Default is 16 + Precision of wavelet to use. Defines the number of timepoints to evaluate the Morlet wavelet at. + Default is 16. norm : {None, 'l1', 'l2'}, optional Normalization method: - * None - no normalization - * 'l1' - divide by the sum of amplitudes - * 'l2' - divide by the square root of the sum of amplitudes + - None - no normalization + - 'l1' - divide by the sum of amplitudes + - 'l2' - divide by the square root of the sum of amplitudes Returns ------- @@ -280,10 +256,10 @@ def compute_wavelet_transform( -------- >>> import numpy as np >>> import pynapple as nap - >>> t = np.linspace(0, 1, 1000) + >>> t = np.arange(0, 1, 1/1000) >>> signal = nap.Tsd(d=np.sin(t * 50 * np.pi * 2), t=t) >>> freqs = np.linspace(10, 100, 10) - >>> mwt = nap.compute_wavelet_transform(signal, fs=None, freqs=freqs) + >>> mwt = nap.compute_wavelet_transform(signal, fs=1000, freqs=freqs) Notes ----- @@ -292,31 +268,28 @@ def compute_wavelet_transform( if not isinstance(sig, (nap.Tsd, nap.TsdFrame, nap.TsdTensor)): raise TypeError("`sig` must be instance of Tsd, TsdFrame, or TsdTensor") - if isinstance(gaussian_width, (int, float, np.number)): - if gaussian_width <= 0: - raise ValueError("gaussian_width must be a positive number.") - else: - raise TypeError("gaussian_width must be a float or int instance.") - if isinstance(window_length, (int, float, np.number)): - if window_length <= 0: - raise ValueError("window_length must be a positive number.") - else: - raise TypeError("window_length must be a float or int instance.") + + if not isinstance(freqs, np.ndarray): + raise TypeError("`freqs` must be a ndarray") + if len(freqs) == 0: + raise ValueError("Given list of freqs cannot be empty.") + if np.min(freqs) <= 0: + raise ValueError("All frequencies in freqs must be strictly positive") + + if fs is not None and not isinstance(fs, (int, float, np.number)): + raise TypeError("`fs` must be of type float or int or None") + if norm is not None and norm not in ["l1", "l2"]: raise ValueError("norm parameter must be 'l1', 'l2', or None.") - if not isinstance(freqs, (np.ndarray, tuple)): - raise TypeError("`freqs` must be a ndarray or tuple instance.") - if isinstance(freqs, tuple): - freqs = _create_freqs(*freqs) if fs is None: fs = sig.rate - if isinstance(sig, nap.Tsd): + if sig.ndim == 1: output_shape = (sig.shape[0], len(freqs)) else: output_shape = (sig.shape[0], len(freqs), *sig.shape[1:]) - sig = sig.reshape((sig.shape[0], np.prod(sig.shape[1:]))) + sig = np.reshape(sig, (sig.shape[0], np.prod(sig.shape[1:]))) filter_bank = generate_morlet_filterbank( freqs, fs, gaussian_width, window_length, precision @@ -324,6 +297,7 @@ def compute_wavelet_transform( convolved_real = sig.convolve(filter_bank.real().values) convolved_imag = sig.convolve(filter_bank.imag().values) convolved = convolved_real.values + convolved_imag.values * 1j + if norm == "l1": coef = convolved / (fs / freqs) elif norm == "l2": @@ -352,8 +326,8 @@ def generate_morlet_filterbank( Parameters ---------- freqs : 1d array - Frequency values to estimate with morlet wavelets. - fs : float + frequency values to estimate with Morlet wavelets. + fs : float or int Sampling rate, in Hz. gaussian_width : float Defines width of Gaussian to be used in wavelet creation. @@ -367,10 +341,34 @@ def generate_morlet_filterbank( filter_bank : pynapple.TsdFrame list of Morlet wavelet filters of the frequencies given """ + if not isinstance(freqs, np.ndarray): + raise TypeError("`freqs` must be a ndarray") if len(freqs) == 0: raise ValueError("Given list of freqs cannot be empty.") if np.min(freqs) <= 0: raise ValueError("All frequencies in freqs must be strictly positive") + + if not isinstance(fs, (int, float, np.number)): + raise TypeError("`fs` must be of type float or int ndarray") + + if isinstance(gaussian_width, (int, float, np.number)): + if gaussian_width <= 0: + raise ValueError("gaussian_width must be a positive number.") + else: + raise TypeError("gaussian_width must be a float or int instance.") + + if isinstance(window_length, (int, float, np.number)): + if window_length <= 0: + raise ValueError("window_length must be a positive number.") + else: + raise TypeError("window_length must be a float or int instance.") + + if isinstance(precision, int): + if precision <= 0: + raise ValueError("precision must be a positive number.") + else: + raise TypeError("precision must be a float or int instance.") + filter_bank = [] cutoff = 8 morlet_f = _morlet( diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index 34b0e79f..6d1fc11a 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -2,6 +2,8 @@ import numpy as np import pytest +import re +from contextlib import nullcontext as does_not_raise import pynapple as nap @@ -76,13 +78,95 @@ def test_generate_morlet_filterbank(): ), ), ( - [], + "a", 1000, 1.5, 1.0, 16, - pytest.raises(ValueError, match="Given list of freqs cannot be empty."), + pytest.raises( + TypeError, match="`freqs` must be a ndarray" + ), + ), + ( + np.array([]), + 1000, + 1.5, + 1.0, + 16, + pytest.raises( + ValueError, match="Given list of freqs cannot be empty." + ), + ), + ( + np.linspace(1, 10, 1), + "a", + 1.5, + 1.0, + 16, + pytest.raises( + TypeError, match="`fs` must be of type float or int ndarray" + ), + ), + ( + np.linspace(1, 10, 1), + 1000, + -1.5, + 1.0, + 16, + pytest.raises( + ValueError, match="gaussian_width must be a positive number." + ), + ), + ( + np.linspace(1, 10, 1), + 1000, + "a", + 1.0, + 16, + pytest.raises( + TypeError, match="gaussian_width must be a float or int instance." + ), ), + ( + np.linspace(1, 10, 1), + 1000, + 1.5, + -1.0, + 16, + pytest.raises( + ValueError, match="window_length must be a positive number." + ), + ), + ( + np.linspace(1, 10, 1), + 1000, + 1.5, + "a", + 16, + pytest.raises( + TypeError, match="window_length must be a float or int instance." + ), + ), + ( + np.linspace(1, 10, 1), + 1000, + 1.5, + 1.0, + -16, + pytest.raises( + ValueError, match="precision must be a positive number." + ), + ), + ( + np.linspace(1, 10, 1), + 1000, + 1.5, + 1.0, + "a", + pytest.raises( + TypeError, match="precision must be a float or int instance." + ), + ), ], ) def test_generate_morlet_filterbank_raise_errors( @@ -94,300 +178,249 @@ def test_generate_morlet_filterbank_raise_errors( ) -def test_compute_wavelet_transform(): - t = np.linspace(0, 1, 1001) - sig = nap.Tsd( - d=np.sin(t * 50 * np.pi * 2) - * np.interp(np.linspace(0, 1, 1001), [0, 0.5, 1], [0, 1, 0]), - t=t, - ) - freqs = np.linspace(10, 100, 10) - mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) - mpf = freqs[np.argmax(np.sum(np.abs(mwt), axis=0))] - assert mpf == 50 - assert ( - np.unravel_index(np.abs(mwt.values).argmax(), np.abs(mwt.values).shape)[0] - == 500 - ) - t = np.linspace(0, 1, 1001) - sig = nap.Tsd( - d=np.sin(t * 50 * np.pi * 2) - * np.interp(np.linspace(0, 1, 1001), [0, 0.5, 1], [0, 1, 0]), - t=t, - ) - freqs = (10, 100, 10) - mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) - mwt2 = nap.compute_wavelet_transform(sig, fs=None, freqs=np.linspace(10, 100, 10)) - assert np.array_equal(mwt, mwt2) +############################################################ +# Test for compute_wavelet_transform +############################################################ +import pynapple as nap +import numpy as np - t = np.linspace(0, 1, 1001) - sig = nap.Tsd( - d=np.sin(t * 50 * np.pi * 2) - * np.interp(np.linspace(0, 1, 1001), [0, 0.5, 1], [0, 1, 0]), - t=t, - ) - freqs = (10, 100, 10) - mwt = nap.compute_wavelet_transform(sig, fs=1001, freqs=freqs) - mwt2 = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) - assert np.array_equal(mwt, mwt2) - t = np.linspace(0, 1, 1001) - sig = nap.Tsd( - d=np.sin(t * 50 * np.pi * 2) - * np.interp(np.linspace(0, 1, 1001), [0, 0.5, 1], [0, 1, 0]), - t=t, - ) - freqs = (10, 100, 10, True) - mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) - mwt2 = nap.compute_wavelet_transform(sig, fs=None, freqs=np.geomspace(10, 100, 10)) - assert np.array_equal(mwt, mwt2) +def get_1d_signal(fs=1000, fc=50): + t = np.arange(0, 2, 1 / fs) + d = np.sin(t * fc * np.pi * 2) * np.interp(t, [0, 1, 2], [0, 1, 0]) + return nap.Tsd(t, d, time_support=nap.IntervalSet(0, 2)) + +def get_2d_signal(fs=1000, fc=50): + t = np.arange(0, 2, 1 / fs) + d = np.sin(t * fc * np.pi * 2) * np.interp(t, [0, 1, 2], [0, 1, 0]) + return nap.TsdFrame(t, d[:,np.newaxis], time_support=nap.IntervalSet(0, 2)) + +def get_output_1d(sig, wavelets): + T = sig.shape[0] + M, N = wavelets.shape + out = [] + for n in range(N): + out.append(np.convolve(sig, wavelets[:, n], mode="full")) + out = np.array(out).T + cut = ((M - 1) // 2, T + M - 1 - ((M - 1) // 2) - (1 - M % 2)) + return out[cut[0] : cut[1]] + +def get_output_2d(sig, wavelets): + T, K = sig.shape + M, N = wavelets.shape + out = [] + for k in range(K): + tmp = [] + for n in range(N): + tmp.append(np.convolve(sig[:,k], wavelets[:, n], mode="full")) + out.append(np.array(tmp)) + out = np.array(out).T + cut = ((M - 1) // 2, T + M - 1 - ((M - 1) // 2) - (1 - M % 2)) + return out[cut[0] : cut[1]] - t = np.linspace(0, 1, 1001) - sig = nap.Tsd( - d=np.sin(t * 10 * np.pi * 2) - * np.interp(np.linspace(0, 1, 1001), [0, 0.5, 1], [0, 1, 0]), - t=t, - ) - freqs = np.linspace(10, 100, 10) - mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) - mpf = freqs[np.argmax(np.sum(np.abs(mwt), axis=0))] - assert mpf == 10 - assert ( - np.unravel_index(np.abs(mwt.values).argmax(), np.abs(mwt.values).shape)[0] - == 500 +@pytest.mark.parametrize( + "func, freqs, fs, gaussian_width, window_length, precision, norm, fc, maxt", + [ + (get_1d_signal, np.linspace(10, 100, 10), 1000, 1.5, 1.0, 16, None, 50, 1000), + (get_1d_signal, np.linspace(10, 100, 10), None, 1.5, 1.0, 16, None, 50, 1000), + (get_1d_signal, np.linspace(10, 100, 10), 1000, 3.0, 1.0, 16, None, 50, 1000), + (get_1d_signal, np.linspace(10, 100, 10), 1000, 1.5, 2.0, 16, None, 50, 1000), + (get_1d_signal, np.linspace(10, 100, 10), 1000, 1.5, 1.0, 20, None, 50, 1000), + (get_1d_signal, np.linspace(10, 100, 10), 1000, 1.5, 1.0, 16, "l1", 50, 1000), + (get_1d_signal, np.linspace(10, 100, 10), 1000, 1.5, 1.0, 16, "l2", 50, 1000), + (get_1d_signal, np.linspace(10, 100, 10), 1000, 1.5, 1.0, 16, None, 20, 1000), + (get_2d_signal, np.linspace(10, 100, 10), 1000, 1.5, 1.0, 16, None, 20, 1000), + ], +) +def test_compute_wavelet_transform( + func, freqs, fs, gaussian_width, window_length, precision, norm, fc, maxt +): + sig = func(1000, fc) + wavelets = nap.generate_morlet_filterbank( + freqs, 1000, gaussian_width, window_length, precision ) + if sig.ndim == 1: + output = get_output_1d(sig.d, wavelets.values) + if sig.ndim == 2: + output = get_output_2d(sig.d, wavelets.values) - t = np.linspace(0, 1, 1000) - sig = nap.Tsd(d=np.sin(t * 50 * np.pi * 2), t=t) - freqs = np.linspace(10, 100, 10) - mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) - mpf = freqs[np.argmax(np.sum(np.abs(mwt), axis=0))] - assert mpf == 50 - assert mwt.shape == (1000, 10) - - t = np.linspace(0, 1, 1000) - sig = nap.Tsd(d=np.sin(t * 20 * np.pi * 2), t=t) - freqs = np.linspace(10, 100, 10) - mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) - mpf = freqs[np.argmax(np.sum(np.abs(mwt), axis=0))] - assert mpf == 20 - assert mwt.shape == (1000, 10) - - t = np.linspace(0, 1, 1000) - sig = nap.Tsd(d=np.sin(t * 20 * np.pi * 2), t=t) - freqs = np.linspace(10, 100, 10) - mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, norm="l1") - mpf = freqs[np.argmax(np.sum(np.abs(mwt), axis=0))] - assert mpf == 20 - assert mwt.shape == (1000, 10) - - t = np.linspace(0, 1, 1000) - sig = nap.Tsd(d=np.sin(t * 20 * np.pi * 2), t=t) - freqs = np.linspace(10, 100, 10) - mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, norm="l2") - mpf = freqs[np.argmax(np.sum(np.abs(mwt), axis=0))] - assert mpf == 20 - assert mwt.shape == (1000, 10) - - t = np.linspace(0, 1, 1000) - sig = nap.Tsd(d=np.sin(t * 20 * np.pi * 2), t=t) - freqs = np.linspace(10, 100, 10) - mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, norm=None) - mpf = freqs[np.argmax(np.sum(np.abs(mwt), axis=0))] - assert mpf == 20 - assert mwt.shape == (1000, 10) - - t = np.linspace(0, 1, 1000) - sig = nap.Tsd(d=np.sin(t * 70 * np.pi * 2), t=t) - freqs = np.linspace(10, 100, 10) - mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) - mpf = freqs[np.argmax(np.sum(np.abs(mwt), axis=0))] - assert mpf == 70 - assert mwt.shape == (1000, 10) - - t = np.linspace(0, 1, 1024) - sig = nap.Tsd(d=np.random.random(1024), t=t) - freqs = np.linspace(1, 600, 10) - mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) - assert mwt.shape == (1024, 10) - - t = np.linspace(0, 1, 1024) - sig = nap.Tsd(d=np.random.random(1024), t=t) - freqs = (1, 51, 6) - mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) - assert mwt.shape == (1024, 6) - - t = np.linspace(0, 1, 1024) - sig = nap.TsdFrame(d=np.random.random((1024, 4)), t=t) - freqs = np.linspace(1, 600, 10) - mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) - assert mwt.shape == (1024, 10, 4) - - t = np.linspace(0, 1, 1024) - sig = nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=t) - freqs = np.linspace(1, 600, 10) - mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) - assert mwt.shape == (1024, 10, 4, 2) - - # Testing against manual convolution for l1 norm - t = np.linspace(0, 1, 1024) - sig = nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=t) - freqs = np.linspace(1, 600, 10) - mwt = nap.compute_wavelet_transform( - sig, fs=None, freqs=freqs, gaussian_width=1.5, window_length=1.0, norm="l1" - ) - output_shape = (sig.shape[0], len(freqs), *sig.shape[1:]) - sig = sig.reshape((sig.shape[0], np.prod(sig.shape[1:]))) - filter_bank = nap.generate_morlet_filterbank(freqs, 1024, 1.5, 1.0, precision=16) - convolved_real = sig.convolve(filter_bank.real().values) - convolved_imag = sig.convolve(filter_bank.imag().values) - convolved = convolved_real.values + convolved_imag.values * 1j - coef = convolved / (1024 / freqs) - cwt = np.expand_dims(coef, -1) if len(coef.shape) == 2 else coef - mwt2 = nap.TsdTensor( - t=sig.index, d=cwt.reshape(output_shape), time_support=sig.time_support - ) - assert np.array_equal(mwt, mwt2) + if norm == "l1": + output = output / (1000 / freqs) + if norm == "l2": + output = output / (1000 / np.sqrt(freqs)) - # Testing against manual convolution for l2 norm - t = np.linspace(0, 1, 1024) - sig = nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=t) - freqs = np.linspace(1, 600, 10) mwt = nap.compute_wavelet_transform( - sig, fs=None, freqs=freqs, gaussian_width=1.5, window_length=1.0, norm="l2" - ) - output_shape = (sig.shape[0], len(freqs), *sig.shape[1:]) - sig = sig.reshape((sig.shape[0], np.prod(sig.shape[1:]))) - filter_bank = nap.generate_morlet_filterbank(freqs, 1024, 1.5, 1.0, precision=16) - convolved_real = sig.convolve(filter_bank.real().values) - convolved_imag = sig.convolve(filter_bank.imag().values) - convolved = convolved_real.values + convolved_imag.values * 1j - coef = convolved / (1024 / np.sqrt(freqs)) - cwt = np.expand_dims(coef, -1) if len(coef.shape) == 2 else coef - mwt2 = nap.TsdTensor( - t=sig.index, d=cwt.reshape(output_shape), time_support=sig.time_support + sig, + freqs, + fs=fs, + gaussian_width=gaussian_width, + window_length=window_length, + precision=precision, + norm=norm, ) - assert np.array_equal(mwt, mwt2) - # Testing against manual convolution for no normalization - t = np.linspace(0, 1, 1024) - sig = nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=t) - freqs = np.linspace(1, 600, 10) - mwt = nap.compute_wavelet_transform( - sig, fs=None, freqs=freqs, gaussian_width=1.5, window_length=1.0, norm=None - ) - output_shape = (sig.shape[0], len(freqs), *sig.shape[1:]) - sig = sig.reshape((sig.shape[0], np.prod(sig.shape[1:]))) - filter_bank = nap.generate_morlet_filterbank(freqs, 1024, 1.5, 1.0, precision=16) - convolved_real = sig.convolve(filter_bank.real().values) - convolved_imag = sig.convolve(filter_bank.imag().values) - convolved = convolved_real.values + convolved_imag.values * 1j - coef = convolved - cwt = np.expand_dims(coef, -1) if len(coef.shape) == 2 else coef - mwt2 = nap.TsdTensor( - t=sig.index, d=cwt.reshape(output_shape), time_support=sig.time_support + np.testing.assert_array_almost_equal(output, mwt.values) + assert freqs[np.argmax(np.sum(np.abs(mwt), axis=0))] == fc + assert ( + np.unravel_index(np.abs(mwt.values).argmax(), np.abs(mwt.values).shape)[0] + == maxt ) - assert np.array_equal(mwt, mwt2) + np.testing.assert_array_almost_equal(mwt.time_support.values, sig.time_support.values) @pytest.mark.parametrize( - "sig, fs, freqs, gaussian_width, window_length, precision, norm, expectation", + "sig, freqs, fs, gaussian_width, window_length, precision, norm, expectation", [ + (get_1d_signal(), np.linspace(1, 10, 2), 1000, 1.5, 1, 16, None, does_not_raise()), ( - nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=np.linspace(0, 1, 1024)), - None, - np.linspace(0, 600, 10), + "a", + np.linspace(1, 10, 2), + 1000, 1.5, - 1.0, + 1, 16, - "l1", + None, pytest.raises( - ValueError, match="All frequencies in freqs must be strictly positive" + TypeError, + match=re.escape( + "`sig` must be instance of Tsd, TsdFrame, or TsdTensor" + ), ), ), ( - nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=np.linspace(0, 1, 1024)), + get_1d_signal(), + np.linspace(1, 10, 2), + "a", + 1.5, + 1, + 16, None, - np.linspace(1, 600, 10), + pytest.raises( + TypeError, + match=re.escape( + "`fs` must be of type float or int or None" + ), + ), + ), + ( + get_1d_signal(), + np.linspace(1, 10, 2), + 1000, -1.5, - 1.0, + 1, 16, - "l1", + None, pytest.raises( - ValueError, match="gaussian_width must be a positive number." + ValueError, + match=re.escape("gaussian_width must be a positive number."), ), ), ( - nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=np.linspace(0, 1, 1024)), + get_1d_signal(), + np.linspace(1, 10, 2), + 1000, + "a", + 1, + 16, None, - np.linspace(1, 600, 10), + pytest.raises( + TypeError, + match=re.escape("gaussian_width must be a float or int instance."), + ), + ), + ( + get_1d_signal(), + np.linspace(1, 10, 2), + 1000, 1.5, - -1.0, + -1, 16, - "l1", - pytest.raises(ValueError, match="window_length must be a positive number."), + None, + pytest.raises( + ValueError, + match=re.escape("window_length must be a positive number."), + ), ), ( - nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=np.linspace(0, 1, 1024)), - None, - np.linspace(1, 600, 10), - "not_number", - 1.0, + get_1d_signal(), + np.linspace(1, 10, 2), + 1000, + 1.5, + "a", 16, - "l1", + None, pytest.raises( - TypeError, match="gaussian_width must be a float or int instance." + TypeError, + match=re.escape("window_length must be a float or int instance."), ), ), ( - nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=np.linspace(0, 1, 1024)), - None, - np.linspace(1, 600, 10), + get_1d_signal(), + np.linspace(1, 10, 2), + 1000, 1.5, - "not_number", + 1, 16, - "l1", + "a", pytest.raises( - TypeError, match="window_length must be a float or int instance." + ValueError, + match=re.escape("norm parameter must be 'l1', 'l2', or None."), ), ), ( - nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=np.linspace(0, 1, 1024)), - None, - np.linspace(1, 600, 10), + get_1d_signal(), + "a", + 1000, 1.5, - 1.0, + 1, 16, - "l3", + None, pytest.raises( - ValueError, match="norm parameter must be 'l1', 'l2', or None." + TypeError, + match=re.escape("`freqs` must be a ndarray"), ), ), ( - nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=np.linspace(0, 1, 1024)), - None, - None, + get_1d_signal(), + np.array([]), + 1000, 1.5, - 1.0, + 1, 16, - "l1", + None, pytest.raises( - TypeError, match="`freqs` must be a ndarray or tuple instance." + ValueError, + match=re.escape("Given list of freqs cannot be empty."), ), ), ( - "not_a_signal", + get_1d_signal(), + np.array([-1]), + 1000, + 1.5, + 1, + 16, None, - np.linspace(10, 100, 10), + pytest.raises( + ValueError, + match=re.escape("All frequencies in freqs must be strictly positive"), + ), + ), + ( + get_1d_signal(), + np.linspace(1, 10, 2), + 1000, 1.5, - 1.0, + 1, 16, - "l1", + 1, pytest.raises( - TypeError, match="`sig` must be instance of Tsd, TsdFrame, or TsdTensor" + ValueError, + match=re.escape("norm parameter must be 'l1', 'l2', or None."), ), ), + ], ) def test_compute_wavelet_transform_raise_errors( From 5fab4dc4b006ffff279d41f92d9a1f982156c98f Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Thu, 8 Aug 2024 16:29:31 +0100 Subject: [PATCH 66/71] check that fft of wavelet is correct gaussian --- tests/test_signal_processing.py | 77 ++++++++++++++++++++++++++++++++- 1 file changed, 76 insertions(+), 1 deletion(-) diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index 34b0e79f..91c590a7 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -1,5 +1,5 @@ """Tests of `signal_processing` for pynapple""" - +import matplotlib.pyplot as plt import numpy as np import pytest @@ -61,6 +61,81 @@ def test_generate_morlet_filterbank(): for i, f in enumerate(freqs): assert power.iloc[:, i].argmax() == np.abs(power.index - f).argmin() + # Checking that the power spectra of the wavelets resemble correct Gaussians + fs = 2000 + freqs = np.linspace(100, 1000, 10) + fb = nap.generate_morlet_filterbank( + freqs, fs, gaussian_width=1.0, window_length=1.0, precision=24 + ) + power = np.abs(nap.compute_power_spectral_density(fb)) + for i, f in enumerate(freqs): + gaussian_width = 1.0 + window_length = 1.0 + fz = power.index + factor = np.pi ** 0.25 * gaussian_width ** 0.25 + morlet_ft = factor * np.exp(-np.pi ** 2 * gaussian_width * (window_length*(fz - f)/f) ** 2) + assert np.isclose(power.iloc[:,i]/np.max(power.iloc[:,i]), morlet_ft/np.max(morlet_ft), atol=0.1).all() + + fs = 100 + freqs = np.linspace(1, 10, 10) + fb = nap.generate_morlet_filterbank( + freqs, fs, gaussian_width=1.0, window_length=1.0, precision=24 + ) + power = np.abs(nap.compute_power_spectral_density(fb)) + for i, f in enumerate(freqs): + gaussian_width = 1.0 + window_length = 1.0 + fz = power.index + factor = np.pi ** 0.25 * gaussian_width ** 0.25 + morlet_ft = factor * np.exp(-np.pi ** 2 * gaussian_width * (window_length*(fz - f) / f) ** 2) + assert np.isclose(power.iloc[:, i] / np.max(power.iloc[:, i]), morlet_ft / np.max(morlet_ft), + atol=0.1).all() + + fs = 100 + freqs = np.linspace(1, 10, 10) + fb = nap.generate_morlet_filterbank( + freqs, fs, gaussian_width=4.0, window_length=1.0, precision=24 + ) + power = np.abs(nap.compute_power_spectral_density(fb)) + for i, f in enumerate(freqs): + gaussian_width = 4.0 + window_length = 1.0 + fz = power.index + factor = np.pi ** 0.25 * gaussian_width ** 0.25 + morlet_ft = factor * np.exp(-np.pi ** 2 * gaussian_width * (window_length*(fz - f) / f) ** 2) + assert np.isclose(power.iloc[:, i] / np.max(power.iloc[:, i]), morlet_ft / np.max(morlet_ft), + atol=0.1).all() + + fs = 100 + freqs = np.linspace(1, 10, 10) + fb = nap.generate_morlet_filterbank( + freqs, fs, gaussian_width=4.0, window_length=3.0, precision=24 + ) + power = np.abs(nap.compute_power_spectral_density(fb)) + for i, f in enumerate(freqs): + gaussian_width = 4.0 + window_length = 3.0 + fz = power.index + factor = np.pi ** 0.25 * gaussian_width ** 0.25 + morlet_ft = factor * np.exp(-np.pi ** 2 * gaussian_width * (window_length*(fz - f) / f) ** 2) + assert np.isclose(power.iloc[:, i] / np.max(power.iloc[:, i]), morlet_ft / np.max(morlet_ft), + atol=0.1).all() + + fs = 1000 + freqs = np.linspace(1, 10, 10) + fb = nap.generate_morlet_filterbank( + freqs, fs, gaussian_width=3.5, window_length=1.25, precision=24 + ) + power = np.abs(nap.compute_power_spectral_density(fb)) + for i, f in enumerate(freqs): + gaussian_width = 3.5 + window_length = 1.25 + fz = power.index + factor = np.pi ** 0.25 * gaussian_width ** 0.25 + morlet_ft = factor * np.exp(-np.pi ** 2 * gaussian_width * (window_length * (fz - f) / f) ** 2) + assert np.isclose(power.iloc[:, i] / np.max(power.iloc[:, i]), morlet_ft / np.max(morlet_ft), + atol=0.1).all() + @pytest.mark.parametrize( "freqs, fs, gaussian_width, window_length, precision, expectation", From a761f6efebab215a5cc982a3a936068747ff923b Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Thu, 8 Aug 2024 16:31:09 +0100 Subject: [PATCH 67/71] linting --- tests/test_signal_processing.py | 126 ++++++++++++++++++++------------ 1 file changed, 78 insertions(+), 48 deletions(-) diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index 594d8550..5ea3bf4f 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -1,9 +1,11 @@ """Tests of `signal_processing` for pynapple""" + +import re +from contextlib import nullcontext as does_not_raise + import matplotlib.pyplot as plt import numpy as np import pytest -import re -from contextlib import nullcontext as does_not_raise import pynapple as nap @@ -74,9 +76,15 @@ def test_generate_morlet_filterbank(): gaussian_width = 1.0 window_length = 1.0 fz = power.index - factor = np.pi ** 0.25 * gaussian_width ** 0.25 - morlet_ft = factor * np.exp(-np.pi ** 2 * gaussian_width * (window_length*(fz - f)/f) ** 2) - assert np.isclose(power.iloc[:,i]/np.max(power.iloc[:,i]), morlet_ft/np.max(morlet_ft), atol=0.1).all() + factor = np.pi**0.25 * gaussian_width**0.25 + morlet_ft = factor * np.exp( + -np.pi**2 * gaussian_width * (window_length * (fz - f) / f) ** 2 + ) + assert np.isclose( + power.iloc[:, i] / np.max(power.iloc[:, i]), + morlet_ft / np.max(morlet_ft), + atol=0.1, + ).all() fs = 100 freqs = np.linspace(1, 10, 10) @@ -88,10 +96,15 @@ def test_generate_morlet_filterbank(): gaussian_width = 1.0 window_length = 1.0 fz = power.index - factor = np.pi ** 0.25 * gaussian_width ** 0.25 - morlet_ft = factor * np.exp(-np.pi ** 2 * gaussian_width * (window_length*(fz - f) / f) ** 2) - assert np.isclose(power.iloc[:, i] / np.max(power.iloc[:, i]), morlet_ft / np.max(morlet_ft), - atol=0.1).all() + factor = np.pi**0.25 * gaussian_width**0.25 + morlet_ft = factor * np.exp( + -np.pi**2 * gaussian_width * (window_length * (fz - f) / f) ** 2 + ) + assert np.isclose( + power.iloc[:, i] / np.max(power.iloc[:, i]), + morlet_ft / np.max(morlet_ft), + atol=0.1, + ).all() fs = 100 freqs = np.linspace(1, 10, 10) @@ -103,10 +116,15 @@ def test_generate_morlet_filterbank(): gaussian_width = 4.0 window_length = 1.0 fz = power.index - factor = np.pi ** 0.25 * gaussian_width ** 0.25 - morlet_ft = factor * np.exp(-np.pi ** 2 * gaussian_width * (window_length*(fz - f) / f) ** 2) - assert np.isclose(power.iloc[:, i] / np.max(power.iloc[:, i]), morlet_ft / np.max(morlet_ft), - atol=0.1).all() + factor = np.pi**0.25 * gaussian_width**0.25 + morlet_ft = factor * np.exp( + -np.pi**2 * gaussian_width * (window_length * (fz - f) / f) ** 2 + ) + assert np.isclose( + power.iloc[:, i] / np.max(power.iloc[:, i]), + morlet_ft / np.max(morlet_ft), + atol=0.1, + ).all() fs = 100 freqs = np.linspace(1, 10, 10) @@ -118,10 +136,15 @@ def test_generate_morlet_filterbank(): gaussian_width = 4.0 window_length = 3.0 fz = power.index - factor = np.pi ** 0.25 * gaussian_width ** 0.25 - morlet_ft = factor * np.exp(-np.pi ** 2 * gaussian_width * (window_length*(fz - f) / f) ** 2) - assert np.isclose(power.iloc[:, i] / np.max(power.iloc[:, i]), morlet_ft / np.max(morlet_ft), - atol=0.1).all() + factor = np.pi**0.25 * gaussian_width**0.25 + morlet_ft = factor * np.exp( + -np.pi**2 * gaussian_width * (window_length * (fz - f) / f) ** 2 + ) + assert np.isclose( + power.iloc[:, i] / np.max(power.iloc[:, i]), + morlet_ft / np.max(morlet_ft), + atol=0.1, + ).all() fs = 1000 freqs = np.linspace(1, 10, 10) @@ -133,10 +156,15 @@ def test_generate_morlet_filterbank(): gaussian_width = 3.5 window_length = 1.25 fz = power.index - factor = np.pi ** 0.25 * gaussian_width ** 0.25 - morlet_ft = factor * np.exp(-np.pi ** 2 * gaussian_width * (window_length * (fz - f) / f) ** 2) - assert np.isclose(power.iloc[:, i] / np.max(power.iloc[:, i]), morlet_ft / np.max(morlet_ft), - atol=0.1).all() + factor = np.pi**0.25 * gaussian_width**0.25 + morlet_ft = factor * np.exp( + -np.pi**2 * gaussian_width * (window_length * (fz - f) / f) ** 2 + ) + assert np.isclose( + power.iloc[:, i] / np.max(power.iloc[:, i]), + morlet_ft / np.max(morlet_ft), + atol=0.1, + ).all() @pytest.mark.parametrize( @@ -158,9 +186,7 @@ def test_generate_morlet_filterbank(): 1.5, 1.0, 16, - pytest.raises( - TypeError, match="`freqs` must be a ndarray" - ), + pytest.raises(TypeError, match="`freqs` must be a ndarray"), ), ( np.array([]), @@ -168,9 +194,7 @@ def test_generate_morlet_filterbank(): 1.5, 1.0, 16, - pytest.raises( - ValueError, match="Given list of freqs cannot be empty." - ), + pytest.raises(ValueError, match="Given list of freqs cannot be empty."), ), ( np.linspace(1, 10, 1), @@ -178,9 +202,7 @@ def test_generate_morlet_filterbank(): 1.5, 1.0, 16, - pytest.raises( - TypeError, match="`fs` must be of type float or int ndarray" - ), + pytest.raises(TypeError, match="`fs` must be of type float or int ndarray"), ), ( np.linspace(1, 10, 1), @@ -208,9 +230,7 @@ def test_generate_morlet_filterbank(): 1.5, -1.0, 16, - pytest.raises( - ValueError, match="window_length must be a positive number." - ), + pytest.raises(ValueError, match="window_length must be a positive number."), ), ( np.linspace(1, 10, 1), @@ -228,9 +248,7 @@ def test_generate_morlet_filterbank(): 1.5, 1.0, -16, - pytest.raises( - ValueError, match="precision must be a positive number." - ), + pytest.raises(ValueError, match="precision must be a positive number."), ), ( np.linspace(1, 10, 1), @@ -241,7 +259,7 @@ def test_generate_morlet_filterbank(): pytest.raises( TypeError, match="precision must be a float or int instance." ), - ), + ), ], ) def test_generate_morlet_filterbank_raise_errors( @@ -253,12 +271,12 @@ def test_generate_morlet_filterbank_raise_errors( ) +import numpy as np ############################################################ # Test for compute_wavelet_transform ############################################################ import pynapple as nap -import numpy as np def get_1d_signal(fs=1000, fc=50): @@ -266,21 +284,24 @@ def get_1d_signal(fs=1000, fc=50): d = np.sin(t * fc * np.pi * 2) * np.interp(t, [0, 1, 2], [0, 1, 0]) return nap.Tsd(t, d, time_support=nap.IntervalSet(0, 2)) + def get_2d_signal(fs=1000, fc=50): t = np.arange(0, 2, 1 / fs) d = np.sin(t * fc * np.pi * 2) * np.interp(t, [0, 1, 2], [0, 1, 0]) - return nap.TsdFrame(t, d[:,np.newaxis], time_support=nap.IntervalSet(0, 2)) + return nap.TsdFrame(t, d[:, np.newaxis], time_support=nap.IntervalSet(0, 2)) + def get_output_1d(sig, wavelets): T = sig.shape[0] M, N = wavelets.shape out = [] for n in range(N): - out.append(np.convolve(sig, wavelets[:, n], mode="full")) + out.append(np.convolve(sig, wavelets[:, n], mode="full")) out = np.array(out).T cut = ((M - 1) // 2, T + M - 1 - ((M - 1) // 2) - (1 - M % 2)) return out[cut[0] : cut[1]] + def get_output_2d(sig, wavelets): T, K = sig.shape M, N = wavelets.shape @@ -288,12 +309,13 @@ def get_output_2d(sig, wavelets): for k in range(K): tmp = [] for n in range(N): - tmp.append(np.convolve(sig[:,k], wavelets[:, n], mode="full")) + tmp.append(np.convolve(sig[:, k], wavelets[:, n], mode="full")) out.append(np.array(tmp)) out = np.array(out).T cut = ((M - 1) // 2, T + M - 1 - ((M - 1) // 2) - (1 - M % 2)) return out[cut[0] : cut[1]] + @pytest.mark.parametrize( "func, freqs, fs, gaussian_width, window_length, precision, norm, fc, maxt", [ @@ -341,13 +363,24 @@ def test_compute_wavelet_transform( np.unravel_index(np.abs(mwt.values).argmax(), np.abs(mwt.values).shape)[0] == maxt ) - np.testing.assert_array_almost_equal(mwt.time_support.values, sig.time_support.values) + np.testing.assert_array_almost_equal( + mwt.time_support.values, sig.time_support.values + ) @pytest.mark.parametrize( "sig, freqs, fs, gaussian_width, window_length, precision, norm, expectation", [ - (get_1d_signal(), np.linspace(1, 10, 2), 1000, 1.5, 1, 16, None, does_not_raise()), + ( + get_1d_signal(), + np.linspace(1, 10, 2), + 1000, + 1.5, + 1, + 16, + None, + does_not_raise(), + ), ( "a", np.linspace(1, 10, 2), @@ -373,11 +406,9 @@ def test_compute_wavelet_transform( None, pytest.raises( TypeError, - match=re.escape( - "`fs` must be of type float or int or None" - ), + match=re.escape("`fs` must be of type float or int or None"), ), - ), + ), ( get_1d_signal(), np.linspace(1, 10, 2), @@ -495,7 +526,6 @@ def test_compute_wavelet_transform( match=re.escape("norm parameter must be 'l1', 'l2', or None."), ), ), - ], ) def test_compute_wavelet_transform_raise_errors( From bda58921d3a04ebc0ba539b2cf4801ae9bbf7066 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Thu, 8 Aug 2024 16:35:29 +0100 Subject: [PATCH 68/71] removed bad import --- tests/test_signal_processing.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index 5ea3bf4f..b3e2e3d1 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -3,7 +3,6 @@ import re from contextlib import nullcontext as does_not_raise -import matplotlib.pyplot as plt import numpy as np import pytest @@ -271,12 +270,9 @@ def test_generate_morlet_filterbank_raise_errors( ) -import numpy as np - ############################################################ # Test for compute_wavelet_transform ############################################################ -import pynapple as nap def get_1d_signal(fs=1000, fc=50): From 96510776dd309dd68c457234efae194250c633ca Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Thu, 8 Aug 2024 15:22:22 -0400 Subject: [PATCH 69/71] updating --- pynapple/core/base_class.py | 2 +- pynapple/process/signal_processing.py | 2 +- tests/test_power_spectral_density.py | 4 ++-- tests/test_time_series.py | 2 ++ 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/pynapple/core/base_class.py b/pynapple/core/base_class.py index 337f304f..c119b57a 100644 --- a/pynapple/core/base_class.py +++ b/pynapple/core/base_class.py @@ -568,7 +568,7 @@ def _get_slice( # get index of preceding time value idx_start = np.searchsorted(self.t, start, side="left") - if idx_start == len(self.t): + if idx_start == len(self.t) and mode != "restrict": idx_start -= 1 # make sure the index is not out of bound if mode == "before_t": diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 4d814549..cdb5639f 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -155,7 +155,7 @@ def compute_mean_power_spectral_density( slices = np.zeros((len(split_ep), 2), dtype=int) for i in range(len(split_ep)): - sl = sig.get_slice(split_ep[i, 0], split_ep[i, 1], mode="backward") + sl = sig.get_slice(split_ep[i, 0], split_ep[i, 1]) slices[i, 0] = sl.start slices[i, 1] = sl.stop diff --git a/tests/test_power_spectral_density.py b/tests/test_power_spectral_density.py index 626b0832..fc76103c 100644 --- a/tests/test_power_spectral_density.py +++ b/tests/test_power_spectral_density.py @@ -161,7 +161,7 @@ def get_signal_and_output(f=2, fs=1000, duration=100, interval_size=10): d = np.cos(2 * np.pi * f * t) sig = nap.Tsd(t=t, d=d, time_support=nap.IntervalSet(0, 100)) tmp = d.reshape((int(duration / interval_size), int(fs * interval_size))).T - tmp = tmp[0:-1] + # tmp = tmp[0:-1] tmp = tmp*signal.windows.hamming(tmp.shape[0])[:,np.newaxis] out = np.sum(np.fft.fft(tmp, axis=0), 1) freq = np.fft.fftfreq(out.shape[0], 1 / fs) @@ -190,7 +190,7 @@ def test_compute_mean_power_spectral_density(): psd = nap.compute_mean_power_spectral_density(sig, 10, norm=True) assert isinstance(psd, pd.DataFrame) assert psd.shape[0] > 0 # Check that the psd DataFrame is not empty - np.testing.assert_array_almost_equal(psd.values.flatten(), out[freq >= 0]/(9999.0*10.0)) + np.testing.assert_array_almost_equal(psd.values.flatten(), out[freq >= 0]/(10000.0*10.0)) np.testing.assert_array_almost_equal(psd.index.values, freq[freq >= 0]) diff --git a/tests/test_time_series.py b/tests/test_time_series.py index b7be2f1a..446a4fd0 100755 --- a/tests/test_time_series.py +++ b/tests/test_time_series.py @@ -1682,6 +1682,8 @@ def test_get_slice_value_step(start, end, n_points, mode, expected_slice, expect (1, None, slice(0, 1), np.array([1])), (4, None, slice(3, 4), np.array([4])), (5, None, slice(3, 4), np.array([4])), + (-1, 0, slice(0, 0), np.array([])), + (5, 6, slice(4, 4), np.array([])), ] ) @pytest.mark.parametrize("ts", From ee558e560193a594dc673871105c78744ee77ad0 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Fri, 9 Aug 2024 15:52:25 +0100 Subject: [PATCH 70/71] addressing comments, slight notebook improvements --- docs/api_guide/tutorial_pynapple_wavelets.py | 20 +++- docs/examples/tutorial_phase_preferences.py | 2 +- docs/examples/tutorial_signal_processing.py | 114 ++++++++++++------- pynapple/io/folder.py | 9 -- pynapple/process/signal_processing.py | 65 +++++++---- tests/test_signal_processing.py | 6 + 6 files changed, 143 insertions(+), 73 deletions(-) diff --git a/docs/api_guide/tutorial_pynapple_wavelets.py b/docs/api_guide/tutorial_pynapple_wavelets.py index 016d7cc0..20246d0f 100644 --- a/docs/api_guide/tutorial_pynapple_wavelets.py +++ b/docs/api_guide/tutorial_pynapple_wavelets.py @@ -5,11 +5,14 @@ This tutorial covers the use of `nap.compute_wavelet_transform` to do continuous wavelet transform. By default, pynapple uses Morlet wavelets. +Wavelet are a great tool for capturing changes of spectral characteristics of a signal over time. As neural signals change +and develop over time, wavelet decompositions can aid both visualization and analysis. + The function `nap.generate_morlet_filterbank` can help parametrize and visualize the Morlet wavelets. See the [documentation](https://pynapple-org.github.io/pynapple/) of Pynapple for instructions on installing the package. -This tutorial was made by Kipp Freud. +This tutorial was made by [Kipp Freud](https://kippfreud.com/). """ @@ -84,12 +87,13 @@ # %% # Lets plot it some of the wavelets. + def plot_filterbank(filter_bank, freqs, title): fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 7)) for f_i in range(filter_bank.shape[1]): - ax.plot(filter_bank[:, f_i].real() + f_i*1.5) + ax.plot(filter_bank[:, f_i].real() + f_i * 1.5) ax.text(-5.5, 1.5 * f_i, f"{np.round(freqs[f_i], 2)}Hz", va="center", ha="left") - + ax.set_yticks([]) ax.set_xlim(-5, 5) ax.set_xlabel("Time (s)") @@ -121,6 +125,7 @@ def plot_filterbank(filter_bank, freqs, title): # %% # Lets plot it. + def plot_timefrequency(freqs, powers, ax=None): im = ax.imshow(np.abs(powers), aspect="auto") ax.invert_yaxis() @@ -209,7 +214,9 @@ def plot_timefrequency(freqs, powers, ax=None): ax1 = plt.subplot(gs[1, 0]) ax1.plot(sig, label="Raw Signal", alpha=0.5) -ax1.plot(slow_oscillation + fifteenHz_oscillation.values, label="2Hz + 15Hz Reconstruction") +ax1.plot( + slow_oscillation + fifteenHz_oscillation.values, label="2Hz + 15Hz Reconstruction" +) ax1.set_xlabel("Time (s)") [ @@ -222,6 +229,11 @@ def plot_timefrequency(freqs, powers, ax=None): # *** # Adding ALL the Oscillations! # ------------------ +# We will now learn how to interpret the parameters of the wavelet, and in particular how to trade off the +# accuracy in the frequency decomposition with the accuracy in the time domain reconstruction; + +# Up to this point we have used default wavelet and normalization parameters. +# # Let's now add together the real components of all frequency bands to recreate a version of the original signal. combined_oscillations = np.real(np.sum(mwt, axis=1)) diff --git a/docs/examples/tutorial_phase_preferences.py b/docs/examples/tutorial_phase_preferences.py index 0abb753b..5af93b3c 100644 --- a/docs/examples/tutorial_phase_preferences.py +++ b/docs/examples/tutorial_phase_preferences.py @@ -8,7 +8,7 @@ Specifically, we will examine LFP and spiking data from a period of REM sleep, after traversal of a linear track. -This tutorial was made by Kipp Freud. +This tutorial was made by [Kipp Freud](https://kippfreud.com/). """ # %% diff --git a/docs/examples/tutorial_signal_processing.py b/docs/examples/tutorial_signal_processing.py index dd7d66df..1daf034a 100644 --- a/docs/examples/tutorial_signal_processing.py +++ b/docs/examples/tutorial_signal_processing.py @@ -7,7 +7,7 @@ Specifically, we will examine Local Field Potential data from a period of active traversal of a linear track. -This tutorial was made by Kipp Freud. +This tutorial was made by [Kipp Freud](https://kippfreud.com/). """ @@ -26,6 +26,7 @@ import math import os +import matplotlib.colors as mcolors import matplotlib.pyplot as plt import numpy as np import requests @@ -54,18 +55,11 @@ total=math.ceil(int(r.headers.get("content-length", 0)) // block_size), ): f.write(data) - - -# %% -# *** -# Loading the data -# ------------------ # Let's load and print the full dataset. - data = nap.load_file(path) - print(data) + # %% # First we can extract the data from the NWB. The local field potential has been downsampled to 1250Hz. We will call it `eeg`. # @@ -113,6 +107,9 @@ axd["pos"].set_ylabel("Linearized Position") axd["pos"].set_xlim(RUN_interval[0, 0], RUN_interval[0, 1]) +# %% +# In the top panel, we can see the lfp trace as a function of time, and on the bottom the mouse position on the linear +# track as a function of time. Position 0 and 1 correspond to the start and end of the trial respectively. # %% # *** @@ -145,22 +142,30 @@ # %% -# The red area outlines the theta rhythm (6-10 Hz) which is proeminent in hippocampal LFP. -# Hippocampal theta rhythm appears mostly when the animal is running. -# (See Hasselmo, M. E., & Stern, C. E. (2014). Theta rhythm and the encoding and retrieval of space and time. Neuroimage, 85, 656-666.) -# We can check it here by separating `wake_ep` into `run_ep` and `rest_ep`. -run_ep = data['position'].dropna().find_support(1) +# The red area outlines the theta rhythm (6-10 Hz) which is proeminent in hippocampal LFP. +# Hippocampal theta rhythm appears mostly when the animal is running [1]. +# We can check it here by separating the wake epochs (`wake_ep`) into run epochs (`run_ep`) and rest epochs (`rest_ep`). + +# The run epoch is the portion of the data for which we have position data +run_ep = data["position"].dropna().find_support(1) +# The rest epoch is the data at all points where we do not have position data rest_ep = wake_ep.set_diff(run_ep) # %% -# `run_ep` and `rest_ep` are IntervalSet with discontinuous epoch. +# `run_ep` and `rest_ep` are IntervalSet with discontinuous epoch. +# # The function `nap.compute_power_spectral_density` takes signal with a single epoch to avoid artefacts between epochs jumps. -# To compare `run_ep` with `rest_ep`, we can use the function `nap.compute_mean_power_spectral_density` which avearge the FFT over multiple epochs of same duration. The parameter `interval_size` controls the duration of those epochs. -# -# In this case, `interval_size` is equal to 1.5 seconds. +# +# To compare `run_ep` with `rest_ep`, we can use the function `nap.compute_mean_power_spectral_density` which avearge the FFT over multiple epochs of same duration. The parameter `interval_size` controls the duration of those epochs. +# +# In this case, `interval_size` is equal to 1.5 seconds. -power_run = nap.compute_mean_power_spectral_density(eeg, 1.5, fs=FS, ep=run_ep, norm=True) -power_rest = nap.compute_mean_power_spectral_density(eeg, 1.5, fs=FS, ep=rest_ep, norm=True) +power_run = nap.compute_mean_power_spectral_density( + eeg, 1.5, fs=FS, ep=run_ep, norm=True +) +power_rest = nap.compute_mean_power_spectral_density( + eeg, 1.5, fs=FS, ep=rest_ep, norm=True +) # %% # `power_run` and `power_rest` are the power spectral density when the animal is respectively running and resting. @@ -170,13 +175,13 @@ np.abs(power_run[(power_run.index >= 3.0) & (power_run.index <= 30)]), alpha=1, label="Run", - linewidth=2 + linewidth=2, ) ax.plot( np.abs(power_rest[(power_rest.index >= 3.0) & (power_rest.index <= 30)]), alpha=1, label="Rest", - linewidth=2 + linewidth=2, ) ax.axvspan(6, 10, color="red", alpha=0.1) ax.set_xlabel("Freq (Hz)") @@ -223,7 +228,7 @@ ax1 = plt.subplot(gs[1, 0], sharex=ax0) ax1.plot(eeg_example) -ax1.set_ylabel("LFP (v)") +ax1.set_ylabel("LFP (a.u.)") ax1 = plt.subplot(gs[2, 0], sharex=ax0) ax1.plot(pos_example, color="black") @@ -238,11 +243,11 @@ # There seems to be a strong theta frequency present in the data during the maze traversal. # Let's plot the estimated 6-10Hz component of the wavelet decomposition on top of our data, and see how well they match up. -theta_freq_index = np.logical_and(freqs>6, freqs<10) +theta_freq_index = np.logical_and(freqs > 6, freqs < 10) # Extract its real component, as well as its power envelope -theta_band_reconstruction = np.mean(mwt_RUN[:,theta_freq_index], 1) +theta_band_reconstruction = np.mean(mwt_RUN[:, theta_freq_index], 1) theta_band_power_envelope = np.abs(theta_band_reconstruction) @@ -252,13 +257,13 @@ fig = plt.figure(constrained_layout=True, figsize=(10, 6)) gs = plt.GridSpec(2, 1, figure=fig, height_ratios=[1.0, 0.9]) -ax0 = plt.subplot(gs[0,0]) +ax0 = plt.subplot(gs[0, 0]) ax0.plot(eeg_example, label="CA1") ax0.set_title("EEG (1250 Hz)") ax0.set_ylabel("LFP (a.u.)") ax0.set_xlabel("time (s)") ax0.legend() -ax1 = plt.subplot(gs[1,0]) +ax1 = plt.subplot(gs[1, 0]) ax1.plot(np.real(theta_band_reconstruction), label="6-10 Hz oscillations") ax1.plot(theta_band_power_envelope, label="6-10 Hz power envelope") ax1.set_xlabel("time (s)") @@ -267,7 +272,12 @@ # %% # *** -# Visualizing high frequency oscillation +# We observe that the theta power is far stronger during the first 4 seconds of the dataset, during which the rat +# is traversing the linear track. + +# %% +# *** +# Visualizing High Frequency Oscillation # ----------------------------------- # There also seem to be peaks in the 200Hz frequency power after traversal of thew maze is complete. Here we use the interval (18356, 18357.5) seconds to zoom in. @@ -292,11 +302,11 @@ ax1.set_xlabel("Time (s)") # %% -# Those events are called Sharp-waves ripples (See : Buzsáki, G. (2015). Hippocampal sharp wave‐ripple: A cognitive biomarker for episodic memory and planning. Hippocampus, 25(10), 1073-1188.) +# Those events are called Sharp-waves ripples [2]. # # Among other methods, we can use the Wavelet decomposition to isolate them. In this case, we will look at the power of the wavelets for frequencies between 150 to 250 Hz. -ripple_freq_index = np.logical_and(freqs>150, freqs<250) +ripple_freq_index = np.logical_and(freqs > 150, freqs < 250) # %% # We can compute the mean power for this frequency band. @@ -309,11 +319,11 @@ fig = plt.figure(constrained_layout=True, figsize=(10, 5)) gs = plt.GridSpec(2, 1, figure=fig, height_ratios=[1.0, 0.5]) -ax0 = plt.subplot(gs[0,0]) +ax0 = plt.subplot(gs[0, 0]) ax0.plot(eeg_example.restrict(zoom_ep), label="CA1") -ax0.set_ylabel("LFP (v)") +ax0.set_ylabel("LFP (a.u.)") ax0.set_title(f"EEG (1250 Hz)") -ax1 = plt.subplot(gs[1,0]) +ax1 = plt.subplot(gs[1, 0]) ax1.legend() ax1.plot(ripple_power.restrict(zoom_ep), label="150-250 Hz") ax1.legend() @@ -342,20 +352,44 @@ fig = plt.figure(constrained_layout=True, figsize=(10, 5)) gs = plt.GridSpec(2, 1, figure=fig, height_ratios=[1.0, 0.5]) -ax0 = plt.subplot(gs[0,0]) +ax0 = plt.subplot(gs[0, 0]) ax0.plot(eeg_example.restrict(zoom_ep), label="CA1") -for s,e in rip_ep.intersect(zoom_ep).values: - ax0.axvspan(s, e, color='red', alpha=0.1, ec=None) -ax0.set_ylabel("LFP (v)") +for i, (s, e) in enumerate(rip_ep.intersect(zoom_ep).values): + ax0.axvspan(s, e, color=list(mcolors.TABLEAU_COLORS.keys())[i], alpha=0.2, ec=None) +ax0.set_ylabel("LFP (a.u.)") ax0.set_title(f"EEG (1250 Hz)") -ax1 = plt.subplot(gs[1,0]) +ax1 = plt.subplot(gs[1, 0]) ax1.legend() ax1.plot(ripple_power.restrict(zoom_ep), label="150-250 Hz") ax1.plot(smoothed_ripple_power.restrict(zoom_ep)) -for s,e in rip_ep.intersect(zoom_ep).values: - ax1.axvspan(s, e, color='red', alpha=0.1, ec=None) +for i, (s, e) in enumerate(rip_ep.intersect(zoom_ep).values): + ax1.axvspan(s, e, color=list(mcolors.TABLEAU_COLORS.keys())[i], alpha=0.2, ec=None) ax1.legend() ax1.set_ylabel("Mean Amplitude") ax1.set_xlabel("Time (s)") +# %% +# Finally, let's zoom in on each of our isolated ripples + +fig = plt.figure(constrained_layout=True, figsize=(10, 5)) +gs = plt.GridSpec(2, 2, figure=fig, height_ratios=[1.0, 1.0]) +buffer = 0.02 +plt.suptitle("Isolated Sharp Wave Ripples") +for i, (s, e) in enumerate(rip_ep.intersect(zoom_ep).values): + ax = plt.subplot(gs[int(i / 2), i % 2]) + ax.plot(eeg_example.restrict(nap.IntervalSet(s - buffer, e + buffer))) + ax.axvspan(s, e, color=list(mcolors.TABLEAU_COLORS.keys())[i], alpha=0.2, ec=None) + ax.set_xlim(s - buffer, e + buffer) + ax.set_xlabel("Time (s)") + ax.set_ylabel("LFP (a.u.)") + + +# %% +# *** +# References +# ----------------------------------- +# +# [1] Hasselmo, M. E., & Stern, C. E. (2014). Theta rhythm and the encoding and retrieval of space and time. Neuroimage, 85, 656-666. +# +# [2] Buzsáki, G. (2015). Hippocampal sharp wave‐ripple: A cognitive biomarker for episodic memory and planning. Hippocampus, 25(10), 1073-1188. diff --git a/pynapple/io/folder.py b/pynapple/io/folder.py index 8f7d2f1a..a35af18d 100644 --- a/pynapple/io/folder.py +++ b/pynapple/io/folder.py @@ -1,16 +1,7 @@ -#!/usr/bin/env python - -# -*- coding: utf-8 -*- -# @Author: Guillaume Viejo -# @Date: 2023-05-15 15:32:24 -# @Last Modified by: Guillaume Viejo -# @Last Modified time: 2024-08-02 11:35:10 - """ The Folder class helps to navigate a hierarchical data tree. """ - import json import string from collections import UserDict diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 4d814549..99f3d9ba 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -1,5 +1,5 @@ """ -# Signal processing tools +# Signal processing tools. - `nap.compute_power_spectral_density` - `nap.compute_mean_power_spectral_density` @@ -19,7 +19,7 @@ def compute_power_spectral_density(sig, fs=None, ep=None, full_range=False, norm=False): """ - Performs numpy fft on sig, returns output. Pynapple assumes a constant sampling rate for sig. + Perform numpy fft on sig, returns output assuming a constant sampling rate for the signal. Parameters ---------- @@ -85,7 +85,9 @@ def compute_mean_power_spectral_density( norm=False, time_unit="s", ): - """Compute mean power spectral density by averaging FFT over epochs of same size. + """ + Compute mean power spectral density by averaging FFT over epochs of same size. + The parameter `interval_size` controls the duration of the epochs. To imporve frequency resolution, the signal is multiplied by a Hamming window. @@ -114,6 +116,14 @@ def compute_mean_power_spectral_density( pandas.DataFrame Power spectral density. + Examples + -------- + >>> import numpy as np + >>> import pynapple as nap + >>> t = np.arange(0, 1, 1/1000) + >>> signal = nap.Tsd(d=np.sin(t * 50 * np.pi * 2), t=t) + >>> mpsd = nap.compute_mean_power_spectral_density(signal, 0.1) + Raises ------ RuntimeError @@ -155,7 +165,7 @@ def compute_mean_power_spectral_density( slices = np.zeros((len(split_ep), 2), dtype=int) for i in range(len(split_ep)): - sl = sig.get_slice(split_ep[i, 0], split_ep[i, 1], mode="backward") + sl = sig.get_slice(split_ep[i, 0], split_ep[i, 1]) slices[i, 0] = sl.start slices[i, 1] = sl.stop @@ -244,7 +254,7 @@ def compute_wavelet_transform( norm : {None, 'l1', 'l2'}, optional Normalization method: - None - no normalization - - 'l1' - divide by the sum of amplitudes + - 'l1' - (default) divide by the sum of amplitudes - 'l2' - divide by the square root of the sum of amplitudes Returns @@ -320,8 +330,10 @@ def generate_morlet_filterbank( freqs, fs, gaussian_width=1.5, window_length=1.0, precision=16 ): """ - Generates a Morlet filterbank using the given frequencies and parameters. Can be used purely for visualization, - or to convolve with a pynapple Tsd, TsdFrame, or TsdTensor as part of a wavelet decomposition process. + Generates a Morlet filterbank using the given frequencies and parameters. + + This function can be used purely for visualization, or to convolve with a pynapple Tsd, + TsdFrame, or TsdTensor as part of a wavelet decomposition process. Parameters ---------- @@ -340,6 +352,12 @@ def generate_morlet_filterbank( ------- filter_bank : pynapple.TsdFrame list of Morlet wavelet filters of the frequencies given + + Notes + ----- + This algorithm first computes a single, finely sampled wavelet using the provided hyperparameters. + Wavelets of different frequencies are generated by resampling this mother wavelet with an appropriate step size. + The step size is determined based on the desired frequency and the sampling rate. """ if not isinstance(freqs, np.ndarray): raise TypeError("`freqs` must be a ndarray") @@ -369,27 +387,35 @@ def generate_morlet_filterbank( else: raise TypeError("precision must be a float or int instance.") + # Initialize filter bank and parameters filter_bank = [] - cutoff = 8 - morlet_f = _morlet( - int(2**precision), gaussian_width=gaussian_width, window_length=window_length + cutoff = 8 # Define cutoff for wavelet + # Compute a single, finely sampled Morlet wavelet + morlet_f = np.conj( + _morlet( + int(2**precision), + gaussian_width=gaussian_width, + window_length=window_length, + ) ) x = np.linspace(-cutoff, cutoff, int(2**precision)) - int_psi = np.conj(morlet_f) - max_len = -1 + max_len = -1 # Track maximum length of wavelet for freq in freqs: scale = window_length / (freq / fs) + # Calculate the indices for subsampling the wavelet and achieve the right frequency + # After the slicing the size will be reduced, therefore we will pad with 0s. j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * (x[1] - x[0])) - j = j.astype(int) # floor - if j[-1] >= int_psi.size: - j = np.extract(j < int_psi.size, j) - int_psi_scale = int_psi[j][::-1] - if len(int_psi_scale) > max_len: - max_len = len(int_psi_scale) + j = j.astype(int) # Floor the values to get integer indices + if j[-1] >= morlet_f.size: + j = np.extract(j < morlet_f.size, j) + scaled_morlet = morlet_f[j][::-1] # Scale and reverse wavelet + if len(scaled_morlet) > max_len: + max_len = len(scaled_morlet) time = np.linspace( -cutoff * window_length / freq, cutoff * window_length / freq, max_len ) - filter_bank.append(int_psi_scale) + filter_bank.append(scaled_morlet) + # Pad wavelets to ensure all are of the same length filter_bank = [ np.pad( arr, @@ -398,4 +424,5 @@ def generate_morlet_filterbank( ) for arr in filter_bank ] + # Return filter bank as a TsdFrame return nap.TsdFrame(d=np.array(filter_bank).transpose(), t=time) diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index b3e2e3d1..a708134f 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -17,6 +17,7 @@ def test_generate_morlet_filterbank(): ) power = np.abs(nap.compute_power_spectral_density(fb)) for i, f in enumerate(freqs): + # Check that peak freq matched expectation assert power.iloc[:, i].argmax() == np.abs(power.index - f).argmin() fs = 10000 @@ -26,6 +27,7 @@ def test_generate_morlet_filterbank(): ) power = np.abs(nap.compute_power_spectral_density(fb)) for i, f in enumerate(freqs): + # Check that peak freq matched expectation assert power.iloc[:, i].argmax() == np.abs(power.index - f).argmin() fs = 1000 @@ -35,6 +37,7 @@ def test_generate_morlet_filterbank(): ) power = np.abs(nap.compute_power_spectral_density(fb)) for i, f in enumerate(freqs): + # Check that peak freq matched expectation assert power.iloc[:, i].argmax() == np.abs(power.index - f).argmin() fs = 10000 @@ -44,6 +47,7 @@ def test_generate_morlet_filterbank(): ) power = np.abs(nap.compute_power_spectral_density(fb)) for i, f in enumerate(freqs): + # Check that peak freq matched expectation assert power.iloc[:, i].argmax() == np.abs(power.index - f).argmin() fs = 1000 @@ -53,6 +57,7 @@ def test_generate_morlet_filterbank(): ) power = np.abs(nap.compute_power_spectral_density(fb)) for i, f in enumerate(freqs): + # Check that peak freq matched expectation assert power.iloc[:, i].argmax() == np.abs(power.index - f).argmin() fs = 10000 @@ -62,6 +67,7 @@ def test_generate_morlet_filterbank(): ) power = np.abs(nap.compute_power_spectral_density(fb)) for i, f in enumerate(freqs): + # Check that peak freq matched expectation assert power.iloc[:, i].argmax() == np.abs(power.index - f).argmin() # Checking that the power spectra of the wavelets resemble correct Gaussians From 861042a38785315e4284ef7751813bae739e796e Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Fri, 9 Aug 2024 17:04:08 +0100 Subject: [PATCH 71/71] suggested changes and wavelet arange rounding fix --- ...rocessing.py => tutorial_wavelet_decomposition.py} | 2 +- pynapple/process/signal_processing.py | 9 +++------ tests/test_signal_processing.py | 11 ++++++----- 3 files changed, 10 insertions(+), 12 deletions(-) rename docs/examples/{tutorial_signal_processing.py => tutorial_wavelet_decomposition.py} (99%) diff --git a/docs/examples/tutorial_signal_processing.py b/docs/examples/tutorial_wavelet_decomposition.py similarity index 99% rename from docs/examples/tutorial_signal_processing.py rename to docs/examples/tutorial_wavelet_decomposition.py index 1daf034a..530540ec 100644 --- a/docs/examples/tutorial_signal_processing.py +++ b/docs/examples/tutorial_wavelet_decomposition.py @@ -374,7 +374,7 @@ fig = plt.figure(constrained_layout=True, figsize=(10, 5)) gs = plt.GridSpec(2, 2, figure=fig, height_ratios=[1.0, 1.0]) -buffer = 0.02 +buffer = 0.075 plt.suptitle("Isolated Sharp Wave Ripples") for i, (s, e) in enumerate(rip_ep.intersect(zoom_ep).values): ax = plt.subplot(gs[int(i / 2), i % 2]) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 99f3d9ba..6071e98e 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -295,11 +295,8 @@ def compute_wavelet_transform( if fs is None: fs = sig.rate - if sig.ndim == 1: - output_shape = (sig.shape[0], len(freqs)) - else: - output_shape = (sig.shape[0], len(freqs), *sig.shape[1:]) - sig = np.reshape(sig, (sig.shape[0], np.prod(sig.shape[1:]))) + output_shape = (sig.shape[0], len(freqs), *sig.shape[1:]) + sig = np.reshape(sig, (sig.shape[0], -1)) filter_bank = generate_morlet_filterbank( freqs, fs, gaussian_width, window_length, precision @@ -405,7 +402,7 @@ def generate_morlet_filterbank( # Calculate the indices for subsampling the wavelet and achieve the right frequency # After the slicing the size will be reduced, therefore we will pad with 0s. j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * (x[1] - x[0])) - j = j.astype(int) # Floor the values to get integer indices + j = np.ceil(j).astype(int) # Ceil the values to get integer indices if j[-1] >= morlet_f.size: j = np.extract(j < morlet_f.size, j) scaled_morlet = morlet_f[j][::-1] # Scale and reverse wavelet diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index a708134f..c9c1495b 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -70,6 +70,7 @@ def test_generate_morlet_filterbank(): # Check that peak freq matched expectation assert power.iloc[:, i].argmax() == np.abs(power.index - f).argmin() + gaussian_atol = 1e-4 # Checking that the power spectra of the wavelets resemble correct Gaussians fs = 2000 freqs = np.linspace(100, 1000, 10) @@ -88,7 +89,7 @@ def test_generate_morlet_filterbank(): assert np.isclose( power.iloc[:, i] / np.max(power.iloc[:, i]), morlet_ft / np.max(morlet_ft), - atol=0.1, + atol=gaussian_atol, ).all() fs = 100 @@ -108,7 +109,7 @@ def test_generate_morlet_filterbank(): assert np.isclose( power.iloc[:, i] / np.max(power.iloc[:, i]), morlet_ft / np.max(morlet_ft), - atol=0.1, + atol=gaussian_atol, ).all() fs = 100 @@ -128,7 +129,7 @@ def test_generate_morlet_filterbank(): assert np.isclose( power.iloc[:, i] / np.max(power.iloc[:, i]), morlet_ft / np.max(morlet_ft), - atol=0.1, + atol=gaussian_atol, ).all() fs = 100 @@ -148,7 +149,7 @@ def test_generate_morlet_filterbank(): assert np.isclose( power.iloc[:, i] / np.max(power.iloc[:, i]), morlet_ft / np.max(morlet_ft), - atol=0.1, + atol=gaussian_atol, ).all() fs = 1000 @@ -168,7 +169,7 @@ def test_generate_morlet_filterbank(): assert np.isclose( power.iloc[:, i] / np.max(power.iloc[:, i]), morlet_ft / np.max(morlet_ft), - atol=0.1, + atol=gaussian_atol, ).all()