From 27a4e9a51a519778ac662852c28836c8e81b8dbc Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sat, 26 Nov 2022 20:39:36 -0700 Subject: [PATCH] Try cleaning up some expected_groups logic (#175) * Try cleaning up some expected_groups logic * Fix _extract_unknown_groups * Fixes --- flox/core.py | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/flox/core.py b/flox/core.py index 30db58012..4179bc236 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1164,7 +1164,7 @@ def subset_to_blocks( return dask.array.Array(graph, name, chunks, meta=array) -def _extract_unknown_groups(reduced, group_chunks, dtype) -> tuple[DaskArray]: +def _extract_unknown_groups(reduced, dtype) -> tuple[DaskArray]: import dask.array from dask.highlevelgraph import HighLevelGraph @@ -1180,7 +1180,7 @@ def _extract_unknown_groups(reduced, group_chunks, dtype) -> tuple[DaskArray]: dask.array.Array( HighLevelGraph.from_collections(groups_token, layer, dependencies=[reduced]), groups_token, - chunks=group_chunks, + chunks=((np.nan,),), meta=np.array([], dtype=dtype), ), ) @@ -1293,14 +1293,7 @@ def dask_groupby_agg( name=f"{name}-chunk-{token}", ) - if expected_groups is None: - if is_duck_dask_array(by_input): - expected_groups = None - else: - expected_groups = _get_expected_groups(by_input, sort=sort) - group_chunks: tuple[tuple[Union[int, float], ...]] = ( - (len(expected_groups),) if expected_groups is not None else (np.nan,), - ) + group_chunks: tuple[tuple[Union[int, float], ...]] if method in ["map-reduce", "cohorts"]: combine: Callable[..., IntermediateDict] @@ -1333,13 +1326,13 @@ def dask_groupby_agg( aggregate=partial(aggregate, expected_groups=expected_groups, reindex=reindex), ) if is_duck_dask_array(by_input) and expected_groups is None: - groups = _extract_unknown_groups(reduced, group_chunks=group_chunks, dtype=by.dtype) + groups = _extract_unknown_groups(reduced, dtype=by.dtype) + group_chunks = ((np.nan,),) else: if expected_groups is None: - expected_groups_ = _get_expected_groups(by_input, sort=sort) - else: - expected_groups_ = expected_groups - groups = (expected_groups_.to_numpy(),) + expected_groups = _get_expected_groups(by_input, sort=sort) + groups = (expected_groups.to_numpy(),) + group_chunks = ((len(expected_groups),),) elif method == "cohorts": chunks_cohorts = find_group_cohorts(