Skip to content

Commit

Permalink
Increment form for implicit RK added and tested (#566)
Browse files Browse the repository at this point in the history
Co-authored-by: Thomas Bendall <[email protected]>
  • Loading branch information
atb1995 and tommbendall authored Jan 7, 2025
1 parent d61f81f commit 4fab020
Show file tree
Hide file tree
Showing 6 changed files with 512 additions and 352 deletions.
87 changes: 28 additions & 59 deletions gusto/time_discretisation/explicit_runge_kutta.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,9 @@ def __init__(self, domain, butcher_matrix, field_name=None,
limiter=limiter, options=options,
augmentation=augmentation)
self.butcher_matrix = butcher_matrix
self.nbutcher = int(np.shape(self.butcher_matrix)[0])
self.nStages = int(np.shape(self.butcher_matrix)[0])
self.rk_formulation = rk_formulation

@property
def nStages(self):
return self.nbutcher

def setup(self, equation, apply_bcs=True, *active_labels):
"""
Set up the time discretisation based on the equation.
Expand Down Expand Up @@ -163,7 +159,7 @@ def solver(self):
for stage in range(self.nStages):
# setup linear solver using lhs and rhs defined in derived class
problem = NonlinearVariationalProblem(
self.lhs[stage].form - self.rhs[stage].form,
self.res[stage].form,
self.field_i[stage+1], bcs=self.bcs
)
solver_name = self.field_name+self.__class__.__name__+str(stage)
Expand All @@ -176,7 +172,7 @@ def solver(self):

elif self.rk_formulation == RungeKuttaFormulation.linear:
problem = NonlinearVariationalProblem(
self.lhs - self.rhs[0], self.x1, bcs=self.bcs
self.res[0], self.x1, bcs=self.bcs
)
solver_name = self.field_name+self.__class__.__name__
solver = NonlinearVariationalSolver(
Expand All @@ -186,7 +182,7 @@ def solver(self):

# Set up problem for final step
problem_last = NonlinearVariationalProblem(
self.lhs - self.rhs[1], self.x1, bcs=self.bcs
self.res[1], self.x1, bcs=self.bcs
)
solver_name = self.field_name+self.__class__.__name__+'_last'
solver_last = NonlinearVariationalSolver(
Expand All @@ -202,54 +198,21 @@ def solver(self):
)

@cached_property
def lhs(self):
"""Set up the discretisation's left hand side (the time derivative)."""
def res(self):
"""Set up the discretisation's residual."""

if self.rk_formulation == RungeKuttaFormulation.increment:
l = self.residual.label_map(
residual = self.residual.label_map(
lambda t: t.has_label(time_derivative),
map_if_true=replace_subject(self.x_out, old_idx=self.idx),
map_if_false=drop)

return l.form

elif self.rk_formulation == RungeKuttaFormulation.predictor:
lhs_list = []
for stage in range(self.nStages):
l = self.residual.label_map(
lambda t: t.has_label(time_derivative),
map_if_true=replace_subject(self.field_i[stage+1], old_idx=self.idx),
map_if_false=drop)
lhs_list.append(l)

return lhs_list

if self.rk_formulation == RungeKuttaFormulation.linear:
l = self.residual.label_map(
lambda t: t.has_label(time_derivative),
map_if_true=replace_subject(self.x1, old_idx=self.idx),
map_if_false=drop)

return l.form

else:
raise NotImplementedError(
'Runge-Kutta formulation is not implemented'
)

@cached_property
def rhs(self):
"""Set up the time discretisation's right hand side."""

if self.rk_formulation == RungeKuttaFormulation.increment:
r = self.residual.label_map(
all_terms,
map_if_true=replace_subject(self.x1, old_idx=self.idx))

r = r.label_map(
residual += r.label_map(
lambda t: t.has_label(time_derivative),
map_if_true=drop,
map_if_false=lambda t: -1*t)
map_if_true=drop)

# If there are no active labels, we may have no terms at this point
# So that we can still do xnp1 = xn, put in a zero term here
Expand All @@ -261,19 +224,22 @@ def rhs(self):
# Drop label from this
map_if_true=lambda t: time_derivative.remove(t),
map_if_false=drop)
r += null_term
residual += null_term

return r.form
return residual.form

elif self.rk_formulation == RungeKuttaFormulation.predictor:
rhs_list = []

residual_list = []
for stage in range(self.nStages):
residual = self.residual.label_map(
lambda t: t.has_label(time_derivative),
map_if_true=replace_subject(self.field_i[stage+1], self.idx),
map_if_false=drop)
r = self.residual.label_map(
all_terms,
map_if_true=replace_subject(self.field_i[0], old_idx=self.idx))

r = r.label_map(
residual -= r.label_map(
lambda t: t.has_label(time_derivative),
map_if_true=keep,
map_if_false=lambda t: -self.butcher_matrix[stage, 0]*self.dt*t)
Expand All @@ -285,14 +251,16 @@ def rhs(self):
map_if_false=replace_subject(self.field_i[i], old_idx=self.idx)
)

r -= self.butcher_matrix[stage, i]*self.dt*r_i

rhs_list.append(r)
residual += self.butcher_matrix[stage, i]*self.dt*r_i
residual_list.append(residual)

return rhs_list

elif self.rk_formulation == RungeKuttaFormulation.linear:
return residual_list

if self.rk_formulation == RungeKuttaFormulation.linear:
time_term = self.residual.label_map(
lambda t: t.has_label(time_derivative),
map_if_true=replace_subject(self.x1, self.idx),
map_if_false=drop)
r = self.residual.label_map(
lambda t: t.has_label(time_derivative),
map_if_true=replace_subject(self.x0, old_idx=self.idx),
Expand Down Expand Up @@ -331,8 +299,9 @@ def rhs(self):
map_if_true=keep,
map_if_false=lambda t: -self.dt*t
)

return r_all_but_last.form, r.form
res = time_term - r
res_all_but_last = time_term - r_all_but_last
return res_all_but_last.form, res.form

else:
raise NotImplementedError(
Expand Down
54 changes: 31 additions & 23 deletions gusto/time_discretisation/imex_runge_kutta.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class IMEXRungeKutta(TimeDiscretisation):

def __init__(self, domain, butcher_imp, butcher_exp, field_name=None,
linear_solver_parameters=None, nonlinear_solver_parameters=None,
limiter=None, options=None):
limiter=None, options=None, augmentation=None):
"""
Args:
domain (:class:`Domain`): the model's domain object, containing the
Expand All @@ -82,10 +82,13 @@ def __init__(self, domain, butcher_imp, butcher_exp, field_name=None,
options to either be passed to the spatial discretisation, or
to control the "wrapper" methods, such as Embedded DG or a
recovery method. Defaults to None.
augmentation (:class:`Augmentation`): allows the equation solved in
this time discretisation to be augmented, for instances with
extra terms of another auxiliary variable. Defaults to None.
"""
super().__init__(domain, field_name=field_name,
solver_parameters=nonlinear_solver_parameters,
options=options)
options=options, augmentation=augmentation)
self.butcher_imp = butcher_imp
self.butcher_exp = butcher_exp
self.nStages = int(np.shape(self.butcher_imp)[1])
Expand Down Expand Up @@ -127,16 +130,6 @@ def setup(self, equation, apply_bcs=True, *active_labels):

