From 45532742338c4c470afca2a05290933bfe4e4a2b Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Tue, 16 Apr 2024 11:39:07 -0400 Subject: [PATCH 1/2] Fixing repr --- pynapple/core/interval_set.py | 42 +++++++++-- pynapple/core/time_series.py | 137 ++++++++++++++++++++-------------- pynapple/core/ts_group.py | 31 +++----- pynapple/core/utils.py | 23 +++++- tests/test_interval_set.py | 14 +++- tests/test_time_series.py | 8 +- 6 files changed, 168 insertions(+), 87 deletions(-) diff --git a/pynapple/core/interval_set.py b/pynapple/core/interval_set.py index b556cb52..4e350163 100644 --- a/pynapple/core/interval_set.py +++ b/pynapple/core/interval_set.py @@ -53,6 +53,7 @@ from .config import nap_config from .time_index import TsIndex from .utils import ( + _get_terminal_size, _IntervalSetSliceHelper, _jitfix_iset, convert_to_numpy, @@ -172,11 +173,39 @@ def __repr__(self): headers = ["start", "end"] bottom = "shape: {}, time unit: sec.".format(self.shape) - return ( - tabulate(self.values, headers=headers, showindex="always", tablefmt="plain") - + "\n" - + bottom - ) + rows = _get_terminal_size()[1] + max_rows = np.maximum(rows - 10, 6) + + if len(self) > max_rows: + n_rows = max_rows // 2 + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + return ( + tabulate( + self.values[0:n_rows], + headers=headers, + showindex=self.index[0:n_rows], + tablefmt="plain", + ) + + "\n\n...\n" + + tabulate( + self.values[-n_rows:], + headers=[" " * 5, " " * 3], # To align properly the columns + showindex=self.index[-n_rows:], + tablefmt="plain", + ) + + "\n" + + bottom + ) + + else: + return ( + tabulate( + self.values, headers=headers, showindex="always", tablefmt="plain" + ) + + "\n" + + bottom + ) def __str__(self): return self.__repr__() @@ -203,6 +232,9 @@ def __getitem__(self, key, *args, **kwargs): elif isinstance(key, (list, slice, np.ndarray)): output = self.values.__getitem__(key) return IntervalSet(start=output[:, 0], end=output[:, 1]) + elif isinstance(key, pd.Series): + output = self.values.__getitem__(key) + return IntervalSet(start=output[:, 0], end=output[:, 1]) elif isinstance(key, tuple): if len(key) == 2: if isinstance(key[1], Number): diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index ffaa33d5..1b2a16b7 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -41,6 +41,7 @@ from .time_index import TsIndex from .utils import ( _concatenate_tsd, + _get_terminal_size, _split_tsd, _TsdFrameSliceHelper, convert_to_numpy, @@ -620,42 +621,39 @@ def __repr__(self): headers = ["Time (s)", ""] bottom = "dtype: {}".format(self.dtype) + ", shape: {}".format(self.shape) + max_rows = 2 + rows = _get_terminal_size()[1] + max_rows = np.maximum(rows - 10, 2) + if len(self): def create_str(array): if array.ndim == 1: if len(array) > 2: - return ( - "[" - + array[0].__repr__() - + " ... " - + array[-1].__repr__() - + "]" - ) - elif len(array) == 2: - return ( - "[" + array[0].__repr__() + "," + array[1].__repr__() + "]" + return np.array2string( + np.array([array[0], array[-1]]), + precision=6, + separator=" ... ", ) - elif len(array) == 1: - return "[" + array[0].__repr__() + "]" else: - return "[]" + return np.array2string(array, precision=6, separator=", ") else: return "[" + create_str(array[0]) + " ...]" _str_ = [] - if self.shape[0] < 100: - for i, array in zip(self.index, self.values): - _str_.append([i.__repr__(), create_str(array)]) - else: - for i, array in zip(self.index[0:5], self.values[0:5]): + if self.shape[0] > max_rows: + n_rows = max_rows // 2 + for i, array in zip(self.index[0:n_rows], self.values[0:n_rows]): _str_.append([i.__repr__(), create_str(array)]) _str_.append(["...", ""]) for i, array in zip( - self.index[-5:], - self.values[self.values.shape[0] - 5 : self.values.shape[0]], + self.index[-n_rows:], + self.values[self.values.shape[0] - n_rows : self.values.shape[0]], ): _str_.append([i.__repr__(), create_str(array)]) + else: + for i, array in zip(self.index, self.values): + _str_.append([i.__repr__(), create_str(array)]) return tabulate(_str_, headers=headers, colalign=("left",)) + "\n" + bottom @@ -818,40 +816,52 @@ def __repr__(self): headers = ["Time (s)"] + [str(k) for k in self.columns] bottom = "dtype: {}".format(self.dtype) + ", shape: {}".format(self.shape) - max_cols = 5 - try: - max_cols = os.get_terminal_size()[0] // 16 - except Exception: - import shutil - - max_cols = shutil.get_terminal_size().columns // 16 - else: - pass + cols, rows = _get_terminal_size() + max_cols = np.maximum(cols // 100, 5) + max_rows = np.maximum(rows - 10, 2) if self.shape[1] > max_cols: headers = headers[0 : max_cols + 1] + ["..."] + def round_if_float(x): + if isinstance(x, float): + return np.round(x, 5) + else: + return x + with warnings.catch_warnings(): warnings.simplefilter("ignore") if len(self): table = [] end = ["..."] if self.shape[1] > max_cols else [] - if len(self) > 51: - for i, array in zip(self.index[0:5], self.values[0:5, 0:max_cols]): - table.append([i] + [k for k in array] + end) + if len(self) > max_rows: + n_rows = max_rows // 2 + for i, array in zip( + self.index[0:n_rows], self.values[0:n_rows, 0:max_cols] + ): + table.append([i] + [round_if_float(k) for k in array] + end) table.append(["..."]) for i, array in zip( - self.index[-5:], + self.index[-n_rows:], self.values[ - self.values.shape[0] - 5 : self.values.shape[0], 0:max_cols + self.values.shape[0] - n_rows : self.values.shape[0], + 0:max_cols, ], ): - table.append([i] + [k for k in array] + end) - return tabulate(table, headers=headers) + "\n" + bottom + table.append([i] + [round_if_float(k) for k in array] + end) + return ( + tabulate(table, headers=headers, colalign=("left",)) + + "\n" + + bottom + ) else: for i, array in zip(self.index, self.values[:, 0:max_cols]): - table.append([i] + [k for k in array] + end) - return tabulate(table, headers=headers) + "\n" + bottom + table.append([i] + [round_if_float(k) for k in array] + end) + return ( + tabulate(table, headers=headers, colalign=("left",)) + + "\n" + + bottom + ) else: return tabulate([], headers=headers) + "\n" + bottom @@ -1053,27 +1063,24 @@ def __repr__(self): headers = ["Time (s)", ""] bottom = "dtype: {}".format(self.dtype) + ", shape: {}".format(self.shape) + max_rows = 2 + rows = _get_terminal_size()[1] + max_rows = np.maximum(rows - 10, 2) + with warnings.catch_warnings(): warnings.simplefilter("ignore") if len(self): - if len(self) < 51: - return ( - tabulate( - np.vstack((self.index, self.values)).T, - headers=headers, - colalign=("left",), - ) - + "\n" - + bottom - ) - else: + if len(self) > max_rows: + n_rows = max_rows // 2 table = [] - for i, v in zip(self.index[0:5], self.values[0:5]): + for i, v in zip(self.index[0:n_rows], self.values[0:n_rows]): table.append([i, v]) table.append(["..."]) for i, v in zip( - self.index[-5:], - self.values[self.values.shape[0] - 5 : self.values.shape[0]], + self.index[-n_rows:], + self.values[ + self.values.shape[0] - n_rows : self.values.shape[0] + ], ): table.append([i, v]) @@ -1082,6 +1089,16 @@ def __repr__(self): + "\n" + bottom ) + else: + return ( + tabulate( + np.vstack((self.index, self.values)).T, + headers=headers, + colalign=("left",), + ) + + "\n" + + bottom + ) else: return tabulate([], headers=headers) + "\n" + bottom @@ -1357,14 +1374,20 @@ def __init__(self, t, time_units="s", time_support=None): def __repr__(self): upper = "Time (s)" - if len(self) < 50: - _str_ = "\n".join([i.__repr__() for i in self.index]) - else: + + max_rows = 2 + rows = _get_terminal_size()[1] + max_rows = np.maximum(rows - 10, 2) + + if len(self) > max_rows: + n_rows = max_rows // 2 _str_ = "\n".join( - [i.__repr__() for i in self.index[0:5]] + [i.__repr__() for i in self.index[0:n_rows]] + ["..."] - + [i.__repr__() for i in self.index[-5:]] + + [i.__repr__() for i in self.index[-n_rows:]] ) + else: + _str_ = "\n".join([i.__repr__() for i in self.index]) bottom = "shape: {}".format(len(self.index)) return "\n".join((upper, _str_, bottom)) diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index 4e97aeed..971fd560 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -24,7 +24,7 @@ from .interval_set import IntervalSet from .time_index import TsIndex from .time_series import BaseTsd, Ts, Tsd, TsdFrame, is_array_like -from .utils import convert_to_numpy +from .utils import _get_terminal_size, convert_to_numpy def _union_intervals(i_sets): @@ -259,27 +259,14 @@ def _ts_group_from_keys(self, keys): ) def __repr__(self): - cols = self._metadata.columns.drop("rate") - headers = ["Index", "rate"] + [c for c in cols] + col_names = self._metadata.columns.drop("rate") + headers = ["Index", "rate"] + [c for c in col_names] max_cols = 6 max_rows = 2 - - try: - max_cols, max_rows = os.get_terminal_size() - max_cols = max_cols // 12 - max_rows = max_rows - 10 - except Exception: - import shutil - - max_cols, max_rows = shutil.get_terminal_size() - max_cols = max_cols // 12 - max_rows = max_rows - 10 - else: - pass - - max_rows = np.maximum(max_rows, 2) - max_cols = np.maximum(max_cols, 6) + cols, rows = _get_terminal_size() + max_cols = np.maximum(cols // 12, 6) + max_rows = np.maximum(rows - 10, 2) end_line = [] lines = [] @@ -303,7 +290,7 @@ def round_if_float(x): [i, np.round(self._metadata.loc[i, "rate"], 5)] + [ round_if_float(self._metadata.loc[i, c]) - for c in cols[0 : max_cols - 2] + for c in col_names[0 : max_cols - 2] ] + end_line ) @@ -313,7 +300,7 @@ def round_if_float(x): [i, np.round(self._metadata.loc[i, "rate"], 5)] + [ round_if_float(self._metadata.loc[i, c]) - for c in cols[0 : max_cols - 2] + for c in col_names[0 : max_cols - 2] ] + end_line ) @@ -323,7 +310,7 @@ def round_if_float(x): [i, np.round(self._metadata.loc[i, "rate"], 5)] + [ round_if_float(self._metadata.loc[i, c]) - for c in cols[0 : max_cols - 2] + for c in col_names[0 : max_cols - 2] ] + end_line ) diff --git a/pynapple/core/utils.py b/pynapple/core/utils.py index f2d8c376..32234422 100644 --- a/pynapple/core/utils.py +++ b/pynapple/core/utils.py @@ -2,12 +2,13 @@ # @Author: Guillaume Viejo # @Date: 2024-02-09 11:45:45 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2024-04-04 17:04:38 +# @Last Modified time: 2024-04-15 12:25:49 """ Utility functions """ +import os import warnings from itertools import combinations @@ -17,6 +18,26 @@ from .config import nap_config +def _get_terminal_size(): + """Helper to get terminal size for __repr__ + + Returns + ------- + tuple + + """ + cols = 100 # Default + rows = 2 + try: + cols, rows = os.get_terminal_size() + except Exception: + import shutil + + cols, rows = shutil.get_terminal_size() + + return (cols, rows) + + def is_array_like(obj): """ Check if an object is array-like. diff --git a/tests/test_interval_set.py b/tests/test_interval_set.py index 9ca0aaaf..f90871a9 100644 --- a/tests/test_interval_set.py +++ b/tests/test_interval_set.py @@ -3,7 +3,7 @@ # @Author: gviejo # @Date: 2022-03-30 11:15:02 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2024-04-04 17:18:10 +# @Last Modified time: 2024-04-16 11:35:35 """Tests for IntervalSet of `pynapple` package.""" @@ -169,6 +169,18 @@ def test_get_iset(): ep[:,0,3] assert str(e.value) == "too many indices for IntervalSet: IntervalSet is 2-dimensional" +def test_get_iset_with_series(): + start = np.array([0, 10, 16], dtype=np.float64) + end = np.array([5, 15, 20], dtype=np.float64) + ep = nap.IntervalSet(start=start,end=end) + + bool_series = pd.Series([True, False, True]) + + ep2 = ep[bool_series] + + assert isinstance(ep2, nap.IntervalSet) + np.testing.assert_array_almost_equal(ep2.values, ep[[0,2]].values) + def test_iset_loc(): start = np.array([0, 10, 16], dtype=np.float64) end = np.array([5, 15, 20], dtype=np.float64) diff --git a/tests/test_time_series.py b/tests/test_time_series.py index d662b666..d86283e0 100755 --- a/tests/test_time_series.py +++ b/tests/test_time_series.py @@ -2,7 +2,7 @@ # @Author: gviejo # @Date: 2022-04-01 09:57:55 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2024-04-03 10:23:28 +# @Last Modified time: 2024-04-16 10:07:56 #!/usr/bin/env python """Tests of time series for `pynapple` package.""" @@ -13,6 +13,12 @@ import pytest +# tsd1 = nap.Tsd(t=np.arange(100), d=np.random.rand(100), time_units="s") +# tsd2 = nap.TsdFrame(t=np.arange(100), d=np.random.rand(100, 10), time_units="s") +# tsd3 = nap.TsdTensor(t=np.arange(100), d=np.random.rand(100, 5, 4), time_units="s") +# tsd4 = nap.Ts(t=np.arange(100), time_units="s") + + def test_create_tsd(): tsd = nap.Tsd(t=np.arange(100), d=np.arange(100)) assert isinstance(tsd, nap.Tsd) From 99449de800d0d183c6adefe21435f3ab2ce8479f Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Tue, 16 Apr 2024 12:02:43 -0400 Subject: [PATCH 2/2] fixing docs --- docs/examples/tutorial_pynapple_process.py | 45 ++++++++++------------ 1 file changed, 21 insertions(+), 24 deletions(-) diff --git a/docs/examples/tutorial_pynapple_process.py b/docs/examples/tutorial_pynapple_process.py index e6165276..c8ad929c 100644 --- a/docs/examples/tutorial_pynapple_process.py +++ b/docs/examples/tutorial_pynapple_process.py @@ -36,44 +36,41 @@ # Discrete correlograms # --------------------- # -# The function to compute cross-correlogram is [*cross_correlogram*](https://peyrachelab.github.io/pynapple/process.correlograms/#pynapple.process.correlograms.cross_correlogram). +# First let's generate some data. Here we have two neurons recorded together. We can group them in a `TsGroup`. # -# -# The function is compiled with [numba](https://numba.pydata.org/) to improve performances. This means it only accepts pure numpy arrays as input arguments. +ts1 = nap.Ts(t=np.sort(np.random.uniform(0, 1000, 2000)), time_units="s") +ts2 = nap.Ts(t=np.sort(np.random.uniform(0, 1000, 1000)), time_units="s") +epoch = nap.IntervalSet(start=0, end=1000, time_units="s") +ts_group = nap.TsGroup({0: ts1, 1: ts2}, time_support=epoch) -ts1 = nap.Ts(t=np.sort(np.random.uniform(0, 1000, 1000)), time_units="s") -ts2 = nap.Ts(t=np.sort(np.random.uniform(0, 1000, 10)), time_units="s") +print(ts_group) -ts1_time_array = ts1.as_units("s").index.values -ts2_time_array = ts2.as_units("s").index.values +# %% +# First we can compute their autocorrelograms meaning the number of spikes of a neuron observed in a time windows centered around its own spikes. +# For this we can use the function `compute_autocorrelogram`. +# We need to specifiy the `binsize` and `windowsize` to bin the spike train. -binsize = 0.1 # second -cc12, xt = nap.process.correlograms.cross_correlogram( - t1=ts1_time_array, t2=ts2_time_array, binsize=binsize, windowsize=1 # second +autocorrs = nap.compute_autocorrelogram( + group=ts_group, binsize=100, windowsize=1000, time_units="ms", ep=epoch # ms ) - -plt.figure(figsize=(10, 6)) -plt.bar(xt, cc12, binsize) -plt.xlabel("Time t1 (us)") -plt.ylabel("CC") +print(autocorrs, "\n") # %% -# To simplify converting to a numpy.ndarray, pynapple provides wrappers for computing autocorrelogram and crosscorrelogram for TsGroup. The function is then called for each unit or each pairs of units. It returns directly a pandas.DataFrame holding all the correlograms. In this example, autocorrelograms and cross-correlograms are computed for the same TsGroup. - -epoch = nap.IntervalSet(start=0, end=1000, time_units="s") -ts_group = nap.TsGroup({0: ts1, 1: ts2}, time_support=epoch) +# The variable `autocorrs` is a pandas DataFrame with the center of the bins for the index and each columns is a neuron. +# +# Similarly, we can compute crosscorrelograms meaning how many spikes of neuron 1 do I observe whenever neuron 0 fires. Here the function +# is called `compute_crosscorrelogram` and takes a `TsGroup` as well. -autocorrs = nap.compute_autocorrelogram( - group=ts_group, binsize=100, windowsize=1000, time_units="ms", ep=epoch # ms # ms -) crosscorrs = nap.compute_crosscorrelogram( - group=ts_group, binsize=100, windowsize=1000, time_units="ms" # ms # ms + group=ts_group, binsize=100, windowsize=1000, time_units="ms" # ms ) -print(autocorrs, "\n") print(crosscorrs, "\n") +# %% +# Column name (0, 1) is read as cross-correlogram of neuron 0 and 1 with neuron 0 being the reference time. + # %% # *** # Peri-Event Time Histogram (PETH)