Skip to content

Commit

Permalink
Merge pull request #312 from davidhassell/dask-digitize
Browse files Browse the repository at this point in the history
dask: `Dask.digitize`
  • Loading branch information
sadielbartholomew authored Feb 4, 2022
2 parents ffb5555 + a669221 commit d6cd3d1
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 62 deletions.
88 changes: 35 additions & 53 deletions cf/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1656,13 +1656,16 @@ def dumps(self):

return json_dumps(d, default=convert_to_builtin_type)

@daskified(_DASKIFIED_VERBOSE)
@_inplace_enabled(default=False)
def digitize(
self,
bins,
upper=False,
open_ends=False,
closed_ends=None,
return_bins=False,
inplace=False,
):
"""Return the indices of the bins to which each value belongs.
Expand Down Expand Up @@ -1747,6 +1750,8 @@ def digitize(
return_bins: `bool`, optional
If True then also return the bins in their 2-d form.
{{inplace: `bool`, optional}}
:Returns:
`Data`, [`Data`]
Expand All @@ -1755,7 +1760,7 @@ def digitize(
If *return_bins* is True then also return the bins in
their 2-d form.
**Examples:**
**Examples**
>>> d = cf.Data(numpy.arange(12).reshape(3, 4))
[[ 0 1 2 3]
Expand Down Expand Up @@ -1811,9 +1816,9 @@ def digitize(
[ 1 1 1 --]]
"""
out = self.copy()
d = _inplace_enabled_define_and_cleanup(self)

org_units = self.Units
org_units = d.Units

bin_units = getattr(bins, "Units", None)

Expand All @@ -1830,12 +1835,16 @@ def digitize(
else:
bin_units = org_units

bins = np.asanyarray(bins)
# Get bins as a numpy array
if isinstance(bins, np.ndarray):
bins = bins.copy()
else:
bins = np.asanyarray(bins)

if bins.ndim > 2:
raise ValueError(
"The 'bins' parameter must be scalar, 1-d or 2-d"
"Got: {!r}".format(bins)
"The 'bins' parameter must be scalar, 1-d or 2-d. "
f"Got: {bins!r}"
)

two_d_bins = None
Expand All @@ -1848,7 +1857,7 @@ def digitize(
if bins.shape[1] != 2:
raise ValueError(
"The second dimension of the 'bins' parameter must "
"have size 2. Got: {!r}".format(bins)
f"have size 2. Got: {bins!r}"
)

bins.sort(axis=1)
Expand All @@ -1858,11 +1867,9 @@ def digitize(
for i, (u, l) in enumerate(zip(bins[:-1, 1], bins[1:, 0])):
if u > l:
raise ValueError(
"Overlapping bins: {}, {}".format(
tuple(bins[i]), tuple(bins[i + i])
)
f"Overlapping bins: "
f"{tuple(bins[i])}, {tuple(bins[i + i])}"
)
# --- End: for

two_d_bins = bins
bins = np.unique(bins)
Expand Down Expand Up @@ -1900,8 +1907,8 @@ def digitize(
"scalar."
)

mx = self.max().datum()
mn = self.min().datum()
mx = d.max().datum()
mn = d.min().datum()
bins = np.linspace(mn, mx, int(bins) + 1, dtype=float)

delete_bins = []
Expand All @@ -1913,7 +1920,8 @@ def digitize(
"Can't set open_ends=True when closed_ends is True."
)

bins = bins.astype(float, copy=True)
if bins.dtype.kind != "f":
bins = bins.astype(float, copy=False)

epsilon = np.finfo(float).eps
ndim = bins.ndim
Expand All @@ -1923,53 +1931,27 @@ def digitize(
else:
mx = bins[(-1,) * ndim]
bins[(-1,) * ndim] += abs(mx) * epsilon
# --- End: if

if not open_ends:
delete_bins.insert(0, 0)
delete_bins.append(bins.size)

if return_bins and two_d_bins is None:
x = np.empty((bins.size - 1, 2), dtype=bins.dtype)
x[:, 0] = bins[:-1]
x[:, 1] = bins[1:]
two_d_bins = x

config = out.partition_configuration(readonly=True)

for partition in out.partitions.matrix.flat:
partition.open(config)
array = partition.array

mask = None
if np.ma.isMA(array):
mask = array.mask.copy()

array = np.digitize(array, bins, right=upper)

if delete_bins:
for n, d in enumerate(delete_bins):
d -= n
array = np.ma.where(array == d, np.ma.masked, array)
array = np.ma.where(array > d, array - 1, array)
# --- End: if

if mask is not None:
array = np.ma.where(mask, np.ma.masked, array)

partition.subarray = array
partition.Units = _units_None

partition.close()

out.dtype = int

out.override_units(_units_None, inplace=True)
# Digitise the array
dx = d._get_dask()
dx = da.digitize(dx, bins, right=upper)
d._set_dask(dx, reset_mask_hardness=True)
d.override_units(_units_None, inplace=True)

if return_bins:
return out, type(self)(two_d_bins, units=bin_units)
if two_d_bins is None:
two_d_bins = np.empty((bins.size - 1, 2), dtype=bins.dtype)
two_d_bins[:, 0] = bins[:-1]
two_d_bins[:, 1] = bins[1:]

return out
two_d_bins = type(self)(two_d_bins, units=bin_units)
return d, two_d_bins

return d

def median(
self,
Expand Down
39 changes: 30 additions & 9 deletions cf/test/test_Data.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,6 @@ def test_Data__init__dtype_mask(self):
self.assertTrue((d.array == a).all())
self.assertTrue((d.mask.array == np.ma.getmaskarray(a)).all())

@unittest.skipIf(TEST_DASKIFIED_ONLY, "no attr. 'partition_configuration'")
def test_Data_digitize(self):
if self.test_only and inspect.stack()[0][3] not in self.test_only:
return
Expand All @@ -829,15 +828,37 @@ def test_Data_digitize(self):
b = np.digitize(a, [2, 6, 10, 50, 100], right=upper)

self.assertTrue((e.array == b).all())

e.where(
cf.set([e.minimum(), e.maximum()]),
cf.masked,
e - 1,
inplace=True,
self.assertTrue(
(np.ma.getmask(e.array) == np.ma.getmask(b)).all()
)
f = d.digitize(bins, upper=upper)
self.assertTrue(e.equals(f, verbose=2))

# TODODASK: Reinstate the following test when
# __sub__, minimum, and maximum have
# been daskified

# e.where(
# cf.set([e.minimum(), e.maximum()]),
# cf.masked,
# e - 1,
# inplace=True,
# )
# f = d.digitize(bins, upper=upper)
# self.assertTrue(e.equals(f, verbose=2))

# Check returned bins
bins = [2, 6, 10, 50, 100]
e, b = d.digitize(bins, return_bins=True)
self.assertTrue(
(b.array == [[2, 6], [6, 10], [10, 50], [50, 100]]).all()
)
self.assertTrue(b.Units == d.Units)

# Check digitized units
self.assertTrue(e.Units == cf.Units(None))

# Check inplace
self.assertIsNone(d.digitize(bins, inplace=True))
self.assertTrue(d.equals(e))

@unittest.skipIf(TEST_DASKIFIED_ONLY, "no attribute '_ndim'")
def test_Data_cumsum(self):
Expand Down

0 comments on commit d6cd3d1

Please sign in to comment.