self.xs = [Function(self.fs) for i in range(self.nStages)]

@cached_property
def lhs(self):
"""Set up the discretisation's left hand side (the time derivative)."""
return super(IMEXRungeKutta, self).lhs

@cached_property
def rhs(self):
"""Set up the discretisation's right hand side (the time derivative)."""
return super(IMEXRungeKutta, self).rhs

def res(self, stage):
"""Set up the discretisation's residual for a given stage."""
# Add time derivative terms y_s - y^n for stage s
Expand Down Expand Up @@ -226,7 +219,7 @@ def solvers(self):
@cached_property
def final_solver(self):
"""Set up a solver for the final solve to evaluate time level n+1."""
# setup solver using lhs and rhs defined in derived class
# setup solver using residual (res) defined in derived class
problem = NonlinearVariationalProblem(self.final_res, self.x_out, bcs=self.bcs)
solver_name = self.field_name+self.__class__.__name__
return NonlinearVariationalSolver(problem, solver_parameters=self.linear_solver_parameters, options_prefix=solver_name)
Expand Down Expand Up @@ -269,7 +262,7 @@ class IMEX_Euler(IMEXRungeKutta):
"""
def __init__(self, domain, field_name=None,
linear_solver_parameters=None, nonlinear_solver_parameters=None,
limiter=None, options=None):
limiter=None, options=None, augmentation=None):
"""
Args:
domain (:class:`Domain`): the model's domain object, containing the
Expand All @@ -286,13 +279,16 @@ def __init__(self, domain, field_name=None,
options to either be passed to the spatial discretisation, or
to control the "wrapper" methods, such as Embedded DG or a
recovery method. Defaults to None.
augmentation (:class:`Augmentation`): allows the equation solved in
this time discretisation to be augmented, for instances with
extra terms of another auxiliary variable. Defaults to None.
"""
butcher_imp = np.array([[0., 0.], [0., 1.], [0., 1.]])
butcher_exp = np.array([[0., 0.], [1., 0.], [1., 0.]])
super().__init__(domain, butcher_imp, butcher_exp, field_name,
linear_solver_parameters=linear_solver_parameters,
nonlinear_solver_parameters=nonlinear_solver_parameters,
limiter=limiter, options=options)
limiter=limiter, options=options, augmentation=augmentation)


class IMEX_ARS3(IMEXRungeKutta):
Expand All @@ -313,7 +309,7 @@ class IMEX_ARS3(IMEXRungeKutta):
"""
def __init__(self, domain, field_name=None,
linear_solver_parameters=None, nonlinear_solver_parameters=None,
limiter=None, options=None):
limiter=None, options=None, augmentation=None):
"""
Args:
domain (:class:`Domain`): the model's domain object, containing the
Expand All @@ -330,6 +326,9 @@ def __init__(self, domain, field_name=None,
options to either be passed to the spatial discretisation, or
to control the "wrapper" methods, such as Embedded DG or a
recovery method. Defaults to None.
augmentation (:class:`Augmentation`): allows the equation solved in
this time discretisation to be augmented, for instances with
extra terms of another auxiliary variable. Defaults to None.
"""
g = (3. + np.sqrt(3.))/6.
butcher_imp = np.array([[0., 0., 0.], [0., g, 0.], [0., 1-2.*g, g], [0., 0.5, 0.5]])
Expand All @@ -338,7 +337,7 @@ def __init__(self, domain, field_name=None,
super().__init__(domain, butcher_imp, butcher_exp, field_name,
linear_solver_parameters=linear_solver_parameters,
nonlinear_solver_parameters=nonlinear_solver_parameters,
limiter=limiter, options=options)
limiter=limiter, options=options, augmentation=augmentation)


