Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use dataclasses for ByteRangeRequests #2585

Merged
merged 30 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
125a729
Use TypedDicts for more literate ByteRangeRequests
maxrjones Dec 22, 2024
608f390
Update utility function
maxrjones Dec 22, 2024
4a73b70
fixes sharding
normanrz Dec 30, 2024
b1b38f9
Merge branch 'main' into literate-byte-ranges
maxrjones Dec 30, 2024
5d06965
Merge branch 'main' into literate-byte-ranges
maxrjones Jan 6, 2025
70a81ec
Merge branch 'main' into literate-byte-ranges
maxrjones Jan 6, 2025
c4e6625
Ignore mypy errors
maxrjones Jan 6, 2025
395b0da
Merge branch 'main' into literate-byte-ranges
maxrjones Jan 6, 2025
f8dc6e5
Fix offset in _normalize_byte_range_index
maxrjones Jan 6, 2025
78dfa76
Update get_partial_values for FsspecStore
maxrjones Jan 6, 2025
66a8b81
Merge branch 'main' into literate-byte-ranges
maxrjones Jan 6, 2025
61035c6
Re-add fs._cat_ranges argument
maxrjones Jan 7, 2025
76ba672
Simplify typing
maxrjones Jan 7, 2025
68a6df3
Update _normalize to return start, stop
maxrjones Jan 7, 2025
bd92bae
Merge branch 'main' into literate-byte-ranges
maxrjones Jan 7, 2025
650fb38
Use explicit range
maxrjones Jan 7, 2025
46070f4
Use dataclasses
maxrjones Jan 7, 2025
8464094
Update typing
maxrjones Jan 7, 2025
646454e
Merge branch 'byterange-dataclass' into literate-byte-ranges
maxrjones Jan 8, 2025
4cf6e11
Update docstring
maxrjones Jan 8, 2025
af2b06a
Merge branch 'main' into literate-byte-ranges
maxrjones Jan 8, 2025
e084313
Rename ExplicitRange to ExplicitByteRequest
maxrjones Jan 8, 2025
7659be4
Rename OffsetRange to OffsetByteRequest
maxrjones Jan 8, 2025
fff58dc
Rename SuffixRange to SuffixByteRequest
maxrjones Jan 8, 2025
a7d35f8
Use match; case instead of if; elif
maxrjones Jan 8, 2025
be6324f
Revert "Use match; case instead of if; elif"
maxrjones Jan 8, 2025
7191c84
Update ByteRangeRequest to ByteRequest
maxrjones Jan 8, 2025
a8ea2da
Remove ByteRange definition from common
maxrjones Jan 8, 2025
e7d29c5
Rename ExplicitByteRequest to RangeByteRequest
maxrjones Jan 8, 2025
e6120bf
Provide more informative error message
maxrjones Jan 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 42 additions & 8 deletions src/zarr/abc/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from abc import ABC, abstractmethod
from asyncio import gather
from dataclasses import dataclass
from itertools import starmap
from typing import TYPE_CHECKING, Protocol, runtime_checkable

Expand All @@ -19,7 +20,34 @@

__all__ = ["ByteGetter", "ByteSetter", "Store", "set_or_delete"]

ByteRangeRequest: TypeAlias = tuple[int | None, int | None]

@dataclass
class RangeByteRequest:
"""Request a specific byte range"""

start: int
"""The start of the byte range request (inclusive)."""
end: int
"""The end of the byte range request (exclusive)."""


@dataclass
class OffsetByteRequest:
"""Request all bytes starting from a given byte offset"""

offset: int
"""The byte offset for the offset range request."""


@dataclass
class SuffixByteRequest:
"""Request up to the last `n` bytes"""

suffix: int
"""The number of bytes from the suffix to request."""


ByteRequest: TypeAlias = RangeByteRequest | OffsetByteRequest | SuffixByteRequest


