Skip to content

Commit

Permalink
Enabling Globals for Unstructured Backends (#1039)
Browse files Browse the repository at this point in the history
## Technical Description

This PR enables globals for the unstructured backends. Furthermore, an unreported issue where globals were not propagated from the wrapper class to the stencil class was fixed. Additionally, a bug in the unstructured cuda codegen was fixed when translating stencils that only use dense dimensions.

### Resolves / Enhances

Fixes #1030
Fixes #1028

### Notes

The methods to set and get globals in the cuda backend are on the inner stencils. This will be addressed in [this issue](#1038). Also, a method to communicate globals from FORTRAN will need to be devised (not addressed yet). 

### Testing

New tests in dawn4py and a new unstructured integration test to test the correct operation of the `CXXNaiveIco` backend. `CudaIco` backend tested manually. 

### Dependencies

This PR is independent.
  • Loading branch information
mroethlin authored Oct 16, 2020
1 parent 2189984 commit 713801a
Show file tree
Hide file tree
Showing 19 changed files with 346 additions and 43 deletions.
27 changes: 17 additions & 10 deletions dawn/src/dawn/CodeGen/CXXNaive-ico/CXXNaiveCodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ std::string CXXNaiveIcoCodeGen::generateStencilInstantiation(

generateStencilWrapperCtr(StencilWrapperClass, stencilInstantiation, codeGenProperties);

generateGlobalsAPI(*stencilInstantiation, StencilWrapperClass, globalsMap, codeGenProperties);
generateGlobalsAPI(StencilWrapperClass, globalsMap, codeGenProperties);

generateStencilWrapperRun(StencilWrapperClass, stencilInstantiation, codeGenProperties);

Expand Down Expand Up @@ -230,9 +230,6 @@ void CXXNaiveIcoCodeGen::generateStencilWrapperCtr(
std::string initCtr = "m_" + stencilName;

initCtr += "(mesh, k_size";
if(!globalsMap.empty()) {
initCtr += ",m_globals";
}
for(const auto& fieldInfoPair : stencilFields) {
const auto& fieldInfo = fieldInfoPair.second;
if(fieldInfo.IsTemporary)
Expand All @@ -242,6 +239,9 @@ void CXXNaiveIcoCodeGen::generateStencilWrapperCtr(
? ("m_" + fieldInfo.Name)
: (fieldInfo.Name));
}
if(!globalsMap.empty()) {
initCtr += ",m_globals";
}
initCtr += ")";
StencilWrapperConstructor.addInit(initCtr);
}
Expand Down Expand Up @@ -312,7 +312,7 @@ void CXXNaiveIcoCodeGen::generateStencilClasses(
Class& stencilWrapperClass, const CodeGenProperties& codeGenProperties) const {

const auto& stencils = stencilInstantiation->getStencils();
// const auto& globalsMap = stencilInstantiation->getIIR()->getGlobalVariableMap();
const auto& globalsMap = stencilInstantiation->getIIR()->getGlobalVariableMap();

// Stencil members:
// generate the code for each of the stencils
Expand Down Expand Up @@ -382,6 +382,9 @@ void CXXNaiveIcoCodeGen::generateStencilClasses(
"m_" + fieldIt.second.Name);
}
stencilClass.addMember("::dawn::unstructured_domain ", " m_unstructured_domain ");
if(!globalsMap.empty()) {
stencilClass.addMember("const globals &", " m_globals");
}

// addTmpStorageDeclaration(StencilClass, tempFields);

Expand All @@ -396,16 +399,20 @@ void CXXNaiveIcoCodeGen::generateStencilClasses(
}

// stencilClassCtr.addInit("m_dom(dom_)");
// if(!globalsMap.empty()) {
// stencilClassCtr.addArg("m_globals(globals_)");
// }
if(!globalsMap.empty()) {
stencilClassCtr.addArg("const globals &globals_");
}

stencilClassCtr.addInit("m_mesh(mesh)");
stencilClassCtr.addInit("m_k_size(k_size)");
for(auto fieldIt : nonTempFields) {
stencilClassCtr.addInit("m_" + fieldIt.second.Name + "(" + fieldIt.second.Name + ")");
}

if(!globalsMap.empty()) {
stencilClassCtr.addInit("m_globals(globals_)");
}

// addTmpStorageInit(stencilClassCtr, *stencil, tempFields);
stencilClassCtr.commit();

Expand Down Expand Up @@ -635,7 +642,7 @@ void CXXNaiveIcoCodeGen::generateStencilFunctions(

// add global parameter
if(stencilFun->hasGlobalVariables()) {
stencilFunMethod.addArg("const globals& m_globals");
stencilFunMethod.addArg("globals m_globals");
}
ASTStencilBody stencilBodyCXXVisitor(stencilInstantiation->getMetaData(),
StencilContext::SC_StencilFunction);
Expand Down Expand Up @@ -680,7 +687,7 @@ std::unique_ptr<TranslationUnit> CXXNaiveIcoCodeGen::generateCode() {
stencils.emplace(nameStencilCtxPair.first, std::move(code));
}

std::string globals = generateGlobals(context_, "::dawn_generated", "cxxnaiveico");
std::string globals = generateGlobals(context_, "dawn_generated", "cxxnaiveico");

std::vector<std::string> ppDefines;
ppDefines.push_back("#define DAWN_GENERATED 1");
Expand Down
18 changes: 9 additions & 9 deletions dawn/src/dawn/CodeGen/CXXNaive/CXXNaiveCodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ std::string CXXNaiveCodeGen::generateStencilInstantiation(

generateStencilWrapperCtr(stencilWrapperClass, stencilInstantiation, codeGenProperties);

generateGlobalsAPI(*stencilInstantiation, stencilWrapperClass, globalsMap, codeGenProperties);
generateGlobalsAPI(stencilWrapperClass, globalsMap, codeGenProperties);

generateStencilWrapperRun(stencilWrapperClass, stencilInstantiation, codeGenProperties);

Expand Down Expand Up @@ -321,7 +321,7 @@ void CXXNaiveCodeGen::generateStencilClasses(

stencilClassCtr.addArg("const " + c_dgt + "domain& dom_");
if(!globalsMap.empty()) {
stencilClassCtr.addArg("const globals& globals_");
stencilClassCtr.addArg("globals& globals_");
}
stencilClassCtr.addArg("int rank");
stencilClassCtr.addArg("int xcols");
Expand Down Expand Up @@ -399,14 +399,14 @@ void CXXNaiveCodeGen::generateStencilClasses(
for(auto it = nonTempFields.begin(); it != nonTempFields.end(); ++it) {
const auto fieldName = (*it).second.Name;
std::string type = stencilProperties->paramNameToType_.at(fieldName);
stencilRunMethod.addStatement(c_gt + "data_view<" + type + "> " + fieldName + "= " +
c_gt + "make_host_view(" + fieldName + "_)");
stencilRunMethod.addStatement(c_gt + "data_view<" + type + "> " + fieldName + "= " + c_gt +
"make_host_view(" + fieldName + "_)");
stencilRunMethod.addStatement("std::array<int,3> " + fieldName + "_offsets{0,0,0}");
}
for(const auto& fieldPair : tempFields) {
const auto fieldName = fieldPair.second.Name;
stencilRunMethod.addStatement(c_gt + "data_view<tmp_storage_t> " + fieldName + "= " +
c_gt + "make_host_view(m_" + fieldName + ")");
stencilRunMethod.addStatement(c_gt + "data_view<tmp_storage_t> " + fieldName + "= " + c_gt +
"make_host_view(m_" + fieldName + ")");
stencilRunMethod.addStatement("std::array<int,3> " + fieldName + "_offsets{0,0,0}");
}

Expand Down Expand Up @@ -569,7 +569,7 @@ void CXXNaiveCodeGen::generateStencilFunctions(

// add global parameter
if(stencilFun->hasGlobalVariables()) {
stencilFunMethod.addArg("const globals& m_globals");
stencilFunMethod.addArg("globals m_globals");
}
ASTStencilBody stencilBodyCXXVisitor(stencilInstantiation->getMetaData(),
StencilContext::SC_StencilFunction);
Expand All @@ -581,8 +581,8 @@ void CXXNaiveCodeGen::generateStencilFunctions(
std::string paramName =
stencilFun->getOriginalNameFromCallerAccessID(fields[m].getAccessID());

stencilFunMethod << c_gt << "data_view<StorageType" + std::to_string(m) + "> "
<< paramName << " = pw_" << paramName << ".dview_;";
stencilFunMethod << c_gt << "data_view<StorageType" + std::to_string(m) + "> " << paramName
<< " = pw_" << paramName << ".dview_;";
stencilFunMethod << "auto " << paramName << "_offsets = pw_" << paramName << ".offsets_;";
}
stencilBodyCXXVisitor.setCurrentStencilFunction(stencilFun);
Expand Down
2 changes: 1 addition & 1 deletion dawn/src/dawn/CodeGen/CXXOpt/CXXOptCodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ std::string CXXOptCodeGen::generateStencilInstantiation(

generateStencilWrapperCtr(stencilWrapperClass, stencilInstantiation, codeGenProperties);

generateGlobalsAPI(*stencilInstantiation, stencilWrapperClass, globalsMap, codeGenProperties);
generateGlobalsAPI(stencilWrapperClass, globalsMap, codeGenProperties);

generateStencilWrapperRun(stencilWrapperClass, stencilInstantiation, codeGenProperties);

Expand Down
5 changes: 2 additions & 3 deletions dawn/src/dawn/CodeGen/CodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,7 @@ std::string CodeGen::generateGlobals(const sir::GlobalVariableMap& globalsMap,
return ss.str();
}

void CodeGen::generateGlobalsAPI(const iir::StencilInstantiation& stencilInstantiation,
Class& stencilWrapperClass,
void CodeGen::generateGlobalsAPI(Structure& stencilWrapperClass,
const sir::GlobalVariableMap& globalsMap,
const CodeGenProperties& codeGenProperties) const {

Expand All @@ -131,7 +130,7 @@ void CodeGen::generateGlobalsAPI(const iir::StencilInstantiation& stencilInstant
setter.addArg(std::string(sir::Value::typeToString(globalValue.getType())) + " " +
globalProp.first);
setter.finishArgs();
setter.addStatement("m_globals." + globalProp.first + "=" + globalProp.first);
setter.addStatement("m_globals." + globalProp.first + "=" + globalProp.first);
setter.commit();
}
}
Expand Down
3 changes: 1 addition & 2 deletions dawn/src/dawn/CodeGen/CodeGen.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ class CodeGen {
CodeGenProperties
computeCodeGenProperties(const iir::StencilInstantiation* stencilInstantiation) const;

virtual void generateGlobalsAPI(const iir::StencilInstantiation& stencilInstantiation,
Class& stencilWrapperClass,
virtual void generateGlobalsAPI(Structure& stencilWrapperClass,
const sir::GlobalVariableMap& globalsMap,
const CodeGenProperties& codeGenProperties) const;
virtual std::string generateGlobals(const StencilInstantiationContext& context,
Expand Down
15 changes: 14 additions & 1 deletion dawn/src/dawn/CodeGen/Cuda-ico/ASTStencilBody.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,20 @@ void ASTStencilBody::visit(const std::shared_ptr<iir::ReturnStmt>& stmt) {
}

void ASTStencilBody::visit(const std::shared_ptr<iir::VarAccessExpr>& expr) {
DAWN_ASSERT_MSG(0, "Var Access not allowed in this context");
std::string name = getName(expr);
int AccessID = iir::getAccessID(expr);

if(metadata_.isAccessType(iir::FieldAccessType::GlobalVariable, AccessID)) {
ss_ << "globals." << name;
} else {
ss_ << name;

if(expr->isArrayAccess()) {
ss_ << "[";
expr->getIndex()->accept(*this);
ss_ << "]";
}
}
}

void ASTStencilBody::visit(const std::shared_ptr<iir::AssignmentExpr>& expr) {
Expand Down
47 changes: 40 additions & 7 deletions dawn/src/dawn/CodeGen/Cuda-ico/CudaIcoCodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ void CudaIcoCodeGen::generateRunFun(
const std::shared_ptr<iir::StencilInstantiation>& stencilInstantiation, MemberFunction& runFun,
CodeGenProperties& codeGenProperties) {

const auto& globalsMap = stencilInstantiation->getIIR()->getGlobalVariableMap();

// find block sizes to generate
std::set<ast::LocationType> stageLocType;
for(const auto& ms : iterateIIROver<iir::MultiStage>(*(stencilInstantiation->getIIR()))) {
Expand Down Expand Up @@ -308,7 +310,9 @@ void CudaIcoCodeGen::generateRunFun(
kernelCall << "<<<"
<< "dG" + std::to_string(stage->getStageID()) + ",dB"
<< ">>>(";

if(!globalsMap.empty()) {
kernelCall << "m_globals, ";
}
kernelCall << numElString << ", ";

// which loc size args (int CellIdx, int EdgeIdx, int CellIdx) need to be passed additionally?
Expand Down Expand Up @@ -364,9 +368,13 @@ void CudaIcoCodeGen::generateRunFun(
}

void CudaIcoCodeGen::generateStencilClassCtr(MemberFunction& ctor, const iir::Stencil& stencil,
const sir::GlobalVariableMap& globalsMap,
CodeGenProperties& codeGenProperties) const {

// arguments: mesh, kSize, fields
if(!globalsMap.empty()) {
ctor.addArg("globals globals");
}
ctor.addArg("const dawn::mesh_t<LibTag>& mesh");
ctor.addArg("int kSize");
for(auto field : support::orderMap(stencil.getFields())) {
Expand All @@ -389,6 +397,9 @@ void CudaIcoCodeGen::generateStencilClassCtr(MemberFunction& ctor, const iir::St
ctor.addInit("sbase(\"" + stencilName + "\")");
ctor.addInit("mesh_(mesh)");
ctor.addInit("kSize_(kSize)");
if(!globalsMap.empty()) {
ctor.addInit("m_globals(globals)");
}

std::stringstream fieldsStr;
{
Expand All @@ -406,8 +417,12 @@ void CudaIcoCodeGen::generateStencilClassCtr(MemberFunction& ctor, const iir::St

void CudaIcoCodeGen::generateStencilClassCtrMinimal(MemberFunction& ctor,
const iir::Stencil& stencil,
const sir::GlobalVariableMap& globalsMap,
CodeGenProperties& codeGenProperties) const {

if(!globalsMap.empty()) {
ctor.addArg("globals globals");
}
// arguments: mesh, kSize, fields
ctor.addArg("const dawn::GlobalGpuTriMesh *mesh");
ctor.addArg("int kSize");
Expand All @@ -417,6 +432,9 @@ void CudaIcoCodeGen::generateStencilClassCtrMinimal(MemberFunction& ctor,
codeGenProperties.getStencilName(StencilContext::SC_Stencil, stencil.getStencilID());
ctor.addInit("sbase(\"" + stencilName + "\")");
ctor.addInit("mesh_(mesh)");
if(!globalsMap.empty()) {
ctor.addInit("m_globals(globals)");
}
ctor.addInit("kSize_(kSize)");
}

Expand Down Expand Up @@ -576,6 +594,7 @@ void CudaIcoCodeGen::generateStencilClasses(
Class& stencilWrapperClass, CodeGenProperties& codeGenProperties) {

const auto& stencils = stencilInstantiation->getStencils();
const auto& globalsMap = stencilInstantiation->getIIR()->getGlobalVariableMap();

// Stencil members:
// generate the code for each of the stencils
Expand All @@ -587,6 +606,8 @@ void CudaIcoCodeGen::generateStencilClasses(

Structure stencilClass = stencilWrapperClass.addStruct(stencilName, "", "sbase");

generateGlobalsAPI(stencilClass, globalsMap, codeGenProperties);

// generate members (fields + kSize + gpuMesh)
stencilClass.changeAccessibility("private");
for(auto field : support::orderMap(stencil.getFields())) {
Expand All @@ -597,9 +618,13 @@ void CudaIcoCodeGen::generateStencilClasses(

stencilClass.changeAccessibility("public");

if(!globalsMap.empty()) {
stencilClass.addMember("globals", "m_globals");
}

// constructor from library
auto stencilClassConstructor = stencilClass.addConstructor();
generateStencilClassCtr(stencilClassConstructor, stencil, codeGenProperties);
generateStencilClassCtr(stencilClassConstructor, stencil, globalsMap, codeGenProperties);
stencilClassConstructor.commit();

// grid helper fun
Expand All @@ -613,7 +638,8 @@ void CudaIcoCodeGen::generateStencilClasses(

// minmal ctor
auto stencilClassMinimalConstructor = stencilClass.addConstructor();
generateStencilClassCtrMinimal(stencilClassMinimalConstructor, stencil, codeGenProperties);
generateStencilClassCtrMinimal(stencilClassMinimalConstructor, stencil, globalsMap,
codeGenProperties);
stencilClassMinimalConstructor.commit();

// run method
Expand Down Expand Up @@ -647,6 +673,8 @@ void CudaIcoCodeGen::generateAllAPIRunFunctions(
CodeGenProperties& codeGenProperties, bool fromHost) {
const auto& stencils = stencilInstantiation->getStencils();

const auto& globalsMap = stencilInstantiation->getIIR()->getGlobalVariableMap();

CollectChainStrings chainCollector;
std::set<std::vector<ast::LocationType>> chains;
for(const auto& doMethod : iterateIIROver<iir::DoMethod>(*(stencilInstantiation->getIIR()))) {
Expand Down Expand Up @@ -732,8 +760,10 @@ void CudaIcoCodeGen::generateAllAPIRunFunctions(
}

for(auto& apiRunFun : apiRunFuns) {
apiRunFun->addStatement(wrapperName + "<dawn::NoLibTag, " + chainSizesStr.str() +
">::" + stencilName + " s(mesh, k_size)");
apiRunFun->addStatement(wrapperName + "<dawn::NoLibTag " +
(chainSizesStr.str().empty() ? "" : ", " + chainSizesStr.str()) +
">::" + stencilName + " s(" +
(globalsMap.empty() ? "" : "globals(), ") + "mesh, k_size)");
}
if(fromHost) {
// depending if we are calling from c or from fortran, we need to transpose the data or not
Expand Down Expand Up @@ -792,6 +822,7 @@ void CudaIcoCodeGen::generateAllCudaKernels(
const std::shared_ptr<iir::StencilInstantiation>& stencilInstantiation) {

ASTStencilBody stencilBodyCXXVisitor(stencilInstantiation->getMetaData());
const auto& globalsMap = stencilInstantiation->getIIR()->getGlobalVariableMap();

for(const auto& ms : iterateIIROver<iir::MultiStage>(*(stencilInstantiation->getIIR()))) {
for(const auto& stage : ms->getChildren()) {
Expand Down Expand Up @@ -829,6 +860,9 @@ void CudaIcoCodeGen::generateAllCudaKernels(
retString,
cuda::CodeGeneratorHelper::buildCudaKernelName(stencilInstantiation, ms, stage), ssSW);

if(!globalsMap.empty()) {
cudaKernel.addArg("globals globals");
}
auto loc = *stage->getLocationType();
cudaKernel.addArg("int " + locToDenseSizeStringGpuMesh(loc));

Expand Down Expand Up @@ -1029,8 +1063,7 @@ std::unique_ptr<TranslationUnit> CudaIcoCodeGen::generateCode() {
"using namespace gridtools::dawn;",
};

// globals not yet supported
std::string globals = "";
std::string globals = generateGlobals(context_, "dawn_generated", "cuda_ico");

DAWN_LOG(INFO) << "Done generating code";

Expand Down
4 changes: 3 additions & 1 deletion dawn/src/dawn/CodeGen/Cuda-ico/CudaIcoCodeGen.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,12 @@ class CudaIcoCodeGen : public CodeGen {
void generateGridFun(MemberFunction& runFun);

void generateStencilClassCtr(MemberFunction& stencilClassCtor, const iir::Stencil& stencil,
const sir::GlobalVariableMap& globalsMap,
CodeGenProperties& codeGenProperties) const;

void generateStencilClassCtrMinimal(MemberFunction& stencilClassCtor, const iir::Stencil& stencil,
CodeGenProperties& codeGenProperties) const;
const sir::GlobalVariableMap& globalsMap,
CodeGenProperties& codeGenProperties) const;

void generateStencilClassRawPtrCtr(MemberFunction& stencilClassCtor, const iir::Stencil& stencil,
CodeGenProperties& codeGenProperties) const;
Expand Down
2 changes: 1 addition & 1 deletion dawn/src/dawn/CodeGen/Cuda/CudaCodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ std::string CudaCodeGen::generateStencilInstantiation(
generateStencilWrapperCtr(stencilWrapperClass, stencilInstantiation, codeGenProperties);

if(!globalsMap.empty()) {
generateGlobalsAPI(*stencilInstantiation, stencilWrapperClass, globalsMap, codeGenProperties);
generateGlobalsAPI(stencilWrapperClass, globalsMap, codeGenProperties);
}

generateStencilWrapperSyncMethod(stencilWrapperClass);
Expand Down
Loading

0 comments on commit 713801a

Please sign in to comment.