Skip to content

Commit

Permalink
Update ts_group.py
Browse files Browse the repository at this point in the history
  • Loading branch information
qian-chu committed May 17, 2024
1 parent c642068 commit 2d08c33
Showing 1 changed file with 55 additions and 35 deletions.
90 changes: 55 additions & 35 deletions pynapple/core/ts_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
jitunion_isets,
)
from .base_class import Base
from .config import time_index_precision
from .interval_set import IntervalSet
from .time_index import TsIndex
from .time_series import BaseTsd, Ts, Tsd, TsdFrame, is_array_like
Expand Down Expand Up @@ -1031,9 +1032,11 @@ def getby_category(self, key):
return sliced

@staticmethod
def merge_group(*tsgroups, reset_index=False, reset_time_support=False, ignore_metadata=False):
def merge_group(
*tsgroups, reset_index=False, reset_time_support=False, ignore_metadata=False
):
"""
Merge multiple TsGroup objects into a single TsGroup object
Merge multiple TsGroup objects into a single TsGroup object.
Parameters
----------
Expand All @@ -1052,16 +1055,16 @@ def merge_group(*tsgroups, reset_index=False, reset_time_support=False, ignore_m
Returns
-------
TsGroup
TsGroup of merged objects
A TsGroup of merged objects
Raises
------
TypeError
If the input objects are not TsGroup objects
ValueError
If ignore_metadata=False and metadata columns are not the same
If reset_index=False and keys overlap
If reset_time_support=False and time supports are not the same
If `ignore_metadata=False` but metadata columns are not the same
If `reset_index=False` but keys overlap
If `reset_time_support=False` but time supports are not the same
Examples
--------
Expand All @@ -1082,7 +1085,7 @@ def merge_group(*tsgroups, reset_index=False, reset_time_support=False, ignore_m
>>> dict4 = {10: nap.Ts(t=[-1, 0, 1], time_units='s')}
>>> tsgroup4 = nap.TsGroup(dict2, time_support=time_support_b)
Merge with default options if have same time_support and non-overlapping indexes:
Merge with default options if have the same time support and non-overlapping indexes:
>>> tsgroup_12 = nap.TsGroup.merge_group(tsgroup1, tsgroup2)
>>> tsgroup_12
Expand All @@ -1091,7 +1094,7 @@ def merge_group(*tsgroups, reset_index=False, reset_time_support=False, ignore_m
0 1.5
10 1.5
Pass reset_index=True if indexes are overlapping:
Set `reset_index=True` if indexes are overlapping:
>>> tsgroup_13 = nap.TsGroup.merge_group(tsgroup1, tsgroup3, reset_index=True)
>>> tsgroup_13
Expand All @@ -1101,7 +1104,7 @@ def merge_group(*tsgroups, reset_index=False, reset_time_support=False, ignore_m
0 1.5
1 1.5
Pass reset_time_support=True if time_supports are different:
Set `reset_time_support=True` if time supports are different:
>>> tsgroup_14 = nap.TsGroup.merge_group(tsgroup1, tsgroup4, reset_time_support=True)
>>> tsgroup_14
Expand All @@ -1119,43 +1122,51 @@ def merge_group(*tsgroups, reset_index=False, reset_time_support=False, ignore_m
"""
is_tsgroup = [isinstance(tsg, TsGroup) for tsg in tsgroups]
if not all(is_tsgroup):
not_tsgroup_index = [i+1 for i, boo in enumerate(is_tsgroup) if not boo]
not_tsgroup_index = [i + 1 for i, boo in enumerate(is_tsgroup) if not boo]
raise TypeError(f"Input at positions {not_tsgroup_index} are not TsGroup!")

if len(tsgroups) == 1:
print('Only one TsGroup object provided, no merge needed')
print("Only one TsGroup object provided, no merge needed.")
return tsgroups[0]

tsg1 = tsgroups[0]
items = tsg1.items()
keys = set(tsg1.keys())
tsg1 = tsgroups[0]
items = tsg1.items()
keys = set(tsg1.keys())
metadata = tsg1._metadata

for i, tsg in enumerate(tsgroups[1:]):
if not ignore_metadata:
if tsg1.metadata_columns != tsg.metadata_columns:
raise ValueError(f"TsGroup at position {i+2} has different metadata columns from previous TsGroup objects. "
"Set `ignore_metadata=True` to bypass the check.")
raise ValueError(
f"TsGroup at position {i+2} has different metadata columns from previous TsGroup objects. "
"Set `ignore_metadata=True` to bypass the check."
)
metadata = pd.concat([metadata, tsg._metadata], axis=0)

if not reset_index:
key_overlap = keys.intersection(tsg.keys())
if key_overlap:
raise ValueError(f"TsGroup at position {i+2} has overlapping keys {key_overlap} with previous TsGroup objects. "
"Pass reset_index=True to bypass")
raise ValueError(
f"TsGroup at position {i+2} has overlapping keys {key_overlap} with previous TsGroup objects. "
"Set `reset_index=True` to bypass the check."
)
keys.update(tsg.keys())

if reset_time_support:
time_support = None
else:
if not np.array_equal(
tsg1.time_support.as_units('s').to_numpy(),
tsg.time_support.as_units('s').to_numpy()
):
raise ValueError(f"TsGroup at position {i+2} has different time support from previous TsGroup objects. "
"Pass reset_time_support=True to bypass")
if not np.allclose(
tsg1.time_support.as_units("s").to_numpy(),
tsg.time_support.as_units("s").to_numpy(),
atol=10 ** (-time_index_precision),
rtol=0,
):
raise ValueError(
f"TsGroup at position {i+2} has different time support from previous TsGroup objects. "
"Set `reset_time_support=True` to bypass the check."
)
time_support = tsg1.time_support

items.extend(tsg.items())

if reset_index:
Expand All @@ -1165,22 +1176,31 @@ def merge_group(*tsgroups, reset_index=False, reset_time_support=False, ignore_m
data = dict(items)

if ignore_metadata:
return TsGroup(
data, time_support=time_support, bypass_check=False
)
return TsGroup(data, time_support=time_support, bypass_check=False)
else:
cols = metadata.columns.drop("rate")
return TsGroup(
data, time_support=time_support, bypass_check=False, **metadata[cols]
)

def merge(self, *tsgroups, reset_index=False, reset_time_support=False, ignore_metadata=False):
)

def merge(
self,
*tsgroups,
reset_index=False,
reset_time_support=False,
ignore_metadata=False,
):
"""
Merge the TsGroup object with other TsGroup objects
See `TsGroup.merge_group` for more details
"""
return TsGroup.merge_group(
self, *tsgroups, reset_index=reset_index, reset_time_support=reset_time_support, ignore_metadata=ignore_metadata)
self,
*tsgroups,
reset_index=reset_index,
reset_time_support=reset_time_support,
ignore_metadata=ignore_metadata,
)

def save(self, filename):
"""
Expand Down

0 comments on commit 2d08c33

Please sign in to comment.