diff --git a/ufl/algorithms/apply_restrictions.py b/ufl/algorithms/apply_restrictions.py index d5928914f..276e3259d 100644 --- a/ufl/algorithms/apply_restrictions.py +++ b/ufl/algorithms/apply_restrictions.py @@ -26,21 +26,23 @@ class RestrictionPropagator(MultiFunction): """Restriction propagator.""" - def __init__(self, side=None, assume_single_integral_type=True, apply_default=True): + def __init__(self, side=None, assume_single_integral_type=True, apply_default=True, default_restriction=None): """Initialise.""" MultiFunction.__init__(self) self.current_restriction = side - self.default_restriction = "+" if assume_single_integral_type else "?" + if default_restriction is None: + default_restriction = "+" if assume_single_integral_type else "?" + self.default_restriction = default_restriction self.apply_default = apply_default # Caches for propagating the restriction with map_expr_dag self.vcaches = {"+": {}, "-": {}, "|": {}, "?": {}} self.rcaches = {"+": {}, "-": {}, "|": {}, "?": {}} if self.current_restriction is None: self._rp = { - "+": RestrictionPropagator("+", assume_single_integral_type, apply_default), - "-": RestrictionPropagator("-", assume_single_integral_type, apply_default), - "|": RestrictionPropagator("|", assume_single_integral_type, apply_default), - "?": RestrictionPropagator("?", assume_single_integral_type, apply_default), + "+": RestrictionPropagator("+", assume_single_integral_type, apply_default, default_restriction), + "-": RestrictionPropagator("-", assume_single_integral_type, apply_default, default_restriction), + "|": RestrictionPropagator("|", assume_single_integral_type, apply_default, default_restriction), + "?": RestrictionPropagator("?", assume_single_integral_type, apply_default, default_restriction), } self.assume_single_integral_type = assume_single_integral_type @@ -71,6 +73,9 @@ def _require_restriction(self, o): if self.current_restriction is not None: return o(self.current_restriction) elif not self.assume_single_integral_type: + # If integration if over interior facet of meshA and exterior facet of meshB, + # arguments (say) on meshA must be restricted, but those on meshB do not + # need to be. return o else: raise ValueError(f"Discontinuous type {o._ufl_class_.__name__} must be restricted.") @@ -84,7 +89,19 @@ def _default_restricted(self, o): domain = extract_unique_domain(o, expand_mixed_mesh=False) if isinstance(domain, MixedMesh): raise RuntimeError(f"Not expecting a terminal object on a mixed mesh at this stage: found {repr(o)}") - return o(self.default_restriction[domain]) + if isinstance(self.default_restriction, dict): + if domain not in self.default_restriction: + raise RuntimeError(f"Integral type on {domain} not known") + r = self.default_restriction[domain] + if r is None: + return o + elif r in ["+", "-"]: + return o(r) + else: + raise RuntimeError(f"Unknown default restriction {r} on domain {domain}") + else: + # conventional "+" default: + return o(self.default_restriction) else: return o @@ -93,12 +110,26 @@ def _opposite(self, o): If the current restriction is different swap the sign, require a side to be set. """ - if self.current_restriction is None: - raise ValueError(f"Discontinuous type {o._ufl_class_.__name__} must be restricted.") - elif self.current_restriction == self.default_restriction: - return o(self.default_restriction) + if isinstance(self.default_restriction, dict): + domain = extract_unique_domain(o, expand_mixed_mesh=False) + if isinstance(domain, MixedMesh): + raise RuntimeError(f"Not expecting a terminal object on a mixed mesh at this stage: found {repr(o)}") + if domain not in self.default_restriction: + raise RuntimeError(f"Integral type on {domain} not known") + r = self.default_restriction[domain] else: - return -o(self.default_restriction) + r = self.default_restriction + if r is None: + if self.current_restriction is not None: + raise ValueError(f"Expecting current_restriction None: got {self.current_restriction}") + return o + else: + if self.current_restriction is None: + raise ValueError(f"Discontinuous type {o._ufl_class_.__name__} must be restricted.") + elif self.current_restriction == r: + return o(self.default_restriction) + else: + return -o(self.default_restriction) def _missing_rule(self, o): """Raise an error.""" @@ -206,7 +237,7 @@ def facet_normal(self, o): return self._require_restriction(o) -def apply_restrictions(expression, assume_single_integral_type=True, apply_default=True): +def apply_restrictions(expression, assume_single_integral_type=True, apply_default=True, default_restriction=None): """Propagate restriction nodes to wrap differential terminals directly.""" if assume_single_integral_type: integral_types = [ @@ -217,7 +248,7 @@ def apply_restrictions(expression, assume_single_integral_type=True, apply_defau # the integral type of a given function; e.g., the former can be # ``exterior_facet`` and the latter ``interior_facet``. integral_types = None - rules = RestrictionPropagator(assume_single_integral_type=assume_single_integral_type, apply_default=apply_default) + rules = RestrictionPropagator(assume_single_integral_type=assume_single_integral_type, apply_default=apply_default, default_restriction=default_restriction) if isinstance(expression, FormData): for integral_data in expression.integral_data: integral_data.integrals = tuple( @@ -347,7 +378,7 @@ def to_be_restricted(self, o): return mt elif integral_type == "exterior_facet": return SingleValueRestricted(mt) - elif integral_type == "interial_facet": + elif integral_type == "interior_facet": return PositiveRestricted(mt) else: raise RuntimeError(f"Unknown integral type: {integral_type}") @@ -355,7 +386,20 @@ def to_be_restricted(self, o): def replace_to_be_restricted(integral_data): new_integrals = [] - rule = ToBeRestrectedReplacer(integral_data.domain_integral_type_map) + #rule = ToBeRestrectedReplacer(integral_data.domain_integral_type_map) + rule = RestrictionPropagator( + side=None, + assume_single_integral_type=False, + apply_default=True, + default_restriction={ + domain: { + "cell": None, + "exterior_facet": None, + "interior_facet": "+", + }[integral_type] + for domain, integral_type in integral_data.domain_integral_type_map.items() + }, + ) for integral in integral_data.integrals: integrand = map_expr_dag(rule, integral.integrand()) new_integrals.append(integral.reconstruct(integrand=integrand)) diff --git a/ufl/algorithms/compute_form_data.py b/ufl/algorithms/compute_form_data.py index c6f27f07b..8331e36bd 100644 --- a/ufl/algorithms/compute_form_data.py +++ b/ufl/algorithms/compute_form_data.py @@ -338,13 +338,6 @@ def compute_form_data( form = apply_coordinate_derivatives(form) - # Propagate restrictions to terminals - if do_apply_restrictions: - if do_assume_single_integral_type: - form = apply_restrictions(form, apply_default=do_apply_default_restrictions) - else: - form = apply_restrictions(form, assume_single_integral_type=have_single_domain, apply_default=False) - # If in real mode, remove any complex nodes introduced during form processing. if not complex_mode: form = remove_complex_nodes(form) @@ -353,6 +346,38 @@ def compute_form_data( # Most of the heavy lifting is done above in group_form_integrals. self.integral_data = build_integral_data(form.integrals()) + # Propagate restrictions to terminals + if do_apply_restrictions: + if do_assume_single_integral_type or have_single_domain: + for itg_data in self.integral_data: + new_integrals = [] + for integral in itg_data.integrals: + new_integral = apply_restrictions( + integral, + apply_default=do_apply_default_restrictions, + default_restriction={ + itg_data.domain: { + "cell": None, + "exterior_facet": None, + "interior_facet": "+", + }[itg_data.integral_type] + }, + ) + new_integrals.append(new_integral) + itg_data.integrals = new_integrals + else: + #form = apply_restrictions(form, assume_single_integral_type=have_single_domain, apply_default=False) + for itg_data in self.integral_data: + new_integrals = [] + for integral in itg_data.integrals: + new_integral = apply_restrictions( + integral, + assume_single_integral_type=have_single_domain, + apply_default=False, + ) + new_integrals.append(new_integral) + itg_data.integrals = new_integrals + # --- Create replacements for arguments and coefficients # Figure out which form coefficients each integral should enable