Skip to content

Fix typing errors in test_consolidated #3047

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/zarr/core/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def to_dict(self) -> dict[str, JSON]:
}

@classmethod
def from_dict(cls, data: dict[str, JSON]) -> ConsolidatedMetadata:
def from_dict(cls, data: Mapping[str, JSON]) -> ConsolidatedMetadata:
data = dict(data)

kind = data.get("kind")
Expand Down
75 changes: 44 additions & 31 deletions tests/test_metadata/test_consolidated.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
open,
open_consolidated,
)
from zarr.api.synchronous import Group
from zarr.core.buffer import cpu, default_buffer_prototype
from zarr.core.group import ConsolidatedMetadata, GroupMetadata
from zarr.core.metadata import ArrayV3Metadata
Expand All @@ -25,11 +26,11 @@

if TYPE_CHECKING:
from zarr.abc.store import Store
from zarr.core.common import ZarrFormat
from zarr.core.common import JSON, ZarrFormat


@pytest.fixture
async def memory_store_with_hierarchy(memory_store: Store) -> None:
async def memory_store_with_hierarchy(memory_store: Store) -> Store:
g = await group(store=memory_store, attributes={"foo": "bar"})
dtype = "uint8"
await g.create_array(name="air", shape=(1, 2, 3), dtype=dtype)
Expand All @@ -49,15 +50,15 @@ async def memory_store_with_hierarchy(memory_store: Store) -> None:


class TestConsolidated:
async def test_open_consolidated_false_raises(self):
async def test_open_consolidated_false_raises(self) -> None:
store = zarr.storage.MemoryStore()
with pytest.raises(TypeError, match="use_consolidated"):
await zarr.api.asynchronous.open_consolidated(store, use_consolidated=False)
await zarr.api.asynchronous.open_consolidated(store, use_consolidated=False) # type: ignore[arg-type]

def test_open_consolidated_false_raises_sync(self):
def test_open_consolidated_false_raises_sync(self) -> None:
store = zarr.storage.MemoryStore()
with pytest.raises(TypeError, match="use_consolidated"):
zarr.open_consolidated(store, use_consolidated=False)
zarr.open_consolidated(store, use_consolidated=False) # type: ignore[arg-type]

async def test_consolidated(self, memory_store_with_hierarchy: Store) -> None:
# TODO: Figure out desired keys in
Expand All @@ -69,7 +70,7 @@ async def test_consolidated(self, memory_store_with_hierarchy: Store) -> None:
await consolidate_metadata(memory_store_with_hierarchy)
group2 = await AsyncGroup.open(memory_store_with_hierarchy)

