-
-
Notifications
You must be signed in to change notification settings - Fork 308
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
implement store.list_prefix
and store._set_many
#2064
Changes from 2 commits
ebbfbe0
da6083e
e4101b7
70f9ceb
6eadb0c
0b54e4c
ddbbd60
49b4c1a
c139c10
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -193,14 +193,10 @@ async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]: | |
------- | ||
AsyncGenerator[str, None] | ||
""" | ||
for p in (self.root / prefix).rglob("*"): | ||
if p.is_file(): | ||
yield str(p) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Were we getting duplicates? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think we were. this code path was not tested until this PR |
||
|
||
to_strip = str(self.root) + "/" | ||
for p in (self.root / prefix).rglob("*"): | ||
if p.is_file(): | ||
yield str(p).replace(to_strip, "") | ||
yield str(p).removeprefix(to_strip) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 |
||
|
||
async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]: | ||
""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -205,5 +205,6 @@ async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]: | |
yield onefile | ||
|
||
async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]: | ||
for onefile in await self._fs._ls(prefix, detail=False): | ||
yield onefile | ||
find_str = "/".join([self.path, prefix]) | ||
for onefile in await self._fs._find(find_str): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The defaults for find are: maxdepth=None, withdirs=False, detail=False; maybe good to be specific. Why is find() better than ls()? The former will return all child files, not just one level deep - is that the intent? If not, ls() ought to be generally more efficient. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It depends on whether you want one directory level or everything below it. When I wrote the original, I didn't know the intent. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i believe the intent here is to list everything below There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I misunderstood your first comment. since the intent is to use the behavior of |
||
yield onefile.removeprefix(find_str) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -114,6 +114,23 @@ def _get_loop() -> asyncio.AbstractEventLoop: | |
return loop[0] | ||
|
||
|
||
async def _collect_aiterator(data: AsyncIterator[T]) -> tuple[T, ...]: | ||
""" | ||
Collect an entire async iterator into a tuple | ||
""" | ||
result = [] | ||
async for x in data: | ||
result.append(x) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. asyncio.gather? Like above, not much point in having coroutines if we serially wait for them. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we can use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and for clarification, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should probably re-examine the use of async-iterators, though. If we can't gather() on them (seems to be true?), then they are the wrong abstraction since gather() is probably always what we actually want. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On second thoughts, maybe I'm wrong - does There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think it schedules them all at once. in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the idea with the generators is to a) support seamless pagination and b) support pipelining ( |
||
return tuple(result) | ||
|
||
|
||
def collect_aiterator(data: AsyncIterator[T]) -> tuple[T, ...]: | ||
""" | ||
Synchronously collect an entire async iterator into a tuple. | ||
""" | ||
return sync(_collect_aiterator(data)) | ||
|
||
|
||
class SyncMixin: | ||
def _sync(self, coroutine: Coroutine[Any, Any, T]) -> T: | ||
# TODO: refactor this to to take *args and **kwargs and pass those to the method | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
from zarr.abc.store import AccessMode, Store | ||
from zarr.buffer import Buffer, default_buffer_prototype | ||
from zarr.store.utils import _normalize_interval_index | ||
from zarr.sync import _collect_aiterator | ||
from zarr.testing.utils import assert_bytes_equal | ||
|
||
S = TypeVar("S", bound=Store) | ||
|
@@ -103,6 +104,18 @@ async def test_set(self, store: S, key: str, data: bytes) -> None: | |
observed = self.get(store, key) | ||
assert_bytes_equal(observed, data_buf) | ||
|
||
async def test_set_dict(self, store: S) -> None: | ||
""" | ||
Test that a dict of key : value pairs can be inserted into the store via the | ||
`_set_dict` method. | ||
""" | ||
keys = ["zarr.json", "c/0", "foo/c/0.0", "foo/0/0"] | ||
data_buf = [Buffer.from_bytes(k.encode()) for k in keys] | ||
store_dict = dict(zip(keys, data_buf, strict=True)) | ||
await store._set_dict(store_dict) | ||
for k, v in store_dict.items(): | ||
assert self.get(store, k).to_bytes() == v.to_bytes() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if x.to_bytes() == y.to_bytes(), does x== y? Isn't there a multiple get? Maybe not important here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
no, and I suspect this might be deliberate since in principle
there is no multiple get (nor a multiple set, nor a multiple delete). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. xref: src/zarr/buffer.py in #2006 |
||
|
||
@pytest.mark.parametrize( | ||
"key_ranges", | ||
( | ||
|
@@ -165,37 +178,55 @@ async def test_clear(self, store: S) -> None: | |
assert await store.empty() | ||
|
||
async def test_list(self, store: S) -> None: | ||
assert [k async for k in store.list()] == [] | ||
await store.set("foo/zarr.json", Buffer.from_bytes(b"bar")) | ||
keys = [k async for k in store.list()] | ||
assert keys == ["foo/zarr.json"], keys | ||
|
||
expected = ["foo/zarr.json"] | ||
for i in range(10): | ||
key = f"foo/c/{i}" | ||
expected.append(key) | ||
await store.set( | ||
f"foo/c/{i}", Buffer.from_bytes(i.to_bytes(length=3, byteorder="little")) | ||
) | ||
assert await _collect_aiterator(store.list()) == () | ||
prefix = "foo" | ||
data = Buffer.from_bytes(b"") | ||
store_dict = { | ||
prefix + "/zarr.json": data, | ||
**{prefix + f"/c/{idx}": data for idx in range(10)}, | ||
} | ||
await store._set_dict(store_dict) | ||
expected_sorted = sorted(store_dict.keys()) | ||
observed = await _collect_aiterator(store.list()) | ||
observed_sorted = sorted(observed) | ||
assert observed_sorted == expected_sorted | ||
|
||
@pytest.mark.xfail | ||
async def test_list_prefix(self, store: S) -> None: | ||
# TODO: we currently don't use list_prefix anywhere | ||
raise NotImplementedError | ||
""" | ||
Test that the `list_prefix` method works as intended. Given a prefix, it should return | ||
all the keys in storage that start with this prefix. Keys should be returned with the shared | ||
prefix removed. | ||
""" | ||
prefixes = ("", "a/", "a/b/", "a/b/c/") | ||
data = Buffer.from_bytes(b"") | ||
fname = "zarr.json" | ||
store_dict = {p + fname: data for p in prefixes} | ||
await store._set_dict(store_dict) | ||
for p in prefixes: | ||
observed = tuple(sorted(await _collect_aiterator(store.list_prefix(p)))) | ||
expected: tuple[str, ...] = () | ||
for k in store_dict.keys(): | ||
if k.startswith(p): | ||
expected += (k.removeprefix(p),) | ||
expected = tuple(sorted(expected)) | ||
assert observed == expected | ||
|
||
async def test_list_dir(self, store: S) -> None: | ||
out = [k async for k in store.list_dir("")] | ||
assert out == [] | ||
assert [k async for k in store.list_dir("foo")] == [] | ||
await store.set("foo/zarr.json", Buffer.from_bytes(b"bar")) | ||
await store.set("foo/c/1", Buffer.from_bytes(b"\x01")) | ||
root = "foo" | ||
store_dict = { | ||
root + "/zarr.json": Buffer.from_bytes(b"bar"), | ||
root + "/c/1": Buffer.from_bytes(b"\x01"), | ||
} | ||
|
||
assert await _collect_aiterator(store.list_dir("")) == () | ||
assert await _collect_aiterator(store.list_dir(root)) == () | ||
|
||
await store._set_dict(store_dict) | ||
|
||
keys_expected = ["zarr.json", "c"] | ||
keys_observed = [k async for k in store.list_dir("foo")] | ||
keys_observed = await _collect_aiterator(store.list_dir(root)) | ||
keys_expected = {k.removeprefix(root + "/").split("/")[0] for k in store_dict.keys()} | ||
|
||
assert len(keys_observed) == len(keys_expected), keys_observed | ||
assert set(keys_observed) == set(keys_expected), keys_observed | ||
assert sorted(keys_observed) == sorted(keys_expected) | ||
|
||
keys_observed = [k async for k in store.list_dir("foo/")] | ||
assert len(keys_expected) == len(keys_observed), keys_observed | ||
assert set(keys_observed) == set(keys_expected), keys_observed | ||
keys_observed = await _collect_aiterator(store.list_dir(root + "/")) | ||
assert sorted(keys_expected) == sorted(keys_observed) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,18 @@ | ||
from __future__ import annotations | ||
|
||
import os | ||
from collections.abc import Generator | ||
|
||
import fsspec | ||
import pytest | ||
from botocore.client import BaseClient | ||
from botocore.session import Session | ||
from s3fs import S3FileSystem | ||
from upath import UPath | ||
|
||
from zarr.buffer import Buffer, default_buffer_prototype | ||
from zarr.store import RemoteStore | ||
from zarr.sync import sync | ||
from zarr.sync import _collect_aiterator, sync | ||
from zarr.testing.store import StoreTests | ||
|
||
s3fs = pytest.importorskip("s3fs") | ||
|
@@ -22,7 +28,7 @@ | |
|
||
|
||
@pytest.fixture(scope="module") | ||
def s3_base(): | ||
def s3_base() -> Generator[None, None, None]: | ||
# writable local S3 system | ||
|
||
# This fixture is module-scoped, meaning that we can reuse the MotoServer across all tests | ||
|
@@ -37,16 +43,14 @@ def s3_base(): | |
server.stop() | ||
|
||
|
||
def get_boto3_client(): | ||
from botocore.session import Session | ||
|
||
def get_boto3_client() -> BaseClient: | ||
# NB: we use the sync botocore client for setup | ||
session = Session() | ||
return session.create_client("s3", endpoint_url=endpoint_url) | ||
|
||
|
||
@pytest.fixture(autouse=True, scope="function") | ||
def s3(s3_base): | ||
def s3(s3_base: Generator[None, None, None]) -> Generator[S3FileSystem, None, None]: | ||
""" | ||
Quoting Martin Durant: | ||
pytest-asyncio creates a new event loop for each async test. | ||
|
@@ -71,21 +75,11 @@ def s3(s3_base): | |
sync(session.close()) | ||
|
||
|
||
# ### end from s3fs ### # | ||
|
||
|
||
async def alist(it): | ||
out = [] | ||
async for a in it: | ||
out.append(a) | ||
return out | ||
|
||
|
||
async def test_basic(): | ||
async def test_basic() -> None: | ||
store = await RemoteStore.open( | ||
f"s3://{test_bucket_name}", mode="w", endpoint_url=endpoint_url, anon=False | ||
) | ||
assert not await alist(store.list()) | ||
assert await _collect_aiterator(store.list()) == () | ||
assert not await store.exists("foo") | ||
data = b"hello" | ||
await store.set("foo", Buffer.from_bytes(data)) | ||
|
@@ -101,7 +95,7 @@ class TestRemoteStoreS3(StoreTests[RemoteStore]): | |
store_cls = RemoteStore | ||
|
||
@pytest.fixture(scope="function", params=("use_upath", "use_str")) | ||
def store_kwargs(self, request) -> dict[str, str | bool]: | ||
def store_kwargs(self, request: pytest.FixtureRequest) -> dict[str, str | bool | UPath]: # type: ignore | ||
url = f"s3://{test_bucket_name}" | ||
anon = False | ||
mode = "r+" | ||
|
@@ -113,8 +107,8 @@ def store_kwargs(self, request) -> dict[str, str | bool]: | |
raise AssertionError | ||
|
||
@pytest.fixture(scope="function") | ||
def store(self, store_kwargs: dict[str, str | bool]) -> RemoteStore: | ||
url = store_kwargs["url"] | ||
async def store(self, store_kwargs: dict[str, str | bool | UPath]) -> RemoteStore: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This isn't actually async There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. correct, but the class we are inheriting from defines this as an async method |
||
url: str | UPath = store_kwargs["url"] | ||
mode = store_kwargs["mode"] | ||
if isinstance(url, UPath): | ||
out = self.store_cls(url=url, mode=mode) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
set_many()
(analogous toinsert_many
)?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved away from set_dict and switched to
_set_many