Skip to content

Commit

Permalink
add full_like and isnan so that equality checking does work with xarray
Browse files Browse the repository at this point in the history
  • Loading branch information
TomNicholas committed Mar 14, 2024
1 parent 023373c commit 6595961
Showing 1 changed file with 46 additions and 14 deletions.
60 changes: 46 additions & 14 deletions virtualizarr/manifests/array_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Tuple
from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Tuple, Union

import numpy as np

Expand All @@ -24,7 +24,7 @@ def decorator(func):
return decorator


def check_combineable_zarr_arrays(arrays: Iterable["ManifestArray"]) -> None:
def _check_combineable_zarr_arrays(arrays: Iterable["ManifestArray"]) -> None:
"""
The downside of the ManifestArray approach compared to the VirtualZarrArray concatenation proposal is that
the result must also be a single valid zarr array, implying that the inputs must have the same dtype, codec etc.
Expand Down Expand Up @@ -73,6 +73,16 @@ def _check_same_chunk_shapes(chunks_list: List[Tuple[int, ...]]) -> None:
)


@implements(np.result_type)
def result_type(*arrays_and_dtypes) -> np.dtype:
"""Called by xarray to ensure all arguments to concat have the same dtype."""
first_dtype, *other_dtypes = [np.dtype(obj) for obj in arrays_and_dtypes]
for other_dtype in other_dtypes:
if other_dtype != first_dtype:
raise ValueError("dtypes not all consistent")
return first_dtype


@implements(np.concatenate)
def concatenate(
arrays: tuple["ManifestArray", ...] | list["ManifestArray"],
Expand All @@ -95,7 +105,7 @@ def concatenate(
raise TypeError()

# ensure dtypes, shapes, codecs etc. are consistent
check_combineable_zarr_arrays(arrays)
_check_combineable_zarr_arrays(arrays)

_check_same_ndims([arr.ndim for arr in arrays])

Expand Down Expand Up @@ -161,16 +171,6 @@ def _remove_element_at_position(t: tuple[int, ...], pos: int) -> tuple[int, ...]
return tuple(new_l)


@implements(np.result_type)
def result_type(*arrays_and_dtypes) -> np.dtype:
"""Called by xarray to ensure all arguments to concat have the same dtype."""
first_dtype, *other_dtypes = [np.dtype(obj) for obj in arrays_and_dtypes]
for other_dtype in other_dtypes:
if other_dtype != first_dtype:
raise ValueError("dtypes not all consistent")
return first_dtype


@implements(np.stack)
def stack(
arrays: tuple["ManifestArray", ...] | list["ManifestArray"],
Expand All @@ -189,7 +189,7 @@ def stack(
raise TypeError()

# ensure dtypes, shapes, codecs etc. are consistent
check_combineable_zarr_arrays(arrays)
_check_combineable_zarr_arrays(arrays)

_check_same_ndims([arr.ndim for arr in arrays])
arr_shapes = [arr.shape for arr in arrays]
Expand Down Expand Up @@ -237,3 +237,35 @@ def _check_same_shapes(shapes: List[Tuple[int, ...]]) -> None:
raise ValueError(
f"Cannot concatenate arrays with differing shapes: {first_shape} vs {other_shape}"
)


@implements(np.full_like)
def full_like(
x: "ManifestArray", /, fill_value: bool, *, dtype: Union[np.dtype, None]
) -> np.ndarray:
"""
Returns a new array filled with fill_value and having the same shape as an input array x.
Returns a numpy array instead of a ManifestArray.
Only implemented to get past some checks deep inside xarray, see https://github.com/TomNicholas/VirtualiZarr/issues/29.
"""
return np.full(
shape=x.shape,
fill_value=fill_value,
dtype=dtype if dtype is not None else x.dtype,
)


@implements(np.isnan)
def isnan(x: "ManifestArray", /) -> np.ndarray:
"""
Returns a numpy array of all False.
Only implemented to get past some checks deep inside xarray, see https://github.com/TomNicholas/VirtualiZarr/issues/29.
"""
return np.full(
shape=x.shape,
fill_value=False,
dtype=np.dtype(bool),
)

0 comments on commit 6595961

Please sign in to comment.