Skip to content

Commit

Permalink
Replace .at() interpolation in 2D callbacks (#373)
Browse files Browse the repository at this point in the history
- .at() raises an error when the number of processors used is large, use a VertexOnlyMesh instead for interpolation
  • Loading branch information
cpjordan authored Dec 4, 2024
1 parent 3e67547 commit cf66f38
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 4 deletions.
25 changes: 21 additions & 4 deletions thetis/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,9 @@ def __init__(self, solver_obj,
self.field_names = field_names
self._name = name

# initialise interpolation functions using vom_interpolator_functions
self.interp_functions = vom_interpolator_functions(solver_obj, field_names, detector_locations)

@property
def name(self):
return self._name
Expand All @@ -539,7 +542,8 @@ def variable_names(self):

def _values_per_field(self, values):
"""
Given all values evaulated in a detector location, return the values per field"""
Given all values evaulated in a detector location, return the values per field
"""
i = 0
result = []
for dim in self.field_dims:
Expand All @@ -554,7 +558,11 @@ def message_str(self, *args):
for name, values in zip(self.detector_names, args))

def _evaluate_field(self, field_name):
return self.solver_obj.fields[field_name](self.detector_locations)
field = self.solver_obj.fields[field_name]
f_at_points, f_at_input_points = self.interp_functions[field_name]
f_at_points.interpolate(field)
f_at_input_points.interpolate(f_at_points)
return f_at_input_points.dat.data_ro[:]

def __call__(self):
"""
Expand Down Expand Up @@ -684,10 +692,16 @@ def _initialize(self):
xyz = (self.x, self.y, self.z) if self.on_sphere else (self.x, self.y)
self.xyz = numpy.array([xyz])

# initialise interpolation functions using vom_interpolator_functions
self.interp_functions = vom_interpolator_functions(self.solver_obj, self.fieldnames, self.xyz)

# test evaluation
try:
if self.eval_func is None:
self.solver_obj.fields.bathymetry_2d.at(self.xyz, tolerance=self.tolerance)
field = self.solver_obj.fields[self.fieldnames[0]]
f_at_points, f_at_input_points = self.interp_functions[self.fieldnames[0]]
f_at_points.interpolate(field)
f_at_input_points.interpolate(f_at_points)
else:
self.eval_func(self.solver_obj.fields.bathymetry_2d, self.xyz, tolerance=self.tolerance)
except PointNotInDomainError as e:
Expand All @@ -707,7 +721,10 @@ def __call__(self):
try:
field = self.solver_obj.fields[fieldname]
if self.eval_func is None:
val = field.at(self.xyz, tolerance=self.tolerance)
f_at_points, f_at_input_points = self.interp_functions[fieldname]
f_at_points.interpolate(field)
f_at_input_points.interpolate(f_at_points)
val = f_at_input_points.dat.data_ro[:]
else:
val = self.eval_func(field, self.xyz, tolerance=self.tolerance)
arr = numpy.array(val)
Expand Down
35 changes: 35 additions & 0 deletions thetis/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -1154,3 +1154,38 @@ def form2indicator(F):
},
)
return indicator


@PETSc.Log.EventDecorator("thetis.vom_interpolator_functions")
def vom_interpolator_functions(solver_obj, field_names, locations):
r"""
Creates function spaces and associated Functions for interpolation
on a VertexOnlyMesh (VOM) and returns them for reuse.
:arg solver_obj: Thetis solver object
:arg field_names: List of field names to create functions for.
:arg locations: List of locations for interpolation.
:return: A dictionary mapping field names to a tuple of (f_at_points, f_at_input_points)
which are Functions for interpolation.
"""
vom = VertexOnlyMesh(solver_obj.mesh2d, locations, redundant=True)

functions_dict = {}

for field_name in field_names:
field = solver_obj.fields[field_name]

if isinstance(field.function_space().ufl_element(), VectorElement):
P0DG = VectorFunctionSpace(vom, "DG", 0)
P0DG_input_ordering = VectorFunctionSpace(vom.input_ordering, "DG", 0)
else:
P0DG = FunctionSpace(vom, "DG", 0)
P0DG_input_ordering = FunctionSpace(vom.input_ordering, "DG", 0)

f_at_points = Function(P0DG)
f_at_input_points = Function(P0DG_input_ordering)

# Store the Functions in the dictionary keyed by field name
functions_dict[field_name] = (f_at_points, f_at_input_points)

return functions_dict

0 comments on commit cf66f38

Please sign in to comment.