Skip to content

Commit

Permalink
Merge pull request #30 from TomNicholas/equality
Browse files Browse the repository at this point in the history
Equality checking
  • Loading branch information
TomNicholas authored Mar 14, 2024
2 parents 28e02ce + 6595961 commit 6cac49e
Show file tree
Hide file tree
Showing 5 changed files with 212 additions and 17 deletions.
40 changes: 40 additions & 0 deletions virtualizarr/manifests/array.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
import warnings
from typing import Any, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -143,6 +144,8 @@ def __array_function__(self, func, types, args, kwargs) -> Any:

return MANIFESTARRAY_HANDLED_ARRAY_FUNCTIONS[func](*args, **kwargs)

# Everything beyond here is basically to make this array class wrappable by xarray #

def __array_ufunc__(self, ufunc, method, *inputs, **kwargs) -> Any:
"""We have to define this in order to convince xarray that this class is a duckarray, even though we will never support ufuncs."""
return NotImplemented
Expand All @@ -152,6 +155,43 @@ def __array__(self) -> np.ndarray:
"ManifestArrays can't be converted into numpy arrays or pandas Index objects"
)

def __eq__( # type: ignore[override]
self,
other: Union[int, float, bool, np.ndarray, "ManifestArray"],
) -> np.ndarray:
"""
Element-wise equality checking.
Returns a numpy array of booleans.
"""
if isinstance(other, (int, float, bool)):
# TODO what should this do when comparing against numpy arrays?
return np.full(shape=self.shape, fill_value=False, dtype=np.dtype(bool))
elif not isinstance(other, ManifestArray):
raise TypeError(
f"Cannot check equality between a ManifestArray and an object of type {type(other)}"
)

if self.shape != other.shape:
raise NotImplementedError("Unsure how to handle broadcasting like this")

if self.zarray != other.zarray:
return np.full(shape=self.shape, fill_value=False, dtype=np.dtype(bool))
else:
if self.manifest == other.manifest:
return np.full(shape=self.shape, fill_value=True, dtype=np.dtype(bool))
else:
# TODO this doesn't yet do what it should - it simply returns all False if any of the chunk entries are different.
# What it should do is return True for the locations where the chunk entries are the same.
warnings.warn(
"__eq__ currently is over-cautious, returning an array of all False if any of the chunk entries don't match.",
UserWarning,
)

# TODO do chunk-wise comparison
# TODO expand it into an element-wise result
return np.full(shape=self.shape, fill_value=False, dtype=np.dtype(bool))

