diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 2722ca0b38a..296cb6d6803 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -65,6 +65,9 @@ Bug fixes - Static typing of ``p0`` and ``bounds`` arguments of :py:func:`xarray.DataArray.curvefit` and :py:func:`xarray.Dataset.curvefit` was changed to ``Mapping`` (:pull:`8502`). By `Michael Niklas `_. +- Fix typing of :py:func:`xarray.DataArray.to_netcdf` and :py:func:`xarray.Dataset.to_netcdf` + when ``compute`` is evaluated to bool instead of a Literal (:pull:`8268`). + By `Jens Hedegaard Nielsen `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/backends/api.py b/xarray/backends/api.py index c59f2f8d81b..1d538bf94ed 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -1160,6 +1160,62 @@ def to_netcdf( ... +# if compute cannot be evaluated at type check time +# we may get back either Delayed or None +@overload +def to_netcdf( + dataset: Dataset, + path_or_file: str | os.PathLike, + mode: Literal["w", "a"] = "w", + format: T_NetcdfTypes | None = None, + group: str | None = None, + engine: T_NetcdfEngine | None = None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, + compute: bool = False, + multifile: Literal[False] = False, + invalid_netcdf: bool = False, +) -> Delayed | None: + ... + + +# if multifile cannot be evaluated at type check time +# we may get back either writer and datastore or Delayed or None +@overload +def to_netcdf( + dataset: Dataset, + path_or_file: str | os.PathLike, + mode: Literal["w", "a"] = "w", + format: T_NetcdfTypes | None = None, + group: str | None = None, + engine: T_NetcdfEngine | None = None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, + compute: bool = False, + multifile: bool = False, + invalid_netcdf: bool = False, +) -> tuple[ArrayWriter, AbstractDataStore] | Delayed | None: + ... + + +# Any +@overload +def to_netcdf( + dataset: Dataset, + path_or_file: str | os.PathLike | None, + mode: Literal["w", "a"] = "w", + format: T_NetcdfTypes | None = None, + group: str | None = None, + engine: T_NetcdfEngine | None = None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, + compute: bool = False, + multifile: bool = False, + invalid_netcdf: bool = False, +) -> tuple[ArrayWriter, AbstractDataStore] | bytes | Delayed | None: + ... + + def to_netcdf( dataset: Dataset, path_or_file: str | os.PathLike | None = None, diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 1d7e82d3044..c8cc579c8b7 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -3914,6 +3914,23 @@ def to_netcdf( ) -> bytes: ... + # compute=False returns dask.Delayed + @overload + def to_netcdf( + self, + path: str | PathLike, + mode: Literal["w", "a"] = "w", + format: T_NetcdfTypes | None = None, + group: str | None = None, + engine: T_NetcdfEngine | None = None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, + *, + compute: Literal[False], + invalid_netcdf: bool = False, + ) -> Delayed: + ... + # default return None @overload def to_netcdf( @@ -3930,7 +3947,8 @@ def to_netcdf( ) -> None: ... - # compute=False returns dask.Delayed + # if compute cannot be evaluated at type check time + # we may get back either Delayed or None @overload def to_netcdf( self, @@ -3941,10 +3959,9 @@ def to_netcdf( engine: T_NetcdfEngine | None = None, encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, unlimited_dims: Iterable[Hashable] | None = None, - *, - compute: Literal[False], + compute: bool = True, invalid_netcdf: bool = False, - ) -> Delayed: + ) -> Delayed | None: ... def to_netcdf( diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index b8093d3dd78..b430b8fddd8 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -2159,6 +2159,23 @@ def to_netcdf( ) -> bytes: ... + # compute=False returns dask.Delayed + @overload + def to_netcdf( + self, + path: str | PathLike, + mode: Literal["w", "a"] = "w", + format: T_NetcdfTypes | None = None, + group: str | None = None, + engine: T_NetcdfEngine | None = None, + encoding: Mapping[Any, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, + *, + compute: Literal[False], + invalid_netcdf: bool = False, + ) -> Delayed: + ... + # default return None @overload def to_netcdf( @@ -2175,7 +2192,8 @@ def to_netcdf( ) -> None: ... - # compute=False returns dask.Delayed + # if compute cannot be evaluated at type check time + # we may get back either Delayed or None @overload def to_netcdf( self, @@ -2186,10 +2204,9 @@ def to_netcdf( engine: T_NetcdfEngine | None = None, encoding: Mapping[Any, Mapping[str, Any]] | None = None, unlimited_dims: Iterable[Hashable] | None = None, - *, - compute: Literal[False], + compute: bool = True, invalid_netcdf: bool = False, - ) -> Delayed: + ) -> Delayed | None: ... def to_netcdf(