Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/recursive members #2118

Merged
merged 11 commits into from
Aug 29, 2024
70 changes: 60 additions & 10 deletions src/zarr/core/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,21 +424,59 @@ async def update_attributes(self, new_attributes: dict[str, Any]) -> AsyncGroup:
def __repr__(self) -> str:
return f"<AsyncGroup {self.store_path}>"

async def nmembers(self) -> int:
async def nmembers(
self,
max_depth: int | None = 0,
) -> int:
"""
Count the number of members in this group.

Parameters
----------
max_depth : int, default 0
The maximum number of levels of the hierarchy to include. By
default, (``max_depth=0``) only immediate children are included. Set
``max_depth=None`` to include all nodes, and some positive integer
to consider children within that many levels of the root Group.

Returns
-------
count : int
"""
# TODO: consider using aioitertools.builtins.sum for this
# return await aioitertools.builtins.sum((1 async for _ in self.members()), start=0)
n = 0
async for _ in self.members():
async for _ in self.members(max_depth=max_depth):
n += 1
return n

async def members(self) -> AsyncGenerator[tuple[str, AsyncArray | AsyncGroup], None]:
async def members(
self,
max_depth: int | None = 0,
) -> AsyncGenerator[tuple[str, AsyncArray | AsyncGroup], None]:
"""
Returns an AsyncGenerator over the arrays and groups contained in this group.
This method requires that `store_path.store` supports directory listing.

The results are not guaranteed to be ordered.

Parameters
----------
max_depth : int, default 0
The maximum number of levels of the hierarchy to include. By
default, (``max_depth=0``) only immediate children are included. Set
``max_depth=None`` to include all nodes, and some positive integer
to consider children within that many levels of the root Group.

"""
if max_depth is not None and max_depth < 0:
raise ValueError(f"max_depth must be None or >= 0. Got '{max_depth}' instead")
async for item in self._members(max_depth=max_depth, current_depth=0):
yield item

