From 9db70c43c01505edbe67bc0fa4805a9ebfaad25c Mon Sep 17 00:00:00 2001 From: Lukas Pilz Date: Fri, 16 Sep 2022 19:07:54 +0200 Subject: [PATCH 1/5] Fixed missing attributes after Dataset destagger --- tests/test_accessors.py | 8 ++++++++ xwrf/accessors.py | 4 ++++ 2 files changed, 12 insertions(+) diff --git a/tests/test_accessors.py b/tests/test_accessors.py index a64db416..207c2f6c 100644 --- a/tests/test_accessors.py +++ b/tests/test_accessors.py @@ -110,5 +110,13 @@ def test_dataset_destagger(test_grid): or destaggered[varname].attrs['stagger'] == '' ) + # Check preservation of variable attrs + for varname in set(test_grid.data_vars).intersection(set(destaggered.data_vars)): + # have to pop 'units' too, because dimensionless units have attr removed on postprocess + for key in ['stagger', 'units']: + if key in test_grid[varname].attrs: + test_grid[varname].attrs.pop(key) + assert set(test_grid[varname].attrs.keys()) <= set(destaggered[varname].attrs.keys()) + # Check that attrs are preserved assert destaggered.attrs == test_grid.attrs diff --git a/xwrf/accessors.py b/xwrf/accessors.py index 5473cc21..0b4a3aea 100644 --- a/xwrf/accessors.py +++ b/xwrf/accessors.py @@ -175,6 +175,7 @@ def destagger(self, staggered_to_unstaggered_dims: dict[str, str] | None = None) # Found a staggered dim # TODO: should we raise an error if somehow end up with more than just one # staggered dim, or just pick one from the set like below? + _attrs = var_data.attrs this_staggered_dim = this_staggered_dims.pop() new_data_vars[var_name] = _destag_variable( var_data.variable, @@ -185,6 +186,9 @@ def destagger(self, staggered_to_unstaggered_dims: dict[str, str] | None = None) else staggered_to_unstaggered_dims[this_staggered_dim] ), ) + if 'stagger' in _attrs: + _attrs.pop('stagger') + new_data_vars[var_name].attrs = _attrs else: # No staggered dims new_data_vars[var_name] = var_data.variable From fda79991f4e6e702d5d57853b24163175fc5ec76 Mon Sep 17 00:00:00 2001 From: Lukas Pilz Date: Fri, 16 Sep 2022 19:51:10 +0200 Subject: [PATCH 2/5] Fixed same bug for DataArray and added test --- tests/test_accessors.py | 6 ++++++ xwrf/accessors.py | 5 ++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/test_accessors.py b/tests/test_accessors.py index 207c2f6c..c0f52c1f 100644 --- a/tests/test_accessors.py +++ b/tests/test_accessors.py @@ -79,6 +79,11 @@ def test_dataarray_destagger(test_grid): xr.testing.assert_allclose(destaggered['XLAT'], test_grid['XLAT']) xr.testing.assert_allclose(destaggered['XLONG'], test_grid['XLONG']) + # Check attributes are preserved + if 'stagger' in data.attrs: + data.attrs.pop('stagger') + assert set(destaggered.attrs.keys()) == set(data.attrs.keys()) + @pytest.mark.parametrize('test_grid', ['lambert_conformal', 'mercator'], indirect=True) def test_dataarray_destagger_with_exclude(test_grid): @@ -116,6 +121,7 @@ def test_dataset_destagger(test_grid): for key in ['stagger', 'units']: if key in test_grid[varname].attrs: test_grid[varname].attrs.pop(key) + # because of xwrf.postprocess, the destaggered attrs will include more information assert set(test_grid[varname].attrs.keys()) <= set(destaggered[varname].attrs.keys()) # Check that attrs are preserved diff --git a/xwrf/accessors.py b/xwrf/accessors.py index 0b4a3aea..a40a83fa 100644 --- a/xwrf/accessors.py +++ b/xwrf/accessors.py @@ -65,6 +65,7 @@ def destagger( and/or use cases. For full accuracy, auxiliary coordinates should be re-computed from dimension coordinates or obtained from the original dataset. """ + _attrs = self.xarray_obj.variable.attrs new_variable = _destag_variable( self.xarray_obj.variable, stagger_dim=stagger_dim, unstag_dim_name=unstaggered_dim_name ) @@ -86,7 +87,9 @@ def destagger( else: new_coords[coord_name] = coord_data.variable - return xr.DataArray(new_variable, coords=new_coords) + if 'stagger' in _attrs: + _attrs.pop('stagger') + return xr.DataArray(new_variable, coords=new_coords, attrs=_attrs) @xr.register_dataset_accessor('xwrf') From a9faa147ed51b7dd3406756c2585f7ec3557ee96 Mon Sep 17 00:00:00 2001 From: Lukas Pilz Date: Fri, 16 Sep 2022 21:02:01 +0200 Subject: [PATCH 3/5] Muuuch cleaner implementation, clearly didn't have enough coffee yet ;) --- tests/test_accessors.py | 14 ++++++-------- xwrf/accessors.py | 9 +-------- xwrf/destagger.py | 2 +- 3 files changed, 8 insertions(+), 17 deletions(-) diff --git a/tests/test_accessors.py b/tests/test_accessors.py index c0f52c1f..794c603f 100644 --- a/tests/test_accessors.py +++ b/tests/test_accessors.py @@ -80,9 +80,9 @@ def test_dataarray_destagger(test_grid): xr.testing.assert_allclose(destaggered['XLONG'], test_grid['XLONG']) # Check attributes are preserved - if 'stagger' in data.attrs: - data.attrs.pop('stagger') - assert set(destaggered.attrs.keys()) == set(data.attrs.keys()) + assert set(destaggered.attrs.keys()) == set(data.attrs.keys()) - { + 'stagger', + } @pytest.mark.parametrize('test_grid', ['lambert_conformal', 'mercator'], indirect=True) @@ -117,12 +117,10 @@ def test_dataset_destagger(test_grid): # Check preservation of variable attrs for varname in set(test_grid.data_vars).intersection(set(destaggered.data_vars)): - # have to pop 'units' too, because dimensionless units have attr removed on postprocess - for key in ['stagger', 'units']: - if key in test_grid[varname].attrs: - test_grid[varname].attrs.pop(key) # because of xwrf.postprocess, the destaggered attrs will include more information - assert set(test_grid[varname].attrs.keys()) <= set(destaggered[varname].attrs.keys()) + assert set(test_grid[varname].attrs.keys()) - {'stagger', 'units'} <= set( + destaggered[varname].attrs.keys() + ) # Check that attrs are preserved assert destaggered.attrs == test_grid.attrs diff --git a/xwrf/accessors.py b/xwrf/accessors.py index a40a83fa..5473cc21 100644 --- a/xwrf/accessors.py +++ b/xwrf/accessors.py @@ -65,7 +65,6 @@ def destagger( and/or use cases. For full accuracy, auxiliary coordinates should be re-computed from dimension coordinates or obtained from the original dataset. """ - _attrs = self.xarray_obj.variable.attrs new_variable = _destag_variable( self.xarray_obj.variable, stagger_dim=stagger_dim, unstag_dim_name=unstaggered_dim_name ) @@ -87,9 +86,7 @@ def destagger( else: new_coords[coord_name] = coord_data.variable - if 'stagger' in _attrs: - _attrs.pop('stagger') - return xr.DataArray(new_variable, coords=new_coords, attrs=_attrs) + return xr.DataArray(new_variable, coords=new_coords) @xr.register_dataset_accessor('xwrf') @@ -178,7 +175,6 @@ def destagger(self, staggered_to_unstaggered_dims: dict[str, str] | None = None) # Found a staggered dim # TODO: should we raise an error if somehow end up with more than just one # staggered dim, or just pick one from the set like below? - _attrs = var_data.attrs this_staggered_dim = this_staggered_dims.pop() new_data_vars[var_name] = _destag_variable( var_data.variable, @@ -189,9 +185,6 @@ def destagger(self, staggered_to_unstaggered_dims: dict[str, str] | None = None) else staggered_to_unstaggered_dims[this_staggered_dim] ), ) - if 'stagger' in _attrs: - _attrs.pop('stagger') - new_data_vars[var_name].attrs = _attrs else: # No staggered dims new_data_vars[var_name] = var_data.variable diff --git a/xwrf/destagger.py b/xwrf/destagger.py index 9752db87..b0f82bad 100644 --- a/xwrf/destagger.py +++ b/xwrf/destagger.py @@ -73,7 +73,7 @@ def _destag_variable(datavar, stagger_dim=None, unstag_dim_name=None): return xr.Variable( dims=tuple(str(unstag_dim_name) if dim == stagger_dim else dim for dim in center_mean.dims), data=center_mean.data, - attrs=_drop_attrs(center_mean.attrs, ('stagger',)), + attrs=_drop_attrs(datavar.attrs, ('stagger',)), encoding=center_mean.encoding, fastpath=True, ) From 4f3806849421fb16605076b3c24998904daf7d53 Mon Sep 17 00:00:00 2001 From: Lukas Pilz Date: Fri, 16 Sep 2022 13:45:22 -0600 Subject: [PATCH 4/5] Apply suggestions from code review Co-authored-by: jthielen --- xwrf/destagger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xwrf/destagger.py b/xwrf/destagger.py index b0f82bad..b8f539c0 100644 --- a/xwrf/destagger.py +++ b/xwrf/destagger.py @@ -74,7 +74,7 @@ def _destag_variable(datavar, stagger_dim=None, unstag_dim_name=None): dims=tuple(str(unstag_dim_name) if dim == stagger_dim else dim for dim in center_mean.dims), data=center_mean.data, attrs=_drop_attrs(datavar.attrs, ('stagger',)), - encoding=center_mean.encoding, + encoding=datavar.encoding, fastpath=True, ) From 27d1b6715c1e1cc0826d959270d005b476b5152e Mon Sep 17 00:00:00 2001 From: Lukas Pilz Date: Fri, 16 Sep 2022 21:48:06 +0200 Subject: [PATCH 5/5] Added attrs test for _destag_var --- tests/test_destagger.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_destagger.py b/tests/test_destagger.py index 695068f8..f79c968d 100644 --- a/tests/test_destagger.py +++ b/tests/test_destagger.py @@ -63,14 +63,16 @@ def test_destag_variable_multiple_dims(): ], ) def test_destag_variable_1d(unstag_dim_name, expected_output_dim_name): - staggered = xr.Variable(('bottom_top_stag',), np.arange(5), attrs={'stagger': 'Z'}) + staggered = xr.Variable( + ('bottom_top_stag',), np.arange(5), attrs={'foo': 'bar', 'stagger': 'Z'} + ) output = _destag_variable(staggered, unstag_dim_name=unstag_dim_name) # Check values np.testing.assert_array_almost_equal(output.values, 0.5 + np.arange(4)) # Check dim name assert output.dims[0] == expected_output_dim_name # Check attrs - assert not output.attrs + assert output.attrs == {'foo': 'bar'} def test_destag_variable_2d():