Skip to content

Commit

Permalink
move to new ufl api (#351)
Browse files Browse the repository at this point in the history
* move to new ufl

* One more UFL related fix (I hope)

---------

Co-authored-by: Stephan Kramer <[email protected]>
  • Loading branch information
jhill1 and stephankramer authored Nov 20, 2023
1 parent 1c85fbc commit a51d859
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 26 deletions.
2 changes: 1 addition & 1 deletion examples/tohoku_inversion/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, mesh2d, coord_system, element=None):
self.mesh2d = mesh2d
self.coord_system = coord_system
if element is None:
element = ufl.FiniteElement("Lagrange", mesh2d.ufl_cell(), 1)
element = fd.FiniteElement("Lagrange", mesh2d.ufl_cell(), 1)
self.function_space = utility.get_functionspace(mesh2d, element.family(), element.degree())
self._elev_init = fd.Function(self.function_space, name="Elevation")
self.xy = ufl.SpatialCoordinate(mesh2d)
Expand Down
6 changes: 3 additions & 3 deletions thetis/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ def get_visu_space(fs):
"""
mesh = fs.mesh()
family = 'Lagrange' if is_cg(fs) else 'Discontinuous Lagrange'
if len(fs.ufl_element().value_shape()) == 1:
dim = fs.ufl_element().value_shape()[0]
if len(fs.ufl_element().value_shape) == 1:
dim = fs.ufl_element().value_shape[0]
visu_fs = get_functionspace(mesh, family, 1, family, 1,
vector=True, dim=dim)
elif len(fs.ufl_element().value_shape()) == 2:
elif len(fs.ufl_element().value_shape) == 2:
visu_fs = get_functionspace(mesh, family, 1, family, 1,
tensor=True)
else:
Expand Down
2 changes: 1 addition & 1 deletion thetis/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def __init__(self, function_space, coord_system, fill_mode=None,
used. Otherwise a constant fill value will be used (default).
:kwarg float fill_value: Set the fill value (default: NaN)
"""
assert function_space.ufl_element().value_shape() == ()
assert function_space.ufl_element().value_shape == ()

# construct local coordinates
on_sphere = function_space.mesh().geometric_dimension() == 3
Expand Down
7 changes: 3 additions & 4 deletions thetis/limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""
from .utility import *
from firedrake import VertexBasedLimiter
import ufl
from pyop2.profiling import timed_stage
import numpy

Expand All @@ -24,12 +23,12 @@ def assert_function_space(fs, family, degree):
if not isinstance(family, list):
fam_list = [family]
ufl_elem = fs.ufl_element()
if isinstance(ufl_elem, ufl.VectorElement):
ufl_elem = ufl_elem.sub_elements()[0]
if isinstance(ufl_elem, firedrake.VectorElement):
ufl_elem = ufl_elem.sub_elements[0]

