Skip to content

Commit

Permalink
Updates for Trilinos 16 (#1325)
Browse files Browse the repository at this point in the history
  • Loading branch information
jrood-nrel authored Nov 5, 2024
1 parent 86d08b5 commit 526c879
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 21 deletions.
28 changes: 14 additions & 14 deletions include/ngp_utils/NgpLoopUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ run_entity_algorithm(
const size_t bktLen = bkt.size();
Kokkos::parallel_for(
Kokkos::TeamThreadRange(team, bktLen), [&](const size_t& bktIndex) {
MeshIndex meshIdx{&bkt, static_cast<unsigned>(bktIndex)};
MeshIndex meshIdx{bkt.bucket_id(), static_cast<unsigned>(bktIndex)};
algorithm(meshIdx);
});
});
Expand Down Expand Up @@ -181,7 +181,7 @@ run_entity_par_reduce(
Kokkos::parallel_reduce(
Kokkos::TeamThreadRange(team, bktLen),
[&](const size_t& bktIndex, ReducerType& threadVal) {
MeshIndex meshIdx{&bkt, static_cast<unsigned>(bktIndex)};
MeshIndex meshIdx{bkt.bucket_id(), static_cast<unsigned>(bktIndex)};
algorithm(meshIdx, threadVal);
},
bktVal);
Expand Down Expand Up @@ -222,7 +222,7 @@ run_entity_par_reduce(
Kokkos::parallel_reduce(
Kokkos::TeamThreadRange(team, bktLen),
[&](const size_t& bktIndex, value_type& threadVal) {
MeshIndex meshIdx{&bkt, static_cast<unsigned>(bktIndex)};
MeshIndex meshIdx{bkt.bucket_id(), static_cast<unsigned>(bktIndex)};
algorithm(meshIdx, threadVal);
},
ReducerType(bktVal));
Expand Down Expand Up @@ -259,8 +259,8 @@ run_edge_algorithm(
run_entity_algorithm(
algName, mesh, rank, sel, KOKKOS_LAMBDA(MeshIndex & meshIdx) {
algorithm(EntityInfo<Mesh>{
meshIdx, (*meshIdx.bucket)[meshIdx.bucketOrd],
mesh.get_nodes(meshIdx)});
meshIdx, mesh.get_entity(rank, meshIdx),
mesh.get_nodes(rank, meshIdx)});
});
}

Expand Down Expand Up @@ -293,8 +293,8 @@ run_elem_algorithm(
run_entity_algorithm(
algName, mesh, rank, sel, KOKKOS_LAMBDA(MeshIndex & meshIdx) {
algorithm(EntityInfo<Mesh>{
meshIdx, (*meshIdx.bucket)[meshIdx.bucketOrd],
mesh.get_nodes(meshIdx)});
meshIdx, mesh.get_entity(rank, meshIdx),
mesh.get_nodes(rank, meshIdx)});
});
}