class IMEX_ARK2(IMEXRungeKutta):
Expand All @@ -359,7 +358,7 @@ class IMEX_ARK2(IMEXRungeKutta):
"""
def __init__(self, domain, field_name=None,
linear_solver_parameters=None, nonlinear_solver_parameters=None,
limiter=None, options=None):
limiter=None, options=None, augmentation=None):
"""
Args:
domain (:class:`Domain`): the model's domain object, containing the
Expand All @@ -376,6 +375,9 @@ def __init__(self, domain, field_name=None,
options to either be passed to the spatial discretisation, or
to control the "wrapper" methods, such as Embedded DG or a
recovery method. Defaults to None.
augmentation (:class:`Augmentation`): allows the equation solved in
this time discretisation to be augmented, for instances with
extra terms of another auxiliary variable. Defaults to None.
"""
g = 1. - 1./np.sqrt(2.)
d = 1./(2.*np.sqrt(2.))
Expand All @@ -385,7 +387,7 @@ def __init__(self, domain, field_name=None,
super().__init__(domain, butcher_imp, butcher_exp, field_name,
linear_solver_parameters=linear_solver_parameters,
nonlinear_solver_parameters=nonlinear_solver_parameters,
limiter=limiter, options=options)
limiter=limiter, options=options, augmentation=augmentation)


class IMEX_SSP3(IMEXRungeKutta):
Expand All @@ -404,7 +406,7 @@ class IMEX_SSP3(IMEXRungeKutta):
"""
def __init__(self, domain, field_name=None,
linear_solver_parameters=None, nonlinear_solver_parameters=None,
limiter=None, options=None):
limiter=None, options=None, augmentation=None):
"""
Args:
domain (:class:`Domain`): the model's domain object, containing the
Expand All @@ -421,14 +423,17 @@ def __init__(self, domain, field_name=None,
options to either be passed to the spatial discretisation, or
to control the "wrapper" methods, such as Embedded DG or a
recovery method. Defaults to None.
augmentation (:class:`Augmentation`): allows the equation solved in
this time discretisation to be augmented, for instances with
extra terms of another auxiliary variable. Defaults to None.
"""
g = 1. - (1./np.sqrt(2.))
butcher_imp = np.array([[g, 0., 0.], [1-2.*g, g, 0.], [0.5-g, 0., g], [(1./6.), (1./6.), (2./3.)]])
butcher_exp = np.array([[0., 0., 0.], [1., 0., 0.], [0.25, 0.25, 0.], [(1./6.), (1./6.), (2./3.)]])
super().__init__(domain, butcher_imp, butcher_exp, field_name,
linear_solver_parameters=linear_solver_parameters,
nonlinear_solver_parameters=nonlinear_solver_parameters,
limiter=limiter, options=options)
limiter=limiter, options=options, augmentation=augmentation)


