Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow filtering by ids #627

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 54 additions & 1 deletion src/spatialdata/_core/spatialdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
import os
import warnings
from collections.abc import Generator, Mapping
from collections.abc import Generator, Iterable, Mapping
from itertools import chain
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal
Expand Down Expand Up @@ -2417,6 +2417,59 @@ def __delitem__(self, key: str) -> None:
element_type, _, _ = self._find_element(key)
getattr(self, element_type).__delitem__(key)

def filter_elements_by_instances(
self,
element_names: Iterable[str],
instances: Iterable[int | str],
region_names: Iterable[str] | str | None = None,
) -> dict[str, DaskDataFrame | GeoDataFrame | AnnData]:
"""
Filter elements to contain only certain instances.

This filters both SpatialElements (points and shapes)
as well as tables to only contain certain IDs. In case of tables
the instance key column of table.obs will be filtered on and not
table.obs.index. Filtering labels by ID is currently not supported
as this is an expensive operation. Should you require this
please open an issue on github.com/scverse/spatialdata. Lastly,
tables not annotating an element cannot be filtered.

element_names:
Name of either points, shapes or table elements within the Spatialdata
object.
instances:
The instance IDs to filter the elements on.
region_names:
If filtering instances in a table, indicate the region_names (the names of the SpatialElement) for
which you want to filter the instances of the table. If not specified, the table instances for all regions
annotated by the table will be filtered by the given instances.
"""
element_dict = {}
element_names = [element_names] if isinstance(element_names, str) else list(element_names)
for element_name in element_names:
element = self.get(element_name)
if element is not None:
if (model := get_model(element)) == PointsModel:
instance_key = element.attrs[PointsModel.ATTRS_KEY][PointsModel.INSTANCE_KEY]
element_dict[element_name] = element[element[instance_key].isin(instances)]
elif model == ShapesModel:
element_dict[element_name] = element[element.index.isin(instances)]
elif model == TableModel:
instance_key = element.uns[TableModel.ATTRS_KEY][TableModel.INSTANCE_KEY]
region_key = element.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY]
if region_names:
region_names = [region_names] if isinstance(region_names, str) else region_names
element = element[element.obs[region_key].isin(region_names)]
regions = element.obs[region_key].cat.categories.tolist()
element_dict[element_name] = element[element.obs[instance_key].isin(instances)].copy()
element_dict[element_name].uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = regions
TableModel().validate(element_dict[element_name])
else:
raise TypeError(f"`{model}` is not a valid model for filtering of instances.")
else:
raise KeyError(f"`{element_name}` is not an element in the SpatialData object.")
return element_dict

@property
def attrs(self) -> dict[Any, Any]:
"""
Expand Down
Loading