Skip to content

Commit

Permalink
allow dataarray inputs into pixel_overlaps [issue 7]
Browse files Browse the repository at this point in the history
xr.DataArrays are now coerced into xr.Datasets at the start of both pixel_overlaps and aggregate. This brings behavior in line with the docs of pixel_overlaps, which allow DataArray inputs. Allowing xr.DataArray inputs for aggregate() will not be added to the docs (if the xr.DataArray is unnamed, then the name 'var' will be given to the variable; not ideal behavior); a warning will be thrown if an xr.DataArray is nevertheless used and is unnamed.

Additionally, tests were added to make sure that pixel_overlaps() works with an xr.DataArray input.
  • Loading branch information
ks905383 committed Jun 30, 2021
1 parent d8f9be0 commit 299b38c
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 1 deletion.
1 change: 0 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,4 +458,3 @@ def test_aggregate_with_some_nans():




71 changes: 71 additions & 0 deletions tests/test_wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import pytest
import pandas as pd
import numpy as np
import xarray as xr
import geopandas as gpd
from geopandas import testing as gpdt
from unittest import TestCase
from shapely.geometry import Polygon
import xesmf as xe

from xagg.wrappers import (pixel_overlaps)


##### pixel_overlaps() tests #####
def test_pixel_overlaps_dataarray():
# To make sure that pixel_overlaps works with an unnamed dataarray
da = xr.DataArray(data=np.array([[0,1],[2,3]]),
coords={'lat':(['lat'],np.array([0,1])),
'lon':(['lon'],np.array([0,1]))},
dims=['lon','lat'])

# Create polygon covering one pixel
gdf_test = {'name':['test'],
'geometry':[Polygon([(-0.5,-0.5),(-0.5,0.5),(0.5,0.5),(0.5,-0.5),(-0.5,-0.5)])]}
gdf_test = gpd.GeoDataFrame(gdf_test,crs="EPSG:4326")

# Calculate pixel_overlaps through the wrapper function,
# which should change the dataarray to a dataframe
wm = pixel_overlaps(da,gdf_test)

df0 = pd.DataFrame(wm.agg)

# Define what the output should be
df_compare = pd.DataFrame({'name':['test'],'poly_idx':0,
'rel_area':[[[1.0]]],'pix_idxs':[[0]],
'coords':[[(0,0)]]})


assert np.allclose([v for v in df0.rel_area],[v for v in df_compare.rel_area])
assert np.allclose([v for v in df0.pix_idxs],[v for v in df_compare.pix_idxs])
assert np.allclose([v for v in df0.coords],[v for v in df_compare.coords])


def test_pixel_overlaps_dataarray_wname():
# To make sure that pixel_overlaps works with a named dataarray
da = xr.DataArray(data=np.array([[0,1],[2,3]]),
coords={'lat':(['lat'],np.array([0,1])),
'lon':(['lon'],np.array([0,1]))},
dims=['lon','lat'],
name='tas')

# Create polygon covering one pixel
gdf_test = {'name':['test'],
'geometry':[Polygon([(-0.5,-0.5),(-0.5,0.5),(0.5,0.5),(0.5,-0.5),(-0.5,-0.5)])]}
gdf_test = gpd.GeoDataFrame(gdf_test,crs="EPSG:4326")

# Calculate pixel_overlaps through the wrapper function,
# which should change the dataarray to a dataframe
wm = pixel_overlaps(da,gdf_test)

df0 = pd.DataFrame(wm.agg)

# Define what the output should be
df_compare = pd.DataFrame({'name':['test'],'poly_idx':0,
'rel_area':[[[1.0]]],'pix_idxs':[[0]],
'coords':[[(0,0)]]})


assert np.allclose([v for v in df0.rel_area],[v for v in df_compare.rel_area])
assert np.allclose([v for v in df0.pix_idxs],[v for v in df_compare.pix_idxs])
assert np.allclose([v for v in df0.coords],[v for v in df_compare.coords])
9 changes: 9 additions & 0 deletions xagg/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,15 @@ def aggregate(ds,wm):
an :class:`xagg.classes.aggregated` object with the aggregated variables
"""
# Turn into dataset if dataarray
if type(ds)==xr.core.dataarray.DataArray:
if ds.name is None:
warnings.warn('An unnamed xr.DataArray was inputted instead of a xr.Dataset; the output variable will be "var"')
ds = ds.to_dataset(name='var')
else:
ds = ds.to_dataset()


# Run ds through fix_ds (to fix lat/lon names, lon coords)
ds = fix_ds(ds)

Expand Down
8 changes: 8 additions & 0 deletions xagg/wrappers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings
import xarray as xr

from . core import (create_raster_polygons,get_pixel_overlaps)

Expand Down Expand Up @@ -53,6 +54,13 @@ def pixel_overlaps(ds,gdf_in,
input into :func:`xagg.core.aggregate`.
"""

# Turn into dataset if dataarray
if type(ds)==xr.core.dataarray.DataArray:
if ds.name is None:
ds = ds.to_dataset(name='var')
else:
ds = ds.to_dataset()

# Create a polygon for each pixel
print('creating polygons for each pixel...')
Expand Down

0 comments on commit 299b38c

Please sign in to comment.