class IMEX_Trap2(IMEXRungeKutta):
Expand All @@ -447,7 +452,7 @@ class IMEX_Trap2(IMEXRungeKutta):
"""
def __init__(self, domain, field_name=None,
linear_solver_parameters=None, nonlinear_solver_parameters=None,
limiter=None, options=None):
limiter=None, options=None, augmentation=None):
"""
Args:
domain (:class:`Domain`): the model's domain object, containing the
Expand All @@ -464,11 +469,14 @@ def __init__(self, domain, field_name=None,
options to either be passed to the spatial discretisation, or
to control the "wrapper" methods, such as Embedded DG or a
recovery method. Defaults to None.
augmentation (:class:`Augmentation`): allows the equation solved in
this time discretisation to be augmented, for instances with
extra terms of another auxiliary variable. Defaults to None.
"""
e = 0.
butcher_imp = np.array([[0., 0., 0., 0.], [e, 0., 0., 0.], [0.5, 0., 0.5, 0.], [0.5, 0., 0., 0.5], [0.5, 0., 0., 0.5]])
butcher_exp = np.array([[0., 0., 0., 0.], [1., 0., 0., 0.], [0.5, 0.5, 0., 0.], [0.5, 0., 0.5, 0.], [0.5, 0., 0.5, 0.]])
super().__init__(domain, butcher_imp, butcher_exp, field_name,
linear_solver_parameters=linear_solver_parameters,
nonlinear_solver_parameters=nonlinear_solver_parameters,
limiter=limiter, options=options)
limiter=limiter, options=options, augmentation=augmentation)
Loading

0 comments on commit 4fab020

Please sign in to comment.