Skip to content

Commit

Permalink
Fix EWA resampling tests not properly testing caching
Browse files Browse the repository at this point in the history
  • Loading branch information
djhoese committed Jun 26, 2019
1 parent fa6f33e commit 9e36229
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 13 deletions.
13 changes: 6 additions & 7 deletions satpy/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,11 @@ def _call_ll2cr(self, lons, lats, target_geo_def, swath_usage=0):

def precompute(self, cache_dir=None, swath_usage=0, **kwargs):
"""Generate row and column arrays and store it for later use."""
if self.cache:
# this resampler should be used for one SwathDefinition
# no need to recompute ll2cr output again
return None

if kwargs.get('mask') is not None:
LOG.warning("'mask' parameter has no affect during EWA "
"resampling")
Expand Down Expand Up @@ -950,7 +955,7 @@ def prepare_resampler(source_area, destination_area, resampler=None, **resample_

key = (resampler_class,
source_area, destination_area,
hash_dict(resample_kwargs))
hash_dict(resample_kwargs).hexdigest())
try:
resampler_instance = resamplers_cache[key]
except KeyError:
Expand All @@ -962,12 +967,6 @@ def prepare_resampler(source_area, destination_area, resampler=None, **resample_
def resample(source_area, data, destination_area,
resampler=None, **kwargs):
"""Do the resampling."""
if 'resampler_class' in kwargs:
import warnings
warnings.warn("'resampler_class' is deprecated, use 'resampler'",
DeprecationWarning)
resampler = kwargs.pop('resampler_class')

if not isinstance(resampler, BaseResampler):
# we don't use the first argument (cache key)
_, resampler_instance = prepare_resampler(source_area,
Expand Down
26 changes: 20 additions & 6 deletions satpy/tests/test_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,8 @@ class TestEWAResampler(unittest.TestCase):

@mock.patch('satpy.resample.fornav')
@mock.patch('satpy.resample.ll2cr')
def test_2d_ewa(self, ll2cr, fornav):
@mock.patch('satpy.resample.SwathDefinition.get_lonlats')
def test_2d_ewa(self, get_lonlats, ll2cr, fornav):
"""Test EWA with a 2D dataset."""
import numpy as np
import xarray as xr
Expand All @@ -223,7 +224,9 @@ def test_2d_ewa(self, ll2cr, fornav):
fornav.return_value = (100 * 200,
np.zeros((200, 100), dtype=np.float32))
_, _, swath_data, source_swath, target_area = get_test_data()
get_lonlats.return_value = (source_swath.lons, source_swath.lats)
swath_data.data = swath_data.data.astype(np.float32)
num_chunks = len(source_swath.lons.chunks[0]) * len(source_swath.lons.chunks[1])

new_data = resample_dataset(swath_data, target_area, resampler='ewa')
self.assertTupleEqual(new_data.shape, (200, 100))
Expand All @@ -232,16 +235,20 @@ def test_2d_ewa(self, ll2cr, fornav):
self.assertIs(new_data.attrs['area'], target_area)
# make sure we can actually compute everything
new_data.compute()
previous_calls = ll2cr.call_count
lonlat_calls = get_lonlats.call_count
ll2cr_calls = ll2cr.call_count

# resample a different dataset and make sure cache is used
data = xr.DataArray(
swath_data.data,
dims=('y', 'x'), attrs={'area': source_swath, 'test': 'test2',
'name': 'test2'})
new_data = resample_dataset(data, target_area, resampler='ewa')
self.assertEqual(ll2cr.call_count, previous_calls)
new_data.compute()
# ll2cr will be called once more because of the computation
self.assertEqual(ll2cr.call_count, ll2cr_calls + num_chunks)
# but we should already have taken the lonlats from the SwathDefinition
self.assertEqual(get_lonlats.call_count, lonlat_calls)
self.assertIn('y', new_data.coords)
self.assertIn('x', new_data.coords)
if CRS is not None:
Expand All @@ -253,7 +260,8 @@ def test_2d_ewa(self, ll2cr, fornav):

@mock.patch('satpy.resample.fornav')
@mock.patch('satpy.resample.ll2cr')
def test_3d_ewa(self, ll2cr, fornav):
@mock.patch('satpy.resample.SwathDefinition.get_lonlats')
def test_3d_ewa(self, get_lonlats, ll2cr, fornav):
"""Test EWA with a 3D dataset."""
import numpy as np
import xarray as xr
Expand All @@ -266,6 +274,8 @@ def test_3d_ewa(self, ll2cr, fornav):
np.zeros((10, 10), dtype=np.float32))
fornav.return_value = ([100 * 200] * 3,
[np.zeros((200, 100), dtype=np.float32)] * 3)
get_lonlats.return_value = (source_swath.lons, source_swath.lats)
num_chunks = len(source_swath.lons.chunks[0]) * len(source_swath.lons.chunks[1])

new_data = resample_dataset(swath_data, target_area, resampler='ewa')
self.assertTupleEqual(new_data.shape, (3, 200, 100))
Expand All @@ -274,16 +284,20 @@ def test_3d_ewa(self, ll2cr, fornav):
self.assertIs(new_data.attrs['area'], target_area)
# make sure we can actually compute everything
new_data.compute()
previous_calls = ll2cr.call_count
lonlat_calls = get_lonlats.call_count
ll2cr_calls = ll2cr.call_count

# resample a different dataset and make sure cache is used
swath_data = xr.DataArray(
swath_data.data,
dims=('bands', 'y', 'x'), coords={'bands': ['R', 'G', 'B']},
attrs={'area': source_swath, 'test': 'test'})
new_data = resample_dataset(swath_data, target_area, resampler='ewa')
self.assertEqual(ll2cr.call_count, previous_calls)
new_data.compute()
# ll2cr will be called once more because of the computation
self.assertEqual(ll2cr.call_count, ll2cr_calls + num_chunks)
# but we should already have taken the lonlats from the SwathDefinition
self.assertEqual(get_lonlats.call_count, lonlat_calls)
self.assertIn('y', new_data.coords)
self.assertIn('x', new_data.coords)
self.assertIn('bands', new_data.coords)
Expand Down

0 comments on commit 9e36229

Please sign in to comment.