Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Add TsGroup.merge_group method #275

Merged
merged 19 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 172 additions & 0 deletions pynapple/core/ts_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ._core_functions import _count
from ._jitted_functions import jitunion, jitunion_isets
from .base_class import Base
from .config import nap_config
from .interval_set import IntervalSet
from .time_index import TsIndex
from .time_series import BaseTsd, Ts, Tsd, TsdFrame, is_array_like
Expand Down Expand Up @@ -1026,6 +1027,177 @@ def getby_category(self, key):
sliced = {k: self[list(groups[k])] for k in groups.keys()}
return sliced

@staticmethod
def merge_group(
*tsgroups, reset_index=False, reset_time_support=False, ignore_metadata=False
):
"""
Merge multiple TsGroup objects into a single TsGroup object.

Parameters
----------
*tsgroups : TsGroup
The TsGroup objects to merge
reset_index : bool, optional
If True, the keys will be reset to range(len(data))
If False, the keys of the TsGroup objects should be non-overlapping and will be preserved
reset_time_support : bool, optional
If True, the merged TsGroup will merge time supports from all the Ts/Tsd objects in data
If False, the time support of the TsGroup objects should be the same and will be preserved
ignore_metadata : bool, optional
If True, the merged TsGroup will not have any metadata columns other than 'rate'
If False, all metadata columns should be the same and all metadata will be concatenated

Returns
-------
TsGroup
A TsGroup of merged objects

Raises
------
TypeError
If the input objects are not TsGroup objects
ValueError
If `ignore_metadata=False` but metadata columns are not the same
If `reset_index=False` but keys overlap
If `reset_time_support=False` but time supports are not the same

Examples
--------

>>> import pynapple as nap
>>> time_support_a = nap.IntervalSet(start=-1, end=1, time_units='s')
>>> time_support_b = nap.IntervalSet(start=-5, end=5, time_units='s')

>>> dict1 = {0: nap.Ts(t=[-1, 0, 1], time_units='s')}
>>> tsgroup1 = nap.TsGroup(dict1, time_support=time_support_a)

>>> dict2 = {10: nap.Ts(t=[-1, 0, 1], time_units='s')}
>>> tsgroup2 = nap.TsGroup(dict2, time_support=time_support_a)

>>> dict3 = {0: nap.Ts(t=[-.1, 0, .1], time_units='s')}
>>> tsgroup3 = nap.TsGroup(dict3, time_support=time_support_a)

>>> dict4 = {10: nap.Ts(t=[-1, 0, 1], time_units='s')}
>>> tsgroup4 = nap.TsGroup(dict2, time_support=time_support_b)

Merge with default options if have the same time support and non-overlapping indexes:

>>> tsgroup_12 = nap.TsGroup.merge_group(tsgroup1, tsgroup2)
>>> tsgroup_12
Index rate
------- ------
0 1.5
10 1.5

Set `reset_index=True` if indexes are overlapping:

>>> tsgroup_13 = nap.TsGroup.merge_group(tsgroup1, tsgroup3, reset_index=True)
>>> tsgroup_13

Index rate
------- ------
0 1.5
1 1.5

Set `reset_time_support=True` if time supports are different:

>>> tsgroup_14 = nap.TsGroup.merge_group(tsgroup1, tsgroup4, reset_time_support=True)
>>> tsgroup_14
>>> tsgroup_14.time_support

Index rate
------- ------
0 0.3
10 0.3

start end
0 -5 5
shape: (1, 2), time unit: sec.

"""
is_tsgroup = [isinstance(tsg, TsGroup) for tsg in tsgroups]
if not all(is_tsgroup):
not_tsgroup_index = [i + 1 for i, boo in enumerate(is_tsgroup) if not boo]
raise TypeError(f"Input at positions {not_tsgroup_index} are not TsGroup!")

if len(tsgroups) == 1:
print("Only one TsGroup object provided, no merge needed.")
return tsgroups[0]

tsg1 = tsgroups[0]
items = tsg1.items()
keys = set(tsg1.keys())
metadata = tsg1._metadata

