Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabling Globals for Unstructured Backends #1039

Merged
merged 5 commits into from
Oct 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions dawn/src/dawn/CodeGen/CXXNaive-ico/CXXNaiveCodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ std::string CXXNaiveIcoCodeGen::generateStencilInstantiation(

generateStencilWrapperCtr(StencilWrapperClass, stencilInstantiation, codeGenProperties);

generateGlobalsAPI(*stencilInstantiation, StencilWrapperClass, globalsMap, codeGenProperties);
generateGlobalsAPI(StencilWrapperClass, globalsMap, codeGenProperties);

generateStencilWrapperRun(StencilWrapperClass, stencilInstantiation, codeGenProperties);

Expand Down Expand Up @@ -230,9 +230,6 @@ void CXXNaiveIcoCodeGen::generateStencilWrapperCtr(
std::string initCtr = "m_" + stencilName;

initCtr += "(mesh, k_size";
if(!globalsMap.empty()) {
initCtr += ",m_globals";
}
for(const auto& fieldInfoPair : stencilFields) {
const auto& fieldInfo = fieldInfoPair.second;
if(fieldInfo.IsTemporary)
Expand All @@ -242,6 +239,9 @@ void CXXNaiveIcoCodeGen::generateStencilWrapperCtr(
? ("m_" + fieldInfo.Name)
: (fieldInfo.Name));
}
if(!globalsMap.empty()) {
initCtr += ",m_globals";
}
initCtr += ")";
StencilWrapperConstructor.addInit(initCtr);
}
Expand Down Expand Up @@ -312,7 +312,7 @@ void CXXNaiveIcoCodeGen::generateStencilClasses(
Class& stencilWrapperClass, const CodeGenProperties& codeGenProperties) const {

const auto& stencils = stencilInstantiation->getStencils();
// const auto& globalsMap = stencilInstantiation->getIIR()->getGlobalVariableMap();
const auto& globalsMap = stencilInstantiation->getIIR()->getGlobalVariableMap();

// Stencil members:
// generate the code for each of the stencils
Expand Down Expand Up @@ -382,6 +382,9 @@ void CXXNaiveIcoCodeGen::generateStencilClasses(
"m_" + fieldIt.second.Name);
}
stencilClass.addMember("::dawn::unstructured_domain ", " m_unstructured_domain ");
if(!globalsMap.empty()) {
stencilClass.addMember("const globals &", " m_globals");
}

// addTmpStorageDeclaration(StencilClass, tempFields);

Expand All @@ -396,16 +399,20 @@ void CXXNaiveIcoCodeGen::generateStencilClasses(
}

// stencilClassCtr.addInit("m_dom(dom_)");
// if(!globalsMap.empty()) {
// stencilClassCtr.addArg("m_globals(globals_)");
// }
if(!globalsMap.empty()) {
stencilClassCtr.addArg("const globals &globals_");
}

stencilClassCtr.addInit("m_mesh(mesh)");
stencilClassCtr.addInit("m_k_size(k_size)");
for(auto fieldIt : nonTempFields) {
stencilClassCtr.addInit("m_" + fieldIt.second.Name + "(" + fieldIt.second.Name + ")");
}

if(!globalsMap.empty()) {
stencilClassCtr.addInit("m_globals(globals_)");
}

// addTmpStorageInit(stencilClassCtr, *stencil, tempFields);
stencilClassCtr.commit();

Expand Down Expand Up @@ -635,7 +642,7 @@ void CXXNaiveIcoCodeGen::generateStencilFunctions(

// add global parameter
if(stencilFun->hasGlobalVariables()) {
stencilFunMethod.addArg("const globals& m_globals");
stencilFunMethod.addArg("globals m_globals");
}
ASTStencilBody stencilBodyCXXVisitor(stencilInstantiation->getMetaData(),
StencilContext::SC_StencilFunction);
Expand Down Expand Up @@ -680,7 +687,7 @@ std::unique_ptr<TranslationUnit> CXXNaiveIcoCodeGen::generateCode() {
stencils.emplace(nameStencilCtxPair.first, std::move(code));
}

std::string globals = generateGlobals(context_, "::dawn_generated", "cxxnaiveico");
std::string globals = generateGlobals(context_, "dawn_generated", "cxxnaiveico");

std::vector<std::string> ppDefines;
ppDefines.push_back("#define DAWN_GENERATED 1");
Expand Down
18 changes: 9 additions & 9 deletions dawn/src/dawn/CodeGen/CXXNaive/CXXNaiveCodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ std::string CXXNaiveCodeGen::generateStencilInstantiation(

generateStencilWrapperCtr(stencilWrapperClass, stencilInstantiation, codeGenProperties);

generateGlobalsAPI(*stencilInstantiation, stencilWrapperClass, globalsMap, codeGenProperties);
generateGlobalsAPI(stencilWrapperClass, globalsMap, codeGenProperties);

generateStencilWrapperRun(stencilWrapperClass, stencilInstantiation, codeGenProperties);

Expand Down Expand Up @@ -321,7 +321,7 @@ void CXXNaiveCodeGen::generateStencilClasses(

stencilClassCtr.addArg("const " + c_dgt + "domain& dom_");
if(!globalsMap.empty()) {
stencilClassCtr.addArg("const globals& globals_");
stencilClassCtr.addArg("globals& globals_");
}
stencilClassCtr.addArg("int rank");
stencilClassCtr.addArg("int xcols");
Expand Down Expand Up @@ -399,14 +399,14 @@ void CXXNaiveCodeGen::generateStencilClasses(
for(auto it = nonTempFields.begin(); it != nonTempFields.end(); ++it) {
const auto fieldName = (*it).second.Name;
std::string type = stencilProperties->paramNameToType_.at(fieldName);
stencilRunMethod.addStatement(c_gt + "data_view<" + type + "> " + fieldName + "= " +
c_gt + "make_host_view(" + fieldName + "_)");
stencilRunMethod.addStatement(c_gt + "data_view<" + type + "> " + fieldName + "= " + c_gt +
"make_host_view(" + fieldName + "_)");
stencilRunMethod.addStatement("std::array<int,3> " + fieldName + "_offsets{0,0,0}");
}
for(const auto& fieldPair : tempFields) {
const auto fieldName = fieldPair.second.Name;
stencilRunMethod.addStatement(c_gt + "data_view<tmp_storage_t> " + fieldName + "= " +
c_gt + "make_host_view(m_" + fieldName + ")");
stencilRunMethod.addStatement(c_gt + "data_view<tmp_storage_t> " + fieldName + "= " + c_gt +
"make_host_view(m_" + fieldName + ")");
stencilRunMethod.addStatement("std::array<int,3> " + fieldName + "_offsets{0,0,0}");
}

Expand Down Expand Up @@ -569,7 +569,7 @@ void CXXNaiveCodeGen::generateStencilFunctions(

// add global parameter
if(stencilFun->hasGlobalVariables()) {
stencilFunMethod.addArg("const globals& m_globals");
stencilFunMethod.addArg("globals m_globals");
}
ASTStencilBody stencilBodyCXXVisitor(stencilInstantiation->getMetaData(),
StencilContext::SC_StencilFunction);
Expand All @@ -581,8 +581,8 @@ void CXXNaiveCodeGen::generateStencilFunctions(
std::string paramName =
stencilFun->getOriginalNameFromCallerAccessID(fields[m].getAccessID());

stencilFunMethod << c_gt << "data_view<StorageType" + std::to_string(m) + "> "
<< paramName << " = pw_" << paramName << ".dview_;";
stencilFunMethod << c_gt << "data_view<StorageType" + std::to_string(m) + "> " << paramName
<< " = pw_" << paramName << ".dview_;";
stencilFunMethod << "auto " << paramName << "_offsets = pw_" << paramName << ".offsets_;";
}
stencilBodyCXXVisitor.setCurrentStencilFunction(stencilFun);
Expand Down
2 changes: 1 addition & 1 deletion dawn/src/dawn/CodeGen/CXXOpt/CXXOptCodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ std::string CXXOptCodeGen::generateStencilInstantiation(

generateStencilWrapperCtr(stencilWrapperClass, stencilInstantiation, codeGenProperties);

generateGlobalsAPI(*stencilInstantiation, stencilWrapperClass, globalsMap, codeGenProperties);
generateGlobalsAPI(stencilWrapperClass, globalsMap, codeGenProperties);

generateStencilWrapperRun(stencilWrapperClass, stencilInstantiation, codeGenProperties);

Expand Down
5 changes: 2 additions & 3 deletions dawn/src/dawn/CodeGen/CodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,7 @@ std::string CodeGen::generateGlobals(const sir::GlobalVariableMap& globalsMap,
return ss.str();
}

void CodeGen::generateGlobalsAPI(const iir::StencilInstantiation& stencilInstantiation,
Class& stencilWrapperClass,
void CodeGen::generateGlobalsAPI(Structure& stencilWrapperClass,
const sir::GlobalVariableMap& globalsMap,
const CodeGenProperties& codeGenProperties) const {

Expand All @@ -131,7 +130,7 @@ void CodeGen::generateGlobalsAPI(const iir::StencilInstantiation& stencilInstant
setter.addArg(std::string(sir::Value::typeToString(globalValue.getType())) + " " +
globalProp.first);
setter.finishArgs();
setter.addStatement("m_globals." + globalProp.first + "=" + globalProp.first);
setter.addStatement("m_globals." + globalProp.first + "=" + globalProp.first);
setter.commit();
}
}
Expand Down
3 changes: 1 addition & 2 deletions dawn/src/dawn/CodeGen/CodeGen.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ class CodeGen {
CodeGenProperties
computeCodeGenProperties(const iir::StencilInstantiation* stencilInstantiation) const;

virtual void generateGlobalsAPI(const iir::StencilInstantiation& stencilInstantiation,
Class& stencilWrapperClass,
virtual void generateGlobalsAPI(Structure& stencilWrapperClass,
const sir::GlobalVariableMap& globalsMap,
const CodeGenProperties& codeGenProperties) const;
virtual std::string generateGlobals(const StencilInstantiationContext& context,
Expand Down
15 changes: 14 additions & 1 deletion dawn/src/dawn/CodeGen/Cuda-ico/ASTStencilBody.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,20 @@ void ASTStencilBody::visit(const std::shared_ptr<iir::ReturnStmt>& stmt) {
}

void ASTStencilBody::visit(const std::shared_ptr<iir::VarAccessExpr>& expr) {
DAWN_ASSERT_MSG(0, "Var Access not allowed in this context");
std::string name = getName(expr);
int AccessID = iir::getAccessID(expr);

if(metadata_.isAccessType(iir::FieldAccessType::GlobalVariable, AccessID)) {
ss_ << "globals." << name;
} else {
ss_ << name;
mroethlin marked this conversation as resolved.
Show resolved Hide resolved

if(expr->isArrayAccess()) {
ss_ << "[";
expr->getIndex()->accept(*this);
ss_ << "]";
}
}
}

void ASTStencilBody::visit(const std::shared_ptr<iir::AssignmentExpr>& expr) {
Expand Down
47 changes: 40 additions & 7 deletions dawn/src/dawn/CodeGen/Cuda-ico/CudaIcoCodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ void CudaIcoCodeGen::generateRunFun(
const std::shared_ptr<iir::StencilInstantiation>& stencilInstantiation, MemberFunction& runFun,
CodeGenProperties& codeGenProperties) {

const auto& globalsMap = stencilInstantiation->getIIR()->getGlobalVariableMap();

// find block sizes to generate
std::set<ast::LocationType> stageLocType;
for(const auto& ms : iterateIIROver<iir::MultiStage>(*(stencilInstantiation->getIIR()))) {
Expand Down Expand Up @@ -308,7 +310,9 @@ void CudaIcoCodeGen::generateRunFun(
kernelCall << "<<<"
<< "dG" + std::to_string(stage->getStageID()) + ",dB"
<< ">>>(";

if(!globalsMap.empty()) {
kernelCall << "m_globals, ";
}
kernelCall << numElString << ", ";

// which loc size args (int CellIdx, int EdgeIdx, int CellIdx) need to be passed additionally?
Expand Down Expand Up @@ -364,9 +368,13 @@ void CudaIcoCodeGen::generateRunFun(
}

void CudaIcoCodeGen::generateStencilClassCtr(MemberFunction& ctor, const iir::Stencil& stencil,
const sir::GlobalVariableMap& globalsMap,
CodeGenProperties& codeGenProperties) const {

// arguments: mesh, kSize, fields
if(!globalsMap.empty()) {
ctor.addArg("globals globals");
}
ctor.addArg("const dawn::mesh_t<LibTag>& mesh");
ctor.addArg("int kSize");
for(auto field : support::orderMap(stencil.getFields())) {
Expand All @@ -389,6 +397,9 @@ void CudaIcoCodeGen::generateStencilClassCtr(MemberFunction& ctor, const iir::St
ctor.addInit("sbase(\"" + stencilName + "\")");
ctor.addInit("mesh_(mesh)");
ctor.addInit("kSize_(kSize)");
if(!globalsMap.empty()) {
ctor.addInit("m_globals(globals)");
}

std::stringstream fieldsStr;
{
Expand All @@ -406,8 +417,12 @@ void CudaIcoCodeGen::generateStencilClassCtr(MemberFunction& ctor, const iir::St

void CudaIcoCodeGen::generateStencilClassCtrMinimal(MemberFunction& ctor,
const iir::Stencil& stencil,
const sir::GlobalVariableMap& globalsMap,
CodeGenProperties& codeGenProperties) const {

if(!globalsMap.empty()) {
ctor.addArg("globals globals");
}
// arguments: mesh, kSize, fields
ctor.addArg("const dawn::GlobalGpuTriMesh *mesh");
ctor.addArg("int kSize");
Expand All @@ -417,6 +432,9 @@ void CudaIcoCodeGen::generateStencilClassCtrMinimal(MemberFunction& ctor,
codeGenProperties.getStencilName(StencilContext::SC_Stencil, stencil.getStencilID());
ctor.addInit("sbase(\"" + stencilName + "\")");
ctor.addInit("mesh_(mesh)");
if(!globalsMap.empty()) {
ctor.addInit("m_globals(globals)");
}
ctor.addInit("kSize_(kSize)");
}

Expand Down Expand Up @@ -576,6 +594,7 @@ void CudaIcoCodeGen::generateStencilClasses(
Class& stencilWrapperClass, CodeGenProperties& codeGenProperties) {

const auto& stencils = stencilInstantiation->getStencils();
const auto& globalsMap = stencilInstantiation->getIIR()->getGlobalVariableMap();

// Stencil members:
// generate the code for each of the stencils
Expand All @@ -587,6 +606,8 @@ void CudaIcoCodeGen::generateStencilClasses(

Structure stencilClass = stencilWrapperClass.addStruct(stencilName, "", "sbase");

generateGlobalsAPI(stencilClass, globalsMap, codeGenProperties);

// generate members (fields + kSize + gpuMesh)
stencilClass.changeAccessibility("private");
for(auto field : support::orderMap(stencil.getFields())) {
Expand All @@ -597,9 +618,13 @@ void CudaIcoCodeGen::generateStencilClasses(

stencilClass.changeAccessibility("public");

if(!globalsMap.empty()) {
stencilClass.addMember("globals", "m_globals");
}

// constructor from library
auto stencilClassConstructor = stencilClass.addConstructor();
generateStencilClassCtr(stencilClassConstructor, stencil, codeGenProperties);
generateStencilClassCtr(stencilClassConstructor, stencil, globalsMap, codeGenProperties);
stencilClassConstructor.commit();

// grid helper fun
Expand All @@ -613,7 +638,8 @@ void CudaIcoCodeGen::generateStencilClasses(

// minmal ctor
auto stencilClassMinimalConstructor = stencilClass.addConstructor();
generateStencilClassCtrMinimal(stencilClassMinimalConstructor, stencil, codeGenProperties);
generateStencilClassCtrMinimal(stencilClassMinimalConstructor, stencil, globalsMap,
codeGenProperties);
stencilClassMinimalConstructor.commit();

// run method
Expand Down Expand Up @@ -647,6 +673,8 @@ void CudaIcoCodeGen::generateAllAPIRunFunctions(
CodeGenProperties& codeGenProperties, bool fromHost) {
const auto& stencils = stencilInstantiation->getStencils();

const auto& globalsMap = stencilInstantiation->getIIR()->getGlobalVariableMap();

CollectChainStrings chainCollector;
std::set<std::vector<ast::LocationType>> chains;
for(const auto& doMethod : iterateIIROver<iir::DoMethod>(*(stencilInstantiation->getIIR()))) {
Expand Down Expand Up @@ -732,8 +760,10 @@ void CudaIcoCodeGen::generateAllAPIRunFunctions(
}

for(auto& apiRunFun : apiRunFuns) {
apiRunFun->addStatement(wrapperName + "<dawn::NoLibTag, " + chainSizesStr.str() +
">::" + stencilName + " s(mesh, k_size)");
apiRunFun->addStatement(wrapperName + "<dawn::NoLibTag " +
(chainSizesStr.str().empty() ? "" : ", " + chainSizesStr.str()) +
">::" + stencilName + " s(" +
(globalsMap.empty() ? "" : "globals(), ") + "mesh, k_size)");
}
if(fromHost) {
// depending if we are calling from c or from fortran, we need to transpose the data or not
Expand Down Expand Up @@ -792,6 +822,7 @@ void CudaIcoCodeGen::generateAllCudaKernels(
const std::shared_ptr<iir::StencilInstantiation>& stencilInstantiation) {

ASTStencilBody stencilBodyCXXVisitor(stencilInstantiation->getMetaData());
const auto& globalsMap = stencilInstantiation->getIIR()->getGlobalVariableMap();

for(const auto& ms : iterateIIROver<iir::MultiStage>(*(stencilInstantiation->getIIR()))) {
for(const auto& stage : ms->getChildren()) {
Expand Down Expand Up @@ -829,6 +860,9 @@ void CudaIcoCodeGen::generateAllCudaKernels(
retString,
cuda::CodeGeneratorHelper::buildCudaKernelName(stencilInstantiation, ms, stage), ssSW);

if(!globalsMap.empty()) {
cudaKernel.addArg("globals globals");
}
auto loc = *stage->getLocationType();
cudaKernel.addArg("int " + locToDenseSizeStringGpuMesh(loc));

Expand Down Expand Up @@ -1029,8 +1063,7 @@ std::unique_ptr<TranslationUnit> CudaIcoCodeGen::generateCode() {
"using namespace gridtools::dawn;",
};

// globals not yet supported
std::string globals = "";
std::string globals = generateGlobals(context_, "dawn_generated", "cuda_ico");

DAWN_LOG(INFO) << "Done generating code";

Expand Down
4 changes: 3 additions & 1 deletion dawn/src/dawn/CodeGen/Cuda-ico/CudaIcoCodeGen.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,12 @@ class CudaIcoCodeGen : public CodeGen {
void generateGridFun(MemberFunction& runFun);

void generateStencilClassCtr(MemberFunction& stencilClassCtor, const iir::Stencil& stencil,
const sir::GlobalVariableMap& globalsMap,
CodeGenProperties& codeGenProperties) const;

void generateStencilClassCtrMinimal(MemberFunction& stencilClassCtor, const iir::Stencil& stencil,
CodeGenProperties& codeGenProperties) const;
const sir::GlobalVariableMap& globalsMap,
CodeGenProperties& codeGenProperties) const;

void generateStencilClassRawPtrCtr(MemberFunction& stencilClassCtor, const iir::Stencil& stencil,
CodeGenProperties& codeGenProperties) const;
Expand Down
2 changes: 1 addition & 1 deletion dawn/src/dawn/CodeGen/Cuda/CudaCodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ std::string CudaCodeGen::generateStencilInstantiation(
generateStencilWrapperCtr(stencilWrapperClass, stencilInstantiation, codeGenProperties);

if(!globalsMap.empty()) {
generateGlobalsAPI(*stencilInstantiation, stencilWrapperClass, globalsMap, codeGenProperties);
generateGlobalsAPI(stencilWrapperClass, globalsMap, codeGenProperties);
}

generateStencilWrapperSyncMethod(stencilWrapperClass);
Expand Down
Loading