Skip to content

Commit

Permalink
remove apply_default_restriction rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
ksagiyam committed Dec 10, 2024
1 parent f9412a0 commit 4ced3a7
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 23 deletions.
76 changes: 60 additions & 16 deletions ufl/algorithms/apply_restrictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,23 @@
class RestrictionPropagator(MultiFunction):
"""Restriction propagator."""

def __init__(self, side=None, assume_single_integral_type=True, apply_default=True):
def __init__(self, side=None, assume_single_integral_type=True, apply_default=True, default_restriction=None):
"""Initialise."""
MultiFunction.__init__(self)
self.current_restriction = side
self.default_restriction = "+" if assume_single_integral_type else "?"
if default_restriction is None:
default_restriction = "+" if assume_single_integral_type else "?"
self.default_restriction = default_restriction
self.apply_default = apply_default
# Caches for propagating the restriction with map_expr_dag
self.vcaches = {"+": {}, "-": {}, "|": {}, "?": {}}
self.rcaches = {"+": {}, "-": {}, "|": {}, "?": {}}
if self.current_restriction is None:
self._rp = {
"+": RestrictionPropagator("+", assume_single_integral_type, apply_default),
"-": RestrictionPropagator("-", assume_single_integral_type, apply_default),
"|": RestrictionPropagator("|", assume_single_integral_type, apply_default),
"?": RestrictionPropagator("?", assume_single_integral_type, apply_default),
"+": RestrictionPropagator("+", assume_single_integral_type, apply_default, default_restriction),
"-": RestrictionPropagator("-", assume_single_integral_type, apply_default, default_restriction),
"|": RestrictionPropagator("|", assume_single_integral_type, apply_default, default_restriction),
"?": RestrictionPropagator("?", assume_single_integral_type, apply_default, default_restriction),
}
self.assume_single_integral_type = assume_single_integral_type

Expand Down Expand Up @@ -71,6 +73,9 @@ def _require_restriction(self, o):
if self.current_restriction is not None:
return o(self.current_restriction)
elif not self.assume_single_integral_type:
# If integration if over interior facet of meshA and exterior facet of meshB,
# arguments (say) on meshA must be restricted, but those on meshB do not
# need to be.
return o
else:
raise ValueError(f"Discontinuous type {o._ufl_class_.__name__} must be restricted.")
Expand All @@ -84,7 +89,19 @@ def _default_restricted(self, o):
domain = extract_unique_domain(o, expand_mixed_mesh=False)
if isinstance(domain, MixedMesh):
raise RuntimeError(f"Not expecting a terminal object on a mixed mesh at this stage: found {repr(o)}")
return o(self.default_restriction[domain])
if isinstance(self.default_restriction, dict):
if domain not in self.default_restriction:
raise RuntimeError(f"Integral type on {domain} not known")
r = self.default_restriction[domain]
if r is None:
return o
elif r in ["+", "-"]:
return o(r)
else:
raise RuntimeError(f"Unknown default restriction {r} on domain {domain}")
else:
# conventional "+" default:
return o(self.default_restriction)
else:
return o

