Skip to content

Commit

Permalink
Merge pull request #304 from eschombu/interval-tuples
Browse files Browse the repository at this point in the history
IntervalSet creation from (start, end) pairs
  • Loading branch information
gviejo authored Aug 13, 2024
2 parents 43d4e7f + fc6f731 commit e3bf7cc
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 5 deletions.
23 changes: 18 additions & 5 deletions pynapple/core/interval_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,13 @@ def __init__(self, start, end=None, time_units="s"):
Parameters
----------
start : numpy.ndarray or number or pandas.DataFrame or pandas.Series
Beginning of intervals
start : numpy.ndarray or number or pandas.DataFrame or pandas.Series or iterable of (start, end) pairs
Beginning of intervals. Alternatively, the `end` argument can be left out and `start` can be one of the
following:
- IntervalSet
- pandas.DataFrame with columns ["start", "end"]
- iterable of (start, end) pairs
- a single (start, end) pair
end : numpy.ndarray or number or pandas.Series, optional
Ends of intervals
time_units : str, optional
Expand All @@ -108,8 +113,8 @@ def __init__(self, start, end=None, time_units="s"):
"""
if isinstance(start, IntervalSet):
end = start.values[:, 1].astype(np.float64)
start = start.values[:, 0].astype(np.float64)
end = start.end.astype(np.float64)
start = start.start.astype(np.float64)

elif isinstance(start, pd.DataFrame):
assert (
Expand All @@ -125,7 +130,15 @@ def __init__(self, start, end=None, time_units="s"):
start = start["start"].values.astype(np.float64)

else:
assert end is not None, "Missing end argument when initializing IntervalSet"
if end is None:
# Require iterable of (start, end) tuples
try:
start_end_array = np.array(list(start)).reshape(-1, 2)
start, end = zip(*start_end_array)
except (TypeError, ValueError):
raise ValueError(
"Unable to Interpret the input. Please provide a list of start-end pairs."
)

args = {"start": start, "end": end}

Expand Down
17 changes: 17 additions & 0 deletions tests/test_interval_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,23 @@ def test_create_iset_from_mock_array():
np.testing.assert_array_almost_equal(ep.start, start)
np.testing.assert_array_almost_equal(ep.end, end)

def test_create_iset_from_tuple():
start = 0
end = 5
ep = nap.IntervalSet((start, end))
assert isinstance(ep, nap.core.interval_set.IntervalSet)
np.testing.assert_array_almost_equal(start, ep.start[0])
np.testing.assert_array_almost_equal(end, ep.end[0])

def test_create_iset_from_tuple_iter():
start = [0, 10, 16, 25]
end = [5, 15, 20, 40]
pairs = zip(start, end)
ep = nap.IntervalSet(pairs)
assert isinstance(ep, nap.core.interval_set.IntervalSet)
np.testing.assert_array_almost_equal(start, ep.start)
np.testing.assert_array_almost_equal(end, ep.end)

def test_create_iset_from_unknown_format():
with pytest.raises(RuntimeError) as e:
nap.IntervalSet(start="abc", end=[1, 2])
Expand Down

0 comments on commit e3bf7cc

Please sign in to comment.