From 536709ce171ca5c07ca9a83b06cec3bb6a1eded7 Mon Sep 17 00:00:00 2001 From: Sam Levang <39069044+slevang@users.noreply.github.com> Date: Wed, 10 Jul 2024 17:54:07 +0000 Subject: [PATCH] fix: support new datatree and numpy 2.0 (#169) --- tests/models/test_complex_eof_rotator.py | 4 ++-- tests/models/test_decomposer.py | 2 +- tests/models/test_eof_rotator.py | 4 ++-- xeofs/data_container/data_container.py | 6 +++++- xeofs/models/_base_cross_model.py | 6 +++++- xeofs/models/_base_model.py | 6 +++++- xeofs/preprocessing/preprocessor.py | 9 +++++++-- xeofs/preprocessing/transformer.py | 6 +++++- xeofs/utils/io.py | 13 ++++++++++--- 9 files changed, 42 insertions(+), 14 deletions(-) diff --git a/tests/models/test_complex_eof_rotator.py b/tests/models/test_complex_eof_rotator.py index 8e195dca..166c777f 100644 --- a/tests/models/test_complex_eof_rotator.py +++ b/tests/models/test_complex_eof_rotator.py @@ -47,8 +47,8 @@ def test_fit(ceof_model): assert hasattr( ceof_rotator, "data" ), 'The attribute "data" should be populated after fitting.' - assert type(ceof_rotator.model) == ComplexEOF - assert type(ceof_rotator.data) == DataContainer + assert isinstance(ceof_rotator.model, ComplexEOF) + assert isinstance(ceof_rotator.data, DataContainer) @pytest.mark.parametrize( diff --git a/tests/models/test_decomposer.py b/tests/models/test_decomposer.py index c1dcef9e..d6f07f53 100644 --- a/tests/models/test_decomposer.py +++ b/tests/models/test_decomposer.py @@ -176,4 +176,4 @@ def test_random_state( U2 = decomposer.U_.data # Check that the results are the same - assert np.alltrue(U1 == U2) + assert np.all(U1 == U2) diff --git a/tests/models/test_eof_rotator.py b/tests/models/test_eof_rotator.py index b95589ff..91eaaa27 100644 --- a/tests/models/test_eof_rotator.py +++ b/tests/models/test_eof_rotator.py @@ -49,8 +49,8 @@ def test_fit(eof_model): assert hasattr( eof_rotator, "data" ), 'The attribute "data" should be populated after fitting.' - assert type(eof_rotator.model) == EOF - assert type(eof_rotator.data) == DataContainer + assert isinstance(eof_rotator.model, EOF) + assert isinstance(eof_rotator.data, DataContainer) @pytest.mark.parametrize( diff --git a/xeofs/data_container/data_container.py b/xeofs/data_container/data_container.py index 2c30d4f4..e70e4e5f 100644 --- a/xeofs/data_container/data_container.py +++ b/xeofs/data_container/data_container.py @@ -3,7 +3,11 @@ import dask from dask.diagnostics.progress import ProgressBar -from datatree import DataTree + +try: + from xarray.core.datatree import DataTree +except ImportError: + from datatree import DataTree from ..utils.data_types import DataArray diff --git a/xeofs/models/_base_cross_model.py b/xeofs/models/_base_cross_model.py index a8fa74c8..0ddff7bf 100644 --- a/xeofs/models/_base_cross_model.py +++ b/xeofs/models/_base_cross_model.py @@ -5,9 +5,13 @@ import dask import xarray as xr -from datatree import DataTree from dask.diagnostics.progress import ProgressBar +try: + from xarray.core.datatree import DataTree +except ImportError: + from datatree import DataTree + from .eof import EOF from ..preprocessing.preprocessor import Preprocessor from ..data_container import DataContainer diff --git a/xeofs/models/_base_model.py b/xeofs/models/_base_model.py index 424c8bd1..922de4bc 100644 --- a/xeofs/models/_base_model.py +++ b/xeofs/models/_base_model.py @@ -14,9 +14,13 @@ import dask import xarray as xr -from datatree import DataTree from dask.diagnostics.progress import ProgressBar +try: + from xarray.core.datatree import DataTree +except ImportError: + from datatree import DataTree + from ..preprocessing.preprocessor import Preprocessor from ..data_container import DataContainer from ..utils.data_types import DataObject, Data, DataArray diff --git a/xeofs/preprocessing/preprocessor.py b/xeofs/preprocessing/preprocessor.py index 22fc2cf7..64ee279d 100644 --- a/xeofs/preprocessing/preprocessor.py +++ b/xeofs/preprocessing/preprocessor.py @@ -2,7 +2,12 @@ from typing_extensions import Self import numpy as np -from datatree import DataTree + +try: + from xarray.core.datatree import DataTree +except ImportError: + from datatree import DataTree + from .list_processor import GenericListTransformer from .dimension_renamer import DimensionRenamer @@ -48,7 +53,7 @@ def extract_new_dim_names(X: List[DimensionRenamer]) -> Tuple[Dims, DimsList]: for x in X: new_sample_dims.append(x.sample_dims_after) new_feature_dims.append(x.feature_dims_after) - new_sample_dims: Dims = tuple(np.unique(np.asarray(new_sample_dims))) + new_sample_dims: Dims = tuple(np.unique(np.asarray(new_sample_dims)).tolist()) return new_sample_dims, new_feature_dims diff --git a/xeofs/preprocessing/transformer.py b/xeofs/preprocessing/transformer.py index 354ad1bd..5ff39824 100644 --- a/xeofs/preprocessing/transformer.py +++ b/xeofs/preprocessing/transformer.py @@ -5,9 +5,13 @@ import pandas as pd import xarray as xr -from datatree import DataTree from sklearn.base import BaseEstimator, TransformerMixin +try: + from xarray.core.datatree import DataTree +except ImportError: + from datatree import DataTree + from ..utils.data_types import Dims, DataArray, DataSet, Data diff --git a/xeofs/utils/io.py b/xeofs/utils/io.py index e5265e8e..0cd69321 100644 --- a/xeofs/utils/io.py +++ b/xeofs/utils/io.py @@ -3,7 +3,12 @@ import numpy as np import xarray as xr -from datatree import DataTree, open_datatree + +try: + from xarray.core.datatree import DataTree + from xarray.backends.api import open_datatree +except ImportError: + from datatree import DataTree, open_datatree def write_model_tree( @@ -20,9 +25,11 @@ def write_model_tree( raise ValueError(f"Unknown engine {engine}") -def open_model_tree(path: str, engine: str = "zarr", chunks={}, **kwargs) -> DataTree: +def open_model_tree(path: str, engine: str = "zarr", **kwargs) -> DataTree: """Open a DataTree from a file.""" - dt = open_datatree(path, engine=engine, chunks=chunks, **kwargs) + if engine == "zarr" and "chunks" not in kwargs: + kwargs["chunks"] = {} + dt = open_datatree(path, engine=engine, **kwargs) if engine in ["netcdf4", "h5netcdf"]: dt = _desanitize_attrs_nc(dt) return dt