From 100a192a7972eb58b73a1880f458f0bdba1b97ad Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 15 Nov 2022 13:04:36 -0700 Subject: [PATCH 1/2] Some typing --- flox/core.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/flox/core.py b/flox/core.py index aa54f6757..dcfb325a3 100644 --- a/flox/core.py +++ b/flox/core.py @@ -420,7 +420,7 @@ def offset_labels(labels: np.ndarray, ngroups: int) -> tuple[np.ndarray, int]: def factorize_( by: tuple, axis: T_AxesOpt, - expected_groups: tuple[pd.Index, ...] = None, + expected_groups: tuple[pd.Index, ...] | None = None, reindex: bool = False, sort=True, fastpath=False, @@ -873,7 +873,7 @@ def _simple_combine( return results -def _conc2(x_chunk, key1, key2=slice(None), axis: T_Axes = None) -> np.ndarray: +def _conc2(x_chunk, key1, key2=slice(None), axis: T_Axes | None = None) -> np.ndarray: """copied from dask.array.reductions.mean_combine""" from dask.array.core import _concatenate2 from dask.utils import deepmap @@ -1071,7 +1071,7 @@ def _reduce_blockwise( return result -def _normalize_indexes(array, flatblocks, blkshape): +def _normalize_indexes(array: DaskArray, flatblocks, blkshape) -> tuple: """ .blocks accessor can only accept one iterable at a time, but can handle multiple slices. @@ -1397,7 +1397,7 @@ def dask_groupby_agg( return (result, groups) -def _collapse_blocks_along_axes(reduced, axis, group_chunks): +def _collapse_blocks_along_axes(reduced: DaskArray, axis: T_Axis, group_chunks) -> DaskArray: import dask.array from dask.highlevelgraph import HighLevelGraph From efb8416dee88b1d14b6778116f2263a758e1a6e7 Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 15 Nov 2022 13:12:17 -0700 Subject: [PATCH 2/2] fixes. --- flox/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flox/core.py b/flox/core.py index dcfb325a3..4835da55b 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1083,7 +1083,7 @@ def _normalize_indexes(array: DaskArray, flatblocks, blkshape) -> tuple: """ unraveled = np.unravel_index(flatblocks, blkshape) - normalized: list[Union[int, np.ndarray, slice]] = [] + normalized: list[Union[int, slice, list[int]]] = [] for ax, idx in enumerate(unraveled): i = _unique(idx).squeeze() if i.ndim == 0: @@ -1397,7 +1397,7 @@ def dask_groupby_agg( return (result, groups) -def _collapse_blocks_along_axes(reduced: DaskArray, axis: T_Axis, group_chunks) -> DaskArray: +def _collapse_blocks_along_axes(reduced: DaskArray, axis: T_Axes, group_chunks) -> DaskArray: import dask.array from dask.highlevelgraph import HighLevelGraph