diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 86d27e3a97..2c26cac3b1 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -424,21 +424,59 @@ async def update_attributes(self, new_attributes: dict[str, Any]) -> AsyncGroup: def __repr__(self) -> str: return f"" - async def nmembers(self) -> int: + async def nmembers( + self, + max_depth: int | None = 0, + ) -> int: + """ + Count the number of members in this group. + + Parameters + ---------- + max_depth : int, default 0 + The maximum number of levels of the hierarchy to include. By + default, (``max_depth=0``) only immediate children are included. Set + ``max_depth=None`` to include all nodes, and some positive integer + to consider children within that many levels of the root Group. + + Returns + ------- + count : int + """ # TODO: consider using aioitertools.builtins.sum for this # return await aioitertools.builtins.sum((1 async for _ in self.members()), start=0) n = 0 - async for _ in self.members(): + async for _ in self.members(max_depth=max_depth): n += 1 return n - async def members(self) -> AsyncGenerator[tuple[str, AsyncArray | AsyncGroup], None]: + async def members( + self, + max_depth: int | None = 0, + ) -> AsyncGenerator[tuple[str, AsyncArray | AsyncGroup], None]: """ Returns an AsyncGenerator over the arrays and groups contained in this group. This method requires that `store_path.store` supports directory listing. The results are not guaranteed to be ordered. + + Parameters + ---------- + max_depth : int, default 0 + The maximum number of levels of the hierarchy to include. By + default, (``max_depth=0``) only immediate children are included. Set + ``max_depth=None`` to include all nodes, and some positive integer + to consider children within that many levels of the root Group. + """ + if max_depth is not None and max_depth < 0: + raise ValueError(f"max_depth must be None or >= 0. Got '{max_depth}' instead") + async for item in self._members(max_depth=max_depth, current_depth=0): + yield item + + async def _members( + self, max_depth: int | None, current_depth: int + ) -> AsyncGenerator[tuple[str, AsyncArray | AsyncGroup], None]: if not self.store_path.store.supports_listing: msg = ( f"The store associated with this group ({type(self.store_path.store)}) " @@ -456,7 +494,21 @@ async def members(self) -> AsyncGenerator[tuple[str, AsyncArray | AsyncGroup], N if key in _skip_keys: continue try: - yield (key, await self.getitem(key)) + obj = await self.getitem(key) + yield (key, obj) + + if ( + ((max_depth is None) or (current_depth < max_depth)) + and hasattr(obj.metadata, "node_type") + and obj.metadata.node_type == "group" + ): + # the assert is just for mypy to know that `obj.metadata.node_type` + # implies an AsyncGroup, not an AsyncArray + assert isinstance(obj, AsyncGroup) + async for child_key, val in obj._members( + max_depth=max_depth, current_depth=current_depth + 1 + ): + yield "/".join([key, child_key]), val except KeyError: # keyerror is raised when `key` names an object (in the object storage sense), # as opposed to a prefix, in the store under the prefix associated with this group @@ -628,17 +680,15 @@ def update_attributes(self, new_attributes: dict[str, Any]) -> Group: self._sync(self._async_group.update_attributes(new_attributes)) return self - @property - def nmembers(self) -> int: - return self._sync(self._async_group.nmembers()) + def nmembers(self, max_depth: int | None = 0) -> int: + return self._sync(self._async_group.nmembers(max_depth=max_depth)) - @property - def members(self) -> tuple[tuple[str, Array | Group], ...]: + def members(self, max_depth: int | None = 0) -> tuple[tuple[str, Array | Group], ...]: """ Return the sub-arrays and sub-groups of this group as a tuple of (name, array | group) pairs """ - _members = self._sync_iter(self._async_group.members()) + _members = self._sync_iter(self._async_group.members(max_depth=max_depth)) result = tuple(map(lambda kv: (kv[0], _parse_async_node(kv[1])), _members)) return result diff --git a/src/zarr/core/metadata.py b/src/zarr/core/metadata.py index d541e43205..72172a2673 100644 --- a/src/zarr/core/metadata.py +++ b/src/zarr/core/metadata.py @@ -256,13 +256,21 @@ def _json_convert(o: Any) -> Any: if isinstance(o, np.dtype): return str(o) if np.isscalar(o): - # convert numpy scalar to python type, and pass - # python types through - out = getattr(o, "item", lambda: o)() - if isinstance(out, complex): - # python complex types are not JSON serializable, so we use the - # serialization defined in the zarr v3 spec - return [out.real, out.imag] + out: Any + if hasattr(o, "dtype") and o.dtype.kind == "M" and hasattr(o, "view"): + # https://github.com/zarr-developers/zarr-python/issues/2119 + # `.item()` on a datetime type might or might not return an + # integer, depending on the value. + # Explicitly cast to an int first, and then grab .item() + out = o.view("i8").item() + else: + # convert numpy scalar to python type, and pass + # python types through + out = getattr(o, "item", lambda: o)() + if isinstance(out, complex): + # python complex types are not JSON serializable, so we use the + # serialization defined in the zarr v3 spec + return [out.real, out.imag] return out if isinstance(o, Enum): return o.name diff --git a/src/zarr/testing/strategies.py b/src/zarr/testing/strategies.py index d2e41c6290..3a460d4fff 100644 --- a/src/zarr/testing/strategies.py +++ b/src/zarr/testing/strategies.py @@ -1,3 +1,4 @@ +import re from typing import Any import hypothesis.extra.numpy as npst @@ -101,7 +102,14 @@ def arrays( root = Group.create(store) fill_value_args: tuple[Any, ...] = tuple() if nparray.dtype.kind == "M": - fill_value_args = ("ns",) + m = re.search(r"\[(.+)\]", nparray.dtype.str) + if not m: + raise ValueError(f"Couldn't find precision for dtype '{nparray.dtype}.") + + fill_value_args = ( + # e.g. ns, D + m.groups()[0], + ) a = root.create_array( array_path, diff --git a/tests/v3/test_group.py b/tests/v3/test_group.py index 39921c26d8..eb7b1f30dd 100644 --- a/tests/v3/test_group.py +++ b/tests/v3/test_group.py @@ -88,7 +88,8 @@ def test_group_members(store: MemoryStore | LocalStore, zarr_format: ZarrFormat) members_expected["subgroup"] = group.create_group("subgroup") # make a sub-sub-subgroup, to ensure that the children calculation doesn't go # too deep in the hierarchy - _ = members_expected["subgroup"].create_group("subsubgroup") # type: ignore + subsubgroup = members_expected["subgroup"].create_group("subsubgroup") # type: ignore + subsubsubgroup = subsubgroup.create_group("subsubsubgroup") # type: ignore members_expected["subarray"] = group.create_array( "subarray", shape=(100,), dtype="uint8", chunk_shape=(10,), exists_ok=True @@ -101,10 +102,25 @@ def test_group_members(store: MemoryStore | LocalStore, zarr_format: ZarrFormat) # this creates a directory with a random key in it # this should not show up as a member sync(store.set(f"{path}/extra_directory/extra_object-2", Buffer.from_bytes(b"000000"))) - members_observed = group.members + members_observed = group.members() # members are not guaranteed to be ordered, so sort before comparing assert sorted(dict(members_observed)) == sorted(members_expected) + # partial + members_observed = group.members(max_depth=1) + members_expected["subgroup/subsubgroup"] = subsubgroup + # members are not guaranteed to be ordered, so sort before comparing + assert sorted(dict(members_observed)) == sorted(members_expected) + + # total + members_observed = group.members(max_depth=None) + members_expected["subgroup/subsubgroup/subsubsubgroup"] = subsubsubgroup + # members are not guaranteed to be ordered, so sort before comparing + assert sorted(dict(members_observed)) == sorted(members_expected) + + with pytest.raises(ValueError, match="max_depth"): + members_observed = group.members(max_depth=-1) + def test_group(store: MemoryStore | LocalStore, zarr_format: ZarrFormat) -> None: """ @@ -349,7 +365,8 @@ def test_group_create_array( if method == "create_array": array = group.create_array(name="array", shape=shape, dtype=dtype, data=data) elif method == "array": - array = group.array(name="array", shape=shape, dtype=dtype, data=data) + with pytest.warns(DeprecationWarning): + array = group.array(name="array", shape=shape, dtype=dtype, data=data) else: raise AssertionError @@ -358,7 +375,7 @@ def test_group_create_array( with pytest.raises(ContainsArrayError): group.create_array(name="array", shape=shape, dtype=dtype, data=data) elif method == "array": - with pytest.raises(ContainsArrayError): + with pytest.raises(ContainsArrayError), pytest.warns(DeprecationWarning): group.array(name="array", shape=shape, dtype=dtype, data=data) assert array.shape == shape assert array.dtype == np.dtype(dtype) @@ -653,3 +670,56 @@ async def test_asyncgroup_update_attributes( agroup_new_attributes = await agroup.update_attributes(attributes_new) assert agroup_new_attributes.attrs == attributes_new + + +async def test_group_members_async(store: LocalStore | MemoryStore): + group = AsyncGroup( + GroupMetadata(), + store_path=StorePath(store=store, path="root"), + ) + a0 = await group.create_array("a0", (1,)) + g0 = await group.create_group("g0") + a1 = await g0.create_array("a1", (1,)) + g1 = await g0.create_group("g1") + a2 = await g1.create_array("a2", (1,)) + g2 = await g1.create_group("g2") + + # immediate children + children = sorted([x async for x in group.members()], key=lambda x: x[0]) + assert children == [ + ("a0", a0), + ("g0", g0), + ] + + nmembers = await group.nmembers() + assert nmembers == 2 + + # partial + children = sorted([x async for x in group.members(max_depth=1)], key=lambda x: x[0]) + expected = [ + ("a0", a0), + ("g0", g0), + ("g0/a1", a1), + ("g0/g1", g1), + ] + assert children == expected + nmembers = await group.nmembers(max_depth=1) + assert nmembers == 4 + + # all children + all_children = sorted([x async for x in group.members(max_depth=None)], key=lambda x: x[0]) + expected = [ + ("a0", a0), + ("g0", g0), + ("g0/a1", a1), + ("g0/g1", g1), + ("g0/g1/a2", a2), + ("g0/g1/g2", g2), + ] + assert all_children == expected + + nmembers = await group.nmembers(max_depth=None) + assert nmembers == 6 + + with pytest.raises(ValueError, match="max_depth"): + [x async for x in group.members(max_depth=-1)] diff --git a/tests/v3/test_metadata/test_v3.py b/tests/v3/test_metadata/test_v3.py index eedcdf6234..1a0c5b94d7 100644 --- a/tests/v3/test_metadata/test_v3.py +++ b/tests/v3/test_metadata/test_v3.py @@ -1,10 +1,12 @@ from __future__ import annotations +import json import re from typing import TYPE_CHECKING, Literal from zarr.abc.codec import Codec from zarr.codecs.bytes import BytesCodec +from zarr.core.buffer import default_buffer_prototype from zarr.core.chunk_key_encodings import DefaultChunkKeyEncoding, V2ChunkKeyEncoding if TYPE_CHECKING: @@ -230,3 +232,24 @@ def test_metadata_to_dict( observed.pop("chunk_key_encoding") expected.pop("chunk_key_encoding") assert observed == expected + + +@pytest.mark.parametrize("fill_value", [-1, 0, 1, 2932897]) +@pytest.mark.parametrize("precision", ["ns", "D"]) +async def test_datetime_metadata(fill_value: int, precision: str): + metadata_dict = { + "zarr_format": 3, + "node_type": "array", + "shape": (1,), + "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": (1,)}}, + "data_type": f"