Skip to content

Commit

Permalink
Merge pull request #383 from pynapple-org/metadata
Browse files Browse the repository at this point in the history
Add groupby and groupby_apply functionality to metadata
  • Loading branch information
gviejo authored Mar 5, 2025
2 parents 2d5a12a + ed2b49a commit b6ea042
Show file tree
Hide file tree
Showing 6 changed files with 869 additions and 78 deletions.
98 changes: 80 additions & 18 deletions doc/user_guide/03_metadata.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,25 +36,26 @@ import pynapple as nap
# input parameters for TsGroup
group = {
1: nap.Ts(t=np.sort(np.random.uniform(0, 100, 10))),
2: nap.Ts(t=np.sort(np.random.uniform(0, 100, 20))),
3: nap.Ts(t=np.sort(np.random.uniform(0, 100, 30))),
1: nap.Ts(t=np.sort(np.random.uniform(0, 100, 100))),
2: nap.Ts(t=np.sort(np.random.uniform(0, 100, 200))),
3: nap.Ts(t=np.sort(np.random.uniform(0, 100, 300))),
4: nap.Ts(t=np.sort(np.random.uniform(0, 100, 400))),
}
# input parameters for IntervalSet
starts = [0,10,20]
ends = [5,15,25]
starts = [0,35,70]
ends = [30,65,100]
# input parameters for TsdFrame
t = np.arange(5)
d = np.ones((5,3))
d = np.tile([1,2,3], (5, 1))
columns = ["a", "b", "c"]
```

### `TsGroup`
Metadata added to `TsGroup` must match the number of `Ts`/`Tsd` objects, or the length of its `index` property.
```{code-cell} ipython3
metadata = {"region": ["pfc", "ofc", "hpc"]}
metadata = {"region": ["pfc", "pfc", "hpc", "hpc"]}
tsgroup = nap.TsGroup(group, metadata=metadata)
print(tsgroup)
Expand All @@ -64,7 +65,7 @@ When initializing with a DataFrame, the index must align with the input dictiona
```{code-cell} ipython3
metadata = pd.DataFrame(
index=group.keys(),
data=["pfc", "ofc", "hpc"],
data=["pfc", "pfc", "hpc", "hpc"],
columns=["region"]
)
Expand All @@ -88,7 +89,7 @@ print(intervalset)
Metadata can be initialized as a DataFrame using the metadata argument, or it can be inferred when initializing an `IntervalSet` with a DataFrame.
```{code-cell} ipython3
df = pd.DataFrame(
data=[[0, 5, 1, "left"], [10, 15, 0, "right"], [20, 25, 1, "left"]],
data=[[0, 30, 1, "left"], [35, 65, 0, "right"], [70, 100, 1, "left"]],
columns=["start", "end", "reward", "choice"]
)
Expand All @@ -101,7 +102,8 @@ Metadata added to `TsdFrame` must match the number of data columns, or the lengt
```{code-cell} ipython3
metadata = {
"color": ["red", "blue", "green"],
"position": [10,20,30]
"position": [10,20,30],
"label": ["x", "x", "y"]
}
tsdframe = nap.TsdFrame(d=d, t=t, columns=["a", "b", "c"], metadata=metadata)
Expand All @@ -112,8 +114,8 @@ When initializing with a DataFrame, the DataFrame index must match the `TsdFrame
```{code-cell} ipython3
metadata = pd.DataFrame(
index=["a", "b", "c"],
data=[["red", 10], ["blue", 20], ["green", 30]],
columns=["color", "position"],
data=[["red", 10, "x"], ["blue", 20, "x"], ["green", 30, "y"]],
columns=["color", "position", "label"],
)
tsdframe = nap.TsdFrame(d=d, t=t, columns=["a", "b", "c"], metadata=metadata)
Expand All @@ -130,21 +132,21 @@ The remaining metadata examples will be shown on a `TsGroup` object; however, al
### `set_info`
Metadata can be passed as a dictionary or pandas DataFrame as the first positional argument, or metadata can be passed as name-value keyword arguments.
```{code-cell} ipython3
tsgroup.set_info(unit_type=["multi", "single", "single"])
tsgroup.set_info(unit_type=["multi", "single", "single", "single"])
print(tsgroup)
```

### Using dictionary-like keys (square brackets)
Most metadata names can set as a dictionary-like key (i.e. using square brackets). The only exceptions are for `IntervalSet`, where the names "start" and "end" are reserved for class properties.
```{code-cell} ipython3
tsgroup["depth"] = [0, 1, 2]
tsgroup["depth"] = [0, 1, 2, 3]
print(tsgroup)
```

### Using attribute assignment
If the metadata name is unique from other class attributes and methods, and it is formatted properly (i.e. only alpha-numeric characters and underscores), it can be set as an attribute (i.e. using a `.` followed by the metadata name).
```{code-cell} ipython3
tsgroup.label=["MUA", "good", "good"]
tsgroup.label=["MUA", "good", "good", "good"]
print(tsgroup)
```

Expand Down Expand Up @@ -177,20 +179,80 @@ print(tsgroup.region)
User-set metadata is mutable and can be overwritten.
```{code-cell} ipython3
print(tsgroup, "\n")
tsgroup.set_info(region=["A", "B", "C"])
tsgroup.set_info(label=["A", "B", "C", "D"])
print(tsgroup)
```

## Allowed data types
As long as the length of the metadata container matches the length of the object (number of columns for `TsdFrame` and number of indices for `IntervalSet` and `TsGroup`), elements of the metadata can be any data type.
```{code-cell} ipython3
tsgroup.coords = [[1,0],[0,1],[1,1]]
tsgroup.coords = [[1,0],[0,1],[1,1],[2,1]]
print(tsgroup.coords)
```

## Using metadata to slice objects
Metadata can be used to slice or filter objects based on metadata values.
```{code-cell} ipython3
print(tsgroup[tsgroup.label == "good"])
print(tsgroup[tsgroup.label == "A"])
```

## `groupby`: Using metadata to group objects
Similar to pandas, metadata can be used to group objects based on one or more metadata columns using the object method `groupby`, where the first argument is the metadata columns name(s) to group by. This function returns a dictionary with keys corresponding to unique groups and values corresponding to object indices belonging to each group.
```{code-cell} ipython3
print(tsgroup,"\n")
print(tsgroup.groupby("region"))
```

Grouping by multiple metadata columns should be passed as a list.
```{code-cell} ipython3
tsgroup.groupby(["region","unit_type"])
```

The optional argument `get_group` can be provided to return a new object corresponding to a specific group.
```{code-cell} ipython3
tsgroup.groupby("region", get_group="hpc")
```

## `groupby_apply`: Applying functions to object groups
The `groupby_apply` object method allows a specific function to be applied to object groups. The first argument, same as `groupby`, is the metadata column(s) used to group the object. The second argument is the function to apply to each group. If only these two arguments are supplied, it is assumed that the grouped object is the first and only input to the applied function. This function returns a dictionary, where keys correspond to each unique group, and values correspond to the function output on each group.
```{code-cell} ipython3
print(tsdframe,"\n")
print(tsdframe.groupby_apply("label", np.mean))
```

If the applied function requires additional inputs, these can be passed as additional keyword arguments into `groupby_apply`.
```{code-cell} ipython3
feature = nap.Tsd(t=np.arange(100), d=np.repeat([0,1], 50))
tsgroup.groupby_apply(
"region",
nap.compute_1d_tuning_curves,
feature=feature,
nb_bins=2)
```

Alternatively, an anonymous function can be passed instead that defines additional arguments.
```{code-cell} ipython3
func = lambda x: nap.compute_1d_tuning_curves(x, feature=feature, nb_bins=2)
tsgroup.groupby_apply("region", func)
```

An anonymous function can also be used to apply a function where the grouped object is not the first input.
```{code-cell} ipython3
func = lambda x: nap.compute_1d_tuning_curves(
group=tsgroup,
feature=feature,
nb_bins=2,
ep=x)
intervalset.groupby_apply("choice", func)
```

Alternatively, the optional parameter `input_key` can be passed to specify which keyword argument the grouped object corresponds to. Other required arguments of the applied function need to be passed as keyword arguments.
```{code-cell} ipython3
intervalset.groupby_apply(
"choice",
nap.compute_1d_tuning_curves,
input_key="ep",
group=tsgroup,
feature=feature,
nb_bins=2)
```
106 changes: 105 additions & 1 deletion pynapple/core/interval_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ def __getitem__(self, key):
output = self.values.__getitem__(key)
metadata = self._metadata.iloc[key].reset_index(drop=True)
return IntervalSet(start=output[:, 0], end=output[:, 1], metadata=metadata)
elif isinstance(key, pd.Series):
elif isinstance(key, (pd.Series, pd.Index)):
# use loc for metadata
output = self.values.__getitem__(key)
metadata = _MetadataMixin.__getitem__(self, key).reset_index(drop=True)
Expand Down Expand Up @@ -1198,3 +1198,107 @@ def get_info(self, key):
2 3 y
"""
return _MetadataMixin.get_info(self, key)

@add_meta_docstring("groupby")
def groupby(self, by, get_group=None):
"""
Examples
--------
>>> import pynapple as nap
>>> import numpy as np
>>> times = np.array([[0, 5], [10, 12], [20, 33]])
>>> metadata = {"l1": [1, 2, 2], "l2": ["x", "x", "y"]}
>>> ep = nap.IntervalSet(times,metadata=metadata)
>>> print(ep)
index start end l1 l2
0 0 5 1 x
1 10 12 2 x
2 20 33 2 y
shape: (3, 2), time unit: sec.
Grouping by a single column:
>>> ep.groupby("l2")
{'x': [0, 1], 'y': [2]}
Grouping by multiple columns:
>>> ep.groupby(["l1","l2"])
{(1, 'x'): [0], (2, 'x'): [1], (2, 'y'): [2]}
Filtering to a specific group using the output dictionary:
>>> groups = ep.groupby("l2")
>>> ep[groups["x"]]
index start end l1 l2
0 0 5 1 x
1 10 12 2 x
shape: (2, 2), time unit: sec.
Filtering to a specific group using the get_group argument:
>>> ep.groupby("l2", get_group="x")
index start end l1 l2
0 0 5 1 x
1 10 12 2 x
shape: (2, 2), time unit: sec.
"""
return _MetadataMixin.groupby(self, by, get_group)

@add_meta_docstring("groupby_apply")
def groupby_apply(self, by, func, input_key=None, **func_kwargs):
"""
Examples
--------
>>> import pynapple as nap
>>> import numpy as np
>>> times = np.array([[0, 5], [10, 12], [20, 33]])
>>> metadata = {"l1": [1, 2, 2], "l2": ["x", "x", "y"]}
>>> ep = nap.IntervalSet(times,metadata=metadata)
>>> print(ep)
index start end l1 l2
0 0 5 1 x
1 10 12 2 x
2 20 33 2 y
shape: (3, 2), time unit: sec.
Apply a numpy function::
>>> ep.groupby_apply("l2", np.mean)
{'x': 6.75, 'y': 26.5}
Apply a custom function:
>>> ep.groupby_apply("l2", lambda x: x.shape[0])
{'x': 2, 'y': 1}
Apply a function with additional arguments:
>>> ep.groupby_apply("l2", np.mean, axis=1)
{'x': array([ 2.5, 11. ]), 'y': array([26.5])}
Applying a function with additional arguments, where the grouped object is not the first argument:
>>> tsg = nap.TsGroup(
... {
... 1: nap.Ts(t=np.arange(0, 40)),
... 2: nap.Ts(t=np.arange(0, 40, 0.5), time_units="s"),
... 3: nap.Ts(t=np.arange(0, 40, 0.2), time_units="s"),
... },
... )
>>> feature = nap.Tsd(t=np.arange(40), d=np.concatenate([np.zeros(20), np.ones(20)]))
>>> func_kwargs = {
>>> "group": tsg,
>>> "feature": feature,
>>> "nb_bins": 2,
>>> }
>>> ep.groupby_apply("l2", nap.compute_1d_tuning_curves, input_key="ep", **func_kwargs)
{'x': 1 2 3
0.25 1.025641 1.823362 4.216524
0.75 NaN NaN NaN,
'y': 1 2 3
0.25 NaN NaN NaN
0.75 1.025641 1.978022 4.835165}
"""
return _MetadataMixin.groupby_apply(self, by, func, input_key, **func_kwargs)
Loading

0 comments on commit b6ea042

Please sign in to comment.