Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mobt 812 vera threshold interpolation #2079

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions improver/cli/threshold_interpolation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#!/usr/bin/env python
# (C) Crown Copyright, Met Office. All rights reserved.
#
# This file is part of 'IMPROVER' and is released under the BSD 3-Clause license.
# See LICENSE in the root of the repository for full licensing details.
"""Script to run the threshold interpolation plugin."""

from improver import cli


@cli.clizefy
@cli.with_output
def process(
forecast_at_thresholds: cli.inputcube,
*,
thresholds: cli.comma_separated_list,
):
"""
Use this CLI to modify the probability thresholds in an existing probability
forecast cube by linearly interpolating between the existing thresholds.

Args:
forecast_at_thresholds:
Cube expected to contain a threshold coordinate.
thresholds:
List of the desired output thresholds.

Returns:
Cube with forecast values at the desired set of thresholds.
The threshold coordinate is always the zeroth dimension.
"""
from improver.utilities.threshold_interpolation import ThresholdInterpolation

result = ThresholdInterpolation(thresholds)(forecast_at_thresholds)

return result
212 changes: 212 additions & 0 deletions improver/utilities/threshold_interpolation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
# (C) Crown Copyright, Met Office. All rights reserved.
#
# This file is part of 'IMPROVER' and is released under the BSD 3-Clause license.
# See LICENSE in the root of the repository for full licensing details.
"""Script to linearly interpolate thresholds"""

from typing import List, Optional

import iris
import numpy as np
from iris.cube import Cube
from numpy import ndarray

from improver import PostProcessingPlugin
from improver.calibration.utilities import convert_cube_data_to_2d
from improver.ensemble_copula_coupling.utilities import (
interpolate_multiple_rows_same_x,
restore_non_percentile_dimensions,
)
from improver.metadata.probabilistic import (
find_threshold_coordinate,
)
from improver.utilities.cube_manipulation import (
collapse_realizations,
enforce_coordinate_ordering,
)


class ThresholdInterpolation(PostProcessingPlugin):
def __init__(self, thresholds: List[float]):
"""
Args:
thresholds:
List of the desired output thresholds.

Raises:
ValueError:
If the thresholds list is empty.
"""
if not thresholds:
raise ValueError("The thresholds list cannot be empty.")
self.thresholds = thresholds
self.threshold_coord = None

def mask_checking(self, forecast_at_thresholds: Cube) -> Optional[np.ndarray]:
"""
Check if the mask is consistent across different slices of the threshold coordinate.

Args:
forecast_at_thresholds:
The input cube containing forecast data with a threshold coordinate.

Returns:
original_mask:
The original mask if the data is masked and the mask is consistent across
different slices of the threshold coordinate, otherwise None.

Raises:
ValueError: If the mask varies across different slices of the threshold coordinate.
"""
original_mask = None
if np.ma.is_masked(forecast_at_thresholds.data):
(crd_dim,) = forecast_at_thresholds.coord_dims(self.threshold_coord.name())
if np.diff(forecast_at_thresholds.data.mask, axis=crd_dim).any():
raise ValueError(
f"The mask is expected to be constant across different slices of the {self.threshold_coord.name()}"
f" dimension, however, in the dataset provided, the mask varies across the {self.threshold_coord.name()}"
f" dimension. This is not currently supported."
)
else:
original_mask = next(
forecast_at_thresholds.slices_over(self.threshold_coord.name())
).data.mask

return original_mask

def _interpolate_thresholds(
self,
forecast_at_thresholds: Cube,
) -> np.ndarray:
"""
Interpolate forecast data to a new set of thresholds.

This method performs linear interpolation of forecast data from an initial
set of thresholds to a new set of thresholds. The interpolation is done
by converting the data to a 2D array, performing the interpolation, and
then restoring the original dimensions.

Args:
forecast_at_thresholds:
Cube containing forecast data with a threshold coordinate.

Returns:
ndarray:
Interpolated forecast data with the new set of thresholds.
"""
original_thresholds = self.threshold_coord.points

# Ensure that the threshold dimension is first, so that the
# conversion to a 2d array produces data in the desired order.
enforce_coordinate_ordering(forecast_at_thresholds, self.threshold_coord.name())
forecast_at_reshaped_thresholds = convert_cube_data_to_2d(
forecast_at_thresholds, coord=self.threshold_coord.name()
)

forecast_at_interpolated_thresholds = interpolate_multiple_rows_same_x(
np.array(self.thresholds, dtype=np.float64),
original_thresholds.astype(np.float64),
forecast_at_reshaped_thresholds.astype(np.float64),
)

forecast_at_interpolated_thresholds = np.transpose(
forecast_at_interpolated_thresholds
)

