diff --git a/flox/core.py b/flox/core.py index aa54f6757..4835da55b 100644 --- a/flox/core.py +++ b/flox/core.py @@ -420,7 +420,7 @@ def offset_labels(labels: np.ndarray, ngroups: int) -> tuple[np.ndarray, int]: def factorize_( by: tuple, axis: T_AxesOpt, - expected_groups: tuple[pd.Index, ...] = None, + expected_groups: tuple[pd.Index, ...] | None = None, reindex: bool = False, sort=True, fastpath=False, @@ -873,7 +873,7 @@ def _simple_combine( return results -def _conc2(x_chunk, key1, key2=slice(None), axis: T_Axes = None) -> np.ndarray: +def _conc2(x_chunk, key1, key2=slice(None), axis: T_Axes | None = None) -> np.ndarray: """copied from dask.array.reductions.mean_combine""" from dask.array.core import _concatenate2 from dask.utils import deepmap @@ -1071,7 +1071,7 @@ def _reduce_blockwise( return result -def _normalize_indexes(array, flatblocks, blkshape): +def _normalize_indexes(array: DaskArray, flatblocks, blkshape) -> tuple: """ .blocks accessor can only accept one iterable at a time, but can handle multiple slices. @@ -1083,7 +1083,7 @@ def _normalize_indexes(array, flatblocks, blkshape): """ unraveled = np.unravel_index(flatblocks, blkshape) - normalized: list[Union[int, np.ndarray, slice]] = [] + normalized: list[Union[int, slice, list[int]]] = [] for ax, idx in enumerate(unraveled): i = _unique(idx).squeeze() if i.ndim == 0: @@ -1397,7 +1397,7 @@ def dask_groupby_agg( return (result, groups) -def _collapse_blocks_along_axes(reduced, axis, group_chunks): +def _collapse_blocks_along_axes(reduced: DaskArray, axis: T_Axes, group_chunks) -> DaskArray: import dask.array from dask.highlevelgraph import HighLevelGraph