Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: check that store, array, and group classes are serializable #2006

Merged
merged 13 commits into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/zarr/abc/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ def _check_writable(self) -> None:
if self.mode.readonly:
raise ValueError("store mode does not support writing")

@abstractmethod
def __eq__(self, value: object) -> bool:
"""Equality comparison."""
...

@abstractmethod
async def get(
self,
Expand Down
5 changes: 5 additions & 0 deletions src/zarr/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,11 @@ def __add__(self, other: Buffer) -> Self:
np.concatenate((np.asanyarray(self._data), np.asanyarray(other_array)))
)

def __eq__(self, other: object) -> bool:
# Note: this was needed to support comparing MemoryStore instances with Buffer values in them
# if/when we stopped putting buffers in memory stores, this can be removed
return isinstance(other, type(self)) and self.to_bytes() == other.to_bytes()


class NDBuffer:
"""An n-dimensional memory block
Expand Down
14 changes: 14 additions & 0 deletions src/zarr/store/memory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from collections.abc import AsyncGenerator, MutableMapping
from typing import Any

from zarr.abc.store import Store
from zarr.buffer import Buffer, BufferPrototype
Expand Down Expand Up @@ -38,6 +39,19 @@ def __str__(self) -> str:
def __repr__(self) -> str:
return f"MemoryStore({str(self)!r})"

def __eq__(self, other: object) -> bool:
return (
isinstance(other, type(self))
and self._store_dict == other._store_dict
and self.mode == other.mode
)

def __setstate__(self, state: Any) -> None:
raise NotImplementedError(f"{type(self)} cannot be pickled")

def __getstate__(self) -> None:
raise NotImplementedError(f"{type(self)} cannot be pickled")

async def get(
self,
key: str,
Expand Down
10 changes: 10 additions & 0 deletions src/zarr/store/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
this must not be used.
"""
super().__init__(mode=mode)
self._storage_options = storage_options
if isinstance(url, str):
self._url = url.rstrip("/")
self._fs, _path = fsspec.url_to_fs(url, **storage_options)
Expand Down Expand Up @@ -91,6 +92,15 @@ def __str__(self) -> str:
def __repr__(self) -> str:
return f"<RemoteStore({type(self._fs).__name__}, {self.path})>"

def __eq__(self, other: object) -> bool:
return (
isinstance(other, type(self))
and self.path == other.path
and self.mode == other.mode
and self._url == other._url
# and self._storage_options == other._storage_options # FIXME: this isn't working for some reason
)