class Store(ABC):
Expand Down Expand Up @@ -141,14 +169,20 @@ async def get(
self,
key: str,
prototype: BufferPrototype,
byte_range: ByteRangeRequest | None = None,
byte_range: ByteRequest | None = None,
) -> Buffer | None:
"""Retrieve the value associated with a given key.

Parameters
----------
key : str
byte_range : tuple[int | None, int | None], optional
byte_range : ByteRequest, optional

ByteRequest may be one of the following. If not provided, all data associated with the key is retrieved.

- RangeByteRequest(int, int): Request a specific range of bytes in the form (start, end). The end is exclusive. If the given range is zero-length or starts after the end of the object, an error will be returned. Additionally, if the range ends after the end of the object, the entire remainder of the object will be returned. Otherwise, the exact requested range will be returned.
- OffsetByteRequest(int): Request all bytes starting from a given byte offset. This is equivalent to bytes={int}- as an HTTP header.
- SuffixByteRequest(int): Request the last int bytes. Note that here, int is the size of the request, not the byte offset. This is equivalent to bytes=-{int} as an HTTP header.

Returns
-------
Expand All @@ -160,7 +194,7 @@ async def get(
async def get_partial_values(
self,
prototype: BufferPrototype,
key_ranges: Iterable[tuple[str, ByteRangeRequest]],
key_ranges: Iterable[tuple[str, ByteRequest | None]],
) -> list[Buffer | None]:
"""Retrieve possibly partial values from given key_ranges.

Expand Down Expand Up @@ -338,7 +372,7 @@ def close(self) -> None:
self._is_open = False

