Skip to content

Commit

Permalink
fix: support new datatree and numpy 2.0 (#169)
Browse files Browse the repository at this point in the history
  • Loading branch information
slevang authored Jul 10, 2024
1 parent 388f87c commit 536709c
Show file tree
Hide file tree
Showing 9 changed files with 42 additions and 14 deletions.
4 changes: 2 additions & 2 deletions tests/models/test_complex_eof_rotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def test_fit(ceof_model):
assert hasattr(
ceof_rotator, "data"
), 'The attribute "data" should be populated after fitting.'
assert type(ceof_rotator.model) == ComplexEOF
assert type(ceof_rotator.data) == DataContainer
assert isinstance(ceof_rotator.model, ComplexEOF)
assert isinstance(ceof_rotator.data, DataContainer)


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_decomposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,4 +176,4 @@ def test_random_state(
U2 = decomposer.U_.data

# Check that the results are the same
assert np.alltrue(U1 == U2)
assert np.all(U1 == U2)
4 changes: 2 additions & 2 deletions tests/models/test_eof_rotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ def test_fit(eof_model):
assert hasattr(
eof_rotator, "data"
), 'The attribute "data" should be populated after fitting.'
assert type(eof_rotator.model) == EOF
assert type(eof_rotator.data) == DataContainer
assert isinstance(eof_rotator.model, EOF)
assert isinstance(eof_rotator.data, DataContainer)


@pytest.mark.parametrize(
Expand Down
6 changes: 5 additions & 1 deletion xeofs/data_container/data_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@

import dask
from dask.diagnostics.progress import ProgressBar
from datatree import DataTree

try:
from xarray.core.datatree import DataTree
except ImportError:
from datatree import DataTree

from ..utils.data_types import DataArray

Expand Down
6 changes: 5 additions & 1 deletion xeofs/models/_base_cross_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@

import dask
import xarray as xr
from datatree import DataTree
from dask.diagnostics.progress import ProgressBar

try:
from xarray.core.datatree import DataTree
except ImportError:
from datatree import DataTree

from .eof import EOF
from ..preprocessing.preprocessor import Preprocessor
from ..data_container import DataContainer
Expand Down
6 changes: 5 additions & 1 deletion xeofs/models/_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@

import dask
import xarray as xr
from datatree import DataTree
from dask.diagnostics.progress import ProgressBar

try:
from xarray.core.datatree import DataTree
except ImportError:
from datatree import DataTree

from ..preprocessing.preprocessor import Preprocessor
from ..data_container import DataContainer
from ..utils.data_types import DataObject, Data, DataArray
Expand Down
9 changes: 7 additions & 2 deletions xeofs/preprocessing/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@
from typing_extensions import Self

import numpy as np
from datatree import DataTree

try:
from xarray.core.datatree import DataTree
except ImportError:
from datatree import DataTree


from .list_processor import GenericListTransformer
from .dimension_renamer import DimensionRenamer
Expand Down Expand Up @@ -48,7 +53,7 @@ def extract_new_dim_names(X: List[DimensionRenamer]) -> Tuple[Dims, DimsList]:
for x in X:
new_sample_dims.append(x.sample_dims_after)
new_feature_dims.append(x.feature_dims_after)
new_sample_dims: Dims = tuple(np.unique(np.asarray(new_sample_dims)))
new_sample_dims: Dims = tuple(np.unique(np.asarray(new_sample_dims)).tolist())
return new_sample_dims, new_feature_dims


Expand Down
6 changes: 5 additions & 1 deletion xeofs/preprocessing/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@

import pandas as pd
import xarray as xr
from datatree import DataTree
from sklearn.base import BaseEstimator, TransformerMixin

try:
from xarray.core.datatree import DataTree
except ImportError:
from datatree import DataTree

from ..utils.data_types import Dims, DataArray, DataSet, Data


Expand Down
13 changes: 10 additions & 3 deletions xeofs/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@

import numpy as np
import xarray as xr
from datatree import DataTree, open_datatree

try:
from xarray.core.datatree import DataTree
from xarray.backends.api import open_datatree
except ImportError:
from datatree import DataTree, open_datatree


def write_model_tree(
Expand All @@ -20,9 +25,11 @@ def write_model_tree(
raise ValueError(f"Unknown engine {engine}")


def open_model_tree(path: str, engine: str = "zarr", chunks={}, **kwargs) -> DataTree:
def open_model_tree(path: str, engine: str = "zarr", **kwargs) -> DataTree:
"""Open a DataTree from a file."""
dt = open_datatree(path, engine=engine, chunks=chunks, **kwargs)
if engine == "zarr" and "chunks" not in kwargs:
kwargs["chunks"] = {}
dt = open_datatree(path, engine=engine, **kwargs)
if engine in ["netcdf4", "h5netcdf"]:
dt = _desanitize_attrs_nc(dt)
return dt
Expand Down

0 comments on commit 536709c

Please sign in to comment.