Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement store.list_prefix and store._set_many #2064

Merged
merged 9 commits into from
Sep 19, 2024
10 changes: 9 additions & 1 deletion src/zarr/abc/store.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Mapping
from typing import Any, NamedTuple, Protocol, runtime_checkable

from typing_extensions import Self
Expand Down Expand Up @@ -221,6 +221,14 @@ def close(self) -> None:
self._is_open = False
pass

async def _set_dict(self, dict: Mapping[str, Buffer]) -> None:
Copy link
Member

@jhamman jhamman Aug 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set_many() (analogous to insert_many)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved away from set_dict and switched to _set_many

"""
Insert objects into storage as defined by a prefix: value mapping.
"""
for key, value in dict.items():
await self.set(key, value)
return None


@runtime_checkable
class ByteGetter(Protocol):
Expand Down
6 changes: 1 addition & 5 deletions src/zarr/store/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,14 +193,10 @@ async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]:
-------
AsyncGenerator[str, None]
"""
for p in (self.root / prefix).rglob("*"):
if p.is_file():
yield str(p)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Were we getting duplicates?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we were. this code path was not tested until this PR


to_strip = str(self.root) + "/"
for p in (self.root / prefix).rglob("*"):
if p.is_file():
yield str(p).replace(to_strip, "")
yield str(p).removeprefix(to_strip)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1


async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
"""
Expand Down
2 changes: 1 addition & 1 deletion src/zarr/store/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ async def list(self) -> AsyncGenerator[str, None]:
async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]:
for key in self._store_dict:
if key.startswith(prefix):
yield key
yield key.removeprefix(prefix)

async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
if prefix.endswith("/"):
Expand Down
5 changes: 3 additions & 2 deletions src/zarr/store/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,5 +205,6 @@ async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
yield onefile

async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]:
for onefile in await self._fs._ls(prefix, detail=False):
yield onefile
find_str = "/".join([self.path, prefix])
for onefile in await self._fs._find(find_str):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The defaults for find are: maxdepth=None, withdirs=False, detail=False; maybe good to be specific.

Why is find() better than ls()? The former will return all child files, not just one level deep - is that the intent? If not, ls() ought to be generally more efficient.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using find here is merely due to my ignorance of fsspec. I will implement ls as you suggest

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It depends on whether you want one directory level or everything below it. When I wrote the original, I didn't know the intent.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i believe the intent here is to list everything below prefix (at least, that's how I'm using it)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I misunderstood your first comment. since the intent is to use the behavior of _find, I'm keeping it, but adding explicit kwargs as you suggested.

yield onefile.removeprefix(find_str)
17 changes: 17 additions & 0 deletions src/zarr/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,23 @@ def _get_loop() -> asyncio.AbstractEventLoop:
return loop[0]


async def _collect_aiterator(data: AsyncIterator[T]) -> tuple[T, ...]:
"""
Collect an entire async iterator into a tuple
"""
result = []
async for x in data:
result.append(x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

asyncio.gather? Like above, not much point in having coroutines if we serially wait for them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we can use asyncio.gather here, because AsyncGenerator is not iterable. Happy to be corrected, since I don't really know asyncio very well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and for clarification, _collect_aiterator largely exists for convenience in testing, because I need some way to collect async generators when debugging with pdb. This function is not intended for use in anything performance sensitive.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably re-examine the use of async-iterators, though. If we can't gather() on them (seems to be true?), then they are the wrong abstraction since gather() is probably always what we actually want.

Copy link
Member

@martindurant martindurant Aug 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thoughts, maybe I'm wrong - does async for schedule all the coroutines at once?? Should be easy to test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it schedules them all at once. in [x async for x in async_generator], x is not an awaitable; it's already awaited. since the basic model of the generator is that it's a resumable, stateful iterator, I don't think we can schedule all the tasks at once.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the idea with the generators is to a) support seamless pagination and b) support pipelining (del_prefix will be able to take advantage of this at some point).

return tuple(result)


def collect_aiterator(data: AsyncIterator[T]) -> tuple[T, ...]:
"""
Synchronously collect an entire async iterator into a tuple.
"""
return sync(_collect_aiterator(data))


