Skip to content

Commit

Permalink
Merge pull request #305 from eschombu/tsgroup-from-iter
Browse files Browse the repository at this point in the history
Create a TsGroup from an iterable of Ts/Tsd objects
  • Loading branch information
gviejo authored Aug 12, 2024
2 parents 8f21e99 + 64ffdc0 commit 43d4e7f
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 6 deletions.
13 changes: 9 additions & 4 deletions pynapple/core/ts_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ def __init__(
Parameters
----------
data : dict
Dictionary containing Ts/Tsd objects, keys should contain integer values or should be convertible
to integer.
data : dict or iterable
Dictionary or iterable of Ts/Tsd objects. The keys should be integer-convertible; if a non-dict iterator is
passed, its values will be used to create a dict with integer keys.
time_support : IntervalSet, optional
The time support of the TsGroup. Ts/Tsd objects will be restricted to the time support if passed.
If no time support is specified, TsGroup will merge time supports from all the Ts/Tsd objects in data.
Expand Down Expand Up @@ -117,13 +117,16 @@ def __init__(

self._initialized = False

if not isinstance(data, dict):
data = dict(enumerate(data))

# convert all keys to integer
try:
keys = [int(k) for k in data.keys()]
except Exception:
raise ValueError("All keys must be convertible to integer.")

# check that there were no floats with decimal points in keys.i
# check that there were no floats with decimal points in keys.
# i.e. 0.5 is not a valid key
if not all(np.allclose(keys[j], float(k)) for j, k in enumerate(data.keys())):
raise ValueError("All keys must have integer value!}")
Expand All @@ -135,6 +138,8 @@ def __init__(

data = {keys[j]: data[k] for j, k in enumerate(data.keys())}
self.index = np.sort(keys)
# Make sure data dict and index are ordered the same
data = {k: data[k] for k in self.index}

self._metadata = pd.DataFrame(index=self.index, columns=["rate"], dtype="float")

Expand Down
12 changes: 10 additions & 2 deletions tests/test_ts_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,19 @@ def test_create_ts_group(self, group):
assert isinstance(tsgroup, UserDict)
assert len(tsgroup) == 3

def test_create_ts_group_from_iter(self, group):
tsgroup = nap.TsGroup(group.values())
assert isinstance(tsgroup, UserDict)
assert len(tsgroup) == 3

def test_create_ts_group_from_invalid(self):
with pytest.raises(AttributeError):
tsgroup = nap.TsGroup(np.arange(0, 200))

@pytest.mark.parametrize(
"test_dict, expectation",
[
({"1": nap.Ts(np.arange(10)), "2":nap.Ts(np.arange(10))}, does_not_raise()),
({"1": nap.Ts(np.arange(10)), "2": nap.Ts(np.arange(10))}, does_not_raise()),
({"1": nap.Ts(np.arange(10)), 2: nap.Ts(np.arange(10))}, does_not_raise()),
({"1": nap.Ts(np.arange(10)), 1: nap.Ts(np.arange(10))},
pytest.raises(ValueError, match="Two dictionary keys contain the same integer")),
Expand Down Expand Up @@ -82,7 +91,6 @@ def test_initialize_from_dict(self, test_dict, expectation):
def test_metadata_len_match(self, tsgroup):
assert len(tsgroup._metadata) == len(tsgroup)


def test_create_ts_group_from_array(self):
with warnings.catch_warnings(record=True) as w:
nap.TsGroup({
Expand Down

0 comments on commit 43d4e7f

Please sign in to comment.