Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/multi-volume' into multi-volume
Browse files Browse the repository at this point in the history
  • Loading branch information
majosm committed May 2, 2022
2 parents 7ea9d73 + 5ee00b4 commit e705304
Show file tree
Hide file tree
Showing 9 changed files with 554 additions and 160 deletions.
52 changes: 29 additions & 23 deletions examples/wave/wave-op-mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import pyopencl.tools as cl_tools

from arraycontext import (
thaw, freeze,
thaw,
with_container_arithmetic,
dataclass_array_container
)
Expand All @@ -45,7 +45,7 @@
from grudge.dof_desc import as_dofdesc, DOFDesc, DISCR_TAG_BASE, DISCR_TAG_QUAD
from grudge.trace_pair import TracePair
from grudge.discretization import DiscretizationCollection
from grudge.shortcuts import make_visualizer, rk4_step
from grudge.shortcuts import make_visualizer, compiled_lsrk45_step

import grudge.op as op

Expand All @@ -57,7 +57,8 @@

# {{{ wave equation bits

@with_container_arithmetic(bcast_obj_array=True, rel_comparison=True)
@with_container_arithmetic(bcast_obj_array=True, rel_comparison=True,
_cls_has_array_context_attr=True)
@dataclass_array_container
@dataclass(frozen=True)
class WaveState:
Expand Down Expand Up @@ -251,7 +252,8 @@ def main(ctx_factory, dim=2, order=3,
c = 1

# FIXME: Sketchy, empirically determined fudge factor
dt = actx.to_numpy(0.45 * estimate_rk4_timestep(actx, dcoll, c))
# 5/4 to account for larger LSRK45 stability region
dt = actx.to_numpy(0.45 * estimate_rk4_timestep(actx, dcoll, c)) * 5/4

vis = make_visualizer(dcoll)

Expand All @@ -271,25 +273,32 @@ def rhs(t, w):
istep = 0
while t < t_final:
start = time.time()
if lazy:
fields = thaw(freeze(fields, actx), actx)

fields = rk4_step(fields, t, dt, compiled_rhs)

l2norm = actx.to_numpy(op.norm(dcoll, fields.u, 2))
fields = compiled_lsrk45_step(actx, fields, t, dt, compiled_rhs)

if istep % 10 == 0:
stop = time.time()
linfnorm = actx.to_numpy(op.norm(dcoll, fields.u, np.inf))
nodalmax = actx.to_numpy(op.nodal_max(dcoll, "vol", fields.u))
nodalmin = actx.to_numpy(op.nodal_min(dcoll, "vol", fields.u))
if comm.rank == 0:
logger.info(f"step: {istep} t: {t} "
f"L2: {l2norm} "
f"Linf: {linfnorm} "
f"sol max: {nodalmax} "
f"sol min: {nodalmin} "
f"wall: {stop-start} ")
if args.no_diagnostics:
if comm.rank == 0:
logger.info(f"step: {istep} t: {t} "
f"wall: {stop-start} ")
else:
l2norm = actx.to_numpy(op.norm(dcoll, fields.u, 2))

# NOTE: These are here to ensure the solution is bounded for the
# time interval specified
assert l2norm < 1

linfnorm = actx.to_numpy(op.norm(dcoll, fields.u, np.inf))
nodalmax = actx.to_numpy(op.nodal_max(dcoll, "vol", fields.u))
nodalmin = actx.to_numpy(op.nodal_min(dcoll, "vol", fields.u))
if comm.rank == 0:
logger.info(f"step: {istep} t: {t} "
f"L2: {l2norm} "
f"Linf: {linfnorm} "
f"sol max: {nodalmax} "
f"sol min: {nodalmin} "
f"wall: {stop-start} ")
if visualize:
vis.write_parallel_vtk_file(
comm,
Expand All @@ -304,10 +313,6 @@ def rhs(t, w):
t += dt
istep += 1

# NOTE: These are here to ensure the solution is bounded for the
# time interval specified
assert l2norm < 1


if __name__ == "__main__":
import argparse
Expand All @@ -320,6 +325,7 @@ def rhs(t, w):
help="switch to a lazy computation mode")
parser.add_argument("--quad", action="store_true")
parser.add_argument("--nonaffine", action="store_true")
parser.add_argument("--no-diagnostics", action="store_true")

args = parser.parse_args()

Expand Down
10 changes: 1 addition & 9 deletions grudge/models/euler.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,18 +318,10 @@ def operator(self, t, q):
qtag = self.qtag
dq = DOFDesc("vol", qtag)
df = DOFDesc("all_faces", qtag)
df_int = DOFDesc("int_faces", qtag)

def interp_to_quad(u):
return op.project(dcoll, "vol", dq, u)

def interp_to_quad_surf(u):
return TracePair(
df_int,
interior=op.project(dcoll, "int_faces", df_int, u.int),
exterior=op.project(dcoll, "int_faces", df_int, u.ext)
)

# Compute volume fluxes
volume_fluxes = op.weak_local_div(
dcoll, dq,
Expand All @@ -341,7 +333,7 @@ def interp_to_quad_surf(u):
sum(
euler_numerical_flux(
dcoll,
interp_to_quad_surf(tpair),
op.tracepair_with_discr_tag(dcoll, qtag, tpair),
gamma=gamma,
lf_stabilization=self.lf_stabilization
) for tpair in op.interior_trace_pairs(dcoll, q)
Expand Down
Loading

0 comments on commit e705304

Please sign in to comment.