diff --git a/grudge/discretization.py b/grudge/discretization.py index 466ca163b..3ead9485f 100644 --- a/grudge/discretization.py +++ b/grudge/discretization.py @@ -29,6 +29,7 @@ THE SOFTWARE. """ +from typing import Callable from pytools import memoize_method from grudge.dof_desc import ( @@ -37,6 +38,7 @@ DISCR_TAG_MODAL, DTAG_BOUNDARY, DOFDesc, + DiscretizationTag, as_dofdesc ) @@ -44,10 +46,12 @@ from arraycontext import ArrayContext +from meshmode.discretization import ElementGroupFactory from meshmode.discretization.connection import ( FACE_RESTR_INTERIOR, FACE_RESTR_ALL, - make_face_restriction + make_face_restriction, + DiscretizationConnection ) from meshmode.mesh import Mesh, BTAG_PARTITION @@ -174,9 +178,18 @@ def __init__(self, array_context: ArrayContext, mesh: Mesh, # }}} - self._dist_boundary_connections = \ - self._set_up_distributed_communication( - mpi_communicator, array_context) + self._dist_boundary_connections = _get_dist_boundary_connections_single_vol( + volume_discr=self._volume_discr, + mpi_communicator=mpi_communicator, + array_context=self._setup_actx, + get_group_factory_for_discretization_tag=( + self.group_factory_for_discretization_tag), + get_connection_to_rank_boundary=( + lambda remote_rank: + self.connection_from_dds( + DOFDesc("vol", DISCR_TAG_BASE), + DOFDesc(BTAG_PARTITION(remote_rank), + DISCR_TAG_BASE))),) # }}} @@ -201,40 +214,6 @@ def is_management_rank(self): # {{{ distributed - def _set_up_distributed_communication(self, mpi_communicator, array_context): - from_dd = DOFDesc("vol", DISCR_TAG_BASE) - - boundary_connections = {} - - from meshmode.distributed import get_connected_partitions - connected_parts = get_connected_partitions(self._volume_discr.mesh) - - if connected_parts: - if mpi_communicator is None: - raise RuntimeError("must supply an MPI communicator when using a " - "distributed mesh") - - grp_factory = \ - self.group_factory_for_discretization_tag(DISCR_TAG_BASE) - - local_boundary_connections = {} - for i_remote_part in connected_parts: - local_boundary_connections[i_remote_part] = self.connection_from_dds( - from_dd, DOFDesc(BTAG_PARTITION(i_remote_part), - DISCR_TAG_BASE)) - - from meshmode.distributed import MPIBoundaryCommSetupHelper - with MPIBoundaryCommSetupHelper(mpi_communicator, array_context, - local_boundary_connections, grp_factory) as bdry_setup_helper: - while True: - conns = bdry_setup_helper.complete_some() - if not conns: - break - for i_remote_part, conn in conns.items(): - boundary_connections[i_remote_part] = conn - - return boundary_connections - def distributed_boundary_swap_connection(self, dd): """Provides a mapping from the base volume discretization to the exterior boundary restriction on a parallel boundary @@ -719,6 +698,45 @@ def normal(self, dd): # }}} +# {{{ distributed setup + +def _get_dist_boundary_connections_single_vol( + volume_discr, mpi_communicator, array_context, + get_group_factory_for_discretization_tag: Callable[ + [DiscretizationTag], ElementGroupFactory], + get_connection_to_rank_boundary: Callable[[int], DiscretizationConnection]): + boundary_connections = {} + + from meshmode.distributed import get_connected_partitions + connected_parts = get_connected_partitions(volume_discr.mesh) + + if connected_parts: + if mpi_communicator is None: + raise RuntimeError("must supply an MPI communicator when using a " + "distributed mesh") + + grp_factory = get_group_factory_for_discretization_tag(DISCR_TAG_BASE) + + local_boundary_connections = {} + for i_remote_part in connected_parts: + local_boundary_connections[i_remote_part] = \ + get_connection_to_rank_boundary(i_remote_part) + + from meshmode.distributed import MPIBoundaryCommSetupHelper + with MPIBoundaryCommSetupHelper(mpi_communicator, array_context, + local_boundary_connections, grp_factory) as bdry_setup_helper: + while True: + conns = bdry_setup_helper.complete_some() + if not conns: + break + for i_remote_part, conn in conns.items(): + boundary_connections[i_remote_part] = conn + + return boundary_connections + +# }}} + + def _generate_modal_group_factory(nodal_group_factory): from meshmode.discretization.poly_element import ( ModalSimplexGroupFactory, diff --git a/grudge/dof_desc.py b/grudge/dof_desc.py index 621f245e7..aba344198 100644 --- a/grudge/dof_desc.py +++ b/grudge/dof_desc.py @@ -25,6 +25,7 @@ THE SOFTWARE. """ +from typing import Hashable from meshmode.discretization.connection import \ FACE_RESTR_INTERIOR, FACE_RESTR_ALL from meshmode.mesh import \ @@ -53,6 +54,9 @@ # {{{ DOF description +DiscretizationTag = Hashable + + class DTAG_SCALAR: # noqa: N801 """A domain tag denoting scalar values."""