Skip to content

Commit b8b7857

Browse files
Add extra overload for to_netcdf (#8268)
* Add extra overload for to_netcdf The current signature does not match with pyright if a non literal bool is passed. * fix typing * add entry to whats-new --------- Co-authored-by: Michael Niklas <[email protected]>
1 parent d44bfd7 commit b8b7857

File tree

4 files changed

+101
-8
lines changed

4 files changed

+101
-8
lines changed

doc/whats-new.rst

+3
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ Bug fixes
6565
- Static typing of ``p0`` and ``bounds`` arguments of :py:func:`xarray.DataArray.curvefit` and :py:func:`xarray.Dataset.curvefit`
6666
was changed to ``Mapping`` (:pull:`8502`).
6767
By `Michael Niklas <https://github.com/headtr1ck>`_.
68+
- Fix typing of :py:func:`xarray.DataArray.to_netcdf` and :py:func:`xarray.Dataset.to_netcdf`
69+
when ``compute`` is evaluated to bool instead of a Literal (:pull:`8268`).
70+
By `Jens Hedegaard Nielsen <https://github.com/jenshnielsen>`_.
6871

6972
Documentation
7073
~~~~~~~~~~~~~

xarray/backends/api.py

+56
Original file line numberDiff line numberDiff line change
@@ -1160,6 +1160,62 @@ def to_netcdf(
11601160
...
11611161

11621162

1163+
# if compute cannot be evaluated at type check time
1164+
# we may get back either Delayed or None
1165+
@overload
1166+
def to_netcdf(
1167+
dataset: Dataset,
1168+
path_or_file: str | os.PathLike,
1169+
mode: Literal["w", "a"] = "w",
1170+
format: T_NetcdfTypes | None = None,
1171+
group: str | None = None,
1172+
engine: T_NetcdfEngine | None = None,
1173+
encoding: Mapping[Hashable, Mapping[str, Any]] | None = None,
1174+
unlimited_dims: Iterable[Hashable] | None = None,
1175+
compute: bool = False,
1176+
multifile: Literal[False] = False,
1177+
invalid_netcdf: bool = False,
1178+
) -> Delayed | None:
1179+
...
1180+
1181+
1182+
# if multifile cannot be evaluated at type check time
1183+
# we may get back either writer and datastore or Delayed or None
1184+
@overload
1185+
def to_netcdf(
1186+
dataset: Dataset,
1187+
path_or_file: str | os.PathLike,
1188+
mode: Literal["w", "a"] = "w",
1189+
format: T_NetcdfTypes | None = None,
1190+
group: str | None = None,
1191+
engine: T_NetcdfEngine | None = None,
1192+
encoding: Mapping[Hashable, Mapping[str, Any]] | None = None,
1193+
unlimited_dims: Iterable[Hashable] | None = None,
1194+
compute: bool = False,
1195+
multifile: bool = False,
1196+
invalid_netcdf: bool = False,
1197+
) -> tuple[ArrayWriter, AbstractDataStore] | Delayed | None:
1198+
...
1199+
1200+
1201+
# Any
1202+
@overload
1203+
def to_netcdf(
1204+
dataset: Dataset,
1205+
path_or_file: str | os.PathLike | None,
1206+
mode: Literal["w", "a"] = "w",
1207+
format: T_NetcdfTypes | None = None,
1208+
group: str | None = None,
1209+
engine: T_NetcdfEngine | None = None,
1210+
encoding: Mapping[Hashable, Mapping[str, Any]] | None = None,
1211+
unlimited_dims: Iterable[Hashable] | None = None,
1212+
compute: bool = False,
1213+
multifile: bool = False,
1214+
invalid_netcdf: bool = False,
1215+
) -> tuple[ArrayWriter, AbstractDataStore] | bytes | Delayed | None:
1216+
...
1217+
1218+
11631219
def to_netcdf(
11641220
dataset: Dataset,
11651221
path_or_file: str | os.PathLike | None = None,

xarray/core/dataarray.py

+21-4
Original file line numberDiff line numberDiff line change
@@ -3914,6 +3914,23 @@ def to_netcdf(
39143914
) -> bytes:
39153915
...
39163916

3917+
# compute=False returns dask.Delayed
3918+
@overload
3919+
def to_netcdf(
3920+
self,
3921+
path: str | PathLike,
3922+
mode: Literal["w", "a"] = "w",
3923+
format: T_NetcdfTypes | None = None,
3924+
group: str | None = None,
3925+
engine: T_NetcdfEngine | None = None,
3926+
encoding: Mapping[Hashable, Mapping[str, Any]] | None = None,
3927+
unlimited_dims: Iterable[Hashable] | None = None,
3928+
*,
3929+
compute: Literal[False],
3930+
invalid_netcdf: bool = False,
3931+
) -> Delayed:
3932+
...
3933+
39173934
# default return None
39183935
@overload
39193936
def to_netcdf(
@@ -3930,7 +3947,8 @@ def to_netcdf(
39303947
) -> None:
39313948
...
39323949

3933-
# compute=False returns dask.Delayed
3950+
# if compute cannot be evaluated at type check time
3951+
# we may get back either Delayed or None
39343952
@overload
39353953
def to_netcdf(
39363954
self,
@@ -3941,10 +3959,9 @@ def to_netcdf(
39413959
engine: T_NetcdfEngine | None = None,
39423960
encoding: Mapping[Hashable, Mapping[str, Any]] | None = None,
39433961
unlimited_dims: Iterable[Hashable] | None = None,
3944-
*,
3945-
compute: Literal[False],
3962+
compute: bool = True,
39463963
invalid_netcdf: bool = False,
3947-
) -> Delayed:
3964+
) -> Delayed | None:
39483965
...
39493966

39503967
def to_netcdf(

xarray/core/dataset.py

+21-4
Original file line numberDiff line numberDiff line change
@@ -2159,6 +2159,23 @@ def to_netcdf(
21592159
) -> bytes:
21602160
...
21612161

2162+
# compute=False returns dask.Delayed
2163+
@overload
2164+
def to_netcdf(
2165+
self,
2166+
path: str | PathLike,
2167+
mode: Literal["w", "a"] = "w",
2168+
format: T_NetcdfTypes | None = None,
2169+
group: str | None = None,
2170+
engine: T_NetcdfEngine | None = None,
2171+
encoding: Mapping[Any, Mapping[str, Any]] | None = None,
2172+
unlimited_dims: Iterable[Hashable] | None = None,
2173+
*,
2174+
compute: Literal[False],
2175+
invalid_netcdf: bool = False,
2176+
) -> Delayed:
2177+
...
2178+
21622179
# default return None
21632180
@overload
21642181
def to_netcdf(
@@ -2175,7 +2192,8 @@ def to_netcdf(
21752192
) -> None:
21762193
...
21772194

2178-
# compute=False returns dask.Delayed
2195+
# if compute cannot be evaluated at type check time
2196+
# we may get back either Delayed or None
21792197
@overload
21802198
def to_netcdf(
21812199
self,
@@ -2186,10 +2204,9 @@ def to_netcdf(
21862204
engine: T_NetcdfEngine | None = None,
21872205
encoding: Mapping[Any, Mapping[str, Any]] | None = None,
21882206
unlimited_dims: Iterable[Hashable] | None = None,
2189-
*,
2190-
compute: Literal[False],
2207+
compute: bool = True,
21912208
invalid_netcdf: bool = False,
2192-
) -> Delayed:
2209+
) -> Delayed | None:
21932210
...
21942211

21952212
def to_netcdf(

0 commit comments

Comments
 (0)