class SyncMixin:
def _sync(self, coroutine: Coroutine[Any, Any, T]) -> T:
# TODO: refactor this to to take *args and **kwargs and pass those to the method
Expand Down
85 changes: 58 additions & 27 deletions src/zarr/testing/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from zarr.abc.store import AccessMode, Store
from zarr.buffer import Buffer, default_buffer_prototype
from zarr.store.utils import _normalize_interval_index
from zarr.sync import _collect_aiterator
from zarr.testing.utils import assert_bytes_equal

S = TypeVar("S", bound=Store)
Expand Down Expand Up @@ -103,6 +104,18 @@ async def test_set(self, store: S, key: str, data: bytes) -> None:
observed = self.get(store, key)
assert_bytes_equal(observed, data_buf)

async def test_set_dict(self, store: S) -> None:
"""
Test that a dict of key : value pairs can be inserted into the store via the
`_set_dict` method.
"""
keys = ["zarr.json", "c/0", "foo/c/0.0", "foo/0/0"]
data_buf = [Buffer.from_bytes(k.encode()) for k in keys]
store_dict = dict(zip(keys, data_buf, strict=True))
await store._set_dict(store_dict)
for k, v in store_dict.items():
assert self.get(store, k).to_bytes() == v.to_bytes()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if x.to_bytes() == y.to_bytes(), does x== y?

Isn't there a multiple get? Maybe not important here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if x.to_bytes() == y.to_bytes(), does x== y?

no, and I suspect this might be deliberate since in principle Buffer instances can have identical bytes but different devices (e.g., gpu memory vs host memory); thus x == y might only be true if two buffers are bytes-equal and device-equal, but I'm speculating here. @madsbk would have a better answer I think.

Isn't there a multiple get? Maybe not important here.

there is no multiple get (nor a multiple set, nor a multiple delete).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


