Skip to content

Commit

Permalink
PI-2472: Optimise the area weighted regridding routine (SciTools#3598)
Browse files Browse the repository at this point in the history
* PI-2472: Tweak area weighting regrid move xdim and ydim axes (SciTools#3594)

* _regrid_area_weighted_array: Set axis order to y_dim, x_dim last dimensions

* _regrid_area_weighted_array: Extra tests for axes ordering

* PI-2472: Tweak area weighting regrid enforce xdim ydim (SciTools#3595)

* _regrid_area_weighted_array: Set axis order to y_dim, x_dim last dimensions

* _regrid_area_weighted_array: Extra tests for axes ordering

* _regrid_area_weighted_array: Ensure x_dim and y_dim

* PI-2472: Tweak area weighting regrid move averaging out of loop (SciTools#3596)

* _regrid_area_weighted_array: Refactor weights and move averaging outside loop
  • Loading branch information
abooton authored and pp-mo committed Jan 14, 2020
1 parent cebfb52 commit 89d2330
Show file tree
Hide file tree
Showing 4 changed files with 307 additions and 151 deletions.
357 changes: 243 additions & 114 deletions lib/iris/experimental/regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,135 +461,264 @@ def _regrid_area_weighted_array(src_data, x_dim, y_dim,
grid.
"""
# Determine which grid bounds are within src extent.
y_within_bounds = _within_bounds(
src_y_bounds, grid_y_bounds, grid_y_decreasing
)
x_within_bounds = _within_bounds(
src_x_bounds, grid_x_bounds, grid_x_decreasing
)

# Cache which src_bounds are within grid bounds
cached_x_bounds = []
cached_x_indices = []
for (x_0, x_1) in grid_x_bounds:
if grid_x_decreasing:
x_0, x_1 = x_1, x_0
x_bounds, x_indices = _cropped_bounds(src_x_bounds, x_0, x_1)
cached_x_bounds.append(x_bounds)
cached_x_indices.append(x_indices)

# Create empty data array to match the new grid.
# Note that dtype is not preserved and that the array is
# masked to allow for regions that do not overlap.
new_shape = list(src_data.shape)
if x_dim is not None:
new_shape[x_dim] = grid_x_bounds.shape[0]
if y_dim is not None:
new_shape[y_dim] = grid_y_bounds.shape[0]
def _calculate_regrid_area_weighted_weights(
src_x_bounds,
src_y_bounds,
grid_x_bounds,
grid_y_bounds,
grid_x_decreasing,
grid_y_decreasing,
area_func,
circular=False,
):
"""
Compute the area weights used for area-weighted regridding.
"""
# Determine which grid bounds are within src extent.
y_within_bounds = _within_bounds(
src_y_bounds, grid_y_bounds, grid_y_decreasing
)
x_within_bounds = _within_bounds(
src_x_bounds, grid_x_bounds, grid_x_decreasing
)

# Cache which src_bounds are within grid bounds
cached_x_bounds = []
cached_x_indices = []
max_x_indices = 0
for (x_0, x_1) in grid_x_bounds:
if grid_x_decreasing:
x_0, x_1 = x_1, x_0
x_bounds, x_indices = _cropped_bounds(src_x_bounds, x_0, x_1)
cached_x_bounds.append(x_bounds)
cached_x_indices.append(x_indices)
# Keep record of the largest slice
if isinstance(x_indices, slice):
x_indices_size = np.sum(x_indices.stop - x_indices.start)
else: # is tuple of indices
x_indices_size = len(x_indices)
if x_indices_size > max_x_indices:
max_x_indices = x_indices_size

# Cache which y src_bounds areas and weights are within grid bounds
cached_y_indices = []
cached_weights = []
max_y_indices = 0
for j, (y_0, y_1) in enumerate(grid_y_bounds):
# Reverse lower and upper if dest grid is decreasing.
if grid_y_decreasing:
y_0, y_1 = y_1, y_0
y_bounds, y_indices = _cropped_bounds(src_y_bounds, y_0, y_1)
cached_y_indices.append(y_indices)
# Keep record of the largest slice
if isinstance(y_indices, slice):
y_indices_size = np.sum(y_indices.stop - y_indices.start)
else: # is tuple of indices
y_indices_size = len(y_indices)
if y_indices_size > max_y_indices:
max_y_indices = y_indices_size

weights_i = []
for i, (x_0, x_1) in enumerate(grid_x_bounds):
# Reverse lower and upper if dest grid is decreasing.
if grid_x_decreasing:
x_0, x_1 = x_1, x_0
x_bounds = cached_x_bounds[i]
x_indices = cached_x_indices[i]

# Determine whether element i, j overlaps with src and hence
# an area weight should be computed.
# If x_0 > x_1 then we want [0]->x_1 and x_0->[0] + mod in the case
# of wrapped longitudes. However if the src grid is not global
# (i.e. circular) this new cell would include a region outside of
# the extent of the src grid and thus the weight is therefore
# invalid.
outside_extent = x_0 > x_1 and not circular
if (
outside_extent
or not y_within_bounds[j]
or not x_within_bounds[i]
):
weights = False
else:
# Calculate weights based on areas of cropped bounds.
if isinstance(x_indices, tuple) and isinstance(
y_indices, tuple
):
raise RuntimeError(
"Cannot handle split bounds " "in both x and y."
)
weights = area_func(y_bounds, x_bounds)
weights_i.append(weights)
cached_weights.append(weights_i)
return (
tuple(cached_x_indices),
tuple(cached_y_indices),
max_x_indices,
max_y_indices,
tuple(cached_weights),
)

weights_info = _calculate_regrid_area_weighted_weights(
src_x_bounds,
src_y_bounds,
grid_x_bounds,
grid_y_bounds,
grid_x_decreasing,
grid_y_decreasing,
area_func,
circular,
)
(
cached_x_indices,
cached_y_indices,
max_x_indices,
max_y_indices,
cached_weights,
) = weights_info
# Delete variables that are not needed and would not be available
# if _calculate_regrid_area_weighted_weights was refactored further
del src_x_bounds, src_y_bounds, grid_x_bounds, grid_y_bounds
del grid_x_decreasing, grid_y_decreasing
del area_func, circular

# Ensure we have x_dim and y_dim.
x_dim_orig = x_dim
y_dim_orig = y_dim
if y_dim is None:
src_data = np.expand_dims(src_data, axis=src_data.ndim)
y_dim = src_data.ndim - 1
if x_dim is None:
src_data = np.expand_dims(src_data, axis=src_data.ndim)
x_dim = src_data.ndim - 1
# Move y_dim and x_dim to last dimensions
if not x_dim == src_data.ndim - 1:
src_data = np.moveaxis(src_data, x_dim, -1)
if not y_dim == src_data.ndim - 2:
if x_dim < y_dim:
# note: y_dim was shifted along by one position when
# x_dim was moved to the last dimension
src_data = np.moveaxis(src_data, y_dim - 1, -2)
elif x_dim > y_dim:
src_data = np.moveaxis(src_data, y_dim, -2)
x_dim = src_data.ndim - 1
y_dim = src_data.ndim - 2

# Create empty "pre-averaging" data array that will enable the
# src_data data coresponding to a given target grid point,
# to be stacked per point.
# Note that dtype is not preserved and that the array mask
# allows for regions that do not overlap.
new_shape = list(src_data.shape)
new_shape[x_dim] = len(cached_x_indices)
new_shape[y_dim] = len(cached_y_indices)
num_target_pts = len(cached_y_indices) * len(cached_x_indices)
src_areas_shape = list(src_data.shape)
src_areas_shape[y_dim] = max_y_indices
src_areas_shape[x_dim] = max_x_indices
src_areas_shape += [num_target_pts]
# Use input cube dtype or convert values to the smallest possible float
# dtype when necessary.
dtype = np.promote_types(src_data.dtype, np.float16)
# Create empty arrays to hold src_data per target point, and weights
src_area_datas = np.zeros(src_areas_shape, dtype=np.float64)
src_area_weights = np.zeros(
list((max_y_indices, max_x_indices, num_target_pts))
)

# Flag to indicate whether the original data was a masked array.
src_masked = ma.isMaskedArray(src_data)
src_masked = src_data.mask.any() if ma.isMaskedArray(src_data) else False
if src_masked:
new_data = ma.zeros(new_shape, fill_value=src_data.fill_value,
dtype=dtype)
src_area_masks = np.full(src_areas_shape, True, dtype=np.bool)
else:
new_data = ma.zeros(new_shape, dtype=dtype)
# Assign to mask to explode it, allowing indexed assignment.
new_data.mask = False
new_data_mask = np.full(new_shape, False, dtype=np.bool)

# Axes of data over which the weighted mean is calculated.
axes = []
if y_dim is not None:
axes.append(y_dim)
if x_dim is not None:
axes.append(x_dim)
axis = tuple(axes)

# Simple for loop approach.
indices = [slice(None)] * new_data.ndim
for j, (y_0, y_1) in enumerate(grid_y_bounds):
# Reverse lower and upper if dest grid is decreasing.
if grid_y_decreasing:
y_0, y_1 = y_1, y_0
y_bounds, y_indices = _cropped_bounds(src_y_bounds, y_0, y_1)
for i, (x_0, x_1) in enumerate(grid_x_bounds):
# Reverse lower and upper if dest grid is decreasing.
if grid_x_decreasing:
x_0, x_1 = x_1, x_0
x_bounds = cached_x_bounds[i]
x_indices = cached_x_indices[i]

# Determine whether to mask element i, j based on overlap with
# src.
# If x_0 > x_1 then we want [0]->x_1 and x_0->[0] + mod in the case
# of wrapped longitudes. However if the src grid is not global
# (i.e. circular) this new cell would include a region outside of
# the extent of the src grid and should therefore be masked.
outside_extent = x_0 > x_1 and not circular
if (outside_extent or not y_within_bounds[j] or not
x_within_bounds[i]):
# Mask out element(s) in new_data
if x_dim is not None:
indices[x_dim] = i
if y_dim is not None:
indices[y_dim] = j
new_data[tuple(indices)] = ma.masked
axis = (y_dim, x_dim)

# Stack the src_area data and weights for each target point
target_pt_ji = -1
for j, y_indices in enumerate(cached_y_indices):
for i, x_indices in enumerate(cached_x_indices):
target_pt_ji += 1
# Determine whether to mask element i, j based on whether
# there are valid weights.
weights = cached_weights[j][i]
if isinstance(weights, bool) and not weights:
if not src_masked:
# Cheat! Fill the data with zeros and weights as one.
# The weighted average result will be the same, but
# we avoid dividing by zero.
src_area_weights[..., target_pt_ji] = 1
new_data_mask[..., j, i] = True
else:
# Calculate weighted mean of data points.
# Slice out relevant data (this may or may not be a view()
# depending on x_indices being a slice or not).
if isinstance(x_indices, tuple) and isinstance(
y_indices, tuple
):
raise RuntimeError(
"Cannot handle split bounds " "in both x and y."
)
# Calculate weights based on areas of cropped bounds.
weights = area_func(y_bounds, x_bounds)

if x_dim is not None:
indices[x_dim] = x_indices
if y_dim is not None:
indices[y_dim] = y_indices
data = src_data[tuple(indices)]

# Transpose weights to match dim ordering in data.
weights_shape_y = weights.shape[0]
weights_shape_x = weights.shape[1]
if x_dim is not None and y_dim is not None and x_dim < y_dim:
weights = weights.T
# Broadcast the weights array to allow numpy's ma.average
# to be called.
weights_padded_shape = [1] * data.ndim
if y_dim is not None:
weights_padded_shape[y_dim] = weights_shape_y
if x_dim is not None:
weights_padded_shape[x_dim] = weights_shape_x
# Assign new shape to raise error on copy.
weights.shape = weights_padded_shape
# Broadcast weights to match shape of data.
_, weights = np.broadcast_arrays(data, weights)

# Calculate weighted mean taking into account missing data.
new_data_pt = _weighted_mean_with_mdtol(
data, weights=weights, axis=axis, mdtol=mdtol)

# Insert data (and mask) values into new array.
if x_dim is not None:
indices[x_dim] = i
if y_dim is not None:
indices[y_dim] = j
new_data[tuple(indices)] = new_data_pt

# Remove new mask if original data was not masked
# and no values in the new array are masked.
if not src_masked and not new_data.mask.any():
new_data = new_data.data
data = src_data[..., y_indices, x_indices]
len_x = data.shape[-1]
len_y = data.shape[-2]
src_area_datas[..., 0:len_y, 0:len_x, target_pt_ji] = data
src_area_weights[0:len_y, 0:len_x, target_pt_ji] = weights
if src_masked:
src_area_masks[
..., 0:len_y, 0:len_x, target_pt_ji
] = data.mask

# Broadcast the weights array to allow numpy's ma.average
# to be called.
# Assign new shape to raise error on copy.
src_area_weights.shape = src_area_datas.shape[-3:]
# Broadcast weights to match shape of data.
_, src_area_weights = np.broadcast_arrays(src_area_datas, src_area_weights)

# Mask the data points
if src_masked:
src_area_datas = np.ma.array(src_area_datas, mask=src_area_masks)

# Calculate weighted mean taking into account missing data.
new_data = _weighted_mean_with_mdtol(
src_area_datas, weights=src_area_weights, axis=axis, mdtol=mdtol
)
new_data = new_data.reshape(new_shape)
if src_masked:
new_data_mask = new_data.mask

# Mask the data if originally masked or if the result has masked points
if ma.isMaskedArray(src_data):
new_data = ma.array(
new_data,
mask=new_data_mask,
fill_value=src_data.fill_value,
dtype=dtype,
)
elif new_data_mask.any():
new_data = ma.array(new_data, mask=new_data_mask, dtype=dtype)
else:
new_data = new_data.astype(dtype)

# Restore data to original form
if x_dim_orig is None and y_dim_orig is None:
new_data = np.squeeze(new_data, axis=x_dim)
new_data = np.squeeze(new_data, axis=y_dim)
elif y_dim_orig is None:
new_data = np.squeeze(new_data, axis=y_dim)
new_data = np.moveaxis(new_data, -1, x_dim_orig)
elif x_dim_orig is None:
new_data = np.squeeze(new_data, axis=x_dim)
new_data = np.moveaxis(new_data, -1, y_dim_orig)
elif x_dim_orig < y_dim_orig:
# move the x_dim back first, so that the y_dim will
# then be moved to its original position
new_data = np.moveaxis(new_data, -1, x_dim_orig)
new_data = np.moveaxis(new_data, -1, y_dim_orig)
else:
# move the y_dim back first, so that the x_dim will
# then be moved to its original position
new_data = np.moveaxis(new_data, -2, y_dim_orig)
new_data = np.moveaxis(new_data, -1, x_dim_orig)

return new_data

Expand Down
Loading

0 comments on commit 89d2330

Please sign in to comment.