array_metadata = {
array_metadata: dict[str, JSON] = {
"attributes": {},
"chunk_key_encoding": {
"configuration": {"separator": "/"},
Expand Down Expand Up @@ -186,13 +187,11 @@ async def test_consolidated(self, memory_store_with_hierarchy: Store) -> None:
group4 = await open_consolidated(store=memory_store_with_hierarchy)
assert group4.metadata == expected

result_raw = json.loads(
(
await memory_store_with_hierarchy.get(
"zarr.json", prototype=default_buffer_prototype()
)
).to_bytes()
)["consolidated_metadata"]
val = await memory_store_with_hierarchy.get(
"zarr.json", prototype=default_buffer_prototype()
)
assert val is not None
result_raw = json.loads((val).to_bytes())["consolidated_metadata"]
assert result_raw["kind"] == "inline"
assert sorted(result_raw["metadata"]) == [
"air",
Expand All @@ -206,7 +205,7 @@ async def test_consolidated(self, memory_store_with_hierarchy: Store) -> None:
"time",
]

def test_consolidated_sync(self, memory_store):
def test_consolidated_sync(self, memory_store: zarr.storage.MemoryStore) -> None:
g = zarr.api.synchronous.group(store=memory_store, attributes={"foo": "bar"})
dtype = "uint8"
g.create_array(name="air", shape=(1, 2, 3), dtype=dtype)
Expand All @@ -215,9 +214,9 @@ def test_consolidated_sync(self, memory_store):
g.create_array(name="time", shape=(3,), dtype=dtype)

zarr.api.synchronous.consolidate_metadata(memory_store)
group2 = zarr.api.synchronous.Group.open(memory_store)
group2 = Group.open(memory_store)

array_metadata = {
array_metadata: dict[str, JSON] = {
"attributes": {},
"chunk_key_encoding": {
"configuration": {"separator": "/"},
Expand Down Expand Up @@ -306,8 +305,8 @@ async def test_non_root_node(self, memory_store_with_hierarchy: Store) -> None:
assert "air" not in child.metadata.consolidated_metadata.metadata
assert "grandchild" in child.metadata.consolidated_metadata.metadata

def test_consolidated_metadata_from_dict(self):
data = {"must_understand": False}
def test_consolidated_metadata_from_dict(self) -> None:
data: dict[str, JSON] = {"must_understand": False}

# missing kind
with pytest.raises(ValueError, match="kind='None'"):
Expand All @@ -329,16 +328,16 @@ def test_consolidated_metadata_from_dict(self):
data["metadata"] = {}
ConsolidatedMetadata.from_dict(data)

def test_flatten(self):
array_metadata = {
def test_flatten(self) -> None:
array_metadata: dict[str, JSON] = {
"attributes": {},
"chunk_key_encoding": {
"configuration": {"separator": "/"},
"name": "default",
},
"codecs": ({"configuration": {"endian": "little"}, "name": "bytes"},),
"data_type": "float64",
"fill_value": np.float64(0.0),
"fill_value": 0,
"node_type": "array",
# "shape": (1, 2, 3),
"zarr_format": 3,
Expand Down Expand Up @@ -407,6 +406,17 @@ def test_flatten(self):
},
)
result = metadata.flattened_metadata
assert isinstance(metadata.metadata["child"], GroupMetadata)
assert isinstance(metadata.metadata["child"].consolidated_metadata, ConsolidatedMetadata)
assert isinstance(
metadata.metadata["child"].consolidated_metadata.metadata["grandchild"], GroupMetadata
)
assert isinstance(
metadata.metadata["child"]
.consolidated_metadata.metadata["grandchild"]
.consolidated_metadata,
ConsolidatedMetadata,
)
expected = {
"air": metadata.metadata["air"],
"lat": metadata.metadata["lat"],
Expand All @@ -426,7 +436,7 @@ def test_flatten(self):
}
assert result == expected

def test_invalid_metadata_raises(self):
def test_invalid_metadata_raises(self) -> None:
payload = {
"kind": "inline",
"must_understand": False,
Expand All @@ -436,9 +446,9 @@ def test_invalid_metadata_raises(self):
}

with pytest.raises(TypeError, match="key='foo', type='list'"):
ConsolidatedMetadata.from_dict(payload)
ConsolidatedMetadata.from_dict(payload) # type: ignore[arg-type]

def test_to_dict_empty(self):
def test_to_dict_empty(self) -> None:
meta = ConsolidatedMetadata(
metadata={
"empty": GroupMetadata(
Expand Down Expand Up @@ -467,7 +477,7 @@ def test_to_dict_empty(self):
assert result == expected

@pytest.mark.parametrize("zarr_format", [2, 3])
async def test_open_consolidated_raises_async(self, zarr_format: ZarrFormat):
async def test_open_consolidated_raises_async(self, zarr_format: ZarrFormat) -> None:
store = zarr.storage.MemoryStore()
await AsyncGroup.from_store(store, zarr_format=zarr_format)
with pytest.raises(ValueError):
Expand All @@ -485,12 +495,15 @@ async def v2_consolidated_metadata_empty_dataset(
b'{"metadata":{".zgroup":{"zarr_format":2}},"zarr_consolidated_format":1}'
)
return AsyncGroup._from_bytes_v2(
None, zgroup_bytes, zattrs_bytes=None, consolidated_metadata_bytes=zmetadata_bytes
None, # type: ignore[arg-type]
zgroup_bytes,
zattrs_bytes=None,
consolidated_metadata_bytes=zmetadata_bytes,
)

async def test_consolidated_metadata_backwards_compatibility(
self, v2_consolidated_metadata_empty_dataset
):
self, v2_consolidated_metadata_empty_dataset: AsyncGroup
) -> None:
"""
Test that consolidated metadata handles a missing .zattrs key. This is necessary for backwards compatibility with zarr-python 2.x. See https://github.com/zarr-developers/zarr-python/issues/2694
"""
Expand All @@ -500,7 +513,7 @@ async def test_consolidated_metadata_backwards_compatibility(
result = await zarr.api.asynchronous.open_consolidated(store, zarr_format=2)
assert result.metadata == v2_consolidated_metadata_empty_dataset.metadata

async def test_consolidated_metadata_v2(self):
async def test_consolidated_metadata_v2(self) -> None:
store = zarr.storage.MemoryStore()
g = await AsyncGroup.from_store(store, attributes={"key": "root"}, zarr_format=2)
dtype = "uint8"
Expand Down Expand Up @@ -622,7 +635,7 @@ async def test_use_consolidated_for_children_members(
@pytest.mark.parametrize("fill_value", [np.nan, np.inf, -np.inf])
async def test_consolidated_metadata_encodes_special_chars(
memory_store: Store, zarr_format: ZarrFormat, fill_value: float
):
) -> None:
root = await group(store=memory_store, zarr_format=zarr_format)
_child = await root.create_group("child", attributes={"test": fill_value})
_time = await root.create_array("time", shape=(12,), dtype=np.float64, fill_value=fill_value)
Expand Down
7 changes: 6 additions & 1 deletion tests/test_metadata/test_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import zarr.api.asynchronous
import zarr.storage
from zarr.core.array import AsyncArray
from zarr.core.buffer import cpu
from zarr.core.buffer.core import default_buffer_prototype
from zarr.core.group import ConsolidatedMetadata, GroupMetadata
Expand All @@ -18,6 +19,8 @@
from typing import Any

from zarr.abc.codec import Codec
from zarr.core.common import JSON


import numcodecs

Expand Down Expand Up @@ -104,7 +107,7 @@ class TestConsolidated:
async def v2_consolidated_metadata(
self, memory_store: zarr.storage.MemoryStore
) -> zarr.storage.MemoryStore:
zmetadata = {
zmetadata: dict[str, JSON] = {
"metadata": {
".zattrs": {
"Conventions": "COARDS",
Expand Down Expand Up @@ -274,6 +277,7 @@ async def test_getitem_consolidated(self, v2_consolidated_metadata):
store = v2_consolidated_metadata
group = await zarr.api.asynchronous.open_consolidated(store=store, zarr_format=2)
air = await group.getitem("air")
assert isinstance(air, AsyncArray[ArrayV2Metadata])
assert air.metadata.shape == (730,)


Expand Down Expand Up @@ -335,6 +339,7 @@ def test_structured_dtype_fill_value_serialization(tmp_path, fill_value):

zarr.consolidate_metadata(root_group.store, zarr_format=2)
root_group = zarr.open_group(group_path, mode="r")
assert isinstance(root_group.metadata.consolidated_metadata, ConsolidatedMetadata)
assert (
root_group.metadata.consolidated_metadata.to_dict()["metadata"]["structured_dtype"][
"fill_value"
Expand Down
14 changes: 7 additions & 7 deletions tests/test_metadata/test_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def test_default_fill_value(dtype_str: str) -> None:
(0j, "complex64"),
],
)
def test_parse_fill_value_valid(fill_value: Any, dtype_str: str) -> None:
def test_parse_fill_value_valid(fill_value: bool | float, dtype_str: str) -> None:
"""
Test that parse_fill_value(fill_value, dtype) casts fill_value to the given dtype.
"""
Expand All @@ -156,18 +156,18 @@ def test_parse_fill_value_valid(fill_value: Any, dtype_str: str) -> None:

@pytest.mark.parametrize("fill_value", ["not a valid value"])
@pytest.mark.parametrize("dtype_str", [*int_dtypes, *float_dtypes, *complex_dtypes])
def test_parse_fill_value_invalid_value(fill_value: Any, dtype_str: str) -> None:
def test_parse_fill_value_invalid_value(fill_value: str, dtype_str: str) -> None:
"""
Test that parse_fill_value(fill_value, dtype) raises ValueError for invalid values.
This test excludes bool because the bool constructor takes anything.
"""
with pytest.raises(ValueError):
parse_fill_value(fill_value, dtype_str)
parse_fill_value(fill_value, dtype_str) # type: ignore[arg-type]


@pytest.mark.parametrize("fill_value", [[1.0, 0.0], [0, 1], complex(1, 1), np.complex64(0)])
@pytest.mark.parametrize("dtype_str", [*complex_dtypes])
def test_parse_fill_value_complex(fill_value: Any, dtype_str: str) -> None:
def test_parse_fill_value_complex(fill_value: list[int] | complex, dtype_str: str) -> None:
"""
Test that parse_fill_value(fill_value, dtype) correctly handles complex values represented
as length-2 sequences
Expand All @@ -193,18 +193,18 @@ def test_parse_fill_value_complex_invalid(fill_value: Any, dtype_str: str) -> No
f"length {len(fill_value)}."
)
with pytest.raises(ValueError, match=re.escape(match)):
parse_fill_value(fill_value=fill_value, dtype=dtype_str)
parse_fill_value(fill_value=fill_value, dtype=dtype_str) # type; ignore[arg-type]


@pytest.mark.parametrize("fill_value", [{"foo": 10}])
@pytest.mark.parametrize("dtype_str", [*int_dtypes, *float_dtypes, *complex_dtypes])
def test_parse_fill_value_invalid_type(fill_value: Any, dtype_str: str) -> None:
def test_parse_fill_value_invalid_type(fill_value: dict[str, int], dtype_str: str) -> None:
"""
Test that parse_fill_value(fill_value, dtype) raises TypeError for invalid non-sequential types.
This test excludes bool because the bool constructor takes anything.
"""
with pytest.raises(ValueError, match=r"fill value .* is not valid for dtype .*"):
parse_fill_value(fill_value, dtype_str)
parse_fill_value(fill_value, dtype_str) # type: ignore[arg-type]


@pytest.mark.parametrize(
Expand Down
Loading