Skip to content

Commit

Permalink
getitem support for standard_name (#39)
Browse files Browse the repository at this point in the history
* Update pre-commit

* Rework getitem for standard_name support

* Add datasets.py
  • Loading branch information
dcherian authored Jun 23, 2020
1 parent 34384de commit d454cb4
Show file tree
Hide file tree
Showing 4 changed files with 226 additions and 70 deletions.
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ repos:
files: .+\.py$
# https://github.com/python/black#version-control-integration
- repo: https://github.com/python/black
rev: stable
rev: 19.10b0
hooks:
- id: black
- repo: https://gitlab.com/pycqa/flake8
rev: 3.7.9
rev: 3.8.3
hooks:
- id: flake8
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.761 # Must match ci/requirements/*.yml
rev: v0.781 # Must match ci/requirements/*.yml
hooks:
- id: mypy
# run this occasionally, ref discussion https://github.com/pydata/xarray/pull/3194
Expand Down
161 changes: 138 additions & 23 deletions cf_xarray/accessor.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
import functools
import inspect
import itertools
import textwrap
from collections import ChainMap
from contextlib import suppress
from typing import Callable, List, Mapping, MutableMapping, Optional, Set, Tuple, Union
from typing import (
Callable,
Hashable,
List,
Mapping,
MutableMapping,
Optional,
Set,
Tuple,
Union,
)

import xarray as xr
from xarray import DataArray, Dataset
Expand Down Expand Up @@ -106,6 +117,11 @@
]


def _strip_none_list(lst: List[Optional[str]]) -> List[str]:
""" The mappers can return [None]. Strip that when necessary. Keeps mypy happy."""
return [item for item in lst if item != [None]] # type: ignore