Expand Down Expand Up @@ -370,10 +370,10 @@ run_elem_algorithm(

for (int is = 0; is < nSimdElems; ++is) {
const unsigned bktOrd = bktIndex * simdLen + is;
MeshIndex meshIdx{&bkt, bktOrd};
MeshIndex meshIdx{bkt.bucket_id(), bktOrd};
const auto& elem = bkt[bktOrd];
elemData.elemInfo[is] =
EntityInfo<Mesh>{meshIdx, elem, ngpMesh.get_nodes(meshIdx)};
EntityInfo<Mesh>{meshIdx, elem, ngpMesh.get_nodes(rank, meshIdx)};

fill_pre_req_data(
dataReqNGP, ngpMesh, rank, elem, *elemData.scrView[is]);
Expand Down Expand Up @@ -469,10 +469,10 @@ run_elem_par_reduce(

for (int is = 0; is < nSimdElems; ++is) {
const unsigned bktOrd = bktIndex * simdLen + is;
MeshIndex meshIdx{&bkt, bktOrd};
MeshIndex meshIdx{bkt.bucket_id(), bktOrd};
const auto& elem = bkt[bktOrd];
elemData.elemInfo[is] =
EntityInfo<Mesh>{meshIdx, elem, ngpMesh.get_nodes(meshIdx)};
EntityInfo<Mesh>{meshIdx, elem, ngpMesh.get_nodes(rank, meshIdx)};

fill_pre_req_data(
dataReqNGP, ngpMesh, rank, elem, *elemData.scrView[is]);
Expand Down Expand Up @@ -571,7 +571,7 @@ run_face_elem_algorithm(
break;

const auto elems = ngpMesh.get_elements(sideRank, faceIdx);
MeshIndex meshIdx{&bkt, static_cast<unsigned>(bktOrd)};
MeshIndex meshIdx{bkt.bucket_id(), static_cast<unsigned>(bktOrd)};
const auto elem = elems[0];
const auto elemIdx = ngpMesh.fast_mesh_index(elem);
faceElemData.faceInfo[simdFaceIdx] = BcFaceElemInfo<Mesh>{
Expand Down Expand Up @@ -698,7 +698,7 @@ run_face_elem_par_reduce(
break;

const auto elems = ngpMesh.get_elements(sideRank, faceIdx);
MeshIndex meshIdx{&bkt, static_cast<unsigned>(bktOrd)};
MeshIndex meshIdx{bkt.bucket_id(), static_cast<unsigned>(bktOrd)};
const auto elem = elems[0];
const auto elemIdx = ngpMesh.fast_mesh_index(elem);
faceElemData.faceInfo[simdFaceIdx] = BcFaceElemInfo<Mesh>{
Expand Down Expand Up @@ -802,7 +802,7 @@ run_face_elem_algorithm_nosimd(
const size_t bktLen = bkt.size();
Kokkos::parallel_for(
Kokkos::TeamThreadRange(team, bktLen), [&](const size_t& bktIndex) {
MeshIndex meshIdx{&bkt, static_cast<unsigned>(bktIndex)};
MeshIndex meshIdx{bkt.bucket_id(), static_cast<unsigned>(bktIndex)};
const auto face = bkt[bktIndex];
const auto faceIdx = ngpMesh.fast_mesh_index(face);
const auto elements = ngpMesh.get_elements(sideRank, faceIdx);
Expand Down
4 changes: 2 additions & 2 deletions src/HypreLinearSystem.C
Original file line number Diff line number Diff line change
Expand Up @@ -2444,7 +2444,7 @@ HypreLinearSystem::applyDirichletBCs(
nalu_ngp::run_entity_algorithm(
"HypreLinearSystem::applyDirichletBCs", ngpMesh, stk::topology::NODE_RANK,
selector, KOKKOS_LAMBDA(const Traits::MeshIndex& mi) {
const auto node = (*mi.bucket)[mi.bucketOrd];
const auto node = ngpMesh.get_entity(stk::topology::NODE_RANK, mi);
HypreIntType hid = hypreGID.get(ngpMesh, node, 0);
for (unsigned d = 0; d < numDof; ++d) {
HypreIntType lid = hid * numDof + d;
Expand Down Expand Up @@ -2586,7 +2586,7 @@ HypreLinearSystem::copy_hypre_to_stk(stk::mesh::FieldBase* stkField)
nalu_ngp::run_entity_algorithm(
"HypreLinearSystem::copy_hypre_to_stk", ngpMesh, stk::topology::NODE_RANK,
selector, KOKKOS_LAMBDA(const Traits::MeshIndex& mi) {
const auto node = (*mi.bucket)[mi.bucketOrd];
const auto node = ngpMesh.get_entity(stk::topology::NODE_RANK, mi);
HypreIntType hid;
if (periodic_node_to_hypre_id.exists(node.local_offset()))
hid = periodic_node_to_hypre_id.value_at(
Expand Down
6 changes: 3 additions & 3 deletions src/HypreUVWLinearSystem.C
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ HypreUVWLinearSystem::applyDirichletBCs(
"HypreUVWLinearSystem::applyDirichletBCs", ngpMesh,
stk::topology::NODE_RANK, selector,
KOKKOS_LAMBDA(const Traits::MeshIndex& mi) {
const auto node = (*mi.bucket)[mi.bucketOrd];
const auto node = ngpMesh.get_entity(stk::topology::NODE_RANK, mi);
HypreIntType hid = hypreGID.get(ngpMesh, node, 0);
unsigned matIndex = mat_row_start_owned(hid - iLower);
vals(matIndex) = 1.0;
Expand Down Expand Up @@ -580,7 +580,7 @@ HypreUVWLinearSystem::copy_hypre_to_stk(
"HypreUVWLinearSystem::copy_hypre_to_stk_3D", ngpMesh,
stk::topology::NODE_RANK, selector,
KOKKOS_LAMBDA(const Traits::MeshIndex& mi) {
const auto node = (*mi.bucket)[mi.bucketOrd];
const auto node = ngpMesh.get_entity(stk::topology::NODE_RANK, mi);
HypreIntType hid;
if (periodic_node_to_hypre_id.exists(node.local_offset()))
hid = periodic_node_to_hypre_id.value_at(
Expand All @@ -607,7 +607,7 @@ HypreUVWLinearSystem::copy_hypre_to_stk(
"HypreUVWLinearSystem::copy_hypre_to_stk_3D", ngpMesh,
stk::topology::NODE_RANK, selector,
KOKKOS_LAMBDA(const Traits::MeshIndex& mi) {
const auto node = (*mi.bucket)[mi.bucketOrd];
const auto node = ngpMesh.get_entity(stk::topology::NODE_RANK, mi);
HypreIntType hid;
if (periodic_node_to_hypre_id.exists(node.local_offset()))
hid = periodic_node_to_hypre_id.value_at(
Expand Down
6 changes: 4 additions & 2 deletions src/TpetraLinearSystem.C
Original file line number Diff line number Diff line change
Expand Up @@ -1726,7 +1726,8 @@ TpetraLinearSystem::applyDirichletBCs(
nalu_ngp::run_entity_algorithm(
"TpetraLinSys::applyDirichletBCs", ngpMesh, stk::topology::NODE_RANK,
selector, KOKKOS_LAMBDA(const MeshIndex& meshIdx) {
stk::mesh::Entity entity = (*meshIdx.bucket)[meshIdx.bucketOrd];
stk::mesh::Entity entity =
ngpMesh.get_entity(stk::topology::NODE_RANK, meshIdx);
const LocalOrdinal localIdOffset = entityToLID[entity.local_offset()];
const bool useOwned = localIdOffset < maxOwnedRowId;
const LinSys::LocalMatrix& local_matrix =
Expand Down Expand Up @@ -2210,7 +2211,8 @@ TpetraLinearSystem::copy_tpetra_to_stk(
nalu_ngp::run_entity_algorithm(
"TpetraLinSys::copy_tpetra_to_stk", ngpMesh, stk::topology::NODE_RANK,
selector, KOKKOS_LAMBDA(const MeshIndex& meshIdx) {
stk::mesh::Entity node = (*meshIdx.bucket)[meshIdx.bucketOrd];
stk::mesh::Entity node =
ngpMesh.get_entity(stk::topology::NODE_RANK, meshIdx);
const LocalOrdinal localIdOffset = entityToLID[node.local_offset()];
for (unsigned d = 0; d < numDof; ++d) {
const LocalOrdinal localId = localIdOffset + d;
Expand Down

0 comments on commit 526c879

Please sign in to comment.