Skip to content

Commit

Permalink
Merge branch 'main' into non_inc_imp_rk
Browse files Browse the repository at this point in the history
  • Loading branch information
tommbendall authored Nov 1, 2024
2 parents 93e20d2 + 6c5e752 commit 8f8d084
Show file tree
Hide file tree
Showing 19 changed files with 879 additions and 94 deletions.
1 change: 1 addition & 0 deletions examples/shallow_water/williamson_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def williamson_5(
# ------------------------------------------------------------------------ #

element_order = 1

# ------------------------------------------------------------------------ #
# Set up model objects
# ------------------------------------------------------------------------ #
Expand Down
23 changes: 12 additions & 11 deletions gusto/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from gusto.core.configuration import * # noqa
from gusto.core.coordinates import * # noqa
from gusto.core.coord_transforms import * # noqa
from gusto.core.domain import * # noqa
from gusto.core.fields import * # noqa
from gusto.core.function_spaces import * # noqa
from gusto.core.io import * # noqa
from gusto.core.kernels import * # noqa
from gusto.core.labels import * # noqa
from gusto.core.logging import * # noqa
from gusto.core.meshes import * # noqa
from gusto.core.configuration import * # noqa
from gusto.core.conservative_projection import * # noqa
from gusto.core.coordinates import * # noqa
from gusto.core.coord_transforms import * # noqa
from gusto.core.domain import * # noqa
from gusto.core.fields import * # noqa
from gusto.core.function_spaces import * # noqa
from gusto.core.io import * # noqa
from gusto.core.kernels import * # noqa
from gusto.core.labels import * # noqa
from gusto.core.logging import * # noqa
from gusto.core.meshes import * # noqa
20 changes: 19 additions & 1 deletion gusto/core/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
"IntegrateByParts", "TransportEquationType", "OutputParameters",
"BoussinesqParameters", "CompressibleParameters",
"ShallowWaterParameters",
"EmbeddedDGOptions", "RecoveryOptions", "SUPGOptions", "MixedFSOptions",
"EmbeddedDGOptions", "ConservativeEmbeddedDGOptions", "RecoveryOptions",
"ConservativeRecoveryOptions", "SUPGOptions", "MixedFSOptions",
"SpongeLayerParameters", "DiffusionParameters", "BoundaryLayerParameters"
]

Expand Down Expand Up @@ -164,6 +165,14 @@ class EmbeddedDGOptions(WrapperOptions):
embedding_space = None


class ConservativeEmbeddedDGOptions(EmbeddedDGOptions):
"""Specifies options for a conservative embedded DG method."""

project_back_method = 'conservative_project'
rho_name = None
orig_rho_space = None


class RecoveryOptions(WrapperOptions):
"""Specifies options for a recovery wrapper method."""

Expand All @@ -177,6 +186,15 @@ class RecoveryOptions(WrapperOptions):
broken_method = 'interpolate'


class ConservativeRecoveryOptions(RecoveryOptions):
"""Specifies options for a conservative recovery wrapper method."""

rho_name = None
orig_rho_space = None
project_high_method = 'conservative_project'
project_low_method = 'conservative_project'


class SUPGOptions(WrapperOptions):
"""Specifies options for an SUPG scheme."""

Expand Down
93 changes: 93 additions & 0 deletions gusto/core/conservative_projection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""
This provides an operator for perform a conservative projection.
The :class:`ConservativeProjector` provided in this module is an operator that
projects a field such as a mixing ratio from one function space to another,
weighted by a density field to ensure that mass is conserved by the projection.
"""

from firedrake import (Function, TestFunction, TrialFunction, lhs, rhs, inner,
dx, LinearVariationalProblem, LinearVariationalSolver,
Constant, assemble)
import ufl

__all__ = ["ConservativeProjector"]


class ConservativeProjector(object):
"""
Projects a field such that mass is conserved.
This object is designed for projecting fields such as mixing ratios of
tracer species from one function space to another, but weighted by density
such that mass is conserved by the projection.
"""

def __init__(self, rho_source, rho_target, m_source, m_target,
subtract_mean=False):
"""
Args:
rho_source (:class:`Function`): the density to use for weighting the
source mixing ratio field. Can also be a :class:`ufl.Expr`.
rho_target (:class:`Function`): the density to use for weighting the
target mixing ratio field. Can also be a :class:`ufl.Expr`.
m_source (:class:`Function`): the source mixing ratio field. Can
also be a :class:`ufl.Expr`.
m_target (:class:`Function`): the target mixing ratio field to
compute.
subtract_mean (bool, optional): whether to solve the projection by
subtracting the mean value of m for both sides. This is more
expensive as it involves calculating the mean, but will ensure
preservation of a constant when projecting to a continuous
space. Default to False.
Raises:
RuntimeError: the geometric shape of the two rho fields must be equal.
RuntimeError: the geometric shape of the two m fields must be equal.
"""

self.subtract_mean = subtract_mean

if not isinstance(rho_source, (ufl.core.expr.Expr, Function)):
raise ValueError("Can only recover UFL expression or Functions not '%s'" % type(rho_source))

if not isinstance(rho_target, (ufl.core.expr.Expr, Function)):
raise ValueError("Can only recover UFL expression or Functions not '%s'" % type(rho_target))

if not isinstance(m_source, (ufl.core.expr.Expr, Function)):
raise ValueError("Can only recover UFL expression or Functions not '%s'" % type(m_source))

# Check shape values
if m_source.ufl_shape != m_target.ufl_shape:
raise RuntimeError('Shape mismatch between source %s and target function spaces %s in project' % (m_source.ufl_shape, m_target.ufl_shape))

if rho_source.ufl_shape != rho_target.ufl_shape:
raise RuntimeError('Shape mismatch between source %s and target function spaces %s in project' % (rho_source.ufl_shape, rho_target.ufl_shape))

self.m_source = m_source
self.m_target = m_target

V = self.m_target.function_space()
mesh = V.mesh()

self.m_mean = Constant(0.0, domain=mesh)
self.volume = assemble(Constant(1.0, domain=mesh)*dx)

test = TestFunction(V)
m_trial = TrialFunction(V)
eqn = (rho_source*inner(test, m_source - self.m_mean)*dx
- rho_target*inner(test, m_trial - self.m_mean)*dx)
problem = LinearVariationalProblem(lhs(eqn), rhs(eqn), self.m_target)
self.solver = LinearVariationalSolver(problem)

def project(self):
"""Apply the projection."""

# Compute mean value
if self.subtract_mean:
self.m_mean.assign(assemble(self.m_source*dx) / self.volume)

# Solve projection
self.solver.solve()

return self.m_target
67 changes: 58 additions & 9 deletions gusto/core/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,22 @@
from pyop2.mpi import MPI
import numpy as np
from gusto.core.logging import logger, update_logfile_location
from collections import namedtuple

__all__ = ["pick_up_mesh", "IO"]
__all__ = ["pick_up_mesh", "IO", "TimeData"]


class GustoIOError(IOError):
pass


# A named tuple object encapsulating data about timing
TimeData = namedtuple(
'TimeData',
['t', 'step', 'initial_steps', 'last_ref_update_time']
)


def pick_up_mesh(output, mesh_name):
"""
Picks up a checkpointed mesh. This must be the first step of any model being
Expand Down Expand Up @@ -531,7 +539,14 @@ def setup_dump(self, state_fields, t, pick_up=False):

# dump initial fields
if not pick_up:
self.dump(state_fields, t, step=1)
step = 1
last_ref_update_time = None
initial_steps = None
time_data = TimeData(
t=t, step=step, initial_steps=initial_steps,
last_ref_update_time=last_ref_update_time
)
self.dump(state_fields, time_data)

def pick_up_from_checkpoint(self, state_fields):
"""
Expand All @@ -541,7 +556,10 @@ def pick_up_from_checkpoint(self, state_fields):
state_fields (:class:`StateFields`): the model's field container.
Returns:
float: the checkpointed model time.
tuple of (`time_data`, `reference_profiles`): where `time_data`
itself is a named tuple containing the timing data.
The `reference_profiles` are a list of (`field_name`, expr)
pairs describing the reference profile fields.
"""

# -------------------------------------------------------------------- #
Expand Down Expand Up @@ -602,6 +620,13 @@ def pick_up_from_checkpoint(self, state_fields):
except AttributeError:
initial_steps = None

# Try to pick up number last_ref_update_time
# Not compulsory so errors allowed
try:
last_ref_update_time = chk.read_attribute("/", "last_ref_update_time")
except AttributeError:
last_ref_update_time = None

# Finally pick up time and step number
t = chk.read_attribute("/", "time")
step = chk.read_attribute("/", "step")
Expand Down Expand Up @@ -632,6 +657,13 @@ def pick_up_from_checkpoint(self, state_fields):
else:
initial_steps = None

# Try to pick up last reference profile update time
# Not compulsory so errors allowed
if chk.has_attr("/", "last_ref_update_time"):
last_ref_update_time = chk.get_attr("/", "last_ref_update_time")
else:
last_ref_update_time = None

# Finally pick up time
t = chk.get_attr("/", "time")
step = chk.get_attr("/", "step")
Expand All @@ -647,9 +679,14 @@ def pick_up_from_checkpoint(self, state_fields):
if hasattr(diagnostic_field, "init_field_set"):
diagnostic_field.init_field_set = True

return t, reference_profiles, step, initial_steps
time_data = TimeData(
t=t, step=step, initial_steps=initial_steps,
last_ref_update_time=last_ref_update_time
)

return time_data, reference_profiles

def dump(self, state_fields, t, step, initial_steps=None):
def dump(self, state_fields, time_data):
"""
Dumps all of the required model output.
Expand All @@ -659,12 +696,20 @@ def dump(self, state_fields, t, step, initial_steps=None):
Args:
state_fields (:class:`StateFields`): the model's field container.
t (float): the simulation's current time.
step (int): the number of time steps.
initial_steps (int, optional): the number of initial time steps
completed by a multi-level time scheme. Defaults to None.
time_data (namedtuple): contains information relating to the time in
the simulation. The tuple is structured as follows:
- t: current time in s
- step: the index of the time step
- initial_steps: number of initial time steps completed by a
multi-level time scheme (could be None)
- last_ref_update_time: the last time in s that the reference
profiles were updated (could be None)
"""
output = self.output
t = time_data.t
step = time_data.step
initial_steps = time_data.initial_steps
last_ref_update_time = time_data.last_ref_update_time

# Diagnostics:
# Compute diagnostic fields
Expand All @@ -688,6 +733,8 @@ def dump(self, state_fields, t, step, initial_steps=None):
self.chkpt.write_attribute("/", "step", step)
if initial_steps is not None:
self.chkpt.write_attribute("/", "initial_steps", initial_steps)
if last_ref_update_time is not None:
self.chkpt.write_attribute("/", "last_ref_update_time", last_ref_update_time)
else:
with CheckpointFile(self.chkpt_path, 'w') as chk:
chk.save_mesh(self.domain.mesh)
Expand All @@ -697,6 +744,8 @@ def dump(self, state_fields, t, step, initial_steps=None):
chk.set_attr("/", "step", step)
if initial_steps is not None:
chk.set_attr("/", "initial_steps", initial_steps)
if last_ref_update_time is not None:
chk.set_attr("/", "last_ref_update_time", last_ref_update_time)

if (next(self.dumpcount) % output.dumpfreq) == 0:
if output.dump_nc:
Expand Down
22 changes: 1 addition & 21 deletions gusto/equations/common_forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"kinetic_energy_form", "advection_equation_circulation_form",
"diffusion_form", "diffusion_form_1d",
"linear_advection_form", "linear_continuity_form",
"linear_continuity_form_1d",
"split_continuity_form", "tracer_conservative_form"]


Expand Down Expand Up @@ -134,26 +133,7 @@ def linear_continuity_form(test, qbar, ubar):
:class:`LabelledForm`: a labelled transport form.
"""

L = qbar*test*div(ubar)*dx
form = transporting_velocity(L, ubar)

return transport(form, TransportEquationType.conservative)


def linear_continuity_form_1d(test, qbar, ubar):
"""
The form corresponding to the linearised continuity transport operator.
Args:
test (:class:`TestFunction`): the test function.
qbar (:class:`ufl.Expr`): the variable to be transported.
ubar (:class:`ufl.Expr`): the transporting velocity.
Returns:
:class:`LabelledForm`: a labelled transport form.
"""

L = qbar*test*ubar.dx(0)*dx
L = test*div(qbar*ubar)*dx
form = transporting_velocity(L, ubar)

return transport(form, TransportEquationType.conservative)
Expand Down
11 changes: 11 additions & 0 deletions gusto/equations/prognostic_equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,17 @@ def add_tracers_to_prognostics(self, domain, active_tracers):
name of the active tracer.
"""

# Check if there are any conservatively transported tracers.
# If so, ensure that the reference density is indexed before this tracer.
for i in range(len(active_tracers) - 1):
tracer = active_tracers[i]
if tracer.transport_eqn == TransportEquationType.tracer_conservative:
ref_density = next(x for x in active_tracers if x.name == tracer.density_name)
j = active_tracers.index(ref_density)
if j > i:
# Swap the indices of the tracer and the reference density
active_tracers[i], active_tracers[j] = active_tracers[j], active_tracers[i]

# Loop through tracer fields and add field names and spaces
for tracer in active_tracers:
if isinstance(tracer, ActiveTracer):
Expand Down
4 changes: 2 additions & 2 deletions gusto/equations/shallow_water_equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
advection_form, advection_form_1d, continuity_form,
continuity_form_1d, vector_invariant_form,
kinetic_energy_form, advection_equation_circulation_form, diffusion_form_1d,
linear_continuity_form, linear_continuity_form_1d
linear_continuity_form
)
from gusto.equations.prognostic_equations import PrognosticEquationSet

Expand Down Expand Up @@ -361,7 +361,7 @@ def __init__(self, domain, parameters,

# Transport term needs special linearisation
if self.linearisation_map(D_adv.terms[0]):
linear_D_adv = linear_continuity_form_1d(phi, H, u_trial)
linear_D_adv = linear_continuity_form(phi, H, u_trial)
# Add linearisation to D_adv
D_adv = linearisation(D_adv, linear_D_adv)

Expand Down
Loading

0 comments on commit 8f8d084

Please sign in to comment.