Skip to content

Commit

Permalink
Merge branch 'ryan/legacy-vlen' into xarray-compat
Browse files Browse the repository at this point in the history
  • Loading branch information
TomAugspurger committed Oct 7, 2024
2 parents 8ad1554 + 8e61a18 commit 3a15e1d
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 4 deletions.
3 changes: 2 additions & 1 deletion src/zarr/codecs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from zarr.codecs.pipeline import BatchedCodecPipeline
from zarr.codecs.sharding import ShardingCodec, ShardingCodecIndexLocation
from zarr.codecs.transpose import TransposeCodec
from zarr.codecs.vlen_utf8 import VLenUTF8Codec
from zarr.codecs.vlen_utf8 import VLenBytesCodec, VLenUTF8Codec
from zarr.codecs.zstd import ZstdCodec

__all__ = [
Expand All @@ -23,5 +23,6 @@
"ShardingCodecIndexLocation",
"TransposeCodec",
"VLenUTF8Codec",
"VLenBytesCodec",
"ZstdCodec",
]
48 changes: 47 additions & 1 deletion src/zarr/codecs/vlen_utf8.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import TYPE_CHECKING

import numpy as np
from numcodecs.vlen import VLenUTF8
from numcodecs.vlen import VLenBytes, VLenUTF8

from zarr.abc.codec import ArrayBytesCodec
from zarr.core.buffer import Buffer, NDBuffer
Expand All @@ -20,6 +20,7 @@

# can use a global because there are no parameters
vlen_utf8_codec = VLenUTF8()
vlen_bytes_codec = VLenBytes()


@dataclass(frozen=True)
Expand Down Expand Up @@ -68,4 +69,49 @@ def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -
raise NotImplementedError("compute_encoded_size is not implemented for VLen codecs")


@dataclass(frozen=True)
class VLenBytesCodec(ArrayBytesCodec):
@classmethod
def from_dict(cls, data: dict[str, JSON]) -> Self:
_, configuration_parsed = parse_named_configuration(
data, "vlen-bytes", require_configuration=False
)
configuration_parsed = configuration_parsed or {}
return cls(**configuration_parsed)

def to_dict(self) -> dict[str, JSON]:
return {"name": "vlen-bytes", "configuration": {}}

def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
return self

async def _decode_single(
self,
chunk_bytes: Buffer,
chunk_spec: ArraySpec,
) -> NDBuffer:
assert isinstance(chunk_bytes, Buffer)

raw_bytes = chunk_bytes.as_array_like()
decoded = vlen_bytes_codec.decode(raw_bytes)
assert decoded.dtype == np.object_
decoded.shape = chunk_spec.shape
return chunk_spec.prototype.nd_buffer.from_numpy_array(decoded)

async def _encode_single(
self,
chunk_array: NDBuffer,
chunk_spec: ArraySpec,
) -> Buffer | None:
assert isinstance(chunk_array, NDBuffer)
return chunk_spec.prototype.buffer.from_bytes(
vlen_bytes_codec.encode(chunk_array.as_numpy_array())
)

def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -> int:
# what is input_byte_length for an object dtype?
raise NotImplementedError("compute_encoded_size is not implemented for VLen codecs")


register_codec("vlen-utf8", VLenUTF8Codec)
register_codec("vlen-bytes", VLenBytesCodec)
1 change: 1 addition & 0 deletions src/zarr/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def reset(self) -> None:
"sharding_indexed": "zarr.codecs.sharding.ShardingCodec",
"transpose": "zarr.codecs.transpose.TransposeCodec",
"vlen-utf8": "zarr.codecs.vlen_utf8.VLenUTF8Codec",
"vlen-bytes": "zarr.codecs.vlen_utf8.VLenBytesCodec",
},
"buffer": "zarr.core.buffer.cpu.Buffer",
"ndbuffer": "zarr.core.buffer.cpu.NDBuffer",
Expand Down
11 changes: 10 additions & 1 deletion src/zarr/core/metadata/v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,8 @@ def parse_fill_value(
"""
if fill_value is None:
return dtype.type(0)
if dtype.kind == "O":
return fill_value
if isinstance(fill_value, Sequence) and not isinstance(fill_value, str):
if dtype.type in (np.complex64, np.complex128):
dtype = cast(COMPLEX_DTYPE, dtype)
Expand Down Expand Up @@ -458,6 +460,7 @@ class DataType(Enum):
complex64 = "complex64"
complex128 = "complex128"
string = "string"
bytes = "bytes"

@property
def byte_count(self) -> int:
Expand Down Expand Up @@ -506,13 +509,19 @@ def to_numpy_shortname(self) -> str:
def to_numpy(self) -> np.dtype[Any]:
if self == DataType.string:
return STRING_DTYPE
elif self == DataType.bytes:
# for now always use object dtype for bytestrings
# TODO: consider whether we can use fixed-width types (e.g. '|S5') instead
return np.dtype("O")
else:
return np.dtype(self.to_numpy_shortname())

@classmethod
def from_numpy(cls, dtype: np.dtype[Any]) -> DataType:
if np.issubdtype(np.str_, dtype):
if dtype.kind in "UT":
return DataType.string
elif dtype.kind == "S":
return DataType.bytes
dtype_to_data_type = {
"|b1": "bool",
"bool": "bool",
Expand Down
32 changes: 31 additions & 1 deletion tests/v3/test_codecs/test_vlen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from zarr import Array
from zarr.abc.store import Store
from zarr.codecs import VLenUTF8Codec
from zarr.codecs import VLenBytesCodec, VLenUTF8Codec
from zarr.core.metadata.v3 import ArrayV3Metadata, DataType
from zarr.storage.common import StorePath
from zarr.strings import NUMPY_SUPPORTS_VLEN_STRING
Expand Down Expand Up @@ -49,3 +49,33 @@ async def test_vlen_string(store: Store, dtype: None | np.dtype[Any]) -> None:
assert np.array_equal(data, b[:, :])
assert b.metadata.data_type == DataType.string
assert a.dtype == expected_zarr_string_dtype


@pytest.mark.parametrize("store", ["memory", "local"], indirect=["store"])
async def test_vlen_bytes(store: Store) -> None:
bstrings = [b"hello", b"world", b"this", b"is", b"a", b"test"]
data = np.array(bstrings).reshape((2, 3))
assert data.dtype == "|S5"

sp = StorePath(store, path="string")
a = Array.create(
sp,
shape=data.shape,
chunk_shape=data.shape,
dtype=data.dtype,
fill_value=b"",
codecs=[VLenBytesCodec()],
)
assert isinstance(a.metadata, ArrayV3Metadata) # needed for mypy

a[:, :] = data
assert np.array_equal(data, a[:, :])
assert a.metadata.data_type == DataType.bytes
assert a.dtype == "O"

# test round trip
b = Array.open(sp)
assert isinstance(b.metadata, ArrayV3Metadata) # needed for mypy
assert np.array_equal(data, b[:, :])
assert b.metadata.data_type == DataType.bytes
assert a.dtype == "O"
1 change: 1 addition & 0 deletions tests/v3/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def test_config_defaults_set() -> None:
"sharding_indexed": "zarr.codecs.sharding.ShardingCodec",
"transpose": "zarr.codecs.transpose.TransposeCodec",
"vlen-utf8": "zarr.codecs.vlen_utf8.VLenUTF8Codec",
"vlen-bytes": "zarr.codecs.vlen_utf8.VLenBytesCodec",
},
}
]
Expand Down

0 comments on commit 3a15e1d

Please sign in to comment.