Skip to content

Commit

Permalink
FOr all methods
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Oct 25, 2022
1 parent 2feb8b7 commit 91110e6
Showing 1 changed file with 17 additions and 30 deletions.
47 changes: 17 additions & 30 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1195,7 +1195,6 @@ def dask_groupby_agg(

import dask.array
from dask.array.core import slices_from_chunks
from dask.highlevelgraph import HighLevelGraph

# I think _tree_reduce expects this
assert isinstance(axis, Sequence)
Expand Down Expand Up @@ -1369,35 +1368,23 @@ def dask_groupby_agg(
raise ValueError(f"Unknown method={method}.")

# extract results from the dict
output_chunks = reduced.chunks[: -len(axis)] + group_chunks
agg_name = f"{name}-{token}"
if method == "blockwise" and len(axis) == 1:
result = reduced.map_blocks(
_extract_result,
key=agg.name,
dtype=agg.dtype[agg.name],
chunks=output_chunks,
name=agg_name,
)

else:
ochunks = tuple(range(len(chunks_v)) for chunks_v in output_chunks)
layer2: dict[tuple, tuple] = {}
for ochunk in itertools.product(*ochunks):
if method == "blockwise":
nblocks = tuple(len(array.chunks[ax]) for ax in axis)
inchunk = ochunk[:-1] + np.unravel_index(ochunk[-1], nblocks)
else:
inchunk = ochunk[:-1] + (0,) * (len(axis) - 1) + (ochunk[-1],)

layer2[(agg_name, *ochunk)] = (operator.getitem, (reduced.name, *inchunk), agg.name)

result = dask.array.Array(
HighLevelGraph.from_collections(agg_name, layer2, dependencies=[reduced]),
agg_name,
chunks=output_chunks,
dtype=agg.dtype[agg.name],
)
adjust_chunks = {inds[ax]: lambda: 0 for ax in axis}
if method == "blockwise" and len(axis) > 1:
nblocks = tuple(len(array.chunks[ax]) for ax in axis)
group_chunks = np.unravel_index(group_chunks, nblocks)
for ax, chunks in zip(axis, group_chunks):
adjust_chunks[ax] = chunks

result = dask.array.blockwise(
_extract_result,
inds[: -len(axis)] + (inds[-1],),
reduced,
inds,
adjust_chunks=adjust_chunks,
dtype=agg.dtype[agg.name],
key=agg.name,
name=f"{name}-{token}",
)

return (result, groups)

Expand Down

0 comments on commit 91110e6

Please sign in to comment.