diff --git a/flox/core.py b/flox/core.py index 1437e506c..9ff99b583 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1344,17 +1344,8 @@ def dask_groupby_agg( # find number of groups in each chunk, this is needed for output chunks # along the reduced axis slices = slices_from_chunks(tuple(array.chunks[ax] for ax in axis)) - if expected_groups is None: - groups_in_block = tuple(_unique(by_input[slc]) for slc in slices) - else: - # For cohorts, we could be indexing a block with groups that - # are not in the cohort (usually for nD `by`) - # Only keep the expected groups. - groups_in_block = tuple( - np.intersect1d(by_input[slc], expected_groups) for slc in slices - ) + groups_in_block = tuple(_unique(by_input[slc]) for slc in slices) groups = (np.concatenate(groups_in_block),) - ngroups_per_block = tuple(len(grp) for grp in groups_in_block) group_chunks = (ngroups_per_block,)