@pytest.mark.parametrize(
"key_ranges",
(
Expand Down Expand Up @@ -165,37 +178,55 @@ async def test_clear(self, store: S) -> None:
assert await store.empty()

async def test_list(self, store: S) -> None:
assert [k async for k in store.list()] == []
await store.set("foo/zarr.json", Buffer.from_bytes(b"bar"))
keys = [k async for k in store.list()]
assert keys == ["foo/zarr.json"], keys

expected = ["foo/zarr.json"]
for i in range(10):
key = f"foo/c/{i}"
expected.append(key)
await store.set(
f"foo/c/{i}", Buffer.from_bytes(i.to_bytes(length=3, byteorder="little"))
)
assert await _collect_aiterator(store.list()) == ()
prefix = "foo"
data = Buffer.from_bytes(b"")
store_dict = {
prefix + "/zarr.json": data,
**{prefix + f"/c/{idx}": data for idx in range(10)},
}
await store._set_dict(store_dict)
expected_sorted = sorted(store_dict.keys())
observed = await _collect_aiterator(store.list())
observed_sorted = sorted(observed)
assert observed_sorted == expected_sorted

@pytest.mark.xfail
async def test_list_prefix(self, store: S) -> None:
# TODO: we currently don't use list_prefix anywhere
raise NotImplementedError
"""
Test that the `list_prefix` method works as intended. Given a prefix, it should return
all the keys in storage that start with this prefix. Keys should be returned with the shared
prefix removed.
"""
prefixes = ("", "a/", "a/b/", "a/b/c/")
data = Buffer.from_bytes(b"")
fname = "zarr.json"
store_dict = {p + fname: data for p in prefixes}
await store._set_dict(store_dict)
for p in prefixes:
observed = tuple(sorted(await _collect_aiterator(store.list_prefix(p))))
expected: tuple[str, ...] = ()
for k in store_dict.keys():
if k.startswith(p):
expected += (k.removeprefix(p),)
expected = tuple(sorted(expected))
assert observed == expected

async def test_list_dir(self, store: S) -> None:
out = [k async for k in store.list_dir("")]
assert out == []
assert [k async for k in store.list_dir("foo")] == []
await store.set("foo/zarr.json", Buffer.from_bytes(b"bar"))
await store.set("foo/c/1", Buffer.from_bytes(b"\x01"))
root = "foo"
store_dict = {
root + "/zarr.json": Buffer.from_bytes(b"bar"),
root + "/c/1": Buffer.from_bytes(b"\x01"),
}

assert await _collect_aiterator(store.list_dir("")) == ()
assert await _collect_aiterator(store.list_dir(root)) == ()

await store._set_dict(store_dict)

keys_expected = ["zarr.json", "c"]
keys_observed = [k async for k in store.list_dir("foo")]
keys_observed = await _collect_aiterator(store.list_dir(root))
keys_expected = {k.removeprefix(root + "/").split("/")[0] for k in store_dict.keys()}

assert len(keys_observed) == len(keys_expected), keys_observed
assert set(keys_observed) == set(keys_expected), keys_observed
assert sorted(keys_observed) == sorted(keys_expected)

keys_observed = [k async for k in store.list_dir("foo/")]
assert len(keys_expected) == len(keys_observed), keys_observed
assert set(keys_observed) == set(keys_expected), keys_observed
keys_observed = await _collect_aiterator(store.list_dir(root + "/"))
assert sorted(keys_expected) == sorted(keys_observed)
36 changes: 15 additions & 21 deletions tests/v3/test_store/test_remote.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from __future__ import annotations

import os
from collections.abc import Generator

import fsspec
import pytest
from botocore.client import BaseClient
from botocore.session import Session
from s3fs import S3FileSystem
from upath import UPath

from zarr.buffer import Buffer, default_buffer_prototype
from zarr.store import RemoteStore
from zarr.sync import sync
from zarr.sync import _collect_aiterator, sync
from zarr.testing.store import StoreTests

s3fs = pytest.importorskip("s3fs")
Expand All @@ -22,7 +28,7 @@


@pytest.fixture(scope="module")
def s3_base():
def s3_base() -> Generator[None, None, None]:
# writable local S3 system

# This fixture is module-scoped, meaning that we can reuse the MotoServer across all tests
Expand All @@ -37,16 +43,14 @@ def s3_base():
server.stop()


def get_boto3_client():
from botocore.session import Session

def get_boto3_client() -> BaseClient:
# NB: we use the sync botocore client for setup
session = Session()
return session.create_client("s3", endpoint_url=endpoint_url)


@pytest.fixture(autouse=True, scope="function")
def s3(s3_base):
def s3(s3_base: Generator[None, None, None]) -> Generator[S3FileSystem, None, None]:
"""
Quoting Martin Durant:
pytest-asyncio creates a new event loop for each async test.
Expand All @@ -71,21 +75,11 @@ def s3(s3_base):
sync(session.close())


# ### end from s3fs ### #


async def alist(it):
out = []
async for a in it:
out.append(a)
return out


async def test_basic():
async def test_basic() -> None:
store = await RemoteStore.open(
f"s3://{test_bucket_name}", mode="w", endpoint_url=endpoint_url, anon=False
)
assert not await alist(store.list())
assert await _collect_aiterator(store.list()) == ()
assert not await store.exists("foo")
data = b"hello"
await store.set("foo", Buffer.from_bytes(data))
Expand All @@ -101,7 +95,7 @@ class TestRemoteStoreS3(StoreTests[RemoteStore]):
store_cls = RemoteStore

@pytest.fixture(scope="function", params=("use_upath", "use_str"))
def store_kwargs(self, request) -> dict[str, str | bool]:
def store_kwargs(self, request: pytest.FixtureRequest) -> dict[str, str | bool | UPath]: # type: ignore
url = f"s3://{test_bucket_name}"
anon = False
mode = "r+"
Expand All @@ -113,8 +107,8 @@ def store_kwargs(self, request) -> dict[str, str | bool]:
raise AssertionError

@pytest.fixture(scope="function")
def store(self, store_kwargs: dict[str, str | bool]) -> RemoteStore:
url = store_kwargs["url"]
async def store(self, store_kwargs: dict[str, str | bool | UPath]) -> RemoteStore:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't actually async

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct, but the class we are inheriting from defines this as an async method

url: str | UPath = store_kwargs["url"]
mode = store_kwargs["mode"]
if isinstance(url, UPath):
out = self.store_cls(url=url, mode=mode)
Expand Down