diff --git a/monai/inferers/merger.py b/monai/inferers/merger.py index d01d334142..1344207e18 100644 --- a/monai/inferers/merger.py +++ b/monai/inferers/merger.py @@ -15,12 +15,13 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from contextlib import nullcontext +from tempfile import TemporaryDirectory from typing import TYPE_CHECKING, Any import numpy as np import torch -from monai.utils import ensure_tuple_size, optional_import, require_pkg +from monai.utils import ensure_tuple_size, get_package_version, optional_import, require_pkg, version_geq if TYPE_CHECKING: import zarr @@ -233,7 +234,7 @@ def __init__( store: zarr.storage.Store | str = "merged.zarr", value_store: zarr.storage.Store | str | None = None, count_store: zarr.storage.Store | str | None = None, - compressor: str = "default", + compressor: str | None = None, value_compressor: str | None = None, count_compressor: str | None = None, chunks: Sequence[int] | bool = True, @@ -246,8 +247,22 @@ def __init__( self.value_dtype = value_dtype self.count_dtype = count_dtype self.store = store - self.value_store = zarr.storage.TempStore() if value_store is None else value_store - self.count_store = zarr.storage.TempStore() if count_store is None else count_store + self.tmpdir: TemporaryDirectory | None + if version_geq(get_package_version("zarr"), "3.0.0"): + if value_store is None: + self.tmpdir = TemporaryDirectory() + self.value_store = zarr.storage.LocalStore(self.tmpdir.name) + else: + self.value_store = value_store + if count_store is None: + self.tmpdir = TemporaryDirectory() + self.count_store = zarr.storage.LocalStore(self.tmpdir.name) + else: + self.count_store = count_store + else: + self.tmpdir = None + self.value_store = zarr.storage.TempStore() if value_store is None else value_store + self.count_store = zarr.storage.TempStore() if count_store is None else count_store self.chunks = chunks self.compressor = compressor self.value_compressor = value_compressor diff --git a/tests/test_zarr_avg_merger.py b/tests/test_zarr_avg_merger.py index a52dbceb4c..3c89e4fb03 100644 --- a/tests/test_zarr_avg_merger.py +++ b/tests/test_zarr_avg_merger.py @@ -287,15 +287,16 @@ class ZarrAvgMergerTests(unittest.TestCase): ] ) def test_zarr_avg_merger_patches(self, arguments, patch_locations, expected): + codec_reg = numcodecs.registry.codec_registry if "compressor" in arguments: if arguments["compressor"] != "default": - arguments["compressor"] = zarr.codec_registry[arguments["compressor"].lower()]() + arguments["compressor"] = codec_reg[arguments["compressor"].lower()]() if "value_compressor" in arguments: if arguments["value_compressor"] != "default": - arguments["value_compressor"] = zarr.codec_registry[arguments["value_compressor"].lower()]() + arguments["value_compressor"] = codec_reg[arguments["value_compressor"].lower()]() if "count_compressor" in arguments: if arguments["count_compressor"] != "default": - arguments["count_compressor"] = zarr.codec_registry[arguments["count_compressor"].lower()]() + arguments["count_compressor"] = codec_reg[arguments["count_compressor"].lower()]() merger = ZarrAvgMerger(**arguments) for pl in patch_locations: merger.aggregate(pl[0], pl[1])