From 3da0d4683924464d3ceab140ed4b3f0c7b7b4915 Mon Sep 17 00:00:00 2001 From: qian-chu Date: Tue, 7 May 2024 01:38:03 +0200 Subject: [PATCH 01/18] add TsGroup.merge_group method --- pynapple/core/ts_group.py | 153 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 153 insertions(+) diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index 971fd560..43df2a72 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -1030,6 +1030,159 @@ def getby_category(self, key): sliced = {k: self[list(groups[k])] for k in groups.keys()} return sliced + @staticmethod + def merge_group(*tsgroups, reset_index=False, reset_time_support=False, ignore_metadata=False): + """ + Merge multiple TsGroup objects into a single TsGroup object + + Parameters + ---------- + *tsgroups : TsGroup + The TsGroup objects to merge + reset_index : bool, optional + If True, the keys will be reset to range(len(data)) + If False, the keys of the TsGroup objects should be non-overlapping and will be preserved + reset_time_support : bool, optional + If True, the merged TsGroup will merge time supports from all the Ts/Tsd objects in data + If False, the time support of the TsGroup objects should be the same and will be preserved + ignore_metadata : bool, optional + If True, the merged TsGroup will not have any metadata columns other than 'rate' + If False, all metadata columns should be the same and all metadata will be concatenated + + Returns + ------- + TsGroup + TsGroup of merged objects + + Raises + ------ + TypeError + If the input objects are not TsGroup objects + ValueError + If ignore_metadata=False and metadata columns are not the same + If reset_index=False and keys overlap + If reset_time_support=False and time supports are not the same + + Examples + -------- + + >>> import pynapple as nap + >>> time_support_a = nap.IntervalSet(start=-1, end=1, time_units='s') + >>> time_support_b = nap.IntervalSet(start=-5, end=5, time_units='s') + + >>> dict1 = {0: nap.Ts(t=[-1, 0, 1], time_units='s')} + >>> tsgroup1 = nap.TsGroup(dict1, time_support=time_support_a) + + >>> dict2 = {10: nap.Ts(t=[-1, 0, 1], time_units='s')} + >>> tsgroup2 = nap.TsGroup(dict2, time_support=time_support_a) + + >>> dict3 = {0: nap.Ts(t=[-.1, 0, .1], time_units='s')} + >>> tsgroup3 = nap.TsGroup(dict3, time_support=time_support_a) + + >>> dict4 = {10: nap.Ts(t=[-1, 0, 1], time_units='s')} + >>> tsgroup4 = nap.TsGroup(dict2, time_support=time_support_b) + + Merge with default options if have same time_support and non-overlapping indexes: + + >>> tsgroup_12 = nap.TsGroup.merge_group(tsgroup1, tsgroup2) + >>> tsgroup_12 + Index rate + ------- ------ + 0 1.5 + 10 1.5 + + Pass reset_index=True if indexes are overlapping: + + >>> tsgroup_13 = nap.TsGroup.merge_group(tsgroup1, tsgroup3, reset_index=True) + >>> tsgroup_13 + + Index rate + ------- ------ + 0 1.5 + 1 1.5 + + Pass reset_time_support=True if time_supports are different: + + >>> tsgroup_14 = nap.TsGroup.merge_group(tsgroup1, tsgroup4, reset_time_support=True) + >>> tsgroup_14 + >>> tsgroup_14.time_support + + Index rate + ------- ------ + 0 0.3 + 10 0.3 + start end + 0 -5 5 + shape: (1, 2), time unit: sec. + + """ + is_tsgroup = [isinstance(tsg, TsGroup) for tsg in tsgroups] + if not all(is_tsgroup): + not_tsgroup_index = [i+1 for i, boo in enumerate(is_tsgroup) if not boo] + raise TypeError(f"Passed variables at positions {not_tsgroup_index} are not TsGroup") + + if len(tsgroups) == 1: + print('Only one TsGroup object provided, no merge needed') + return tsgroups[0] + + tsg1 = tsgroups[0] + items = tsg1.items() + keys = set(tsg1.keys()) + metadata = tsg1._metadata + + for i, tsg in enumerate(tsgroups[1:]): + if not ignore_metadata: + if tsg1.metadata_columns != tsg.metadata_columns: + raise ValueError(f"TsGroup at position {i+2} has different metadata columns from previous TsGroup objects. " + "Pass ignore_metadata=True to bypass") + metadata = pd.concat([metadata, tsg._metadata], axis=0) + + if not reset_index: + key_overlap = keys.intersection(tsg.keys()) + if key_overlap: + raise ValueError(f"TsGroup at position {i+2} has overlapping keys {key_overlap} with previous TsGroup objects. " + "Pass reset_index=True to bypass") + keys.update(tsg.keys()) + + if reset_time_support: + time_support = None + else: + if not np.array_equal( + tsg1.time_support.as_units('s').to_numpy(), + tsg.time_support.as_units('s').to_numpy() + ): + raise ValueError(f"TsGroup at position {i+2} has different time support from previous TsGroup objects. " + "Pass reset_time_support=True to bypass") + time_support = tsg1.time_support + + items.extend(tsg.items()) + + if reset_index: + metadata.index = range(len(metadata)) + data = {i: ts[1] for i, ts in enumerate(items)} + else: + data = dict(items) + + if ignore_metadata: + return TsGroup( + data, time_support=time_support, bypass_check=False + ) + else: + cols = metadata.columns.drop("rate") + return TsGroup( + data, time_support=time_support, bypass_check=False, **metadata[cols] + ) + + def merge(self, *tsgroups, reset_index=False, reset_time_support=False, ignore_metadata=False): + """ + Merge the TsGroup object with other TsGroup objects + See `TsGroup.merge_group` for more details + """ + tsgroups = list(tsgroups) + tsgroups.insert(0, self) + return TsGroup.merge_group( + *tsgroups, reset_index=reset_index, reset_time_support=reset_time_support, ignore_metadata=ignore_metadata) + def save(self, filename): """ Save TsGroup object in npz format. The file will contain the timestamps, From 60719f51f6010fbbd5c26bb7194fbe4af2b01917 Mon Sep 17 00:00:00 2001 From: qian-chu Date: Tue, 7 May 2024 01:49:08 +0200 Subject: [PATCH 02/18] doc aesthetics --- pynapple/core/ts_group.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index 43df2a72..d105bde4 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -1111,8 +1111,8 @@ def merge_group(*tsgroups, reset_index=False, reset_time_support=False, ignore_m ------- ------ 0 0.3 10 0.3 - start end - 0 -5 5 + start end + 0 -5 5 shape: (1, 2), time unit: sec. """ From b3874a15299e2c2af5d7f3959cfd5091d27581c6 Mon Sep 17 00:00:00 2001 From: qian-chu Date: Tue, 7 May 2024 01:49:29 +0200 Subject: [PATCH 03/18] Update ts_group.py --- pynapple/core/ts_group.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index d105bde4..e04762fd 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -1111,6 +1111,7 @@ def merge_group(*tsgroups, reset_index=False, reset_time_support=False, ignore_m ------- ------ 0 0.3 10 0.3 + start end 0 -5 5 shape: (1, 2), time unit: sec. From d24f3312aeab9393809a2219337b083522686084 Mon Sep 17 00:00:00 2001 From: Qian Chu <97355086+qian-chu@users.noreply.github.com> Date: Tue, 7 May 2024 18:21:30 +0200 Subject: [PATCH 04/18] Simplify tsgroup.merge() Co-authored-by: Edoardo Balzani --- pynapple/core/ts_group.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index e04762fd..395b5e31 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -1179,10 +1179,8 @@ def merge(self, *tsgroups, reset_index=False, reset_time_support=False, ignore_m Merge the TsGroup object with other TsGroup objects See `TsGroup.merge_group` for more details """ - tsgroups = list(tsgroups) - tsgroups.insert(0, self) return TsGroup.merge_group( - *tsgroups, reset_index=reset_index, reset_time_support=reset_time_support, ignore_metadata=ignore_metadata) + self, *tsgroups, reset_index=reset_index, reset_time_support=reset_time_support, ignore_metadata=ignore_metadata) def save(self, filename): """ From 04a0053fa616fa6aeea6a825d8e9507bcfdf5109 Mon Sep 17 00:00:00 2001 From: Qian Chu <97355086+qian-chu@users.noreply.github.com> Date: Tue, 7 May 2024 18:22:05 +0200 Subject: [PATCH 05/18] Update pynapple/core/ts_group.py Co-authored-by: Edoardo Balzani --- pynapple/core/ts_group.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index 395b5e31..614eb600 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -1120,7 +1120,7 @@ def merge_group(*tsgroups, reset_index=False, reset_time_support=False, ignore_m is_tsgroup = [isinstance(tsg, TsGroup) for tsg in tsgroups] if not all(is_tsgroup): not_tsgroup_index = [i+1 for i, boo in enumerate(is_tsgroup) if not boo] - raise TypeError(f"Passed variables at positions {not_tsgroup_index} are not TsGroup") + raise TypeError(f"Input at positions {not_tsgroup_index} are not TsGroup!") if len(tsgroups) == 1: print('Only one TsGroup object provided, no merge needed') From c642068cb5faac287c64e8c8ccbf16f1815f499f Mon Sep 17 00:00:00 2001 From: Qian Chu <97355086+qian-chu@users.noreply.github.com> Date: Tue, 7 May 2024 18:22:21 +0200 Subject: [PATCH 06/18] Update pynapple/core/ts_group.py Co-authored-by: Edoardo Balzani --- pynapple/core/ts_group.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index 614eb600..fe6fbed3 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -1135,7 +1135,7 @@ def merge_group(*tsgroups, reset_index=False, reset_time_support=False, ignore_m if not ignore_metadata: if tsg1.metadata_columns != tsg.metadata_columns: raise ValueError(f"TsGroup at position {i+2} has different metadata columns from previous TsGroup objects. " - "Pass ignore_metadata=True to bypass") + "Set `ignore_metadata=True` to bypass the check.") metadata = pd.concat([metadata, tsg._metadata], axis=0) if not reset_index: From 2d08c338a25a7e901d3a8c9a4b02075280205199 Mon Sep 17 00:00:00 2001 From: qian-chu Date: Fri, 17 May 2024 22:56:02 +0200 Subject: [PATCH 07/18] Update ts_group.py --- pynapple/core/ts_group.py | 90 ++++++++++++++++++++++++--------------- 1 file changed, 55 insertions(+), 35 deletions(-) diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index fe6fbed3..635796e4 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -21,6 +21,7 @@ jitunion_isets, ) from .base_class import Base +from .config import time_index_precision from .interval_set import IntervalSet from .time_index import TsIndex from .time_series import BaseTsd, Ts, Tsd, TsdFrame, is_array_like @@ -1031,9 +1032,11 @@ def getby_category(self, key): return sliced @staticmethod - def merge_group(*tsgroups, reset_index=False, reset_time_support=False, ignore_metadata=False): + def merge_group( + *tsgroups, reset_index=False, reset_time_support=False, ignore_metadata=False + ): """ - Merge multiple TsGroup objects into a single TsGroup object + Merge multiple TsGroup objects into a single TsGroup object. Parameters ---------- @@ -1052,16 +1055,16 @@ def merge_group(*tsgroups, reset_index=False, reset_time_support=False, ignore_m Returns ------- TsGroup - TsGroup of merged objects - + A TsGroup of merged objects + Raises ------ TypeError If the input objects are not TsGroup objects ValueError - If ignore_metadata=False and metadata columns are not the same - If reset_index=False and keys overlap - If reset_time_support=False and time supports are not the same + If `ignore_metadata=False` but metadata columns are not the same + If `reset_index=False` but keys overlap + If `reset_time_support=False` but time supports are not the same Examples -------- @@ -1082,7 +1085,7 @@ def merge_group(*tsgroups, reset_index=False, reset_time_support=False, ignore_m >>> dict4 = {10: nap.Ts(t=[-1, 0, 1], time_units='s')} >>> tsgroup4 = nap.TsGroup(dict2, time_support=time_support_b) - Merge with default options if have same time_support and non-overlapping indexes: + Merge with default options if have the same time support and non-overlapping indexes: >>> tsgroup_12 = nap.TsGroup.merge_group(tsgroup1, tsgroup2) >>> tsgroup_12 @@ -1091,7 +1094,7 @@ def merge_group(*tsgroups, reset_index=False, reset_time_support=False, ignore_m 0 1.5 10 1.5 - Pass reset_index=True if indexes are overlapping: + Set `reset_index=True` if indexes are overlapping: >>> tsgroup_13 = nap.TsGroup.merge_group(tsgroup1, tsgroup3, reset_index=True) >>> tsgroup_13 @@ -1101,7 +1104,7 @@ def merge_group(*tsgroups, reset_index=False, reset_time_support=False, ignore_m 0 1.5 1 1.5 - Pass reset_time_support=True if time_supports are different: + Set `reset_time_support=True` if time supports are different: >>> tsgroup_14 = nap.TsGroup.merge_group(tsgroup1, tsgroup4, reset_time_support=True) >>> tsgroup_14 @@ -1119,43 +1122,51 @@ def merge_group(*tsgroups, reset_index=False, reset_time_support=False, ignore_m """ is_tsgroup = [isinstance(tsg, TsGroup) for tsg in tsgroups] if not all(is_tsgroup): - not_tsgroup_index = [i+1 for i, boo in enumerate(is_tsgroup) if not boo] + not_tsgroup_index = [i + 1 for i, boo in enumerate(is_tsgroup) if not boo] raise TypeError(f"Input at positions {not_tsgroup_index} are not TsGroup!") if len(tsgroups) == 1: - print('Only one TsGroup object provided, no merge needed') + print("Only one TsGroup object provided, no merge needed.") return tsgroups[0] - tsg1 = tsgroups[0] - items = tsg1.items() - keys = set(tsg1.keys()) + tsg1 = tsgroups[0] + items = tsg1.items() + keys = set(tsg1.keys()) metadata = tsg1._metadata for i, tsg in enumerate(tsgroups[1:]): if not ignore_metadata: if tsg1.metadata_columns != tsg.metadata_columns: - raise ValueError(f"TsGroup at position {i+2} has different metadata columns from previous TsGroup objects. " - "Set `ignore_metadata=True` to bypass the check.") + raise ValueError( + f"TsGroup at position {i+2} has different metadata columns from previous TsGroup objects. " + "Set `ignore_metadata=True` to bypass the check." + ) metadata = pd.concat([metadata, tsg._metadata], axis=0) - + if not reset_index: key_overlap = keys.intersection(tsg.keys()) if key_overlap: - raise ValueError(f"TsGroup at position {i+2} has overlapping keys {key_overlap} with previous TsGroup objects. " - "Pass reset_index=True to bypass") + raise ValueError( + f"TsGroup at position {i+2} has overlapping keys {key_overlap} with previous TsGroup objects. " + "Set `reset_index=True` to bypass the check." + ) keys.update(tsg.keys()) - + if reset_time_support: time_support = None else: - if not np.array_equal( - tsg1.time_support.as_units('s').to_numpy(), - tsg.time_support.as_units('s').to_numpy() - ): - raise ValueError(f"TsGroup at position {i+2} has different time support from previous TsGroup objects. " - "Pass reset_time_support=True to bypass") + if not np.allclose( + tsg1.time_support.as_units("s").to_numpy(), + tsg.time_support.as_units("s").to_numpy(), + atol=10 ** (-time_index_precision), + rtol=0, + ): + raise ValueError( + f"TsGroup at position {i+2} has different time support from previous TsGroup objects. " + "Set `reset_time_support=True` to bypass the check." + ) time_support = tsg1.time_support - + items.extend(tsg.items()) if reset_index: @@ -1165,22 +1176,31 @@ def merge_group(*tsgroups, reset_index=False, reset_time_support=False, ignore_m data = dict(items) if ignore_metadata: - return TsGroup( - data, time_support=time_support, bypass_check=False - ) + return TsGroup(data, time_support=time_support, bypass_check=False) else: cols = metadata.columns.drop("rate") return TsGroup( data, time_support=time_support, bypass_check=False, **metadata[cols] - ) - - def merge(self, *tsgroups, reset_index=False, reset_time_support=False, ignore_metadata=False): + ) + + def merge( + self, + *tsgroups, + reset_index=False, + reset_time_support=False, + ignore_metadata=False, + ): """ Merge the TsGroup object with other TsGroup objects See `TsGroup.merge_group` for more details """ return TsGroup.merge_group( - self, *tsgroups, reset_index=reset_index, reset_time_support=reset_time_support, ignore_metadata=ignore_metadata) + self, + *tsgroups, + reset_index=reset_index, + reset_time_support=reset_time_support, + ignore_metadata=ignore_metadata, + ) def save(self, filename): """ From 26f633388b1cb81d863b5f77c581a8021ff0f5cb Mon Sep 17 00:00:00 2001 From: qian-chu Date: Sun, 19 May 2024 14:46:01 +0200 Subject: [PATCH 08/18] correctly access nap_config --- pynapple/core/ts_group.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index 635796e4..d8deddd5 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -21,7 +21,7 @@ jitunion_isets, ) from .base_class import Base -from .config import time_index_precision +from .config import nap_config from .interval_set import IntervalSet from .time_index import TsIndex from .time_series import BaseTsd, Ts, Tsd, TsdFrame, is_array_like @@ -1158,7 +1158,7 @@ def merge_group( if not np.allclose( tsg1.time_support.as_units("s").to_numpy(), tsg.time_support.as_units("s").to_numpy(), - atol=10 ** (-time_index_precision), + atol=10 ** (-nap_config.time_index_precision), rtol=0, ): raise ValueError( From 5c49a0ea381cd3a5584c976fff9e2bb27f7fff4d Mon Sep 17 00:00:00 2001 From: qian-chu Date: Mon, 20 May 2024 00:50:49 +0200 Subject: [PATCH 09/18] add test for merging groups --- tests/test_ts_group.py | 113 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 109 insertions(+), 4 deletions(-) diff --git a/tests/test_ts_group.py b/tests/test_ts_group.py index 65d23af4..062240ff 100644 --- a/tests/test_ts_group.py +++ b/tests/test_ts_group.py @@ -13,6 +13,7 @@ from collections import UserDict import warnings from contextlib import nullcontext as does_not_raise +import itertools @pytest.fixture @@ -575,15 +576,16 @@ def test_save_npz(self, group): np.testing.assert_array_almost_equal(file['index'], index) np.testing.assert_array_almost_equal(file['meta'], np.arange(len(group), dtype=np.int64)) assert np.all(file['meta2']==np.array(['a', 'b', 'c'])) + file.close() tsgroup3 = nap.TsGroup({ 0: nap.Ts(t=np.arange(0, 20)), }) tsgroup3.save("tsgroup3") - file = np.load("tsgroup3.npz") - assert 'd' not in list(file.keys()) - np.testing.assert_array_almost_equal(file['t'], tsgroup3[0].index) + with np.load("tsgroup3.npz") as file: + assert 'd' not in list(file.keys()) + np.testing.assert_array_almost_equal(file['t'], tsgroup3[0].index) os.remove("tsgroup.npz") os.remove("tsgroup2.npz") @@ -752,4 +754,107 @@ def test_getitem_attribute_error(self, ts_group): ) def test_getitem_boolean_fail(self, ts_group, bool_idx, expectation): with expectation: - out = ts_group[bool_idx] \ No newline at end of file + out = ts_group[bool_idx] + + def test_merge_complete(self, ts_group): + with pytest.raises(TypeError) as e_info: + nap.TsGroup.merge_group(ts_group, str, dict) + assert str(e_info.value) == f"Input at positions {[2, 3]} are not TsGroup!" + + ts_group2 = nap.TsGroup( + { + 3: nap.Ts(t=np.arange(15)), + 4: nap.Ts(t=np.arange(20)), + }, + time_support=ts_group.time_support, + meta=np.array([12, 13]) + ) + merged = ts_group.merge(ts_group2) + assert len(merged) == 4 + assert np.all(merged.keys() == np.array([1, 2, 3, 4])) + assert np.all(merged.meta == np.array([10, 11, 12, 13])) + np.testing.assert_equal(merged.metadata_columns, ts_group.metadata_columns) + + @pytest.mark.parametrize( + 'col_name, ignore_metadata, expectation', + [ + ('meta', False, does_not_raise()), + ('meta', True, does_not_raise()), + ('wrong_name', False, pytest.raises(ValueError, match="TsGroup at position 2 has different metadata columns.*")), + ('wrong_name', True, does_not_raise()) + ] + ) + def test_merge_metadata(self, ts_group, col_name, ignore_metadata, expectation): + metadata = pd.DataFrame([12, 13], index=[3, 4], columns=[col_name]) + ts_group2 = nap.TsGroup( + { + 3: nap.Ts(t=np.arange(15)), + 4: nap.Ts(t=np.arange(20)), + }, + time_support=ts_group.time_support, + **metadata + ) + + with expectation: + merged = ts_group.merge(ts_group2, ignore_metadata=ignore_metadata) + + if ignore_metadata: + assert merged.metadata_columns[0] == 'rate' + elif col_name == 'meta': + np.testing.assert_equal(merged.metadata_columns, ts_group.metadata_columns) + + @pytest.mark.parametrize( + 'index, reset_index, expectation', + [ + (np.array([1, 2]), False, pytest.raises(ValueError, match="TsGroup at position 2 has overlapping keys.*")), + (np.array([1, 2]), True, does_not_raise()), + (np.array([3, 4]), False, does_not_raise()), + (np.array([3, 4]), True, does_not_raise()) + ] + ) + def test_merge_index(self, ts_group, index, reset_index, expectation): + ts_group2 = nap.TsGroup( + dict(zip(index, [nap.Ts(t=np.arange(15)), nap.Ts(t=np.arange(20))])), + time_support=ts_group.time_support, + meta=np.array([12, 13]) + ) + + with expectation: + merged = ts_group.merge(ts_group2, reset_index=reset_index) + + if reset_index: + assert np.all(merged.keys() == np.arange(4)) + elif np.all(index == np.array([3, 4])): + assert np.all(merged.keys() == np.array([1, 2, 3, 4])) + + @pytest.mark.parametrize( + 'time_support, reset_time_support, expectation', + [ + (None, False, does_not_raise()), + (None, True, does_not_raise()), + (nap.IntervalSet(start=0, end=1), False, + pytest.raises(ValueError, match="TsGroup at position 2 has different time support.*")), + (nap.IntervalSet(start=0, end=1), True, does_not_raise()) + ] + ) + def test_merge_time_support(self, ts_group, time_support, reset_time_support, expectation): + if time_support is None: + time_support = ts_group.time_support + + ts_group2 = nap.TsGroup( + { + 3: nap.Ts(t=np.arange(15)), + 4: nap.Ts(t=np.arange(20)), + }, + time_support=time_support, + meta=np.array([12, 13]) + ) + + with expectation: + merged = ts_group.merge(ts_group2, reset_time_support=reset_time_support) + + if reset_time_support: + np.testing.assert_array_almost_equal( + ts_group.time_support.as_units("s").to_numpy(), + merged.time_support.as_units("s").to_numpy() + ) \ No newline at end of file From a65f5f5e6c82f4fdf124962c149ee0111ffe0877 Mon Sep 17 00:00:00 2001 From: Qian Chu <97355086+qian-chu@users.noreply.github.com> Date: Mon, 20 May 2024 15:41:28 +0200 Subject: [PATCH 10/18] Update tests/test_ts_group.py Co-authored-by: Edoardo Balzani --- tests/test_ts_group.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_ts_group.py b/tests/test_ts_group.py index 062240ff..1688e741 100644 --- a/tests/test_ts_group.py +++ b/tests/test_ts_group.py @@ -757,9 +757,8 @@ def test_getitem_boolean_fail(self, ts_group, bool_idx, expectation): out = ts_group[bool_idx] def test_merge_complete(self, ts_group): - with pytest.raises(TypeError) as e_info: + with pytest.raises(TypeError, match=f"Input at positions {[2, 3]} are not TsGroup!"): nap.TsGroup.merge_group(ts_group, str, dict) - assert str(e_info.value) == f"Input at positions {[2, 3]} are not TsGroup!" ts_group2 = nap.TsGroup( { From 6bf20a7a849ee6d09ae1fcdcfee52a4d663c0f4c Mon Sep 17 00:00:00 2001 From: Qian Chu <97355086+qian-chu@users.noreply.github.com> Date: Mon, 20 May 2024 16:06:48 +0200 Subject: [PATCH 11/18] Update pynapple/core/ts_group.py Co-authored-by: Edoardo Balzani --- pynapple/core/ts_group.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index 9be60cc5..38fcb767 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -1106,10 +1106,10 @@ def merge_group( >>> tsgroup_14 >>> tsgroup_14.time_support - Index rate - ------- ------ - 0 0.3 - 10 0.3 + Index rate + ------- ------ + 0 0.3 + 10 0.3 start end 0 -5 5 From 6988758f1de88ecdaa5571c06266895081fcb833 Mon Sep 17 00:00:00 2001 From: Qian Chu <97355086+qian-chu@users.noreply.github.com> Date: Mon, 20 May 2024 16:07:06 +0200 Subject: [PATCH 12/18] Update pynapple/core/ts_group.py Co-authored-by: Edoardo Balzani --- pynapple/core/ts_group.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index 38fcb767..26c5f383 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -1095,10 +1095,10 @@ def merge_group( >>> tsgroup_13 = nap.TsGroup.merge_group(tsgroup1, tsgroup3, reset_index=True) >>> tsgroup_13 - Index rate - ------- ------ - 0 1.5 - 1 1.5 + Index rate + ------- ------ + 0 1.5 + 1 1.5 Set `reset_time_support=True` if time supports are different: From 41e0dfc5d89cdbdad5393dbe93f2d975ab893a44 Mon Sep 17 00:00:00 2001 From: qian-chu <97355086+qian-chu@users.noreply.github.com> Date: Mon, 20 May 2024 16:39:23 +0200 Subject: [PATCH 13/18] Move docstring to merge() --- pynapple/core/ts_group.py | 140 +++++++++++++++++++++++--------------- 1 file changed, 85 insertions(+), 55 deletions(-) diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index 26c5f383..221c0436 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -1062,59 +1062,6 @@ def merge_group( If `reset_index=False` but keys overlap If `reset_time_support=False` but time supports are not the same - Examples - -------- - - >>> import pynapple as nap - >>> time_support_a = nap.IntervalSet(start=-1, end=1, time_units='s') - >>> time_support_b = nap.IntervalSet(start=-5, end=5, time_units='s') - - >>> dict1 = {0: nap.Ts(t=[-1, 0, 1], time_units='s')} - >>> tsgroup1 = nap.TsGroup(dict1, time_support=time_support_a) - - >>> dict2 = {10: nap.Ts(t=[-1, 0, 1], time_units='s')} - >>> tsgroup2 = nap.TsGroup(dict2, time_support=time_support_a) - - >>> dict3 = {0: nap.Ts(t=[-.1, 0, .1], time_units='s')} - >>> tsgroup3 = nap.TsGroup(dict3, time_support=time_support_a) - - >>> dict4 = {10: nap.Ts(t=[-1, 0, 1], time_units='s')} - >>> tsgroup4 = nap.TsGroup(dict2, time_support=time_support_b) - - Merge with default options if have the same time support and non-overlapping indexes: - - >>> tsgroup_12 = nap.TsGroup.merge_group(tsgroup1, tsgroup2) - >>> tsgroup_12 - Index rate - ------- ------ - 0 1.5 - 10 1.5 - - Set `reset_index=True` if indexes are overlapping: - - >>> tsgroup_13 = nap.TsGroup.merge_group(tsgroup1, tsgroup3, reset_index=True) - >>> tsgroup_13 - - Index rate - ------- ------ - 0 1.5 - 1 1.5 - - Set `reset_time_support=True` if time supports are different: - - >>> tsgroup_14 = nap.TsGroup.merge_group(tsgroup1, tsgroup4, reset_time_support=True) - >>> tsgroup_14 - >>> tsgroup_14.time_support - - Index rate - ------- ------ - 0 0.3 - 10 0.3 - - start end - 0 -5 5 - shape: (1, 2), time unit: sec. - """ is_tsgroup = [isinstance(tsg, TsGroup) for tsg in tsgroups] if not all(is_tsgroup): @@ -1187,8 +1134,91 @@ def merge( ignore_metadata=False, ): """ - Merge the TsGroup object with other TsGroup objects - See `TsGroup.merge_group` for more details + Merge the TsGroup object with other TsGroup objects. + Common uses include adding more neurons/channels (supposing each Ts/Tsd corresponds to data from a neuron/channel) or adding more trials (supposing each Ts/Tsd corresponds to data from a trial). + + Parameters + ---------- + *tsgroups : TsGroup + The TsGroup objects to merge with + reset_index : bool, optional + If True, the keys will be reset to range(len(data)) + If False, the keys of the TsGroup objects should be non-overlapping and will be preserved + reset_time_support : bool, optional + If True, the merged TsGroup will merge time supports from all the Ts/Tsd objects in data + If False, the time support of the TsGroup objects should be the same and will be preserved + ignore_metadata : bool, optional + If True, the merged TsGroup will not have any metadata columns other than 'rate' + If False, all metadata columns should be the same and all metadata will be concatenated + + Returns + ------- + TsGroup + A TsGroup of merged objects + + Raises + ------ + TypeError + If the input objects are not TsGroup objects + ValueError + If `ignore_metadata=False` but metadata columns are not the same + If `reset_index=False` but keys overlap + If `reset_time_support=False` but time supports are not the same + + Examples + -------- + + >>> import pynapple as nap + >>> time_support_a = nap.IntervalSet(start=-1, end=1, time_units='s') + >>> time_support_b = nap.IntervalSet(start=-5, end=5, time_units='s') + + >>> dict1 = {0: nap.Ts(t=[-1, 0, 1], time_units='s')} + >>> tsgroup1 = nap.TsGroup(dict1, time_support=time_support_a) + + >>> dict2 = {10: nap.Ts(t=[-1, 0, 1], time_units='s')} + >>> tsgroup2 = nap.TsGroup(dict2, time_support=time_support_a) + + >>> dict3 = {0: nap.Ts(t=[-.1, 0, .1], time_units='s')} + >>> tsgroup3 = nap.TsGroup(dict3, time_support=time_support_a) + + >>> dict4 = {10: nap.Ts(t=[-1, 0, 1], time_units='s')} + >>> tsgroup4 = nap.TsGroup(dict2, time_support=time_support_b) + + Merge with default options if have the same time support and non-overlapping indexes: + + >>> tsgroup_12 = tsgroup1.merge(tsgroup2) + >>> tsgroup_12 + Index rate + ------- ------ + 0 1.5 + 10 1.5 + + Set `reset_index=True` if indexes are overlapping: + + >>> tsgroup_13 = tsgroup1.merge(tsgroup3, reset_index=True) + >>> tsgroup_13 + + Index rate + ------- ------ + 0 1.5 + 1 1.5 + + Set `reset_time_support=True` if time supports are different: + + >>> tsgroup_14 = tsgroup1.merge(tsgroup4, reset_time_support=True) + >>> tsgroup_14 + >>> tsgroup_14.time_support + + Index rate + ------- ------ + 0 0.3 + 10 0.3 + + start end + 0 -5 5 + shape: (1, 2), time unit: sec. + + See Also `TsGroup.merge_group` """ return TsGroup.merge_group( self, From e08c0e4820fa505711b55c2fdbe7f9b8c6de1900 Mon Sep 17 00:00:00 2001 From: qian-chu <97355086+qian-chu@users.noreply.github.com> Date: Mon, 20 May 2024 17:05:55 +0200 Subject: [PATCH 14/18] Fix regex error --- tests/test_ts_group.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_ts_group.py b/tests/test_ts_group.py index 1688e741..fc221f60 100644 --- a/tests/test_ts_group.py +++ b/tests/test_ts_group.py @@ -13,8 +13,6 @@ from collections import UserDict import warnings from contextlib import nullcontext as does_not_raise -import itertools - @pytest.fixture def group(): @@ -757,7 +755,7 @@ def test_getitem_boolean_fail(self, ts_group, bool_idx, expectation): out = ts_group[bool_idx] def test_merge_complete(self, ts_group): - with pytest.raises(TypeError, match=f"Input at positions {[2, 3]} are not TsGroup!"): + with pytest.raises(TypeError, match="Input at positions(.*)are not TsGroup!"): nap.TsGroup.merge_group(ts_group, str, dict) ts_group2 = nap.TsGroup( From 2225a2aefbdcdda59dfd8cd977b4c881aa3ccbdb Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 20 May 2024 15:56:37 -0400 Subject: [PATCH 15/18] fixed docstrings --- pynapple/core/ts_group.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index 221c0436..b4865d6d 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -1188,37 +1188,37 @@ def merge( >>> tsgroup_12 = tsgroup1.merge(tsgroup2) >>> tsgroup_12 - Index rate - ------- ------ - 0 1.5 - 10 1.5 + Index rate + ------- ------ + 0 1.5 + 10 1.5 Set `reset_index=True` if indexes are overlapping: >>> tsgroup_13 = tsgroup1.merge(tsgroup3, reset_index=True) >>> tsgroup_13 - - Index rate - ------- ------ - 0 1.5 - 1 1.5 + Index rate + ------- ------ + 0 1.5 + 1 1.5 Set `reset_time_support=True` if time supports are different: >>> tsgroup_14 = tsgroup1.merge(tsgroup4, reset_time_support=True) >>> tsgroup_14 >>> tsgroup_14.time_support - - Index rate - ------- ------ - 0 0.3 - 10 0.3 + Index rate + ------- ------ + 0 0.3 + 10 0.3 start end 0 -5 5 shape: (1, 2), time unit: sec. - See Also `TsGroup.merge_group` + See Also + -------- + [`TsGroup.merge_group`](./#pynapple.core.ts_group.TsGroup.merge_group) """ return TsGroup.merge_group( self, From 235d06e29c9162e4e1c8c64b0b9f7ecba0213645 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 20 May 2024 16:10:17 -0400 Subject: [PATCH 16/18] add checks for pr to dev --- .github/workflows/main.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 349e7157..7c3a45de 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -5,7 +5,7 @@ on: push: branches: [ main ] pull_request: - branches: [ main ] + branches: [ main, dev ] jobs: lint: From 92e0df4c900418b4a04c84c7509327b3b916cbe0 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 21 May 2024 12:09:31 -0400 Subject: [PATCH 17/18] change to codecov --- .github/workflows/main.yml | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 7c3a45de..6cc876b7 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -52,14 +52,12 @@ jobs: - name: Test run: | coverage run --source=pynapple --branch -m pytest tests/ - coverage report -m + coverage report -m - - name: Coveralls - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - run: | - pip install coveralls - coveralls --service=github + - name: Upload coverage reports to Codecov + uses: codecov/codecov-action@v4.0.1 + with: + token: ${{ secrets.CODECOV_TOKEN }} check: if: always() needs: From e6c9bc58bff3f19c5f5d26a7cfb4931fd222f563 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 21 May 2024 12:11:42 -0400 Subject: [PATCH 18/18] updated actions --- .github/workflows/main.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 6cc876b7..0c779db4 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -40,9 +40,9 @@ jobs: # - os: windows-latest # python-version: 3.7 steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - name: Install dependencies