for i, tsg in enumerate(tsgroups[1:]):
if not ignore_metadata:
if tsg1.metadata_columns != tsg.metadata_columns:
raise ValueError(
f"TsGroup at position {i+2} has different metadata columns from previous TsGroup objects. "
"Set `ignore_metadata=True` to bypass the check."
)
metadata = pd.concat([metadata, tsg._metadata], axis=0)

if not reset_index:
key_overlap = keys.intersection(tsg.keys())
if key_overlap:
raise ValueError(
f"TsGroup at position {i+2} has overlapping keys {key_overlap} with previous TsGroup objects. "
"Set `reset_index=True` to bypass the check."
)
keys.update(tsg.keys())

if reset_time_support:
time_support = None
else:
if not np.allclose(
tsg1.time_support.as_units("s").to_numpy(),
tsg.time_support.as_units("s").to_numpy(),
atol=10 ** (-nap_config.time_index_precision),
rtol=0,
):
raise ValueError(
f"TsGroup at position {i+2} has different time support from previous TsGroup objects. "
"Set `reset_time_support=True` to bypass the check."
)
time_support = tsg1.time_support

items.extend(tsg.items())

if reset_index:
metadata.index = range(len(metadata))
data = {i: ts[1] for i, ts in enumerate(items)}
else:
data = dict(items)

if ignore_metadata:
return TsGroup(data, time_support=time_support, bypass_check=False)
else:
cols = metadata.columns.drop("rate")
return TsGroup(
data, time_support=time_support, bypass_check=False, **metadata[cols]
)

def merge(
self,
*tsgroups,
reset_index=False,
reset_time_support=False,
ignore_metadata=False,
):
"""
Merge the TsGroup object with other TsGroup objects
See `TsGroup.merge_group` for more details
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since this is probably the most user facing method, I would have a nice numpydoc docstring here too!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good suggestion. Could you offer some tips on how to do so while reducing redundancy?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest to move the detailed docstrings from the merge_group to the merge.
In merge_group keep the same initial paragraph describing the method, the parameters, and the returns sections. I'll add a See Also section that points to merge.

"""
return TsGroup.merge_group(
self,
*tsgroups,
reset_index=reset_index,
reset_time_support=reset_time_support,
ignore_metadata=ignore_metadata,
)

