diff --git a/grudge/trace_pair.py b/grudge/trace_pair.py index 5a7ef322d..ea1d5f880 100644 --- a/grudge/trace_pair.py +++ b/grudge/trace_pair.py @@ -840,7 +840,14 @@ def cross_rank_trace_pairs( {dcoll._part_id_helper.get_mpi_rank(part_id) for part_id in remote_part_ids}) actx = get_container_context_recursively(ary) - assert actx is not None + + if actx is None: + # NOTE: Assumes that the same number is passed on every rank + return [ + TracePair( + volume_dd.trace(BTAG_PARTITION(remote_part_id)), + interior=ary, exterior=ary) + for remote_part_id in remote_part_ids] from grudge.array_context import MPIPytatoArrayContextBase @@ -923,14 +930,6 @@ def cross_rank_inter_volume_trace_pairs( break if actx is not None: break - assert actx is not None - - from grudge.array_context import MPIPytatoArrayContextBase - - if isinstance(actx, MPIPytatoArrayContextBase): - rbc_class = _RankBoundaryCommunicationLazy - else: - rbc_class = _RankBoundaryCommunicationEager def get_remote_connected_partitions(local_vol_dd, remote_vol_dd): connected_part_ids = _connected_partitions( @@ -941,6 +940,26 @@ def get_remote_connected_partitions(local_vol_dd, remote_vol_dd): for part_id in connected_part_ids if dcoll._part_id_helper.get_mpi_rank(part_id) != rank] + if actx is None: + # NOTE: Assumes that the same number is passed on every rank for a + # given volume + return { + (remote_vol_dd, local_vol_dd): [ + TracePair( + local_vol_dd.trace(BTAG_PARTITION(remote_part_id)), + interior=local_vol_ary, exterior=remote_vol_ary) + for remote_part_id in get_remote_connected_partitions( + local_vol_dd, remote_vol_dd)] + for (remote_vol_dd, local_vol_dd), (remote_vol_ary, local_vol_ary) + in pairwise_volume_data.items()} + + from grudge.array_context import MPIPytatoArrayContextBase + + if isinstance(actx, MPIPytatoArrayContextBase): + rbc_class = _RankBoundaryCommunicationLazy + else: + rbc_class = _RankBoundaryCommunicationEager + rank_bdry_communicators = {} for vol_dd_pair, vol_data_pair in pairwise_volume_data.items():