From c353677c9d50a7870f2a2165ff9cb5845d82f683 Mon Sep 17 00:00:00 2001 From: Elynn Wu Date: Mon, 10 Oct 2022 15:34:29 -0700 Subject: [PATCH 1/4] initial change --- driver/pace/driver/state.py | 11 +- physics/pace/physics/physics_state.py | 270 ++++++++++++++++++-------- 2 files changed, 189 insertions(+), 92 deletions(-) diff --git a/driver/pace/driver/state.py b/driver/pace/driver/state.py index 1fa04e590..6278504d8 100644 --- a/driver/pace/driver/state.py +++ b/driver/pace/driver/state.py @@ -162,14 +162,9 @@ def _overwrite_state_from_restart( for _field in fields(type(state)): if "units" in _field.metadata.keys(): if is_gpu_backend: - if "physics" in restart_file_prefix: - state.__dict__[_field.name][:] = gt_utils.asarray( - df[_field.name].data[:], to_type=cp.ndarray - ) - else: - state.__dict__[_field.name].data[:] = gt_utils.asarray( - df[_field.name].data[:], to_type=cp.ndarray - ) + state.__dict__[_field.name].data[:] = gt_utils.asarray( + df[_field.name].data[:], to_type=cp.ndarray + ) else: state.__dict__[_field.name].data[:] = df[_field.name].data[:] return state diff --git a/physics/pace/physics/physics_state.py b/physics/pace/physics/physics_state.py index 67985fce4..d77eb01f3 100644 --- a/physics/pace/physics/physics_state.py +++ b/physics/pace/physics/physics_state.py @@ -1,176 +1,283 @@ from dataclasses import InitVar, dataclass, field, fields from typing import List, Optional -import gt4py.gtscript as gtscript import xarray as xr import pace.dsl.gt4py_utils as gt_utils import pace.util -from pace.dsl.typing import FloatField, FloatFieldIJ from pace.physics.stencils.microphysics import MicrophysicsState @dataclass() class PhysicsState: - qvapor: FloatField = field(metadata={"name": "specific_humidity", "units": "kg/kg"}) - qliquid: FloatField = field( + qvapor: pace.util.Quantity = field( + metadata={ + "name": "specific_humidity", + "dims": [pace.util.X_DIM, pace.util.Y_DIM, pace.util.Z_DIM], + "units": "kg/kg", + } + ) + qliquid: pace.util.Quantity = field( metadata={ "name": "cloud_water_mixing_ratio", + "dims": [pace.util.X_DIM, pace.util.Y_DIM, pace.util.Z_DIM], "units": "kg/kg", "intent": "inout", } ) - qice: FloatField = field( - metadata={"name": "cloud_ice_mixing_ratio", "units": "kg/kg", "intent": "inout"} + qice: pace.util.Quantity = field( + metadata={ + "name": "cloud_ice_mixing_ratio", + "dims": [pace.util.X_DIM, pace.util.Y_DIM, pace.util.Z_DIM], + "units": "kg/kg", + "intent": "inout", + } ) - qrain: FloatField = field( - metadata={"name": "rain_mixing_ratio", "units": "kg/kg", "intent": "inout"} + qrain: pace.util.Quantity = field( + metadata={ + "name": "rain_mixing_ratio", + "dims": [pace.util.X_DIM, pace.util.Y_DIM, pace.util.Z_DIM], + "units": "kg/kg", + "intent": "inout", + } ) - qsnow: FloatField = field( - metadata={"name": "snow_mixing_ratio", "units": "kg/kg", "intent": "inout"} + qsnow: pace.util.Quantity = field( + metadata={ + "name": "snow_mixing_ratio", + "dims": [pace.util.X_DIM, pace.util.Y_DIM, pace.util.Z_DIM], + "units": "kg/kg", + "intent": "inout", + } ) - qgraupel: FloatField = field( - metadata={"name": "graupel_mixing_ratio", "units": "kg/kg", "intent": "inout"} + qgraupel: pace.util.Quantity = field( + metadata={ + "name": "graupel_mixing_ratio", + "dims": [pace.util.X_DIM, pace.util.Y_DIM, pace.util.Z_DIM], + "units": "kg/kg", + "intent": "inout", + } ) - qo3mr: FloatField = field( - metadata={"name": "ozone_mixing_ratio", "units": "kg/kg", "intent": "inout"} + qo3mr: pace.util.Quantity = field( + metadata={ + "name": "ozone_mixing_ratio", + "dims": [pace.util.X_DIM, pace.util.Y_DIM, pace.util.Z_DIM], + "units": "kg/kg", + "intent": "inout", + } ) - qsgs_tke: FloatField = field( + qsgs_tke: pace.util.Quantity = field( metadata={ "name": "turbulent_kinetic_energy", + "dims": [pace.util.X_DIM, pace.util.Y_DIM, pace.util.Z_DIM], "units": "m**2/s**2", "intent": "inout", } ) - qcld: FloatField = field( - metadata={"name": "cloud_fraction", "units": "", "intent": "inout"} + qcld: pace.util.Quantity = field( + metadata={ + "name": "cloud_fraction", + "dims": [pace.util.X_DIM, pace.util.Y_DIM, pace.util.Z_DIM], + "units": "", + "intent": "inout", + } ) - pt: FloatField = field( - metadata={"name": "air_temperature", "units": "degK", "intent": "inout"} + pt: pace.util.Quantity = field( + metadata={ + "name": "air_temperature", + "dims": [pace.util.X_DIM, pace.util.Y_DIM, pace.util.Z_DIM], + "units": "degK", + "intent": "inout", + } ) - delp: FloatField = field( + delp: pace.util.Quantity = field( metadata={ "name": "pressure_thickness_of_atmospheric_layer", + "dims": [pace.util.X_DIM, pace.util.Y_DIM, pace.util.Z_DIM], "units": "Pa", "intent": "inout", } ) - delz: FloatField = field( + delz: pace.util.Quantity = field( metadata={ "name": "vertical_thickness_of_atmospheric_layer", + "dims": [pace.util.X_DIM, pace.util.Y_DIM, pace.util.Z_DIM], "units": "m", "intent": "inout", } ) - ua: FloatField = field( - metadata={"name": "eastward_wind", "units": "m/s", "intent": "inout"} + ua: pace.util.Quantity = field( + metadata={ + "name": "eastward_wind", + "dims": [pace.util.X_DIM, pace.util.Y_DIM, pace.util.Z_DIM], + "units": "m/s", + "intent": "inout", + } ) - va: FloatField = field( - metadata={"name": "northward_wind", "units": "m/s", "intent": "inout"} + va: pace.util.Quantity = field( + metadata={ + "name": "northward_wind", + "dims": [pace.util.X_DIM, pace.util.Y_DIM, pace.util.Z_DIM], + "units": "m/s", + } ) - w: FloatField = field( - metadata={"name": "vertical_wind", "units": "m/s", "intent": "inout"} + w: pace.util.Quantity = field( + metadata={ + "name": "vertical_wind", + "dims": [pace.util.X_DIM, pace.util.Y_DIM, pace.util.Z_DIM], + "units": "m/s", + "intent": "inout", + } ) - omga: FloatField = field( + omga: pace.util.Quantity = field( metadata={ "name": "vertical_pressure_velocity", + "dims": [pace.util.X_DIM, pace.util.Y_DIM, pace.util.Z_DIM], "units": "Pa/s", "intent": "inout", } ) - physics_updated_specific_humidity: FloatField = field( + physics_updated_specific_humidity: pace.util.Quantity = field( metadata={ - "name": "physics_specific_humidity", + "name": "physics_updated_specific_humidity", + "dims": [pace.util.X_DIM, pace.util.Y_DIM, pace.util.Z_DIM], "units": "kg/kg", - "intent": "inout", } ) - physics_updated_qliquid: FloatField = field( + physics_updated_qliquid: pace.util.Quantity = field( metadata={ - "name": "physics_cloud_water_mixing_ratio", + "name": "physics_updated_liquid_water_mixing_ratio", + "dims": [pace.util.X_DIM, pace.util.Y_DIM, pace.util.Z_DIM], "units": "kg/kg", "intent": "inout", } ) - physics_updated_qice: FloatField = field( + physics_updated_qice: pace.util.Quantity = field( metadata={ - "name": "physics_cloud_ice_mixing_ratio", + "name": "physics_updated_ice_water_mixing_ratio", + "dims": [pace.util.X_DIM, pace.util.Y_DIM, pace.util.Z_DIM], "units": "kg/kg", "intent": "inout", } ) - physics_updated_qrain: FloatField = field( + physics_updated_qrain: pace.util.Quantity = field( metadata={ - "name": "physics_rain_mixing_ratio", + "name": "physics_updated_rain_water_mixing_ratio", + "dims": [pace.util.X_DIM, pace.util.Y_DIM, pace.util.Z_DIM], "units": "kg/kg", "intent": "inout", } ) - physics_updated_qsnow: FloatField = field( + physics_updated_qsnow: pace.util.Quantity = field( metadata={ - "name": "physics_snow_mixing_ratio", + "name": "physics_updated_snow_mixing_ratio", + "dims": [pace.util.X_DIM, pace.util.Y_DIM, pace.util.Z_DIM], "units": "kg/kg", "intent": "inout", } ) - physics_updated_qgraupel: FloatField = field( + physics_updated_qgraupel: pace.util.Quantity = field( metadata={ - "name": "physics_graupel_mixing_ratio", + "name": "physics_updated_graupel_mixing_ratio", + "dims": [pace.util.X_DIM, pace.util.Y_DIM, pace.util.Z_DIM], "units": "kg/kg", "intent": "inout", } ) - physics_updated_cloud_fraction: FloatField = field( - metadata={"name": "physics_cloud_fraction", "units": "", "intent": "inout"} + physics_updated_cloud_fraction: pace.util.Quantity = field( + metadata={ + "name": "physics_cloud_fraction", + "dims": [pace.util.X_DIM, pace.util.Y_DIM, pace.util.Z_DIM], + "units": "", + "intent": "inout", + } ) - physics_updated_pt: FloatField = field( - metadata={"name": "physics_air_temperature", "units": "degK", "intent": "inout"} + physics_updated_pt: pace.util.Quantity = field( + metadata={ + "name": "physics_air_temperature", + "dims": [pace.util.X_DIM, pace.util.Y_DIM, pace.util.Z_DIM], + "units": "degK", + "intent": "inout", + } ) - physics_updated_ua: FloatField = field( - metadata={"name": "physics_eastward_wind", "units": "m/s", "intent": "inout"} + physics_updated_ua: pace.util.Quantity = field( + metadata={ + "name": "physics_eastward_wind", + "dims": [pace.util.X_DIM, pace.util.Y_DIM, pace.util.Z_DIM], + "units": "m/s", + "intent": "inout", + } ) - physics_updated_va: FloatField = field( - metadata={"name": "physics_northward_wind", "units": "m/s", "intent": "inout"} + physics_updated_va: pace.util.Quantity = field( + metadata={ + "name": "physics_northward_wind", + "dims": [pace.util.X_DIM, pace.util.Y_DIM, pace.util.Z_DIM], + "units": "m/s", + "intent": "inout", + } ) - delprsi: FloatField = field( + delprsi: pace.util.Quantity = field( metadata={ "name": "model_level_pressure_thickness_in_physics", + "dims": [pace.util.X_DIM, pace.util.Y_DIM, pace.util.Z_DIM], "units": "Pa", "intent": "inout", } ) - phii: FloatField = field( + phii: pace.util.Quantity = field( metadata={ "name": "interface_geopotential_height", + "dims": [pace.util.X_DIM, pace.util.Y_DIM, pace.util.Z_INTERFACE_DIM], "units": "m", "intent": "inout", } ) - phil: FloatField = field( - metadata={"name": "layer_geopotential_height", "units": "m", "intent": "inout"} + phil: pace.util.Quantity = field( + metadata={ + "name": "layer_geopotential_height", + "dims": [pace.util.X_DIM, pace.util.Y_DIM, pace.util.Z_DIM], + "units": "m", + "intent": "inout", + } ) - dz: FloatField = field( + dz: pace.util.Quantity = field( metadata={ "name": "geopotential_height_thickness", + "dims": [pace.util.X_DIM, pace.util.Y_DIM, pace.util.Z_DIM], "units": "m", "intent": "inout", } ) - wmp: FloatField = field( + wmp: pace.util.Quantity = field( metadata={ "name": "layer_mean_vertical_velocity_microph", + "dims": [pace.util.X_DIM, pace.util.Y_DIM, pace.util.Z_DIM], "units": "m/s", "intent": "inout", } ) - prsi: FloatField = field( - metadata={"name": "interface_pressure", "units": "Pa", "intent": "inout"} + prsi: pace.util.Quantity = field( + metadata={ + "name": "interface_pressure", + "dims": [pace.util.X_DIM, pace.util.Y_DIM, pace.util.Z_INTERFACE_DIM], + "units": "Pa", + "intent": "inout", + } ) - prsik: FloatField = field( - metadata={"name": "log_interface_pressure", "units": "Pa", "intent": "inout"} + prsik: pace.util.Quantity = field( + metadata={ + "name": "log_interface_pressure", + "dims": [pace.util.X_DIM, pace.util.Y_DIM, pace.util.Z_INTERFACE_DIM], + "units": "Pa", + "intent": "inout", + } ) - land: FloatFieldIJ = field( - metadata={"name": "land_mask", "units": "-", "intent": "in", "dimensions": "2D"} + land: pace.util.Quantity = field( + metadata={ + "name": "land_mask", + "dims": [pace.util.X_DIM, pace.util.Y_DIM], + "units": "-", + "intent": "in", + } ) quantity_factory: InitVar[pace.util.QuantityFactory] active_packages: InitVar[List[str]] @@ -184,7 +291,7 @@ def __post_init__( [pace.util.X_DIM, pace.util.Y_DIM, pace.util.Z_DIM], "unknown", dtype=float, - ).storage + ).data self.microphysics: Optional[MicrophysicsState] = MicrophysicsState( pt=self.pt, qvapor=self.qvapor, @@ -212,15 +319,10 @@ def __post_init__( def init_zeros(cls, quantity_factory, active_packages: List[str]) -> "PhysicsState": initial_storages = {} for _field in fields(cls): - if len(_field.type.axes) == 2: - dims = [pace.util.X_DIM, pace.util.Y_DIM] - else: - dims = [pace.util.X_DIM, pace.util.Y_DIM, pace.util.Z_DIM] - initial_storages[_field.name] = quantity_factory.zeros( - dims, - _field.metadata["units"], - dtype=float, - ).storage + if "dims" in _field.metadata.keys(): + initial_storages[_field.name] = quantity_factory.zeros( + _field.metadata["dims"], _field.metadata["units"], dtype=float + ).storage return cls( **initial_storages, quantity_factory=quantity_factory, @@ -231,17 +333,17 @@ def init_zeros(cls, quantity_factory, active_packages: List[str]) -> "PhysicsSta def xr_dataset(self): data_vars = {} for name, field_info in self.__dataclass_fields__.items(): - if isinstance(field_info.type, gtscript._FieldDescriptor): - if len(field_info.type.axes) == 2: - dims = [pace.util.X_DIM, pace.util.Y_DIM] - else: - dims = [pace.util.X_DIM, pace.util.Y_DIM, pace.util.Z_DIM] - data_vars[name] = xr.DataArray( - gt_utils.asarray(getattr(self, name).data), - dims=dims, - attrs={ - "long_name": field_info.metadata["name"], - "units": field_info.metadata.get("units", "unknown"), - }, - ) + if name not in ["quantity_factory", "active_packages"]: + if issubclass(field_info.type, pace.util.Quantity): + dims = [ + f"{dim_name}_{name}" for dim_name in field_info.metadata["dims"] + ] + data_vars[name] = xr.DataArray( + gt_utils.asarray(getattr(self, name).data), + dims=dims, + attrs={ + "long_name": field_info.metadata["name"], + "units": field_info.metadata.get("units", "unknown"), + }, + ) return xr.Dataset(data_vars=data_vars) From ddebe12b478a61bbd3d5fe95284558b60c9d660f Mon Sep 17 00:00:00 2001 From: Elynn Wu Date: Tue, 11 Oct 2022 13:51:30 -0700 Subject: [PATCH 2/4] fix savepoint test --- physics/pace/physics/physics_state.py | 26 ++++++++++++++++++- .../translate/translate_microphysics.py | 5 ++-- .../stencils/testing/translate_physics.py | 3 ++- 3 files changed, 30 insertions(+), 4 deletions(-) diff --git a/physics/pace/physics/physics_state.py b/physics/pace/physics/physics_state.py index d77eb01f3..6f2f34c58 100644 --- a/physics/pace/physics/physics_state.py +++ b/physics/pace/physics/physics_state.py @@ -1,5 +1,5 @@ from dataclasses import InitVar, dataclass, field, fields -from typing import List, Optional +from typing import Any, List, Mapping, Optional import xarray as xr @@ -329,6 +329,30 @@ def init_zeros(cls, quantity_factory, active_packages: List[str]) -> "PhysicsSta active_packages=active_packages, ) + @classmethod + def init_from_storages( + cls, + storages: Mapping[str, Any], + sizer: pace.util.GridSizer, + quantity_factory: pace.util.QuantityFactory, + active_packages: List[str], + ): + inputs = {} + for _field in fields(cls): + if "dims" in _field.metadata.keys(): + dims = _field.metadata["dims"] + quantity = pace.util.Quantity( + storages[_field.name], + dims, + _field.metadata["units"], + origin=sizer.get_origin(dims), + extent=sizer.get_extent(dims), + ) + inputs[_field.name] = quantity + return cls( + **inputs, quantity_factory=quantity_factory, active_packages=active_packages + ) + @property def xr_dataset(self): data_vars = {} diff --git a/physics/tests/savepoint/translate/translate_microphysics.py b/physics/tests/savepoint/translate/translate_microphysics.py index 74e9f8bb1..aaa8842b2 100644 --- a/physics/tests/savepoint/translate/translate_microphysics.py +++ b/physics/tests/savepoint/translate/translate_microphysics.py @@ -84,8 +84,9 @@ def compute(self, inputs): quantity_factory = pace.util.QuantityFactory.from_backend( sizer, self.stencil_factory.backend ) - physics_state = PhysicsState( - **inputs, + physics_state = PhysicsState.init_from_storages( + inputs, + sizer=sizer, quantity_factory=quantity_factory, active_packages=["microphysics"], ) diff --git a/stencils/pace/stencils/testing/translate_physics.py b/stencils/pace/stencils/testing/translate_physics.py index 4b2c44d3c..11ca023d5 100644 --- a/stencils/pace/stencils/testing/translate_physics.py +++ b/stencils/pace/stencils/testing/translate_physics.py @@ -143,7 +143,8 @@ def slice_output(self, inputs, out_data=None): roll_zero = info["out_roll_zero"] if "out_roll_zero" in info else False index_order = info["order"] if "order" in info else "C" dycore = info["dycore"] if "dycore" in info else False - data_result.synchronize() + if hasattr(data_result, "synchronize"): + data_result.synchronize() if n_dim == 3: npz = data_result.shape[2] k_length = info["kend"] if "kend" in info else npz From 06ce247900d8770d44810a7f0194e48d51fbe481 Mon Sep 17 00:00:00 2001 From: Elynn Wu Date: Tue, 11 Oct 2022 14:27:59 -0700 Subject: [PATCH 3/4] microphysics state uses quantity too --- physics/pace/physics/physics_state.py | 4 +- physics/pace/physics/stencils/microphysics.py | 57 ++++++++++--------- 2 files changed, 31 insertions(+), 30 deletions(-) diff --git a/physics/pace/physics/physics_state.py b/physics/pace/physics/physics_state.py index 6f2f34c58..c7c0b45c4 100644 --- a/physics/pace/physics/physics_state.py +++ b/physics/pace/physics/physics_state.py @@ -291,7 +291,7 @@ def __post_init__( [pace.util.X_DIM, pace.util.Y_DIM, pace.util.Z_DIM], "unknown", dtype=float, - ).data + ) self.microphysics: Optional[MicrophysicsState] = MicrophysicsState( pt=self.pt, qvapor=self.qvapor, @@ -309,7 +309,7 @@ def __post_init__( delprsi=self.delprsi, wmp=self.wmp, dz=self.dz, - tendency_storage=tendency, + tendency=tendency, land=self.land, ) else: diff --git a/physics/pace/physics/stencils/microphysics.py b/physics/pace/physics/stencils/microphysics.py index 93c052b32..df5595546 100644 --- a/physics/pace/physics/stencils/microphysics.py +++ b/physics/pace/physics/stencils/microphysics.py @@ -12,6 +12,7 @@ import pace.dsl.gt4py_utils as utils import pace.physics.functions.microphysics_funcs as functions +import pace.util import pace.util.constants as constants from pace.dsl.dace.orchestration import orchestrate from pace.dsl.stencil import StencilFactory @@ -1844,24 +1845,24 @@ class MicrophysicsState: def __init__( self, - pt: FloatField, - qvapor: FloatField, - qliquid: FloatField, - qrain: FloatField, - qice: FloatField, - qsnow: FloatField, - qgraupel: FloatField, - qcld: FloatField, - ua: FloatField, - va: FloatField, - delp: FloatField, - delz: FloatField, - omga: FloatField, - delprsi: FloatField, - wmp: FloatField, - dz: FloatField, - tendency_storage: FloatField, - land: FloatField, + pt: pace.util.Quantity, + qvapor: pace.util.Quantity, + qliquid: pace.util.Quantity, + qrain: pace.util.Quantity, + qice: pace.util.Quantity, + qsnow: pace.util.Quantity, + qgraupel: pace.util.Quantity, + qcld: pace.util.Quantity, + ua: pace.util.Quantity, + va: pace.util.Quantity, + delp: pace.util.Quantity, + delz: pace.util.Quantity, + omga: pace.util.Quantity, + delprsi: pace.util.Quantity, + wmp: pace.util.Quantity, + dz: pace.util.Quantity, + tendency: pace.util.Quantity, + land: pace.util.Quantity, ): self.pt = pt self.qvapor = qvapor @@ -1876,16 +1877,16 @@ def __init__( self.delp = delp self.delz = delz self.omga = omga - self.qv_dt = copy.deepcopy(tendency_storage) - self.ql_dt = copy.deepcopy(tendency_storage) - self.qr_dt = copy.deepcopy(tendency_storage) - self.qi_dt = copy.deepcopy(tendency_storage) - self.qs_dt = copy.deepcopy(tendency_storage) - self.qg_dt = copy.deepcopy(tendency_storage) - self.qa_dt = copy.deepcopy(tendency_storage) - self.udt = copy.deepcopy(tendency_storage) - self.vdt = copy.deepcopy(tendency_storage) - self.pt_dt = copy.deepcopy(tendency_storage) + self.qv_dt = copy.deepcopy(tendency) + self.ql_dt = copy.deepcopy(tendency) + self.qr_dt = copy.deepcopy(tendency) + self.qi_dt = copy.deepcopy(tendency) + self.qs_dt = copy.deepcopy(tendency) + self.qg_dt = copy.deepcopy(tendency) + self.qa_dt = copy.deepcopy(tendency) + self.udt = copy.deepcopy(tendency) + self.vdt = copy.deepcopy(tendency) + self.pt_dt = copy.deepcopy(tendency) self.delprsi = delprsi self.wmp = wmp self.dz = dz From 095797efe8af8b90bdd1e633172687ba64891989 Mon Sep 17 00:00:00 2001 From: Elynn Wu Date: Thu, 13 Oct 2022 09:47:39 -0700 Subject: [PATCH 4/4] address PR comments --- physics/pace/physics/physics_state.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/physics/pace/physics/physics_state.py b/physics/pace/physics/physics_state.py index c7c0b45c4..7499c393d 100644 --- a/physics/pace/physics/physics_state.py +++ b/physics/pace/physics/physics_state.py @@ -1,5 +1,5 @@ from dataclasses import InitVar, dataclass, field, fields -from typing import Any, List, Mapping, Optional +from typing import Any, Dict, List, Mapping, Optional import xarray as xr @@ -336,8 +336,8 @@ def init_from_storages( sizer: pace.util.GridSizer, quantity_factory: pace.util.QuantityFactory, active_packages: List[str], - ): - inputs = {} + ) -> "PhysicsState": + inputs: Dict[str, pace.util.Quantity] = {} for _field in fields(cls): if "dims" in _field.metadata.keys(): dims = _field.metadata["dims"]