Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid broadcasting by variables against each other #186

Merged
merged 17 commits into from
Nov 26, 2022
26 changes: 19 additions & 7 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,10 @@ def find_group_cohorts(labels, chunks, merge: bool = True):

# Build an array with the shape of labels, but where every element is the "chunk number"
# 1. First subset the array appropriately
axis = range(-labels.ndim, 0)
axis = tuple(range(-labels.ndim, 0))
# Easier to create a dask array and use the .blocks property
array = dask.array.ones(tuple(sum(c) for c in chunks), chunks=chunks)
labels = np.broadcast_to(labels, array.shape[-labels.ndim :])

# Iterate over each block and create a new block of same shape with "chunk number"
shape = tuple(array.blocks.shape[ax] for ax in axis)
Expand Down Expand Up @@ -479,7 +480,7 @@ def factorize_(
idx, groups = pd.factorize(flat, sort=sort)

found_groups.append(np.array(groups))
factorized.append(idx)
factorized.append(idx.reshape(groupvar.shape))

grp_shape = tuple(len(grp) for grp in found_groups)
ngroups = math.prod(grp_shape)
Expand All @@ -489,20 +490,18 @@ def factorize_(
# Restore these after the raveling
nan_by_mask = reduce(np.logical_or, [(f == -1) for f in factorized])
group_idx[nan_by_mask] = -1
group_idx = group_idx.reshape(by[0].shape)
else:
group_idx = factorized[0]

if fastpath:
return group_idx.reshape(by[0].shape), found_groups, grp_shape
return group_idx, found_groups, grp_shape

if np.isscalar(axis) and groupvar.ndim > 1:
# Not reducing along all dimensions of by
# this is OK because for 3D by and axis=(1,2),
# we collapse to a 2D by and axis=-1
offset_group = True
group_idx, size = offset_labels(group_idx.reshape(by[0].shape), ngroups)
group_idx = group_idx.reshape(-1)
else:
size = ngroups
offset_group = False
Expand Down Expand Up @@ -647,6 +646,8 @@ def chunk_reduce(
else:
nax = by.ndim

assert by.ndim <= array.ndim

final_array_shape = array.shape[:-nax] + (1,) * (nax - 1)
final_groups_shape = (1,) * (nax - 1)

Expand All @@ -667,9 +668,17 @@ def chunk_reduce(
)
groups = groups[0]

if isinstance(axis, Sequence):
needs_broadcast = any(
group_idx.shape[ax] != array.shape[ax] and group_idx.shape[ax] == 1
for ax in range(-len(axis), 0)
)
if needs_broadcast:
group_idx = np.broadcast_to(group_idx, array.shape[-by.ndim :])
# always reshape to 1D along group dimensions
newshape = array.shape[: array.ndim - by.ndim] + (math.prod(array.shape[-by.ndim :]),)
array = array.reshape(newshape)
group_idx = group_idx.reshape(-1)

assert group_idx.ndim == 1
empty = np.all(props.nanmask)
Expand Down Expand Up @@ -1220,7 +1229,9 @@ def dask_groupby_agg(
# chunk numpy arrays like the input array
# This removes an extra rechunk-merge layer that would be
# added otherwise
by = dask.array.from_array(by, chunks=tuple(array.chunks[ax] for ax in range(-by.ndim, 0)))
chunks = tuple(array.chunks[ax] if by.shape[ax] != 1 else (1,) for ax in range(-by.ndim, 0))

by = dask.array.from_array(by, chunks=chunks)
_, (array, by) = dask.array.unify_chunks(array, inds, by, inds[-by.ndim :])

# preprocess the array: for argreductions, this zips the index together with the array block
Expand Down Expand Up @@ -1424,8 +1435,9 @@ def _validate_reindex(


def _assert_by_is_aligned(shape, by):
assert all(b.ndim == by[0].ndim for b in by[1:])
for idx, b in enumerate(by):
if shape[-b.ndim :] != b.shape:
if not all(j in [i, 1] for i, j in zip(shape[-b.ndim :], b.shape)):
raise ValueError(
"`array` and `by` arrays must be aligned "
"i.e. array.shape[-by.ndim :] == by.shape. "
Expand Down
74 changes: 40 additions & 34 deletions flox/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,6 @@
Dims = Union[str, Iterable[Hashable], None]


def _get_input_core_dims(group_names, dim, ds, grouper_dims):
input_core_dims = [[], []]
for g in group_names:
if g in dim:
continue
if g in ds.dims:
input_core_dims[0].extend([g])
if g in grouper_dims:
input_core_dims[1].extend([g])
input_core_dims[0].extend(dim)
input_core_dims[1].extend(dim)
return input_core_dims


def _restore_dim_order(result, obj, by):
def lookup_order(dimension):
if dimension == by.name and by.ndim == 1:
Expand All @@ -54,6 +40,27 @@ def lookup_order(dimension):
return result.transpose(*new_order)


def _broadcast_size_one_dims(*arrays, core_dims):
"""Broadcast by adding size-1 dimensions in the right place.

Workaround because apply_ufunc doesn't support this yet.
https://github.com/pydata/xarray/issues/3032#issuecomment-503337637

Specialized to the groupby problem.
"""
array_dims = set(core_dims[0])
broadcasted = [arrays[0]]
for dims, array in zip(core_dims[1:], arrays[1:]):
assert set(dims).issubset(array_dims)
order = [dims.index(d) for d in core_dims[0] if d in dims]
array = array.transpose(*order)
axis = [core_dims[0].index(d) for d in core_dims[0] if d not in dims]
broadcasted.append(np.expand_dims(array, axis))

# ic(tuple(zip(core_dims, (a.shape for a in broadcasted))))
return broadcasted


def xarray_reduce(
obj: T_Dataset | T_DataArray,
*by: T_DataArray | Hashable,
Expand Down Expand Up @@ -255,20 +262,11 @@ def xarray_reduce(
elif dim is not None:
dim_tuple = _atleast_1d(dim)
else:
dim_tuple = tuple()
dim_tuple = tuple(grouper_dims)

# broadcast all variables against each other along all dimensions in `by` variables
# don't exclude `dim` because it need not be a dimension in any of the `by` variables!
# in the case where dim is Ellipsis, and by.ndim < obj.ndim
# then we also broadcast `by` to all `obj.dims`
# TODO: avoid this broadcasting
# broadcast to make sure grouper dimensions are present in the array.
exclude_dims = tuple(d for d in ds.dims if d not in grouper_dims and d not in dim_tuple)
ds_broad, *by_broad = xr.broadcast(ds, *by_da, exclude=exclude_dims)

# all members of by_broad have the same dimensions
# so we just pull by_broad[0].dims if dim is None
if not dim_tuple:
dim_tuple = tuple(by_broad[0].dims)
ds_broad = xr.broadcast(ds, *by_da, exclude=exclude_dims)[0]

if any(d not in grouper_dims and d not in obj.dims for d in dim_tuple):
raise ValueError(f"Cannot reduce over absent dimensions {dim}.")
Expand Down Expand Up @@ -298,7 +296,7 @@ def xarray_reduce(
expected_groups = list(expected_groups)
group_names: tuple[Any, ...] = ()
group_sizes: dict[Any, int] = {}
for idx, (b_, expect, isbin_) in enumerate(zip(by_broad, expected_groups, isbins)):
for idx, (b_, expect, isbin_) in enumerate(zip(by_da, expected_groups, isbins)):
group_name = b_.name if not isbin_ else f"{b_.name}_bins"
group_names += (group_name,)

Expand Down Expand Up @@ -326,7 +324,10 @@ def xarray_reduce(
# This will never be reached
raise ValueError("expect_index cannot be None")

def wrapper(array, *by, func, skipna, **kwargs):
def wrapper(array, *by, func, skipna, core_dims, **kwargs):

array, *by = _broadcast_size_one_dims(array, *by, core_dims=core_dims)

# Handle skipna here because I need to know dtype to make a good default choice.
# We cannnot handle this easily for xarray Datasets in xarray_reduce
if skipna and func in ["all", "any", "count"]:
Expand Down Expand Up @@ -374,17 +375,21 @@ def wrapper(array, *by, func, skipna, **kwargs):
if is_missing_dim:
missing_dim[k] = v

input_core_dims = _get_input_core_dims(group_names, dim_tuple, ds_broad, grouper_dims)
input_core_dims += [input_core_dims[-1]] * (nby - 1)
# dim_tuple contains dimensions we are reducing over. These need to be the last
# core dimensions to be synchronized with axis.
input_core_dims = [[d for d in grouper_dims if d not in dim_tuple] + list(dim_tuple)]
input_core_dims += [list(b.dims) for b in by_da]

output_core_dims = [d for d in input_core_dims[0] if d not in dim_tuple]
output_core_dims.extend(group_names)
actual = xr.apply_ufunc(
wrapper,
ds_broad.drop_vars(tuple(missing_dim)).transpose(..., *grouper_dims),
*by_broad,
*by_da,
input_core_dims=input_core_dims,
# for xarray's test_groupby_duplicate_coordinate_labels
exclude_dims=set(dim_tuple),
output_core_dims=[group_names],
output_core_dims=[output_core_dims],
dask="allowed",
dask_gufunc_kwargs=dict(
output_sizes=group_sizes, output_dtypes=[dtype] if dtype is not None else None
Expand All @@ -404,6 +409,7 @@ def wrapper(array, *by, func, skipna, **kwargs):
"isbin": isbins,
"finalize_kwargs": finalize_kwargs,
"dtype": dtype,
"core_dims": input_core_dims,
},
)

Expand All @@ -413,7 +419,7 @@ def wrapper(array, *by, func, skipna, **kwargs):
if all(d not in ds_broad[var].dims for d in dim_tuple):
actual[var] = ds_broad[var]

for name, expect, by_ in zip(group_names, expected_groups, by_broad):
for name, expect, by_ in zip(group_names, expected_groups, by_da):
# Can't remove this till xarray handles IntervalIndex
if isinstance(expect, pd.IntervalIndex):
expect = expect.to_numpy()
Expand Down Expand Up @@ -443,7 +449,7 @@ def wrapper(array, *by, func, skipna, **kwargs):
template = obj

if actual[var].ndim > 1:
actual[var] = _restore_dim_order(actual[var], template, by_broad[0])
actual[var] = _restore_dim_order(actual[var], template, by_da[0])

if missing_dim:
for k, v in missing_dim.items():
Expand Down