From ecc3e544cfb89ff6a3b0476e3e26deed1d1125cb Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 10 Oct 2023 15:15:07 -0600 Subject: [PATCH 01/10] Significantly faster cohorts detection. --- flox/core.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/flox/core.py b/flox/core.py index e5518b551..d0e7b9872 100644 --- a/flox/core.py +++ b/flox/core.py @@ -175,6 +175,22 @@ def _unique(a: np.ndarray) -> np.ndarray: return np.sort(pd.unique(a.reshape(-1))) +def get_chunk_shape(array_chunks, index): + # from dask.array.slicing import normalize_index + + # if not isinstance(index, tuple): + # index = (index,) + # if sum(isinstance(ind, (np.ndarray, list)) for ind in index) > 1: + # raise ValueError("Can only slice with a single list") + # if any(ind is None for ind in index): + # raise ValueError("Slicing with np.newaxis or None is not supported") + # index = normalize_index(index, array.numblocks) + index = tuple(slice(k, k + 1) for k in index) # type: ignore + chunks = tuple(c[i] for c, i in zip(array_chunks, index)) + chunkshape = tuple(itertools.chain(*chunks)) + return chunkshape + + @memoize def find_group_cohorts(labels, chunks, merge: bool = True) -> dict: """ @@ -214,8 +230,10 @@ def find_group_cohorts(labels, chunks, merge: bool = True) -> dict: # Iterate over each block and create a new block of same shape with "chunk number" shape = tuple(array.blocks.shape[ax] for ax in axis) blocks = np.empty(math.prod(shape), dtype=object) - for idx, block in enumerate(array.blocks.ravel()): - blocks[idx] = np.full(tuple(block.shape[ax] for ax in axis), idx) + array_chunks = tuple(np.array(c) for c in array.chunks) + for idx, blockindex in enumerate(np.ndindex(array.shape)): + chunkshape = get_chunk_shape(array_chunks, blockindex) + blocks[idx] = np.full(chunkshape, idx) which_chunk = np.block(blocks.reshape(shape).tolist()).reshape(-1) raveled = labels.reshape(-1) From 8628490c8edfa29b9de38835ec51bb90d5f4a31f Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 10 Oct 2023 15:17:04 -0600 Subject: [PATCH 02/10] cleanup --- flox/core.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/flox/core.py b/flox/core.py index d0e7b9872..fe89d8e3b 100644 --- a/flox/core.py +++ b/flox/core.py @@ -176,15 +176,6 @@ def _unique(a: np.ndarray) -> np.ndarray: def get_chunk_shape(array_chunks, index): - # from dask.array.slicing import normalize_index - - # if not isinstance(index, tuple): - # index = (index,) - # if sum(isinstance(ind, (np.ndarray, list)) for ind in index) > 1: - # raise ValueError("Can only slice with a single list") - # if any(ind is None for ind in index): - # raise ValueError("Slicing with np.newaxis or None is not supported") - # index = normalize_index(index, array.numblocks) index = tuple(slice(k, k + 1) for k in index) # type: ignore chunks = tuple(c[i] for c, i in zip(array_chunks, index)) chunkshape = tuple(itertools.chain(*chunks)) From bdf41bfc8f8110f31bb2c6ef7e75caba884d63af Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 10 Oct 2023 15:26:07 -0600 Subject: [PATCH 03/10] add benchmark --- asv_bench/benchmarks/cohorts.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/asv_bench/benchmarks/cohorts.py b/asv_bench/benchmarks/cohorts.py index dbfbe8cd5..fb936c2bc 100644 --- a/asv_bench/benchmarks/cohorts.py +++ b/asv_bench/benchmarks/cohorts.py @@ -1,8 +1,10 @@ import dask import numpy as np import pandas as pd +import xarray as xr import flox +from flox.xarray import xarray_reduce class Cohorts: @@ -125,3 +127,13 @@ class PerfectMonthlyRechunked(PerfectMonthly): def setup(self, *args, **kwargs): super().setup() super().rechunk() + + +def time_cohorts_era5_single(): + TIME = 900 # 92044 in Google ARCO ERA5 + da = xr.DataArray( + dask.array.ones((TIME, 721, 1440), chunks=(1, -1, -1)), + dims=("time", "lat", "lon"), + coords=dict(time=pd.date_range("1959-01-01", freq="6H", periods=TIME)), + ) + xarray_reduce(da, da.time.dt.day, method="cohorts", func="any") From 6c5ab08a71cc2a569bbedf4a0932faa5aae1cf4e Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 10 Oct 2023 15:42:19 -0600 Subject: [PATCH 04/10] update types --- flox/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flox/core.py b/flox/core.py index fe89d8e3b..b6992334b 100644 --- a/flox/core.py +++ b/flox/core.py @@ -175,8 +175,8 @@ def _unique(a: np.ndarray) -> np.ndarray: return np.sort(pd.unique(a.reshape(-1))) -def get_chunk_shape(array_chunks, index): - index = tuple(slice(k, k + 1) for k in index) # type: ignore +def get_chunk_shape(array_chunks, index: tuple[int, ...]) -> tuple[int, ...]: + index = tuple(slice(k, k + 1) for k in index) chunks = tuple(c[i] for c, i in zip(array_chunks, index)) chunkshape = tuple(itertools.chain(*chunks)) return chunkshape From de90a7432cc1e90211a382b76732afa03f12da05 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 10 Oct 2023 15:53:46 -0600 Subject: [PATCH 05/10] fix --- flox/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flox/core.py b/flox/core.py index b6992334b..192c0d7dd 100644 --- a/flox/core.py +++ b/flox/core.py @@ -222,7 +222,7 @@ def find_group_cohorts(labels, chunks, merge: bool = True) -> dict: shape = tuple(array.blocks.shape[ax] for ax in axis) blocks = np.empty(math.prod(shape), dtype=object) array_chunks = tuple(np.array(c) for c in array.chunks) - for idx, blockindex in enumerate(np.ndindex(array.shape)): + for idx, blockindex in enumerate(np.ndindex(array.numblocks)): chunkshape = get_chunk_shape(array_chunks, blockindex) blocks[idx] = np.full(chunkshape, idx) which_chunk = np.block(blocks.reshape(shape).tolist()).reshape(-1) From 31f849e6c2128ea656c4dadf635e56909db875c2 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 10 Oct 2023 16:29:27 -0600 Subject: [PATCH 06/10] single chunk optimization --- flox/core.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/flox/core.py b/flox/core.py index 192c0d7dd..e2900e0f2 100644 --- a/flox/core.py +++ b/flox/core.py @@ -238,7 +238,11 @@ def invert(x) -> tuple[np.ndarray, ...]: chunks_cohorts = tlz.groupby(invert, label_chunks.keys()) - if merge: + # If our dataset has chunksize one along the axis, + # then no merging is possible. + single_chunks = all((ac == 1).all() for ac in array_chunks) + + if merge and not single_chunks: # First sort by number of chunks occupied by cohort sorted_chunks_cohorts = dict( sorted(chunks_cohorts.items(), key=lambda kv: len(kv[0]), reverse=True) From e417d58c3705317c9c4adb05c174814bc4dada25 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 10 Oct 2023 16:29:35 -0600 Subject: [PATCH 07/10] more optimization --- flox/core.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/flox/core.py b/flox/core.py index e2900e0f2..cd301cb36 100644 --- a/flox/core.py +++ b/flox/core.py @@ -220,12 +220,12 @@ def find_group_cohorts(labels, chunks, merge: bool = True) -> dict: # Iterate over each block and create a new block of same shape with "chunk number" shape = tuple(array.blocks.shape[ax] for ax in axis) - blocks = np.empty(math.prod(shape), dtype=object) + blocks = np.empty(shape, dtype=object) array_chunks = tuple(np.array(c) for c in array.chunks) - for idx, blockindex in enumerate(np.ndindex(array.numblocks)): + for blockindex in np.ndindex(array.numblocks): chunkshape = get_chunk_shape(array_chunks, blockindex) - blocks[idx] = np.full(chunkshape, idx) - which_chunk = np.block(blocks.reshape(shape).tolist()).reshape(-1) + blocks[blockindex] = np.full(chunkshape, np.ravel_multi_index(blockindex, array.numblocks)) + which_chunk = np.block(blocks.tolist()).reshape(-1) raveled = labels.reshape(-1) # these are chunks where a label is present From ab83e295d16c137c236dbf4a221961be8f29849c Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 10 Oct 2023 16:54:33 -0600 Subject: [PATCH 08/10] more optimization --- flox/core.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/flox/core.py b/flox/core.py index cd301cb36..f34cfb82b 100644 --- a/flox/core.py +++ b/flox/core.py @@ -175,13 +175,6 @@ def _unique(a: np.ndarray) -> np.ndarray: return np.sort(pd.unique(a.reshape(-1))) -def get_chunk_shape(array_chunks, index: tuple[int, ...]) -> tuple[int, ...]: - index = tuple(slice(k, k + 1) for k in index) - chunks = tuple(c[i] for c, i in zip(array_chunks, index)) - chunkshape = tuple(itertools.chain(*chunks)) - return chunkshape - - @memoize def find_group_cohorts(labels, chunks, merge: bool = True) -> dict: """ @@ -215,16 +208,17 @@ def find_group_cohorts(labels, chunks, merge: bool = True) -> dict: # 1. First subset the array appropriately axis = range(-labels.ndim, 0) # Easier to create a dask array and use the .blocks property - array = dask.array.ones(tuple(sum(c) for c in chunks), chunks=chunks) + array = dask.array.empty(tuple(sum(c) for c in chunks), chunks=chunks) labels = np.broadcast_to(labels, array.shape[-labels.ndim :]) # Iterate over each block and create a new block of same shape with "chunk number" shape = tuple(array.blocks.shape[ax] for ax in axis) + # Use a numpy object array to enable assignment in the loop blocks = np.empty(shape, dtype=object) array_chunks = tuple(np.array(c) for c in array.chunks) - for blockindex in np.ndindex(array.numblocks): - chunkshape = get_chunk_shape(array_chunks, blockindex) - blocks[blockindex] = np.full(chunkshape, np.ravel_multi_index(blockindex, array.numblocks)) + for idx, blockindex in enumerate(np.ndindex(array.numblocks)): + chunkshape = tuple(c[i] for c, i in zip(array_chunks, blockindex)) + blocks[blockindex] = np.full(chunkshape, idx) which_chunk = np.block(blocks.tolist()).reshape(-1) raveled = labels.reshape(-1) From e05b64c4841f983c46fe3d340df712ec1e5a17d8 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 10 Oct 2023 16:57:22 -0600 Subject: [PATCH 09/10] add comment --- flox/core.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/flox/core.py b/flox/core.py index f34cfb82b..e2784ae99 100644 --- a/flox/core.py +++ b/flox/core.py @@ -214,6 +214,8 @@ def find_group_cohorts(labels, chunks, merge: bool = True) -> dict: # Iterate over each block and create a new block of same shape with "chunk number" shape = tuple(array.blocks.shape[ax] for ax in axis) # Use a numpy object array to enable assignment in the loop + # TODO: is it possible to just use a nested list? + # That is what we need for `np.block` blocks = np.empty(shape, dtype=object) array_chunks = tuple(np.array(c) for c in array.chunks) for idx, blockindex in enumerate(np.ndindex(array.numblocks)): From fa934069f94d5bfc2f6ad62b644edf00c38cc29e Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 10 Oct 2023 17:31:00 -0600 Subject: [PATCH 10/10] fix benchmark --- asv_bench/benchmarks/cohorts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asv_bench/benchmarks/cohorts.py b/asv_bench/benchmarks/cohorts.py index fb936c2bc..21707d448 100644 --- a/asv_bench/benchmarks/cohorts.py +++ b/asv_bench/benchmarks/cohorts.py @@ -14,7 +14,7 @@ def setup(self, *args, **kwargs): raise NotImplementedError def time_find_group_cohorts(self): - flox.core.find_group_cohorts(self.by, self.array.chunks) + flox.core.find_group_cohorts(self.by, [self.array.chunks[ax] for ax in self.axis]) # The cache clear fails dependably in CI # Not sure why try: