Skip to content

Commit

Permalink
handle all-Number cases
Browse files Browse the repository at this point in the history
  • Loading branch information
majosm committed Mar 25, 2022
1 parent ae737f7 commit 012341d
Showing 1 changed file with 28 additions and 9 deletions.
37 changes: 28 additions & 9 deletions grudge/trace_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,14 @@ def cross_rank_trace_pairs(
{dcoll._part_id_helper.get_mpi_rank(part_id) for part_id in remote_part_ids})

actx = get_container_context_recursively(ary)
assert actx is not None

if actx is None:
# NOTE: Assumes that the same number is passed on every rank
return [
TracePair(
volume_dd.trace(BTAG_PARTITION(remote_part_id)),
interior=ary, exterior=ary)
for remote_part_id in remote_part_ids]

from grudge.array_context import MPIPytatoArrayContextBase

Expand Down Expand Up @@ -923,14 +930,6 @@ def cross_rank_inter_volume_trace_pairs(
break
if actx is not None:
break
assert actx is not None

from grudge.array_context import MPIPytatoArrayContextBase

if isinstance(actx, MPIPytatoArrayContextBase):
rbc_class = _RankBoundaryCommunicationLazy
else:
rbc_class = _RankBoundaryCommunicationEager

def get_remote_connected_partitions(local_vol_dd, remote_vol_dd):
connected_part_ids = _connected_partitions(
Expand All @@ -941,6 +940,26 @@ def get_remote_connected_partitions(local_vol_dd, remote_vol_dd):
for part_id in connected_part_ids
if dcoll._part_id_helper.get_mpi_rank(part_id) != rank]

if actx is None:
# NOTE: Assumes that the same number is passed on every rank for a
# given volume
return {
(remote_vol_dd, local_vol_dd): [
TracePair(
local_vol_dd.trace(BTAG_PARTITION(remote_part_id)),
interior=local_vol_ary, exterior=remote_vol_ary)
for remote_part_id in get_remote_connected_partitions(
local_vol_dd, remote_vol_dd)]
for (remote_vol_dd, local_vol_dd), (remote_vol_ary, local_vol_ary)
in pairwise_volume_data.items()}

from grudge.array_context import MPIPytatoArrayContextBase

if isinstance(actx, MPIPytatoArrayContextBase):
rbc_class = _RankBoundaryCommunicationLazy
else:
rbc_class = _RankBoundaryCommunicationEager

rank_bdry_communicators = {}

for vol_dd_pair, vol_data_pair in pairwise_volume_data.items():
Expand Down

0 comments on commit 012341d

Please sign in to comment.