Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Add TsGroup.merge_group method #275

Merged
merged 19 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 202 additions & 0 deletions pynapple/core/ts_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ._core_functions import _count
from ._jitted_functions import jitunion, jitunion_isets
from .base_class import Base
from .config import nap_config
from .interval_set import IntervalSet
from .time_index import TsIndex
from .time_series import BaseTsd, Ts, Tsd, TsdFrame, is_array_like
Expand Down Expand Up @@ -1026,6 +1027,207 @@ def getby_category(self, key):
sliced = {k: self[list(groups[k])] for k in groups.keys()}
return sliced

@staticmethod
def merge_group(
*tsgroups, reset_index=False, reset_time_support=False, ignore_metadata=False
):
"""
Merge multiple TsGroup objects into a single TsGroup object.

Parameters
----------
*tsgroups : TsGroup
The TsGroup objects to merge
reset_index : bool, optional
If True, the keys will be reset to range(len(data))
If False, the keys of the TsGroup objects should be non-overlapping and will be preserved
reset_time_support : bool, optional
If True, the merged TsGroup will merge time supports from all the Ts/Tsd objects in data
If False, the time support of the TsGroup objects should be the same and will be preserved
ignore_metadata : bool, optional
If True, the merged TsGroup will not have any metadata columns other than 'rate'
If False, all metadata columns should be the same and all metadata will be concatenated

Returns
-------
TsGroup
A TsGroup of merged objects

Raises
------
TypeError
If the input objects are not TsGroup objects
ValueError
If `ignore_metadata=False` but metadata columns are not the same
If `reset_index=False` but keys overlap
If `reset_time_support=False` but time supports are not the same

"""
is_tsgroup = [isinstance(tsg, TsGroup) for tsg in tsgroups]
if not all(is_tsgroup):
not_tsgroup_index = [i + 1 for i, boo in enumerate(is_tsgroup) if not boo]
raise TypeError(f"Input at positions {not_tsgroup_index} are not TsGroup!")

if len(tsgroups) == 1:
print("Only one TsGroup object provided, no merge needed.")
return tsgroups[0]

tsg1 = tsgroups[0]
items = tsg1.items()
keys = set(tsg1.keys())
metadata = tsg1._metadata

for i, tsg in enumerate(tsgroups[1:]):
if not ignore_metadata:
if tsg1.metadata_columns != tsg.metadata_columns:
raise ValueError(
f"TsGroup at position {i+2} has different metadata columns from previous TsGroup objects. "
"Set `ignore_metadata=True` to bypass the check."
)
metadata = pd.concat([metadata, tsg._metadata], axis=0)

if not reset_index:
key_overlap = keys.intersection(tsg.keys())
if key_overlap:
raise ValueError(
f"TsGroup at position {i+2} has overlapping keys {key_overlap} with previous TsGroup objects. "
"Set `reset_index=True` to bypass the check."
)
keys.update(tsg.keys())

if reset_time_support:
time_support = None
else:
if not np.allclose(
tsg1.time_support.as_units("s").to_numpy(),
tsg.time_support.as_units("s").to_numpy(),
atol=10 ** (-nap_config.time_index_precision),
rtol=0,
):
raise ValueError(
f"TsGroup at position {i+2} has different time support from previous TsGroup objects. "
"Set `reset_time_support=True` to bypass the check."
)
time_support = tsg1.time_support

items.extend(tsg.items())

if reset_index:
metadata.index = range(len(metadata))
data = {i: ts[1] for i, ts in enumerate(items)}
else:
data = dict(items)

if ignore_metadata:
return TsGroup(data, time_support=time_support, bypass_check=False)
else:
cols = metadata.columns.drop("rate")
return TsGroup(
data, time_support=time_support, bypass_check=False, **metadata[cols]
)

