diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 349e7157..0c779db4 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -5,7 +5,7 @@ on: push: branches: [ main ] pull_request: - branches: [ main ] + branches: [ main, dev ] jobs: lint: @@ -40,9 +40,9 @@ jobs: # - os: windows-latest # python-version: 3.7 steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -52,14 +52,12 @@ jobs: - name: Test run: | coverage run --source=pynapple --branch -m pytest tests/ - coverage report -m + coverage report -m - - name: Coveralls - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - run: | - pip install coveralls - coveralls --service=github + - name: Upload coverage reports to Codecov + uses: codecov/codecov-action@v4.0.1 + with: + token: ${{ secrets.CODECOV_TOKEN }} check: if: always() needs: diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index e35a6937..b4865d6d 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -17,6 +17,7 @@ from ._core_functions import _count from ._jitted_functions import jitunion, jitunion_isets from .base_class import Base +from .config import nap_config from .interval_set import IntervalSet from .time_index import TsIndex from .time_series import BaseTsd, Ts, Tsd, TsdFrame, is_array_like @@ -1026,6 +1027,207 @@ def getby_category(self, key): sliced = {k: self[list(groups[k])] for k in groups.keys()} return sliced + @staticmethod + def merge_group( + *tsgroups, reset_index=False, reset_time_support=False, ignore_metadata=False + ): + """ + Merge multiple TsGroup objects into a single TsGroup object. + + Parameters + ---------- + *tsgroups : TsGroup + The TsGroup objects to merge + reset_index : bool, optional + If True, the keys will be reset to range(len(data)) + If False, the keys of the TsGroup objects should be non-overlapping and will be preserved + reset_time_support : bool, optional + If True, the merged TsGroup will merge time supports from all the Ts/Tsd objects in data + If False, the time support of the TsGroup objects should be the same and will be preserved + ignore_metadata : bool, optional + If True, the merged TsGroup will not have any metadata columns other than 'rate' + If False, all metadata columns should be the same and all metadata will be concatenated + + Returns + ------- + TsGroup + A TsGroup of merged objects + + Raises + ------ + TypeError + If the input objects are not TsGroup objects + ValueError + If `ignore_metadata=False` but metadata columns are not the same + If `reset_index=False` but keys overlap + If `reset_time_support=False` but time supports are not the same + + """ + is_tsgroup = [isinstance(tsg, TsGroup) for tsg in tsgroups] + if not all(is_tsgroup): + not_tsgroup_index = [i + 1 for i, boo in enumerate(is_tsgroup) if not boo] + raise TypeError(f"Input at positions {not_tsgroup_index} are not TsGroup!") + + if len(tsgroups) == 1: + print("Only one TsGroup object provided, no merge needed.") + return tsgroups[0] + + tsg1 = tsgroups[0] + items = tsg1.items() + keys = set(tsg1.keys()) + metadata = tsg1._metadata + + for i, tsg in enumerate(tsgroups[1:]): + if not ignore_metadata: + if tsg1.metadata_columns != tsg.metadata_columns: + raise ValueError( + f"TsGroup at position {i+2} has different metadata columns from previous TsGroup objects. " + "Set `ignore_metadata=True` to bypass the check." + ) + metadata = pd.concat([metadata, tsg._metadata], axis=0) + + if not reset_index: + key_overlap = keys.intersection(tsg.keys()) + if key_overlap: + raise ValueError( + f"TsGroup at position {i+2} has overlapping keys {key_overlap} with previous TsGroup objects. " + "Set `reset_index=True` to bypass the check." + ) + keys.update(tsg.keys()) + + if reset_time_support: + time_support = None + else: + if not np.allclose( + tsg1.time_support.as_units("s").to_numpy(), + tsg.time_support.as_units("s").to_numpy(), + atol=10 ** (-nap_config.time_index_precision), + rtol=0, + ): + raise ValueError( + f"TsGroup at position {i+2} has different time support from previous TsGroup objects. " + "Set `reset_time_support=True` to bypass the check." + ) + time_support = tsg1.time_support + + items.extend(tsg.items()) + + if reset_index: + metadata.index = range(len(metadata)) + data = {i: ts[1] for i, ts in enumerate(items)} + else: + data = dict(items) + + if ignore_metadata: + return TsGroup(data, time_support=time_support, bypass_check=False) + else: + cols = metadata.columns.drop("rate") + return TsGroup( + data, time_support=time_support, bypass_check=False, **metadata[cols] + ) + + def merge( + self, + *tsgroups, + reset_index=False, + reset_time_support=False, + ignore_metadata=False, + ): + """ + Merge the TsGroup object with other TsGroup objects. + Common uses include adding more neurons/channels (supposing each Ts/Tsd corresponds to data from a neuron/channel) or adding more trials (supposing each Ts/Tsd corresponds to data from a trial). + + Parameters + ---------- + *tsgroups : TsGroup + The TsGroup objects to merge with + reset_index : bool, optional + If True, the keys will be reset to range(len(data)) + If False, the keys of the TsGroup objects should be non-overlapping and will be preserved + reset_time_support : bool, optional + If True, the merged TsGroup will merge time supports from all the Ts/Tsd objects in data + If False, the time support of the TsGroup objects should be the same and will be preserved + ignore_metadata : bool, optional + If True, the merged TsGroup will not have any metadata columns other than 'rate' + If False, all metadata columns should be the same and all metadata will be concatenated + + Returns + ------- + TsGroup + A TsGroup of merged objects + + Raises + ------ + TypeError + If the input objects are not TsGroup objects + ValueError + If `ignore_metadata=False` but metadata columns are not the same + If `reset_index=False` but keys overlap + If `reset_time_support=False` but time supports are not the same + + Examples + -------- + + >>> import pynapple as nap + >>> time_support_a = nap.IntervalSet(start=-1, end=1, time_units='s') + >>> time_support_b = nap.IntervalSet(start=-5, end=5, time_units='s') + + >>> dict1 = {0: nap.Ts(t=[-1, 0, 1], time_units='s')} + >>> tsgroup1 = nap.TsGroup(dict1, time_support=time_support_a) + + >>> dict2 = {10: nap.Ts(t=[-1, 0, 1], time_units='s')} + >>> tsgroup2 = nap.TsGroup(dict2, time_support=time_support_a) + + >>> dict3 = {0: nap.Ts(t=[-.1, 0, .1], time_units='s')} + >>> tsgroup3 = nap.TsGroup(dict3, time_support=time_support_a) + + >>> dict4 = {10: nap.Ts(t=[-1, 0, 1], time_units='s')} + >>> tsgroup4 = nap.TsGroup(dict2, time_support=time_support_b) + + Merge with default options if have the same time support and non-overlapping indexes: + + >>> tsgroup_12 = tsgroup1.merge(tsgroup2) + >>> tsgroup_12 + Index rate + ------- ------ + 0 1.5 + 10 1.5 + + Set `reset_index=True` if indexes are overlapping: + + >>> tsgroup_13 = tsgroup1.merge(tsgroup3, reset_index=True) + >>> tsgroup_13 + Index rate + ------- ------ + 0 1.5 + 1 1.5 + + Set `reset_time_support=True` if time supports are different: + + >>> tsgroup_14 = tsgroup1.merge(tsgroup4, reset_time_support=True) + >>> tsgroup_14 + >>> tsgroup_14.time_support + Index rate + ------- ------ + 0 0.3 + 10 0.3 + + start end + 0 -5 5 + shape: (1, 2), time unit: sec. + + See Also + -------- + [`TsGroup.merge_group`](./#pynapple.core.ts_group.TsGroup.merge_group) + """ + return TsGroup.merge_group( + self, + *tsgroups, + reset_index=reset_index, + reset_time_support=reset_time_support, + ignore_metadata=ignore_metadata, + ) + def save(self, filename): """ Save TsGroup object in npz format. The file will contain the timestamps, diff --git a/tests/test_ts_group.py b/tests/test_ts_group.py index 65d23af4..fc221f60 100644 --- a/tests/test_ts_group.py +++ b/tests/test_ts_group.py @@ -14,7 +14,6 @@ import warnings from contextlib import nullcontext as does_not_raise - @pytest.fixture def group(): """Fixture to be used in all tests.""" @@ -575,15 +574,16 @@ def test_save_npz(self, group): np.testing.assert_array_almost_equal(file['index'], index) np.testing.assert_array_almost_equal(file['meta'], np.arange(len(group), dtype=np.int64)) assert np.all(file['meta2']==np.array(['a', 'b', 'c'])) + file.close() tsgroup3 = nap.TsGroup({ 0: nap.Ts(t=np.arange(0, 20)), }) tsgroup3.save("tsgroup3") - file = np.load("tsgroup3.npz") - assert 'd' not in list(file.keys()) - np.testing.assert_array_almost_equal(file['t'], tsgroup3[0].index) + with np.load("tsgroup3.npz") as file: + assert 'd' not in list(file.keys()) + np.testing.assert_array_almost_equal(file['t'], tsgroup3[0].index) os.remove("tsgroup.npz") os.remove("tsgroup2.npz") @@ -752,4 +752,106 @@ def test_getitem_attribute_error(self, ts_group): ) def test_getitem_boolean_fail(self, ts_group, bool_idx, expectation): with expectation: - out = ts_group[bool_idx] \ No newline at end of file + out = ts_group[bool_idx] + + def test_merge_complete(self, ts_group): + with pytest.raises(TypeError, match="Input at positions(.*)are not TsGroup!"): + nap.TsGroup.merge_group(ts_group, str, dict) + + ts_group2 = nap.TsGroup( + { + 3: nap.Ts(t=np.arange(15)), + 4: nap.Ts(t=np.arange(20)), + }, + time_support=ts_group.time_support, + meta=np.array([12, 13]) + ) + merged = ts_group.merge(ts_group2) + assert len(merged) == 4 + assert np.all(merged.keys() == np.array([1, 2, 3, 4])) + assert np.all(merged.meta == np.array([10, 11, 12, 13])) + np.testing.assert_equal(merged.metadata_columns, ts_group.metadata_columns) + + @pytest.mark.parametrize( + 'col_name, ignore_metadata, expectation', + [ + ('meta', False, does_not_raise()), + ('meta', True, does_not_raise()), + ('wrong_name', False, pytest.raises(ValueError, match="TsGroup at position 2 has different metadata columns.*")), + ('wrong_name', True, does_not_raise()) + ] + ) + def test_merge_metadata(self, ts_group, col_name, ignore_metadata, expectation): + metadata = pd.DataFrame([12, 13], index=[3, 4], columns=[col_name]) + ts_group2 = nap.TsGroup( + { + 3: nap.Ts(t=np.arange(15)), + 4: nap.Ts(t=np.arange(20)), + }, + time_support=ts_group.time_support, + **metadata + ) + + with expectation: + merged = ts_group.merge(ts_group2, ignore_metadata=ignore_metadata) + + if ignore_metadata: + assert merged.metadata_columns[0] == 'rate' + elif col_name == 'meta': + np.testing.assert_equal(merged.metadata_columns, ts_group.metadata_columns) + + @pytest.mark.parametrize( + 'index, reset_index, expectation', + [ + (np.array([1, 2]), False, pytest.raises(ValueError, match="TsGroup at position 2 has overlapping keys.*")), + (np.array([1, 2]), True, does_not_raise()), + (np.array([3, 4]), False, does_not_raise()), + (np.array([3, 4]), True, does_not_raise()) + ] + ) + def test_merge_index(self, ts_group, index, reset_index, expectation): + ts_group2 = nap.TsGroup( + dict(zip(index, [nap.Ts(t=np.arange(15)), nap.Ts(t=np.arange(20))])), + time_support=ts_group.time_support, + meta=np.array([12, 13]) + ) + + with expectation: + merged = ts_group.merge(ts_group2, reset_index=reset_index) + + if reset_index: + assert np.all(merged.keys() == np.arange(4)) + elif np.all(index == np.array([3, 4])): + assert np.all(merged.keys() == np.array([1, 2, 3, 4])) + + @pytest.mark.parametrize( + 'time_support, reset_time_support, expectation', + [ + (None, False, does_not_raise()), + (None, True, does_not_raise()), + (nap.IntervalSet(start=0, end=1), False, + pytest.raises(ValueError, match="TsGroup at position 2 has different time support.*")), + (nap.IntervalSet(start=0, end=1), True, does_not_raise()) + ] + ) + def test_merge_time_support(self, ts_group, time_support, reset_time_support, expectation): + if time_support is None: + time_support = ts_group.time_support + + ts_group2 = nap.TsGroup( + { + 3: nap.Ts(t=np.arange(15)), + 4: nap.Ts(t=np.arange(20)), + }, + time_support=time_support, + meta=np.array([12, 13]) + ) + + with expectation: + merged = ts_group.merge(ts_group2, reset_time_support=reset_time_support) + + if reset_time_support: + np.testing.assert_array_almost_equal( + ts_group.time_support.as_units("s").to_numpy(), + merged.time_support.as_units("s").to_numpy() + ) \ No newline at end of file