Skip to content

Commit

Permalink
added test tsgroup
Browse files Browse the repository at this point in the history
  • Loading branch information
BalzaniEdoardo committed May 22, 2024
1 parent efaad74 commit 7aa66aa
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 22 deletions.
63 changes: 45 additions & 18 deletions tests/test_lazy_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import numpy as np
import pandas as pd
import pytest
from pynwb.testing.mock.base import mock_TimeSeries
from pynwb.testing.mock.file import mock_NWBFile

import pynapple as nap

Expand Down Expand Up @@ -207,24 +209,49 @@ def test_lazy_load_nwb(lazy, expected_type):
nwb.io.close()


@pytest.mark.parametrize(
"path, var_name",
[
("phy/pynapplenwb/A8604-211122.nwb", "units"), # TsGroup
("basic/pynapplenwb/A2929-200711.nwb", "z"), # Tsd
("suite2p/pynapplenwb/2022_08_08.nwb", "Neuropil"), # TsdFrame
]
)
def test_lazy_load_nwb_no_warnings(path, var_name):
@pytest.mark.parametrize("data", [np.ones(10), np.ones((10, 2)), np.ones((10, 2, 2))])
def test_lazy_load_nwb_no_warnings(data):
file_path = Path('data.h5')

try:
nwb = nap.NWBFile(os.path.join("tests/nwbfilestest", path), lazy_loading=True)
except:
nwb = nap.NWBFile(os.path.join("nwbfilestest", path), lazy_loading=True)
with h5py.File(file_path, 'w') as f:
f.create_dataset('data', data=data)
time_series = mock_TimeSeries(name="TimeSeries", data=f["data"])
nwbfile = mock_NWBFile()
nwbfile.add_acquisition(time_series)
nwb = nap.NWBFile(nwbfile)

with warnings.catch_warnings():
warnings.simplefilter("error")
tsd = nwb[var_name]
if isinstance(tsd, (nap.Tsd, nap.TsdFrame, nap.TsdTensor)):
tsd * 2
with warnings.catch_warnings():
warnings.simplefilter("error")
tsd = nwb["TimeSeries"]
tsd.count(0.1)
assert isinstance(tsd.d, h5py.Dataset)

nwb.io.close()
finally:
if file_path.exists():
file_path.unlink()


def test_tsgroup_no_warinings():
n_units = 2
try:
for k in range(n_units):
file_path = Path(f'data_{k}.h5')
with h5py.File(file_path, 'w') as f:
f.create_dataset('spks', data=np.sort(np.random.uniform(0, 10, size=20)))
with warnings.catch_warnings():
nwbfile = mock_NWBFile()

for k in range(n_units):
file_path = Path(f'data_{k}.h5')
spike_times = h5py.File(file_path, "r")['spks']
nwbfile.add_unit(spike_times=spike_times)
nwb = nap.NWBFile(nwbfile)
warnings.simplefilter("error")
tsgroup = nwb["units"]
tsgroup.count(0.1)
finally:
for k in range(n_units):
file_path = Path(f'data_{k}.h5')
if file_path.exists():
file_path.unlink()
8 changes: 4 additions & 4 deletions tests/test_nwb.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,11 +519,11 @@ def test_add_Units():
nwbfile.add_unit(spike_times=spike_times, quality="good", alpha=alpha[n_units_per_shank])
spks[n_units_per_shank] = spike_times

nwb = nap.NWBFile(nwbfile)
assert len(nwb) == 1
assert "units" in nwb.keys()
nwb_tsgroup = nap.NWBFile(nwbfile)
assert len(nwb_tsgroup) == 1
assert "units" in nwb_tsgroup.keys()

data = nwb['units']
data = nwb_tsgroup['units']
assert isinstance(data, nap.TsGroup)
assert len(data) == n_units
for n in data.keys():
Expand Down

0 comments on commit 7aa66aa

Please sign in to comment.