def merge(
self,
*tsgroups,
reset_index=False,
reset_time_support=False,
ignore_metadata=False,
):
"""
Merge the TsGroup object with other TsGroup objects.
Common uses include adding more neurons/channels (supposing each Ts/Tsd corresponds to data from a neuron/channel) or adding more trials (supposing each Ts/Tsd corresponds to data from a trial).

Parameters
----------
*tsgroups : TsGroup
The TsGroup objects to merge with
reset_index : bool, optional
If True, the keys will be reset to range(len(data))
If False, the keys of the TsGroup objects should be non-overlapping and will be preserved
reset_time_support : bool, optional
If True, the merged TsGroup will merge time supports from all the Ts/Tsd objects in data
If False, the time support of the TsGroup objects should be the same and will be preserved
ignore_metadata : bool, optional
If True, the merged TsGroup will not have any metadata columns other than 'rate'
If False, all metadata columns should be the same and all metadata will be concatenated

Returns
-------
TsGroup
A TsGroup of merged objects

Raises
------
TypeError
If the input objects are not TsGroup objects
ValueError
If `ignore_metadata=False` but metadata columns are not the same
If `reset_index=False` but keys overlap
If `reset_time_support=False` but time supports are not the same

Examples
--------

>>> import pynapple as nap
>>> time_support_a = nap.IntervalSet(start=-1, end=1, time_units='s')
>>> time_support_b = nap.IntervalSet(start=-5, end=5, time_units='s')

>>> dict1 = {0: nap.Ts(t=[-1, 0, 1], time_units='s')}
>>> tsgroup1 = nap.TsGroup(dict1, time_support=time_support_a)

>>> dict2 = {10: nap.Ts(t=[-1, 0, 1], time_units='s')}
>>> tsgroup2 = nap.TsGroup(dict2, time_support=time_support_a)

>>> dict3 = {0: nap.Ts(t=[-.1, 0, .1], time_units='s')}
>>> tsgroup3 = nap.TsGroup(dict3, time_support=time_support_a)

>>> dict4 = {10: nap.Ts(t=[-1, 0, 1], time_units='s')}
>>> tsgroup4 = nap.TsGroup(dict2, time_support=time_support_b)

Merge with default options if have the same time support and non-overlapping indexes:

>>> tsgroup_12 = tsgroup1.merge(tsgroup2)
>>> tsgroup_12
Index rate
------- ------
0 1.5
10 1.5

Set `reset_index=True` if indexes are overlapping:

>>> tsgroup_13 = tsgroup1.merge(tsgroup3, reset_index=True)
>>> tsgroup_13
Index rate
------- ------
0 1.5
1 1.5

Set `reset_time_support=True` if time supports are different:

>>> tsgroup_14 = tsgroup1.merge(tsgroup4, reset_time_support=True)
>>> tsgroup_14
>>> tsgroup_14.time_support
Index rate
------- ------
0 0.3
10 0.3

start end
0 -5 5
shape: (1, 2), time unit: sec.

See Also
--------
[`TsGroup.merge_group`](./#pynapple.core.ts_group.TsGroup.merge_group)
"""
return TsGroup.merge_group(
self,
*tsgroups,
reset_index=reset_index,
reset_time_support=reset_time_support,
ignore_metadata=ignore_metadata,
)

