Skip to content

Commit

Permalink
Merged in extremeParalfix (pull request #646)
Browse files Browse the repository at this point in the history
fixed extreme parallel scaling crash in forces

Approved-by: Phani Motamarri
  • Loading branch information
dsambit committed Dec 24, 2024
2 parents b7ec0e0 + e2bb987 commit 893ffce
Showing 1 changed file with 95 additions and 136 deletions.
231 changes: 95 additions & 136 deletions src/force/forceWfcContractions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -642,26 +642,27 @@ namespace dftfe
// MPI_Barrier(d_mpiCommParent);
// double kernel1_time = MPI_Wtime();

computeELocWfcEshelbyTensor(basisOperationsPtr,
densityQuadratureId,
BLASWrapperPtr,
flattenedArrayBlock,
numPsi,
numCells,
numQuads,
eigenValues,
partialOccupancies,
kcoordx,
kcoordy,
kcoordz,
onesVec,
cellsBlockSize,
psiQuadsFlat,
gradPsiQuadsFlat,
eshelbyTensorContributions,
eshelbyTensorQuadValues,
isFloatingChargeForces,
addEk);
if (numCells > 0)
computeELocWfcEshelbyTensor(basisOperationsPtr,
densityQuadratureId,
BLASWrapperPtr,
flattenedArrayBlock,
numPsi,
numCells,
numQuads,
eigenValues,
partialOccupancies,
kcoordx,
kcoordy,
kcoordz,
onesVec,
cellsBlockSize,
psiQuadsFlat,
gradPsiQuadsFlat,
eshelbyTensorContributions,
eshelbyTensorQuadValues,
isFloatingChargeForces,
addEk);

// dftfe::utils::deviceSynchronize();
// MPI_Barrier(d_mpiCommParent);
Expand All @@ -672,7 +673,7 @@ namespace dftfe
// interpolatePsiComputeELocWfcEshelbyTensorD inside blocked
// loop: "<<kernel1_time<<std::endl;

if (isPsp)
if (isPsp and numCells > 0)
{
// dftfe::utils::deviceSynchronize();
// MPI_Barrier(d_mpiCommParent);
Expand All @@ -687,67 +688,20 @@ namespace dftfe
CouplingStructure::diagonal,
oncvClassPtr->getCouplingMatrix(),
projectorKetTimesVector);
/*
dftfe::utils::MemoryStorage<dataTypes::number,
dftfe::utils::MemorySpace::HOST>
projectorKetTimesVectorHostData;
projectorKetTimesVectorHostData.resize(projectorKetTimesVector.getData().size());
projectorKetTimesVectorHostData.copyFrom(projectorKetTimesVector.getData());
double projectorKetTimesVectorHostDataNorm = 0.0;
for( unsigned int i = 0; i <
projectorKetTimesVectorHostData.size(); i++)
{
projectorKetTimesVectorHostDataNorm +=
projectorKetTimesVectorHostData.data()[i]*
projectorKetTimesVectorHostData.data()[i];
}
std::cout<<" projectorKetTimesVectorHostDataNorm =
"<<projectorKetTimesVectorHostDataNorm<<"\n";
*/
}

if (useHubbard)
{
flattenedArrayBlock.updateGhostValues();
basisOperationsPtr->distribute(flattenedArrayBlock);

hubbardClassPtr->getNonLocalOperator()->applyVCconjtransOnX(
flattenedArrayBlock,
kPointIndex,
CouplingStructure::dense,
hubbardClassPtr->getCouplingMatrix(spinIndex),
projectorKetTimesVectorHubbard);

/*
dftfe::utils::MemoryStorage<dataTypes::number,
dftfe::utils::MemorySpace::HOST>
projectorKetTimesVectorHubbardHostData;
projectorKetTimesVectorHubbardHostData.resize(projectorKetTimesVectorHubbard.getData().size());
projectorKetTimesVectorHubbardHostData.copyFrom(projectorKetTimesVectorHubbard.getData());
double projectorKetTimesVectorHubbardHostDataNorm = 0.0;
for( unsigned int i = 0; i <
projectorKetTimesVectorHubbardHostData.size(); i++)
{
std::cout<<" i = "<<i<<" HubbardHostData =
"<<projectorKetTimesVectorHubbardHostData.data()[i]<<"\n";
projectorKetTimesVectorHubbardHostDataNorm +=
projectorKetTimesVectorHubbardHostData.data()[i]*
projectorKetTimesVectorHubbardHostData.data()[i];
}
std::cout<<" projectorKetTimesVectorHubbardHostDataNorm =
"<<projectorKetTimesVectorHubbardHostDataNorm<<"\n";
*/
if (numCells > 0)
hubbardClassPtr->getNonLocalOperator()->applyVCconjtransOnX(
flattenedArrayBlock,
kPointIndex,
CouplingStructure::dense,
hubbardClassPtr->getCouplingMatrix(spinIndex),
projectorKetTimesVectorHubbard);
}

// dftfe::utils::deviceSynchronize();
Expand All @@ -759,71 +713,74 @@ namespace dftfe
// MPI_Barrier(d_mpiCommParent);
// double kernel3_time = MPI_Wtime();

if (isPsp || useHubbard)
if (numCells > 0)
{
interpolatePsiGradPsiNlpQuads(basisOperationsPtr,
nlpspQuadratureId,
BLASWrapperPtr,
flattenedArrayBlock,
numPsi,
numCells,
cellsBlockSize,
psiQuadsNLP,
gradPsiQuadsNLP);
}
if (isPsp || useHubbard)
{
interpolatePsiGradPsiNlpQuads(basisOperationsPtr,
nlpspQuadratureId,
BLASWrapperPtr,
flattenedArrayBlock,
numPsi,
numCells,
cellsBlockSize,
psiQuadsNLP,
gradPsiQuadsNLP);
}



