Skip to content

Commit

Permalink
Issue #429 regridding cache (#933)
Browse files Browse the repository at this point in the history
Fixes #429

# Description
Now caches weight sets instead of regridders when regridding models.
This saves re-computing weights, because the same weights are used for
different regridders (if the regridder type is the same and only the
regridding method is different).
For example if we have 2 overlap regridders, but one is of method "mean"
and the other one is of method "geometric_mean" they will use the same
weights.

# Checklist

- [X] Links to correct issue
- [X] Update changelog, if changes affect users
- [X] PR title starts with ``Issue #nr``, e.g. ``Issue #737``
- [X] Unit tests were added
- [ ] **If feature added**: Added/extended example

---------

Co-authored-by: Joeri van Engelen <[email protected]>
  • Loading branch information
luitjansl and JoerivanEngelen authored Mar 28, 2024
1 parent 8c997f8 commit 07bbfa7
Show file tree
Hide file tree
Showing 12 changed files with 220 additions and 106 deletions.
3 changes: 3 additions & 0 deletions docs/api/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ Fixed
- :meth'`imod.mf6.Modflow6Simulation.write` failed after splitting the simulation. This has been fixed.
- modflow options like "print flow" , "save flow" and "print input" can now be set on
:class:`imod.mf6.Well`
- when regridding a :class:`imod.mf6.Modflow6Simulation`, :class:`imod.mf6.GroundwaterFlowModel`, :class:`imod.mf6.GroundwaterTransportModel`
or a :class:`imod.mf6.package`, regridding weights are now cached and can be re-used over the different objects that are regridded.
This improves performance considerably in most use cases: when regridding is applied over the same grid cells with the same regridder type, but with different values/methods, multiple times.

Added
~~~~~
Expand Down
18 changes: 13 additions & 5 deletions examples/mf6/different_ways_to_regrid_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import xarray as xr
from example_models import create_twri_simulation

from imod.mf6.utilities.regrid import RegridderType
from imod.mf6.utilities.regrid import RegridderType, RegridderWeightsCache

# %%
# Now we create the twri simulation itself. It yields a simulation of a flow problem, with a grid of 3 layers and 15 cells in both x and y directions.
Expand Down Expand Up @@ -89,13 +89,21 @@
regridded_k_2.sel(layer=1).plot(y="y", yincrease=False, ax=ax)


# %%
# Finally, we can regrid package per package. This allows us to choose the regridding method as well.
# in this example we'll regrid the npf package manually and the rest of the packages using default methods.
# %% Finally, we can regrid package per package. This allows us to choose the
# regridding method as well. in this example we'll regrid the npf package
# manually and the rest of the packages using default methods.
#
# Note that we create a RegridderWeightsCache here. This will store the weights
# of the regridder. Using the same cache to regrid another package will lead to
# a performance increase if that package uses the same regridding method,
# because initializing a regridder is costly.

regridder_types = {"k": (RegridderType.CENTROIDLOCATOR, None)}
regrid_context = RegridderWeightsCache(model["npf"]["k"], target_grid)
npf_regridded = model["npf"].regrid_like(
target_grid=target_grid, regridder_types=regridder_types
target_grid=target_grid,
regrid_context=regrid_context,
regridder_types=regridder_types,
)
new_model["npf"] = npf_regridded

