diff --git a/grudge/trace_pair.py b/grudge/trace_pair.py index ef4ce8795..570729e81 100644 --- a/grudge/trace_pair.py +++ b/grudge/trace_pair.py @@ -876,10 +876,19 @@ def cross_rank_trace_pairs( local_bdry_data = project(dcoll, volume_dd, bdry_dd, ary) - remote_bdry_data_template = _replace_dof_arrays( - local_bdry_data, + from meshmode.array_context import tag_axes + from meshmode.transform_metadata import ( + DiscretizationElementAxisTag, + DiscretizationDOFAxisTag) + remote_bdry_zeros = tag_axes( dcoll._inter_partition_connections[ - remote_part_id, local_part_id].from_discr.zeros(actx)) + remote_part_id, local_part_id].from_discr.zeros(actx), + actx, { + 0: DiscretizationElementAxisTag(), + 1: DiscretizationDOFAxisTag()}) + + remote_bdry_data_template = _replace_dof_arrays( + local_bdry_data, remote_bdry_zeros) rank_bdry_communicators.append( rbc_class(actx, dcoll, @@ -996,15 +1005,19 @@ def get_remote_connected_partitions(local_vol_dd, remote_vol_dd): self_bdry_data = project( dcoll, self_vol_dd, self_bdry_dd, self_vol_data) - other_bdry_template_dd = other_vol_dd.trace( - BTAG_PARTITION(self_part_id)) - other_bdry_container_template = project( - dcoll, other_vol_dd, other_bdry_template_dd, - other_vol_data) - other_bdry_data_template = _replace_dof_arrays( - other_bdry_container_template, + from meshmode.array_context import tag_axes + from meshmode.transform_metadata import ( + DiscretizationElementAxisTag, + DiscretizationDOFAxisTag) + other_bdry_zeros = tag_axes( dcoll._inter_partition_connections[ - other_part_id, self_part_id].from_discr.zeros(actx)) + other_part_id, self_part_id].from_discr.zeros(actx), + actx, { + 0: DiscretizationElementAxisTag(), + 1: DiscretizationDOFAxisTag()}) + + other_bdry_data_template = _replace_dof_arrays( + other_vol_data, other_bdry_zeros) rbcs.append( rbc_class(actx, dcoll,