if (totalNonTrivialPseudoWfcs > 0)
{
nlpPsiContraction(
BLASWrapperPtr,
psiQuadsNLP,
gradPsiQuadsNLP,
partialOccupancies,
onesVecNLP,
projectorKetTimesVector.data(),
nonTrivialIdToElemIdMap,
projecterKetTimesFlattenedVectorLocalIds,
numCells,
numQuadsNLP,
numPsi,
totalNonTrivialPseudoWfcs,
innerBlockSizeEnlp,
nlpContractionContribution,
projectorKetTimesPsiTimesVTimesPartOccContractionGradPsiQuadsFlattenedBlock,
projectorKetTimesPsiTimesVTimesPartOccContractionGradPsiQuadsFlattenedH,
if (totalNonTrivialPseudoWfcs > 0)
{
nlpPsiContraction(
BLASWrapperPtr,
psiQuadsNLP,
gradPsiQuadsNLP,
partialOccupancies,
onesVecNLP,
projectorKetTimesVector.data(),
nonTrivialIdToElemIdMap,
projecterKetTimesFlattenedVectorLocalIds,
numCells,
numQuadsNLP,
numPsi,
totalNonTrivialPseudoWfcs,
innerBlockSizeEnlp,
nlpContractionContribution,
projectorKetTimesPsiTimesVTimesPartOccContractionGradPsiQuadsFlattenedBlock,
projectorKetTimesPsiTimesVTimesPartOccContractionGradPsiQuadsFlattenedH,
#ifdef USE_COMPLEX
projectorKetTimesPsiTimesVTimesPartOccContractionPsiQuadsFlattenedBlock,
projectorKetTimesPsiTimesVTimesPartOccContractionPsiQuadsFlattenedH,
projectorKetTimesPsiTimesVTimesPartOccContractionPsiQuadsFlattenedBlock,
projectorKetTimesPsiTimesVTimesPartOccContractionPsiQuadsFlattenedH,
#endif
projectorKetTimesPsiTimesVTimesPartOccContractionGradPsiQuadsFlattenedHPinnedTemp);
}
projectorKetTimesPsiTimesVTimesPartOccContractionGradPsiQuadsFlattenedHPinnedTemp);
}

if (totalNonTrivialHubbardProjectors > 0)
{
nlpPsiContraction(
BLASWrapperPtr,
psiQuadsNLP,
gradPsiQuadsNLP,
partialOccupancies,
onesVecNLP,
projectorKetTimesVectorHubbard.data(),
nonTrivialIdToElemIdMapHubbard,
projecterKetTimesFlattenedVectorLocalIdsHubbard,
numCells,
numQuadsNLP,
numPsi,
totalNonTrivialHubbardProjectors,
innerBlockSizeHubbard,
contractionContributionHubbard,
projectorKetTimesPsiTimesVTimesPartOccContractionGradPsiQuadsFlattenedBlockHubbard,
projectorKetTimesPsiTimesVTimesPartOccContractionGradPsiQuadsFlattenedHHubbard,
if (totalNonTrivialHubbardProjectors > 0)
{
nlpPsiContraction(
BLASWrapperPtr,
psiQuadsNLP,
gradPsiQuadsNLP,
partialOccupancies,
onesVecNLP,
projectorKetTimesVectorHubbard.data(),
nonTrivialIdToElemIdMapHubbard,
projecterKetTimesFlattenedVectorLocalIdsHubbard,
numCells,
numQuadsNLP,
numPsi,
totalNonTrivialHubbardProjectors,
innerBlockSizeHubbard,
contractionContributionHubbard,
projectorKetTimesPsiTimesVTimesPartOccContractionGradPsiQuadsFlattenedBlockHubbard,
projectorKetTimesPsiTimesVTimesPartOccContractionGradPsiQuadsFlattenedHHubbard,
#ifdef USE_COMPLEX
projectorKetTimesPsiTimesVTimesPartOccContractionPsiQuadsFlattenedBlockHubbard,
projectorKetTimesPsiTimesVTimesPartOccContractionPsiQuadsFlattenedHHubbard,
projectorKetTimesPsiTimesVTimesPartOccContractionPsiQuadsFlattenedBlockHubbard,
projectorKetTimesPsiTimesVTimesPartOccContractionPsiQuadsFlattenedHHubbard,
#endif
projectorKetTimesPsiTimesVTimesPartOccContractionGradPsiQuadsFlattenedHPinnedTempHubbard);
projectorKetTimesPsiTimesVTimesPartOccContractionGradPsiQuadsFlattenedHPinnedTempHubbard);
}
}

// dftfe::utils::deviceSynchronize();
Expand Down Expand Up @@ -1274,11 +1231,13 @@ namespace dftfe
} // band parallelization
} // ivec loop

dftfe::utils::
MemoryTransfer<dftfe::utils::MemorySpace::HOST, memorySpace>::copy(
numCells * numQuads * 9,
eshelbyTensorQuadValuesH + kPoint * numCells * numQuads * 9,
elocWfcEshelbyTensorQuadValues.data());
if (numCells > 0)
dftfe::utils::MemoryTransfer<
dftfe::utils::MemorySpace::HOST,
memorySpace>::copy(numCells * numQuads * 9,
eshelbyTensorQuadValuesH +
kPoint * numCells * numQuads * 9,
elocWfcEshelbyTensorQuadValues.data());
/*
dftfe::utils::deviceMemcpyD2H(eshelbyTensorQuadValuesH +
kPoint * numCells * numQuads * 9,
Expand Down

0 comments on commit 893ffce

Please sign in to comment.