async def _get_many(
self, requests: Iterable[tuple[str, BufferPrototype, ByteRangeRequest | None]]
self, requests: Iterable[tuple[str, BufferPrototype, ByteRequest | None]]
) -> AsyncGenerator[tuple[str, Buffer | None], None]:
"""
Retrieve a collection of objects from storage. In general this method does not guarantee
Expand Down Expand Up @@ -416,17 +450,17 @@ async def getsize_prefix(self, prefix: str) -> int:
@runtime_checkable
class ByteGetter(Protocol):
async def get(
self, prototype: BufferPrototype, byte_range: ByteRangeRequest | None = None
self, prototype: BufferPrototype, byte_range: ByteRequest | None = None
) -> Buffer | None: ...


@runtime_checkable
class ByteSetter(Protocol):
async def get(
self, prototype: BufferPrototype, byte_range: ByteRangeRequest | None = None
self, prototype: BufferPrototype, byte_range: ByteRequest | None = None
) -> Buffer | None: ...

async def set(self, value: Buffer, byte_range: ByteRangeRequest | None = None) -> None: ...
async def set(self, value: Buffer, byte_range: ByteRequest | None = None) -> None: ...

async def delete(self) -> None: ...

Expand Down
24 changes: 16 additions & 8 deletions src/zarr/codecs/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@
Codec,
CodecPipeline,
)
from zarr.abc.store import ByteGetter, ByteRangeRequest, ByteSetter
from zarr.abc.store import (
ByteGetter,
ByteRequest,
ByteSetter,
RangeByteRequest,
SuffixByteRequest,
)
from zarr.codecs.bytes import BytesCodec
from zarr.codecs.crc32c_ import Crc32cCodec
from zarr.core.array_spec import ArrayConfig, ArraySpec
Expand Down Expand Up @@ -77,7 +83,7 @@ class _ShardingByteGetter(ByteGetter):
chunk_coords: ChunkCoords

async def get(
self, prototype: BufferPrototype, byte_range: ByteRangeRequest | None = None
self, prototype: BufferPrototype, byte_range: ByteRequest | None = None
) -> Buffer | None:
assert byte_range is None, "byte_range is not supported within shards"
assert (
Expand All @@ -90,7 +96,7 @@ async def get(
class _ShardingByteSetter(_ShardingByteGetter, ByteSetter):
shard_dict: ShardMutableMapping

async def set(self, value: Buffer, byte_range: ByteRangeRequest | None = None) -> None:
async def set(self, value: Buffer, byte_range: ByteRequest | None = None) -> None:
assert byte_range is None, "byte_range is not supported within shards"
self.shard_dict[self.chunk_coords] = value

Expand Down Expand Up @@ -129,7 +135,7 @@ def get_chunk_slice(self, chunk_coords: ChunkCoords) -> tuple[int, int] | None:
if (chunk_start, chunk_len) == (MAX_UINT_64, MAX_UINT_64):
return None
else:
return (int(chunk_start), int(chunk_len))
return (int(chunk_start), int(chunk_start + chunk_len))

def set_chunk_slice(self, chunk_coords: ChunkCoords, chunk_slice: slice | None) -> None:
localized_chunk = self._localize_chunk(chunk_coords)
Expand Down Expand Up @@ -203,7 +209,7 @@ def create_empty(
def __getitem__(self, chunk_coords: ChunkCoords) -> Buffer:
chunk_byte_slice = self.index.get_chunk_slice(chunk_coords)
if chunk_byte_slice:
return self.buf[chunk_byte_slice[0] : (chunk_byte_slice[0] + chunk_byte_slice[1])]
return self.buf[chunk_byte_slice[0] : chunk_byte_slice[1]]
raise KeyError

def __len__(self) -> int:
Expand Down Expand Up @@ -504,7 +510,8 @@ async def _decode_partial_single(
chunk_byte_slice = shard_index.get_chunk_slice(chunk_coords)
if chunk_byte_slice:
chunk_bytes = await byte_getter.get(
prototype=chunk_spec.prototype, byte_range=chunk_byte_slice
prototype=chunk_spec.prototype,
byte_range=RangeByteRequest(chunk_byte_slice[0], chunk_byte_slice[1]),
)
if chunk_bytes:
shard_dict[chunk_coords] = chunk_bytes
Expand Down Expand Up @@ -696,11 +703,12 @@ async def _load_shard_index_maybe(
shard_index_size = self._shard_index_size(chunks_per_shard)
if self.index_location == ShardingCodecIndexLocation.start:
index_bytes = await byte_getter.get(
prototype=numpy_buffer_prototype(), byte_range=(0, shard_index_size)
prototype=numpy_buffer_prototype(),
byte_range=RangeByteRequest(0, shard_index_size),
)
else:
index_bytes = await byte_getter.get(
prototype=numpy_buffer_prototype(), byte_range=(-shard_index_size, None)
prototype=numpy_buffer_prototype(), byte_range=SuffixByteRequest(shard_index_size)
)
if index_bytes is not None:
return await self._decode_shard_index(index_bytes, chunks_per_shard)
Expand Down
1 change: 0 additions & 1 deletion src/zarr/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
ZATTRS_JSON = ".zattrs"
ZMETADATA_V2_JSON = ".zmetadata"

ByteRangeRequest = tuple[int | None, int | None]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not 100% confident in this change, but couldn't otherwise quickly find a way to avoid circular imports. @d-v-b can you please confirm that this removal won't cause any issues?

BytesLike = bytes | bytearray | memoryview
ShapeLike = tuple[int, ...] | int
ChunkCoords = tuple[int, ...]
Expand Down
10 changes: 5 additions & 5 deletions src/zarr/storage/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal

from zarr.abc.store import ByteRangeRequest, Store
from zarr.abc.store import ByteRequest, Store
from zarr.core.buffer import Buffer, default_buffer_prototype
from zarr.core.common import ZARR_JSON, ZARRAY_JSON, ZGROUP_JSON, AccessModeLiteral, ZarrFormat
from zarr.errors import ContainsArrayAndGroupError, ContainsArrayError, ContainsGroupError
Expand Down Expand Up @@ -102,7 +102,7 @@ async def open(
async def get(
self,
prototype: BufferPrototype | None = None,
byte_range: ByteRangeRequest | None = None,
byte_range: ByteRequest | None = None,
) -> Buffer | None:
"""
Read bytes from the store.
Expand All @@ -111,7 +111,7 @@ async def get(
----------
prototype : BufferPrototype, optional
The buffer prototype to use when reading the bytes.
byte_range : ByteRangeRequest, optional
byte_range : ByteRequest, optional
The range of bytes to read.

Returns
Expand All @@ -123,15 +123,15 @@ async def get(
prototype = default_buffer_prototype()
return await self.store.get(self.path, prototype=prototype, byte_range=byte_range)

async def set(self, value: Buffer, byte_range: ByteRangeRequest | None = None) -> None:
async def set(self, value: Buffer, byte_range: ByteRequest | None = None) -> None:
"""
Write bytes to the store.

Parameters
----------
value : Buffer
The buffer to write.
byte_range : ByteRangeRequest, optional
byte_range : ByteRequest, optional
The range of bytes to write. If None, the entire buffer is written.

Raises
Expand Down
81 changes: 50 additions & 31 deletions src/zarr/storage/_fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
import warnings
from typing import TYPE_CHECKING, Any

from zarr.abc.store import ByteRangeRequest, Store
from zarr.abc.store import (
ByteRequest,
OffsetByteRequest,
RangeByteRequest,
Store,
SuffixByteRequest,
)
from zarr.storage._common import _dereference_path

if TYPE_CHECKING:
Expand Down Expand Up @@ -199,31 +205,34 @@ async def get(
self,
key: str,
prototype: BufferPrototype,
byte_range: ByteRangeRequest | None = None,
byte_range: ByteRequest | None = None,
) -> Buffer | None:
# docstring inherited
if not self._is_open:
await self._open()
path = _dereference_path(self.path, key)

try:
if byte_range:
# fsspec uses start/end, not start/length
start, length = byte_range
if start is not None and length is not None:
end = start + length
elif length is not None:
end = length
else:
end = None
value = prototype.buffer.from_bytes(
await (
self.fs._cat_file(path, start=byte_range[0], end=end)
if byte_range
else self.fs._cat_file(path)
if byte_range is None:
value = prototype.buffer.from_bytes(await self.fs._cat_file(path))
elif isinstance(byte_range, RangeByteRequest):
value = prototype.buffer.from_bytes(
await self.fs._cat_file(
path,
start=byte_range.start,
end=byte_range.end,
)
)
)

elif isinstance(byte_range, OffsetByteRequest):
value = prototype.buffer.from_bytes(
await self.fs._cat_file(path, start=byte_range.offset, end=None)
)
elif isinstance(byte_range, SuffixByteRequest):
value = prototype.buffer.from_bytes(
await self.fs._cat_file(path, start=-byte_range.suffix, end=None)
)
else:
raise ValueError(f"Unexpected byte_range, got {byte_range}.")
except self.allowed_exceptions:
return None
except OSError as e:
Expand Down Expand Up @@ -270,25 +279,35 @@ async def exists(self, key: str) -> bool:
async def get_partial_values(
self,
prototype: BufferPrototype,
key_ranges: Iterable[tuple[str, ByteRangeRequest]],
key_ranges: Iterable[tuple[str, ByteRequest | None]],
) -> list[Buffer | None]:
# docstring inherited
if key_ranges:
paths, starts, stops = zip(
*(
(
_dereference_path(self.path, k[0]),
k[1][0],
((k[1][0] or 0) + k[1][1]) if k[1][1] is not None else None,
)
for k in key_ranges
),
strict=False,
)
# _cat_ranges expects a list of paths, start, and end ranges, so we need to reformat each ByteRequest.
key_ranges = list(key_ranges)
paths: list[str] = []
starts: list[int | None] = []
stops: list[int | None] = []
for key, byte_range in key_ranges:
paths.append(_dereference_path(self.path, key))
if byte_range is None:
starts.append(None)
stops.append(None)
elif isinstance(byte_range, RangeByteRequest):
starts.append(byte_range.start)
stops.append(byte_range.end)
elif isinstance(byte_range, OffsetByteRequest):
starts.append(byte_range.offset)
stops.append(None)
elif isinstance(byte_range, SuffixByteRequest):
starts.append(-byte_range.suffix)
stops.append(None)
else:
raise ValueError(f"Unexpected byte_range, got {byte_range}.")
else:
return []
# TODO: expectations for exceptions or missing keys?
res = await self.fs._cat_ranges(list(paths), starts, stops, on_error="return")
res = await self.fs._cat_ranges(paths, starts, stops, on_error="return")
# the following is an s3-specific condition we probably don't want to leak
res = [b"" if (isinstance(r, OSError) and "not satisfiable" in str(r)) else r for r in res]
for r in res:
Expand Down
Loading
Loading