Skip to content

Commit

Permalink
Merge pull request #224 from pytroll/bugfix-check-areas
Browse files Browse the repository at this point in the history
Add helper method for checking areas in compositors
  • Loading branch information
djhoese authored Mar 16, 2018
2 parents 97b731b + a323a52 commit 547819c
Show file tree
Hide file tree
Showing 5 changed files with 237 additions and 26 deletions.
56 changes: 37 additions & 19 deletions satpy/composites/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,30 @@ def apply_modifier_info(self, origin, destination):
elif o.get(k) is not None:
d[k] = o[k]

def check_areas(self, data_arrays):
if len(data_arrays) == 1:
return data_arrays

if 'x' in data_arrays[0].dims and \
not all(x.sizes['x'] == data_arrays[0].sizes['x']
for x in data_arrays[1:]):
raise IncompatibleAreas("X dimension has different sizes")
if 'y' in data_arrays[0].dims and \
not all(x.sizes['y'] == data_arrays[0].sizes['y']
for x in data_arrays[1:]):
raise IncompatibleAreas("Y dimension has different sizes")

areas = [ds.attrs.get('area') for ds in data_arrays]
if not areas or any(a is None for a in areas):
raise ValueError("Missing 'area' attribute")

if not all(areas[0] == x for x in areas[1:]):
LOG.debug("Not all areas are the same in "
"'{}'".format(self.attrs['name']))
raise IncompatibleAreas("Areas are different")

return data_arrays


class SunZenithCorrectorBase(CompositeBase):

Expand Down Expand Up @@ -382,6 +406,7 @@ def __call__(self, projectables, optional_datasets=None, **info):
sunalt, suna = get_alt_az(vis.attrs['start_time'], lons, lats)
suna = xu.rad2deg(suna)
sunz = sun_zenith_angle(vis.attrs['start_time'], lons, lats)
# FIXME: Make it daskified
sata, satel = get_observer_look(vis.attrs['satellite_longitude'],
vis.attrs['satellite_latitude'],
vis.attrs['satellite_altitude'],
Expand Down Expand Up @@ -577,16 +602,8 @@ class GenericCompositor(CompositeBase):

modes = {1: 'L', 2: 'LA', 3: 'RGB', 4: 'RGBA'}

def check_area_compatibility(self, projectables):
areas = [projectable.attrs.get('area', None)
for projectable in projectables]
areas = [area for area in areas if area is not None]
if areas and areas.count(areas[0]) != len(areas):
LOG.debug("Not all areas are the same in '{}'".format(self.attrs['name']))
raise IncompatibleAreas

def _concat_datasets(self, projectables, mode):
self.check_area_compatibility(projectables)
projectables = self.check_areas(projectables)

try:
data = xr.concat(projectables, 'bands', coords='minimal')
Expand Down Expand Up @@ -1012,16 +1029,10 @@ def __call__(self, datasets, optional_datasets=None, **info):
'the same size. Must resample first.')

new_attrs = {}
p1, p2, p3 = datasets
if optional_datasets:
high_res = optional_datasets[0]
low_res = datasets[["red", "green", "blue"].index(
self.high_resolution_band)]
if high_res.attrs["area"] != low_res.attrs["area"]:
raise IncompatibleAreas("High resolution band is not "
"mapped to the same area as the "
"low resolution bands. Must "
"resample first.")
datasets = self.check_areas(datasets + optional_datasets)
high_res = datasets[-1]
p1, p2, p3 = datasets[:3]
if 'rows_per_scan' in high_res.attrs:
new_attrs.setdefault('rows_per_scan',
high_res.attrs['rows_per_scan'])
Expand All @@ -1035,27 +1046,34 @@ def __call__(self, datasets, optional_datasets=None, **info):
r = high_res
g = p2 * ratio
b = p3 * ratio
g.attrs = p2.attrs.copy()
b.attrs = p3.attrs.copy()
elif self.high_resolution_band == "green":
LOG.debug("Sharpening image with high resolution green band")
ratio = high_res / p2
ratio = ratio.where(xu.isfinite(ratio) | (ratio >= 0), 1.)
r = p1 * ratio
g = high_res
b = p3 * ratio
r.attrs = p1.attrs.copy()
b.attrs = p3.attrs.copy()
elif self.high_resolution_band == "blue":
LOG.debug("Sharpening image with high resolution blue band")
ratio = high_res / p3
ratio = ratio.where(xu.isfinite(ratio) | (ratio >= 0), 1.)
r = p1 * ratio
g = p2 * ratio
b = high_res
r.attrs = p1.attrs.copy()
g.attrs = p2.attrs.copy()
else:
# no sharpening
r = p1
g = p2
b = p3
else:
r, g, b = p1, p2, p3
datasets = self.check_areas(datasets)
r, g, b = datasets[:3]
# combine the masks
mask = ~(da.isnull(r.data) | da.isnull(g.data) | da.isnull(b.data))
r = r.where(mask)
Expand Down
7 changes: 1 addition & 6 deletions satpy/composites/abi.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,7 @@ class SimulatedGreen(GenericCompositor):
"""A single-band dataset resembles a Green (0.55 µm)."""

