Skip to content

Commit

Permalink
Error if multiple results for scalar "standard_name" key
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Sep 27, 2020
1 parent c1d3197 commit c626d74
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 11 deletions.
22 changes: 12 additions & 10 deletions cf_xarray/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -984,32 +984,34 @@ def __getitem__(self, key: Union[str, List[str]]):
)

if scalar_key:
axis_coord_mapper = _get_axis_coord_single
key = (key,) # type: ignore
else:
axis_coord_mapper = _get_axis_coord

def check_results(names, k):
if scalar_key and len(names) > 1:
raise ValueError(
f"Receive multiple variables for key {k!r}: {names}. "
f"Expected only one. Please pass a list [{k!r}] "
f"instead to get all variables matching {k!r}."
)

varnames: List[Hashable] = []
coords: List[Hashable] = []
successful = dict.fromkeys(key, False)
for k in key:
if k in _AXIS_NAMES + _COORD_NAMES:
try:
names = axis_coord_mapper(self._obj, k)
except KeyError as e:
raise KeyError(
f"Receive multiple variables for key {k!r}. Expected only one. Please pass a list [{k!r}] instead to get all variables matching {k!r}."
)
raise e
names = _get_axis_coord(self._obj, k)
check_results(names, k)
successful[k] = bool(names)
coords.extend(names)
elif k in _CELL_MEASURES:
measure = _get_measure(self._obj, k)
check_results(measure, k)
successful[k] = bool(measure)
if measure:
varnames.extend(measure)
elif not isinstance(self._obj, DataArray):
stdnames = _get_with_standard_name(self._obj, k)
check_results(stdnames, k)
successful[k] = bool(stdnames)
varnames.extend(stdnames)
coords.extend(list(set(stdnames) & set(self._obj.coords)))
Expand Down
4 changes: 3 additions & 1 deletion cf_xarray/tests/test_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ def test_getitem_standard_name():

ds = airds.copy(deep=True)
ds["air2"] = ds.air
actual = ds.cf["air_temperature"]
with pytest.raises(ValueError):
ds.cf["air_temperature"]
actual = ds.cf[["air_temperature"]]
expected = ds[["air", "air2"]]
assert_identical(actual, expected)

Expand Down

0 comments on commit c626d74

Please sign in to comment.