def save(self, filename):
"""
Save TsGroup object in npz format. The file will contain the timestamps,
Expand Down
112 changes: 107 additions & 5 deletions tests/test_ts_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import warnings
from contextlib import nullcontext as does_not_raise


@pytest.fixture
def group():
"""Fixture to be used in all tests."""
Expand Down Expand Up @@ -575,15 +574,16 @@ def test_save_npz(self, group):
np.testing.assert_array_almost_equal(file['index'], index)
np.testing.assert_array_almost_equal(file['meta'], np.arange(len(group), dtype=np.int64))
assert np.all(file['meta2']==np.array(['a', 'b', 'c']))
file.close()

tsgroup3 = nap.TsGroup({
0: nap.Ts(t=np.arange(0, 20)),
})
tsgroup3.save("tsgroup3")
file = np.load("tsgroup3.npz")

assert 'd' not in list(file.keys())
np.testing.assert_array_almost_equal(file['t'], tsgroup3[0].index)
with np.load("tsgroup3.npz") as file:
assert 'd' not in list(file.keys())
np.testing.assert_array_almost_equal(file['t'], tsgroup3[0].index)

os.remove("tsgroup.npz")
os.remove("tsgroup2.npz")
Expand Down Expand Up @@ -752,4 +752,106 @@ def test_getitem_attribute_error(self, ts_group):
)
def test_getitem_boolean_fail(self, ts_group, bool_idx, expectation):
with expectation:
out = ts_group[bool_idx]
out = ts_group[bool_idx]

def test_merge_complete(self, ts_group):
with pytest.raises(TypeError, match="Input at positions(.*)are not TsGroup!"):
nap.TsGroup.merge_group(ts_group, str, dict)

ts_group2 = nap.TsGroup(
{
3: nap.Ts(t=np.arange(15)),
4: nap.Ts(t=np.arange(20)),
},
time_support=ts_group.time_support,
meta=np.array([12, 13])
)
merged = ts_group.merge(ts_group2)
assert len(merged) == 4
assert np.all(merged.keys() == np.array([1, 2, 3, 4]))
assert np.all(merged.meta == np.array([10, 11, 12, 13]))
np.testing.assert_equal(merged.metadata_columns, ts_group.metadata_columns)

@pytest.mark.parametrize(
'col_name, ignore_metadata, expectation',
[
('meta', False, does_not_raise()),
('meta', True, does_not_raise()),
('wrong_name', False, pytest.raises(ValueError, match="TsGroup at position 2 has different metadata columns.*")),
('wrong_name', True, does_not_raise())
]
)
def test_merge_metadata(self, ts_group, col_name, ignore_metadata, expectation):
metadata = pd.DataFrame([12, 13], index=[3, 4], columns=[col_name])
ts_group2 = nap.TsGroup(
{
3: nap.Ts(t=np.arange(15)),
4: nap.Ts(t=np.arange(20)),
},
time_support=ts_group.time_support,
**metadata
)

with expectation:
merged = ts_group.merge(ts_group2, ignore_metadata=ignore_metadata)

if ignore_metadata:
assert merged.metadata_columns[0] == 'rate'
elif col_name == 'meta':
np.testing.assert_equal(merged.metadata_columns, ts_group.metadata_columns)

@pytest.mark.parametrize(
'index, reset_index, expectation',
[
(np.array([1, 2]), False, pytest.raises(ValueError, match="TsGroup at position 2 has overlapping keys.*")),
(np.array([1, 2]), True, does_not_raise()),
(np.array([3, 4]), False, does_not_raise()),
(np.array([3, 4]), True, does_not_raise())
]
)
def test_merge_index(self, ts_group, index, reset_index, expectation):
ts_group2 = nap.TsGroup(
dict(zip(index, [nap.Ts(t=np.arange(15)), nap.Ts(t=np.arange(20))])),
time_support=ts_group.time_support,
meta=np.array([12, 13])
)

with expectation:
merged = ts_group.merge(ts_group2, reset_index=reset_index)

if reset_index:
assert np.all(merged.keys() == np.arange(4))
elif np.all(index == np.array([3, 4])):
assert np.all(merged.keys() == np.array([1, 2, 3, 4]))

@pytest.mark.parametrize(
'time_support, reset_time_support, expectation',
[
(None, False, does_not_raise()),
(None, True, does_not_raise()),
(nap.IntervalSet(start=0, end=1), False,
pytest.raises(ValueError, match="TsGroup at position 2 has different time support.*")),
(nap.IntervalSet(start=0, end=1), True, does_not_raise())
]
)
def test_merge_time_support(self, ts_group, time_support, reset_time_support, expectation):
if time_support is None:
time_support = ts_group.time_support

ts_group2 = nap.TsGroup(
{
3: nap.Ts(t=np.arange(15)),
4: nap.Ts(t=np.arange(20)),
},
time_support=time_support,
meta=np.array([12, 13])
)

with expectation:
merged = ts_group.merge(ts_group2, reset_time_support=reset_time_support)

if reset_time_support:
np.testing.assert_array_almost_equal(
ts_group.time_support.as_units("s").to_numpy(),
merged.time_support.as_units("s").to_numpy()
)