Skip to content

Commit

Permalink
Merge pull request #275 from qian-chu/merge_tsgroup
Browse files Browse the repository at this point in the history
[ENH] Add TsGroup.merge_group method
  • Loading branch information
gviejo authored May 21, 2024
2 parents 651d3cc + e6c9bc5 commit ff231e0
Show file tree
Hide file tree
Showing 3 changed files with 317 additions and 15 deletions.
18 changes: 8 additions & 10 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
branches: [ main, dev ]

jobs:
lint:
Expand Down Expand Up @@ -40,9 +40,9 @@ jobs:
# - os: windows-latest
# python-version: 3.7
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
Expand All @@ -52,14 +52,12 @@ jobs:
- name: Test
run: |
coverage run --source=pynapple --branch -m pytest tests/
coverage report -m
coverage report -m
- name: Coveralls
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
pip install coveralls
coveralls --service=github
- name: Upload coverage reports to Codecov
uses: codecov/[email protected]
with:
token: ${{ secrets.CODECOV_TOKEN }}
check:
if: always()
needs:
Expand Down
202 changes: 202 additions & 0 deletions pynapple/core/ts_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ._core_functions import _count
from ._jitted_functions import jitunion, jitunion_isets
from .base_class import Base
from .config import nap_config
from .interval_set import IntervalSet
from .time_index import TsIndex
from .time_series import BaseTsd, Ts, Tsd, TsdFrame, is_array_like
Expand Down Expand Up @@ -1026,6 +1027,207 @@ def getby_category(self, key):
sliced = {k: self[list(groups[k])] for k in groups.keys()}
return sliced

@staticmethod
def merge_group(
*tsgroups, reset_index=False, reset_time_support=False, ignore_metadata=False
):
"""
Merge multiple TsGroup objects into a single TsGroup object.
Parameters
----------
*tsgroups : TsGroup
The TsGroup objects to merge
reset_index : bool, optional
If True, the keys will be reset to range(len(data))
If False, the keys of the TsGroup objects should be non-overlapping and will be preserved
reset_time_support : bool, optional
If True, the merged TsGroup will merge time supports from all the Ts/Tsd objects in data
If False, the time support of the TsGroup objects should be the same and will be preserved
ignore_metadata : bool, optional
If True, the merged TsGroup will not have any metadata columns other than 'rate'
If False, all metadata columns should be the same and all metadata will be concatenated
Returns
-------
TsGroup
A TsGroup of merged objects
Raises
------
TypeError
If the input objects are not TsGroup objects
ValueError
If `ignore_metadata=False` but metadata columns are not the same
If `reset_index=False` but keys overlap
If `reset_time_support=False` but time supports are not the same
"""
is_tsgroup = [isinstance(tsg, TsGroup) for tsg in tsgroups]
if not all(is_tsgroup):
not_tsgroup_index = [i + 1 for i, boo in enumerate(is_tsgroup) if not boo]
raise TypeError(f"Input at positions {not_tsgroup_index} are not TsGroup!")

if len(tsgroups) == 1:
print("Only one TsGroup object provided, no merge needed.")
return tsgroups[0]

tsg1 = tsgroups[0]
items = tsg1.items()
keys = set(tsg1.keys())
metadata = tsg1._metadata

for i, tsg in enumerate(tsgroups[1:]):
if not ignore_metadata:
if tsg1.metadata_columns != tsg.metadata_columns:
raise ValueError(
f"TsGroup at position {i+2} has different metadata columns from previous TsGroup objects. "
"Set `ignore_metadata=True` to bypass the check."
)
metadata = pd.concat([metadata, tsg._metadata], axis=0)

if not reset_index:
key_overlap = keys.intersection(tsg.keys())
if key_overlap:
raise ValueError(
f"TsGroup at position {i+2} has overlapping keys {key_overlap} with previous TsGroup objects. "
"Set `reset_index=True` to bypass the check."
)
keys.update(tsg.keys())

if reset_time_support:
time_support = None
else:
if not np.allclose(
tsg1.time_support.as_units("s").to_numpy(),
tsg.time_support.as_units("s").to_numpy(),
atol=10 ** (-nap_config.time_index_precision),
rtol=0,
):
raise ValueError(
f"TsGroup at position {i+2} has different time support from previous TsGroup objects. "
"Set `reset_time_support=True` to bypass the check."
)
time_support = tsg1.time_support

items.extend(tsg.items())

