Skip to content

Commit

Permalink
Add requires_dask decorator for tests (#177)
Browse files Browse the repository at this point in the history
* Add requires_dask decorator for tests
- Disable xarray `warn_for_unclosed_files` setting, which gets enabled when importing from `xarray.tests`
- Refactor tests in `test_dataset.py`
  • Loading branch information
tomvothecoder authored Dec 21, 2021
1 parent f9988a7 commit b1a6086
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 71 deletions.
4 changes: 4 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,7 @@ test = pytest
junit_family=xunit2
addopts = --cov=xcdat --cov-report term --cov-report html:tests_coverage_reports/htmlcov --cov-report xml:tests_coverage_reports/coverage.xml -s
python_files = tests.py test_*.py
# These markers are defined in `xarray.tests` and must be included to avoid warnings when importing from this module.
markers =
flaky
network
3 changes: 3 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
"""Unit test package for xcdat."""
from xarray.core.options import set_options

set_options(warn_for_unclosed_files=False)
143 changes: 72 additions & 71 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,11 @@

class TestOpenDataset:
@pytest.fixture(autouse=True)
def setUp(self, tmp_path):
def setup(self, tmp_path):
# Create temporary directory to save files.
self.dir = tmp_path / "input_data"
self.dir.mkdir()

# Paths to the dummy datasets.
self.file_path = f"{self.dir}/file.nc"
dir = tmp_path / "input_data"
dir.mkdir()
self.file_path = f"{dir}/file.nc"

def test_only_keeps_specified_var(self):
ds = generate_dataset(cf_compliant=True, has_bounds=True)
Expand All @@ -41,39 +39,39 @@ def test_only_keeps_specified_var(self):
warnings.simplefilter("ignore")
ds_mod.to_netcdf(self.file_path)

result_ds = open_dataset(self.file_path, data_var="ts")
result = open_dataset(self.file_path, data_var="ts")
expected = ds.copy()
expected.attrs["xcdat_infer"] = "ts"
assert result_ds.identical(expected)
assert result.identical(expected)

def test_non_cf_compliant_time_is_decoded(self):
# Generate dummy datasets with non-CF compliant time units that aren't
# encoded yet.
ds = generate_dataset(cf_compliant=False, has_bounds=False)
ds.to_netcdf(self.file_path)

result_ds = open_dataset(self.file_path, data_var="ts")
result = open_dataset(self.file_path, data_var="ts")
# Replicates decode_times=False, which adds units to "time" coordinate.
# Refer to xcdat.bounds.BoundsAccessor._add_bounds() for
# how attributes propagate from coord to coord bounds.
result_ds["time_bnds"].attrs["units"] = "months since 2000-01-01"
result["time_bnds"].attrs["units"] = "months since 2000-01-01"

# Generate an expected dataset with non-CF compliant time units that are
# manually encoded
expected_ds = generate_dataset(cf_compliant=True, has_bounds=True)
expected_ds.attrs["xcdat_infer"] = "ts"
expected_ds.time.attrs["units"] = "months since 2000-01-01"
expected_ds.time_bnds.attrs["units"] = "months since 2000-01-01"
expected_ds.time.encoding = {
expected = generate_dataset(cf_compliant=True, has_bounds=True)
expected.attrs["xcdat_infer"] = "ts"
expected.time.attrs["units"] = "months since 2000-01-01"
expected.time_bnds.attrs["units"] = "months since 2000-01-01"
expected.time.encoding = {
"source": None,
"dtype": np.dtype(np.int64),
"original_shape": expected_ds.time.data.shape,
"original_shape": expected.time.data.shape,
"units": "months since 2000-01-01",
"calendar": "proleptic_gregorian",
}

# Check that non-cf compliant time was decoded and bounds were generated.
assert result_ds.identical(expected_ds)
assert result.identical(expected)

def test_preserves_lat_and_lon_bounds_if_they_exist(self):
ds = generate_dataset(cf_compliant=True, has_bounds=True)
Expand All @@ -84,16 +82,18 @@ def test_preserves_lat_and_lon_bounds_if_they_exist(self):
warnings.simplefilter("ignore")
ds.to_netcdf(self.file_path)

result_ds = open_dataset(self.file_path, data_var="ts")
result = open_dataset(self.file_path, data_var="ts")
expected = ds.copy()
expected.attrs["xcdat_infer"] = "ts"

assert result_ds.identical(expected)
assert result.identical(expected)

def test_generates_lat_and_lon_bounds_if_they_dont_exist(self):
# Create expected dataset without bounds.
ds = generate_dataset(cf_compliant=True, has_bounds=False)

ds.to_netcdf(self.file_path)
ds.close()

# Make sure bounds don't exist
data_vars = list(ds.data_vars.keys())
Expand All @@ -111,12 +111,10 @@ class TestOpenMfDataset:
@pytest.fixture(autouse=True)
def setUp(self, tmp_path):
# Create temporary directory to save files.
self.dir = tmp_path / "input_data"
self.dir.mkdir()

# Paths to the dummy datasets.
self.file_path1 = f"{self.dir}/file1.nc"
self.file_path2 = f"{self.dir}/file2.nc"
dir = tmp_path / "input_data"
dir.mkdir()
self.file_path1 = f"{dir}/file1.nc"
self.file_path2 = f"{dir}/file2.nc"

def test_only_keeps_specified_var(self):
# Generate two dummy datasets with non-CF compliant time units.
Expand All @@ -126,88 +124,89 @@ def test_only_keeps_specified_var(self):
ds2 = ds2.rename_vars({"ts": "tas"})
ds2.to_netcdf(self.file_path2)

result_ds = open_mfdataset([self.file_path1, self.file_path2], data_var="ts")
result = open_mfdataset([self.file_path1, self.file_path2], data_var="ts")

# Replicates decode_times=False, which adds units to "time" coordinate.
# Refer to xcdat.bounds.BoundsAccessor._add_bounds() for
# how attributes propagate from coord to coord bounds.
result_ds.time_bnds.attrs["units"] = "months since 2000-01-01"
result.time_bnds.attrs["units"] = "months since 2000-01-01"

# Generate an expected dataset, which is a combination of both datasets
# with decoded time units and coordinate bounds.
expected_ds = generate_dataset(cf_compliant=True, has_bounds=True)
expected_ds.attrs["xcdat_infer"] = "ts"
expected_ds.time.attrs["units"] = "months since 2000-01-01"
expected_ds.time_bnds.attrs["units"] = "months since 2000-01-01"
expected_ds.time.encoding = {
expected = generate_dataset(cf_compliant=True, has_bounds=True)
expected.attrs["xcdat_infer"] = "ts"
expected.time.attrs["units"] = "months since 2000-01-01"
expected.time_bnds.attrs["units"] = "months since 2000-01-01"
expected.time.encoding = {
"source": None,
"dtype": np.dtype(np.int64),
"original_shape": expected_ds.time.data.shape,
"original_shape": expected.time.data.shape,
"units": "months since 2000-01-01",
"calendar": "proleptic_gregorian",
}

# Check that non-cf compliant time was decoded and bounds were generated.
assert result_ds.identical(expected_ds)
assert result.identical(expected)
result.close()

def test_non_cf_compliant_time_is_decoded(self):
# Generate two dummy datasets with non-CF compliant time units.
ds1 = generate_dataset(cf_compliant=False, has_bounds=False)
ds1.to_netcdf(self.file_path1)
ds2 = generate_dataset(cf_compliant=False, has_bounds=False)
ds2 = ds2.rename_vars({"ts": "tas"})

ds1.to_netcdf(self.file_path1)
ds2.to_netcdf(self.file_path2)

result_ds = open_mfdataset([self.file_path1, self.file_path2], data_var="ts")
result = open_mfdataset([self.file_path1, self.file_path2], data_var="ts")
# Replicates decode_times=False, which adds units to "time" coordinate.
# Refer to xcdat.bounds.BoundsAccessor._add_bounds() for
# how attributes propagate from coord to coord bounds.
result_ds.time_bnds.attrs["units"] = "months since 2000-01-01"
result.time_bnds.attrs["units"] = "months since 2000-01-01"

# Generate an expected dataset, which is a combination of both datasets
# with decoded time units and coordinate bounds.
expected_ds = generate_dataset(cf_compliant=True, has_bounds=True)
expected_ds.attrs["xcdat_infer"] = "ts"
expected_ds.time.attrs["units"] = "months since 2000-01-01"
expected_ds.time_bnds.attrs["units"] = "months since 2000-01-01"
expected_ds.time.encoding = {
expected = generate_dataset(cf_compliant=True, has_bounds=True)
expected.attrs["xcdat_infer"] = "ts"
expected.time.attrs["units"] = "months since 2000-01-01"
expected.time_bnds.attrs["units"] = "months since 2000-01-01"
expected.time.encoding = {
"source": None,
"dtype": np.dtype(np.int64),
"original_shape": expected_ds.time.data.shape,
"original_shape": expected.time.data.shape,
"units": "months since 2000-01-01",
"calendar": "proleptic_gregorian",
}

# Check that non-cf compliant time was decoded and bounds were generated.
assert result_ds.identical(expected_ds)
assert result.identical(expected)

def test_preserves_lat_and_lon_bounds_if_they_exist(self):
# Generate two dummy datasets.
ds1 = generate_dataset(cf_compliant=True, has_bounds=True)
ds2 = generate_dataset(cf_compliant=True, has_bounds=True)
ds2 = ds2.rename_vars({"ts": "tas"})

# Suppress UserWarnings regarding missing time.encoding "units" because
# Suppress UserWarning regarding missing time.encoding "units" because
# it is not relevant to this test.
with warnings.catch_warnings():
warnings.simplefilter("ignore")
ds1.to_netcdf(self.file_path1)
ds2.to_netcdf(self.file_path2)

# Generate expected dataset, which is a combination of the two datasets.
expected_ds = generate_dataset(cf_compliant=True, has_bounds=True)
expected_ds.attrs["xcdat_infer"] = "ts"
expected = generate_dataset(cf_compliant=True, has_bounds=True)
expected.attrs["xcdat_infer"] = "ts"
# Check that the result is identical to the expected.
result_ds = open_mfdataset([self.file_path1, self.file_path2], data_var="ts")
assert result_ds.identical(expected_ds)
result = open_mfdataset([self.file_path1, self.file_path2], data_var="ts")
assert result.identical(expected)

def test_generates_lat_and_lon_bounds_if_they_dont_exist(self):
# Generate two dummy datasets.
ds1 = generate_dataset(cf_compliant=True, has_bounds=False)
ds1.to_netcdf(self.file_path1)

ds2 = generate_dataset(cf_compliant=True, has_bounds=False)
ds2 = ds2.rename_vars({"ts": "tas"})

ds1.to_netcdf(self.file_path1)
ds2.to_netcdf(self.file_path2)

# Make sure no bounds exist in the input file.
Expand Down Expand Up @@ -249,7 +248,7 @@ def test_decodes_cf_compliant_time_units(self):
time_attrs = self.time_attrs

# Create an expected dataset with properly decoded time units.
expected_ds = xr.Dataset(
expected = xr.Dataset(
{
"time": xr.DataArray(
name="time",
Expand All @@ -272,16 +271,16 @@ def test_decodes_cf_compliant_time_units(self):
input_ds = xr.Dataset({"time": time_coord})

# Check the resulting dataset is identical to the expected.
result_ds = decode_time_units(input_ds)
assert result_ds.identical(expected_ds)
result = decode_time_units(input_ds)
assert result.identical(expected)

# Check the encodings are the same.
expected_ds.time.encoding = {
expected.time.encoding = {
# Default entries when `decode_times=True`
"dtype": np.dtype(np.int64),
"units": time_attrs["units"],
}
assert result_ds.time.encoding == expected_ds.time.encoding
assert result.time.encoding == expected.time.encoding

def test_decodes_non_cf_compliant_time_units_months(self):
# Create a dummy dataset with non-CF compliant time units.
Expand All @@ -293,7 +292,7 @@ def test_decodes_non_cf_compliant_time_units_months(self):
input_ds = xr.Dataset({"time": time_coord})

# Create an expected dataset with properly decoded time units.
expected_ds = xr.Dataset(
expected = xr.Dataset(
{
"time": xr.DataArray(
name="time",
Expand All @@ -309,18 +308,18 @@ def test_decodes_non_cf_compliant_time_units_months(self):
)

# Check the resulting dataset is identical to the expected.
result_ds = decode_time_units(input_ds)
assert result_ds.identical(expected_ds)
result = decode_time_units(input_ds)
assert result.identical(expected)

# Check result and expected time coordinate encodings are the same.
expected_ds.time.encoding = {
expected.time.encoding = {
"source": None,
"dtype": np.dtype(np.int64),
"original_shape": expected_ds.time.data.shape,
"original_shape": expected.time.data.shape,
"units": time_attrs["units"],
"calendar": "proleptic_gregorian",
}
assert result_ds.time.encoding == expected_ds.time.encoding
assert result.time.encoding == expected.time.encoding

def test_decodes_non_cf_compliant_time_units_years(self):
# Create a dummy dataset with non-CF compliant time units.
Expand All @@ -332,7 +331,7 @@ def test_decodes_non_cf_compliant_time_units_years(self):
input_ds = xr.Dataset({"time": time_coord})

# Create an expected dataset with properly decoded time units.
expected_ds = xr.Dataset(
expected = xr.Dataset(
{
"time": xr.DataArray(
name="time",
Expand All @@ -348,18 +347,18 @@ def test_decodes_non_cf_compliant_time_units_years(self):
)

# Check the resulting dataset is identical to the expected.
result_ds = decode_time_units(input_ds)
assert result_ds.identical(expected_ds)
result = decode_time_units(input_ds)
assert result.identical(expected)

# Check result and expected time coordinate encodings are the same.
expected_ds.time.encoding = {
expected.time.encoding = {
"source": None,
"dtype": np.dtype(np.int64),
"original_shape": expected_ds.time.data.shape,
"original_shape": expected.time.data.shape,
"units": time_attrs["units"],
"calendar": "proleptic_gregorian",
}
assert result_ds.time.encoding == expected_ds.time.encoding
assert result.time.encoding == expected.time.encoding


class TestInferOrKeepVar:
Expand All @@ -380,12 +379,14 @@ def tests_raises_logger_debug_if_only_bounds_data_variables_exist(self, caplog):
assert "This dataset only contains bounds data variables." in caplog.text

def test_raises_error_if_specified_data_var_does_not_exist(self):
ds = self.ds_mod.copy()
with pytest.raises(KeyError):
infer_or_keep_var(self.ds_mod, data_var="nonexistent")
infer_or_keep_var(ds, data_var="nonexistent")

def test_raises_error_if_specified_data_var_is_a_bounds_var(self):
ds = self.ds_mod.copy()
with pytest.raises(KeyError):
infer_or_keep_var(self.ds_mod, data_var="lat_bnds")
infer_or_keep_var(ds, data_var="lat_bnds")

def test_returns_dataset_if_it_only_has_one_non_bounds_data_var(self):
ds = self.ds.copy()
Expand Down
Loading

0 comments on commit b1a6086

Please sign in to comment.