forecast_at_thresholds_data = restore_non_percentile_dimensions(
forecast_at_interpolated_thresholds,
next(forecast_at_thresholds.slices_over(self.threshold_coord.name())),
len(self.thresholds),
)

return forecast_at_thresholds_data

def create_cube_with_thresholds(
self,
forecast_at_thresholds: Cube,
cube_data: ndarray,
) -> Cube:
"""
Create a cube with a threshold coordinate based on a template cube extracted
by slicing over the threshold coordinate.

The resulting cube will have an extra threshold coordinate compared with
the template cube. The shape of the cube_data should be the shape of the
desired output cube.

Args:
forecast_at_thresholds:
Cube containing forecast data with a threshold coordinate.
cube_data:
Array containing the interpolated forecast data with the new thresholds.

Returns:
Cube containing the new threshold coordinate and the interpolated data.
"""
template_cube = next(
forecast_at_thresholds.slices_over(self.threshold_coord.name())
)
template_cube.remove_coord(self.threshold_coord)

# create cube with new threshold dimension
cubes = iris.cube.CubeList([])
for point in self.thresholds:
cube = template_cube.copy()
coord = iris.coords.DimCoord(
np.array([point], dtype="float32"), units=self.threshold_coord.units
)
coord.rename(self.threshold_coord.name())
coord.var_name = "threshold"
coord.attributes = self.threshold_coord.attributes
cube.add_aux_coord(coord)
cubes.append(cube)
result = cubes.merge_cube()
# replace data
result.data = cube_data
return result

def process(
self,
forecast_at_thresholds: Cube,
) -> Cube:
"""
Process the input cube to interpolate forecast data to a new set of thresholds.

This method performs the following steps:
1. Identifies the threshold coordinate in the input cube.
2. Checks if the mask is consistent across different slices of the threshold coordinate.
3. Collapses the realizations if present.
4. Interpolates the forecast data to the new set of thresholds.
5. Creates a new cube with the interpolated threshold data.
6. Applies the original mask to the new cube if it exists.

Args:
forecast_at_thresholds:
Cube expected to contain a threshold coordinate.

Returns:
Cube:
Cube with forecast values at the desired set of thresholds.
The threshold coordinate is always the zeroth dimension.
"""
self.threshold_coord = find_threshold_coordinate(forecast_at_thresholds)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 196/197 below:
Code coverage is suggesting you don't have test coverage of this. You'll need a unit test with a cube that has a realization dimension.


original_mask = self.mask_checking(forecast_at_thresholds)

if forecast_at_thresholds.coords("realization"):
forecast_at_thresholds = collapse_realizations(forecast_at_thresholds)

Check warning on line 197 in improver/utilities/threshold_interpolation.py

View check run for this annotation

Codecov / codecov/patch

improver/utilities/threshold_interpolation.py#L197

Added line #L197 was not covered by tests

forecast_at_thresholds_data = self._interpolate_thresholds(
forecast_at_thresholds,
)
threshold_cube = self.create_cube_with_thresholds(
forecast_at_thresholds,
forecast_at_thresholds_data,
)
if original_mask is not None:
original_mask = np.broadcast_to(original_mask, threshold_cube.shape)
threshold_cube.data = np.ma.MaskedArray(
threshold_cube.data, mask=original_mask
)

