Skip to content

Commit

Permalink
set up connections between volumes
Browse files Browse the repository at this point in the history
  • Loading branch information
majosm committed Sep 26, 2022
1 parent 0aa399a commit afc56b6
Showing 1 changed file with 208 additions and 73 deletions.
281 changes: 208 additions & 73 deletions grudge/discretization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
.. autofunction:: make_discretization_collection
.. currentmodule:: grudge.discretization
.. autoclass:: PartID
"""

__copyright__ = """
Expand Down Expand Up @@ -34,10 +35,12 @@
THE SOFTWARE.
"""

from typing import Mapping, Optional, Union, TYPE_CHECKING, Any
from typing import Sequence, Mapping, Optional, Union, Tuple, TYPE_CHECKING, Any

from pytools import memoize_method, single_valued

from dataclasses import dataclass, replace

from grudge.dof_desc import (
VTAG_ALL,
DD_VOLUME_ALL,
Expand Down Expand Up @@ -71,6 +74,75 @@
import mpi4py.MPI


@dataclass(frozen=True)
class PartID:
"""Unique identifier for a piece of a partitioned mesh.
.. attribute:: volume_tag
The volume of the part.
.. attribute:: rank
The (optional) MPI rank of the part.
"""
volume_tag: VolumeTag
rank: Optional[int] = None


# {{{ part ID normalization

def _normalize_mesh_part_ids(
mesh: Mesh,
self_volume_tag: VolumeTag,
all_volume_tags: Sequence[VolumeTag],
mpi_communicator: Optional["mpi4py.MPI.Intracomm"] = None):
"""Convert a mesh's configuration-dependent "part ID" into a fixed type."""
from numbers import Integral
if mpi_communicator is not None:
# Accept PartID or rank (assume intra-volume for the latter)
def as_part_id(mesh_part_id):
if isinstance(mesh_part_id, PartID):
return mesh_part_id
elif isinstance(mesh_part_id, Integral):
return PartID(self_volume_tag, int(mesh_part_id))
else:
raise TypeError(f"Unable to convert {mesh_part_id} to PartID.")
else:
# Accept PartID or volume tag
def as_part_id(mesh_part_id):
if isinstance(mesh_part_id, PartID):
return mesh_part_id
elif mesh_part_id in all_volume_tags:
return PartID(mesh_part_id)
else:
raise TypeError(f"Unable to convert {mesh_part_id} to PartID.")

facial_adjacency_groups = mesh.facial_adjacency_groups

new_facial_adjacency_groups = []

from meshmode.mesh import InterPartAdjacencyGroup
for grp_list in facial_adjacency_groups:
new_grp_list = []
for fagrp in grp_list:
if isinstance(fagrp, InterPartAdjacencyGroup):
part_id = as_part_id(fagrp.part_id)
new_fagrp = replace(
fagrp,
boundary_tag=BTAG_PARTITION(part_id),
part_id=part_id)
else:
new_fagrp = fagrp
new_grp_list.append(new_fagrp)
new_facial_adjacency_groups.append(new_grp_list)

return mesh.copy(facial_adjacency_groups=new_facial_adjacency_groups)

# }}}


# {{{ discr_tag_to_group_factory normalization