def save(self, filename):
"""
Save TsGroup object in npz format. The file will contain the timestamps,
Expand Down
113 changes: 109 additions & 4 deletions tests/test_ts_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from collections import UserDict
import warnings
from contextlib import nullcontext as does_not_raise
import itertools


@pytest.fixture
Expand Down Expand Up @@ -575,15 +576,16 @@ def test_save_npz(self, group):
np.testing.assert_array_almost_equal(file['index'], index)
np.testing.assert_array_almost_equal(file['meta'], np.arange(len(group), dtype=np.int64))
assert np.all(file['meta2']==np.array(['a', 'b', 'c']))
file.close()

tsgroup3 = nap.TsGroup({
0: nap.Ts(t=np.arange(0, 20)),
})
tsgroup3.save("tsgroup3")
file = np.load("tsgroup3.npz")

assert 'd' not in list(file.keys())
np.testing.assert_array_almost_equal(file['t'], tsgroup3[0].index)
with np.load("tsgroup3.npz") as file:
assert 'd' not in list(file.keys())
np.testing.assert_array_almost_equal(file['t'], tsgroup3[0].index)

os.remove("tsgroup.npz")
os.remove("tsgroup2.npz")
Expand Down Expand Up @@ -752,4 +754,107 @@ def test_getitem_attribute_error(self, ts_group):
)
def test_getitem_boolean_fail(self, ts_group, bool_idx, expectation):
with expectation:
out = ts_group[bool_idx]
out = ts_group[bool_idx]

def test_merge_complete(self, ts_group):
with pytest.raises(TypeError) as e_info:
nap.TsGroup.merge_group(ts_group, str, dict)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
with pytest.raises(TypeError) as e_info:
nap.TsGroup.merge_group(ts_group, str, dict)
assert str(e_info.value) == f"Input at positions {[2, 3]} are not TsGroup!"
with pytest.raises(TypeError, match=f"Input at positions {[2, 3]} are not TsGroup!"):
nap.TsGroup.merge_group(ts_group, str, dict)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interestingly, this change gave a perplexing error

E           AssertionError: Regex pattern did not match.
E            Regex: 'Input at positions [2, 3] are not TsGroup!'
E            Input: 'Input at positions [2, 3] are not TsGroup!'
E            Did you mean to `re.escape()` the regex?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed with replacing {[2, 3]} with (.*)

assert str(e_info.value) == f"Input at positions {[2, 3]} are not TsGroup!"

ts_group2 = nap.TsGroup(
{
3: nap.Ts(t=np.arange(15)),
4: nap.Ts(t=np.arange(20)),
},
time_support=ts_group.time_support,
meta=np.array([12, 13])
)
merged = ts_group.merge(ts_group2)
assert len(merged) == 4
assert np.all(merged.keys() == np.array([1, 2, 3, 4]))
assert np.all(merged.meta == np.array([10, 11, 12, 13]))
np.testing.assert_equal(merged.metadata_columns, ts_group.metadata_columns)

@pytest.mark.parametrize(
'col_name, ignore_metadata, expectation',
[
('meta', False, does_not_raise()),
('meta', True, does_not_raise()),
('wrong_name', False, pytest.raises(ValueError, match="TsGroup at position 2 has different metadata columns.*")),
('wrong_name', True, does_not_raise())
]
)
def test_merge_metadata(self, ts_group, col_name, ignore_metadata, expectation):
metadata = pd.DataFrame([12, 13], index=[3, 4], columns=[col_name])
ts_group2 = nap.TsGroup(
{
3: nap.Ts(t=np.arange(15)),
4: nap.Ts(t=np.arange(20)),
},
time_support=ts_group.time_support,
**metadata
)

with expectation:
merged = ts_group.merge(ts_group2, ignore_metadata=ignore_metadata)

if ignore_metadata:
assert merged.metadata_columns[0] == 'rate'
elif col_name == 'meta':
np.testing.assert_equal(merged.metadata_columns, ts_group.metadata_columns)

@pytest.mark.parametrize(
'index, reset_index, expectation',
[
(np.array([1, 2]), False, pytest.raises(ValueError, match="TsGroup at position 2 has overlapping keys.*")),
(np.array([1, 2]), True, does_not_raise()),
(np.array([3, 4]), False, does_not_raise()),
(np.array([3, 4]), True, does_not_raise())
]
)
def test_merge_index(self, ts_group, index, reset_index, expectation):
ts_group2 = nap.TsGroup(
dict(zip(index, [nap.Ts(t=np.arange(15)), nap.Ts(t=np.arange(20))])),
time_support=ts_group.time_support,
meta=np.array([12, 13])
)

with expectation:
merged = ts_group.merge(ts_group2, reset_index=reset_index)

if reset_index:
assert np.all(merged.keys() == np.arange(4))
elif np.all(index == np.array([3, 4])):
assert np.all(merged.keys() == np.array([1, 2, 3, 4]))

@pytest.mark.parametrize(
'time_support, reset_time_support, expectation',
[
(None, False, does_not_raise()),
(None, True, does_not_raise()),
(nap.IntervalSet(start=0, end=1), False,
pytest.raises(ValueError, match="TsGroup at position 2 has different time support.*")),
(nap.IntervalSet(start=0, end=1), True, does_not_raise())
]
)
def test_merge_time_support(self, ts_group, time_support, reset_time_support, expectation):
if time_support is None:
time_support = ts_group.time_support

ts_group2 = nap.TsGroup(
{
3: nap.Ts(t=np.arange(15)),
4: nap.Ts(t=np.arange(20)),
},
time_support=time_support,
meta=np.array([12, 13])
)

with expectation:
merged = ts_group.merge(ts_group2, reset_time_support=reset_time_support)

if reset_time_support:
np.testing.assert_array_almost_equal(
ts_group.time_support.as_units("s").to_numpy(),
merged.time_support.as_units("s").to_numpy()
)
Loading