Expand Down
2 changes: 1 addition & 1 deletion imod/mf6/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from imod.mf6.ssm import SourceSinkMixing
from imod.mf6.sto import SpecificStorage, Storage, StorageCoefficient
from imod.mf6.timedis import TimeDiscretization
from imod.mf6.utilities.regrid import RegridderInstancesCollection, RegridderType
from imod.mf6.utilities.regrid import RegridderType, RegridderWeightsCache
from imod.mf6.uzf import UnsaturatedZoneFlow
from imod.mf6.wel import Well, WellDisStructured, WellDisVertices
from imod.mf6.write_context import WriteContext
14 changes: 9 additions & 5 deletions imod/mf6/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@
from imod.mf6.package import Package
from imod.mf6.statusinfo import NestedStatusInfo, StatusInfo, StatusInfoBase
from imod.mf6.utilities.mask import _mask_all_packages
from imod.mf6.utilities.regrid import (
_regrid_like,
)
from imod.mf6.utilities.regrid import RegridderWeightsCache, _regrid_like
from imod.mf6.validation import pkg_errors_to_status_info
from imod.mf6.write_context import WriteContext
from imod.schemata import ValidationError
Expand Down Expand Up @@ -468,7 +466,10 @@ def _clip_box_packages(
return clipped

def regrid_like(
self, target_grid: GridDataArray, validate: bool = True
self,
target_grid: GridDataArray,
validate: bool = True,
regrid_context: Optional[RegridderWeightsCache] = None,
) -> "Modflow6Model":
"""
Creates a model by regridding the packages of this model to another discretization.
Expand All @@ -482,13 +483,16 @@ def regrid_like(
a grid defined over the same discretization as the one we want to regrid the package to
validate: bool
set to true to validate the regridded packages
regrid_context: Optional RegridderWeightsCache
stores regridder weights for different regridders. Can be used to speed up regridding,
if the same regridders are used several times for regridding different arrays.
Returns
-------
a model with similar packages to the input model, and with all the data-arrays regridded to another discretization,
similar to the one used in input argument "target_grid"
"""
return _regrid_like(self, target_grid, validate)
return _regrid_like(self, target_grid, validate, regrid_context)

def mask_all_packages(
self,
Expand Down
7 changes: 6 additions & 1 deletion imod/mf6/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from imod.mf6.utilities.mask import _mask
from imod.mf6.utilities.regrid import (
RegridderType,
RegridderWeightsCache,
_regrid_like,
)
from imod.mf6.utilities.schemata import filter_schemata_dict
Expand Down Expand Up @@ -554,6 +555,7 @@ def mask(self, mask: GridDataArray) -> Any:
def regrid_like(
self,
target_grid: GridDataArray,
regrid_context: RegridderWeightsCache,
regridder_types: Optional[dict[str, Tuple[RegridderType, str]]] = None,
) -> "Package":
"""
Expand All @@ -580,14 +582,17 @@ def regrid_like(
regridder_types: dict(str->(regridder type,str))
dictionary mapping arraynames (str) to a tuple of regrid type (a specialization class of BaseRegridder) and function name (str)
this dictionary can be used to override the default mapping method.
regrid_context: Optional RegridderWeightsCache
stores regridder weights for different regridders. Can be used to speed up regridding,
if the same regridders are used several times for regridding different arrays.
Returns
-------
a package with the same options as this package, and with all the data-arrays regridded to another discretization,
similar to the one used in input argument "target_grid"
"""
try:
result = _regrid_like(self, target_grid, regridder_types)
result = _regrid_like(self, target_grid, regrid_context, regridder_types)
except ValueError as e:
raise e
except Exception:
Expand Down
98 changes: 57 additions & 41 deletions imod/mf6/utilities/regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,48 +23,31 @@
from imod.mf6.utilities.clip import clip_by_grid
from imod.mf6.utilities.regridding_types import RegridderType
from imod.schemata import ValidationError
from imod.typing.grid import GridDataArray, ones_like
from imod.typing.grid import GridDataArray, get_grid_geometry_hash, ones_like


class RegridderInstancesCollection:
class RegridderWeightsCache:
"""
This class stores any number of regridders that can regrid a single source grid to a single target grid.
By storing the regridders, we make sure the regridders can be re-used for different arrays on the same grid.
Regridders are stored based on their type (`see these docs<https://deltares.github.io/xugrid/examples/regridder_overview.html>`_) and planar coordinates (x, y).
This is important because computing the regridding weights is a costly affair.
"""

def __init__(
self,
source_grid: Union[xr.DataArray, xu.UgridDataArray],
target_grid: Union[xr.DataArray, xu.UgridDataArray],
max_cache_size: int = 6,
) -> None:
self.regridder_instances: dict[
tuple[type[BaseRegridder], Optional[str]], BaseRegridder
] = {}
self._source_grid = source_grid
self._target_grid = target_grid

def __has_regridder(
self, regridder_type: type[BaseRegridder], method: Optional[str] = None
) -> bool:
return (regridder_type, method) in self.regridder_instances.keys()

def __get_existing_regridder(
self, regridder_type: type[BaseRegridder], method: Optional[str]
) -> BaseRegridder:
if self.__has_regridder(regridder_type, method):
return self.regridder_instances[(regridder_type, method)]
raise ValueError("no existing regridder of type " + str(regridder_type))

def __create_regridder(
self, regridder_type: type[BaseRegridder], method: Optional[str]
) -> BaseRegridder:
method_args = () if method is None else (method,)

self.regridder_instances[(regridder_type, method)] = regridder_type(
self._source_grid, self._target_grid, *method_args
)
return self.regridder_instances[(regridder_type, method)]
self.weights_cache = {}
self.max_cache_size = max_cache_size

def __get_regridder_class(
self, regridder_type: RegridderType | BaseRegridder
Expand All @@ -82,6 +65,8 @@ def __get_regridder_class(

def get_regridder(
self,
source_grid: GridDataArray,
target_grid: GridDataArray,
regridder_type: Union[RegridderType, BaseRegridder],
method: Optional[str] = None,
) -> BaseRegridder:
Expand All @@ -107,10 +92,31 @@ def get_regridder(
"""
regridder_class = self.__get_regridder_class(regridder_type)

if not self.__has_regridder(regridder_class, method):
self.__create_regridder(regridder_class, method)
if "layer" not in source_grid.coords and "layer" in target_grid.coords:
target_grid = target_grid.drop_vars("layer")

source_hash = get_grid_geometry_hash(source_grid)
target_hash = get_grid_geometry_hash(target_grid)
key = (source_hash, target_hash, regridder_class)
if not key in self.weights_cache.keys():
if len(self.weights_cache) >= self.max_cache_size:
self.remove_first_regridder()
kwargs = {"source": source_grid, "target": target_grid}
if method is not None:
kwargs["method"] = method
regridder = regridder_class(**kwargs)
self.weights_cache[key] = regridder.weights
else:
kwargs = {"weights": self.weights_cache[key], "target": target_grid}
if method is not None:
kwargs["method"] = method
regridder = regridder_class.from_weights(**kwargs)

return regridder

return self.__get_existing_regridder(regridder_class, method)
def remove_first_regridder(self):
keys = list(self.weights_cache.keys())
self.weights_cache.pop(keys[0])


def assign_coord_if_present(
Expand All @@ -131,7 +137,7 @@ def assign_coord_if_present(
def _regrid_array(
package: IRegridPackage,
varname: str,
regridder_collection: RegridderInstancesCollection,
regridder_collection: RegridderWeightsCache,
regridder_name: str,
regridder_function: str,
target_grid: GridDataArray,
Expand Down Expand Up @@ -167,6 +173,8 @@ def _regrid_array(

# obtain an instance of a regridder for the chosen method
regridder = regridder_collection.get_regridder(
package.dataset[varname],
target_grid,
regridder_name,
regridder_function,
)
Expand Down Expand Up @@ -205,6 +213,7 @@ def _get_unique_regridder_types(model: IModel) -> defaultdict[RegridderType, lis
def _regrid_like(
package: IRegridPackage,
target_grid: GridDataArray,
regrid_context: RegridderWeightsCache,
regridder_types: Optional[dict[str, tuple[RegridderType, str]]] = None,
) -> IPackage:
"""
Expand All @@ -231,6 +240,9 @@ def _regrid_like(
regridder_types: dict(str->(regridder type,str))
dictionary mapping arraynames (str) to a tuple of regrid type (a specialization class of BaseRegridder) and function name (str)
this dictionary can be used to override the default mapping method.
regrid_context: RegridderWeightsCache
stores regridder weights for different regridders. Can be used to speed up regridding,
if the same regridders are used several times for regridding different arrays.
Returns
-------
Expand All @@ -245,10 +257,6 @@ def _regrid_like(
if hasattr(package, "auxiliary_data_fields"):
remove_expanded_auxiliary_variables_from_dataset(package)

regridder_collection = RegridderInstancesCollection(
package.dataset, target_grid=target_grid
)

regridder_settings = package.get_regrid_methods()
if regridder_types is not None:
regridder_settings.update(regridder_types)
Expand All @@ -269,7 +277,7 @@ def _regrid_like(
new_package_data[varname] = _regrid_array(
package,
varname,
regridder_collection,
regrid_context,
regridder_name,
regridder_function,
target_grid,
Expand All @@ -289,7 +297,10 @@ def _regrid_like(

@typedispatch # type: ignore[no-redef]
def _regrid_like(
model: IModel, target_grid: GridDataArray, validate: bool = True
model: IModel,
target_grid: GridDataArray,
validate: bool = True,
regrid_context: Optional[RegridderWeightsCache] = None,
) -> IModel:
"""
Creates a model by regridding the packages of this model to another discretization.
Expand All @@ -303,6 +314,9 @@ def _regrid_like(
a grid defined over the same discretization as the one we want to regrid the package to
validate: bool
set to true to validate the regridded packages
regrid_context: Optional RegridderWeightsCache
stores regridder weights for different regridders. Can be used to speed up regridding,
if the same regridders are used several times for regridding different arrays.
Returns
-------
Expand All @@ -315,17 +329,18 @@ def _regrid_like(
f"regridding this model cannot be done due to the presence of package {error_with_object_name}"
)
new_model = model.__class__()

if regrid_context is None:
regrid_context = RegridderWeightsCache(model.domain, target_grid)
for pkg_name, pkg in model.items():
if isinstance(pkg, (IRegridPackage, ILineDataPackage, IPointDataPackage)):
new_model[pkg_name] = pkg.regrid_like(target_grid)
new_model[pkg_name] = pkg.regrid_like(target_grid, regrid_context)
else:
raise NotImplementedError(
f"regridding is not implemented for package {pkg_name} of type {type(pkg)}"
)

methods = _get_unique_regridder_types(model)
output_domain = _get_regridding_domain(model, target_grid, methods)
output_domain = _get_regridding_domain(model, target_grid, regrid_context, methods)
new_model.mask_all_packages(output_domain)
new_model.purge_empty_packages()
if validate:
Expand Down Expand Up @@ -371,6 +386,9 @@ def _regrid_like(
raise ValueError(
"Unable to regrid simulation. Regridding can only be done on simulations that have a single flow model."
)
flow_models = simulation.get_models_of_type("gwf6")
old_grid = list(flow_models.values())[0].domain
regrid_context = RegridderWeightsCache(old_grid, target_grid)

models = simulation.get_models()
for model_name, model in models.items():
Expand All @@ -383,7 +401,7 @@ def _regrid_like(
result = simulation.__class__(regridded_simulation_name)
for key, item in simulation.items():
if isinstance(item, IModel):
result[key] = item.regrid_like(target_grid, validate)
result[key] = item.regrid_like(target_grid, validate, regrid_context)
elif key == "gwtgwf_exchanges":
pass
elif isinstance(item, IPackage) and not isinstance(item, IRegridPackage):
Expand Down Expand Up @@ -428,6 +446,7 @@ def _regrid_like(package: object, target_grid: GridDataArray, *_) -> None:
def _get_regridding_domain(
model: IModel,
target_grid: GridDataArray,
regrid_context: RegridderWeightsCache,
methods: defaultdict[RegridderType, list[str]],
) -> GridDataArray:
"""
Expand All @@ -436,12 +455,9 @@ def _get_regridding_domain(
cells that all regridders consider active.
"""
idomain = model.domain
regridder_collection = RegridderInstancesCollection(
idomain, target_grid=target_grid
)
included_in_all = ones_like(target_grid)
regridders = [
regridder_collection.get_regridder(regriddertype, function)
regrid_context.get_regridder(idomain, target_grid, regriddertype, function)
for regriddertype, functionlist in methods.items()
for function in functionlist
]
Expand Down
7 changes: 6 additions & 1 deletion imod/tests/test_mf6/test_mf6_hfb.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
LayeredHorizontalFlowBarrierResistance,
)
from imod.mf6.hfb import to_connected_cells_dataset
from imod.mf6.utilities.regrid import RegridderWeightsCache
from imod.tests.fixtures.flow_basic_fixture import BasicDisSettings
from imod.typing.grid import ones_like

Expand Down Expand Up @@ -136,7 +137,11 @@ def test_to_mf6_creates_mf6_adapter(
else:
idomain_clipped = idomain.sel(x=slice(None, 54.0))

hfb_clipped = hfb.regrid_like(idomain_clipped.sel(layer=1))
regrid_context = RegridderWeightsCache(
idomain.sel(layer=1), idomain_clipped.sel(layer=1)
)

hfb_clipped = hfb.regrid_like(idomain_clipped.sel(layer=1), regrid_context)

# Assert
x, y = hfb_clipped.dataset["geometry"].values[0].xy
Expand Down
Loading

0 comments on commit 07bbfa7

Please sign in to comment.