Expand All @@ -93,12 +110,26 @@ def _opposite(self, o):
If the current restriction is different swap the sign, require a side to be set.
"""
if self.current_restriction is None:
raise ValueError(f"Discontinuous type {o._ufl_class_.__name__} must be restricted.")
elif self.current_restriction == self.default_restriction:
return o(self.default_restriction)
if isinstance(self.default_restriction, dict):
domain = extract_unique_domain(o, expand_mixed_mesh=False)
if isinstance(domain, MixedMesh):
raise RuntimeError(f"Not expecting a terminal object on a mixed mesh at this stage: found {repr(o)}")
if domain not in self.default_restriction:
raise RuntimeError(f"Integral type on {domain} not known")
r = self.default_restriction[domain]
else:
return -o(self.default_restriction)
r = self.default_restriction
if r is None:
if self.current_restriction is not None:
raise ValueError(f"Expecting current_restriction None: got {self.current_restriction}")
return o
else:
if self.current_restriction is None:
raise ValueError(f"Discontinuous type {o._ufl_class_.__name__} must be restricted.")
elif self.current_restriction == r:
return o(self.default_restriction)
else:
return -o(self.default_restriction)

def _missing_rule(self, o):
"""Raise an error."""
Expand Down Expand Up @@ -206,7 +237,7 @@ def facet_normal(self, o):
return self._require_restriction(o)


def apply_restrictions(expression, assume_single_integral_type=True, apply_default=True):
def apply_restrictions(expression, assume_single_integral_type=True, apply_default=True, default_restriction=None):
"""Propagate restriction nodes to wrap differential terminals directly."""
if assume_single_integral_type:
integral_types = [
Expand All @@ -217,7 +248,7 @@ def apply_restrictions(expression, assume_single_integral_type=True, apply_defau
# the integral type of a given function; e.g., the former can be
# ``exterior_facet`` and the latter ``interior_facet``.
integral_types = None
rules = RestrictionPropagator(assume_single_integral_type=assume_single_integral_type, apply_default=apply_default)
rules = RestrictionPropagator(assume_single_integral_type=assume_single_integral_type, apply_default=apply_default, default_restriction=default_restriction)
if isinstance(expression, FormData):
for integral_data in expression.integral_data:
integral_data.integrals = tuple(
Expand Down Expand Up @@ -347,15 +378,28 @@ def to_be_restricted(self, o):
return mt
elif integral_type == "exterior_facet":
return SingleValueRestricted(mt)
elif integral_type == "interial_facet":
elif integral_type == "interior_facet":
return PositiveRestricted(mt)
else:
raise RuntimeError(f"Unknown integral type: {integral_type}")


def replace_to_be_restricted(integral_data):
new_integrals = []
rule = ToBeRestrectedReplacer(integral_data.domain_integral_type_map)
#rule = ToBeRestrectedReplacer(integral_data.domain_integral_type_map)
rule = RestrictionPropagator(
side=None,
assume_single_integral_type=False,
apply_default=True,
default_restriction={
domain: {
"cell": None,
"exterior_facet": None,
"interior_facet": "+",
}[integral_type]
for domain, integral_type in integral_data.domain_integral_type_map.items()
},
)
for integral in integral_data.integrals:
integrand = map_expr_dag(rule, integral.integrand())
new_integrals.append(integral.reconstruct(integrand=integrand))
Expand Down
39 changes: 32 additions & 7 deletions ufl/algorithms/compute_form_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,13 +338,6 @@ def compute_form_data(

form = apply_coordinate_derivatives(form)

# Propagate restrictions to terminals
if do_apply_restrictions:
if do_assume_single_integral_type:
form = apply_restrictions(form, apply_default=do_apply_default_restrictions)
else:
form = apply_restrictions(form, assume_single_integral_type=have_single_domain, apply_default=False)

# If in real mode, remove any complex nodes introduced during form processing.
if not complex_mode:
form = remove_complex_nodes(form)
Expand All @@ -353,6 +346,38 @@ def compute_form_data(
# Most of the heavy lifting is done above in group_form_integrals.
self.integral_data = build_integral_data(form.integrals())

# Propagate restrictions to terminals
if do_apply_restrictions:
if do_assume_single_integral_type or have_single_domain:
for itg_data in self.integral_data:
new_integrals = []
for integral in itg_data.integrals:
new_integral = apply_restrictions(
integral,
apply_default=do_apply_default_restrictions,
default_restriction={
itg_data.domain: {
"cell": None,
"exterior_facet": None,
"interior_facet": "+",
}[itg_data.integral_type]
},
)
new_integrals.append(new_integral)
itg_data.integrals = new_integrals
else:
#form = apply_restrictions(form, assume_single_integral_type=have_single_domain, apply_default=False)
for itg_data in self.integral_data:
new_integrals = []
for integral in itg_data.integrals:
new_integral = apply_restrictions(
integral,
assume_single_integral_type=have_single_domain,
apply_default=False,
)
new_integrals.append(new_integral)
itg_data.integrals = new_integrals

# --- Create replacements for arguments and coefficients

# Figure out which form coefficients each integral should enable
Expand Down

0 comments on commit 4ced3a7

Please sign in to comment.