diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 349e7157..0c779db4 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -5,7 +5,7 @@ on: push: branches: [ main ] pull_request: - branches: [ main ] + branches: [ main, dev ] jobs: lint: @@ -40,9 +40,9 @@ jobs: # - os: windows-latest # python-version: 3.7 steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -52,14 +52,12 @@ jobs: - name: Test run: | coverage run --source=pynapple --branch -m pytest tests/ - coverage report -m + coverage report -m - - name: Coveralls - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - run: | - pip install coveralls - coveralls --service=github + - name: Upload coverage reports to Codecov + uses: codecov/codecov-action@v4.0.1 + with: + token: ${{ secrets.CODECOV_TOKEN }} check: if: always() needs: diff --git a/pynapple/core/_core_functions.py b/pynapple/core/_core_functions.py index b969e787..6a47fdb4 100644 --- a/pynapple/core/_core_functions.py +++ b/pynapple/core/_core_functions.py @@ -144,6 +144,6 @@ def _threshold(time_array, data_array, starts, ends, thr, method): if get_backend() == "jax": from pynajax.jax_core_threshold import threshold - return threshold(time_array, data_array, starts, ends, thr, method) + return threshold(time_array, data_array[:], starts, ends, thr, method) else: - return jitthreshold(time_array, data_array, starts, ends, thr, method) + return jitthreshold(time_array, data_array[:], starts, ends, thr, method) diff --git a/pynapple/core/_jitted_functions.py b/pynapple/core/_jitted_functions.py index a30d143a..4de37635 100644 --- a/pynapple/core/_jitted_functions.py +++ b/pynapple/core/_jitted_functions.py @@ -312,11 +312,16 @@ def jitthreshold(time_array, data_array, starts, ends, thr, method="above"): return (new_time_array, new_data_array, new_starts, new_ends) -@jit(nopython=True) def jitbin_array(time_array, data_array, starts, ends, bin_size): + """Slice first for compatibility with lazy loading.""" idx, countin = jitrestrict_with_count(time_array, starts, ends) - time_array = time_array[idx] - data_array = data_array[idx] + return _jitbin_array( + countin, time_array[idx], data_array[idx], starts, ends, bin_size + ) + + +@jit(nopython=True) +def _jitbin_array(countin, time_array, data_array, starts, ends, bin_size): m = starts.shape[0] f = data_array.shape[1:] diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index 140baf31..41b2f77b 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -68,10 +68,18 @@ class BaseTsd(Base, NDArrayOperatorsMixin, abc.ABC): Implement most of the shared functions across concrete classes `Tsd`, `TsdFrame`, `TsdTensor` """ - def __init__(self, t, d, time_units="s", time_support=None): + def __init__(self, t, d, time_units="s", time_support=None, load_array=True): super().__init__(t, time_units, time_support) - self.values = convert_to_array(d, "d") + if load_array: + self.values = convert_to_array(d, "d") + else: + if not is_array_like(d): + raise TypeError( + "Data should be array-like, i.e. be indexable, iterable and, have attributes " + "`shape`, `ndim` and, `dtype`)." + ) + self.values = d assert len(self.index) == len( self.values @@ -220,36 +228,38 @@ def __array_function__(self, func, types, args, kwargs): def as_array(self): """ - Return the data as a numpy.ndarray + Return the data. Returns ------- - out: numpy.ndarray + out: array-like _ """ return self.values def data(self): """ - Return the data as a numpy.ndarray + Return the data. Returns ------- - out: numpy.ndarray + out: array-like _ """ return self.values def to_numpy(self): """ - Return the data as a numpy.ndarray. Mostly useful for matplotlib plotting when calling `plot(tsd)` + Return the data as a numpy.ndarray. + + Mostly useful for matplotlib plotting when calling `plot(tsd)`. """ - return self.values + return np.asarray(self.values) def copy(self): """Copy the data, index and time support""" return self.__class__( - t=self.index.copy(), d=self.values.copy(), time_support=self.time_support + t=self.index.copy(), d=self.values[:].copy(), time_support=self.time_support ) def value_from(self, data, ep=None): @@ -662,7 +672,9 @@ class TsdTensor(BaseTsd): The time support of the time series """ - def __init__(self, t, d, time_units="s", time_support=None, **kwargs): + def __init__( + self, t, d, time_units="s", time_support=None, load_array=True, **kwargs + ): """ TsdTensor initializer @@ -677,7 +689,7 @@ def __init__(self, t, d, time_units="s", time_support=None, **kwargs): time_support : IntervalSet, optional The time support of the TsdFrame object """ - super().__init__(t, d, time_units, time_support) + super().__init__(t, d, time_units, time_support, load_array) assert ( self.values.ndim >= 3 @@ -831,7 +843,15 @@ class TsdFrame(BaseTsd): The time support of the time series """ - def __init__(self, t, d=None, time_units="s", time_support=None, columns=None): + def __init__( + self, + t, + d=None, + time_units="s", + time_support=None, + columns=None, + load_array=True, + ): """ TsdFrame initializer A pandas.DataFrame can be passed directly @@ -859,7 +879,7 @@ def __init__(self, t, d=None, time_units="s", time_support=None, columns=None): else: assert d is not None, "Missing argument d when initializing TsdFrame" - super().__init__(t, d, time_units, time_support) + super().__init__(t, d, time_units, time_support, load_array) assert self.values.ndim <= 2, "Data should be 1 or 2 dimensional." @@ -1084,7 +1104,7 @@ def save(self, filename): np.savez( filename, t=self.index.values, - d=self.values, + d=self.values[:], start=self.time_support.start, end=self.time_support.end, columns=cols_name, @@ -1108,7 +1128,9 @@ class Tsd(BaseTsd): The time support of the time series """ - def __init__(self, t, d=None, time_units="s", time_support=None, **kwargs): + def __init__( + self, t, d=None, time_units="s", time_support=None, load_array=True, **kwargs + ): """ Tsd Initializer. @@ -1129,7 +1151,7 @@ def __init__(self, t, d=None, time_units="s", time_support=None, **kwargs): else: assert d is not None, "Missing argument d when initializing Tsd" - super().__init__(t, d, time_units, time_support) + super().__init__(t, d, time_units, time_support, load_array) assert self.values.ndim == 1, "Data should be 1 dimensional" diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index e35a6937..b4865d6d 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -17,6 +17,7 @@ from ._core_functions import _count from ._jitted_functions import jitunion, jitunion_isets from .base_class import Base +from .config import nap_config from .interval_set import IntervalSet from .time_index import TsIndex from .time_series import BaseTsd, Ts, Tsd, TsdFrame, is_array_like @@ -1026,6 +1027,207 @@ def getby_category(self, key): sliced = {k: self[list(groups[k])] for k in groups.keys()} return sliced + @staticmethod + def merge_group( + *tsgroups, reset_index=False, reset_time_support=False, ignore_metadata=False + ): + """ + Merge multiple TsGroup objects into a single TsGroup object. + + Parameters + ---------- + *tsgroups : TsGroup + The TsGroup objects to merge + reset_index : bool, optional + If True, the keys will be reset to range(len(data)) + If False, the keys of the TsGroup objects should be non-overlapping and will be preserved + reset_time_support : bool, optional + If True, the merged TsGroup will merge time supports from all the Ts/Tsd objects in data + If False, the time support of the TsGroup objects should be the same and will be preserved + ignore_metadata : bool, optional + If True, the merged TsGroup will not have any metadata columns other than 'rate' + If False, all metadata columns should be the same and all metadata will be concatenated + + Returns + ------- + TsGroup + A TsGroup of merged objects + + Raises + ------ + TypeError + If the input objects are not TsGroup objects + ValueError + If `ignore_metadata=False` but metadata columns are not the same + If `reset_index=False` but keys overlap + If `reset_time_support=False` but time supports are not the same + + """ + is_tsgroup = [isinstance(tsg, TsGroup) for tsg in tsgroups] + if not all(is_tsgroup): + not_tsgroup_index = [i + 1 for i, boo in enumerate(is_tsgroup) if not boo] + raise TypeError(f"Input at positions {not_tsgroup_index} are not TsGroup!") + + if len(tsgroups) == 1: + print("Only one TsGroup object provided, no merge needed.") + return tsgroups[0] + + tsg1 = tsgroups[0] + items = tsg1.items() + keys = set(tsg1.keys()) + metadata = tsg1._metadata + + for i, tsg in enumerate(tsgroups[1:]): + if not ignore_metadata: + if tsg1.metadata_columns != tsg.metadata_columns: + raise ValueError( + f"TsGroup at position {i+2} has different metadata columns from previous TsGroup objects. " + "Set `ignore_metadata=True` to bypass the check." + ) + metadata = pd.concat([metadata, tsg._metadata], axis=0) + + if not reset_index: + key_overlap = keys.intersection(tsg.keys()) + if key_overlap: + raise ValueError( + f"TsGroup at position {i+2} has overlapping keys {key_overlap} with previous TsGroup objects. " + "Set `reset_index=True` to bypass the check." + ) + keys.update(tsg.keys()) + + if reset_time_support: + time_support = None + else: + if not np.allclose( + tsg1.time_support.as_units("s").to_numpy(), + tsg.time_support.as_units("s").to_numpy(), + atol=10 ** (-nap_config.time_index_precision), + rtol=0, + ): + raise ValueError( + f"TsGroup at position {i+2} has different time support from previous TsGroup objects. " + "Set `reset_time_support=True` to bypass the check." + ) + time_support = tsg1.time_support + + items.extend(tsg.items()) + + if reset_index: + metadata.index = range(len(metadata)) + data = {i: ts[1] for i, ts in enumerate(items)} + else: + data = dict(items) + + if ignore_metadata: + return TsGroup(data, time_support=time_support, bypass_check=False) + else: + cols = metadata.columns.drop("rate") + return TsGroup( + data, time_support=time_support, bypass_check=False, **metadata[cols] + ) + + def merge( + self, + *tsgroups, + reset_index=False, + reset_time_support=False, + ignore_metadata=False, + ): + """ + Merge the TsGroup object with other TsGroup objects. + Common uses include adding more neurons/channels (supposing each Ts/Tsd corresponds to data from a neuron/channel) or adding more trials (supposing each Ts/Tsd corresponds to data from a trial). + + Parameters + ---------- + *tsgroups : TsGroup + The TsGroup objects to merge with + reset_index : bool, optional + If True, the keys will be reset to range(len(data)) + If False, the keys of the TsGroup objects should be non-overlapping and will be preserved + reset_time_support : bool, optional + If True, the merged TsGroup will merge time supports from all the Ts/Tsd objects in data + If False, the time support of the TsGroup objects should be the same and will be preserved + ignore_metadata : bool, optional + If True, the merged TsGroup will not have any metadata columns other than 'rate' + If False, all metadata columns should be the same and all metadata will be concatenated + + Returns + ------- + TsGroup + A TsGroup of merged objects + + Raises + ------ + TypeError + If the input objects are not TsGroup objects + ValueError + If `ignore_metadata=False` but metadata columns are not the same + If `reset_index=False` but keys overlap + If `reset_time_support=False` but time supports are not the same + + Examples + -------- + + >>> import pynapple as nap + >>> time_support_a = nap.IntervalSet(start=-1, end=1, time_units='s') + >>> time_support_b = nap.IntervalSet(start=-5, end=5, time_units='s') + + >>> dict1 = {0: nap.Ts(t=[-1, 0, 1], time_units='s')} + >>> tsgroup1 = nap.TsGroup(dict1, time_support=time_support_a) + + >>> dict2 = {10: nap.Ts(t=[-1, 0, 1], time_units='s')} + >>> tsgroup2 = nap.TsGroup(dict2, time_support=time_support_a) + + >>> dict3 = {0: nap.Ts(t=[-.1, 0, .1], time_units='s')} + >>> tsgroup3 = nap.TsGroup(dict3, time_support=time_support_a) + + >>> dict4 = {10: nap.Ts(t=[-1, 0, 1], time_units='s')} + >>> tsgroup4 = nap.TsGroup(dict2, time_support=time_support_b) + + Merge with default options if have the same time support and non-overlapping indexes: + + >>> tsgroup_12 = tsgroup1.merge(tsgroup2) + >>> tsgroup_12 + Index rate + ------- ------ + 0 1.5 + 10 1.5 + + Set `reset_index=True` if indexes are overlapping: + + >>> tsgroup_13 = tsgroup1.merge(tsgroup3, reset_index=True) + >>> tsgroup_13 + Index rate + ------- ------ + 0 1.5 + 1 1.5 + + Set `reset_time_support=True` if time supports are different: + + >>> tsgroup_14 = tsgroup1.merge(tsgroup4, reset_time_support=True) + >>> tsgroup_14 + >>> tsgroup_14.time_support + Index rate + ------- ------ + 0 0.3 + 10 0.3 + + start end + 0 -5 5 + shape: (1, 2), time unit: sec. + + See Also + -------- + [`TsGroup.merge_group`](./#pynapple.core.ts_group.TsGroup.merge_group) + """ + return TsGroup.merge_group( + self, + *tsgroups, + reset_index=reset_index, + reset_time_support=reset_time_support, + ignore_metadata=ignore_metadata, + ) + def save(self, filename): """ Save TsGroup object in npz format. The file will contain the timestamps, diff --git a/pynapple/io/interface_nwb.py b/pynapple/io/interface_nwb.py index 1cfc76f4..70f33956 100644 --- a/pynapple/io/interface_nwb.py +++ b/pynapple/io/interface_nwb.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- # @Author: Guillaume Viejo # @Date: 2023-08-01 11:54:45 -# @Last Modified by: gviejo -# @Last Modified time: 2023-10-19 12:16:55 +# @Last Modified by: Guillaume Viejo +# @Last Modified time: 2024-05-21 15:28:27 """ Pynapple class to interface with NWB files. @@ -72,7 +72,7 @@ def _extract_compatible_data_from_nwbfile(nwbfile): return data -def _make_interval_set(obj): +def _make_interval_set(obj, **kwargs): """Helper function to make IntervalSet Parameters @@ -128,13 +128,15 @@ def _make_interval_set(obj): return obj -def _make_tsd(obj): +def _make_tsd(obj, lazy_loading=True): """Helper function to make Tsd Parameters ---------- obj : pynwb.misc.TimeSeries NWB object + lazy_loading: bool + If True return a memory-view of the data, load otherwise. Returns ------- @@ -142,24 +144,29 @@ def _make_tsd(obj): """ - d = obj.data[:] + d = obj.data + if not lazy_loading: + d = d[:] + if obj.timestamps is not None: t = obj.timestamps[:] else: t = obj.starting_time + np.arange(obj.num_samples) / obj.rate - data = nap.Tsd(t=t, d=d) + data = nap.Tsd(t=t, d=d, load_array=not lazy_loading) return data -def _make_tsd_tensor(obj): +def _make_tsd_tensor(obj, lazy_loading=True): """Helper function to make TsdTensor Parameters ---------- obj : pynwb.misc.TimeSeries NWB object + lazy_loading: bool + If True return a memory-view of the data, load otherwise. Returns ------- @@ -167,24 +174,29 @@ def _make_tsd_tensor(obj): """ - d = obj.data[:] + d = obj.data + if not lazy_loading: + d = d[:] + if obj.timestamps is not None: t = obj.timestamps[:] else: t = obj.starting_time + np.arange(obj.num_samples) / obj.rate - data = nap.TsdTensor(t=t, d=d) + data = nap.TsdTensor(t=t, d=d, load_array=not lazy_loading) return data -def _make_tsd_frame(obj): +def _make_tsd_frame(obj, lazy_loading=True): """Helper function to make TsdFrame Parameters ---------- obj : pynwb.misc.TimeSeries NWB object + lazy_loading: bool + If True return a memory-view of the data, load otherwise. Returns ------- @@ -192,7 +204,10 @@ def _make_tsd_frame(obj): """ - d = obj.data[:] + d = obj.data + if not lazy_loading: + d = d[:] + if obj.timestamps is not None: t = obj.timestamps[:] else: @@ -232,12 +247,12 @@ def _make_tsd_frame(obj): else: columns = np.arange(obj.data.shape[1]) - data = nap.TsdFrame(t=t, d=d, columns=columns) + data = nap.TsdFrame(t=t, d=d, columns=columns, load_array=not lazy_loading) return data -def _make_tsgroup(obj): +def _make_tsgroup(obj, **kwargs): """Helper function to make TsGroup Parameters @@ -301,7 +316,7 @@ def _make_tsgroup(obj): return tsgroup -def _make_ts(obj): +def _make_ts(obj, **kwargs): """Helper function to make Ts Parameters @@ -355,12 +370,14 @@ class NWBFile(UserDict): "TsGroup": _make_tsgroup, } - def __init__(self, file): + def __init__(self, file, lazy_loading=True): """ Parameters ---------- file : str or pynwb.file.NWBFile Valid file to a NWB file + lazy_loading: bool + If True return a memory-view of the data, load otherwise. Raises ------ @@ -391,6 +408,8 @@ def __init__(self, file): self._view = [[k, self.data[k]["type"]] for k in self.data.keys()] + self._lazy_loading = lazy_loading + UserDict.__init__(self, self.data) def __str__(self): @@ -443,7 +462,9 @@ def __getitem__(self, key): if isinstance(self.data[key], dict) and "id" in self.data[key]: obj = self.nwb.objects[self.data[key]["id"]] try: - data = self._f_eval[self.data[key]["type"]](obj) + data = self._f_eval[self.data[key]["type"]]( + obj, lazy_loading=self._lazy_loading + ) except Exception: warnings.warn( "Failed to build {}.\n Returning the NWB object for manual inspection".format( diff --git a/pynapple/process/_process_functions.py b/pynapple/process/_process_functions.py index 0713d26c..9d9ea5ff 100644 --- a/pynapple/process/_process_functions.py +++ b/pynapple/process/_process_functions.py @@ -190,7 +190,7 @@ def _perievent_trigger_average( time_target_array, count_array, time_array, - data_array, + data_array[:], starts, ends, windows, @@ -204,7 +204,7 @@ def _perievent_trigger_average( time_target_array, count_array, time_array, - np.expand_dims(data_array, -1), + np.expand_dims(data_array[:], -1), starts, ends, windows, @@ -216,7 +216,7 @@ def _perievent_trigger_average( time_target_array, count_array, time_array, - data_array, + data_array[:], starts, ends, windows, diff --git a/tests/test_lazy_loading.py b/tests/test_lazy_loading.py new file mode 100644 index 00000000..9618e52e --- /dev/null +++ b/tests/test_lazy_loading.py @@ -0,0 +1,260 @@ +import os.path +import warnings +from contextlib import nullcontext as does_not_raise +from pathlib import Path + +import h5py +import numpy as np +import pandas as pd +import pytest +from pynwb.testing.mock.base import mock_TimeSeries +from pynwb.testing.mock.file import mock_NWBFile + +import pynapple as nap + + +@pytest.mark.parametrize( + "time, data, expectation", + [ + (np.arange(12), np.arange(12), does_not_raise()), + (np.arange(12), "not_an_array", pytest.raises(TypeError, match="Data should be array-like")) + ] +) +def test_lazy_load_hdf5_is_array(time, data, expectation): + file_path = Path('data.h5') + try: + with h5py.File(file_path, 'w') as f: + f.create_dataset('data', data=data) + h5_data = h5py.File(file_path, 'r')["data"] + with expectation: + nap.Tsd(t=time, d=h5_data, load_array=False) + finally: + # delete file + if file_path.exists(): + file_path.unlink() + + +@pytest.mark.parametrize( + "time, data", + [ + (np.arange(12), np.arange(12)), + ] +) +@pytest.mark.parametrize("convert_flag", [True, False]) +def test_lazy_load_hdf5_is_array(time, data, convert_flag): + file_path = Path('data.h5') + try: + with h5py.File(file_path, 'w') as f: + f.create_dataset('data', data=data) + # get the tsd + h5_data = h5py.File(file_path, 'r')["data"] + tsd = nap.Tsd(t=time, d=h5_data, load_array=convert_flag) + if convert_flag: + assert isinstance(tsd.d, np.ndarray) + else: + assert isinstance(tsd.d, h5py.Dataset) + finally: + # delete file + if file_path.exists(): + file_path.unlink() + + +@pytest.mark.parametrize("time, data", [(np.arange(12), np.arange(12))]) +@pytest.mark.parametrize("cls", [nap.Tsd, nap.TsdFrame, nap.TsdTensor]) +@pytest.mark.parametrize("func", [np.exp, lambda x: x*2]) +def test_lazy_load_hdf5_apply_func(time, data, func,cls): + """Apply a unary function to a lazy loaded array.""" + file_path = Path('data.h5') + try: + if cls is nap.TsdFrame: + data = data[:, None] + elif cls is nap.TsdTensor: + data = data[:, None, None] + with h5py.File(file_path, 'w') as f: + f.create_dataset('data', data=data) + # get the tsd + h5_data = h5py.File(file_path, 'r')["data"] + # lazy load and apply function + res = func(cls(t=time, d=h5_data, load_array=False)) + assert isinstance(res, cls) + assert isinstance(res.d, np.ndarray) + finally: + # delete file + if file_path.exists(): + file_path.unlink() + + +@pytest.mark.parametrize("time, data", [(np.arange(12), np.arange(12))]) +@pytest.mark.parametrize("cls", [nap.Tsd, nap.TsdFrame, nap.TsdTensor]) +@pytest.mark.parametrize( + "method_name, args", + [ + ("bin_average", [0.1]), + ("count", [0.1]), + ("interpolate", [nap.Ts(t=np.linspace(0, 12, 50))]), + ("convolve", [np.ones(3)]), + ("smooth", [2]), + ("dropna", [True]), + ("value_from", [nap.Tsd(t=np.linspace(0, 12, 20), d=np.random.normal(size=20))]), + ("copy", []), + ("get", [2, 7]) + ] +) +def test_lazy_load_hdf5_apply_method(time, data, method_name, args, cls): + file_path = Path('data.h5') + try: + if cls is nap.TsdFrame: + data = data[:, None] + elif cls is nap.TsdTensor: + data = data[:, None, None] + with h5py.File(file_path, 'w') as f: + f.create_dataset('data', data=data) + # get the tsd + h5_data = h5py.File(file_path, 'r')["data"] + # lazy load and apply function + tsd = cls(t=time, d=h5_data, load_array=False) + func = getattr(tsd, method_name) + out = func(*args) + assert isinstance(out.d, np.ndarray) + finally: + # delete file + if file_path.exists(): + file_path.unlink() + + +@pytest.mark.parametrize("time, data", [(np.arange(12), np.arange(12))]) +@pytest.mark.parametrize( + "method_name, args, expected_out_type", + [ + ("threshold", [3], nap.Tsd), + ("as_series", [], pd.Series), + ("as_units", ['ms'], pd.Series), + ("to_tsgroup", [], nap.TsGroup) + ] +) +def test_lazy_load_hdf5_apply_method_tsd_specific(time, data, method_name, args, expected_out_type): + file_path = Path('data.h5') + try: + with h5py.File(file_path, 'w') as f: + f.create_dataset('data', data=data) + # get the tsd + h5_data = h5py.File(file_path, 'r')["data"] + # lazy load and apply function + tsd = nap.Tsd(t=time, d=h5_data, load_array=False) + func = getattr(tsd, method_name) + assert isinstance(func(*args), expected_out_type) + finally: + # delete file + if file_path.exists(): + file_path.unlink() + + +@pytest.mark.parametrize("time, data", [(np.arange(12), np.arange(12))]) +@pytest.mark.parametrize( + "method_name, args, expected_out_type", + [ + ("as_dataframe", [], pd.DataFrame), + ] +) +def test_lazy_load_hdf5_apply_method_tsdframe_specific(time, data, method_name, args, expected_out_type): + file_path = Path('data.h5') + try: + with h5py.File(file_path, 'w') as f: + f.create_dataset('data', data=data[:, None]) + # get the tsd + h5_data = h5py.File(file_path, 'r')["data"] + # lazy load and apply function + tsd = nap.TsdFrame(t=time, d=h5_data, load_array=False) + func = getattr(tsd, method_name) + assert isinstance(func(*args), expected_out_type) + finally: + # delete file + if file_path.exists(): + file_path.unlink() + + +def test_lazy_load_hdf5_tsdframe_loc(): + file_path = Path('data.h5') + data = np.arange(10).reshape(5, 2) + try: + with h5py.File(file_path, 'w') as f: + f.create_dataset('data', data=data) + # get the tsd + h5_data = h5py.File(file_path, 'r')["data"] + # lazy load and apply function + tsd = nap.TsdFrame(t=np.arange(data.shape[0]), d=h5_data, load_array=False).loc[1] + assert isinstance(tsd, nap.Tsd) + assert all(tsd.d == np.array([1, 3, 5, 7, 9])) + + finally: + # delete file + if file_path.exists(): + file_path.unlink() + +@pytest.mark.parametrize( + "lazy, expected_type", + [ + (True, h5py.Dataset), + (False, np.ndarray), + ] +) +def test_lazy_load_nwb(lazy, expected_type): + try: + nwb = nap.NWBFile("tests/nwbfilestest/basic/pynapplenwb/A2929-200711.nwb", lazy_loading=lazy) + except: + nwb = nap.NWBFile("nwbfilestest/basic/pynapplenwb/A2929-200711.nwb", lazy_loading=lazy) + + tsd = nwb["z"] + assert isinstance(tsd.d, expected_type) + nwb.io.close() + + +@pytest.mark.parametrize("data", [np.ones(10), np.ones((10, 2)), np.ones((10, 2, 2))]) +def test_lazy_load_nwb_no_warnings(data): + file_path = Path('data.h5') + + try: + with h5py.File(file_path, 'w') as f: + f.create_dataset('data', data=data) + time_series = mock_TimeSeries(name="TimeSeries", data=f["data"]) + nwbfile = mock_NWBFile() + nwbfile.add_acquisition(time_series) + nwb = nap.NWBFile(nwbfile) + + with warnings.catch_warnings(): + warnings.simplefilter("error") + tsd = nwb["TimeSeries"] + tsd.count(0.1) + assert isinstance(tsd.d, h5py.Dataset) + + finally: + if file_path.exists(): + file_path.unlink() + + +def test_tsgroup_no_warnings(): + n_units = 2 + try: + for k in range(n_units): + file_path = Path(f'data_{k}.h5') + with h5py.File(file_path, 'w') as f: + f.create_dataset('spks', data=np.sort(np.random.uniform(0, 10, size=20))) + with warnings.catch_warnings(): + warnings.simplefilter("error") + + nwbfile = mock_NWBFile() + + for k in range(n_units): + file_path = Path(f'data_{k}.h5') + spike_times = h5py.File(file_path, "r")['spks'] + nwbfile.add_unit(spike_times=spike_times) + + nwb = nap.NWBFile(nwbfile) + tsgroup = nwb["units"] + tsgroup.count(0.1) + + finally: + for k in range(n_units): + file_path = Path(f'data_{k}.h5') + if file_path.exists(): + file_path.unlink() diff --git a/tests/test_nwb.py b/tests/test_nwb.py index bf3558e8..943726e9 100644 --- a/tests/test_nwb.py +++ b/tests/test_nwb.py @@ -519,11 +519,11 @@ def test_add_Units(): nwbfile.add_unit(spike_times=spike_times, quality="good", alpha=alpha[n_units_per_shank]) spks[n_units_per_shank] = spike_times - nwb = nap.NWBFile(nwbfile) - assert len(nwb) == 1 - assert "units" in nwb.keys() + nwb_tsgroup = nap.NWBFile(nwbfile) + assert len(nwb_tsgroup) == 1 + assert "units" in nwb_tsgroup.keys() - data = nwb['units'] + data = nwb_tsgroup['units'] assert isinstance(data, nap.TsGroup) assert len(data) == n_units for n in data.keys(): diff --git a/tests/test_ts_group.py b/tests/test_ts_group.py index 65d23af4..fc221f60 100644 --- a/tests/test_ts_group.py +++ b/tests/test_ts_group.py @@ -14,7 +14,6 @@ import warnings from contextlib import nullcontext as does_not_raise - @pytest.fixture def group(): """Fixture to be used in all tests.""" @@ -575,15 +574,16 @@ def test_save_npz(self, group): np.testing.assert_array_almost_equal(file['index'], index) np.testing.assert_array_almost_equal(file['meta'], np.arange(len(group), dtype=np.int64)) assert np.all(file['meta2']==np.array(['a', 'b', 'c'])) + file.close() tsgroup3 = nap.TsGroup({ 0: nap.Ts(t=np.arange(0, 20)), }) tsgroup3.save("tsgroup3") - file = np.load("tsgroup3.npz") - assert 'd' not in list(file.keys()) - np.testing.assert_array_almost_equal(file['t'], tsgroup3[0].index) + with np.load("tsgroup3.npz") as file: + assert 'd' not in list(file.keys()) + np.testing.assert_array_almost_equal(file['t'], tsgroup3[0].index) os.remove("tsgroup.npz") os.remove("tsgroup2.npz") @@ -752,4 +752,106 @@ def test_getitem_attribute_error(self, ts_group): ) def test_getitem_boolean_fail(self, ts_group, bool_idx, expectation): with expectation: - out = ts_group[bool_idx] \ No newline at end of file + out = ts_group[bool_idx] + + def test_merge_complete(self, ts_group): + with pytest.raises(TypeError, match="Input at positions(.*)are not TsGroup!"): + nap.TsGroup.merge_group(ts_group, str, dict) + + ts_group2 = nap.TsGroup( + { + 3: nap.Ts(t=np.arange(15)), + 4: nap.Ts(t=np.arange(20)), + }, + time_support=ts_group.time_support, + meta=np.array([12, 13]) + ) + merged = ts_group.merge(ts_group2) + assert len(merged) == 4 + assert np.all(merged.keys() == np.array([1, 2, 3, 4])) + assert np.all(merged.meta == np.array([10, 11, 12, 13])) + np.testing.assert_equal(merged.metadata_columns, ts_group.metadata_columns) + + @pytest.mark.parametrize( + 'col_name, ignore_metadata, expectation', + [ + ('meta', False, does_not_raise()), + ('meta', True, does_not_raise()), + ('wrong_name', False, pytest.raises(ValueError, match="TsGroup at position 2 has different metadata columns.*")), + ('wrong_name', True, does_not_raise()) + ] + ) + def test_merge_metadata(self, ts_group, col_name, ignore_metadata, expectation): + metadata = pd.DataFrame([12, 13], index=[3, 4], columns=[col_name]) + ts_group2 = nap.TsGroup( + { + 3: nap.Ts(t=np.arange(15)), + 4: nap.Ts(t=np.arange(20)), + }, + time_support=ts_group.time_support, + **metadata + ) + + with expectation: + merged = ts_group.merge(ts_group2, ignore_metadata=ignore_metadata) + + if ignore_metadata: + assert merged.metadata_columns[0] == 'rate' + elif col_name == 'meta': + np.testing.assert_equal(merged.metadata_columns, ts_group.metadata_columns) + + @pytest.mark.parametrize( + 'index, reset_index, expectation', + [ + (np.array([1, 2]), False, pytest.raises(ValueError, match="TsGroup at position 2 has overlapping keys.*")), + (np.array([1, 2]), True, does_not_raise()), + (np.array([3, 4]), False, does_not_raise()), + (np.array([3, 4]), True, does_not_raise()) + ] + ) + def test_merge_index(self, ts_group, index, reset_index, expectation): + ts_group2 = nap.TsGroup( + dict(zip(index, [nap.Ts(t=np.arange(15)), nap.Ts(t=np.arange(20))])), + time_support=ts_group.time_support, + meta=np.array([12, 13]) + ) + + with expectation: + merged = ts_group.merge(ts_group2, reset_index=reset_index) + + if reset_index: + assert np.all(merged.keys() == np.arange(4)) + elif np.all(index == np.array([3, 4])): + assert np.all(merged.keys() == np.array([1, 2, 3, 4])) + + @pytest.mark.parametrize( + 'time_support, reset_time_support, expectation', + [ + (None, False, does_not_raise()), + (None, True, does_not_raise()), + (nap.IntervalSet(start=0, end=1), False, + pytest.raises(ValueError, match="TsGroup at position 2 has different time support.*")), + (nap.IntervalSet(start=0, end=1), True, does_not_raise()) + ] + ) + def test_merge_time_support(self, ts_group, time_support, reset_time_support, expectation): + if time_support is None: + time_support = ts_group.time_support + + ts_group2 = nap.TsGroup( + { + 3: nap.Ts(t=np.arange(15)), + 4: nap.Ts(t=np.arange(20)), + }, + time_support=time_support, + meta=np.array([12, 13]) + ) + + with expectation: + merged = ts_group.merge(ts_group2, reset_time_support=reset_time_support) + + if reset_time_support: + np.testing.assert_array_almost_equal( + ts_group.time_support.as_units("s").to_numpy(), + merged.time_support.as_units("s").to_numpy() + ) \ No newline at end of file