def __getitem__(
self,
key,
Expand Down
67 changes: 51 additions & 16 deletions virtualizarr/manifests/array_api.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Callable, Dict, Iterable, List, Tuple, TYPE_CHECKING
from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Tuple, Union

import numpy as np

from ..zarr import ZArray, Codec
from ..zarr import Codec, ZArray
from .manifest import concat_manifests, stack_manifests

if TYPE_CHECKING:
Expand All @@ -24,7 +24,7 @@ def decorator(func):
return decorator


def check_combineable_zarr_arrays(arrays: Iterable["ManifestArray"]) -> None:
def _check_combineable_zarr_arrays(arrays: Iterable["ManifestArray"]) -> None:
"""
The downside of the ManifestArray approach compared to the VirtualZarrArray concatenation proposal is that
the result must also be a single valid zarr array, implying that the inputs must have the same dtype, codec etc.
Expand Down Expand Up @@ -73,9 +73,22 @@ def _check_same_chunk_shapes(chunks_list: List[Tuple[int, ...]]) -> None:
)


@implements(np.result_type)
def result_type(*arrays_and_dtypes) -> np.dtype:
"""Called by xarray to ensure all arguments to concat have the same dtype."""
first_dtype, *other_dtypes = [np.dtype(obj) for obj in arrays_and_dtypes]
for other_dtype in other_dtypes:
if other_dtype != first_dtype:
raise ValueError("dtypes not all consistent")
return first_dtype


@implements(np.concatenate)
def concatenate(
arrays: tuple["ManifestArray", ...] | list["ManifestArray"], /, *, axis: int | None = 0
arrays: tuple["ManifestArray", ...] | list["ManifestArray"],
/,
*,
axis: int | None = 0,
) -> "ManifestArray":
"""
Concatenate ManifestArrays by merging their chunk manifests.
Expand All @@ -92,7 +105,7 @@ def concatenate(
raise TypeError()

# ensure dtypes, shapes, codecs etc. are consistent
check_combineable_zarr_arrays(arrays)
_check_combineable_zarr_arrays(arrays)

_check_same_ndims([arr.ndim for arr in arrays])

Expand Down Expand Up @@ -158,16 +171,6 @@ def _remove_element_at_position(t: tuple[int, ...], pos: int) -> tuple[int, ...]
return tuple(new_l)


@implements(np.result_type)
def result_type(*arrays_and_dtypes) -> np.dtype:
"""Called by xarray to ensure all arguments to concat have the same dtype."""
first_dtype, *other_dtypes = [np.dtype(obj) for obj in arrays_and_dtypes]
for other_dtype in other_dtypes:
if other_dtype != first_dtype:
raise ValueError("dtypes not all consistent")
return first_dtype


@implements(np.stack)
def stack(
arrays: tuple["ManifestArray", ...] | list["ManifestArray"],
Expand All @@ -186,7 +189,7 @@ def stack(
raise TypeError()

# ensure dtypes, shapes, codecs etc. are consistent
check_combineable_zarr_arrays(arrays)
_check_combineable_zarr_arrays(arrays)

_check_same_ndims([arr.ndim for arr in arrays])
arr_shapes = [arr.shape for arr in arrays]
Expand Down Expand Up @@ -234,3 +237,35 @@ def _check_same_shapes(shapes: List[Tuple[int, ...]]) -> None:
raise ValueError(
f"Cannot concatenate arrays with differing shapes: {first_shape} vs {other_shape}"
)


@implements(np.full_like)
def full_like(
x: "ManifestArray", /, fill_value: bool, *, dtype: Union[np.dtype, None]
) -> np.ndarray:
"""
Returns a new array filled with fill_value and having the same shape as an input array x.
Returns a numpy array instead of a ManifestArray.
Only implemented to get past some checks deep inside xarray, see https://github.com/TomNicholas/VirtualiZarr/issues/29.
"""
return np.full(
shape=x.shape,
fill_value=fill_value,
dtype=dtype if dtype is not None else x.dtype,
)


@implements(np.isnan)
def isnan(x: "ManifestArray", /) -> np.ndarray:
"""
Returns a numpy array of all False.
Only implemented to get past some checks deep inside xarray, see https://github.com/TomNicholas/VirtualiZarr/issues/29.
"""
return np.full(
shape=x.shape,
fill_value=False,
dtype=np.dtype(bool),
)
63 changes: 63 additions & 0 deletions virtualizarr/tests/test_manifests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,69 @@ def test_create_manifestarray_from_kerchunk_refs(self):
assert marr.zarray.order == "C"


class TestEquals:
def test_equals(self):
chunks_dict = {
"0.0.0": {"path": "s3://bucket/foo.nc", "offset": 100, "length": 100},
"0.0.1": {"path": "s3://bucket/foo.nc", "offset": 200, "length": 100},
"0.1.0": {"path": "s3://bucket/foo.nc", "offset": 300, "length": 100},
"0.1.1": {"path": "s3://bucket/foo.nc", "offset": 400, "length": 100},
}
manifest = ChunkManifest(entries=chunks_dict)
chunks = (5, 1, 10)
shape = (5, 2, 20)
zarray = ZArray(
chunks=chunks,
compressor="zlib",
dtype=np.dtype("int32"),
fill_value=0.0,
filters=None,
order="C",
shape=shape,
zarr_format=2,
)

marr1 = ManifestArray(zarray=zarray, chunkmanifest=manifest)
marr2 = ManifestArray(zarray=zarray, chunkmanifest=manifest)
result = marr1 == marr2
assert isinstance(result, np.ndarray)
assert result.shape == shape
assert result.dtype == np.dtype(bool)
assert result.all()

def test_not_equal_chunk_entries(self):
# both manifest arrays in this example have the same zarray properties
zarray = ZArray(
chunks=(5, 1, 10),
compressor="zlib",
dtype=np.dtype("int32"),
fill_value=0.0,
filters=None,
order="C",
shape=(5, 1, 20),
zarr_format=2,
)

chunks_dict1 = {
"0.0.0": {"path": "foo.nc", "offset": 100, "length": 100},
"0.0.1": {"path": "foo.nc", "offset": 200, "length": 100},
}
manifest1 = ChunkManifest(entries=chunks_dict1)
marr1 = ManifestArray(zarray=zarray, chunkmanifest=manifest1)

chunks_dict2 = {
"0.0.0": {"path": "foo.nc", "offset": 300, "length": 100},
"0.0.1": {"path": "foo.nc", "offset": 400, "length": 100},
}
manifest2 = ChunkManifest(entries=chunks_dict2)
marr2 = ManifestArray(zarray=zarray, chunkmanifest=manifest2)
assert not (marr1 == marr2).all()

@pytest.mark.skip(reason="Not Implemented")
def test_partly_equals(self):
...


# TODO we really need some kind of fixtures to generate useful example data
# The hard part is having an alternative way to get to the expected result of concatenation
class TestConcat:
Expand Down
18 changes: 18 additions & 0 deletions virtualizarr/tests/test_manifests/test_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,24 @@ def test_chunk_grid_info(self):
assert manifest.shape_chunk_grid == (1, 2, 2)


class TestEquals:
def test_equals(self):
manifest1 = ChunkManifest(
entries={
"0.0.0": {"path": "foo.nc", "offset": 100, "length": 100},
"0.0.1": {"path": "foo.nc", "offset": 200, "length": 100},
}
)
manifest2 = ChunkManifest(
entries={
"0.0.0": {"path": "foo.nc", "offset": 300, "length": 100},
"0.0.1": {"path": "foo.nc", "offset": 400, "length": 100},
}
)
assert not manifest1 == manifest2
assert manifest1 != manifest2


# TODO could we use hypothesis to test this?
# Perhaps by testing the property that splitting along a dimension then concatenating the pieces along that dimension should recreate the original manifest?
class TestCombineManifests:
Expand Down
41 changes: 40 additions & 1 deletion virtualizarr/tests/test_xarray.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import numpy as np
import xarray as xr

from virtualizarr.manifests import ChunkEntry, ManifestArray
from virtualizarr.manifests import ChunkEntry, ChunkManifest, ManifestArray
from virtualizarr.xarray import dataset_from_kerchunk_refs
from virtualizarr.zarr import ZArray


def test_dataset_from_kerchunk_refs():
Expand Down Expand Up @@ -66,6 +67,44 @@ def test_accessor_to_kerchunk_dict(self):
assert result_ds_refs == expected_ds_refs


class TestEquals:
# regression test for GH29 https://github.com/TomNicholas/VirtualiZarr/issues/29
def test_equals(self):
chunks = (5, 10)
shape = (5, 20)
zarray = ZArray(
chunks=chunks,
compressor="zlib",
dtype=np.dtype("int32"),
fill_value=0.0,
filters=None,
order="C",
shape=shape,
zarr_format=2,
)

chunks_dict1 = {
"0.0": {"path": "foo.nc", "offset": 100, "length": 100},
"0.1": {"path": "foo.nc", "offset": 200, "length": 100},
}
manifest1 = ChunkManifest(entries=chunks_dict1)
marr1 = ManifestArray(zarray=zarray, chunkmanifest=manifest1)
ds1 = xr.Dataset({"a": (["x", "y"], marr1)})

marr2 = ManifestArray(zarray=zarray, chunkmanifest=manifest1)
ds2 = xr.Dataset({"a": (["x", "y"], marr2)})
assert ds1.equals(ds2)

chunks_dict3 = {
"0.0": {"path": "foo.nc", "offset": 300, "length": 100},
"0.1": {"path": "foo.nc", "offset": 400, "length": 100},
}
manifest3 = ChunkManifest(entries=chunks_dict3)
marr3 = ManifestArray(zarray=zarray, chunkmanifest=manifest3)
ds3 = xr.Dataset({"a": (["x", "y"], marr3)})
assert not ds1.equals(ds3)


def test_kerchunk_roundtrip_in_memory_no_concat():
# TODO set up example xarray dataset

Expand Down

0 comments on commit 6cac49e

Please sign in to comment.