diff --git a/grudge/trace_pair.py b/grudge/trace_pair.py index ea1d5f880..115863b5b 100644 --- a/grudge/trace_pair.py +++ b/grudge/trace_pair.py @@ -59,7 +59,7 @@ ArrayContext, with_container_arithmetic, dataclass_array_container, - get_container_context_recursively, + get_container_context_recursively_opt, to_numpy, from_numpy) from arraycontext.container import ArrayOrContainerT @@ -259,7 +259,7 @@ def bv_trace_pair( DeprecationWarning, stacklevel=2) dd = dof_desc.as_dofdesc(dd) return bdry_trace_pair( - dcoll, dd, project(dcoll, "vol", dd, interior), exterior) + dcoll, dd, project(dcoll, dd.domain_tag.volume_tag, dd, interior), exterior) # }}} @@ -471,7 +471,7 @@ def inter_volume_trace_pairs(dcoll: DiscretizationCollection, result[directional_vol_dd_pair] = [tpair] for directional_vol_dd_pair, tpairs in cross_rank_tpairs.items(): - result.setdefault(directional_vol_dd_pair, []).append(tpairs) + result.setdefault(directional_vol_dd_pair, []).extend(tpairs) return result @@ -839,7 +839,7 @@ def cross_rank_trace_pairs( assert len(remote_part_ids) == len( {dcoll._part_id_helper.get_mpi_rank(part_id) for part_id in remote_part_ids}) - actx = get_container_context_recursively(ary) + actx = get_container_context_recursively_opt(ary) if actx is None: # NOTE: Assumes that the same number is passed on every rank @@ -925,7 +925,7 @@ def cross_rank_inter_volume_trace_pairs( for vol_data_pair in pairwise_volume_data.values(): for vol_data in vol_data_pair: - actx = get_container_context_recursively(vol_data) + actx = get_container_context_recursively_opt(vol_data) if actx is not None: break if actx is not None: