diff --git a/flox/core.py b/flox/core.py index a51b6ed3f..063a8fa3a 100644 --- a/flox/core.py +++ b/flox/core.py @@ -803,10 +803,15 @@ def _aggregate( keepdims, fill_value: Any, reindex: bool, + return_array: bool, ) -> FinalResultsDict: """Final aggregation step of tree reduction""" results = combine(x_chunk, agg, axis, keepdims, is_aggregate=True) - return _finalize_results(results, agg, axis, expected_groups, fill_value, reindex) + finalized = _finalize_results(results, agg, axis, expected_groups, fill_value, reindex) + if return_array: + return finalized[agg.name] + else: + return finalized def _expand_dims(results: IntermediateDict) -> IntermediateDict: @@ -1287,6 +1292,7 @@ def dask_groupby_agg( group_chunks: tuple[tuple[Union[int, float], ...]] = ( (len(expected_groups),) if expected_groups is not None else (np.nan,), ) + groups_are_unknown = is_duck_dask_array(by_input) and expected_groups is None if method in ["map-reduce", "cohorts"]: combine: Callable[..., IntermediateDict] @@ -1316,16 +1322,32 @@ def dask_groupby_agg( reduced = tree_reduce( intermediate, combine=partial(combine, agg=agg), - aggregate=partial(aggregate, expected_groups=expected_groups, reindex=reindex), + aggregate=partial( + aggregate, + expected_groups=expected_groups, + reindex=reindex, + return_array=not groups_are_unknown, + ), ) - if is_duck_dask_array(by_input) and expected_groups is None: + if groups_are_unknown: groups = _extract_unknown_groups(reduced, group_chunks=group_chunks, dtype=by.dtype) + result = dask.array.map_blocks( + _extract_result, + reduced, + chunks=reduced.chunks[: -len(axis)] + group_chunks, + drop_axis=axis[:-1], + dtype=agg.dtype[agg.name], + key=agg.name, + name=f"{name}-{token}", + ) + else: if expected_groups is None: expected_groups_ = _get_expected_groups(by_input, sort=sort) else: expected_groups_ = expected_groups groups = (expected_groups_.to_numpy(),) + result = reduced elif method == "cohorts": chunks_cohorts = find_group_cohorts( @@ -1344,12 +1366,14 @@ def dask_groupby_agg( tree_reduce( reindexed, combine=partial(combine, agg=agg, reindex=True), - aggregate=partial(aggregate, expected_groups=index, reindex=True), + aggregate=partial( + aggregate, expected_groups=index, reindex=True, return_array=True + ), ) ) groups_.append(cohort) - reduced = dask.array.concatenate(reduced_, axis=-1) + result = dask.array.concatenate(reduced_, axis=-1) groups = (np.concatenate(groups_),) group_chunks = (tuple(len(cohort) for cohort in groups_),) @@ -1375,21 +1399,24 @@ def dask_groupby_agg( for ax, chunks in zip(axis, group_chunks): adjust_chunks[ax] = chunks - result = dask.array.blockwise( - _extract_result, - inds[: -len(axis)] + (inds[-1],), - reduced, - inds, - adjust_chunks=adjust_chunks, - dtype=agg.dtype[agg.name], - key=agg.name, - name=f"{name}-{token}", - ) - + # result = dask.array.blockwise( + # _extract_result, + # inds[: -len(axis)] + (inds[-1],), + # reduced, + # inds, + # adjust_chunks=adjust_chunks, + # dtype=agg.dtype[agg.name], + # key=agg.name, + # name=f"{name}-{token}", + # ) return (result, groups) def _extract_result(result_dict: FinalResultsDict, key) -> np.ndarray: + from dask.array.core import deepfirst + + if not isinstance(result_dict, dict): + result_dict = deepfirst(result_dict) return result_dict[key]