Skip to content

Commit

Permalink
Use blockwise to extract final result for method="blockwise"
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Oct 25, 2022
1 parent df0da40 commit 2feb8b7
Showing 1 changed file with 27 additions and 16 deletions.
43 changes: 27 additions & 16 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1370,31 +1370,42 @@ def dask_groupby_agg(

# extract results from the dict
output_chunks = reduced.chunks[: -len(axis)] + group_chunks
ochunks = tuple(range(len(chunks_v)) for chunks_v in output_chunks)
layer2: dict[tuple, tuple] = {}
agg_name = f"{name}-{token}"
for ochunk in itertools.product(*ochunks):
if method == "blockwise":
if len(axis) == 1:
inchunk = ochunk
else:
if method == "blockwise" and len(axis) == 1:
result = reduced.map_blocks(
_extract_result,
key=agg.name,
dtype=agg.dtype[agg.name],
chunks=output_chunks,
name=agg_name,
)

else:
ochunks = tuple(range(len(chunks_v)) for chunks_v in output_chunks)
layer2: dict[tuple, tuple] = {}
for ochunk in itertools.product(*ochunks):
if method == "blockwise":
nblocks = tuple(len(array.chunks[ax]) for ax in axis)
inchunk = ochunk[:-1] + np.unravel_index(ochunk[-1], nblocks)
else:
inchunk = ochunk[:-1] + (0,) * (len(axis) - 1) + (ochunk[-1],)
else:
inchunk = ochunk[:-1] + (0,) * (len(axis) - 1) + (ochunk[-1],)

layer2[(agg_name, *ochunk)] = (operator.getitem, (reduced.name, *inchunk), agg.name)
layer2[(agg_name, *ochunk)] = (operator.getitem, (reduced.name, *inchunk), agg.name)

result = dask.array.Array(
HighLevelGraph.from_collections(agg_name, layer2, dependencies=[reduced]),
agg_name,
chunks=output_chunks,
dtype=agg.dtype[agg.name],
)
result = dask.array.Array(
HighLevelGraph.from_collections(agg_name, layer2, dependencies=[reduced]),
agg_name,
chunks=output_chunks,
dtype=agg.dtype[agg.name],
)

return (result, groups)


def _extract_result(result_dict: FinalResultsDict, key) -> np.ndarray:
return result_dict[key]


def _validate_reindex(
reindex: bool | None, func, method: T_Method, expected_groups, by_is_dask: bool
) -> bool:
Expand Down

0 comments on commit 2feb8b7

Please sign in to comment.