return threshold_cube
10 changes: 7 additions & 3 deletions improver_tests/acceptance/SHA256SUMS
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
a6a84d0142796e4b9ca7bd3f0ad78586ea77684f5df02732fdda2ab54233cbb6 ./aggregate-reliability-tables/basic/multiple_tables_kgo.nc
c0c38c5b1ba16fd5f7310b6d2ee0ab9d8b6bcdb3b23c8475193eefae1eb14a17 ./aggregate-reliability-tables/basic/reliability_table.nc
9eeda326cdc66d93e591c87e9f7ceb17cea329397169fb0518f50da3bf4dc61f ./aggregate-reliability-tables/basic/reliability_table_2.nc
9d000581df59946c188943734f480182fcc05eb3cd6c0e38d1dc4fb0c1e93b93 ./apply-beta-recalibration/config.json
68791ea2cedce742ddcd8d9d3cc2fd72bcbb603c0b98a694cddb0d32713019e8 ./apply-beta-recalibration/forecast.nc
93f8710ce0672e4d55cb627ff030c4ef3d40f40bf131a359b1d83a4595a71ee6 ./apply-beta-recalibration/kgo.nc
d82436afd61b2f9739e920c61293dc2bac32292f98255c86a632dd3dff504394 ./apply-bias-correction/20220814T0300Z-PT0003H00M-wind_speed_at_10m.nc
5d259b136452c34a790d97359eebe0e250995463d040d24b77588e43236596cf ./apply-bias-correction/fcst_with_comment/kgo.nc
6b878745e994b0a1516a7b790983c129b87d8f53f6d0e1661e56b1f7ca9fc67a ./apply-bias-correction/masked_bias_data/20220814T0300Z-PT0003H00M-wind_speed_at_10m.nc
Expand Down Expand Up @@ -881,6 +878,13 @@ eb6f7c3f646c4c51a0964b9a19367f43d6e3762ff5523b982cfaf7bf2610f091 ./temporal-int
e3b8f51a0be52c4fead55f95c0e3da29ee3d93f92deed26314e60ad43e8fd5ef ./temporal-interpolate/uv/20181220T1200Z-PT0024H00M-uv_index.nc
b3fde693b3a8e144cb8f9ee9ff23c51ef92701858667cff850b2a49986bacaab ./temporal-interpolate/uv/kgo_t1.nc
1065ae1f25e6bc6df8d02e61c1f8ef92ab3dae679595d5165bd94d9c740adb2c ./temporal-interpolate/uv/kgo_t1_daynight.nc
391049857a990aa80f904fcd2097398df7cd2461b64474dd97e5d03f85ed9705 ./threshold-interpolation/extra_thresholds_kgo.nc
f9d78e938e1f72a18dde7fedfd9d7a93aeac5e667047d48719eeb5f13f2eab04 ./threshold-interpolation/input.nc
e829b8c3cfba204bf798df7865f77bf4d258000ffbf1279c16728a834979224f ./threshold-interpolation/input_realization.nc
9ff6512643b3ec2b20456fbfed3fb096e6d461816f232d319d15d7b8f0d49494 ./threshold-interpolation/masked_cube_kgo.nc
ec73679ff5e308a2bb4d21283262118f8d9fbb6a425309b76d5865a97a773c40 ./threshold-interpolation/masked_input.nc
6058009963941b539117ea44792277253d87c7a1c81318e4836406b5c0b88525 ./threshold-interpolation/realization_collapse_kgo.nc
46b875c8f214610e52ab956ccb5992b35c690a8eb738638ae15401728ef6224d ./threshold-interpolation/save_netcdf.py
ac93ed67c9947547e5879af6faaa329fede18afd822c720ac3afcb18fa41077a ./threshold/basic/input.nc
eb3fdc9400401ec47d95961553aed452abcbd91891d0fbca106b3a05131adaa9 ./threshold/basic/kgo.nc
6b50fa16b663869b3e3fbff36197603886ff7383b2df2a8ba92579bcc9461a16 ./threshold/below_threshold/kgo.nc
Expand Down
52 changes: 52 additions & 0 deletions improver_tests/acceptance/test_threshold_interpolation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# (C) Crown Copyright, Met Office. All rights reserved.
#
# This file is part of 'IMPROVER' and is released under the BSD 3-Clause license.
# See LICENSE in the root of the repository for full licensing details.
"""Test for the threshold interpolation CLI."""

import pytest

from . import acceptance as acc

pytestmark = [pytest.mark.acc, acc.skip_if_kgo_missing]
CLI = acc.cli_name_with_dashes(__file__)
run_cli = acc.run_cli(CLI)


def test_basic(tmp_path):
"""Test basic invocation with threshold argument"""
thresholds = "50.0,200.0,400.0,600.0,1000.0,2000.0,10000.0,25000.0,40000.0"
kgo_dir = acc.kgo_root() / "threshold-interpolation"
kgo_path = kgo_dir / "extra_thresholds_kgo.nc"
input_path = kgo_dir / "input.nc"
output_path = tmp_path / "output.nc"
args = [input_path, "--thresholds", thresholds, "--output", f"{output_path}"]

run_cli(args)
acc.compare(output_path, kgo_path)


def test_realization_collapse(tmp_path):
"""Test realization coordinate is collapsed"""
thresholds = "50.0,200.0,400.0,600.0,1000.0,2000.0,10000.0,25000.0,40000.0"
kgo_dir = acc.kgo_root() / "threshold-interpolation"
kgo_path = kgo_dir / "realization_collapse_kgo.nc"
input_path = kgo_dir / "input_realization.nc"
output_path = tmp_path / "output.nc"
args = [input_path, "--thresholds", thresholds, "--output", f"{output_path}"]

run_cli(args)
acc.compare(output_path, kgo_path)


def test_masked_cube(tmp_path):
"""Test masked cube"""
thresholds = "50.0,200.0,400.0,600.0,1000.0,2000.0,10000.0,25000.0,40000.0"
kgo_dir = acc.kgo_root() / "threshold-interpolation"
kgo_path = kgo_dir / "masked_cube_kgo.nc"
input_path = kgo_dir / "masked_input.nc"
output_path = tmp_path / "output.nc"
args = [input_path, "--thresholds", thresholds, "--output", f"{output_path}"]

run_cli(args)
acc.compare(output_path, kgo_path)
Loading