def __call__(self, projectables, optional_datasets=None, **attrs):
c01, c02, c03 = projectables
if not all(c.shape == projectables[0].shape
for c in projectables[1:]):
raise IncompatibleAreas("Simulated green can only be made from "
"bands of the same size. Resample "
"first.")
c01, c02, c03 = self.check_areas(projectables)

# Kaba:
# res = (c01 + c02) * 0.45 + 0.1 * c03
Expand Down
3 changes: 2 additions & 1 deletion satpy/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
test_readers, test_resample,
test_scene, test_utils, test_writers,
test_yaml_reader, writer_tests,
test_enhancements)
test_enhancements, compositor_tests)


if sys.version_info < (2, 7):
Expand All @@ -55,6 +55,7 @@ def suite():
mysuite.addTests(test_file_handlers.suite())
mysuite.addTests(test_utils.suite())
mysuite.addTests(test_enhancements.suite())
mysuite.addTests(compositor_tests.suite())

return mysuite

Expand Down
124 changes: 124 additions & 0 deletions satpy/tests/compositor_tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2018 PyTroll developers
#
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Tests for compositors.
"""


import sys

from satpy.tests.compositor_tests import test_abi

if sys.version_info < (2, 7):
import unittest2 as unittest
else:
import unittest


class TestCheckArea(unittest.TestCase):

"""Test the utility method 'check_areas'."""

def _get_test_ds(self, shape=(50, 100), dims=('y', 'x')):
"""Helper method to get a fake DataArray."""
import xarray as xr
import dask.array as da
from pyresample.geometry import AreaDefinition
data = da.random.random(shape, chunks=25)
area = AreaDefinition(
'test', 'test', 'test',
{'proj': 'eqc', 'lon_0': 0.0,
'lat_0': 0.0},
shape[dims.index('x')], shape[dims.index('y')],
(-20037508.34, -10018754.17, 20037508.34, 10018754.17))
attrs = {'area': area}
return xr.DataArray(data, dims=dims, attrs=attrs)

def test_single_ds(self):
"""Test a single dataset is returned unharmed."""
from satpy.composites import CompositeBase
ds1 = self._get_test_ds()
comp = CompositeBase('test_comp')
ret_datasets = comp.check_areas((ds1,))
self.assertIs(ret_datasets[0], ds1)

def test_mult_ds_area(self):
"""Test multiple datasets successfully pass."""
from satpy.composites import CompositeBase
ds1 = self._get_test_ds()
ds2 = self._get_test_ds()
comp = CompositeBase('test_comp')
ret_datasets = comp.check_areas((ds1, ds2))
self.assertIs(ret_datasets[0], ds1)
self.assertIs(ret_datasets[1], ds2)

def test_mult_ds_no_area(self):
"""Test that all datasets must have an area attribute."""
from satpy.composites import CompositeBase
ds1 = self._get_test_ds()
ds2 = self._get_test_ds()
del ds2.attrs['area']
comp = CompositeBase('test_comp')
self.assertRaises(ValueError, comp.check_areas, (ds1, ds2))

def test_mult_ds_diff_area(self):
"""Test that datasets with different areas fail."""
from satpy.composites import CompositeBase, IncompatibleAreas
from pyresample.geometry import AreaDefinition
ds1 = self._get_test_ds()
ds2 = self._get_test_ds()
ds2.attrs['area'] = AreaDefinition(
'test', 'test', 'test',
{'proj': 'eqc', 'lon_0': 0.0,
'lat_0': 0.0},
100, 50,
(-30037508.34, -20018754.17, 10037508.34, 18754.17))
comp = CompositeBase('test_comp')
self.assertRaises(IncompatibleAreas, comp.check_areas, (ds1, ds2))

def test_mult_ds_diff_dims(self):
"""Test that datasets with different dimensions still pass."""
from satpy.composites import CompositeBase
# x is still 50, y is still 100, even though they are in
# different order
ds1 = self._get_test_ds(shape=(50, 100), dims=('y', 'x'))
ds2 = self._get_test_ds(shape=(3, 100, 50), dims=('bands', 'x', 'y'))
comp = CompositeBase('test_comp')
ret_datasets = comp.check_areas((ds1, ds2))
self.assertIs(ret_datasets[0], ds1)
self.assertIs(ret_datasets[1], ds2)

def test_mult_ds_diff_size(self):
"""Test that datasets with different sizes fail."""
from satpy.composites import CompositeBase, IncompatibleAreas
# x is 50 in this one, 100 in ds2
# y is 100 in this one, 50 in ds2
ds1 = self._get_test_ds(shape=(50, 100), dims=('x', 'y'))
ds2 = self._get_test_ds(shape=(3, 50, 100), dims=('bands', 'y', 'x'))
comp = CompositeBase('test_comp')
self.assertRaises(IncompatibleAreas, comp.check_areas, (ds1, ds2))


def suite():
"""Test suite for all reader tests"""
loader = unittest.TestLoader()
mysuite = unittest.TestSuite()
mysuite.addTests(test_abi.suite())
mysuite.addTest(loader.loadTestsFromTestCase(TestCheckArea))

return mysuite
73 changes: 73 additions & 0 deletions satpy/tests/compositor_tests/test_abi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2018 PyTroll developers
#
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Tests for ABI compositors.
"""