if ufl_elem.family() == 'TensorProductElement':
# extruded mesh
A, B = ufl_elem.sub_elements()
A, B = ufl_elem.sub_elements
assert A.family() in fam_list, \
'horizontal space must be one of {0:s}'.format(fam_list)
assert B.family() in fam_list, \
Expand Down
22 changes: 11 additions & 11 deletions thetis/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,13 @@ def get_extruded_base_element(ufl_element):
In case of a non-extruded mesh, returns the element itself.
"""
if isinstance(ufl_element, ufl.HDivElement):
if isinstance(ufl_element, firedrake.HDivElement):
ufl_element = ufl_element._element
if isinstance(ufl_element, ufl.MixedElement):
ufl_element = ufl_element.sub_elements()[0]
if isinstance(ufl_element, ufl.VectorElement):
ufl_element = ufl_element.sub_elements()[0] # take the first component
if isinstance(ufl_element, ufl.EnrichedElement):
if isinstance(ufl_element, firedrake.MixedElement):
ufl_element = ufl_element.sub_elements[0]
if isinstance(ufl_element, firedrake.VectorElement):
ufl_element = ufl_element.sub_elements[0] # take the first component
if isinstance(ufl_element, firedrake.EnrichedElement):
ufl_element = ufl_element._elements[0]
return ufl_element

Expand Down Expand Up @@ -220,11 +220,11 @@ def element_continuity(ufl_element):
}

base_element = get_extruded_base_element(ufl_element)
if isinstance(elem, ufl.HDivElement):
if isinstance(elem, firedrake.HDivElement):
horiz_type = 'hdiv'
vert_type = 'hdiv'
elif isinstance(base_element, ufl.TensorProductElement):
a, b = base_element.sub_elements()
elif isinstance(base_element, firedrake.TensorProductElement):
a, b = base_element.sub_elements
horiz_type = elem_types[a.family()]
vert_type = elem_types[b.family()]
else:
Expand Down Expand Up @@ -311,7 +311,7 @@ def get_facet_mask(function_space, facet='bottom'):
assert isinstance(elem, TensorProductElement), \
f'function space must be defined on an extruded 3D mesh: {elem}'
# figure out number of nodes in sub elements
h_elt, v_elt = elem.sub_elements()
h_elt, v_elt = elem.sub_elements
nb_nodes_h = create_finat_element(h_elt).space_dimension()
nb_nodes_v = create_finat_element(v_elt).space_dimension()
# compute top/bottom facet indices
Expand Down Expand Up @@ -484,7 +484,7 @@ def extend_function_to_3d(func, mesh_extruded):
family = ufl_elem.family()
degree = ufl_elem.degree()
name = func.name()
if isinstance(ufl_elem, ufl.VectorElement):
if isinstance(ufl_elem, firedrake.VectorElement):
# vector function space
fs_extended = get_functionspace(mesh_extruded, family, degree, 'R', 0,
dim=2, vector=True)
Expand Down
12 changes: 6 additions & 6 deletions thetis/utility3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,8 +502,8 @@ def __init__(self, input_2d, output_3d, elem_height=None):

family_2d = self.fs_2d.ufl_element().family()
base_element_3d = get_extruded_base_element(self.fs_3d.ufl_element())
assert isinstance(base_element_3d, ufl.TensorProductElement)
family_3dh = base_element_3d.sub_elements()[0].family()
assert isinstance(base_element_3d, firedrake.TensorProductElement)
family_3dh = base_element_3d.sub_elements[0].family()
if family_2d != family_3dh:
raise Exception('2D and 3D spaces do not match: {0:s} {1:s}'.format(family_2d, family_3dh))
self.do_hdiv_scaling = family_2d in ['Raviart-Thomas', 'RTCF', 'Brezzi-Douglas-Marini', 'BDMCF']
Expand Down Expand Up @@ -625,8 +625,8 @@ def __init__(self, input_3d, output_2d,

family_2d = self.fs_2d.ufl_element().family()
base_element_3d = get_extruded_base_element(self.fs_3d.ufl_element())
assert isinstance(base_element_3d, ufl.TensorProductElement)
family_3dh = base_element_3d.sub_elements()[0].family()
assert isinstance(base_element_3d, firedrake.TensorProductElement)
family_3dh = base_element_3d.sub_elements[0].family()
if family_2d != family_3dh:
raise Exception('2D and 3D spaces do not match: {0:s} {1:s}'.format(family_2d, family_3dh))
self.do_hdiv_scaling = family_2d in ['Raviart-Thomas', 'RTCF', 'Brezzi-Douglas-Marini', 'BDMCF']
Expand Down Expand Up @@ -747,8 +747,8 @@ def __init__(self, solver):

family_2d = self.fs_2d.ufl_element().family()
base_element_3d = get_extruded_base_element(self.fs_3d.ufl_element())
assert isinstance(base_element_3d, ufl.TensorProductElement)
family_3dh = base_element_3d.sub_elements()[0].family()
assert isinstance(base_element_3d, firedrake.TensorProductElement)
family_3dh = base_element_3d.sub_elements[0].family()
if family_2d != family_3dh:
raise Exception('2D and 3D spaces do not match: "{0:s}" != "{1:s}"'.format(family_2d, family_3dh))

Expand Down

0 comments on commit a51d859

Please sign in to comment.