diff --git a/grudge/trace_pair.py b/grudge/trace_pair.py index 7d9f216bf..a16b1d100 100644 --- a/grudge/trace_pair.py +++ b/grudge/trace_pair.py @@ -50,7 +50,7 @@ from warnings import warn -from typing import List, Hashable, Optional, Type, Any, Sequence +from typing import List, Hashable, Optional, Tuple, Type, Any, Sequence, Mapping from pytools.persistent_dict import KeyBuilder @@ -60,9 +60,8 @@ with_container_arithmetic, dataclass_array_container, get_container_context_recursively, - flatten, to_numpy, - unflatten, from_numpy, - flat_size_and_dtype) + to_numpy, + from_numpy) from arraycontext.container import ArrayOrContainerT from dataclasses import dataclass @@ -70,7 +69,6 @@ from numbers import Number from pytools import memoize_on_first_arg -from pytools.obj_array import obj_array_vectorize from grudge.discretization import DiscretizationCollection from grudge.projection import project @@ -296,16 +294,22 @@ def local_interior_trace_pair( interior = project(dcoll, volume_dd, trace_dd, vec) - def get_opposite_trace(el): - if isinstance(el, Number): - return el + opposite_face_conn = dcoll.opposite_face_connection(trace_dd.domain_tag) + + def get_opposite_trace(ary): + if isinstance(ary, Number): + return ary else: - assert isinstance(trace_dd.domain_tag, BoundaryDomainTag) - return dcoll.opposite_face_connection(trace_dd.domain_tag)(el) + return opposite_face_conn(ary) - e = obj_array_vectorize(get_opposite_trace, interior) + from arraycontext import rec_map_array_container + from meshmode.dof_array import DOFArray + exterior = rec_map_array_container( + get_opposite_trace, + interior, + leaf_class=DOFArray) - return TracePair(trace_dd, interior=interior, exterior=e) + return TracePair(trace_dd, interior=interior, exterior=exterior) def interior_trace_pair(dcoll: DiscretizationCollection, vec) -> TracePair: @@ -363,58 +367,90 @@ def interior_trace_pairs(dcoll: DiscretizationCollection, vec, *, def local_inter_volume_trace_pairs( dcoll: DiscretizationCollection, - self_volume_dd: DOFDesc, self_ary: ArrayOrContainerT, - other_volume_dd: DOFDesc, other_ary: ArrayOrContainerT, - ) -> ArrayOrContainerT: - if not isinstance(self_volume_dd.domain_tag, VolumeDomainTag): - raise ValueError("self_volume_dd must describe a volume") - if not isinstance(other_volume_dd.domain_tag, VolumeDomainTag): - raise ValueError("other_volume_dd must describe a volume") - if self_volume_dd.discretization_tag != DISCR_TAG_BASE: - raise TypeError( - f"expected a base-discretized self DOFDesc, got '{self_volume_dd}'") - if other_volume_dd.discretization_tag != DISCR_TAG_BASE: - raise TypeError( - f"expected a base-discretized other DOFDesc, got '{other_volume_dd}'") + pairwise_volume_data: Mapping[ + Tuple[DOFDesc, DOFDesc], + Tuple[ArrayOrContainerT, ArrayOrContainerT]] + ) -> Mapping[Tuple[DOFDesc, DOFDesc], TracePair]: + for vol_dd_pair in pairwise_volume_data.keys(): + for vol_dd in vol_dd_pair: + if not isinstance(vol_dd.domain_tag, VolumeDomainTag): + raise ValueError( + "pairwise_volume_data keys must describe volumes, " + f"got '{vol_dd}'") + if vol_dd.discretization_tag != DISCR_TAG_BASE: + raise ValueError( + "expected base-discretized DOFDesc in pairwise_volume_data, " + f"got '{vol_dd}'") rank = ( dcoll.mpi_communicator.Get_rank() if dcoll.mpi_communicator is not None else None) - self_part_id = dcoll._part_id_helper.make(rank, self_volume_dd.domain_tag.tag) - other_part_id = dcoll._part_id_helper.make(rank, other_volume_dd.domain_tag.tag) - - self_trace_dd = self_volume_dd.trace(BTAG_PARTITION(other_part_id)) - other_trace_dd = other_volume_dd.trace(BTAG_PARTITION(self_part_id)) - - # FIXME: In all likelihood, these traces will be reevaluated from - # the other side, which is hard to prevent given the interface we - # have. Lazy eval will hopefully collapse those redundant evaluations... - self_trace = project( - dcoll, self_volume_dd, self_trace_dd, self_ary) - other_trace = project( - dcoll, other_volume_dd, other_trace_dd, other_ary) + result: Mapping[Tuple[DOFDesc, DOFDesc], TracePair] = {} + + for vol_dd_pair, vol_data_pair in pairwise_volume_data.items(): + directional_vol_dd_pairs = [ + (vol_dd_pair[1], vol_dd_pair[0]), + (vol_dd_pair[0], vol_dd_pair[1])] + + trace_dd_pair = tuple( + self_vol_dd.trace( + BTAG_PARTITION( + dcoll._part_id_helper.make( + rank, other_vol_dd.domain_tag.tag))) + for other_vol_dd, self_vol_dd in directional_vol_dd_pairs) + + # Pre-compute the projections out here to avoid doing it twice inside + # the loop below + trace_data = { + trace_dd: project(dcoll, vol_dd, trace_dd, vol_data) + for vol_dd, trace_dd, vol_data in zip( + vol_dd_pair, trace_dd_pair, vol_data_pair)} + + for other_vol_dd, self_vol_dd in directional_vol_dd_pairs: + self_part_id = dcoll._part_id_helper.make( + rank, self_vol_dd.domain_tag.tag) + other_part_id = dcoll._part_id_helper.make( + rank, other_vol_dd.domain_tag.tag) + + self_trace_dd = self_vol_dd.trace(BTAG_PARTITION(other_part_id)) + other_trace_dd = other_vol_dd.trace(BTAG_PARTITION(self_part_id)) + + self_trace_data = trace_data[self_trace_dd] + unswapped_other_trace_data = trace_data[other_trace_dd] + + other_to_self = dcoll._inter_partition_connections[ + other_part_id, self_part_id] + + def get_opposite_trace(ary): + if isinstance(ary, Number): + return ary + else: + return other_to_self(ary) + + from arraycontext import rec_map_array_container + from meshmode.dof_array import DOFArray + other_trace_data = rec_map_array_container( + get_opposite_trace, + unswapped_other_trace_data, + leaf_class=DOFArray) + + result[other_vol_dd, self_vol_dd] = TracePair( + self_trace_dd, + interior=self_trace_data, + exterior=other_trace_data) - other_to_self = dcoll._inter_partition_connections[ - other_part_id, self_part_id] - - def get_opposite_trace(el): - if isinstance(el, Number): - return el - else: - return other_to_self(el) - - return TracePair( - self_trace_dd, - interior=self_trace, - exterior=obj_array_vectorize(get_opposite_trace, other_trace)) + return result def inter_volume_trace_pairs(dcoll: DiscretizationCollection, - self_volume_dd: DOFDesc, self_ary: ArrayOrContainerT, - other_volume_dd: DOFDesc, other_ary: ArrayOrContainerT, - comm_tag: Hashable = None) -> List[ArrayOrContainerT]: + pairwise_volume_data: Mapping[ + Tuple[DOFDesc, DOFDesc], + Tuple[ArrayOrContainerT, ArrayOrContainerT]], + comm_tag: Hashable = None) -> Mapping[ + Tuple[DOFDesc, DOFDesc], + List[TracePair]]: """ Note that :func:`local_inter_volume_trace_pairs` provides the rank-local contributions if those are needed in isolation. Similarly, @@ -423,13 +459,21 @@ def inter_volume_trace_pairs(dcoll: DiscretizationCollection, """ # TODO documentation - return ( - [local_inter_volume_trace_pairs(dcoll, - self_volume_dd, self_ary, other_volume_dd, other_ary)] - + cross_rank_inter_volume_trace_pairs(dcoll, - self_volume_dd, self_ary, other_volume_dd, other_ary, - comm_tag=comm_tag) - ) + result: Mapping[ + Tuple[DOFDesc, DOFDesc], + List[TracePair]] = {} + + local_tpairs = local_inter_volume_trace_pairs(dcoll, pairwise_volume_data) + cross_rank_tpairs = cross_rank_inter_volume_trace_pairs( + dcoll, pairwise_volume_data, comm_tag=comm_tag) + + for directional_vol_dd_pair, tpair in local_tpairs.items(): + result[directional_vol_dd_pair] = [tpair] + + for directional_vol_dd_pair, tpairs in cross_rank_tpairs.items(): + result.setdefault(directional_vol_dd_pair, []).append(tpairs) + + return result # }}} @@ -500,12 +544,14 @@ def __init__(self, local_part_id: PartitionID, remote_part_id: PartitionID, local_bdry_data: ArrayOrContainerT, - send_data: ArrayOrContainerT, + remote_bdry_data_template: ArrayOrContainerT, comm_tag: Optional[Hashable] = None): comm = dcoll.mpi_communicator assert comm is not None + local_vtag = dcoll._part_id_helper.get_volume(local_part_id) + remote_rank = dcoll._part_id_helper.get_mpi_rank(remote_part_id) assert remote_rank is not None @@ -513,9 +559,14 @@ def __init__(self, self.array_context = actx self.local_part_id = local_part_id self.remote_part_id = remote_part_id - self.bdry_discr = dcoll.discr_from_dd( - BoundaryDomainTag(BTAG_PARTITION(remote_part_id))) + self.local_bdry_dd = DOFDesc( + BoundaryDomainTag( + BTAG_PARTITION(remote_part_id), + volume_tag=local_vtag), + DISCR_TAG_BASE) + self.bdry_discr = dcoll.discr_from_dd(self.local_bdry_dd) self.local_bdry_data = local_bdry_data + self.remote_bdry_data_template = remote_bdry_data_template self.comm_tag = self.base_comm_tag comm_tag = _sym_tag_to_num_tag(comm_tag) @@ -528,36 +579,73 @@ def __init__(self, # requests is complete, however it is not clear that this is documented # behavior. We hold on to the buffer (via the instance attribute) # as well, just in case. - self.send_data_np = to_numpy(flatten(send_data, actx), actx) - self.send_req = comm.Isend(self.send_data_np, - remote_rank, - tag=self.comm_tag) + self.send_reqs = [] + self.send_data = [] + + def send_single_array(key, local_subary): + if not isinstance(local_subary, Number): + local_subary_np = to_numpy(local_subary, actx) + self.send_reqs.append( + comm.Isend(local_subary_np, remote_rank, tag=self.comm_tag)) + self.send_data.append(local_subary_np) + return local_subary + + self.recv_reqs = [] + self.recv_data = {} + + def recv_single_array(key, remote_subary_template): + if not isinstance(remote_subary_template, Number): + remote_subary_np = np.empty( + remote_subary_template.shape, + remote_subary_template.dtype) + self.recv_reqs.append( + comm.Irecv(remote_subary_np, remote_rank, tag=self.comm_tag)) + self.recv_data[key] = remote_subary_np + return remote_subary_template - recv_size, recv_dtype = flat_size_and_dtype(local_bdry_data) - self.recv_data_np = np.empty(recv_size, recv_dtype) - self.recv_req = comm.Irecv(self.recv_data_np, remote_rank, tag=self.comm_tag) + from arraycontext.container.traversal import rec_keyed_map_array_container + rec_keyed_map_array_container(send_single_array, local_bdry_data) + rec_keyed_map_array_container(recv_single_array, remote_bdry_data_template) def finish(self): - # Wait for the nonblocking receive request to complete before + from mpi4py import MPI + + # Wait for the nonblocking receive requests to complete before # accessing the data - self.recv_req.Wait() + MPI.Request.waitall(self.recv_reqs) + + def finish_single_array(key, remote_subary_template): + if isinstance(remote_subary_template, Number): + # NOTE: Assumes that the same number is passed on every rank + return remote_subary_template + else: + return from_numpy(self.recv_data[key], self.array_context) - recv_data_flat = from_numpy( - self.recv_data_np, self.array_context) - unswapped_remote_bdry_data = unflatten(self.local_bdry_data, - recv_data_flat, self.array_context) - bdry_conn = self.dcoll._inter_partition_connections[ + from arraycontext.container.traversal import rec_keyed_map_array_container + unswapped_remote_bdry_data = rec_keyed_map_array_container( + finish_single_array, self.remote_bdry_data_template) + + remote_to_local = self.dcoll._inter_partition_connections[ self.remote_part_id, self.local_part_id] - remote_bdry_data = bdry_conn(unswapped_remote_bdry_data) - # Complete the nonblocking send request associated with communicating - # `self.local_bdry_data_np` - self.send_req.Wait() + def get_opposite_trace(ary): + if isinstance(ary, Number): + return ary + else: + return remote_to_local(ary) + + from arraycontext import rec_map_array_container + from meshmode.dof_array import DOFArray + remote_bdry_data = rec_map_array_container( + get_opposite_trace, + unswapped_remote_bdry_data, + leaf_class=DOFArray) + + # Complete the nonblocking send requests + MPI.Request.waitall(self.send_reqs) return TracePair( - DOFDesc( - BoundaryDomainTag(BTAG_PARTITION(self.remote_part_id)), - DISCR_TAG_BASE), + self.local_bdry_dd, interior=self.local_bdry_data, exterior=remote_bdry_data) @@ -574,63 +662,109 @@ def __init__(self, local_part_id: PartitionID, remote_part_id: PartitionID, local_bdry_data: ArrayOrContainerT, - send_data: ArrayOrContainerT, + remote_bdry_data_template: ArrayOrContainerT, comm_tag: Optional[Hashable] = None) -> None: if comm_tag is None: raise ValueError("lazy communication requires 'comm_tag' to be supplied") - self.dcoll = dcoll - self.array_context = actx - self.bdry_discr = dcoll.discr_from_dd( - BoundaryDomainTag(BTAG_PARTITION(remote_part_id))) - self.local_part_id = local_part_id - self.remote_part_id = remote_part_id + local_vtag = dcoll._part_id_helper.get_volume(local_part_id) remote_rank = dcoll._part_id_helper.get_mpi_rank(remote_part_id) assert remote_rank is not None - self.local_bdry_data = local_bdry_data - - from arraycontext.container.traversal import rec_keyed_map_array_container - - key_to_send_subary = {} - - def store_send_subary(key, send_subary): - key_to_send_subary[key] = send_subary - return send_subary - rec_keyed_map_array_container(store_send_subary, send_data) + self.dcoll = dcoll + self.array_context = actx + self.local_bdry_dd = DOFDesc( + BoundaryDomainTag( + BTAG_PARTITION(remote_part_id), + volume_tag=local_vtag), + DISCR_TAG_BASE) + self.bdry_discr = dcoll.discr_from_dd(self.local_bdry_dd) + self.local_part_id = local_part_id + self.remote_part_id = remote_part_id from pytato import make_distributed_recv, staple_distributed_send - def communicate_single_array(key, local_bdry_subary): - ary_tag = (comm_tag, key) - return staple_distributed_send( - key_to_send_subary[key], dest_rank=remote_rank, comm_tag=ary_tag, - stapled_to=make_distributed_recv( - src_rank=remote_rank, comm_tag=ary_tag, - shape=local_bdry_subary.shape, - dtype=local_bdry_subary.dtype)) + # Staple the sends to a bunch of dummy arrays of zeros + def send_single_array(key, local_subary): + if isinstance(local_subary, Number): + return 0 + else: + ary_tag = (comm_tag, key) + return staple_distributed_send( + local_subary, dest_rank=remote_rank, comm_tag=ary_tag, + stapled_to=actx.zeros_like(local_subary)) + + def recv_single_array(key, remote_subary_template): + if isinstance(remote_subary_template, Number): + # NOTE: Assumes that the same number is passed on every rank + return remote_subary_template + else: + ary_tag = (comm_tag, key) + return make_distributed_recv( + src_rank=remote_rank, comm_tag=ary_tag, + shape=remote_subary_template.shape, + dtype=remote_subary_template.dtype) - self.remote_data = rec_keyed_map_array_container( - communicate_single_array, self.local_bdry_data) + from arraycontext.container.traversal import rec_keyed_map_array_container + zeros_like_local_bdry_data = rec_keyed_map_array_container( + send_single_array, local_bdry_data) + unswapped_remote_bdry_data = rec_keyed_map_array_container( + recv_single_array, remote_bdry_data_template) + + # Sum up the dummy zeros + zero = actx.np.sum(zeros_like_local_bdry_data) + + # Add the dummy zeros and hope that the caller proceeds to actually + # use some of this data on every rank... + from arraycontext import rec_map_array_container + # This caused test_mpi_communication.py::test_func_comparison_mpi to fail + # for some reason +# self.local_bdry_data = rec_map_array_container( +# lambda x: x + zero, +# local_bdry_data) + self.local_bdry_data = local_bdry_data + self.unswapped_remote_bdry_data = rec_map_array_container( + lambda x: x + zero, + unswapped_remote_bdry_data) def finish(self): - bdry_conn = self.dcoll._inter_partition_connections[ + remote_to_local = self.dcoll._inter_partition_connections[ self.remote_part_id, self.local_part_id] + def get_opposite_trace(ary): + if isinstance(ary, Number): + return ary + else: + return remote_to_local(ary) + + from arraycontext import rec_map_array_container + from meshmode.dof_array import DOFArray + remote_bdry_data = rec_map_array_container( + get_opposite_trace, + self.unswapped_remote_bdry_data, + leaf_class=DOFArray) + return TracePair( - DOFDesc( - BoundaryDomainTag(BTAG_PARTITION(self.remote_part_id)), - DISCR_TAG_BASE), + self.local_bdry_dd, interior=self.local_bdry_data, - exterior=bdry_conn(self.remote_data)) + exterior=remote_bdry_data) # }}} # {{{ cross_rank_trace_pairs +def _replace_dof_arrays(array_container, dof_array): + from arraycontext import rec_map_array_container + from meshmode.dof_array import DOFArray + return rec_map_array_container( + lambda x: dof_array if isinstance(x, DOFArray) else x, + array_container, + leaf_class=DOFArray) + + def cross_rank_trace_pairs( dcoll: DiscretizationCollection, ary: ArrayOrContainerT, tag: Hashable = None, @@ -708,42 +842,37 @@ def cross_rank_trace_pairs( assert len(remote_part_ids) == len( {dcoll._part_id_helper.get_mpi_rank(part_id) for part_id in remote_part_ids}) - if isinstance(ary, Number): - # NOTE: Assumes that the same number is passed on every rank - return [ - TracePair( - DOFDesc( - BoundaryDomainTag(BTAG_PARTITION(remote_part_id)), - DISCR_TAG_BASE), - interior=ary, exterior=ary) - for remote_part_id in remote_part_ids] - actx = get_container_context_recursively(ary) assert actx is not None from grudge.array_context import MPIPytatoArrayContextBase if isinstance(actx, MPIPytatoArrayContextBase): - rbc = _RankBoundaryCommunicationLazy + rbc_class = _RankBoundaryCommunicationLazy else: - rbc = _RankBoundaryCommunicationEager + rbc_class = _RankBoundaryCommunicationEager - def start_comm(remote_part_id): - bdtag = BoundaryDomainTag(BTAG_PARTITION(remote_part_id)) + rank_bdry_communicators = [] - local_bdry_data = project(dcoll, volume_dd, bdtag, ary) + for remote_part_id in remote_part_ids: + bdry_dd = volume_dd.trace(BTAG_PARTITION(remote_part_id)) - return rbc(actx, dcoll, - local_part_id=local_part_id, - remote_part_id=remote_part_id, - local_bdry_data=local_bdry_data, - send_data=local_bdry_data, - comm_tag=comm_tag) + local_bdry_data = project(dcoll, volume_dd, bdry_dd, ary) - rank_bdry_communcators = [ - start_comm(remote_part_id) - for remote_part_id in remote_part_ids] - return [rc.finish() for rc in rank_bdry_communcators] + remote_bdry_data_template = _replace_dof_arrays( + local_bdry_data, + dcoll._inter_partition_connections[ + remote_part_id, local_part_id].from_discr.zeros(actx)) + + rank_bdry_communicators.append( + rbc_class(actx, dcoll, + local_part_id=local_part_id, + remote_part_id=remote_part_id, + local_bdry_data=local_bdry_data, + remote_bdry_data_template=remote_bdry_data_template, + comm_tag=comm_tag)) + + return [rbc.finish() for rbc in rank_bdry_communicators] # }}} @@ -752,10 +881,13 @@ def start_comm(remote_part_id): def cross_rank_inter_volume_trace_pairs( dcoll: DiscretizationCollection, - self_volume_dd: DOFDesc, self_ary: ArrayOrContainerT, - other_volume_dd: DOFDesc, other_ary: ArrayOrContainerT, + pairwise_volume_data: Mapping[ + Tuple[DOFDesc, DOFDesc], + Tuple[ArrayOrContainerT, ArrayOrContainerT]], *, comm_tag: Hashable = None, - ) -> List[TracePair]: + ) -> Mapping[ + Tuple[DOFDesc, DOFDesc], + List[TracePair]]: # FIXME: Should this interface take in boundary data instead? # TODO: Docs r"""Get a :class:`list` of *ary* trace pairs for each partition boundary. @@ -769,16 +901,16 @@ def cross_rank_inter_volume_trace_pairs( """ # {{{ process arguments - if not isinstance(self_volume_dd.domain_tag, VolumeDomainTag): - raise ValueError("self_volume_dd must describe a volume") - if not isinstance(other_volume_dd.domain_tag, VolumeDomainTag): - raise ValueError("other_volume_dd must describe a volume") - if self_volume_dd.discretization_tag != DISCR_TAG_BASE: - raise TypeError( - f"expected a base-discretized self DOFDesc, got '{self_volume_dd}'") - if other_volume_dd.discretization_tag != DISCR_TAG_BASE: - raise TypeError( - f"expected a base-discretized other DOFDesc, got '{other_volume_dd}'") + for vol_dd_pair in pairwise_volume_data.keys(): + for vol_dd in vol_dd_pair: + if not isinstance(vol_dd.domain_tag, VolumeDomainTag): + raise ValueError( + "pairwise_volume_data keys must describe volumes, " + f"got '{vol_dd}'") + if vol_dd.discretization_tag != DISCR_TAG_BASE: + raise ValueError( + "expected base-discretized DOFDesc in pairwise_volume_data, " + f"got '{vol_dd}'") # }}} @@ -787,51 +919,77 @@ def cross_rank_inter_volume_trace_pairs( rank = dcoll.mpi_communicator.Get_rank() - local_part_id = dcoll._part_id_helper.make(rank, self_volume_dd.domain_tag.tag) - - connected_part_ids = _connected_partitions( - dcoll, self_volume_tag=self_volume_dd.domain_tag.tag, - other_volume_tag=other_volume_dd.domain_tag.tag) - - remote_part_ids = [ - part_id - for part_id in connected_part_ids - if dcoll._part_id_helper.get_mpi_rank(part_id) != rank] - - # This asserts that there is only one data exchange per rank, so that - # there is no risk of mismatched data reaching the wrong recipient. - # (Since we have only a single tag.) - assert len(remote_part_ids) == len( - {dcoll._part_id_helper.get_mpi_rank(part_id) for part_id in remote_part_ids}) - - actx = get_container_context_recursively(self_ary) + for vol_data_pair in pairwise_volume_data.values(): + for vol_data in vol_data_pair: + actx = get_container_context_recursively(vol_data) + if actx is not None: + break + if actx is not None: + break assert actx is not None from grudge.array_context import MPIPytatoArrayContextBase if isinstance(actx, MPIPytatoArrayContextBase): - rbc = _RankBoundaryCommunicationLazy + rbc_class = _RankBoundaryCommunicationLazy else: - rbc = _RankBoundaryCommunicationEager - - def start_comm(remote_part_id): - bdtag = BoundaryDomainTag(BTAG_PARTITION(remote_part_id)) - - local_bdry_data = project(dcoll, self_volume_dd, bdtag, self_ary) - send_data = project(dcoll, other_volume_dd, - BTAG_PARTITION(local_part_id), other_ary) - - return rbc(actx, dcoll, - local_part_id=local_part_id, - remote_part_id=remote_part_id, - local_bdry_data=local_bdry_data, - send_data=send_data, - comm_tag=comm_tag) + rbc_class = _RankBoundaryCommunicationEager - rank_bdry_communcators = [ - start_comm(remote_part_id) - for remote_part_id in remote_part_ids] - return [rc.finish() for rc in rank_bdry_communcators] + def get_remote_connected_partitions(local_vol_dd, remote_vol_dd): + connected_part_ids = _connected_partitions( + dcoll, self_volume_tag=local_vol_dd.domain_tag.tag, + other_volume_tag=remote_vol_dd.domain_tag.tag) + return [ + part_id + for part_id in connected_part_ids + if dcoll._part_id_helper.get_mpi_rank(part_id) != rank] + + rank_bdry_communicators = {} + + for vol_dd_pair, vol_data_pair in pairwise_volume_data.items(): + directional_volume_data = { + (vol_dd_pair[0], vol_dd_pair[1]): (vol_data_pair[0], vol_data_pair[1]), + (vol_dd_pair[1], vol_dd_pair[0]): (vol_data_pair[1], vol_data_pair[0])} + + for dd_pair, data_pair in directional_volume_data.items(): + other_vol_dd, self_vol_dd = dd_pair + other_vol_data, self_vol_data = data_pair + + self_part_id = dcoll._part_id_helper.make( + rank, self_vol_dd.domain_tag.tag) + other_part_ids = get_remote_connected_partitions( + self_vol_dd, other_vol_dd) + + rbcs = [] + + for other_part_id in other_part_ids: + self_bdry_dd = self_vol_dd.trace(BTAG_PARTITION(other_part_id)) + self_bdry_data = project( + dcoll, self_vol_dd, self_bdry_dd, self_vol_data) + + other_bdry_template_dd = other_vol_dd.trace( + BTAG_PARTITION(self_part_id)) + other_bdry_container_template = project( + dcoll, other_vol_dd, other_bdry_template_dd, + other_vol_data) + other_bdry_data_template = _replace_dof_arrays( + other_bdry_container_template, + dcoll._inter_partition_connections[ + other_part_id, self_part_id].from_discr.zeros(actx)) + + rbcs.append( + rbc_class(actx, dcoll, + local_part_id=self_part_id, + remote_part_id=other_part_id, + local_bdry_data=self_bdry_data, + remote_bdry_data_template=other_bdry_data_template, + comm_tag=comm_tag)) + + rank_bdry_communicators[other_vol_dd, self_vol_dd] = rbcs + + return { + directional_vol_dd_pair: [rbc.finish() for rbc in rbcs] + for directional_vol_dd_pair, rbcs in rank_bdry_communicators.items()} # }}}