def _normalize_discr_tag_to_group_factory(
Expand Down Expand Up @@ -156,6 +228,9 @@ def __init__(self, array_context: ArrayContext,
discr_tag_to_group_factory: Optional[
Mapping[DiscretizationTag, ElementGroupFactory]] = None,
mpi_communicator: Optional["mpi4py.MPI.Intracomm"] = None,
inter_part_connections: Optional[
Mapping[Tuple[PartID, PartID],
DiscretizationConnection]] = None,
) -> None:
"""
:arg discr_tag_to_group_factory: A mapping from discretization tags
Expand Down Expand Up @@ -206,6 +281,9 @@ def __init__(self, array_context: ArrayContext,

mesh = volume_discrs

mesh = _normalize_mesh_part_ids(
mesh, VTAG_ALL, [VTAG_ALL], mpi_communicator=mpi_communicator)

discr_tag_to_group_factory = _normalize_discr_tag_to_group_factory(
dim=mesh.dim,
discr_tag_to_group_factory=discr_tag_to_group_factory,
Expand All @@ -219,17 +297,32 @@ def __init__(self, array_context: ArrayContext,

del mesh

if inter_part_connections is not None:
raise TypeError("may not pass inter_part_connections when "
"DiscretizationCollection constructor is called in "
"legacy mode")

self._inter_part_connections = \
_set_up_inter_part_connections(
array_context=self._setup_actx,
mpi_communicator=mpi_communicator,
volume_discrs=volume_discrs,
base_group_factory=(
discr_tag_to_group_factory[DISCR_TAG_BASE]))

# }}}
else:
assert discr_tag_to_group_factory is not None
self._discr_tag_to_group_factory = discr_tag_to_group_factory

self._volume_discrs = volume_discrs
if inter_part_connections is None:
raise TypeError("inter_part_connections must be passed when "
"DiscretizationCollection constructor is called in "
"'modern' mode")

self._inter_part_connections = inter_part_connections

self._dist_boundary_connections = {
vtag: self._set_up_distributed_communication(
vtag, mpi_communicator, array_context)
for vtag in self._volume_discrs.keys()}
self._volume_discrs = volume_discrs

# }}}

Expand All @@ -252,71 +345,6 @@ def is_management_rank(self):
return self.mpi_communicator.Get_rank() \
== self.get_management_rank_index()

# {{{ distributed

def _set_up_distributed_communication(
self, vtag, mpi_communicator, array_context):
from_dd = DOFDesc(VolumeDomainTag(vtag), DISCR_TAG_BASE)

boundary_connections = {}

from meshmode.distributed import get_connected_partitions
connected_parts = get_connected_partitions(self._volume_discrs[vtag].mesh)

if connected_parts:
if mpi_communicator is None:
raise RuntimeError("must supply an MPI communicator when using a "
"distributed mesh")

grp_factory = \
self.group_factory_for_discretization_tag(DISCR_TAG_BASE)

local_boundary_connections = {}
for i_remote_part in connected_parts:
local_boundary_connections[i_remote_part] = self.connection_from_dds(
from_dd, from_dd.trace(BTAG_PARTITION(i_remote_part)))

from meshmode.distributed import MPIBoundaryCommSetupHelper
with MPIBoundaryCommSetupHelper(mpi_communicator, array_context,
local_boundary_connections, grp_factory) as bdry_setup_helper:
while True:
conns = bdry_setup_helper.complete_some()
if not conns:
break
for i_remote_part, conn in conns.items():
boundary_connections[i_remote_part] = conn

return boundary_connections

def distributed_boundary_swap_connection(self, dd):
"""Provides a mapping from the base volume discretization
to the exterior boundary restriction on a parallel boundary
partition described by *dd*. This connection is used to
communicate across element boundaries in different parallel
partitions during distributed runs.
:arg dd: a :class:`~grudge.dof_desc.DOFDesc`, or a value
convertible to one. The domain tag must be a subclass
of :class:`grudge.dof_desc.BoundaryDomainTag` with an
associated :class:`meshmode.mesh.BTAG_PARTITION`
corresponding to a particular communication rank.
"""
if dd.discretization_tag is not DISCR_TAG_BASE:
# FIXME
raise NotImplementedError(
"Distributed communication with discretization tag "
f"{dd.discretization_tag} is not implemented."
)

assert isinstance(dd.domain_tag, BoundaryDomainTag)
assert isinstance(dd.domain_tag.tag, BTAG_PARTITION)

vtag = dd.domain_tag.volume_tag

return self._dist_boundary_connections[vtag][dd.domain_tag.tag.part_nr]

# }}}

# {{{ discr_from_dd

@memoize_method
Expand Down Expand Up @@ -772,6 +800,105 @@ def normal(self, dd):
# }}}


# {{{ distributed/multi-volume setup

def _set_up_inter_part_connections(
array_context: ArrayContext,
mpi_communicator: Optional["mpi4py.MPI.Intracomm"],
volume_discrs: Mapping[VolumeTag, Discretization],
base_group_factory: ElementGroupFactory,
) -> Mapping[
Tuple[PartID, PartID],
DiscretizationConnection]:

from meshmode.distributed import (get_connected_parts,
make_remote_group_infos, InterRankBoundaryInfo,
MPIBoundaryCommSetupHelper)

rank = mpi_communicator.Get_rank() if mpi_communicator is not None else None

# Save boundary restrictions as they're created to avoid potentially creating
# them twice in the loop below
cached_part_bdry_restrictions: Mapping[
Tuple[PartID, PartID],
DiscretizationConnection] = {}

def get_part_bdry_restriction(self_part_id, other_part_id):
cached_result = cached_part_bdry_restrictions.get(
(self_part_id, other_part_id), None)
if cached_result is not None:
return cached_result
return cached_part_bdry_restrictions.setdefault(
(self_part_id, other_part_id),
make_face_restriction(
array_context, volume_discrs[self_part_id.volume_tag],
base_group_factory,
boundary_tag=BTAG_PARTITION(other_part_id)))

inter_part_conns: Mapping[
Tuple[PartID, PartID],
DiscretizationConnection] = {}

irbis = []

for vtag, volume_discr in volume_discrs.items():
part_id = PartID(vtag, rank)
connected_part_ids = get_connected_parts(volume_discr.mesh)
for connected_part_id in connected_part_ids:
bdry_restr = get_part_bdry_restriction(
self_part_id=part_id, other_part_id=connected_part_id)

if connected_part_id.rank == rank:
# {{{ rank-local interface between multiple volumes

connected_bdry_restr = get_part_bdry_restriction(
self_part_id=connected_part_id, other_part_id=part_id)

from meshmode.discretization.connection import \
make_partition_connection
inter_part_conns[connected_part_id, part_id] = \
make_partition_connection(
array_context,
local_bdry_conn=bdry_restr,
remote_bdry_discr=connected_bdry_restr.to_discr,
remote_group_infos=make_remote_group_infos(
array_context, part_id, connected_bdry_restr))

# }}}
else:
# {{{ cross-rank interface

if mpi_communicator is None:
raise RuntimeError("must supply an MPI communicator "
"when using a distributed mesh")

irbis.append(
InterRankBoundaryInfo(
local_part_id=part_id,
remote_part_id=connected_part_id,
remote_rank=connected_part_id.rank,
local_boundary_connection=bdry_restr))

# }}}

if irbis:
assert mpi_communicator is not None

with MPIBoundaryCommSetupHelper(mpi_communicator, array_context,
irbis, base_group_factory) as bdry_setup_helper:
while True:
conns = bdry_setup_helper.complete_some()
if not conns:
# We're done.
break

inter_part_conns.update(conns)

return inter_part_conns

# }}}


# {{{ modal group factory

def _generate_modal_group_factory(nodal_group_factory):
Expand Down Expand Up @@ -860,6 +987,8 @@ def make_discretization_collection(

del order

mpi_communicator = getattr(array_context, "mpi_communicator", None)

if any(
isinstance(mesh_or_discr, Discretization)
for mesh_or_discr in volumes.values()):
Expand All @@ -868,14 +997,20 @@ def make_discretization_collection(
volume_discrs = {
vtag: Discretization(
array_context,
mesh,
_normalize_mesh_part_ids(
mesh, vtag, volumes.keys(), mpi_communicator=mpi_communicator),
discr_tag_to_group_factory[DISCR_TAG_BASE])
for vtag, mesh in volumes.items()}

return DiscretizationCollection(
array_context=array_context,
volume_discrs=volume_discrs,
discr_tag_to_group_factory=discr_tag_to_group_factory)
discr_tag_to_group_factory=discr_tag_to_group_factory,
inter_part_connections=_set_up_inter_part_connections(
array_context=array_context,
mpi_communicator=mpi_communicator,
volume_discrs=volume_discrs,
base_group_factory=discr_tag_to_group_factory[DISCR_TAG_BASE]))

# }}}

Expand Down

0 comments on commit afc56b6

Please sign in to comment.