diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index da024804..457fe0c9 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -12,10 +12,17 @@ def compute_fft(sig, fs): """ - Performs numpy fft on sig, returns output - ..todo: Make sig handle TsdFrame, TsdTensor - :param sig: :param fs: :return: + Parameters + ---------- + sig : pynapple.Tsd + The signal time series to obtain time frequency decomposition from + fs : int / float + The sampling frequency of sig + + Returns + ------- + """ if not isinstance(sig, nap.Tsd): raise TypeError("Currently compute_fft is only implemented for Tsd") @@ -155,7 +162,7 @@ def compute_wavelet_transform(sig, fs, freqs, n_cycles=7, scaling=0.5, norm="amp mwt = np.zeros([len(freqs), len(sig)], dtype=complex) for ind, (freq, n_cycle) in enumerate(zip(freqs, n_cycles)): mwt[ind, :] = _convolve_wavelet(sig, fs, freq, n_cycle, scaling, norm=norm) - return nap.TsdFrame(t=sig.index, d=np.transpose(mwt)) + return nap.TsdFrame(t=sig.index, d=np.transpose(mwt), time_support=sig.time_support) else: mwt = np.zeros( [sig.values.shape[0], len(freqs), sig.values.shape[1]], dtype=complex @@ -165,7 +172,7 @@ def compute_wavelet_transform(sig, fs, freqs, n_cycles=7, scaling=0.5, norm="amp mwt[:, ind, channel_i] = _convolve_wavelet( sig[:, channel_i], fs, freq, n_cycle, scaling, norm=norm ) - return nap.TsdTensor(t=sig.index, d=mwt) + return nap.TsdTensor(t=sig.index, d=mwt, time_support=sig.time_support) def _convolve_wavelet(