diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index a931839b..f42b4df8 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -4,7 +4,7 @@ import json import os import warnings -from collections.abc import Generator, Mapping +from collections.abc import Generator, Iterable, Mapping from itertools import chain from pathlib import Path from typing import TYPE_CHECKING, Any, Literal @@ -2417,6 +2417,59 @@ def __delitem__(self, key: str) -> None: element_type, _, _ = self._find_element(key) getattr(self, element_type).__delitem__(key) + def filter_elements_by_instances( + self, + element_names: Iterable[str], + instances: Iterable[int | str], + region_names: Iterable[str] | str | None = None, + ) -> dict[str, DaskDataFrame | GeoDataFrame | AnnData]: + """ + Filter elements to contain only certain instances. + + This filters both SpatialElements (points and shapes) + as well as tables to only contain certain IDs. In case of tables + the instance key column of table.obs will be filtered on and not + table.obs.index. Filtering labels by ID is currently not supported + as this is an expensive operation. Should you require this + please open an issue on github.com/scverse/spatialdata. Lastly, + tables not annotating an element cannot be filtered. + + element_names: + Name of either points, shapes or table elements within the Spatialdata + object. + instances: + The instance IDs to filter the elements on. + region_names: + If filtering instances in a table, indicate the region_names (the names of the SpatialElement) for + which you want to filter the instances of the table. If not specified, the table instances for all regions + annotated by the table will be filtered by the given instances. + """ + element_dict = {} + element_names = [element_names] if isinstance(element_names, str) else list(element_names) + for element_name in element_names: + element = self.get(element_name) + if element is not None: + if (model := get_model(element)) == PointsModel: + instance_key = element.attrs[PointsModel.ATTRS_KEY][PointsModel.INSTANCE_KEY] + element_dict[element_name] = element[element[instance_key].isin(instances)] + elif model == ShapesModel: + element_dict[element_name] = element[element.index.isin(instances)] + elif model == TableModel: + instance_key = element.uns[TableModel.ATTRS_KEY][TableModel.INSTANCE_KEY] + region_key = element.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY] + if region_names: + region_names = [region_names] if isinstance(region_names, str) else region_names + element = element[element.obs[region_key].isin(region_names)] + regions = element.obs[region_key].cat.categories.tolist() + element_dict[element_name] = element[element.obs[instance_key].isin(instances)].copy() + element_dict[element_name].uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = regions + TableModel().validate(element_dict[element_name]) + else: + raise TypeError(f"`{model}` is not a valid model for filtering of instances.") + else: + raise KeyError(f"`{element_name}` is not an element in the SpatialData object.") + return element_dict + @property def attrs(self) -> dict[Any, Any]: """