Skip to content

Commit

Permalink
Rip _get_dist_boundary_connections_single_vol out of DColl.__init__
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Mar 15, 2022
1 parent 76725d6 commit f42f141
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 38 deletions.
94 changes: 56 additions & 38 deletions grudge/discretization.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
THE SOFTWARE.
"""

from typing import Callable
from pytools import memoize_method

from grudge.dof_desc import (
Expand All @@ -37,17 +38,20 @@
DISCR_TAG_MODAL,
DTAG_BOUNDARY,
DOFDesc,
DiscretizationTag,
as_dofdesc
)

import numpy as np # noqa: F401

from arraycontext import ArrayContext

from meshmode.discretization import ElementGroupFactory
from meshmode.discretization.connection import (
FACE_RESTR_INTERIOR,
FACE_RESTR_ALL,
make_face_restriction
make_face_restriction,
DiscretizationConnection
)
from meshmode.mesh import Mesh, BTAG_PARTITION

Expand Down Expand Up @@ -174,9 +178,18 @@ def __init__(self, array_context: ArrayContext, mesh: Mesh,

# }}}

self._dist_boundary_connections = \
self._set_up_distributed_communication(
mpi_communicator, array_context)
self._dist_boundary_connections = _get_dist_boundary_connections_single_vol(
volume_discr=self._volume_discr,
mpi_communicator=mpi_communicator,
array_context=self._setup_actx,
get_group_factory_for_discretization_tag=(
self.group_factory_for_discretization_tag),
get_connection_to_rank_boundary=(
lambda remote_rank:
self.connection_from_dds(
DOFDesc("vol", DISCR_TAG_BASE),
DOFDesc(BTAG_PARTITION(remote_rank),
DISCR_TAG_BASE))),)

# }}}

Expand All @@ -201,40 +214,6 @@ def is_management_rank(self):

# {{{ distributed

def _set_up_distributed_communication(self, mpi_communicator, array_context):
from_dd = DOFDesc("vol", DISCR_TAG_BASE)

boundary_connections = {}

from meshmode.distributed import get_connected_partitions
connected_parts = get_connected_partitions(self._volume_discr.mesh)

if connected_parts:
if mpi_communicator is None:
raise RuntimeError("must supply an MPI communicator when using a "
"distributed mesh")

grp_factory = \
self.group_factory_for_discretization_tag(DISCR_TAG_BASE)

local_boundary_connections = {}
for i_remote_part in connected_parts:
local_boundary_connections[i_remote_part] = self.connection_from_dds(
from_dd, DOFDesc(BTAG_PARTITION(i_remote_part),
DISCR_TAG_BASE))

from meshmode.distributed import MPIBoundaryCommSetupHelper
with MPIBoundaryCommSetupHelper(mpi_communicator, array_context,
local_boundary_connections, grp_factory) as bdry_setup_helper:
while True:
conns = bdry_setup_helper.complete_some()
if not conns:
break
for i_remote_part, conn in conns.items():
boundary_connections[i_remote_part] = conn

return boundary_connections

def distributed_boundary_swap_connection(self, dd):
"""Provides a mapping from the base volume discretization
to the exterior boundary restriction on a parallel boundary
Expand Down Expand Up @@ -719,6 +698,45 @@ def normal(self, dd):
# }}}


# {{{ distributed setup

def _get_dist_boundary_connections_single_vol(
volume_discr, mpi_communicator, array_context,
get_group_factory_for_discretization_tag: Callable[
[DiscretizationTag], ElementGroupFactory],
get_connection_to_rank_boundary: Callable[[int], DiscretizationConnection]):
boundary_connections = {}

from meshmode.distributed import get_connected_partitions
connected_parts = get_connected_partitions(volume_discr.mesh)

if connected_parts:
if mpi_communicator is None:
raise RuntimeError("must supply an MPI communicator when using a "
"distributed mesh")

grp_factory = get_group_factory_for_discretization_tag(DISCR_TAG_BASE)

local_boundary_connections = {}
for i_remote_part in connected_parts:
local_boundary_connections[i_remote_part] = \
get_connection_to_rank_boundary(i_remote_part)

from meshmode.distributed import MPIBoundaryCommSetupHelper
with MPIBoundaryCommSetupHelper(mpi_communicator, array_context,
local_boundary_connections, grp_factory) as bdry_setup_helper:
while True:
conns = bdry_setup_helper.complete_some()
if not conns:
break
for i_remote_part, conn in conns.items():
boundary_connections[i_remote_part] = conn

return boundary_connections

# }}}


def _generate_modal_group_factory(nodal_group_factory):
from meshmode.discretization.poly_element import (
ModalSimplexGroupFactory,
Expand Down
4 changes: 4 additions & 0 deletions grudge/dof_desc.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
THE SOFTWARE.
"""

from typing import Hashable
from meshmode.discretization.connection import \
FACE_RESTR_INTERIOR, FACE_RESTR_ALL
from meshmode.mesh import \
Expand Down Expand Up @@ -53,6 +54,9 @@

# {{{ DOF description

DiscretizationTag = Hashable


class DTAG_SCALAR: # noqa: N801
"""A domain tag denoting scalar values."""

Expand Down

0 comments on commit f42f141

Please sign in to comment.