Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved DataArray typing #6637

Merged
merged 14 commits into from
May 27, 2022
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
*.py[cod]
__pycache__
.env
.venv

# example caches from Hypothesis
.hypothesis/
Expand Down
28 changes: 17 additions & 11 deletions xarray/core/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Tuple,
Type,
TypeVar,
cast,
)

import numpy as np
Expand All @@ -30,7 +31,7 @@
if TYPE_CHECKING:
from .dataarray import DataArray
from .dataset import Dataset
from .types import JoinOptions
from .types import JoinOptions, T_DataArray, T_DataArrayOrSet, T_Dataset

DataAlignable = TypeVar("DataAlignable", bound=DataWithCoords)

Expand Down Expand Up @@ -559,7 +560,7 @@ def align(self) -> None:
def align(
*objects: DataAlignable,
join: JoinOptions = "inner",
copy=True,
copy: bool = True,
indexes=None,
exclude=frozenset(),
fill_value=dtypes.NA,
Expand Down Expand Up @@ -592,7 +593,7 @@ def align(
those of the first object with that dimension. Indexes for the same
dimension must have the same size in all objects.

copy : bool, optional
copy : bool, default: True
If ``copy=True``, data in the return values is always copied. If
``copy=False`` and reindexing is unnecessary, or can be performed with
only slice operations, then the output may share memory with the input.
Expand All @@ -609,7 +610,7 @@ def align(

Returns
-------
aligned : DataArray or Dataset
aligned : tuple of DataArray or Dataset
Tuple of objects with the same type as `*objects` with aligned
coordinates.

Expand Down Expand Up @@ -935,7 +936,9 @@ def _get_broadcast_dims_map_common_coords(args, exclude):
return dims_map, common_coords


def _broadcast_helper(arg, exclude, dims_map, common_coords):
def _broadcast_helper(
arg: T_DataArrayOrSet, exclude, dims_map, common_coords
) -> T_DataArrayOrSet:

from .dataarray import DataArray
from .dataset import Dataset
Expand All @@ -950,22 +953,25 @@ def _set_dims(var):

return var.set_dims(var_dims_map)

def _broadcast_array(array):
def _broadcast_array(array: T_DataArray) -> T_DataArray:
data = _set_dims(array.variable)
coords = dict(array.coords)
coords.update(common_coords)
return DataArray(data, coords, data.dims, name=array.name, attrs=array.attrs)
return array.__class__(
data, coords, data.dims, name=array.name, attrs=array.attrs
)

def _broadcast_dataset(ds):
def _broadcast_dataset(ds: T_Dataset) -> T_Dataset:
data_vars = {k: _set_dims(ds.variables[k]) for k in ds.data_vars}
coords = dict(ds.coords)
coords.update(common_coords)
return Dataset(data_vars, coords, ds.attrs)
return ds.__class__(data_vars, coords, ds.attrs)

# remove casts once https://github.com/python/mypy/issues/12800 is resolved
if isinstance(arg, DataArray):
return _broadcast_array(arg)
return cast("T_DataArrayOrSet", _broadcast_array(arg))
elif isinstance(arg, Dataset):
return _broadcast_dataset(arg)
return cast("T_DataArrayOrSet", _broadcast_dataset(arg))
else:
raise ValueError("all input must be Dataset or DataArray objects")

Expand Down
Loading