if reset_index:
metadata.index = range(len(metadata))
data = {i: ts[1] for i, ts in enumerate(items)}
else:
data = dict(items)

if ignore_metadata:
return TsGroup(data, time_support=time_support, bypass_check=False)
else:
cols = metadata.columns.drop("rate")
return TsGroup(
data, time_support=time_support, bypass_check=False, **metadata[cols]
)

def merge(
self,
*tsgroups,
reset_index=False,
reset_time_support=False,
ignore_metadata=False,
):
"""
Merge the TsGroup object with other TsGroup objects.
Common uses include adding more neurons/channels (supposing each Ts/Tsd corresponds to data from a neuron/channel) or adding more trials (supposing each Ts/Tsd corresponds to data from a trial).
Parameters
----------
*tsgroups : TsGroup
The TsGroup objects to merge with
reset_index : bool, optional
If True, the keys will be reset to range(len(data))
If False, the keys of the TsGroup objects should be non-overlapping and will be preserved
reset_time_support : bool, optional
If True, the merged TsGroup will merge time supports from all the Ts/Tsd objects in data
If False, the time support of the TsGroup objects should be the same and will be preserved
ignore_metadata : bool, optional
If True, the merged TsGroup will not have any metadata columns other than 'rate'
If False, all metadata columns should be the same and all metadata will be concatenated
Returns
-------
TsGroup
A TsGroup of merged objects
Raises
------
TypeError
If the input objects are not TsGroup objects
ValueError
If `ignore_metadata=False` but metadata columns are not the same
If `reset_index=False` but keys overlap
If `reset_time_support=False` but time supports are not the same
Examples
--------
>>> import pynapple as nap
>>> time_support_a = nap.IntervalSet(start=-1, end=1, time_units='s')
>>> time_support_b = nap.IntervalSet(start=-5, end=5, time_units='s')
>>> dict1 = {0: nap.Ts(t=[-1, 0, 1], time_units='s')}
>>> tsgroup1 = nap.TsGroup(dict1, time_support=time_support_a)
>>> dict2 = {10: nap.Ts(t=[-1, 0, 1], time_units='s')}
>>> tsgroup2 = nap.TsGroup(dict2, time_support=time_support_a)
>>> dict3 = {0: nap.Ts(t=[-.1, 0, .1], time_units='s')}
>>> tsgroup3 = nap.TsGroup(dict3, time_support=time_support_a)
>>> dict4 = {10: nap.Ts(t=[-1, 0, 1], time_units='s')}
>>> tsgroup4 = nap.TsGroup(dict2, time_support=time_support_b)
Merge with default options if have the same time support and non-overlapping indexes:
>>> tsgroup_12 = tsgroup1.merge(tsgroup2)
>>> tsgroup_12
Index rate
------- ------
0 1.5
10 1.5
Set `reset_index=True` if indexes are overlapping:
>>> tsgroup_13 = tsgroup1.merge(tsgroup3, reset_index=True)
>>> tsgroup_13
Index rate
------- ------
0 1.5
1 1.5
Set `reset_time_support=True` if time supports are different:
>>> tsgroup_14 = tsgroup1.merge(tsgroup4, reset_time_support=True)
>>> tsgroup_14
>>> tsgroup_14.time_support
Index rate
------- ------
0 0.3
10 0.3
start end
0 -5 5
shape: (1, 2), time unit: sec.
See Also
--------
[`TsGroup.merge_group`](./#pynapple.core.ts_group.TsGroup.merge_group)
"""
return TsGroup.merge_group(
self,
*tsgroups,
reset_index=reset_index,
reset_time_support=reset_time_support,
ignore_metadata=ignore_metadata,
)

