From 299b38cc09c5fa2872f901f7549f6adc852cb116 Mon Sep 17 00:00:00 2001 From: Kevin Schwarzwald Date: Wed, 30 Jun 2021 18:32:29 -0400 Subject: [PATCH] allow dataarray inputs into pixel_overlaps [issue 7] xr.DataArrays are now coerced into xr.Datasets at the start of both pixel_overlaps and aggregate. This brings behavior in line with the docs of pixel_overlaps, which allow DataArray inputs. Allowing xr.DataArray inputs for aggregate() will not be added to the docs (if the xr.DataArray is unnamed, then the name 'var' will be given to the variable; not ideal behavior); a warning will be thrown if an xr.DataArray is nevertheless used and is unnamed. Additionally, tests were added to make sure that pixel_overlaps() works with an xr.DataArray input. --- tests/test_core.py | 1 - tests/test_wrappers.py | 71 ++++++++++++++++++++++++++++++++++++++++++ xagg/core.py | 9 ++++++ xagg/wrappers.py | 8 +++++ 4 files changed, 88 insertions(+), 1 deletion(-) create mode 100644 tests/test_wrappers.py diff --git a/tests/test_core.py b/tests/test_core.py index bafc34e..c53f829 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -458,4 +458,3 @@ def test_aggregate_with_some_nans(): - diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py new file mode 100644 index 0000000..848239c --- /dev/null +++ b/tests/test_wrappers.py @@ -0,0 +1,71 @@ +import pytest +import pandas as pd +import numpy as np +import xarray as xr +import geopandas as gpd +from geopandas import testing as gpdt +from unittest import TestCase +from shapely.geometry import Polygon +import xesmf as xe + +from xagg.wrappers import (pixel_overlaps) + + +##### pixel_overlaps() tests ##### +def test_pixel_overlaps_dataarray(): + # To make sure that pixel_overlaps works with an unnamed dataarray + da = xr.DataArray(data=np.array([[0,1],[2,3]]), + coords={'lat':(['lat'],np.array([0,1])), + 'lon':(['lon'],np.array([0,1]))}, + dims=['lon','lat']) + + # Create polygon covering one pixel + gdf_test = {'name':['test'], + 'geometry':[Polygon([(-0.5,-0.5),(-0.5,0.5),(0.5,0.5),(0.5,-0.5),(-0.5,-0.5)])]} + gdf_test = gpd.GeoDataFrame(gdf_test,crs="EPSG:4326") + + # Calculate pixel_overlaps through the wrapper function, + # which should change the dataarray to a dataframe + wm = pixel_overlaps(da,gdf_test) + + df0 = pd.DataFrame(wm.agg) + + # Define what the output should be + df_compare = pd.DataFrame({'name':['test'],'poly_idx':0, + 'rel_area':[[[1.0]]],'pix_idxs':[[0]], + 'coords':[[(0,0)]]}) + + + assert np.allclose([v for v in df0.rel_area],[v for v in df_compare.rel_area]) + assert np.allclose([v for v in df0.pix_idxs],[v for v in df_compare.pix_idxs]) + assert np.allclose([v for v in df0.coords],[v for v in df_compare.coords]) + + +def test_pixel_overlaps_dataarray_wname(): + # To make sure that pixel_overlaps works with a named dataarray + da = xr.DataArray(data=np.array([[0,1],[2,3]]), + coords={'lat':(['lat'],np.array([0,1])), + 'lon':(['lon'],np.array([0,1]))}, + dims=['lon','lat'], + name='tas') + + # Create polygon covering one pixel + gdf_test = {'name':['test'], + 'geometry':[Polygon([(-0.5,-0.5),(-0.5,0.5),(0.5,0.5),(0.5,-0.5),(-0.5,-0.5)])]} + gdf_test = gpd.GeoDataFrame(gdf_test,crs="EPSG:4326") + + # Calculate pixel_overlaps through the wrapper function, + # which should change the dataarray to a dataframe + wm = pixel_overlaps(da,gdf_test) + + df0 = pd.DataFrame(wm.agg) + + # Define what the output should be + df_compare = pd.DataFrame({'name':['test'],'poly_idx':0, + 'rel_area':[[[1.0]]],'pix_idxs':[[0]], + 'coords':[[(0,0)]]}) + + + assert np.allclose([v for v in df0.rel_area],[v for v in df_compare.rel_area]) + assert np.allclose([v for v in df0.pix_idxs],[v for v in df_compare.pix_idxs]) + assert np.allclose([v for v in df0.coords],[v for v in df_compare.coords]) \ No newline at end of file diff --git a/xagg/core.py b/xagg/core.py index 29f161e..d6ec392 100644 --- a/xagg/core.py +++ b/xagg/core.py @@ -395,6 +395,15 @@ def aggregate(ds,wm): an :class:`xagg.classes.aggregated` object with the aggregated variables """ + # Turn into dataset if dataarray + if type(ds)==xr.core.dataarray.DataArray: + if ds.name is None: + warnings.warn('An unnamed xr.DataArray was inputted instead of a xr.Dataset; the output variable will be "var"') + ds = ds.to_dataset(name='var') + else: + ds = ds.to_dataset() + + # Run ds through fix_ds (to fix lat/lon names, lon coords) ds = fix_ds(ds) diff --git a/xagg/wrappers.py b/xagg/wrappers.py index f51714d..75789cc 100644 --- a/xagg/wrappers.py +++ b/xagg/wrappers.py @@ -1,4 +1,5 @@ import warnings +import xarray as xr from . core import (create_raster_polygons,get_pixel_overlaps) @@ -53,6 +54,13 @@ def pixel_overlaps(ds,gdf_in, input into :func:`xagg.core.aggregate`. """ + + # Turn into dataset if dataarray + if type(ds)==xr.core.dataarray.DataArray: + if ds.name is None: + ds = ds.to_dataset(name='var') + else: + ds = ds.to_dataset() # Create a polygon for each pixel print('creating polygons for each pixel...')