Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Add TsGroup.merge_group method #275

Merged
merged 19 commits into from
May 21, 2024
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 154 additions & 0 deletions pynapple/core/ts_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,6 +1030,160 @@ def getby_category(self, key):
sliced = {k: self[list(groups[k])] for k in groups.keys()}
return sliced

@staticmethod
def merge_group(*tsgroups, reset_index=False, reset_time_support=False, ignore_metadata=False):
"""
Merge multiple TsGroup objects into a single TsGroup object

Parameters
----------
*tsgroups : TsGroup
The TsGroup objects to merge
reset_index : bool, optional
If True, the keys will be reset to range(len(data))
If False, the keys of the TsGroup objects should be non-overlapping and will be preserved
reset_time_support : bool, optional
If True, the merged TsGroup will merge time supports from all the Ts/Tsd objects in data
If False, the time support of the TsGroup objects should be the same and will be preserved
ignore_metadata : bool, optional
If True, the merged TsGroup will not have any metadata columns other than 'rate'
If False, all metadata columns should be the same and all metadata will be concatenated

Returns
-------
TsGroup
TsGroup of merged objects

Raises
------
TypeError
If the input objects are not TsGroup objects
ValueError
If ignore_metadata=False and metadata columns are not the same
If reset_index=False and keys overlap
If reset_time_support=False and time supports are not the same

Examples
--------

>>> import pynapple as nap
>>> time_support_a = nap.IntervalSet(start=-1, end=1, time_units='s')
>>> time_support_b = nap.IntervalSet(start=-5, end=5, time_units='s')

>>> dict1 = {0: nap.Ts(t=[-1, 0, 1], time_units='s')}
>>> tsgroup1 = nap.TsGroup(dict1, time_support=time_support_a)

>>> dict2 = {10: nap.Ts(t=[-1, 0, 1], time_units='s')}
>>> tsgroup2 = nap.TsGroup(dict2, time_support=time_support_a)

>>> dict3 = {0: nap.Ts(t=[-.1, 0, .1], time_units='s')}
>>> tsgroup3 = nap.TsGroup(dict3, time_support=time_support_a)

>>> dict4 = {10: nap.Ts(t=[-1, 0, 1], time_units='s')}
>>> tsgroup4 = nap.TsGroup(dict2, time_support=time_support_b)

Merge with default options if have same time_support and non-overlapping indexes:

>>> tsgroup_12 = nap.TsGroup.merge_group(tsgroup1, tsgroup2)
>>> tsgroup_12
Index rate
------- ------
0 1.5
10 1.5

Pass reset_index=True if indexes are overlapping:

>>> tsgroup_13 = nap.TsGroup.merge_group(tsgroup1, tsgroup3, reset_index=True)
>>> tsgroup_13

Index rate
------- ------
0 1.5
1 1.5

Pass reset_time_support=True if time_supports are different:

>>> tsgroup_14 = nap.TsGroup.merge_group(tsgroup1, tsgroup4, reset_time_support=True)
>>> tsgroup_14
>>> tsgroup_14.time_support

Index rate
------- ------
0 0.3
10 0.3

start end
0 -5 5
shape: (1, 2), time unit: sec.

"""
is_tsgroup = [isinstance(tsg, TsGroup) for tsg in tsgroups]
if not all(is_tsgroup):
not_tsgroup_index = [i+1 for i, boo in enumerate(is_tsgroup) if not boo]
raise TypeError(f"Passed variables at positions {not_tsgroup_index} are not TsGroup")

if len(tsgroups) == 1:
print('Only one TsGroup object provided, no merge needed')
return tsgroups[0]

tsg1 = tsgroups[0]
items = tsg1.items()
keys = set(tsg1.keys())
metadata = tsg1._metadata

for i, tsg in enumerate(tsgroups[1:]):
if not ignore_metadata:
if tsg1.metadata_columns != tsg.metadata_columns:
raise ValueError(f"TsGroup at position {i+2} has different metadata columns from previous TsGroup objects. "
"Pass ignore_metadata=True to bypass")
metadata = pd.concat([metadata, tsg._metadata], axis=0)

if not reset_index:
key_overlap = keys.intersection(tsg.keys())
if key_overlap:
raise ValueError(f"TsGroup at position {i+2} has overlapping keys {key_overlap} with previous TsGroup objects. "
"Pass reset_index=True to bypass")
keys.update(tsg.keys())

if reset_time_support:
time_support = None
else:
if not np.array_equal(
tsg1.time_support.as_units('s').to_numpy(),
tsg.time_support.as_units('s').to_numpy()
):
raise ValueError(f"TsGroup at position {i+2} has different time support from previous TsGroup objects. "
"Pass reset_time_support=True to bypass")
time_support = tsg1.time_support

items.extend(tsg.items())

if reset_index:
metadata.index = range(len(metadata))
data = {i: ts[1] for i, ts in enumerate(items)}
else:
data = dict(items)

if ignore_metadata:
return TsGroup(
data, time_support=time_support, bypass_check=False
)
else:
cols = metadata.columns.drop("rate")
return TsGroup(
data, time_support=time_support, bypass_check=False, **metadata[cols]
)

def merge(self, *tsgroups, reset_index=False, reset_time_support=False, ignore_metadata=False):
"""
Merge the TsGroup object with other TsGroup objects
See `TsGroup.merge_group` for more details
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since this is probably the most user facing method, I would have a nice numpydoc docstring here too!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good suggestion. Could you offer some tips on how to do so while reducing redundancy?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest to move the detailed docstrings from the merge_group to the merge.
In merge_group keep the same initial paragraph describing the method, the parameters, and the returns sections. I'll add a See Also section that points to merge.

"""
tsgroups = list(tsgroups)
tsgroups.insert(0, self)
return TsGroup.merge_group(
*tsgroups, reset_index=reset_index, reset_time_support=reset_time_support, ignore_metadata=ignore_metadata)

def save(self, filename):
"""
Save TsGroup object in npz format. The file will contain the timestamps,
Expand Down
Loading