def save(self, filename):
"""
Save TsGroup object in npz format. The file will contain the timestamps,
Expand Down
112 changes: 107 additions & 5 deletions tests/test_ts_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import warnings
from contextlib import nullcontext as does_not_raise


@pytest.fixture
def group():
"""Fixture to be used in all tests."""
Expand Down Expand Up @@ -575,15 +574,16 @@ def test_save_npz(self, group):
np.testing.assert_array_almost_equal(file['index'], index)
np.testing.assert_array_almost_equal(file['meta'], np.arange(len(group), dtype=np.int64))
assert np.all(file['meta2']==np.array(['a', 'b', 'c']))
file.close()

tsgroup3 = nap.TsGroup({
0: nap.Ts(t=np.arange(0, 20)),
})
tsgroup3.save("tsgroup3")
file = np.load("tsgroup3.npz")

assert 'd' not in list(file.keys())
np.testing.assert_array_almost_equal(file['t'], tsgroup3[0].index)
with np.load("tsgroup3.npz") as file:
assert 'd' not in list(file.keys())
np.testing.assert_array_almost_equal(file['t'], tsgroup3[0].index)

os.remove("tsgroup.npz")
os.remove("tsgroup2.npz")
Expand Down Expand Up @@ -752,4 +752,106 @@ def test_getitem_attribute_error(self, ts_group):
)
def test_getitem_boolean_fail(self, ts_group, bool_idx, expectation):
with expectation:
out = ts_group[bool_idx]
out = ts_group[bool_idx]

def test_merge_complete(self, ts_group):
with pytest.raises(TypeError, match="Input at positions(.*)are not TsGroup!"):
nap.TsGroup.merge_group(ts_group, str, dict)

ts_group2 = nap.TsGroup(
{
3: nap.Ts(t=np.arange(15)),
4: nap.Ts(t=np.arange(20)),
},
time_support=ts_group.time_support,
meta=np.array([12, 13])
)
merged = ts_group.merge(ts_group2)
assert len(merged) == 4
assert np.all(merged.keys() == np.array([1, 2, 3, 4]))
assert np.all(merged.meta == np.array([10, 11, 12, 13]))
np.testing.assert_equal(merged.metadata_columns, ts_group.metadata_columns)

@pytest.mark.parametrize(
'col_name, ignore_metadata, expectation',
[
('meta', False, does_not_raise()),
('meta', True, does_not_raise()),
('wrong_name', False, pytest.raises(ValueError, match="TsGroup at position 2 has different metadata columns.*")),
('wrong_name', True, does_not_raise())
]
)
def test_merge_metadata(self, ts_group, col_name, ignore_metadata, expectation):
metadata = pd.DataFrame([12, 13], index=[3, 4], columns=[col_name])
ts_group2 = nap.TsGroup(
{
3: nap.Ts(t=np.arange(15)),
4: nap.Ts(t=np.arange(20)),
},
time_support=ts_group.time_support,
**metadata
)

with expectation:
merged = ts_group.merge(ts_group2, ignore_metadata=ignore_metadata)

if ignore_metadata:
assert merged.metadata_columns[0] == 'rate'
elif col_name == 'meta':
np.testing.assert_equal(merged.metadata_columns, ts_group.metadata_columns)

@pytest.mark.parametrize(
'index, reset_index, expectation',
[
(np.array([1, 2]), False, pytest.raises(ValueError, match="TsGroup at position 2 has overlapping keys.*")),
(np.array([1, 2]), True, does_not_raise()),
(np.array([3, 4]), False, does_not_raise()),
(np.array([3, 4]), True, does_not_raise())
]
)
def test_merge_index(self, ts_group, index, reset_index, expectation):
ts_group2 = nap.TsGroup(
dict(zip(index, [nap.Ts(t=np.arange(15)), nap.Ts(t=np.arange(20))])),
time_support=ts_group.time_support,
meta=np.array([12, 13])
)

with expectation:
merged = ts_group.merge(ts_group2, reset_index=reset_index)

if reset_index:
assert np.all(merged.keys() == np.arange(4))
elif np.all(index == np.array([3, 4])):
assert np.all(merged.keys() == np.array([1, 2, 3, 4]))

@pytest.mark.parametrize(
'time_support, reset_time_support, expectation',
[
(None, False, does_not_raise()),
(None, True, does_not_raise()),
(nap.IntervalSet(start=0, end=1), False,
pytest.raises(ValueError, match="TsGroup at position 2 has different time support.*")),
(nap.IntervalSet(start=0, end=1), True, does_not_raise())
]
)
def test_merge_time_support(self, ts_group, time_support, reset_time_support, expectation):
if time_support is None:
time_support = ts_group.time_support

ts_group2 = nap.TsGroup(
{
3: nap.Ts(t=np.arange(15)),
4: nap.Ts(t=np.arange(20)),
},
time_support=time_support,
meta=np.array([12, 13])
)

with expectation:
merged = ts_group.merge(ts_group2, reset_time_support=reset_time_support)

if reset_time_support:
np.testing.assert_array_almost_equal(
ts_group.time_support.as_units("s").to_numpy(),
merged.time_support.as_units("s").to_numpy()
)

0 comments on commit ff231e0

Please sign in to comment.