async def _members(
self, max_depth: int | None, current_depth: int
) -> AsyncGenerator[tuple[str, AsyncArray | AsyncGroup], None]:
if not self.store_path.store.supports_listing:
msg = (
f"The store associated with this group ({type(self.store_path.store)}) "
Expand All @@ -456,7 +494,21 @@ async def members(self) -> AsyncGenerator[tuple[str, AsyncArray | AsyncGroup], N
if key in _skip_keys:
continue
try:
yield (key, await self.getitem(key))
obj = await self.getitem(key)
yield (key, obj)

if (
((max_depth is None) or (current_depth < max_depth))
and hasattr(obj.metadata, "node_type")
and obj.metadata.node_type == "group"
):
# the assert is just for mypy to know that `obj.metadata.node_type`
# implies an AsyncGroup, not an AsyncArray
assert isinstance(obj, AsyncGroup)
async for child_key, val in obj._members(
max_depth=max_depth, current_depth=current_depth + 1
):
yield "/".join([key, child_key]), val
except KeyError:
# keyerror is raised when `key` names an object (in the object storage sense),
# as opposed to a prefix, in the store under the prefix associated with this group
Expand Down Expand Up @@ -628,17 +680,15 @@ def update_attributes(self, new_attributes: dict[str, Any]) -> Group:
self._sync(self._async_group.update_attributes(new_attributes))
return self

@property
def nmembers(self) -> int:
return self._sync(self._async_group.nmembers())
def nmembers(self, max_depth: int | None = 0) -> int:
return self._sync(self._async_group.nmembers(max_depth=max_depth))

@property
def members(self) -> tuple[tuple[str, Array | Group], ...]:
def members(self, max_depth: int | None = 0) -> tuple[tuple[str, Array | Group], ...]:
"""
Return the sub-arrays and sub-groups of this group as a tuple of (name, array | group)
pairs
"""
_members = self._sync_iter(self._async_group.members())
_members = self._sync_iter(self._async_group.members(max_depth=max_depth))

result = tuple(map(lambda kv: (kv[0], _parse_async_node(kv[1])), _members))
return result
Expand Down
22 changes: 15 additions & 7 deletions src/zarr/core/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,13 +256,21 @@ def _json_convert(o: Any) -> Any:
if isinstance(o, np.dtype):
return str(o)
if np.isscalar(o):
# convert numpy scalar to python type, and pass
# python types through
out = getattr(o, "item", lambda: o)()
if isinstance(out, complex):
# python complex types are not JSON serializable, so we use the
# serialization defined in the zarr v3 spec
return [out.real, out.imag]
out: Any
if hasattr(o, "dtype") and o.dtype.kind == "M" and hasattr(o, "view"):
# https://github.com/zarr-developers/zarr-python/issues/2119
# `.item()` on a datetime type might or might not return an
# integer, depending on the value.
# Explicitly cast to an int first, and then grab .item()
out = o.view("i8").item()
else:
# convert numpy scalar to python type, and pass
# python types through
out = getattr(o, "item", lambda: o)()
if isinstance(out, complex):
# python complex types are not JSON serializable, so we use the
# serialization defined in the zarr v3 spec
return [out.real, out.imag]
return out
if isinstance(o, Enum):
return o.name
Expand Down
10 changes: 9 additions & 1 deletion src/zarr/testing/strategies.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from typing import Any

import hypothesis.extra.numpy as npst
Expand Down Expand Up @@ -101,7 +102,14 @@ def arrays(
root = Group.create(store)
fill_value_args: tuple[Any, ...] = tuple()
if nparray.dtype.kind == "M":
fill_value_args = ("ns",)
m = re.search(r"\[(.+)\]", nparray.dtype.str)
if not m:
raise ValueError(f"Couldn't find precision for dtype '{nparray.dtype}.")

fill_value_args = (
# e.g. ns, D
m.groups()[0],
)

a = root.create_array(
array_path,
Expand Down
78 changes: 74 additions & 4 deletions tests/v3/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ def test_group_members(store: MemoryStore | LocalStore, zarr_format: ZarrFormat)
members_expected["subgroup"] = group.create_group("subgroup")
# make a sub-sub-subgroup, to ensure that the children calculation doesn't go
# too deep in the hierarchy
_ = members_expected["subgroup"].create_group("subsubgroup") # type: ignore
subsubgroup = members_expected["subgroup"].create_group("subsubgroup") # type: ignore
subsubsubgroup = subsubgroup.create_group("subsubsubgroup") # type: ignore

members_expected["subarray"] = group.create_array(
"subarray", shape=(100,), dtype="uint8", chunk_shape=(10,), exists_ok=True
Expand All @@ -101,10 +102,25 @@ def test_group_members(store: MemoryStore | LocalStore, zarr_format: ZarrFormat)
# this creates a directory with a random key in it
# this should not show up as a member
sync(store.set(f"{path}/extra_directory/extra_object-2", Buffer.from_bytes(b"000000")))
members_observed = group.members
members_observed = group.members()
# members are not guaranteed to be ordered, so sort before comparing
assert sorted(dict(members_observed)) == sorted(members_expected)

# partial
members_observed = group.members(max_depth=1)
members_expected["subgroup/subsubgroup"] = subsubgroup
# members are not guaranteed to be ordered, so sort before comparing
assert sorted(dict(members_observed)) == sorted(members_expected)

# total
members_observed = group.members(max_depth=None)
members_expected["subgroup/subsubgroup/subsubsubgroup"] = subsubsubgroup
# members are not guaranteed to be ordered, so sort before comparing
assert sorted(dict(members_observed)) == sorted(members_expected)

with pytest.raises(ValueError, match="max_depth"):
members_observed = group.members(max_depth=-1)


def test_group(store: MemoryStore | LocalStore, zarr_format: ZarrFormat) -> None:
"""
Expand Down Expand Up @@ -349,7 +365,8 @@ def test_group_create_array(
if method == "create_array":
array = group.create_array(name="array", shape=shape, dtype=dtype, data=data)
elif method == "array":
array = group.array(name="array", shape=shape, dtype=dtype, data=data)
with pytest.warns(DeprecationWarning):
array = group.array(name="array", shape=shape, dtype=dtype, data=data)
else:
raise AssertionError

Expand All @@ -358,7 +375,7 @@ def test_group_create_array(
with pytest.raises(ContainsArrayError):
group.create_array(name="array", shape=shape, dtype=dtype, data=data)
elif method == "array":
with pytest.raises(ContainsArrayError):
with pytest.raises(ContainsArrayError), pytest.warns(DeprecationWarning):
group.array(name="array", shape=shape, dtype=dtype, data=data)
assert array.shape == shape
assert array.dtype == np.dtype(dtype)
Expand Down Expand Up @@ -653,3 +670,56 @@ async def test_asyncgroup_update_attributes(

agroup_new_attributes = await agroup.update_attributes(attributes_new)
assert agroup_new_attributes.attrs == attributes_new


async def test_group_members_async(store: LocalStore | MemoryStore):
group = AsyncGroup(
GroupMetadata(),
store_path=StorePath(store=store, path="root"),
)
a0 = await group.create_array("a0", (1,))
g0 = await group.create_group("g0")
a1 = await g0.create_array("a1", (1,))
g1 = await g0.create_group("g1")
a2 = await g1.create_array("a2", (1,))
g2 = await g1.create_group("g2")

# immediate children
children = sorted([x async for x in group.members()], key=lambda x: x[0])
assert children == [
("a0", a0),
("g0", g0),
]

nmembers = await group.nmembers()
assert nmembers == 2

# partial
children = sorted([x async for x in group.members(max_depth=1)], key=lambda x: x[0])
expected = [
("a0", a0),
("g0", g0),
("g0/a1", a1),
("g0/g1", g1),
]
assert children == expected
nmembers = await group.nmembers(max_depth=1)
assert nmembers == 4

# all children
all_children = sorted([x async for x in group.members(max_depth=None)], key=lambda x: x[0])
expected = [
("a0", a0),
("g0", g0),
("g0/a1", a1),
("g0/g1", g1),
("g0/g1/a2", a2),
("g0/g1/g2", g2),
]
assert all_children == expected

nmembers = await group.nmembers(max_depth=None)
assert nmembers == 6

with pytest.raises(ValueError, match="max_depth"):
[x async for x in group.members(max_depth=-1)]
23 changes: 23 additions & 0 deletions tests/v3/test_metadata/test_v3.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import json
import re
from typing import TYPE_CHECKING, Literal

from zarr.abc.codec import Codec
from zarr.codecs.bytes import BytesCodec
from zarr.core.buffer import default_buffer_prototype
from zarr.core.chunk_key_encodings import DefaultChunkKeyEncoding, V2ChunkKeyEncoding

if TYPE_CHECKING:
Expand Down Expand Up @@ -230,3 +232,24 @@ def test_metadata_to_dict(
observed.pop("chunk_key_encoding")
expected.pop("chunk_key_encoding")
assert observed == expected


@pytest.mark.parametrize("fill_value", [-1, 0, 1, 2932897])
@pytest.mark.parametrize("precision", ["ns", "D"])
async def test_datetime_metadata(fill_value: int, precision: str):
metadata_dict = {
"zarr_format": 3,
"node_type": "array",
"shape": (1,),
"chunk_grid": {"name": "regular", "configuration": {"chunk_shape": (1,)}},
"data_type": f"<M8[{precision}]",
"chunk_key_encoding": {"name": "default", "separator": "."},
"codecs": (),
"fill_value": np.datetime64(fill_value, precision),
}
metadata = ArrayV3Metadata.from_dict(metadata_dict)
# ensure there isn't a TypeError here.
d = metadata.to_buffer_dict(default_buffer_prototype())

result = json.loads(d["zarr.json"].to_bytes())
assert result["fill_value"] == fill_value
10 changes: 10 additions & 0 deletions tests/v3/test_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ def test_roundtrip(data):


@given(data=st.data())
# The filter warning here is to silence an occasional warning in NDBuffer.all_equal
# See https://github.com/zarr-developers/zarr-python/pull/2118#issuecomment-2310280899
# Uncomment the next line to reproduce the original failure.
# @reproduce_failure('6.111.2', b'AXicY2FgZGRAB/8/ndR2z7nkDZEDADWpBL4=')
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
def test_basic_indexing(data):
zarray = data.draw(arrays())
nparray = zarray[:]
Expand All @@ -32,6 +37,11 @@ def test_basic_indexing(data):


@given(data=st.data())
# The filter warning here is to silence an occasional warning in NDBuffer.all_equal
# See https://github.com/zarr-developers/zarr-python/pull/2118#issuecomment-2310280899
# Uncomment the next line to reproduce the original failure.
# @reproduce_failure('6.111.2', b'AXicY2FgZGRAB/8/eLmF7qr/C5EDADZUBRM=')
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dcherian - would you mind looking into this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a minefield 😬 The deepest I got was hypothesis.extra.numpy:ArrayStrategy.set_element. Something about the sequence of operations and the data passed into the ndarray made it "weird", such that array == np.complex128(0.0) raised an InvalidComparision warning. I'm not sure how to interpret the bytes at #2118 (comment).

def test_vindex(data):
zarray = data.draw(arrays())
nparray = zarray[:]
Expand Down