async def get(
self,
key: str,
Expand Down
14 changes: 14 additions & 0 deletions src/zarr/testing/store.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pickle
from typing import Any, Generic, TypeVar

import pytest
Expand Down Expand Up @@ -42,6 +43,19 @@ def test_store_type(self, store: S) -> None:
assert isinstance(store, Store)
assert isinstance(store, self.store_cls)

def test_store_eq(self, store: S, store_kwargs: dict[str, Any]) -> None:
# check self equality
assert store == store

# check store equality with same inputs
# asserting this is important for being able to compare (de)serialized stores
store2 = self.store_cls(**store_kwargs)
assert store == store2

def test_serizalizable_store(self, store: S) -> None:
foo = pickle.dumps(store)
assert pickle.loads(foo) == store

def test_store_mode(self, store: S, store_kwargs: dict[str, Any]) -> None:
assert store.mode == AccessMode.from_literal("r+")
assert not store.mode.readonly
Expand Down
36 changes: 35 additions & 1 deletion tests/v3/test_array.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import pickle
from typing import Literal

import numpy as np
import pytest

from zarr.array import Array
from zarr.array import Array, AsyncArray
from zarr.common import ZarrFormat
from zarr.errors import ContainsArrayError, ContainsGroupError
from zarr.group import Group
Expand Down Expand Up @@ -136,3 +137,36 @@ def test_array_v3_fill_value(store: MemoryStore, fill_value: int, dtype_str: str

assert arr.fill_value == np.dtype(dtype_str).type(fill_value)
assert arr.fill_value.dtype == arr.dtype


@pytest.mark.parametrize("store", ("local",), indirect=["store"])
@pytest.mark.parametrize("zarr_format", (2, 3))
async def test_serizalizable_async_array(
store: LocalStore | MemoryStore, zarr_format: ZarrFormat
) -> None:
expected = await AsyncArray.create(
store=store, shape=(100,), chunks=(10,), zarr_format=zarr_format, dtype="i4"
)
# await expected.setitems(list(range(100)))

p = pickle.dumps(expected)
actual = pickle.loads(p)

assert actual == expected
# np.testing.assert_array_equal(await actual.getitem(slice(None)), await expected.getitem(slice(None)))
# TODO: uncomment the parts of this test that will be impacted by the config/prototype changes in flight


@pytest.mark.parametrize("store", ("local",), indirect=["store"])
@pytest.mark.parametrize("zarr_format", (2, 3))
def test_serizalizable_sync_array(store: LocalStore, zarr_format: ZarrFormat) -> None:
expected = Array.create(
store=store, shape=(100,), chunks=(10,), zarr_format=zarr_format, dtype="i4"
)
expected[:] = list(range(100))

p = pickle.dumps(expected)
actual = pickle.loads(p)

assert actual == expected
np.testing.assert_array_equal(actual[:], expected[:])
24 changes: 24 additions & 0 deletions tests/v3/test_group.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import pickle
from typing import Any, Literal, cast

import numpy as np
Expand Down Expand Up @@ -653,3 +654,26 @@ async def test_asyncgroup_update_attributes(

agroup_new_attributes = await agroup.update_attributes(attributes_new)
assert agroup_new_attributes.attrs == attributes_new


@pytest.mark.parametrize("store", ("local",), indirect=["store"])
@pytest.mark.parametrize("zarr_format", (2, 3))
async def test_serizalizable_async_group(
store: LocalStore | MemoryStore, zarr_format: ZarrFormat
) -> None:
expected = await AsyncGroup.create(
store=store, attributes={"foo": 999}, zarr_format=zarr_format
)
p = pickle.dumps(expected)
actual = pickle.loads(p)
assert actual == expected


@pytest.mark.parametrize("store", ("local",), indirect=["store"])
@pytest.mark.parametrize("zarr_format", (2, 3))
def test_serizalizable_sync_group(store: LocalStore, zarr_format: ZarrFormat) -> None:
expected = Group.create(store=store, attributes={"foo": 999}, zarr_format=zarr_format)
p = pickle.dumps(expected)
actual = pickle.loads(p)

assert actual == expected
12 changes: 12 additions & 0 deletions tests/v3/test_store/test_memory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import pickle

import pytest

from zarr.buffer import Buffer
Expand Down Expand Up @@ -40,3 +42,13 @@ def test_store_supports_partial_writes(self, store: MemoryStore) -> None:

def test_list_prefix(self, store: MemoryStore) -> None:
assert True

def test_serizalizable_store(self, store: MemoryStore) -> None:
with pytest.raises(NotImplementedError):
store.__getstate__()

with pytest.raises(NotImplementedError):
store.__setstate__({})

with pytest.raises(NotImplementedError):
pickle.dumps(store)
2 changes: 1 addition & 1 deletion tests/v3/test_store/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def store_kwargs(self, request) -> dict[str, str | bool]:
anon = False
mode = "r+"
if request.param == "use_upath":
return {"mode": mode, "url": UPath(url, endpoint_url=endpoint_url, anon=anon)}
return {"url": UPath(url, endpoint_url=endpoint_url, anon=anon), "mode": mode}
elif request.param == "use_str":
return {"url": url, "mode": mode, "anon": anon, "endpoint_url": endpoint_url}

Expand Down