Skip to content

Commit

Permalink
FEniCS-style bcs (#3995)
Browse files Browse the repository at this point in the history
* FEniCS-style bcs

* Delay Dirichlet Lifting for LVP, MG, and Fieldsplit

* LinearSolver: support pre_apply_bcs

* Update firedrake/assemble.py

Co-authored-by: ksagiyam <[email protected]>

---------

Co-authored-by: ksagiyam <[email protected]>
  • Loading branch information
pbrubeck and ksagiyam authored Feb 12, 2025
1 parent 3331ed1 commit f473388
Show file tree
Hide file tree
Showing 11 changed files with 277 additions and 146 deletions.
14 changes: 2 additions & 12 deletions firedrake/adjoint_utils/blocks/solving.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,13 +197,10 @@ def _assemble_dFdu_adj(self, dFdu_adj_form, **kwargs):

def _assemble_and_solve_adj_eq(self, dFdu_adj_form, dJdu, compute_bdy):
dJdu_copy = dJdu.copy()
# Homogenize and apply boundary conditions on adj_dFdu and dJdu.
# Homogenize and apply boundary conditions on adj_dFdu.
bcs = self._homogenize_bcs()
dFdu = firedrake.assemble(dFdu_adj_form, bcs=bcs, **self.assemble_kwargs)

for bc in bcs:
bc.zero(dJdu)

adj_sol = firedrake.Function(self.function_space)
firedrake.solve(
dFdu, adj_sol, dJdu, *self.adj_args, **self.adj_kwargs
Expand Down Expand Up @@ -526,18 +523,11 @@ def _forward_solve(self, lhs, rhs, func, bcs):
return func

def _assembled_solve(self, lhs, rhs, func, bcs, **kwargs):
rhs_func = rhs.riesz_representation(riesz_map="l2")
for bc in bcs:
bc.apply(rhs_func)
rhs.assign(rhs_func.riesz_representation(riesz_map="l2"))
firedrake.solve(lhs, func, rhs, **kwargs)
return func

def recompute_component(self, inputs, block_variable, idx, prepared):
lhs = prepared[0]
rhs = prepared[1]
func = prepared[2]
bcs = prepared[3]
lhs, rhs, func, bcs = prepared
result = self._forward_solve(lhs, rhs, func, bcs)
if isinstance(block_variable.checkpoint, firedrake.Function):
result = block_variable.checkpoint.assign(result)
Expand Down
77 changes: 51 additions & 26 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ def assemble(expr, *args, **kwargs):
`matrix.Matrix`.
is_base_form_preprocessed : bool
If `True`, skip preprocessing of the form.
current_state : firedrake.function.Function or None
If provided and ``zero_bc_nodes == False``, the boundary condition
nodes of the output are set to the residual of the boundary conditions
computed as ``current_state`` minus the boundary condition value.
Returns
-------
Expand Down Expand Up @@ -130,16 +134,21 @@ def assemble(expr, *args, **kwargs):
"""
if args:
raise RuntimeError(f"Got unexpected args: {args}")
tensor = kwargs.pop("tensor", None)
return get_assembler(expr, *args, **kwargs).assemble(tensor=tensor)

assemble_kwargs = {}
for key in ("tensor", "current_state"):
if key in kwargs:
assemble_kwargs[key] = kwargs.pop(key, None)
return get_assembler(expr, *args, **kwargs).assemble(**assemble_kwargs)


def get_assembler(form, *args, **kwargs):
"""Create an assembler.
Notes
-----
See `assemble` for descriptions of the parameters. ``tensor`` should not be passed to this function.
See `assemble` for descriptions of the parameters. ``tensor`` and
``current_state`` should not be passed to this function.
"""
is_base_form_preprocessed = kwargs.pop('is_base_form_preprocessed', False)
Expand Down Expand Up @@ -187,13 +196,15 @@ class ExprAssembler(object):
def __init__(self, expr):
self._expr = expr

def assemble(self, tensor=None):
def assemble(self, tensor=None, current_state=None):
"""Assemble the pointwise expression.
Parameters
----------
tensor : firedrake.function.Function or firedrake.cofunction.Cofunction or matrix.MatrixBase
Output tensor.
current_state : None
Ignored by this class.
Returns
-------
Expand All @@ -205,6 +216,7 @@ def assemble(self, tensor=None):
from ufl.checks import is_scalar_constant_expression

assert tensor is None
assert current_state is None
expr = self._expr
# Get BaseFormOperators (e.g. `Interpolate` or `ExternalOperator`)
base_form_operators = extract_base_form_operators(expr)
Expand Down Expand Up @@ -274,13 +286,16 @@ def allocate(self):
"""Allocate memory for the output tensor."""

@abc.abstractmethod
def assemble(self, tensor=None):
def assemble(self, tensor=None, current_state=None):
"""Assemble the form.
Parameters
----------
tensor : firedrake.cofunction.Cofunction or firedrake.function.Function or matrix.MatrixBase
Output tensor to contain the result of assembly; if `None`, a tensor of appropriate type is created.
current_state : firedrake.function.Function or None
If provided, the boundary condition nodes are set to the boundary condition residual
computed as ``current_state`` minus the boundary condition value.
Returns
-------
Expand Down Expand Up @@ -358,13 +373,16 @@ def allocation_integral_types(self):
else:
return self._allocation_integral_types

def assemble(self, tensor=None):
def assemble(self, tensor=None, current_state=None):
"""Assemble the form.
Parameters
----------
tensor : firedrake.cofunction.Cofunction or firedrake.function.Function or matrix.MatrixBase
Output tensor to contain the result of assembly.
current_state : firedrake.function.Function or None
If provided, the boundary condition nodes are set to the boundary condition residual
computed as ``current_state`` minus the boundary condition value.
Returns
-------
Expand All @@ -389,7 +407,7 @@ def visitor(e, *operands):
rank = len(self._form.arguments())
if rank == 1 and not isinstance(result, ufl.ZeroBaseForm):
for bc in self._bcs:
bc.zero(result)
OneFormAssembler._apply_bc(self, result, bc, u=current_state)

if tensor:
BaseFormAssembler.update_tensor(result, tensor)
Expand Down Expand Up @@ -968,13 +986,16 @@ def __init__(self, form, bcs=None, form_compiler_parameters=None, needs_zeroing=
super().__init__(form, bcs=bcs, form_compiler_parameters=form_compiler_parameters)
self._needs_zeroing = needs_zeroing

def assemble(self, tensor=None):
def assemble(self, tensor=None, current_state=None):
"""Assemble the form.
Parameters
----------
tensor : firedrake.cofunction.Cofunction or matrix.MatrixBase
Output tensor to contain the result of assembly; if `None`, a tensor of appropriate type is created.
current_state : firedrake.function.Function or None
If provided, the boundary condition nodes are set to the boundary condition residual
computed as ``current_state`` minus the boundary condition value.
Returns
-------
Expand All @@ -998,12 +1019,12 @@ def assemble(self, tensor=None):
self.execute_parloops(tensor)

for bc in self._bcs:
self._apply_bc(tensor, bc)
self._apply_bc(tensor, bc, u=current_state)

return self.result(tensor)

@abc.abstractmethod
def _apply_bc(self, tensor, bc):
def _apply_bc(self, tensor, bc, u=None):
"""Apply boundary condition."""

@abc.abstractmethod
Expand Down Expand Up @@ -1138,7 +1159,7 @@ def allocate(self):
comm=self._form.ufl_domains()[0]._comm
)

def _apply_bc(self, tensor, bc):
def _apply_bc(self, tensor, bc, u=None):
pass

def _check_tensor(self, tensor):
Expand Down Expand Up @@ -1199,26 +1220,29 @@ def allocate(self):
else:
raise RuntimeError(f"Not expected: found rank = {rank} and diagonal = {self._diagonal}")

def _apply_bc(self, tensor, bc):
def _apply_bc(self, tensor, bc, u=None):
# TODO Maybe this could be a singledispatchmethod?
if isinstance(bc, DirichletBC):
self._apply_dirichlet_bc(tensor, bc)
if self._diagonal:
bc.set(tensor, self._weight)
elif self._zero_bc_nodes:
bc.zero(tensor)
else:
# The residual belongs to a mixed space that is dual on the boundary nodes
# and primal on the interior nodes. Therefore, this is a type-safe operation.
r = tensor.riesz_representation("l2")
bc.apply(r, u=u)
elif isinstance(bc, EquationBCSplit):
bc.zero(tensor)
type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False,
zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal, weight=self._weight).assemble(tensor=tensor)
OneFormAssembler(bc.f, bcs=bc.bcs,
form_compiler_parameters=self._form_compiler_params,
needs_zeroing=False,
zero_bc_nodes=self._zero_bc_nodes,
diagonal=self._diagonal,
weight=self._weight).assemble(tensor=tensor, current_state=u)
else:
raise AssertionError

def _apply_dirichlet_bc(self, tensor, bc):
if self._diagonal:
bc.set(tensor, self._weight)
elif not self._zero_bc_nodes:
# NOTE this only works if tensor is a Function and not a Cofunction
bc.apply(tensor)
else:
bc.zero(tensor)

def _check_tensor(self, tensor):
if tensor.function_space() != self._form.arguments()[0].function_space().dual():
raise ValueError("Form's argument does not match provided result tensor")
Expand Down Expand Up @@ -1430,7 +1454,8 @@ def _all_assemblers(self):
all_assemblers.extend(_assembler._all_assemblers)
return tuple(all_assemblers)

def _apply_bc(self, tensor, bc):
def _apply_bc(self, tensor, bc, u=None):
assert u is None
op2tensor = tensor.M
spaces = tuple(a.function_space() for a in tensor.a.arguments())
V = bc.function_space()
Expand Down Expand Up @@ -1534,7 +1559,7 @@ def allocate(self):
options_prefix=self._options_prefix,
appctx=self._appctx or {})

def assemble(self, tensor=None):
def assemble(self, tensor=None, current_state=None):
if tensor is None:
tensor = self.allocate()
else:
Expand Down
28 changes: 14 additions & 14 deletions firedrake/bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def extract_form(self, form_type):
# DirichletBC is directly used in assembly.
return self

def _as_nonlinear_variational_problem_arg(self):
def _as_nonlinear_variational_problem_arg(self, is_linear=False):
return self


Expand Down Expand Up @@ -501,15 +501,16 @@ def __init__(self, *args, bcs=None, J=None, Jp=None, V=None, is_linear=False, Jp
# linear
if isinstance(eq.lhs, ufl.Form) and isinstance(eq.rhs, ufl.Form):
J = eq.lhs
L = eq.rhs
Jp = Jp or J
if eq.rhs == 0:
if L == 0 or L.empty():
F = ufl_expr.action(J, u)
else:
if not isinstance(eq.rhs, (ufl.Form, slate.slate.TensorBase)):
raise TypeError("Provided BC RHS is a '%s', not a Form or Slate Tensor" % type(eq.rhs).__name__)
if len(eq.rhs.arguments()) != 1:
if not isinstance(L, (ufl.BaseForm, slate.slate.TensorBase)):
raise TypeError("Provided BC RHS is a '%s', not a BaseForm or Slate Tensor" % type(L).__name__)
if len(L.arguments()) != 1:
raise ValueError("Provided BC RHS is not a linear form")
F = ufl_expr.action(J, u) - eq.rhs
F = ufl_expr.action(J, u) - L
self.is_linear = True
# nonlinear
else:
Expand All @@ -531,9 +532,7 @@ def __init__(self, *args, bcs=None, J=None, Jp=None, V=None, is_linear=False, Jp
# reconstruction for splitting `solving_utils.split`
self.Jp_eq_J = Jp_eq_J
self.is_linear = is_linear
self._F = args[0]
self._J = args[1]
self._Jp = args[2]
self._F, self._J, self._Jp = args[:3]
else:
raise TypeError("Wrong EquationBC arguments")

Expand Down Expand Up @@ -562,7 +561,7 @@ def reconstruct(self, V, subu, u, field, is_linear):
if all([_F is not None, _J is not None, _Jp is not None]):
return EquationBC(_F, _J, _Jp, Jp_eq_J=self.Jp_eq_J, is_linear=is_linear)

def _as_nonlinear_variational_problem_arg(self):
def _as_nonlinear_variational_problem_arg(self, is_linear=False):
return self


Expand Down Expand Up @@ -654,19 +653,20 @@ def reconstruct(self, field=None, V=None, subu=None, u=None, row_field=None, col
ebc.add(bc_temp)
return ebc

def _as_nonlinear_variational_problem_arg(self):
def _as_nonlinear_variational_problem_arg(self, is_linear=False):
# NonlinearVariationalProblem expects EquationBC, not EquationBCSplit.
# -- This method is required when NonlinearVariationalProblem is constructed inside PC.
if len(self.f.arguments()) != 2:
raise NotImplementedError(f"Not expecting a form of rank {len(self.f.arguments())} (!= 2)")
J = self.f
Vcol = J.arguments()[-1].function_space()
u = firedrake.Function(Vcol)
F = ufl_expr.action(J, u)
Vrow = self._function_space
sub_domain = self.sub_domain
bcs = tuple(bc._as_nonlinear_variational_problem_arg() for bc in self.bcs)
return EquationBC(F == 0, u, sub_domain, bcs=bcs, J=J, V=Vrow)
bcs = tuple(bc._as_nonlinear_variational_problem_arg(is_linear=is_linear) for bc in self.bcs)
lhs = J if is_linear else ufl_expr.action(J, u)
rhs = ufl.Form([]) if is_linear else 0
return EquationBC(lhs == rhs, u, sub_domain, bcs=bcs, J=J, V=Vrow)


@PETSc.Log.EventDecorator()
Expand Down
Loading

0 comments on commit f473388

Please sign in to comment.