From 1e5045a8c56a93d83d9347abcc020be2bb3cf622 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 7 Jan 2025 07:11:19 -0700 Subject: [PATCH] Fix zarr upstream tests (#9927) Co-authored-by: Matthew Iannucci --- xarray/backends/zarr.py | 21 ++++++----- xarray/tests/test_backends.py | 67 ++++++++++++++++------------------- 2 files changed, 40 insertions(+), 48 deletions(-) diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index f7f30272941..383c385e1d5 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -447,10 +447,11 @@ def extract_zarr_variable_encoding( safe_to_drop = {"source", "original_shape", "preferred_chunks"} valid_encodings = { - "codecs", "chunks", - "compressor", + "compressor", # TODO: delete when min zarr >=3 + "compressors", "filters", + "serializer", "cache_metadata", "write_empty_chunks", } @@ -480,6 +481,8 @@ def extract_zarr_variable_encoding( mode=mode, shape=shape, ) + if _zarr_v3() and chunks is None: + chunks = "auto" encoding["chunks"] = chunks return encoding @@ -816,24 +819,20 @@ def open_store_variable(self, name): ) attributes = dict(attributes) - # TODO: this should not be needed once - # https://github.com/zarr-developers/zarr-python/issues/1269 is resolved. - attributes.pop("filters", None) - encoding = { "chunks": zarr_array.chunks, "preferred_chunks": dict(zip(dimensions, zarr_array.chunks, strict=True)), } - if _zarr_v3() and zarr_array.metadata.zarr_format == 3: - encoding["codecs"] = [x.to_dict() for x in zarr_array.metadata.codecs] - elif _zarr_v3(): + if _zarr_v3(): encoding.update( { - "compressor": zarr_array.metadata.compressor, - "filters": zarr_array.metadata.filters, + "compressors": zarr_array.compressors, + "filters": zarr_array.filters, } ) + if self.zarr_group.metadata.zarr_format == 3: + encoding.update({"serializer": zarr_array.serializer}) else: encoding.update( { diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 330dd1dac1f..cfca5e69048 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -134,21 +134,21 @@ @pytest.fixture(scope="module", params=ZARR_FORMATS) -def default_zarr_version(request) -> Generator[None, None]: +def default_zarr_format(request) -> Generator[None, None]: if has_zarr_v3: - with zarr.config.set(default_zarr_version=request.param): + with zarr.config.set(default_zarr_format=request.param): yield else: yield def skip_if_zarr_format_3(reason: str): - if has_zarr_v3 and zarr.config["default_zarr_version"] == 3: + if has_zarr_v3 and zarr.config["default_zarr_format"] == 3: pytest.skip(reason=f"Unsupported with zarr_format=3: {reason}") def skip_if_zarr_format_2(reason: str): - if not has_zarr_v3 or (zarr.config["default_zarr_version"] == 2): + if not has_zarr_v3 or (zarr.config["default_zarr_format"] == 2): pytest.skip(reason=f"Unsupported with zarr_format=2: {reason}") @@ -2270,7 +2270,7 @@ def test_roundtrip_coordinates(self) -> None: @requires_zarr -@pytest.mark.usefixtures("default_zarr_version") +@pytest.mark.usefixtures("default_zarr_format") class ZarrBase(CFEncodedBase): DIMENSION_KEY = "_ARRAY_DIMENSIONS" zarr_version = 2 @@ -2439,7 +2439,7 @@ def test_warning_on_bad_chunks(self) -> None: with warnings.catch_warnings(): warnings.filterwarnings( "ignore", - message=".*Zarr version 3 specification.*", + message=".*Zarr format 3 specification.*", category=UserWarning, ) with self.roundtrip(original, open_kwargs=kwargs) as actual: @@ -2675,15 +2675,14 @@ def test_write_persistence_modes(self, group) -> None: assert_identical(original, actual) def test_compressor_encoding(self) -> None: - original = create_test_data() # specify a custom compressor - - if has_zarr_v3 and zarr.config.config["default_zarr_version"] == 3: - encoding_key = "codecs" + original = create_test_data() + if has_zarr_v3 and zarr.config.config["default_zarr_format"] == 3: + encoding_key = "compressors" # all parameters need to be explicitly specified in order for the comparison to pass below encoding = { + "serializer": zarr.codecs.BytesCodec(endian="little"), encoding_key: ( - zarr.codecs.BytesCodec(endian="little"), zarr.codecs.BloscCodec( cname="zstd", clevel=3, @@ -2691,24 +2690,20 @@ def test_compressor_encoding(self) -> None: typesize=8, blocksize=0, ), - ) + ), } else: from numcodecs.blosc import Blosc - encoding_key = "compressor" - encoding = {encoding_key: Blosc(cname="zstd", clevel=3, shuffle=2)} + encoding_key = "compressors" if has_zarr_v3 else "compressor" + comp = Blosc(cname="zstd", clevel=3, shuffle=2) + encoding = {encoding_key: (comp,) if has_zarr_v3 else comp} save_kwargs = dict(encoding={"var1": encoding}) with self.roundtrip(original, save_kwargs=save_kwargs) as ds: enc = ds["var1"].encoding[encoding_key] - if has_zarr_v3 and zarr.config.config["default_zarr_version"] == 3: - # TODO: figure out a cleaner way to do this comparison - codecs = zarr.core.metadata.v3.parse_codecs(enc) - assert codecs == encoding[encoding_key] - else: - assert enc == encoding[encoding_key] + assert enc == encoding[encoding_key] def test_group(self) -> None: original = create_test_data() @@ -2846,14 +2841,12 @@ def test_check_encoding_is_consistent_after_append(self) -> None: import numcodecs encoding_value: Any - if has_zarr_v3 and zarr.config.config["default_zarr_version"] == 3: + if has_zarr_v3 and zarr.config.config["default_zarr_format"] == 3: compressor = zarr.codecs.BloscCodec() - encoding_key = "codecs" - encoding_value = [zarr.codecs.BytesCodec(), compressor] else: compressor = numcodecs.Blosc() - encoding_key = "compressor" - encoding_value = compressor + encoding_key = "compressors" if has_zarr_v3 else "compressor" + encoding_value = (compressor,) if has_zarr_v3 else compressor encoding = {"da": {encoding_key: encoding_value}} ds.to_zarr(store_target, mode="w", encoding=encoding, **self.version_kwargs) @@ -2995,7 +2988,7 @@ def test_no_warning_from_open_emptydim_with_chunks(self) -> None: with warnings.catch_warnings(): warnings.filterwarnings( "ignore", - message=".*Zarr version 3 specification.*", + message=".*Zarr format 3 specification.*", category=UserWarning, ) with self.roundtrip(ds, open_kwargs=dict(chunks={"a": 1})) as ds_reload: @@ -5479,7 +5472,7 @@ def test_dataarray_to_netcdf_no_name_pathlib(self) -> None: @requires_zarr class TestDataArrayToZarr: def skip_if_zarr_python_3_and_zip_store(self, store) -> None: - if has_zarr_v3 and isinstance(store, zarr.storage.zip.ZipStore): + if has_zarr_v3 and isinstance(store, zarr.storage.ZipStore): pytest.skip( reason="zarr-python 3.x doesn't support reopening ZipStore with a new mode." ) @@ -5786,7 +5779,7 @@ def test_extract_zarr_variable_encoding() -> None: var = xr.Variable("x", [1, 2]) actual = backends.zarr.extract_zarr_variable_encoding(var) assert "chunks" in actual - assert actual["chunks"] is None + assert actual["chunks"] == ("auto" if has_zarr_v3 else None) var = xr.Variable("x", [1, 2], encoding={"chunks": (1,)}) actual = backends.zarr.extract_zarr_variable_encoding(var) @@ -6092,14 +6085,14 @@ def test_raise_writing_to_nczarr(self, mode) -> None: @requires_netCDF4 @requires_dask -@pytest.mark.usefixtures("default_zarr_version") +@pytest.mark.usefixtures("default_zarr_format") def test_pickle_open_mfdataset_dataset(): with open_example_mfdataset(["bears.nc"]) as ds: assert_identical(ds, pickle.loads(pickle.dumps(ds))) @requires_zarr -@pytest.mark.usefixtures("default_zarr_version") +@pytest.mark.usefixtures("default_zarr_format") def test_zarr_closing_internal_zip_store(): store_name = "tmp.zarr.zip" original_da = DataArray(np.arange(12).reshape((3, 4))) @@ -6110,7 +6103,7 @@ def test_zarr_closing_internal_zip_store(): @requires_zarr -@pytest.mark.usefixtures("default_zarr_version") +@pytest.mark.usefixtures("default_zarr_format") class TestZarrRegionAuto: def test_zarr_region_auto_all(self, tmp_path): x = np.arange(0, 50, 10) @@ -6286,7 +6279,7 @@ def test_zarr_region_append(self, tmp_path): @requires_zarr -@pytest.mark.usefixtures("default_zarr_version") +@pytest.mark.usefixtures("default_zarr_format") def test_zarr_region(tmp_path): x = np.arange(0, 50, 10) y = np.arange(0, 20, 2) @@ -6315,7 +6308,7 @@ def test_zarr_region(tmp_path): @requires_zarr @requires_dask -@pytest.mark.usefixtures("default_zarr_version") +@pytest.mark.usefixtures("default_zarr_format") def test_zarr_region_chunk_partial(tmp_path): """ Check that writing to partial chunks with `region` fails, assuming `safe_chunks=False`. @@ -6336,7 +6329,7 @@ def test_zarr_region_chunk_partial(tmp_path): @requires_zarr @requires_dask -@pytest.mark.usefixtures("default_zarr_version") +@pytest.mark.usefixtures("default_zarr_format") def test_zarr_append_chunk_partial(tmp_path): t_coords = np.array([np.datetime64("2020-01-01").astype("datetime64[ns]")]) data = np.ones((10, 10)) @@ -6374,7 +6367,7 @@ def test_zarr_append_chunk_partial(tmp_path): @requires_zarr @requires_dask -@pytest.mark.usefixtures("default_zarr_version") +@pytest.mark.usefixtures("default_zarr_format") def test_zarr_region_chunk_partial_offset(tmp_path): # https://github.com/pydata/xarray/pull/8459#issuecomment-1819417545 store = tmp_path / "foo.zarr" @@ -6394,7 +6387,7 @@ def test_zarr_region_chunk_partial_offset(tmp_path): @requires_zarr @requires_dask -@pytest.mark.usefixtures("default_zarr_version") +@pytest.mark.usefixtures("default_zarr_format") def test_zarr_safe_chunk_append_dim(tmp_path): store = tmp_path / "foo.zarr" data = np.ones((20,)) @@ -6445,7 +6438,7 @@ def test_zarr_safe_chunk_append_dim(tmp_path): @requires_zarr @requires_dask -@pytest.mark.usefixtures("default_zarr_version") +@pytest.mark.usefixtures("default_zarr_format") def test_zarr_safe_chunk_region(tmp_path): store = tmp_path / "foo.zarr"