diff --git a/docs/api.rst b/docs/api.rst index 765ef8d7..47d2b130 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -11,8 +11,9 @@ Top-level API dataset.open_dataset dataset.open_mfdataset + dataset.has_cf_compliant_time + dataset.decode_non_cf_time dataset.infer_or_keep_var - dataset.decode_time_units dataset.get_inferred_var .. currentmodule:: xarray diff --git a/tests/fixtures.py b/tests/fixtures.py index cef9c925..e59f51b0 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -32,18 +32,20 @@ ], dims=["time"], attrs={ + "axis": "T", "long_name": "time", "standard_name": "time", - "axis": "T", }, ) time_non_cf = xr.DataArray( data=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], dims=["time"], attrs={ + "units": "months since 2000-01-01", + "calendar": "standard", + "axis": "T", "long_name": "time", "standard_name": "time", - "axis": "T", }, ) @@ -72,18 +74,18 @@ time_bnds_non_cf = xr.DataArray( name="time_bnds", data=[ - [datetime(1999, 12, 16, 12), datetime(2000, 1, 16, 12)], - [datetime(2000, 1, 16, 12), datetime(2000, 2, 15, 12)], - [datetime(2000, 2, 15, 12), datetime(2000, 3, 16, 12)], - [datetime(2000, 3, 16, 12), datetime(2000, 4, 16)], - [datetime(2000, 4, 16), datetime(2000, 5, 16, 12)], - [datetime(2000, 5, 16, 12), datetime(2000, 6, 16)], - [datetime(2000, 6, 16), datetime(2000, 7, 16, 12)], - [datetime(2000, 7, 16, 12), datetime(2000, 8, 16, 12)], - [datetime(2000, 8, 16, 12), datetime(2000, 9, 16)], - [datetime(2000, 9, 16), datetime(2000, 10, 16, 12)], - [datetime(2000, 10, 16, 12), datetime(2000, 11, 16)], - [datetime(2000, 11, 16), datetime(2000, 12, 16)], + [-1, 0], + [0, 1], + [1, 2], + [2, 3], + [3, 4], + [4, 5], + [5, 6], + [6, 7], + [7, 8], + [8, 9], + [9, 10], + [10, 11], ], coords={"time": time_non_cf}, dims=["time", "bnds"], @@ -172,19 +174,18 @@ def generate_dataset(cf_compliant: bool, has_bounds: bool) -> xr.Dataset: ) if cf_compliant: - ds = ds.assign({"time_bnds": time_bnds.copy()}) - ds = ds.assign_coords({"time": time_cf.copy()}) + ds.coords["time"] = time_cf.copy() + ds["time_bnds"] = time_bnds.copy() elif not cf_compliant: - ds = ds.assign({"time_bnds": time_bnds_non_cf.copy()}) - ds = ds.assign_coords({"time": time_non_cf.copy()}) - ds["time"] = ds.time.assign_attrs(units="months since 2000-01-01") + ds.coords["time"] = time_non_cf.copy() + ds["time_bnds"] = time_bnds_non_cf.copy() # If the "bounds" attribute is included in an existing DataArray and # added to a new Dataset, it will get dropped. Therefore, it needs to be # assigned to the DataArrays after they are added to Dataset. - ds["lat"] = ds.lat.assign_attrs(bounds="lat_bnds") - ds["lon"] = ds.lon.assign_attrs(bounds="lon_bnds") - ds["time"] = ds.time.assign_attrs(bounds="time_bnds") + ds["lat"].attrs["bounds"] = "lat_bnds" + ds["lon"].attrs["bounds"] = "lon_bnds" + ds["time"].attrs["bounds"] = "time_bnds" elif not has_bounds: ds = xr.Dataset( @@ -193,9 +194,8 @@ def generate_dataset(cf_compliant: bool, has_bounds: bool) -> xr.Dataset: ) if cf_compliant: - ds = ds.assign_coords({"time": time_cf.copy()}) + ds.coords["time"] = time_cf.copy() elif not cf_compliant: - ds = ds.assign_coords({"time": time_non_cf.copy()}) - ds["time"] = ds.time.assign_attrs(units="months since 2000-01-01") + ds.coords["time"] = time_non_cf.copy() return ds diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 7341af4e..a21760e3 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -7,9 +7,11 @@ from tests.fixtures import generate_dataset from xcdat.dataset import ( - _check_dataset_for_cf_compliant_time, - decode_time_units, + _preprocess_non_cf_dataset, + _split_time_units_attr, + decode_non_cf_time, get_inferred_var, + has_cf_compliant_time, infer_or_keep_var, open_dataset, open_mfdataset, @@ -42,39 +44,40 @@ def test_only_keeps_specified_var(self): warnings.simplefilter("ignore") ds_mod.to_netcdf(self.file_path) - result_ds = open_dataset(self.file_path, data_var="ts") + result = open_dataset(self.file_path, data_var="ts") expected = ds.copy() expected.attrs["xcdat_infer"] = "ts" - assert result_ds.identical(expected) + assert result.identical(expected) + + def test_non_cf_compliant_time_is_not_decoded(self): + ds = generate_dataset(cf_compliant=False, has_bounds=True) + ds.to_netcdf(self.file_path) + + result = open_dataset(self.file_path, decode_times=False) + expected = generate_dataset(cf_compliant=False, has_bounds=True) + expected.attrs["xcdat_infer"] = "ts" + + assert result.identical(expected) def test_non_cf_compliant_time_is_decoded(self): - # Generate dummy datasets with non-CF compliant time units that aren't - # encoded yet. ds = generate_dataset(cf_compliant=False, has_bounds=False) ds.to_netcdf(self.file_path) - result_ds = open_dataset(self.file_path, data_var="ts") - # Replicates decode_times=False, which adds units to "time" coordinate. - # Refer to xcdat.bounds.BoundsAccessor._add_bounds() for - # how attributes propagate from coord to coord bounds. - result_ds["time_bnds"].attrs["units"] = "months since 2000-01-01" - - # Generate an expected dataset with non-CF compliant time units that are - # manually encoded - expected_ds = generate_dataset(cf_compliant=True, has_bounds=True) - expected_ds.attrs["xcdat_infer"] = "ts" - expected_ds.time.attrs["units"] = "months since 2000-01-01" - expected_ds.time_bnds.attrs["units"] = "months since 2000-01-01" - expected_ds.time.encoding = { - "source": None, + result = open_dataset(self.file_path, data_var="ts") + expected = generate_dataset(cf_compliant=True, has_bounds=True) + expected.attrs["xcdat_infer"] = "ts" + expected.time.attrs["calendar"] = "standard" + expected.time.attrs["units"] = "months since 2000-01-01" + expected.time.encoding = { + "source": result.time.encoding["source"], "dtype": np.dtype(np.int64), - "original_shape": expected_ds.time.data.shape, + "original_shape": expected.time.data.shape, "units": "months since 2000-01-01", - "calendar": "proleptic_gregorian", + "calendar": "standard", } - # Check that non-cf compliant time was decoded and bounds were generated. - assert result_ds.identical(expected_ds) + assert result.identical(expected) + assert result.time.encoding == expected.time.encoding def test_preserves_lat_and_lon_bounds_if_they_exist(self): ds = generate_dataset(cf_compliant=True, has_bounds=True) @@ -85,96 +88,27 @@ def test_preserves_lat_and_lon_bounds_if_they_exist(self): warnings.simplefilter("ignore") ds.to_netcdf(self.file_path) - result_ds = open_dataset(self.file_path, data_var="ts") + result = open_dataset(self.file_path, data_var="ts") expected = ds.copy() expected.attrs["xcdat_infer"] = "ts" - assert result_ds.identical(expected) + assert result.identical(expected) def test_generates_lat_and_lon_bounds_if_they_dont_exist(self): # Create expected dataset without bounds. ds = generate_dataset(cf_compliant=True, has_bounds=False) ds.to_netcdf(self.file_path) - # Make sure bounds don't exist data_vars = list(ds.data_vars.keys()) assert "lat_bnds" not in data_vars assert "lon_bnds" not in data_vars - # Check bounds were generated. result = open_dataset(self.file_path, data_var="ts") result_data_vars = list(result.data_vars.keys()) assert "lat_bnds" in result_data_vars assert "lon_bnds" in result_data_vars -class TestCheckTimeCfCompliant: - @pytest.fixture(autouse=True) - def setUp(self, tmp_path): - # Create temporary directory to save files. - self.dir = tmp_path / "input_data" - self.dir.mkdir() - - # Paths to the dummy datasets. - self.file_path = f"{self.dir}/file.nc" - - def test_non_cf_compliant_time(self): - # Generate dummy datasets with non-CF compliant time units - ds = generate_dataset(cf_compliant=False, has_bounds=False) - ds.to_netcdf(self.file_path) - - result = _check_dataset_for_cf_compliant_time(self.file_path) - - # Check that False is returned when the dataset has non-cf_compliant time - assert result is False - - def test_cf_compliant_time(self): - # Generate dummy datasets with CF compliant time units - ds = generate_dataset(cf_compliant=True, has_bounds=False) - ds.to_netcdf(self.file_path) - - result = _check_dataset_for_cf_compliant_time(self.file_path) - - # Check that True is returned when the dataset has cf_compliant time - assert result is True - - def test_no_time_axis(self): - # Generate dummy datasets with CF compliant time - ds = generate_dataset(cf_compliant=True, has_bounds=False) - # remove time axis - ds = ds.isel(time=0) - ds = ds.squeeze(drop=True) - ds = ds.reset_coords() - ds = ds.drop_vars("time") - ds.to_netcdf(self.file_path) - - result = _check_dataset_for_cf_compliant_time(self.file_path) - - # Check that None is returned when there is no time axis - assert result is None - - def test_glob_cf_compliant_time(self): - # Generate dummy datasets with CF compliant time - ds = generate_dataset(cf_compliant=True, has_bounds=False) - ds.to_netcdf(self.file_path) - - result = _check_dataset_for_cf_compliant_time(f"{self.dir}" + "/*.nc") - - # Check that the wildcard path input is correctly evaluated - assert result is True - - def test_list_cf_compliant_time(self): - # Generate dummy datasets with CF compliant time units - ds = generate_dataset(cf_compliant=True, has_bounds=False) - ds.to_netcdf(self.file_path) - - flist = [self.file_path, self.file_path, self.file_path] - result = _check_dataset_for_cf_compliant_time(flist) - - # Check that the list input is correctly evaluated - assert result is True - - class TestOpenMfDataset: @pytest.fixture(autouse=True) def setUp(self, tmp_path): @@ -194,63 +128,60 @@ def test_only_keeps_specified_var(self): ds2 = ds2.rename_vars({"ts": "tas"}) ds2.to_netcdf(self.file_path2) - result_ds = open_mfdataset([self.file_path1, self.file_path2], data_var="ts") - - # Replicates decode_times=False, which adds units to "time" coordinate. - # Refer to xcdat.bounds.BoundsAccessor._add_bounds() for - # how attributes propagate from coord to coord bounds. - result_ds.time_bnds.attrs["units"] = "months since 2000-01-01" - - # Generate an expected dataset, which is a combination of both datasets - # with decoded time units and coordinate bounds. - expected_ds = generate_dataset(cf_compliant=True, has_bounds=True) - expected_ds.attrs["xcdat_infer"] = "ts" - expected_ds.time.attrs["units"] = "months since 2000-01-01" - expected_ds.time_bnds.attrs["units"] = "months since 2000-01-01" - expected_ds.time.encoding = { - "source": None, + result = open_mfdataset([self.file_path1, self.file_path2], data_var="ts") + expected = generate_dataset(cf_compliant=True, has_bounds=True) + expected.attrs["xcdat_infer"] = "ts" + expected.time.attrs["calendar"] = "standard" + expected.time.attrs["units"] = "months since 2000-01-01" + + expected.time.encoding = { + "source": result.time.encoding["source"], "dtype": np.dtype(np.int64), - "original_shape": expected_ds.time.data.shape, + "original_shape": expected.time.data.shape, "units": "months since 2000-01-01", - "calendar": "proleptic_gregorian", + "calendar": "standard", } - # Check that non-cf compliant time was decoded and bounds were generated. - assert result_ds.identical(expected_ds) + assert result.identical(expected) + assert result.time.encoding == expected.time.encoding + + def test_non_cf_compliant_time_is_not_decoded(self): + ds1 = generate_dataset(cf_compliant=False, has_bounds=True) + ds1.to_netcdf(self.file_path1) + ds2 = generate_dataset(cf_compliant=False, has_bounds=True) + ds2 = ds2.rename_vars({"ts": "tas"}) + ds2.to_netcdf(self.file_path2) + + result = open_mfdataset([self.file_path1, self.file_path2], decode_times=False) + + expected = ds1.merge(ds2) + expected.attrs["xcdat_infer"] = "None" + assert result.identical(expected) def test_non_cf_compliant_time_is_decoded(self): - # Generate two dummy datasets with non-CF compliant time units. ds1 = generate_dataset(cf_compliant=False, has_bounds=False) ds1.to_netcdf(self.file_path1) ds2 = generate_dataset(cf_compliant=False, has_bounds=False) ds2 = ds2.rename_vars({"ts": "tas"}) ds2.to_netcdf(self.file_path2) - result_ds = open_mfdataset([self.file_path1, self.file_path2], data_var="ts") - # Replicates decode_times=False, which adds units to "time" coordinate. - # Refer to xcdat.bounds.BoundsAccessor._add_bounds() for - # how attributes propagate from coord to coord bounds. - result_ds.time_bnds.attrs["units"] = "months since 2000-01-01" - - # Generate an expected dataset, which is a combination of both datasets - # with decoded time units and coordinate bounds. - expected_ds = generate_dataset(cf_compliant=True, has_bounds=True) - expected_ds.attrs["xcdat_infer"] = "ts" - expected_ds.time.attrs["units"] = "months since 2000-01-01" - expected_ds.time_bnds.attrs["units"] = "months since 2000-01-01" - expected_ds.time.encoding = { - "source": None, + result = open_mfdataset([self.file_path1, self.file_path2], data_var="ts") + expected = generate_dataset(cf_compliant=True, has_bounds=True) + expected.attrs["xcdat_infer"] = "ts" + expected.time.attrs["units"] = "months since 2000-01-01" + expected.time.attrs["calendar"] = "standard" + expected.time.encoding = { + "source": result.time.encoding["source"], "dtype": np.dtype(np.int64), - "original_shape": expected_ds.time.data.shape, + "original_shape": expected.time.data.shape, "units": "months since 2000-01-01", - "calendar": "proleptic_gregorian", + "calendar": "standard", } - # Check that non-cf compliant time was decoded and bounds were generated. - assert result_ds.identical(expected_ds) + assert result.identical(expected) + assert result.time.encoding == expected.time.encoding def test_preserves_lat_and_lon_bounds_if_they_exist(self): - # Generate two dummy datasets. ds1 = generate_dataset(cf_compliant=True, has_bounds=True) ds2 = generate_dataset(cf_compliant=True, has_bounds=True) ds2 = ds2.rename_vars({"ts": "tas"}) @@ -262,172 +193,421 @@ def test_preserves_lat_and_lon_bounds_if_they_exist(self): ds1.to_netcdf(self.file_path1) ds2.to_netcdf(self.file_path2) - # Generate expected dataset, which is a combination of the two datasets. - expected_ds = generate_dataset(cf_compliant=True, has_bounds=True) - expected_ds.attrs["xcdat_infer"] = "ts" - # Check that the result is identical to the expected. - result_ds = open_mfdataset([self.file_path1, self.file_path2], data_var="ts") - assert result_ds.identical(expected_ds) + expected = generate_dataset(cf_compliant=True, has_bounds=True) + expected.attrs["xcdat_infer"] = "ts" + result = open_mfdataset([self.file_path1, self.file_path2], data_var="ts") + assert result.identical(expected) def test_generates_lat_and_lon_bounds_if_they_dont_exist(self): - # Generate two dummy datasets. ds1 = generate_dataset(cf_compliant=True, has_bounds=False) ds1.to_netcdf(self.file_path1) - ds2 = generate_dataset(cf_compliant=True, has_bounds=False) ds2 = ds2.rename_vars({"ts": "tas"}) ds2.to_netcdf(self.file_path2) - # Make sure no bounds exist in the input file. data_vars1 = list(ds1.data_vars.keys()) data_vars2 = list(ds2.data_vars.keys()) assert "lat_bnds" not in data_vars1 + data_vars2 assert "lon_bnds" not in data_vars1 + data_vars2 - # Check that bounds were generated. result = open_dataset(self.file_path1, data_var="ts") result_data_vars = list(result.data_vars.keys()) assert "lat_bnds" in result_data_vars assert "lon_bnds" in result_data_vars -class TestDecodeTimeUnits: +class TestHasCFCompliantTime: + @pytest.fixture(autouse=True) + def setUp(self, tmp_path): + # Create temporary directory to save files. + self.dir = tmp_path / "input_data" + self.dir.mkdir() + + # Paths to the dummy datasets. + self.file_path = f"{self.dir}/file.nc" + + def test_non_cf_compliant_time(self): + # Generate dummy dataset with non-CF compliant time units + ds = generate_dataset(cf_compliant=False, has_bounds=False) + ds.to_netcdf(self.file_path) + + result = has_cf_compliant_time(self.file_path) + + # Check that False is returned when the dataset has non-cf_compliant time + assert result is False + + def test_cf_compliant_time(self): + # Generate dummy dataset with CF compliant time units + ds = generate_dataset(cf_compliant=True, has_bounds=False) + ds.to_netcdf(self.file_path) + + result = has_cf_compliant_time(self.file_path) + + # Check that True is returned when the dataset has cf_compliant time + assert result is True + + def test_no_time_axis(self): + # Generate dummy dataset with CF compliant time + ds = generate_dataset(cf_compliant=True, has_bounds=False) + # remove time axis + ds = ds.isel(time=0) + ds = ds.squeeze(drop=True) + ds = ds.reset_coords() + ds = ds.drop_vars("time") + ds.to_netcdf(self.file_path) + + result = has_cf_compliant_time(self.file_path) + + # Check that None is returned when there is no time axis + assert result is None + + def test_glob_cf_compliant_time(self): + # Generate dummy datasets with CF compliant time + ds = generate_dataset(cf_compliant=True, has_bounds=False) + ds.to_netcdf(self.file_path) + + result = has_cf_compliant_time(f"{self.dir}" + "/*.nc") + + # Check that the wildcard path input is correctly evaluated + assert result is True + + def test_list_cf_compliant_time(self): + # Generate dummy datasets with CF compliant time units + ds = generate_dataset(cf_compliant=True, has_bounds=False) + ds.to_netcdf(self.file_path) + + flist = [self.file_path, self.file_path, self.file_path] + result = has_cf_compliant_time(flist) + + # Check that the list input is correctly evaluated + assert result is True + + +class TestDecodeNonCFTimeUnits: @pytest.fixture(autouse=True) def setup(self): - # Common attributes for the time coordinate. Units are overriden based - # on the unit that needs to be tested (days (CF compliant) or months - # (non-CF compliant). - self.time_attrs = { - "bounds": "time_bnds", - "axis": "T", - "long_name": "time", - "standard_name": "time", + time = xr.DataArray( + name="time", + data=[1, 2, 3], + dims=["time"], + attrs={ + "bounds": "time_bnds", + "axis": "T", + "long_name": "time", + "standard_name": "time", + "calendar": "noleap", + }, + ) + time_bnds = xr.DataArray( + name="time_bnds", + data=[[0, 1], [1, 2], [2, 3]], + dims=["time", "bnds"], + ) + time_bnds.encoding = { + "zlib": False, + "shuffle": False, + "complevel": 0, + "fletcher32": False, + "contiguous": False, + "chunksizes": (1, 2), + "source": "None", + "original_shape": (1980, 2), + "dtype": np.dtype("float64"), } + self.ds = xr.Dataset({"time": time, "time_bnds": time_bnds}) - def test_throws_error_if_function_is_called_on_already_decoded_cf_compliant_dataset( + def test_raises_error_if_function_is_called_on_already_decoded_cf_compliant_dataset( self, ): ds = generate_dataset(cf_compliant=True, has_bounds=True) with pytest.raises(KeyError): - decode_time_units(ds) + decode_non_cf_time(ds) - def test_decodes_cf_compliant_time_units(self): - # Create a dummy dataset with CF compliant time units. - time_attrs = self.time_attrs + def test_decodes_months_with_a_reference_date_at_the_start_of_the_month(self): + ds = self.ds.copy() + ds.time.attrs["units"] = "months since 2000-01-01" - # Create an expected dataset with properly decoded time units. - expected_ds = xr.Dataset( + result = decode_non_cf_time(ds) + expected = xr.Dataset( { "time": xr.DataArray( name="time", - data=[ - np.datetime64("2000-01-01"), - np.datetime64("2000-01-02"), - np.datetime64("2000-01-03"), - ], + data=np.array( + [ + "2000-02-01", + "2000-03-01", + "2000-04-01", + ], + dtype="datetime64", + ), dims=["time"], - attrs=time_attrs, - ) + attrs=ds.time.attrs, + ), + "time_bnds": xr.DataArray( + name="time_bnds", + data=np.array( + [ + ["2000-01-01", "2000-02-01"], + ["2000-02-01", "2000-03-01"], + ["2000-03-01", "2000-04-01"], + ], + dtype="datetime64", + ), + dims=["time", "bnds"], + attrs=ds.time_bnds.attrs, + ), } ) + assert result.identical(expected) - # Update the time attrs to mimic decode_times=False - time_attrs.update({"units": "days since 2000-01-01"}) - time_coord = xr.DataArray( - name="time", data=[0, 1, 2], dims=["time"], attrs=time_attrs - ) - input_ds = xr.Dataset({"time": time_coord}) + expected.time.encoding = { + "source": "None", + "dtype": np.dtype(np.int64), + "original_shape": expected.time.data.shape, + "units": ds.time.attrs["units"], + "calendar": ds.time.attrs["calendar"], + } + expected.time_bnds.encoding = ds.time_bnds.encoding + assert result.time.encoding == expected.time.encoding + assert result.time_bnds.encoding == expected.time_bnds.encoding - # Check the resulting dataset is identical to the expected. - result_ds = decode_time_units(input_ds) - assert result_ds.identical(expected_ds) + def test_decodes_months_with_a_reference_date_at_the_middle_of_the_month(self): + ds = self.ds.copy() + ds.time.attrs["units"] = "months since 2000-01-15" + + result = decode_non_cf_time(ds) + expected = xr.Dataset( + { + "time": xr.DataArray( + name="time", + data=np.array( + [ + "2000-02-15", + "2000-03-15", + "2000-04-15", + ], + dtype="datetime64", + ), + dims=["time"], + attrs=ds.time.attrs, + ), + "time_bnds": xr.DataArray( + name="time_bnds", + data=np.array( + [ + ["2000-01-15", "2000-02-15"], + ["2000-02-15", "2000-03-15"], + ["2000-03-15", "2000-04-15"], + ], + dtype="datetime64", + ), + dims=["time", "bnds"], + attrs=ds.time_bnds.attrs, + ), + } + ) + assert result.identical(expected) - # Check the encodings are the same. - expected_ds.time.encoding = { - # Default entries when `decode_times=True` + expected.time.encoding = { + "source": "None", "dtype": np.dtype(np.int64), - "units": time_attrs["units"], + "original_shape": expected.time.data.shape, + "units": ds.time.attrs["units"], + "calendar": ds.time.attrs["calendar"], } - assert result_ds.time.encoding == expected_ds.time.encoding - - def test_decodes_non_cf_compliant_time_units_months(self): - # Create a dummy dataset with non-CF compliant time units. - time_attrs = self.time_attrs - time_attrs.update({"units": "months since 2000-01-01"}) - time_coord = xr.DataArray( - name="time", data=[0, 1, 2], dims=["time"], attrs=time_attrs - ) - input_ds = xr.Dataset({"time": time_coord}) + expected.time_bnds.encoding = ds.time_bnds.encoding + assert result.time.encoding == expected.time.encoding + assert result.time_bnds.encoding == expected.time_bnds.encoding - # Create an expected dataset with properly decoded time units. - expected_ds = xr.Dataset( + def test_decodes_months_with_a_reference_date_at_the_end_of_the_month(self): + ds = self.ds.copy() + ds.time.attrs["units"] = "months since 1999-12-31" + + result = decode_non_cf_time(ds) + expected = xr.Dataset( { "time": xr.DataArray( name="time", - data=[ - np.datetime64("2000-01-01"), - np.datetime64("2000-02-01"), - np.datetime64("2000-03-01"), - ], + data=np.array( + [ + "2000-01-31", + "2000-02-29", + "2000-03-31", + ], + dtype="datetime64", + ), dims=["time"], - attrs=time_attrs, - ) + attrs=ds.time.attrs, + ), + "time_bnds": xr.DataArray( + name="time_bnds", + data=np.array( + [ + ["1999-12-31", "2000-01-31"], + ["2000-01-31", "2000-02-29"], + ["2000-02-29", "2000-03-31"], + ], + dtype="datetime64", + ), + dims=["time", "bnds"], + attrs=ds.time_bnds.attrs, + ), } ) + assert result.identical(expected) - # Check the resulting dataset is identical to the expected. - result_ds = decode_time_units(input_ds) - assert result_ds.identical(expected_ds) + expected.time.encoding = { + "source": "None", + "dtype": np.dtype(np.int64), + "original_shape": expected.time.data.shape, + "units": ds.time.attrs["units"], + "calendar": ds.time.attrs["calendar"], + } + expected.time_bnds.encoding = ds.time_bnds.encoding + assert result.time.encoding == expected.time.encoding + assert result.time_bnds.encoding == expected.time_bnds.encoding - # Check result and expected time coordinate encodings are the same. - expected_ds.time.encoding = { - "source": None, + def test_decodes_months_with_a_reference_date_on_a_leap_year(self): + ds = self.ds.copy() + ds.time.attrs["units"] = "months since 2000-02-29" + + result = decode_non_cf_time(ds) + expected = xr.Dataset( + { + "time": xr.DataArray( + name="time", + data=np.array( + [ + "2000-03-29", + "2000-04-29", + "2000-05-29", + ], + dtype="datetime64", + ), + dims=["time"], + attrs=ds.time.attrs, + ), + "time_bnds": xr.DataArray( + name="time_bnds", + data=np.array( + [ + ["2000-02-29", "2000-03-29"], + ["2000-03-29", "2000-04-29"], + ["2000-04-29", "2000-05-29"], + ], + dtype="datetime64", + ), + dims=["time", "bnds"], + attrs=ds.time_bnds.attrs, + ), + } + ) + assert result.identical(expected) + + expected.time.encoding = { + "source": "None", "dtype": np.dtype(np.int64), - "original_shape": expected_ds.time.data.shape, - "units": time_attrs["units"], - "calendar": "proleptic_gregorian", + "original_shape": expected.time.data.shape, + "units": ds.time.attrs["units"], + "calendar": ds.time.attrs["calendar"], } - assert result_ds.time.encoding == expected_ds.time.encoding - - def test_decodes_non_cf_compliant_time_units_years(self): - # Create a dummy dataset with non-CF compliant time units. - time_attrs = self.time_attrs - time_attrs.update({"units": "years since 2000-01-01"}) - time_coord = xr.DataArray( - name="time", data=[0, 1, 2], dims=["time"], attrs=time_attrs + expected.time_bnds.encoding = ds.time_bnds.encoding + assert result.time.encoding == expected.time.encoding + assert result.time_bnds.encoding == expected.time_bnds.encoding + + def test_decodes_years_with_a_reference_date_at_the_middle_of_the_year(self): + ds = self.ds.copy() + ds.time.attrs["units"] = "years since 2000-06-01" + + result = decode_non_cf_time(ds) + expected = xr.Dataset( + { + "time": xr.DataArray( + name="time", + data=np.array( + [ + "2001-06-01", + "2002-06-01", + "2003-06-01", + ], + dtype="datetime64", + ), + dims=["time"], + attrs=ds.time.attrs, + ), + "time_bnds": xr.DataArray( + name="time_bnds", + data=np.array( + [ + ["2000-06-01", "2001-06-01"], + ["2001-06-01", "2002-06-01"], + ["2002-06-01", "2003-06-01"], + ], + dtype="datetime64", + ), + dims=["time", "bnds"], + attrs=ds.time_bnds.attrs, + ), + } ) - input_ds = xr.Dataset({"time": time_coord}) + assert result.identical(expected) + + expected.time.encoding = { + "source": "None", + "dtype": np.dtype(np.int64), + "original_shape": expected.time.data.shape, + "units": ds.time.attrs["units"], + "calendar": ds.time.attrs["calendar"], + } + expected.time_bnds.encoding = ds.time_bnds.encoding + assert result.time.encoding == expected.time.encoding + assert result.time_bnds.encoding == expected.time_bnds.encoding + + def test_decodes_years_with_a_reference_date_on_a_leap_year(self): + ds = self.ds.copy() + ds.time.attrs["units"] = "years since 2000-02-29" - # Create an expected dataset with properly decoded time units. - expected_ds = xr.Dataset( + result = decode_non_cf_time(ds) + expected = xr.Dataset( { "time": xr.DataArray( name="time", data=[ - np.datetime64("2000-01-01"), - np.datetime64("2001-01-01"), - np.datetime64("2002-01-01"), + np.datetime64("2001-02-28"), + np.datetime64("2002-02-28"), + np.datetime64("2003-02-28"), ], dims=["time"], - attrs=time_attrs, - ) + ), + "time_bnds": xr.DataArray( + name="time_bnds", + data=np.array( + [ + ["2000-02-29", "2001-02-28"], + ["2001-02-28", "2002-02-28"], + ["2002-02-28", "2003-02-28"], + ], + dtype="datetime64", + ), + dims=["time", "bnds"], + attrs=ds.time_bnds.attrs, + ), } ) + expected.time.attrs = ds.time.attrs + assert result.identical(expected) - # Check the resulting dataset is identical to the expected. - result_ds = decode_time_units(input_ds) - assert result_ds.identical(expected_ds) - - # Check result and expected time coordinate encodings are the same. - expected_ds.time.encoding = { - "source": None, + expected.time.encoding = { + "source": "None", "dtype": np.dtype(np.int64), - "original_shape": expected_ds.time.data.shape, - "units": time_attrs["units"], - "calendar": "proleptic_gregorian", + "original_shape": expected.time.data.shape, + "units": ds.time.attrs["units"], + "calendar": ds.time.attrs["calendar"], } - assert result_ds.time.encoding == expected_ds.time.encoding + expected.time_bnds.encoding = ds.time_bnds.encoding + assert result.time.encoding == expected.time.encoding + assert result.time_bnds.encoding == expected.time_bnds.encoding class TestInferOrKeepVar: @@ -472,10 +652,9 @@ def test_returns_dataset_if_it_contains_multiple_non_bounds_data_var_with_logger ds = self.ds_mod.copy() result = infer_or_keep_var(ds, data_var=None) expected = ds.copy() - expected.attrs["xcdat_infer"] = None + expected.attrs["xcdat_infer"] = "None" assert result.identical(expected) - assert result.attrs.get("xcdat_infer") is None assert ( "This dataset contains more than one regular data variable ('tas', 'ts'). " "If desired, pass the `data_var` kwarg to reduce down to one regular data var." @@ -533,3 +712,59 @@ def test_returns_inferred_data_var(self, caplog): "The data variable 'ts' was inferred from the Dataset attr 'xcdat_infer' " "for this operation." ) in caplog.text + + +class TestPreProcessNonCFDataset: + @pytest.fixture(autouse=True) + def setup(self): + self.ds = generate_dataset(cf_compliant=False, has_bounds=True) + + def test_user_specified_callable_results_in_subsetting_dataset_on_time_slice(self): + def callable(ds): + return ds.isel(time=slice(0, 1)) + + ds = self.ds.copy() + + result = _preprocess_non_cf_dataset(ds, callable) + expected = ds.copy().isel(time=slice(0, 1)) + expected["time"] = xr.DataArray( + name="time", + data=np.array( + [ + "2000-01-01", + ], + dtype="datetime64", + ), + dims=["time"], + ) + expected["time_bnds"] = xr.DataArray( + name="time_bnds", + data=np.array( + [ + ["1999-12-01", "2000-01-01"], + ], + dtype="datetime64", + ), + dims=["time", "bnds"], + ) + + expected.time.attrs = ds.time.attrs + expected.time_bnds.attrs = ds.time_bnds.attrs + assert result.identical(expected) + + +class TestSplitTimeUnitsAttr: + def test_raises_error_if_units_attr_is_none(self): + with pytest.raises(KeyError): + _split_time_units_attr(None) # type: ignore + + def test_splits_units_attr_to_unit_and_reference_date(self): + assert _split_time_units_attr("months since 1800") == ("months", "1800") + assert _split_time_units_attr("months since 1800-01-01") == ( + "months", + "1800-01-01", + ) + assert _split_time_units_attr("months since 1800-01-01 00:00:00") == ( + "months", + "1800-01-01 00:00:00", + ) diff --git a/xcdat/__init__.py b/xcdat/__init__.py index 31515ed2..b6efe0fc 100644 --- a/xcdat/__init__.py +++ b/xcdat/__init__.py @@ -1,6 +1,13 @@ """Top-level package for xcdat.""" from xcdat.bounds import BoundsAccessor # noqa: F401 -from xcdat.dataset import decode_time_units, open_dataset, open_mfdataset # noqa: F401 +from xcdat.dataset import ( # noqa: F401 + decode_non_cf_time, + get_inferred_var, + has_cf_compliant_time, + infer_or_keep_var, + open_dataset, + open_mfdataset, +) from xcdat.spatial_avg import SpatialAverageAccessor # noqa: F401 from xcdat.xcdat import XCDATAccessor # noqa: F401 diff --git a/xcdat/dataset.py b/xcdat/dataset.py index c38c5747..a2f70abf 100644 --- a/xcdat/dataset.py +++ b/xcdat/dataset.py @@ -1,6 +1,7 @@ """Dataset module for functions related to an xarray.Dataset.""" +from functools import partial from glob import glob -from typing import Any, Dict, Hashable, List, Optional, Union +from typing import Any, Callable, Dict, Hashable, List, Optional, Tuple, Union import pandas as pd import xarray as xr @@ -11,30 +12,36 @@ logger = setup_custom_logger(__name__) +#: List of non-CF compliant time units. +NON_CF_TIME_UNITS: List[str] = ["months", "years"] + def open_dataset( - path: str, data_var: Optional[str] = None, **kwargs: Dict[str, Any] + path: str, + data_var: Optional[str] = None, + decode_times: bool = True, + **kwargs: Dict[str, Any], ) -> xr.Dataset: """Wrapper for ``xarray.open_dataset()`` that applies common operations. Operations include: - - Decode both CF and non-CF compliant time units if the Dataset has a time - dimension + - Optional decoding of time coordinates with CF or non-CF compliant units if + the Dataset has a time dimension - Add missing bounds for supported axis - Option to limit the Dataset to a single regular (non-bounds) data variable, while retaining any bounds data variables - ``decode_times`` is statically set to ``False``. This enables a check - for whether the units in the time dimension (if it exists) contains CF or - non-CF compliant units, which determines if manual decoding is necessary. - Parameters ---------- path : str Path to Dataset. data_var: Optional[str], optional The key of the data variable to keep in the Dataset, by default None. + decode_times: bool + If True, decode times encoded in the standard NetCDF datetime format + into datetime objects. Otherwise, leave them encoded as numbers. + This keyword may not be supported by all the backends, by default True. kwargs : Dict[str, Any] Additional arguments passed on to ``xarray.open_dataset``. Refer to the [1]_ xarray docs for accepted keyword arguments. @@ -67,18 +74,19 @@ def open_dataset( >>> from xcdat.dataset import open_dataset >>> ds = open_dataset("file_path", data_var="tas") - - Keep multiple variables in the Dataset: - - >>> from xcdat.dataset import open_dataset - >>> ds = open_dataset("file_path", data_var=["ts", "tas"]) """ - ds = xr.open_dataset(path, decode_times=False, **kwargs) - ds = infer_or_keep_var(ds, data_var) - - if ds.cf.dims.get("T") is not None: - ds = decode_time_units(ds) + if decode_times: + cf_compliant_time: Optional[bool] = has_cf_compliant_time(path) + if cf_compliant_time is False: + # XCDAT handles decoding time values with non-CF units. + ds = xr.open_dataset(path, decode_times=False, **kwargs) + ds = decode_non_cf_time(ds) + else: + ds = xr.open_dataset(path, decode_times=True, **kwargs) + else: + ds = xr.open_dataset(path, decode_times=False, **kwargs) + ds = infer_or_keep_var(ds, data_var) ds = ds.bounds.add_missing_bounds() return ds @@ -86,6 +94,8 @@ def open_dataset( def open_mfdataset( paths: Union[str, List[str]], data_var: Optional[str] = None, + preprocess: Optional[Callable] = None, + decode_times: bool = True, data_vars: Union[Literal["minimal", "different", "all"], List[str]] = "minimal", **kwargs: Dict[str, Any], ) -> xr.Dataset: @@ -93,9 +103,9 @@ def open_mfdataset( Operations include: - - Decode both CF and non-CF compliant time units if the Dataset has a time - dimension - - Fill missing bounds for supported axis + - Optional decoding of time coordinates with CF or non-CF compliant units if + the Dataset has a time dimension + - Add missing bounds for supported axis - Option to limit the Dataset to a single regular (non-bounds) data variable, while retaining any bounds data variables @@ -106,10 +116,6 @@ def open_mfdataset( `"minimal"` is required for some XCDAT functions, including spatial averaging where a reduction is performed using the lat/lon bounds. - ``decode_times`` is statically set to ``False``. This enables a check - for whether the units in the time dimension (if it exists) contains CF or - non-CF compliant units, which determines if manual decoding is necessary. - Parameters ---------- path : Union[str, List[str]] @@ -120,6 +126,14 @@ def open_mfdataset( for details). (A string glob will be expanded to a 1-dimensional list.) data_var: Optional[str], optional The key of the data variable to keep in the Dataset, by default None. + preprocess : Optional[Callable], optional + If provided, call this function on each dataset prior to concatenation. + You can find the file-name from which each dataset was loaded in + ``ds.encoding["source"]``. + decode_times: bool + If True, decode times encoded in the standard NetCDF datetime format + into datetime objects. Otherwise, leave them encoded as numbers. + This keyword may not be supported by all the backends, by default True. data_vars: Union[Literal["minimal", "different", "all"], List[str]], optional These data variables will be concatenated together: * "minimal": Only data variables in which the dimension already @@ -162,145 +176,130 @@ def open_mfdataset( Keep a single variable in the Dataset: - >>> from xcdat.dataset import open_dataset + >>> from xcdat.dataset import open_mfdataset >>> ds = open_mfdataset(["file_path1", "file_path2"], data_var="tas") - - Keep multiple variables in the Dataset: - - >>> from xcdat.dataset import open_dataset - >>> ds = open_mfdataset(["file_path1", "file_path2"], data_var=["ts", "tas"]) """ - # check if time axis is cf_compliant - cf_compliant = _check_dataset_for_cf_compliant_time(paths) - - # if cf_compliant, let xarray decode the time units - # otherwise, decode using decode_time_units - if cf_compliant: - ds = xr.open_mfdataset(paths, decode_times=True, data_vars=data_vars, **kwargs) - else: - ds = xr.open_mfdataset(paths, decode_times=False, data_vars=data_vars, **kwargs) - if ds.cf.dims.get("T") is not None: - ds = decode_time_units(ds) - + if decode_times: + cf_compliant_time: Optional[bool] = has_cf_compliant_time(paths) + # XCDAT handles decoding decoding time values with non-CF units using + # the preprocess kwarg. + if cf_compliant_time is False: + decode_times = False + preprocess = partial(_preprocess_non_cf_dataset, callable=preprocess) + + ds = xr.open_mfdataset( + paths, + decode_times=decode_times, + data_vars=data_vars, + preprocess=preprocess, + **kwargs, + ) ds = infer_or_keep_var(ds, data_var) - ds = ds.bounds.add_missing_bounds() return ds -def infer_or_keep_var(dataset: xr.Dataset, data_var: Optional[str]) -> xr.Dataset: - """Infer the data variable(s) or keep a specific one in the Dataset. +def has_cf_compliant_time(path: Union[str, List[str]]) -> Optional[bool]: + """Determine if a dataset has time coordinates with CF compliant units. - If ``data_var`` is None, then this function checks the number of - regular (non-bounds) data variables in the Dataset. If there is a single - regular data var, then it will add an 'xcdat_infer' attr pointing to it in - the Dataset. XCDAT APIs can then call `get_inferred_var()` to get the data - var linked to the 'xcdat_infer' attr. If there are multiple regular data - variables, the 'xcdat_infer' attr is not set and the Dataset is returned - as is. - - If ``data_var`` is not None, then this function checks if the ``data_var`` - exists in the Dataset and if it is a regular data var. If those checks pass, - it will subset the Dataset to retain that ``data_var`` and all bounds data - vars. An 'xcdat_infer' attr pointing to the ``data_var`` is also added - to the Dataset. - - This utility function is useful for designing XCDAT APIs with an optional - ``data_var`` kwarg. If ``data_var`` is None, an inference to the desired - data var is performed with a call to this function. Otherwise, perform the - API operation explicitly on ``data_var``. + This function opens a dataset either from a single path or the first path + from a list of paths (for a multi-file dataset). If the dataset does not + contain a time dimension, None is returned. Otherwise, the units attribute + is extracted from the time coordinates to determine whether it is CF or + non-CF compliant. Parameters ---------- - dataset : xr.Dataset - The Dataset. - data_var: Optional[str], optional - The key of the data variable to keep in the Dataset. + path : Union[str, List[str]] + Either a file (``"file.nc"``), a string glob in the form + ``"path/to/my/files/*.nc"``, or an explicit list of files to open. + Paths can be given as strings or as pathlib Paths. If concatenation + along more than one dimension is desired, then ``paths`` must be a + nested list-of-lists (see ``combine_nested`` for details). (A string + glob will be expanded to a 1-dimensional list.) Returns ------- - xr.Dataset - The Dataset. + Optional[bool] + None if time dimension does not exist, True if CF compliant, or False if + non-CF compliant. - Raises - ------ - KeyError - If the specified data variable is not found in the Dataset. - KeyError - If the user specifies a bounds variable to keep. + Notes + ----- + This function only checks one file for multi-file datasets to optimize + performance because it is slower to combine all files then check for CF + compliance. """ - ds = dataset.copy() - # Make sure the "xcdat_infer" attr is None because a Dataset may be written - # with this attr already set. - ds.attrs["xcdat_infer"] = None - - all_vars = ds.data_vars.keys() - bounds_vars = ds.bounds.names - regular_vars: List[Hashable] = list(set(all_vars) ^ set(bounds_vars)) - - if len(regular_vars) == 0: - logger.debug("This dataset only contains bounds data variables.") - - if data_var is None: - if len(regular_vars) == 1: - ds.attrs["xcdat_infer"] = regular_vars[0] - elif len(regular_vars) > 1: - regular_vars_str = ", ".join( - f"'{var}'" for var in sorted(regular_vars) # type:ignore - ) - logger.debug( - "This dataset contains more than one regular data variable " - f"({regular_vars_str}). If desired, pass the `data_var` kwarg to " - "reduce down to one regular data var." - ) - if data_var is not None: - if data_var not in all_vars: - raise KeyError( - f"The data variable '{data_var}' does not exist in the dataset." - ) - if data_var in bounds_vars: - raise KeyError("Please specify a regular (non-bounds) data variable.") - - ds = dataset[[data_var] + bounds_vars] - ds.attrs["xcdat_infer"] = data_var + # FIXME: This doesn't handle pathlib paths or a list of lists + if type(path) == str: + if "*" in path: + first_file = glob(path)[0] + else: + first_file = path + else: + first_file = path[0] - return ds + ds = xr.open_dataset(first_file, decode_times=False) + if ds.cf.dims.get("T") is None: + return None + time = ds.cf["T"] + units = _split_time_units_attr(time.attrs.get("units"))[0] + cf_compliant = units not in NON_CF_TIME_UNITS + return cf_compliant -def decode_time_units(dataset: xr.Dataset): - """Decodes both CF and non-CF compliant time units. - ``xarray`` uses the ``cftime`` module, which only supports CF compliant - time units [4]_. As a result, opening datasets with non-CF compliant - time units (months and years) will throw an error if ``decode_times=True``. +def decode_non_cf_time(dataset: xr.Dataset) -> xr.Dataset: + """Decodes time coordinates and time bounds with non-CF compliant units. - This function works around this issue by first checking if the time units - are CF or non-CF compliant. Datasets with CF compliant time units are passed - to ``xarray.decode_cf``. Datasets with non-CF compliant time units are - manually decoded by extracting the units and reference date, which are used - to generate an array of datetime values. + By default, ``xarray`` uses the ``cftime`` module, which only supports + decoding time with [3]_ CF compliant units. This function fills the gap in + xarray by being able to decode time with non-CF compliant units such as + "months since ..." and "years since ...". It extracts the units and + reference date from the "units" attribute, which are used to convert the + numerically encoded time values (representing the offset from the reference + date) to pandas DateOffset objects. These offset values are added to the + reference date, forming DataArrays of datetime objects that replace the time + coordinate and time bounds (if they exist) values in the Dataset. Parameters ---------- dataset : xr.Dataset - Dataset with non-decoded CF/non-CF compliant time units. + Dataset with numerically encoded time coordinates and time bounds (if + they exist). Returns ------- xr.Dataset - Dataset with decoded time units. + Dataset with decoded time coordinates and time bounds (if they exist) as + datetime objects. Notes ----- - .. [4] https://unidata.github.io/cftime/api.html#cftime.num2date + The [4]_ pandas ``DateOffset`` object is a time duration relative to a + reference date that respects calendar arithmetic. This means it considers + CF calendar types with or without leap years when adding the offsets to the + reference date. + + DateOffset is used instead of timedelta64 because timedelta64 does + not respect calendar arithmetic. One downside of DateOffset (unlike + timedelta64) is that there is currently no simple way of vectorizing the + addition of DateOffset objects to Timestamp/datetime64 objects. However, the + performance of element-wise iteration should be sufficient for datasets + that have "months" and "years" time units since the size of the time + coordinates isn't expected to be large in comparison to "days" or "hours". + + References + ----- + .. [3] https://cfconventions.org/cf-conventions/cf-conventions.html#time-coordinate + .. [4] https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#dateoffset-objects Examples -------- - Decode non-CF compliant time units in a Dataset: + Decode the time coordinates with non-CF units in a Dataset: >>> from xcdat.dataset import decode_time_units - >>> ds = xr.open_dataset("file_path", decode_times=False) >>> ds.time array([0, 1, 2]) @@ -312,8 +311,10 @@ def decode_time_units(dataset: xr.Dataset): axis: T long_name: time standard_name: time - >>> ds = decode_time_units(ds) - >>> ds.time + calendar: noleap + >>> + >>> ds_decoded = decode_time_units(ds) + >>> ds_decoded.time array(['2000-01-01T00:00:00.000000000', '2001-01-01T00:00:00.000000000', '2002-01-01T00:00:00.000000000'], dtype='datetime64[ns]') @@ -325,114 +326,136 @@ def decode_time_units(dataset: xr.Dataset): axis: T long_name: time standard_name: time + calendar: noleap - View time coordinate encoding information: + View time encoding information: - >>> ds.time.encoding + >>> ds_decoded.time.encoding {'source': None, 'dtype': dtype('int64'), 'original_shape': (3,), 'units': - 'years since 2000-01-01', 'calendar': 'proleptic_gregorian'} + 'years since 2000-01-01', 'calendar': 'noleap'} """ - time = dataset["time"] + time = dataset.cf["T"] + time_bounds = dataset.get(time.attrs.get("bounds"), None) units_attr = time.attrs.get("units") - - if units_attr is None: - raise KeyError( - "No 'units' attribute found for time coordinate. Make sure to open " - "the dataset with `decode_times=False`." + units, ref_date = _split_time_units_attr(units_attr) + ref_date = pd.to_datetime(ref_date) + + data = [ref_date + pd.DateOffset(**{units: offset}) for offset in time.data] + decoded_time = xr.DataArray( + name=time.name, + data=data, + dims=time.dims, + coords={time.name: data}, + attrs=time.attrs, + ) + decoded_time.encoding = { + "source": dataset.encoding.get("source", "None"), + "dtype": time.dtype, + "original_shape": time.shape, + "units": units_attr, + "calendar": time.attrs.get("calendar", "none"), + } + dataset = dataset.assign_coords({time.name: decoded_time}) + + if time_bounds is not None: + data_bounds = [ + [ + ref_date + pd.DateOffset(**{units: lower}), + ref_date + pd.DateOffset(**{units: upper}), + ] + for [lower, upper] in time_bounds.data + ] + decoded_time_bnds = xr.DataArray( + name=time_bounds.name, + data=data_bounds, + dims=time_bounds.dims, + coords=time_bounds.coords, + attrs=time_bounds.attrs, ) + decoded_time_bnds.coords[time.name] = decoded_time + decoded_time_bnds.encoding = time_bounds.encoding + dataset = dataset.assign({time_bounds.name: decoded_time_bnds}) - units, reference_date = units_attr.split(" since ") - non_cf_units_to_freq = {"months": "MS", "years": "YS"} - - cf_compliant = units not in non_cf_units_to_freq.keys() - if cf_compliant: - dataset = xr.decode_cf(dataset, decode_times=True) - else: - # NOTE: The "calendar" attribute for units consisting of "months" or - # "years" is not factored when generating date ranges. The number of - # days in a month is not factored. - decoded_time = xr.DataArray( - data=pd.date_range( - start=reference_date, - periods=time.size, - freq=non_cf_units_to_freq[units], - ), - dims=["time"], - attrs=dataset["time"].attrs, - ) - decoded_time.encoding = { - "source": dataset.encoding.get("source"), - "dtype": time.dtype, - "original_shape": decoded_time.shape, - "units": units_attr, - # pandas.date_range() returns "proleptic_gregorian" by default - "calendar": "proleptic_gregorian", - } - - dataset = dataset.assign_coords({"time": decoded_time}) return dataset -def _check_dataset_for_cf_compliant_time(path: Union[str, List[str]]): - """Determine if a dataset has cf_compliant time +def infer_or_keep_var(dataset: xr.Dataset, data_var: Optional[str]) -> xr.Dataset: + """Infer the data variable(s) or keep a specific one in the Dataset. - Operations include: + If ``data_var`` is None, then this function checks the number of + regular (non-bounds) data variables in the Dataset. If there is a single + regular data var, then it will add an 'xcdat_infer' attr pointing to it in + the Dataset. XCDAT APIs can then call `get_inferred_var()` to get the data + var linked to the 'xcdat_infer' attr. If there are multiple regular data + variables, the 'xcdat_infer' attr is not set and the Dataset is returned + as is. + + If ``data_var`` is not None, then this function checks if the ``data_var`` + exists in the Dataset and if it is a regular data var. If those checks pass, + it will subset the Dataset to retain that ``data_var`` and all bounds data + vars. An 'xcdat_infer' attr pointing to the ``data_var`` is also added + to the Dataset. - - Open the file / dataset (in the case of multi-file datasets, only open - one file) - - Determine the time units and whether they are cf-compliant - - Return a Boolean (None if the time axis or time units do not exist) + This utility function is useful for designing XCDAT APIs with an optional + ``data_var`` kwarg. If ``data_var`` is None, an inference to the desired + data var is performed with a call to this function. Otherwise, perform the + API operation explicitly on ``data_var``. Parameters ---------- - path : Union[str, List[str]] - Either a file (``"file.nc"``), a string glob in the form - ``"path/to/my/files/*.nc"``, or an explicit list of files to open. - Paths can be given as strings or as pathlib Paths. If concatenation - along more than one dimension is desired, then ``paths`` must be a - nested list-of-lists (see ``combine_nested`` for details). (A string - glob will be expanded to a 1-dimensional list.) + dataset : xr.Dataset + The Dataset. + data_var: Optional[str], optional + The key of the data variable to keep in the Dataset. Returns ------- - Boolean - True if dataset is cf_compliant or False if not - Returns None if time or time units are not present - - Notes - ----- - This function only checks one file of multifile datasets (for performance). + xr.Dataset + The Dataset. + Raises + ------ + KeyError + If the specified data variable is not found in the Dataset. + KeyError + If the user specifies a bounds variable to keep. """ - # non-cf compliant units handled by xcdat - # Note: Should this be defined more globally? Is it possible to do the - # opposite (e.g., get the list of cf_compliant units and check that)? - non_cf_units_to_freq = ["months", "years"] + ds = dataset.copy() + # Make sure the "xcdat_infer" attr is "None" because a Dataset may be + # written with this attr already set. + ds.attrs["xcdat_infer"] = "None" - # Get one example file to check - # Note: This doesn't handle pathlib paths or a list of lists - if type(path) == str: - if "*" in path: - fn1 = glob(path)[0] - else: - fn1 = path - else: - fn1 = path[0] + all_vars = ds.data_vars.keys() + bounds_vars = ds.bounds.names + regular_vars: List[Hashable] = list(set(all_vars) ^ set(bounds_vars)) - # Open one file - ds = xr.open_dataset(fn1, decode_times=False) - # if there is no time dimension return None for the time units - # else get the time units - if ds.cf.dims.get("T") is None: - cf_compliant = None - else: - time = ds["time"] - units_attr = time.attrs.get("units") - units, reference_date = units_attr.split(" since ") - cf_compliant = units not in non_cf_units_to_freq - ds.close() + if len(regular_vars) == 0: + logger.debug("This dataset only contains bounds data variables.") - return cf_compliant + if data_var is None: + if len(regular_vars) == 1: + ds.attrs["xcdat_infer"] = regular_vars[0] + elif len(regular_vars) > 1: + regular_vars_str = ", ".join( + f"'{var}'" for var in sorted(regular_vars) # type:ignore + ) + logger.debug( + "This dataset contains more than one regular data variable " + f"({regular_vars_str}). If desired, pass the `data_var` kwarg to " + "reduce down to one regular data var." + ) + if data_var is not None: + if data_var not in all_vars: + raise KeyError( + f"The data variable '{data_var}' does not exist in the dataset." + ) + if data_var in bounds_vars: + raise KeyError("Please specify a regular (non-bounds) data variable.") + + ds = dataset[[data_var] + bounds_vars] + ds.attrs["xcdat_infer"] = data_var + + return ds def get_inferred_var(dataset: xr.Dataset) -> xr.DataArray: @@ -494,3 +517,69 @@ def get_inferred_var(dataset: xr.Dataset) -> xr.DataArray: "'xcdat_infer' for this operation." ) return data_var.copy() + + +def _preprocess_non_cf_dataset( + ds: xr.Dataset, callable: Optional[Callable] = None +) -> xr.Dataset: + """Preprocessing for each non-CF compliant dataset in ``open_mfdataset()``. + + This function allows for a user specified preprocess function, in addition + to XCDAT preprocessing functions. + + One call is performed to ``decode_non_cf_time()`` for decoding each + dataset's time coordinates and time bounds (if they exist) with non-CF + compliant units. By default, if ``decode_times=False`` is passed, xarray + will concatenate time values using the first dataset's "units" attribute. + This is an issue for cases where the numerically encoded time values are the + same and the "units" attribute differs between datasets. For example, + two files have the same time values, but the units of the first file is + "months since 2000-01-01" and the second is "months since 2001-01-01". Since + the first dataset's units are used in xarray for concatenating datasets, + the time values corresponding to the second file will be dropped since they + appear to be the same as the first file. Calling ``decode_non_cf_time()`` + on each dataset individually before concatenating solves the aforementioned + issue. + + Parameters + ---------- + ds : xr.Dataset + The Dataset. + callable : Optional[Callable], optional + A user specified optional callable function for preprocessing. + + Returns + ------- + xr.Dataset + The preprocessed Dataset. + """ + ds_new = ds.copy() + if callable: + ds_new = callable(ds) + ds_new = decode_non_cf_time(ds_new) + return ds_new + + +def _split_time_units_attr(units_attr: str) -> Tuple[str, str]: + """Splits the time coordinates' units attr into units and reference date. + + Parameters + ---------- + units_attr : str + The units attribute (e.g., "months since 1800-01-01"). + + Returns + ------- + Tuple[str, str] + The units ("months") and the reference date ("1800-01-01"). + + Raises + ------ + KeyError + If the units attribute doesn't exist for the time coordinates. + """ + if units_attr is None: + raise KeyError("No 'units' attribute found for the dataset's time coordinates.") + + units, reference_date = units_attr.split(" since ") + return units, reference_date