-
-
Notifications
You must be signed in to change notification settings - Fork 308
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
implement store.list_prefix
and store._set_many
#2064
Changes from all commits
ebbfbe0
da6083e
e4101b7
70f9ceb
6eadb0c
0b54e4c
ddbbd60
49b4c1a
c139c10
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,9 +3,9 @@ | |
|
||
import pytest | ||
|
||
import zarr.api.asynchronous | ||
from zarr.abc.store import AccessMode, Store | ||
from zarr.core.buffer import Buffer, default_buffer_prototype | ||
from zarr.core.sync import _collect_aiterator | ||
from zarr.store._utils import _normalize_interval_index | ||
from zarr.testing.utils import assert_bytes_equal | ||
|
||
|
@@ -123,6 +123,18 @@ async def test_set(self, store: S, key: str, data: bytes) -> None: | |
observed = self.get(store, key) | ||
assert_bytes_equal(observed, data_buf) | ||
|
||
async def test_set_many(self, store: S) -> None: | ||
""" | ||
Test that a dict of key : value pairs can be inserted into the store via the | ||
`_set_many` method. | ||
""" | ||
keys = ["zarr.json", "c/0", "foo/c/0.0", "foo/0/0"] | ||
data_buf = [self.buffer_cls.from_bytes(k.encode()) for k in keys] | ||
store_dict = dict(zip(keys, data_buf, strict=True)) | ||
await store._set_many(store_dict.items()) | ||
for k, v in store_dict.items(): | ||
assert self.get(store, k).to_bytes() == v.to_bytes() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if x.to_bytes() == y.to_bytes(), does x== y? Isn't there a multiple get? Maybe not important here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
no, and I suspect this might be deliberate since in principle
there is no multiple get (nor a multiple set, nor a multiple delete). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. xref: src/zarr/buffer.py in #2006 |
||
|
||
@pytest.mark.parametrize( | ||
"key_ranges", | ||
( | ||
|
@@ -185,76 +197,57 @@ async def test_clear(self, store: S) -> None: | |
assert await store.empty() | ||
|
||
async def test_list(self, store: S) -> None: | ||
assert [k async for k in store.list()] == [] | ||
await store.set("foo/zarr.json", self.buffer_cls.from_bytes(b"bar")) | ||
keys = [k async for k in store.list()] | ||
assert keys == ["foo/zarr.json"], keys | ||
|
||
expected = ["foo/zarr.json"] | ||
for i in range(10): | ||
key = f"foo/c/{i}" | ||
expected.append(key) | ||
await store.set( | ||
f"foo/c/{i}", self.buffer_cls.from_bytes(i.to_bytes(length=3, byteorder="little")) | ||
) | ||
assert await _collect_aiterator(store.list()) == () | ||
prefix = "foo" | ||
data = self.buffer_cls.from_bytes(b"") | ||
store_dict = { | ||
prefix + "/zarr.json": data, | ||
**{prefix + f"/c/{idx}": data for idx in range(10)}, | ||
} | ||
await store._set_many(store_dict.items()) | ||
expected_sorted = sorted(store_dict.keys()) | ||
observed = await _collect_aiterator(store.list()) | ||
observed_sorted = sorted(observed) | ||
assert observed_sorted == expected_sorted | ||
|
||
@pytest.mark.xfail | ||
async def test_list_prefix(self, store: S) -> None: | ||
# TODO: we currently don't use list_prefix anywhere | ||
raise NotImplementedError | ||
""" | ||
Test that the `list_prefix` method works as intended. Given a prefix, it should return | ||
all the keys in storage that start with this prefix. Keys should be returned with the shared | ||
prefix removed. | ||
""" | ||
prefixes = ("", "a/", "a/b/", "a/b/c/") | ||
data = self.buffer_cls.from_bytes(b"") | ||
fname = "zarr.json" | ||
store_dict = {p + fname: data for p in prefixes} | ||
|
||
await store._set_many(store_dict.items()) | ||
|
||
for prefix in prefixes: | ||
observed = tuple(sorted(await _collect_aiterator(store.list_prefix(prefix)))) | ||
expected: tuple[str, ...] = () | ||
for key in store_dict.keys(): | ||
if key.startswith(prefix): | ||
expected += (key.removeprefix(prefix),) | ||
expected = tuple(sorted(expected)) | ||
assert observed == expected | ||
|
||
async def test_list_dir(self, store: S) -> None: | ||
out = [k async for k in store.list_dir("")] | ||
assert out == [] | ||
assert [k async for k in store.list_dir("foo")] == [] | ||
await store.set("foo/zarr.json", self.buffer_cls.from_bytes(b"bar")) | ||
await store.set("group-0/zarr.json", self.buffer_cls.from_bytes(b"\x01")) # group | ||
await store.set("group-0/group-1/zarr.json", self.buffer_cls.from_bytes(b"\x01")) # group | ||
await store.set("group-0/group-1/a1/zarr.json", self.buffer_cls.from_bytes(b"\x01")) | ||
await store.set("group-0/group-1/a2/zarr.json", self.buffer_cls.from_bytes(b"\x01")) | ||
await store.set("group-0/group-1/a3/zarr.json", self.buffer_cls.from_bytes(b"\x01")) | ||
|
||
keys_expected = ["foo", "group-0"] | ||
keys_observed = [k async for k in store.list_dir("")] | ||
assert set(keys_observed) == set(keys_expected) | ||
|
||
keys_expected = ["zarr.json"] | ||
keys_observed = [k async for k in store.list_dir("foo")] | ||
|
||
assert len(keys_observed) == len(keys_expected), keys_observed | ||
assert set(keys_observed) == set(keys_expected), keys_observed | ||
|
||
keys_observed = [k async for k in store.list_dir("foo/")] | ||
assert len(keys_expected) == len(keys_observed), keys_observed | ||
assert set(keys_observed) == set(keys_expected), keys_observed | ||
|
||
keys_observed = [k async for k in store.list_dir("group-0")] | ||
keys_expected = ["zarr.json", "group-1"] | ||
|
||
assert len(keys_observed) == len(keys_expected), keys_observed | ||
assert set(keys_observed) == set(keys_expected), keys_observed | ||
|
||
keys_observed = [k async for k in store.list_dir("group-0/")] | ||
assert len(keys_expected) == len(keys_observed), keys_observed | ||
assert set(keys_observed) == set(keys_expected), keys_observed | ||
root = "foo" | ||
store_dict = { | ||
root + "/zarr.json": self.buffer_cls.from_bytes(b"bar"), | ||
root + "/c/1": self.buffer_cls.from_bytes(b"\x01"), | ||
} | ||
|
||
keys_observed = [k async for k in store.list_dir("group-0/group-1")] | ||
keys_expected = ["zarr.json", "a1", "a2", "a3"] | ||
assert await _collect_aiterator(store.list_dir("")) == () | ||
assert await _collect_aiterator(store.list_dir(root)) == () | ||
|
||
assert len(keys_observed) == len(keys_expected), keys_observed | ||
assert set(keys_observed) == set(keys_expected), keys_observed | ||
await store._set_many(store_dict.items()) | ||
|
||
keys_observed = [k async for k in store.list_dir("group-0/group-1")] | ||
assert len(keys_expected) == len(keys_observed), keys_observed | ||
assert set(keys_observed) == set(keys_expected), keys_observed | ||
keys_observed = await _collect_aiterator(store.list_dir(root)) | ||
keys_expected = {k.removeprefix(root + "/").split("/")[0] for k in store_dict.keys()} | ||
|
||
async def test_set_get(self, store_kwargs: dict[str, Any]) -> None: | ||
kwargs = {**store_kwargs, **{"mode": "w"}} | ||
store = self.store_cls(**kwargs) | ||
await zarr.api.asynchronous.open_array(store=store, path="a", mode="w", shape=(4,)) | ||
keys = [x async for x in store.list()] | ||
assert keys == ["a/zarr.json"] | ||
assert sorted(keys_observed) == sorted(keys_expected) | ||
|
||
# no errors | ||
await zarr.api.asynchronous.open_array(store=store, path="a", mode="r") | ||
await zarr.api.asynchronous.open_array(store=store, path="a", mode="a") | ||
keys_observed = await _collect_aiterator(store.list_dir(root + "/")) | ||
assert sorted(keys_expected) == sorted(keys_observed) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,21 @@ | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING | ||
|
||
if TYPE_CHECKING: | ||
from collections.abc import Generator | ||
|
||
import botocore.client | ||
|
||
import os | ||
from collections.abc import Generator | ||
|
||
import botocore.client | ||
import fsspec | ||
import pytest | ||
from botocore.session import Session | ||
from upath import UPath | ||
|
||
from zarr.core.buffer import Buffer, cpu, default_buffer_prototype | ||
from zarr.core.sync import sync | ||
from zarr.core.sync import _collect_aiterator, sync | ||
from zarr.store import RemoteStore | ||
from zarr.testing.store import StoreTests | ||
|
||
|
@@ -40,8 +48,6 @@ def s3_base() -> Generator[None, None, None]: | |
|
||
|
||
def get_boto3_client() -> botocore.client.BaseClient: | ||
from botocore.session import Session | ||
|
||
# NB: we use the sync botocore client for setup | ||
session = Session() | ||
return session.create_client("s3", endpoint_url=endpoint_url) | ||
|
@@ -87,7 +93,7 @@ async def test_basic() -> None: | |
store = await RemoteStore.open( | ||
f"s3://{test_bucket_name}", mode="w", endpoint_url=endpoint_url, anon=False | ||
) | ||
assert not await alist(store.list()) | ||
assert await _collect_aiterator(store.list()) == () | ||
assert not await store.exists("foo") | ||
data = b"hello" | ||
await store.set("foo", cpu.Buffer.from_bytes(data)) | ||
|
@@ -104,7 +110,7 @@ class TestRemoteStoreS3(StoreTests[RemoteStore, cpu.Buffer]): | |
buffer_cls = cpu.Buffer | ||
|
||
@pytest.fixture(scope="function", params=("use_upath", "use_str")) | ||
def store_kwargs(self, request) -> dict[str, str | bool]: | ||
def store_kwargs(self, request: pytest.FixtureRequest) -> dict[str, str | bool | UPath]: # type: ignore | ||
url = f"s3://{test_bucket_name}" | ||
anon = False | ||
mode = "r+" | ||
|
@@ -116,8 +122,8 @@ def store_kwargs(self, request) -> dict[str, str | bool]: | |
raise AssertionError | ||
|
||
@pytest.fixture(scope="function") | ||
def store(self, store_kwargs: dict[str, str | bool]) -> RemoteStore: | ||
url = store_kwargs["url"] | ||
async def store(self, store_kwargs: dict[str, str | bool | UPath]) -> RemoteStore: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This isn't actually async There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. correct, but the class we are inheriting from defines this as an async method |
||
url: str | UPath = store_kwargs["url"] | ||
mode = store_kwargs["mode"] | ||
if isinstance(url, UPath): | ||
out = self.store_cls(url=url, mode=mode) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Were we getting duplicates?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think we were. this code path was not tested until this PR