Skip to content

Commit

Permalink
[#2] conforming dot-product outs to orig structure
Browse files Browse the repository at this point in the history
  • Loading branch information
jsadler2 committed Dec 3, 2021
1 parent 0770ca8 commit 4787d99
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 12 deletions.
11 changes: 11 additions & 0 deletions xagg/aux.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,17 @@ def find_rel_area(df):
return df


def list_or_first(ser):
lis = list(ser)
# only the columns associated with the pixels should have multiple values;
# for all other columns (those associated with the polygons), it should be
# safe to return just the first item
if all(x == lis[0] for x in lis) and ser.name not in ['pix_idx', 'coords', 'rel_area', 'lat', 'lon']:
return lis[0]
else:
return lis


def fix_ds(ds,var_cipher = {'latitude':{'latitude':'lat','longitude':'lon'},
'Latitude':{'Latitude':'lat','Longitude':'lon'},
'Lat':{'Lat':'lat','Lon':'lon'},
Expand Down
3 changes: 2 additions & 1 deletion xagg/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@ class weightmap(object):
""" Class for mapping from pixels to polgyons, output from :func:`xagg.wrappers.pixel_overlaps`
"""
def __init__(self,agg,source_grid,geometry,weights='nowghts'):
def __init__(self,agg,source_grid,geometry,overlap_da,weights='nowghts'):
self.agg = agg
self.source_grid = source_grid
self.geometry = geometry
self.weights = weights
self.overlap_da = overlap_da

def diag_fig(self,poly_idx):
""" (NOT YET IMPLEMENTED) a diagnostic figure of the aggregation
Expand Down
49 changes: 38 additions & 11 deletions xagg/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import warnings
import xesmf as xe

from . aux import (find_rel_area,fix_ds,get_bnds,subset_find)
from . aux import (find_rel_area,fix_ds,get_bnds,subset_find,list_or_first)
from . classes import (weightmap,aggregated)


Expand Down Expand Up @@ -306,22 +306,40 @@ def get_pixel_overlaps(gdf_in,pix_agg):
overlaps = gpd.overlay(gdf_in.to_crs(epsg_set),
pix_agg['gdf_pixels'].to_crs(epsg_set),
how='intersection')

overlaps = overlaps.groupby('poly_idx').apply(find_rel_area)
overlaps['lat'] = overlaps['lat'].astype(float)
overlaps['lon'] = overlaps['lon'].astype(float)


# Drop 'geometry' eventually, just for size/clarity
overlaps = overlaps.drop('geometry', axis=1)

# Now, group by poly_idx (each polygon in the shapefile)
ov_groups = overlaps.groupby('poly_idx')

overlap_info = ov_groups.agg(list_or_first)

overlap_info = overlap_info.rename(columns={'pix_idx': 'pix_idxs'})

# Zip lat, lon columns into a list of (lat,lon) coordinates
# (separate from above because as of 12/20, named aggs with
# multiple columns is still an open issue in the pandas github)
overlap_info['coords'] = overlap_info.apply(lambda row: list(zip(row['lat'],row['lon'])),axis=1)
overlap_info = overlap_info.drop(columns=['lat','lon'])

# Reset index to make poly_idx a column for merging with gdf_in
overlap_info = overlap_info.reset_index()

# Merge in pixel overlaps to the input polygon geodataframe
overlap_columns = ['pix_idxs', 'rel_area', 'coords', 'poly_idx']
gdf_in = pd.merge(gdf_in, overlap_info[overlap_columns],'outer', on='poly_idx')

# make the weight grid an xarray dataset for later dot product
idx_cols = ['lat', 'lon', 'poly_idx']
overlaps_da = overlaps.set_index(idx_cols)['rel_area'].to_xarray()
overlaps_da = overlaps_da.stack(loc=['lat', 'lon'])
wm_out = weightmap(agg=overlaps_da,
overlap_da = overlaps.set_index(idx_cols)['rel_area'].to_xarray()
overlap_da = overlap_da.stack(loc=['lat', 'lon'])
overlap_da = overlap_da.fillna(0)
wm_out = weightmap(agg=gdf_in.drop('geometry', axis=1),
source_grid=pix_agg['source_grid'],
geometry=gdf_in.geometry)
geometry=gdf_in.geometry,
overlap_da = overlap_da)

if 'weights' in pix_agg['gdf_pixels'].columns:
wm_out.weights = pix_agg['gdf_pixels'].weights
Expand Down Expand Up @@ -404,10 +422,19 @@ def aggregate(ds,wm):
if ('bnds' not in ds[var].dims) & ('loc' in ds[var].dims):
print('aggregating '+var+'...')
var_array = ds[var]
var_array = wm.agg.dot(var_array)
var_array = wm.overlap_da.dot(var_array)
data_dict[var] = var_array

ds_combined = xr.Dataset(data_dict)
df_combined = ds_combined.to_dataframe().reset_index()
df_combined = df_combined.groupby('poly_idx').agg(list_or_first)

wm.agg = pd.merge(wm.agg, df_combined, on='poly_idx')
for var in ds.var():
if ('bnds' not in ds[var].dims) & ('loc' in ds[var].dims):
# convert to list of arrays - NOT SURE THIS IS THE RIGHT THING TO
# DO, JUST TRYING TO MATCH ORIGINAL FORMAT
wm.agg[var] = wm.agg[var].apply(np.array).apply(lambda x: [x])

# Put in class format
agg_out = aggregated(agg=wm.agg,source_grid=wm.source_grid,
Expand Down

0 comments on commit 4787d99

Please sign in to comment.