Skip to content

Commit

Permalink
BaseForm: ensure that subclasses implement ufl_domains()
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Jul 25, 2024
1 parent 004a678 commit 01d5bbd
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 29 deletions.
10 changes: 9 additions & 1 deletion ufl/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,15 @@ def _analyze_domains(self):
from ufl.domain import join_domains

# Collect domains
self._domains = join_domains(chain.from_iterable(e.ufl_domain() for e in self.ufl_operands))
self._domains = join_domains(
chain.from_iterable(e.ufl_domains() for e in self.ufl_operands)
)

def ufl_domains(self):
"""Return all domains found in the base form."""
if self._domains is None:
self._analyze_domains()
return self._domains

def equals(self, other):
"""Check if two Actions are equal."""
Expand Down
12 changes: 11 additions & 1 deletion ufl/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#
# Modified by Nacime Bouziani, 2021-2022.

from itertools import chain

from ufl.argument import Coargument
from ufl.core.ufl_type import ufl_type
from ufl.form import BaseForm, FormSum, ZeroBaseForm
Expand Down Expand Up @@ -97,7 +99,15 @@ def _analyze_domains(self):
from ufl.domain import join_domains

# Collect unique domains
self._domains = join_domains([e.ufl_domain() for e in self.ufl_operands])
self._domains = join_domains(
chain.from_iterable(e.ufl_domains() for e in self.ufl_operands)
)

def ufl_domains(self):
"""Return all domains found in the base form."""
if self._domains is None:
self._analyze_domains()
return self._domains

def equals(self, other):
"""Check if two Adjoints are equal."""
Expand Down
33 changes: 6 additions & 27 deletions ufl/form.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,12 @@ def ufl_domain(self):
Fails if multiple domains are found.
"""
if self._domains is None:
self._analyze_domains()

if len(self._domains) > 1:
try:
(domain,) = set(self.ufl_domains())
except ValueError:
raise ValueError("%s must have exactly one domain." % type(self).__name__)
# Return the single geometric domain
return self._domains[0]
# Return the one and only domain
return domain

# --- Operator implementations ---

Expand All @@ -139,7 +138,7 @@ def __radd__(self, other):

def __add__(self, other):
"""Add."""
if isinstance(other, (int, float)) and other == 0:
if isinstance(other, numbers.Number) and other == 0:
# Allow adding 0 or 0.0 as a no-op, needed for sum([a,b])
return self
elif isinstance(other, Zero):
Expand Down Expand Up @@ -329,26 +328,6 @@ def ufl_cell(self):
"""
return self.ufl_domain().ufl_cell()

def ufl_domain(self):
"""Return the single geometric integration domain occuring in the form.
Fails if multiple domains are found.
NB! This does not include domains of coefficients defined on
other meshes, look at form data for that additional information.
"""
# Collect all domains
domains = self.ufl_domains()
# Check that all are equal TODO: don't return more than one if
# all are equal?
if not all(domain == domains[0] for domain in domains):
raise ValueError(
"Calling Form.ufl_domain() is only valid if all integrals share domain."
)

# Return the one and only domain
return domains[0]

def geometric_dimension(self):
"""Return the geometric dimension shared by all domains and functions in this form."""
gdims = tuple(set(domain.geometric_dimension() for domain in self.ufl_domains()))
Expand Down

0 comments on commit 01d5bbd

Please sign in to comment.