diff --git a/docs/api/changelog.rst b/docs/api/changelog.rst index 7b49983d4..ba0284ca9 100644 --- a/docs/api/changelog.rst +++ b/docs/api/changelog.rst @@ -27,6 +27,9 @@ Fixed - :meth'`imod.mf6.Modflow6Simulation.write` failed after splitting the simulation. This has been fixed. - modflow options like "print flow" , "save flow" and "print input" can now be set on :class:`imod.mf6.Well` +- when regridding a :class:`imod.mf6.Modflow6Simulation`, :class:`imod.mf6.GroundwaterFlowModel`, :class:`imod.mf6.GroundwaterTransportModel` + or a :class:`imod.mf6.package`, regridding weights are now cached and can be re-used over the different objects that are regridded. + This improves performance considerably in most use cases: when regridding is applied over the same grid cells with the same regridder type, but with different values/methods, multiple times. Added ~~~~~ diff --git a/examples/mf6/different_ways_to_regrid_models.py b/examples/mf6/different_ways_to_regrid_models.py index 31480696e..d3d33843a 100644 --- a/examples/mf6/different_ways_to_regrid_models.py +++ b/examples/mf6/different_ways_to_regrid_models.py @@ -25,7 +25,7 @@ import xarray as xr from example_models import create_twri_simulation -from imod.mf6.utilities.regrid import RegridderType +from imod.mf6.utilities.regrid import RegridderType, RegridderWeightsCache # %% # Now we create the twri simulation itself. It yields a simulation of a flow problem, with a grid of 3 layers and 15 cells in both x and y directions. @@ -89,13 +89,21 @@ regridded_k_2.sel(layer=1).plot(y="y", yincrease=False, ax=ax) -# %% -# Finally, we can regrid package per package. This allows us to choose the regridding method as well. -# in this example we'll regrid the npf package manually and the rest of the packages using default methods. +# %% Finally, we can regrid package per package. This allows us to choose the +# regridding method as well. in this example we'll regrid the npf package +# manually and the rest of the packages using default methods. +# +# Note that we create a RegridderWeightsCache here. This will store the weights +# of the regridder. Using the same cache to regrid another package will lead to +# a performance increase if that package uses the same regridding method, +# because initializing a regridder is costly. regridder_types = {"k": (RegridderType.CENTROIDLOCATOR, None)} +regrid_context = RegridderWeightsCache(model["npf"]["k"], target_grid) npf_regridded = model["npf"].regrid_like( - target_grid=target_grid, regridder_types=regridder_types + target_grid=target_grid, + regrid_context=regrid_context, + regridder_types=regridder_types, ) new_model["npf"] = npf_regridded diff --git a/imod/mf6/__init__.py b/imod/mf6/__init__.py index 29c8cc891..d72b1bdb8 100644 --- a/imod/mf6/__init__.py +++ b/imod/mf6/__init__.py @@ -47,7 +47,7 @@ from imod.mf6.ssm import SourceSinkMixing from imod.mf6.sto import SpecificStorage, Storage, StorageCoefficient from imod.mf6.timedis import TimeDiscretization -from imod.mf6.utilities.regrid import RegridderInstancesCollection, RegridderType +from imod.mf6.utilities.regrid import RegridderType, RegridderWeightsCache from imod.mf6.uzf import UnsaturatedZoneFlow from imod.mf6.wel import Well, WellDisStructured, WellDisVertices from imod.mf6.write_context import WriteContext diff --git a/imod/mf6/model.py b/imod/mf6/model.py index 9208fd82e..216614d7d 100644 --- a/imod/mf6/model.py +++ b/imod/mf6/model.py @@ -23,9 +23,7 @@ from imod.mf6.package import Package from imod.mf6.statusinfo import NestedStatusInfo, StatusInfo, StatusInfoBase from imod.mf6.utilities.mask import _mask_all_packages -from imod.mf6.utilities.regrid import ( - _regrid_like, -) +from imod.mf6.utilities.regrid import RegridderWeightsCache, _regrid_like from imod.mf6.validation import pkg_errors_to_status_info from imod.mf6.write_context import WriteContext from imod.schemata import ValidationError @@ -468,7 +466,10 @@ def _clip_box_packages( return clipped def regrid_like( - self, target_grid: GridDataArray, validate: bool = True + self, + target_grid: GridDataArray, + validate: bool = True, + regrid_context: Optional[RegridderWeightsCache] = None, ) -> "Modflow6Model": """ Creates a model by regridding the packages of this model to another discretization. @@ -482,13 +483,16 @@ def regrid_like( a grid defined over the same discretization as the one we want to regrid the package to validate: bool set to true to validate the regridded packages + regrid_context: Optional RegridderWeightsCache + stores regridder weights for different regridders. Can be used to speed up regridding, + if the same regridders are used several times for regridding different arrays. Returns ------- a model with similar packages to the input model, and with all the data-arrays regridded to another discretization, similar to the one used in input argument "target_grid" """ - return _regrid_like(self, target_grid, validate) + return _regrid_like(self, target_grid, validate, regrid_context) def mask_all_packages( self, diff --git a/imod/mf6/package.py b/imod/mf6/package.py index 15203e3e6..04a801333 100644 --- a/imod/mf6/package.py +++ b/imod/mf6/package.py @@ -25,6 +25,7 @@ from imod.mf6.utilities.mask import _mask from imod.mf6.utilities.regrid import ( RegridderType, + RegridderWeightsCache, _regrid_like, ) from imod.mf6.utilities.schemata import filter_schemata_dict @@ -554,6 +555,7 @@ def mask(self, mask: GridDataArray) -> Any: def regrid_like( self, target_grid: GridDataArray, + regrid_context: RegridderWeightsCache, regridder_types: Optional[dict[str, Tuple[RegridderType, str]]] = None, ) -> "Package": """ @@ -580,6 +582,9 @@ def regrid_like( regridder_types: dict(str->(regridder type,str)) dictionary mapping arraynames (str) to a tuple of regrid type (a specialization class of BaseRegridder) and function name (str) this dictionary can be used to override the default mapping method. + regrid_context: Optional RegridderWeightsCache + stores regridder weights for different regridders. Can be used to speed up regridding, + if the same regridders are used several times for regridding different arrays. Returns ------- @@ -587,7 +592,7 @@ def regrid_like( similar to the one used in input argument "target_grid" """ try: - result = _regrid_like(self, target_grid, regridder_types) + result = _regrid_like(self, target_grid, regrid_context, regridder_types) except ValueError as e: raise e except Exception: diff --git a/imod/mf6/utilities/regrid.py b/imod/mf6/utilities/regrid.py index c9bcc43dd..5062cf498 100644 --- a/imod/mf6/utilities/regrid.py +++ b/imod/mf6/utilities/regrid.py @@ -23,13 +23,14 @@ from imod.mf6.utilities.clip import clip_by_grid from imod.mf6.utilities.regridding_types import RegridderType from imod.schemata import ValidationError -from imod.typing.grid import GridDataArray, ones_like +from imod.typing.grid import GridDataArray, get_grid_geometry_hash, ones_like -class RegridderInstancesCollection: +class RegridderWeightsCache: """ This class stores any number of regridders that can regrid a single source grid to a single target grid. By storing the regridders, we make sure the regridders can be re-used for different arrays on the same grid. + Regridders are stored based on their type (`see these docs`_) and planar coordinates (x, y). This is important because computing the regridding weights is a costly affair. """ @@ -37,6 +38,7 @@ def __init__( self, source_grid: Union[xr.DataArray, xu.UgridDataArray], target_grid: Union[xr.DataArray, xu.UgridDataArray], + max_cache_size: int = 6, ) -> None: self.regridder_instances: dict[ tuple[type[BaseRegridder], Optional[str]], BaseRegridder @@ -44,27 +46,8 @@ def __init__( self._source_grid = source_grid self._target_grid = target_grid - def __has_regridder( - self, regridder_type: type[BaseRegridder], method: Optional[str] = None - ) -> bool: - return (regridder_type, method) in self.regridder_instances.keys() - - def __get_existing_regridder( - self, regridder_type: type[BaseRegridder], method: Optional[str] - ) -> BaseRegridder: - if self.__has_regridder(regridder_type, method): - return self.regridder_instances[(regridder_type, method)] - raise ValueError("no existing regridder of type " + str(regridder_type)) - - def __create_regridder( - self, regridder_type: type[BaseRegridder], method: Optional[str] - ) -> BaseRegridder: - method_args = () if method is None else (method,) - - self.regridder_instances[(regridder_type, method)] = regridder_type( - self._source_grid, self._target_grid, *method_args - ) - return self.regridder_instances[(regridder_type, method)] + self.weights_cache = {} + self.max_cache_size = max_cache_size def __get_regridder_class( self, regridder_type: RegridderType | BaseRegridder @@ -82,6 +65,8 @@ def __get_regridder_class( def get_regridder( self, + source_grid: GridDataArray, + target_grid: GridDataArray, regridder_type: Union[RegridderType, BaseRegridder], method: Optional[str] = None, ) -> BaseRegridder: @@ -107,10 +92,31 @@ def get_regridder( """ regridder_class = self.__get_regridder_class(regridder_type) - if not self.__has_regridder(regridder_class, method): - self.__create_regridder(regridder_class, method) + if "layer" not in source_grid.coords and "layer" in target_grid.coords: + target_grid = target_grid.drop_vars("layer") + + source_hash = get_grid_geometry_hash(source_grid) + target_hash = get_grid_geometry_hash(target_grid) + key = (source_hash, target_hash, regridder_class) + if not key in self.weights_cache.keys(): + if len(self.weights_cache) >= self.max_cache_size: + self.remove_first_regridder() + kwargs = {"source": source_grid, "target": target_grid} + if method is not None: + kwargs["method"] = method + regridder = regridder_class(**kwargs) + self.weights_cache[key] = regridder.weights + else: + kwargs = {"weights": self.weights_cache[key], "target": target_grid} + if method is not None: + kwargs["method"] = method + regridder = regridder_class.from_weights(**kwargs) + + return regridder - return self.__get_existing_regridder(regridder_class, method) + def remove_first_regridder(self): + keys = list(self.weights_cache.keys()) + self.weights_cache.pop(keys[0]) def assign_coord_if_present( @@ -131,7 +137,7 @@ def assign_coord_if_present( def _regrid_array( package: IRegridPackage, varname: str, - regridder_collection: RegridderInstancesCollection, + regridder_collection: RegridderWeightsCache, regridder_name: str, regridder_function: str, target_grid: GridDataArray, @@ -167,6 +173,8 @@ def _regrid_array( # obtain an instance of a regridder for the chosen method regridder = regridder_collection.get_regridder( + package.dataset[varname], + target_grid, regridder_name, regridder_function, ) @@ -205,6 +213,7 @@ def _get_unique_regridder_types(model: IModel) -> defaultdict[RegridderType, lis def _regrid_like( package: IRegridPackage, target_grid: GridDataArray, + regrid_context: RegridderWeightsCache, regridder_types: Optional[dict[str, tuple[RegridderType, str]]] = None, ) -> IPackage: """ @@ -231,6 +240,9 @@ def _regrid_like( regridder_types: dict(str->(regridder type,str)) dictionary mapping arraynames (str) to a tuple of regrid type (a specialization class of BaseRegridder) and function name (str) this dictionary can be used to override the default mapping method. + regrid_context: RegridderWeightsCache + stores regridder weights for different regridders. Can be used to speed up regridding, + if the same regridders are used several times for regridding different arrays. Returns ------- @@ -245,10 +257,6 @@ def _regrid_like( if hasattr(package, "auxiliary_data_fields"): remove_expanded_auxiliary_variables_from_dataset(package) - regridder_collection = RegridderInstancesCollection( - package.dataset, target_grid=target_grid - ) - regridder_settings = package.get_regrid_methods() if regridder_types is not None: regridder_settings.update(regridder_types) @@ -269,7 +277,7 @@ def _regrid_like( new_package_data[varname] = _regrid_array( package, varname, - regridder_collection, + regrid_context, regridder_name, regridder_function, target_grid, @@ -289,7 +297,10 @@ def _regrid_like( @typedispatch # type: ignore[no-redef] def _regrid_like( - model: IModel, target_grid: GridDataArray, validate: bool = True + model: IModel, + target_grid: GridDataArray, + validate: bool = True, + regrid_context: Optional[RegridderWeightsCache] = None, ) -> IModel: """ Creates a model by regridding the packages of this model to another discretization. @@ -303,6 +314,9 @@ def _regrid_like( a grid defined over the same discretization as the one we want to regrid the package to validate: bool set to true to validate the regridded packages + regrid_context: Optional RegridderWeightsCache + stores regridder weights for different regridders. Can be used to speed up regridding, + if the same regridders are used several times for regridding different arrays. Returns ------- @@ -315,17 +329,18 @@ def _regrid_like( f"regridding this model cannot be done due to the presence of package {error_with_object_name}" ) new_model = model.__class__() - + if regrid_context is None: + regrid_context = RegridderWeightsCache(model.domain, target_grid) for pkg_name, pkg in model.items(): if isinstance(pkg, (IRegridPackage, ILineDataPackage, IPointDataPackage)): - new_model[pkg_name] = pkg.regrid_like(target_grid) + new_model[pkg_name] = pkg.regrid_like(target_grid, regrid_context) else: raise NotImplementedError( f"regridding is not implemented for package {pkg_name} of type {type(pkg)}" ) methods = _get_unique_regridder_types(model) - output_domain = _get_regridding_domain(model, target_grid, methods) + output_domain = _get_regridding_domain(model, target_grid, regrid_context, methods) new_model.mask_all_packages(output_domain) new_model.purge_empty_packages() if validate: @@ -371,6 +386,9 @@ def _regrid_like( raise ValueError( "Unable to regrid simulation. Regridding can only be done on simulations that have a single flow model." ) + flow_models = simulation.get_models_of_type("gwf6") + old_grid = list(flow_models.values())[0].domain + regrid_context = RegridderWeightsCache(old_grid, target_grid) models = simulation.get_models() for model_name, model in models.items(): @@ -383,7 +401,7 @@ def _regrid_like( result = simulation.__class__(regridded_simulation_name) for key, item in simulation.items(): if isinstance(item, IModel): - result[key] = item.regrid_like(target_grid, validate) + result[key] = item.regrid_like(target_grid, validate, regrid_context) elif key == "gwtgwf_exchanges": pass elif isinstance(item, IPackage) and not isinstance(item, IRegridPackage): @@ -428,6 +446,7 @@ def _regrid_like(package: object, target_grid: GridDataArray, *_) -> None: def _get_regridding_domain( model: IModel, target_grid: GridDataArray, + regrid_context: RegridderWeightsCache, methods: defaultdict[RegridderType, list[str]], ) -> GridDataArray: """ @@ -436,12 +455,9 @@ def _get_regridding_domain( cells that all regridders consider active. """ idomain = model.domain - regridder_collection = RegridderInstancesCollection( - idomain, target_grid=target_grid - ) included_in_all = ones_like(target_grid) regridders = [ - regridder_collection.get_regridder(regriddertype, function) + regrid_context.get_regridder(idomain, target_grid, regriddertype, function) for regriddertype, functionlist in methods.items() for function in functionlist ] diff --git a/imod/tests/test_mf6/test_mf6_hfb.py b/imod/tests/test_mf6/test_mf6_hfb.py index 5c854d69b..332eb5461 100644 --- a/imod/tests/test_mf6/test_mf6_hfb.py +++ b/imod/tests/test_mf6/test_mf6_hfb.py @@ -17,6 +17,7 @@ LayeredHorizontalFlowBarrierResistance, ) from imod.mf6.hfb import to_connected_cells_dataset +from imod.mf6.utilities.regrid import RegridderWeightsCache from imod.tests.fixtures.flow_basic_fixture import BasicDisSettings from imod.typing.grid import ones_like @@ -136,7 +137,11 @@ def test_to_mf6_creates_mf6_adapter( else: idomain_clipped = idomain.sel(x=slice(None, 54.0)) - hfb_clipped = hfb.regrid_like(idomain_clipped.sel(layer=1)) + regrid_context = RegridderWeightsCache( + idomain.sel(layer=1), idomain_clipped.sel(layer=1) + ) + + hfb_clipped = hfb.regrid_like(idomain_clipped.sel(layer=1), regrid_context) # Assert x, y = hfb_clipped.dataset["geometry"].values[0].xy diff --git a/imod/tests/test_mf6/test_mf6_regrid.py b/imod/tests/test_mf6/test_mf6_regrid.py index 6533dc76d..2fed6c4fe 100644 --- a/imod/tests/test_mf6/test_mf6_regrid.py +++ b/imod/tests/test_mf6/test_mf6_regrid.py @@ -9,6 +9,7 @@ import imod from imod.mf6.package import Package +from imod.mf6.utilities.regrid import RegridderWeightsCache from imod.tests.fixtures.mf6_small_models_fixture import ( grid_data_structured, grid_data_structured_layered, @@ -115,9 +116,12 @@ def test_regrid_structured(): """ structured_grid_packages = create_package_instances(is_structured=True) new_grid = grid_data_structured(np.float64, 12, 2.5) + old_grid = grid_data_structured(np.float64, 1.0e-4, 5.0) + + regrid_context = RegridderWeightsCache(old_grid, new_grid) new_packages = [] for package in structured_grid_packages: - new_packages.append(package.regrid_like(new_grid)) + new_packages.append(package.regrid_like(new_grid, regrid_context)) new_idomain = new_packages[0].dataset["icelltype"] @@ -134,9 +138,12 @@ def test_regrid_unstructured(): """ unstructured_grid_packages = create_package_instances(is_structured=False) new_grid = grid_data_unstructured(np.float64, 12, 2.5) + old_grid = grid_data_unstructured(np.float_, 1.0e-4, 5.0) + regrid_context = RegridderWeightsCache(old_grid, new_grid) + new_packages = [] for package in unstructured_grid_packages: - new_packages.append(package.regrid_like(new_grid)) + new_packages.append(package.regrid_like(new_grid, regrid_context)) new_idomain = new_packages[0].dataset["icelltype"] for new_package in new_packages: @@ -165,12 +172,13 @@ def test_regrid_structured_missing_dx_and_dy(): ) new_grid = grid_data_structured(np.float64, 12, 0.25) - + old_grid = grid_data_unstructured(np.float_, 1.0e-4, 5.0) + regrid_context = RegridderWeightsCache(old_grid, new_grid) with pytest.raises( ValueError, match="DataArray icelltype does not have both a dx and dy coordinates", ): - _ = package.regrid_like(new_grid) + _ = package.regrid_like(new_grid, regrid_context) def test_regrid(tmp_path: Path): @@ -203,8 +211,8 @@ def test_regrid(tmp_path: Path): save_flows=True, alternative_cell_averaging="AMT-HMK", ) - - new_npf = npf.regrid_like(k) + regrid_context = RegridderWeightsCache(k, k) + new_npf = npf.regrid_like(k, regrid_context) # check the rendered versions are the same, they contain the options new_rendered = new_npf.render(tmp_path, "regridded", None, False) @@ -242,7 +250,9 @@ def test_regridding_can_skip_validation(): # Regrid the package to a finer domain new_grid = grid_data_structured(np.float64, 1.0, 0.025) - regridded_package = sto_package.regrid_like(new_grid) + old_grid = grid_data_structured(np.float64, -20.0, 0.25) + regrid_context = RegridderWeightsCache(old_grid, new_grid) + regridded_package = sto_package.regrid_like(new_grid, regrid_context) # Check that write validation still fails for the regridded package new_bottom = deepcopy(new_grid) @@ -286,9 +296,10 @@ def test_regridding_layer_based_array(): save_flows=True, validate=False, ) - + old_grid = grid_data_structured(np.float64, -20.0, 0.25) new_grid = grid_data_structured(np.float64, 1.0, 0.025) - regridded_package = sto_package.regrid_like(new_grid) + regrid_context = RegridderWeightsCache(old_grid, new_grid) + regridded_package = sto_package.regrid_like(new_grid, regrid_context) assert ( regridded_package.dataset.coords["dx"].values[()] diff --git a/imod/tests/test_mf6/test_mf6_regrid_model.py b/imod/tests/test_mf6/test_mf6_regrid_model.py index 1229c44d5..e4e100c94 100644 --- a/imod/tests/test_mf6/test_mf6_regrid_model.py +++ b/imod/tests/test_mf6/test_mf6_regrid_model.py @@ -113,8 +113,9 @@ def test_model_regridding_can_skip_validation( """ # create a sto package with a negative storage coefficient. This would trigger a validation error if it were turned on. - storage_coefficient = grid_data_structured(np.float64, -20.0, 0.25) - specific_yield = grid_data_structured(np.float64, -30.0, 0.25) + storage_coefficient = grid_data_structured(np.float64, -20.0, 2.0) + specific_yield = grid_data_structured(np.float64, -30.0, 2.0) + sto_package = imod.mf6.StorageCoefficient( storage_coefficient, specific_yield, @@ -162,8 +163,8 @@ def test_model_regridding_can_validate( """ # Create a storage package with a negative storage coefficient. This would trigger a validation error if it were turned on. - storage_coefficient = grid_data_structured(np.float64, -20, 0.25) - specific_yield = grid_data_structured(np.float64, -30, 0.25) + storage_coefficient = grid_data_structured(np.float64, -20, 2.0) + specific_yield = grid_data_structured(np.float64, -30, 2.0) sto_package = imod.mf6.StorageCoefficient( storage_coefficient, specific_yield, diff --git a/imod/tests/test_mf6/test_mf6_unsupported_grid_operations.py b/imod/tests/test_mf6/test_mf6_unsupported_grid_operations.py index fa7465e13..38c7e853f 100644 --- a/imod/tests/test_mf6/test_mf6_unsupported_grid_operations.py +++ b/imod/tests/test_mf6/test_mf6_unsupported_grid_operations.py @@ -2,6 +2,7 @@ import pytest import xarray as xr +from imod.mf6.utilities.regrid import RegridderWeightsCache from imod.typing.grid import zeros_like @@ -55,10 +56,11 @@ def test_mf6_model_regrid_with_lakes(rectangle_with_lakes, tmp_path): def test_mf6_package_regrid_with_lakes(rectangle_with_lakes, tmp_path): simulation = rectangle_with_lakes package = simulation["GWF_1"]["lake"] + old_grid = simulation["GWF_1"].domain new_grid = finer_grid(simulation["GWF_1"].domain) - + regrid_context = RegridderWeightsCache(old_grid.sel(layer=1), new_grid.sel(layer=1)) with pytest.raises(ValueError, match="package(.+)not be regridded"): - _ = package.regrid_like(new_grid) + _ = package.regrid_like(new_grid, regrid_context) @pytest.mark.usefixtures("rectangle_with_lakes") diff --git a/imod/tests/test_mf6/test_utilities/test_regrid_utils.py b/imod/tests/test_mf6/test_utilities/test_regrid_utils.py index 4ada857f7..360a551d2 100644 --- a/imod/tests/test_mf6/test_utilities/test_regrid_utils.py +++ b/imod/tests/test_mf6/test_utilities/test_regrid_utils.py @@ -1,66 +1,81 @@ import copy +import pickle import pytest from xugrid import OverlapRegridder from imod.mf6 import Dispersion -from imod.mf6.utilities.regrid import RegridderInstancesCollection, RegridderType - - -def test_instance_collection_returns_same_instance_when_enum_and_method_match( - basic_unstructured_dis, -): - grid, _, _ = basic_unstructured_dis - new_grid = copy.deepcopy(grid) - - collection = RegridderInstancesCollection(grid, new_grid) - - first_instance = collection.get_regridder(RegridderType.OVERLAP, "harmonic_mean") - second_instance = collection.get_regridder(RegridderType.OVERLAP, "harmonic_mean") - - assert first_instance == second_instance - - -def test_instance_collection_combining_different_instantiation_parmeters( +from imod.mf6.utilities.regrid import RegridderType, RegridderWeightsCache + + +def is_equal_regridder(instance_1, instance_2) -> bool: + if type(instance_1) != type(instance_2): + return False + keys = ["__regrid_data", "__regrid_indices", "__regrid_indptr"] + for key in keys: + if hash(pickle.dumps(instance_1.weights[key])) != hash( + pickle.dumps(instance_2.weights[key]) + ): + return False + keys = ["__regrid_n", "__regrid_m", "__regrid_nnz"] + for key in keys: + if instance_1.weights[key].values[()] != instance_2.weights[key].values[()]: + return False + return True + + +def test_regridders_weight_cache_returns_similar_instance_when_enum_and_method_match( basic_unstructured_dis, ): grid, _, _ = basic_unstructured_dis new_grid = copy.deepcopy(grid) - collection = RegridderInstancesCollection(grid, new_grid) + collection = RegridderWeightsCache(grid, new_grid) - first_instance = collection.get_regridder(RegridderType.OVERLAP, "harmonic_mean") - second_instance = collection.get_regridder(OverlapRegridder, "harmonic_mean") + first_instance = collection.get_regridder( + grid, new_grid, RegridderType.OVERLAP, "harmonic_mean" + ) + second_instance = collection.get_regridder( + grid, new_grid, RegridderType.OVERLAP, "harmonic_mean" + ) - assert first_instance == second_instance + assert is_equal_regridder(first_instance, second_instance) -def test_instance_collection_returns_different_instance_when_name_does_not_match( +def test_regridders_weight_cache_combining_different_instantiation_parmeters( basic_unstructured_dis, ): grid, _, _ = basic_unstructured_dis new_grid = copy.deepcopy(grid) - collection = RegridderInstancesCollection(grid, new_grid) + collection = RegridderWeightsCache(grid, new_grid) - first_instance = collection.get_regridder(RegridderType.CENTROIDLOCATOR) - second_instance = collection.get_regridder(RegridderType.BARYCENTRIC) + first_instance = collection.get_regridder( + grid, new_grid, RegridderType.OVERLAP, "harmonic_mean" + ) + second_instance = collection.get_regridder( + grid, new_grid, OverlapRegridder, "harmonic_mean" + ) - assert first_instance != second_instance + assert is_equal_regridder(first_instance, second_instance) -def test_instance_collection_returns_different_instance_when_method_does_not_match( +def test_regridders_weight_cache_returns_different_instance_when_name_does_not_match( basic_unstructured_dis, ): grid, _, _ = basic_unstructured_dis new_grid = copy.deepcopy(grid) - collection = RegridderInstancesCollection(grid, new_grid) + collection = RegridderWeightsCache(grid, new_grid) - first_instance = collection.get_regridder(RegridderType.OVERLAP, "geometric_mean") - second_instance = collection.get_regridder(RegridderType.OVERLAP, "harmonic_mean") + first_instance = collection.get_regridder( + grid, new_grid, RegridderType.CENTROIDLOCATOR + ) + second_instance = collection.get_regridder( + grid, new_grid, RegridderType.BARYCENTRIC + ) - assert first_instance != second_instance + assert not is_equal_regridder(first_instance, second_instance) def test_non_regridder_cannot_be_instantiated( @@ -69,7 +84,7 @@ def test_non_regridder_cannot_be_instantiated( grid, _, _ = basic_unstructured_dis new_grid = copy.deepcopy(grid) - collection = RegridderInstancesCollection(grid, new_grid) + collection = RegridderWeightsCache(grid, new_grid) # we create a class with the same constructor-signature as a regridder has, but it is not a regridder # still, it is an abc.ABCMeta @@ -78,22 +93,44 @@ def __init__(self, sourcegrid, targetgrid, method): pass with pytest.raises(ValueError): - _ = collection.get_regridder(nonregridder, "geometric_mean") + _ = collection.get_regridder(grid, new_grid, nonregridder, "geometric_mean") + + +def test_regridders_weight_cache_grows_up_to_size_limit( + basic_unstructured_dis, +): + grid, _, _ = basic_unstructured_dis + new_grid = copy.deepcopy(grid) + cache_size = 3 + collection = RegridderWeightsCache(grid, new_grid, cache_size) + + _ = collection.get_regridder(grid, new_grid, RegridderType.OVERLAP, "harmonic_mean") + _ = collection.get_regridder(grid, new_grid, RegridderType.BARYCENTRIC) + _ = collection.get_regridder(grid, new_grid, RegridderType.CENTROIDLOCATOR) + assert len(collection.weights_cache) == cache_size + _ = collection.get_regridder( + grid, new_grid, RegridderType.RELATIVEOVERLAP, method="conductance" + ) + assert len(collection.weights_cache) == cache_size def test_error_messages(basic_unstructured_dis): grid, _, _ = basic_unstructured_dis new_grid = copy.deepcopy(grid) - collection = RegridderInstancesCollection(grid, new_grid) + collection = RegridderWeightsCache(grid, new_grid) with pytest.raises(TypeError): _ = collection.get_regridder( + grid, + new_grid, RegridderType.BARYCENTRIC, method="geometric_mean", ) with pytest.raises(ValueError): _ = collection.get_regridder( + grid, + new_grid, RegridderType.OVERLAP, method="non-existing function", ) @@ -105,7 +142,9 @@ def test_create_regridder_from_class_not_enum( grid, _, _ = basic_unstructured_dis new_grid = copy.deepcopy(grid) - collection = RegridderInstancesCollection(grid, new_grid) - regridder = collection.get_regridder(OverlapRegridder, "harmonic_mean") + collection = RegridderWeightsCache(grid, new_grid) + regridder = collection.get_regridder( + grid, new_grid, OverlapRegridder, "harmonic_mean" + ) assert isinstance(regridder, OverlapRegridder) diff --git a/imod/typing/grid.py b/imod/typing/grid.py index 9fc2875cb..a59b0d4d1 100644 --- a/imod/typing/grid.py +++ b/imod/typing/grid.py @@ -316,3 +316,23 @@ def get_spatial_dimension_names(grid: xu.UgridDataArray) -> list[str]: @typedispatch def get_spatial_dimension_names(grid: object) -> list[str]: return [] + + +@typedispatch +def get_grid_geometry_hash(grid: xr.DataArray) -> int: + hash_x = hash(pickle.dumps(grid["x"].values)) + hash_y = hash(pickle.dumps(grid["y"].values)) + return (hash_x, hash_y) + + +@typedispatch +def get_grid_geometry_hash(grid: xu.UgridDataArray) -> int: + hash_x = hash(pickle.dumps(grid.ugrid.grid.node_x)) + hash_y = hash(pickle.dumps(grid.ugrid.grid.node_y)) + hash_connectivity = hash(pickle.dumps(grid.ugrid.grid.node_face_connectivity)) + return (hash_x, hash_y, hash_connectivity) + + +@typedispatch +def get_grid_geometry_hash(grid: object) -> int: + raise ValueError("get_grid_geometry_hash not supported for this object.")