Skip to content

Commit

Permalink
Add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
djkirkham committed Aug 25, 2017
1 parent 5d8375a commit d9dc750
Show file tree
Hide file tree
Showing 4 changed files with 353 additions and 2 deletions.
2 changes: 1 addition & 1 deletion lib/iris/fileformats/netcdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1981,7 +1981,7 @@ def store(data, cf_var, fill_value):
"data please explicitly provide a fill value."
.format(cube.name()))
elif contains_fill_value:
warnings.warn("Cube '{}' contains data points equal to the fill"
warnings.warn("Cube '{}' contains data points equal to the fill "
"value {}. The points will be interpreted as being "
"masked. Please provide a fill_value argument not "
"equal to any data point.".format(cube.name(),
Expand Down
155 changes: 155 additions & 0 deletions lib/iris/tests/unit/fileformats/netcdf/test_Saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,24 @@
"""Unit tests for the `iris.fileformats.netcdf.Saver` class."""

from __future__ import (absolute_import, division, print_function)

from six.moves import (filter, input, map, range, zip) # noqa
import six

# Import iris.tests first so that some things can be initialised before
# importing anything else.
import iris.tests as tests

from contextlib import contextmanager
import re
import warnings

import netCDF4 as nc
import numpy as np
from numpy import ma

import iris
import iris._lazy_data
from iris.coord_systems import (GeogCS, TransverseMercator, RotatedGeogCS,
LambertConformal, Mercator, Stereographic,
LambertAzimuthalEqualArea)
Expand Down Expand Up @@ -326,6 +333,154 @@ def test_valid_max_saved(self):
ds.close()


class Test_write_fill_value(tests.IrisTest):
def _make_cube(self, dtype, lazy=False, masked_value=None,
masked_index=None):
data = np.arange(12, dtype=dtype).reshape(3, 4)
if masked_value is not None:
data = ma.masked_equal(data, masked_value)
if masked_index is not None:
data = np.ma.masked_array(data)
data[masked_index] = ma.masked
if lazy:
data = iris._lazy_data.as_lazy_data(data)
lat = DimCoord(np.arange(3), 'latitude', units='degrees')
lon = DimCoord(np.arange(4), 'longitude', units='degrees')
return Cube(data, standard_name='air_temperature', units='K',
dim_coords_and_dims=[(lat, 0), (lon, 1)])

@contextmanager
def _netCDF_var(self, cube, **kwargs):
# Get the netCDF4 Variable for a cube from a temp file
standard_name = cube.standard_name
with self.temp_filename('.nc') as nc_path:
with Saver(nc_path, 'NETCDF4') as saver:
saver.write(cube, **kwargs)
ds = nc.Dataset(nc_path)
var, = [var for var in ds.variables.values()
if var.standard_name == standard_name]
yield var

@contextmanager
def _warning_check(self, message_text='', expect_warning=True):
# Check that a warning is raised containing a given string, or that
# no warning containing a given string is raised.
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
yield
matches = (message_text in str(warning.message) for warning in w)
warning_raised = any(matches)
msg = "Warning containing text '{}' not raised." if expect_warning \
else "Warning containing text '{}' unexpectedly raised."
self.assertEqual(expect_warning, warning_raised,
msg.format(message_text))

def test_fill_value(self):
# Test that a passed fill value is saved as a _FillValue attribute.
cube = self._make_cube('>f4')
fill_value = 12345.
with self._netCDF_var(cube, fill_value=fill_value) as var:
self.assertEqual(fill_value, var._FillValue)

def test_default_fill_value(self):
# Test that if no fill value is passed then there is no _FillValue.
# attribute.
cube = self._make_cube('>f4')
with self._netCDF_var(cube) as var:
self.assertNotIn('_FillValue', var.ncattrs())

def test_mask_fill_value(self):
# Test that masked data saves correctly when given a fill value.
index = (1, 1)
fill_value = 12345.
cube = self._make_cube('>f4', masked_index=index)
with self._netCDF_var(cube, fill_value=fill_value) as var:
self.assertEqual(fill_value, var._FillValue)
self.assertTrue(var[index].mask)

def test_mask_default_fill_value(self):
# Test that masked data saves correctly using the default fill value.
index = (1, 1)
cube = self._make_cube('>f4', masked_index=index)
with self._netCDF_var(cube) as var:
self.assertNotIn('_FillValue', var.ncattrs())
self.assertTrue(var[index].mask)

def test_mask_lazy_fill_value(self):
# Test that masked lazy data saves correctly when given a fill value.
index = (1, 1)
fill_value = 12345.
cube = self._make_cube('>f4', masked_index=index, lazy=True)
with self._netCDF_var(cube, fill_value=fill_value) as var:
self.assertEqual(var._FillValue, fill_value)
self.assertTrue(var[index].mask)

def test_mask_lazy_default_fill_value(self):
# Test that masked lazy data saves correctly when given a fill value.
index = (1, 1)
cube = self._make_cube('>f4', masked_index=index, lazy=True)
with self._netCDF_var(cube) as var:
self.assertNotIn('_FillValue', var.ncattrs())
self.assertTrue(var[index].mask)

def test_contains_fill_value_passed(self):
# Test that a warning is raised if the data contains the fill value.
cube = self._make_cube('>f4')
fill_value = 1
with self._warning_check(
'contains data points equal to the fill value'):
with self._netCDF_var(cube, fill_value=fill_value):
pass

def test_contains_fill_value_byte(self):
# Test that a warning is raised if the data contains the fill value
# when it is of a byte type.
cube = self._make_cube('>i1')
fill_value = 1
with self._warning_check(
'contains data points equal to the fill value'):
with self._netCDF_var(cube, fill_value=fill_value):
pass

def test_contains_default_fill_value(self):
# Test that a warning is raised if the data contains the default fill
# value if no fill_value argument is supplied.
cube = self._make_cube('>f4')
cube.data[0, 0] = nc.default_fillvals['f4']
with self._warning_check(
'contains data points equal to the fill value'):
with self._netCDF_var(cube):
pass

def test_contains_default_fill_value_byte(self):
# Test that no warning is raised if the data contains the default fill
# value if no fill_value argument is supplied when the data is of a
# byte type.
cube = self._make_cube('>i1')
with self._warning_check(
'contains data points equal to the fill value', False):
with self._netCDF_var(cube):
pass

def test_contains_masked_fill_value(self):
# Test that no warning is raised if the data contains the fill_value at
# a masked point.
fill_value = 1
cube = self._make_cube('>f4', masked_value=fill_value)
with self._warning_check(
'contains data points equal to the fill value', False):
with self._netCDF_var(cube, fill_value=fill_value):
pass

def test_masked_byte_default_fill_value(self):
# Test that a warning is raised when saving masked byte data with no
# fill value supplied.
cube = self._make_cube('>i1', masked_value=1)
with self._warning_check('contains masked byte data', True):
with self._netCDF_var(cube):
pass


class _Common__check_attribute_compliance(object):
def setUp(self):
self.container = mock.Mock(name='container', attributes={})
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# (C) British Crown Copyright 2017, Met Office
#
# This file is part of Iris.
#
# Iris is free software: you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the
# Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Iris is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with Iris. If not, see <http://www.gnu.org/licenses/>.
"""
Unit tests for the `iris.fileformats.netcdf._FillValueMaskCheckAndStoreTarget`
class.
"""

from __future__ import (absolute_import, division, print_function)
from six.moves import (filter, input, map, range, zip) # noqa

# Import iris.tests first so that some things can be initialised before
# importing anything else.
import iris.tests as tests

import mock
import numpy as np

from iris.fileformats.netcdf import _FillValueMaskCheckAndStoreTarget


class Test__FillValueMaskCheckAndStoreTarget(tests.IrisTest):
def _call_target(self, fill_value, keys, vals):
inner_target = mock.MagicMock()
target = _FillValueMaskCheckAndStoreTarget(inner_target,
fill_value=fill_value)

for key, val in zip(keys, vals):
target[key] = val

calls = [mock.call(key, val) for key, val in zip(keys, vals)]
inner_target.__setitem__.assert_has_calls(calls)

return target

def test___setitem__(self):
self._call_target(None, [1], [2])

def test_no_fill_value_not_masked(self):
# Test when the fill value is not present and the data is not masked
keys = [slice(0, 10), slice(10, 15)]
vals = [np.arange(10), np.arange(5)]
fill_value = 16
target = self._call_target(fill_value, keys, vals)
self.assertFalse(target.contains_value)
self.assertFalse(target.is_masked)

def test_contains_fill_value_not_masked(self):
# Test when the fill value is present and the data is not masked
keys = [slice(0, 10), slice(10, 15)]
vals = [np.arange(10), np.arange(5)]
fill_value = 5
target = self._call_target(fill_value, keys, vals)
self.assertTrue(target.contains_value)
self.assertFalse(target.is_masked)

def test_no_fill_value_masked(self):
# Test when the fill value is not present and the data is masked
keys = [slice(0, 10), slice(10, 15)]
vals = [np.arange(10), np.ma.masked_equal(np.arange(5), 3)]
fill_value = 16
target = self._call_target(fill_value, keys, vals)
self.assertFalse(target.contains_value)
self.assertTrue(target.is_masked)

def test_contains_fill_value_masked(self):
# Test when the fill value is present and the data is masked
keys = [slice(0, 10), slice(10, 15)]
vals = [np.arange(10), np.ma.masked_equal(np.arange(5), 3)]
fill_value = 5
target = self._call_target(fill_value, keys, vals)
self.assertTrue(target.contains_value)
self.assertTrue(target.is_masked)

def test_fill_value_None(self):
# Test when the fill value is None
keys = [slice(0, 10), slice(10, 15)]
vals = [np.arange(10), np.arange(5)]
fill_value = None
target = self._call_target(fill_value, keys, vals)
self.assertFalse(target.contains_value)

def test_contains_masked_fill_value(self):
# Test when the fill value is present but masked the data is masked
keys = [slice(0, 10), slice(10, 15)]
vals = [np.arange(10), np.ma.masked_equal(np.arange(10, 15), 13)]
fill_value = 13
target = self._call_target(fill_value, keys, vals)
self.assertFalse(target.contains_value)
self.assertTrue(target.is_masked)
Loading

0 comments on commit d9dc750

Please sign in to comment.