def _get_axis_coord_single(
var: Union[xr.DataArray, xr.Dataset],
key: str,
Expand Down Expand Up @@ -176,8 +192,8 @@ def _get_axis_coord(
results: Set = set()
for coord in search_in:
for criterion, valid_values in coordinate_criteria.items():
if key in valid_values: # type: ignore
expected = valid_values[key] # type: ignore
if key in valid_values:
expected = valid_values[key]
if var.coords[coord].attrs.get(criterion, None) in expected:
results.update((coord,))

Expand Down Expand Up @@ -246,6 +262,31 @@ def _get_measure(
}


def _filter_by_standard_names(ds: xr.Dataset, name: Union[str, List[str]]) -> List[str]:
""" returns a list of variable names with standard names matching name. """
if isinstance(name, str):
name = [name]

varnames = []
counts = dict.fromkeys(name, 0)
for vname, var in ds.variables.items():
stdname = var.attrs.get("standard_name", None)
if stdname in name:
varnames.append(str(vname))
counts[stdname] += 1

return varnames


def _get_list_standard_names(obj: xr.Dataset) -> List[str]:
""" Returns a sorted list of standard names in Dataset. """
names = []
for k, v in obj.variables.items():
if "standard_name" in v.attrs:
names.append(v.attrs["standard_name"])
return sorted(names)


def _getattr(
obj: Union[DataArray, Dataset],
attr: str,
Expand Down Expand Up @@ -503,6 +544,16 @@ def _describe(self):
text += f"\t{measure}: unsupported\n"
else:
text += f"\t{measure}: {_get_measure(self._obj, measure, error=False, default=None)}\n"

text += "\nStandard Names:\n"
if isinstance(self._obj, xr.DataArray):
text += "\tunsupported\n"
else:
stdnames = _get_list_standard_names(self._obj)
text += "\t"
text += "\n".join(
textwrap.wrap(f"{stdnames!r}", 70, break_long_words=False)
)
return text

def describe(self):
Expand All @@ -529,32 +580,96 @@ def get_valid_keys(self) -> Set[str]:
]
if measures:
varnames.append(*measures)

if not isinstance(self._obj, xr.DataArray):
varnames.extend(_get_list_standard_names(self._obj))
return set(varnames)

def __getitem__(self, key: Union[str, List[str]]):

kind = str(type(self._obj).__name__)
scalar_key = isinstance(key, str)
if scalar_key:
key = (key,) # type: ignore

varnames: List[Hashable] = []
coords: List[Hashable] = []
successful = dict.fromkeys(key, False)
for k in key:
if k in _AXIS_NAMES + _COORD_NAMES:
names = _get_axis_coord(self._obj, k)
successful[k] = bool(names)
varnames.extend(_strip_none_list(names))
coords.extend(_strip_none_list(names))
elif k in _CELL_MEASURES:
if isinstance(self._obj, xr.Dataset):
raise NotImplementedError(
"Invalid key {k!r}. Cell measures not implemented for Dataset yet."
)
else:
measure = _get_measure(self._obj, k)
successful[k] = bool(measure)
if measure:
varnames.append(measure)
elif not isinstance(self._obj, xr.DataArray):
stdnames = _filter_by_standard_names(self._obj, k)
successful[k] = bool(stdnames)
varnames.extend(stdnames)
coords.extend(list(set(stdnames).intersection(set(self._obj.coords))))

# these are not special names but could be variable names in underlying object
# we allow this so that we can return variables with appropriate CF auxiliary variables
varnames.extend([k for k, v in successful.items() if not v])
assert len(varnames) > 0

try:
# TODO: make this a get_auxiliary_variables function
# make sure to set coordinate variables referred to in "coordinates" attribute
for name in varnames:
attrs = self._obj[name].attrs
if "coordinates" in attrs:
coords.extend(attrs.get("coordinates").split(" "))

if "cell_measures" in attrs:
measures = [
_get_measure(self._obj[name], measure)
for measure in _CELL_MEASURES
if measure in attrs["cell_measures"]
]
coords.extend(_strip_none_list(measures))

varnames.extend(coords)
if isinstance(self._obj, xr.DataArray):
ds = self._obj._to_temp_dataset()
else:
ds = self._obj
ds = ds.reset_coords()[varnames]
if isinstance(self._obj, DataArray):
if scalar_key and len(ds.variables) == 1:
# single dimension coordinates
return ds[list(ds.variables.keys())[0]].squeeze(drop=True)
elif scalar_key and len(ds.coords) > 1:
raise NotImplementedError(
"Not sure what to return when given scalar key for DataArray and it has multiple values. "
"Please open an issue."
)
elif not scalar_key:
return ds.set_coords(coords)
else:
return ds.set_coords(coords)

except KeyError:
raise KeyError(
f"{kind}.cf does not understand the key {k!r}. "
f"Use {kind}.cf.describe() to see a list of key names that can be interpreted."
)


@xr.register_dataset_accessor("cf")
class CFDatasetAccessor(CFAccessor):
def __getitem__(self, key):
if key in _AXIS_NAMES + _COORD_NAMES:
varnames = _get_axis_coord(self._obj, key)
return self._obj.reset_coords()[varnames].set_coords(varnames)
elif key in _CELL_MEASURES:
raise NotImplementedError("measures not implemented for Dataset yet.")
else:
raise KeyError(
f"Dataset.cf does not understand the key {key!r}. Use Dataset.cf.describe() to see a list of key names that can be interpreted."
)
pass


@xr.register_dataarray_accessor("cf")
class CFDataArrayAccessor(CFAccessor):
def __getitem__(self, key):
if key in _AXIS_NAMES + _COORD_NAMES:
varname = _get_axis_coord_single(self._obj, key)
return self._obj[varname].reset_coords(drop=True)
elif key in _CELL_MEASURES:
return self._obj[_get_measure(self._obj, key)]
else:
raise KeyError(
f"DataArray.cf does not understand the key {key!r}. Use DataArray.cf.describe() to see a list of key names that can be interpreted."
)
pass
49 changes: 49 additions & 0 deletions cf_xarray/tests/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import numpy as np
import xarray as xr

airds = xr.tutorial.open_dataset("air_temperature").isel(time=slice(4), lon=slice(50))
airds.air.attrs["cell_measures"] = "area: cell_area"
airds.air.attrs["standard_name"] = "air_temperature"
airds.coords["cell_area"] = (
xr.DataArray(np.cos(airds.lat * np.pi / 180))
* xr.ones_like(airds.lon)
* 105e3
* 110e3
)

ds_no_attrs = airds.copy(deep=True)
for variable in ds_no_attrs.variables:
ds_no_attrs[variable].attrs = {}


popds = xr.Dataset()
popds.coords["TLONG"] = (
("nlat", "nlon"),
np.ones((20, 30)),
{"axis": "X", "units": "degrees_east"},
)
popds.coords["TLAT"] = (
("nlat", "nlon"),
2 * np.ones((20, 30)),
{"axis": "Y", "units": "degrees_north"},
)
popds.coords["ULONG"] = (
("nlat", "nlon"),
0.5 * np.ones((20, 30)),
{"axis": "X", "units": "degrees_east"},
)
popds.coords["ULAT"] = (
("nlat", "nlon"),
2.5 * np.ones((20, 30)),
{"axis": "Y", "units": "degrees_north"},
)
popds["UVEL"] = (
("nlat", "nlon"),
np.ones((20, 30)) * 15,
{"coordinates": "ULONG ULAT", "standard_name": "sea_water_x_velocity"},
)
popds["TEMP"] = (
("nlat", "nlon"),
np.ones((20, 30)) * 15,
{"coordinates": "TLONG TLAT", "standard_name": "sea_water_potential_temperature"},
)
Loading

0 comments on commit d454cb4

Please sign in to comment.