import sys

if sys.version_info < (2, 7):
import unittest2 as unittest
else:
import unittest


class TestABIComposites(unittest.TestCase):
def test_simulated_green(self):
import xarray as xr
import dask.array as da
import numpy as np
from satpy.composites.abi import SimulatedGreen
from pyresample.geometry import AreaDefinition
rows = 5
cols = 10
area = AreaDefinition(
'test', 'test', 'test',
{'proj': 'eqc', 'lon_0': 0.0,
'lat_0': 0.0},
cols, rows,
(-20037508.34, -10018754.17, 20037508.34, 10018754.17))

comp = SimulatedGreen('green', prerequisites=('C01', 'C02', 'C03'),
standard_name='toa_bidirectional_reflectance')
c01 = xr.DataArray(da.zeros((rows, cols), chunks=25) + 0.25,
dims=('y', 'x'),
attrs={'name': 'C01', 'area': area})
c02 = xr.DataArray(da.zeros((rows, cols), chunks=25) + 0.30,
dims=('y', 'x'),
attrs={'name': 'C02', 'area': area})
c03 = xr.DataArray(da.zeros((rows, cols), chunks=25) + 0.35,
dims=('y', 'x'),
attrs={'name': 'C03', 'area': area})
res = comp((c01, c02, c03))
self.assertIsInstance(res, xr.DataArray)
self.assertIsInstance(res.data, da.Array)
self.assertEqual(res.attrs['name'], 'green')
self.assertEqual(res.attrs['standard_name'],
'toa_bidirectional_reflectance')
data = res.compute()
np.testing.assert_allclose(data, 0.28025)


def suite():
"""The test suite for test_scene.
"""
loader = unittest.TestLoader()
mysuite = unittest.TestSuite()
mysuite.addTest(loader.loadTestsFromTestCase(TestABIComposites))
return mysuite

0 comments on commit 547819c

Please sign in to comment.