From 713801a1b46939cc4220d2bf2665e8fa02a889c0 Mon Sep 17 00:00:00 2001 From: mroethlin <53755555+mroethlin@users.noreply.github.com> Date: Fri, 16 Oct 2020 11:50:14 +0200 Subject: [PATCH] Enabling Globals for Unstructured Backends (#1039) ## Technical Description This PR enables globals for the unstructured backends. Furthermore, an unreported issue where globals were not propagated from the wrapper class to the stencil class was fixed. Additionally, a bug in the unstructured cuda codegen was fixed when translating stencils that only use dense dimensions. ### Resolves / Enhances Fixes https://github.com/MeteoSwiss-APN/dawn/issues/1030 Fixes https://github.com/MeteoSwiss-APN/dawn/issues/1028 ### Notes The methods to set and get globals in the cuda backend are on the inner stencils. This will be addressed in [this issue](https://github.com/MeteoSwiss-APN/dawn/issues/1038). Also, a method to communicate globals from FORTRAN will need to be devised (not addressed yet). ### Testing New tests in dawn4py and a new unstructured integration test to test the correct operation of the `CXXNaiveIco` backend. `CudaIco` backend tested manually. ### Dependencies This PR is independent. --- .../CodeGen/CXXNaive-ico/CXXNaiveCodeGen.cpp | 27 +++-- .../dawn/CodeGen/CXXNaive/CXXNaiveCodeGen.cpp | 18 ++-- .../src/dawn/CodeGen/CXXOpt/CXXOptCodeGen.cpp | 2 +- dawn/src/dawn/CodeGen/CodeGen.cpp | 5 +- dawn/src/dawn/CodeGen/CodeGen.h | 3 +- .../dawn/CodeGen/Cuda-ico/ASTStencilBody.cpp | 15 ++- .../dawn/CodeGen/Cuda-ico/CudaIcoCodeGen.cpp | 47 ++++++-- .../dawn/CodeGen/Cuda-ico/CudaIcoCodeGen.h | 4 +- dawn/src/dawn/CodeGen/Cuda/CudaCodeGen.cpp | 2 +- dawn/src/dawn/CodeGen/GridTools/GTCodeGen.cpp | 5 +- dawn/src/dawn/CodeGen/GridTools/GTCodeGen.h | 3 +- .../dawn4py-tests/CMakeLists.txt | 2 + .../dawn4py-tests/global_var.py | 93 ++++++++++++++++ .../dawn4py-tests/global_var_unstructured.py | 101 ++++++++++++++++++ .../AtlasIntegrationTestCompareOutput.cpp | 29 +++++ .../unstructured/CMakeLists.txt | 1 + .../GenerateUnstructuredStencils.cpp | 28 ++++- .../CodeGen/reference/conditional_stencil.cpp | 2 +- .../dawn/CodeGen/reference/update_dz_c.cpp | 2 +- 19 files changed, 346 insertions(+), 43 deletions(-) create mode 100644 dawn/test/integration-test/dawn4py-tests/global_var.py create mode 100644 dawn/test/integration-test/dawn4py-tests/global_var_unstructured.py diff --git a/dawn/src/dawn/CodeGen/CXXNaive-ico/CXXNaiveCodeGen.cpp b/dawn/src/dawn/CodeGen/CXXNaive-ico/CXXNaiveCodeGen.cpp index a3b54fbb0..29ac27b92 100644 --- a/dawn/src/dawn/CodeGen/CXXNaive-ico/CXXNaiveCodeGen.cpp +++ b/dawn/src/dawn/CodeGen/CXXNaive-ico/CXXNaiveCodeGen.cpp @@ -131,7 +131,7 @@ std::string CXXNaiveIcoCodeGen::generateStencilInstantiation( generateStencilWrapperCtr(StencilWrapperClass, stencilInstantiation, codeGenProperties); - generateGlobalsAPI(*stencilInstantiation, StencilWrapperClass, globalsMap, codeGenProperties); + generateGlobalsAPI(StencilWrapperClass, globalsMap, codeGenProperties); generateStencilWrapperRun(StencilWrapperClass, stencilInstantiation, codeGenProperties); @@ -230,9 +230,6 @@ void CXXNaiveIcoCodeGen::generateStencilWrapperCtr( std::string initCtr = "m_" + stencilName; initCtr += "(mesh, k_size"; - if(!globalsMap.empty()) { - initCtr += ",m_globals"; - } for(const auto& fieldInfoPair : stencilFields) { const auto& fieldInfo = fieldInfoPair.second; if(fieldInfo.IsTemporary) @@ -242,6 +239,9 @@ void CXXNaiveIcoCodeGen::generateStencilWrapperCtr( ? ("m_" + fieldInfo.Name) : (fieldInfo.Name)); } + if(!globalsMap.empty()) { + initCtr += ",m_globals"; + } initCtr += ")"; StencilWrapperConstructor.addInit(initCtr); } @@ -312,7 +312,7 @@ void CXXNaiveIcoCodeGen::generateStencilClasses( Class& stencilWrapperClass, const CodeGenProperties& codeGenProperties) const { const auto& stencils = stencilInstantiation->getStencils(); - // const auto& globalsMap = stencilInstantiation->getIIR()->getGlobalVariableMap(); + const auto& globalsMap = stencilInstantiation->getIIR()->getGlobalVariableMap(); // Stencil members: // generate the code for each of the stencils @@ -382,6 +382,9 @@ void CXXNaiveIcoCodeGen::generateStencilClasses( "m_" + fieldIt.second.Name); } stencilClass.addMember("::dawn::unstructured_domain ", " m_unstructured_domain "); + if(!globalsMap.empty()) { + stencilClass.addMember("const globals &", " m_globals"); + } // addTmpStorageDeclaration(StencilClass, tempFields); @@ -396,9 +399,9 @@ void CXXNaiveIcoCodeGen::generateStencilClasses( } // stencilClassCtr.addInit("m_dom(dom_)"); - // if(!globalsMap.empty()) { - // stencilClassCtr.addArg("m_globals(globals_)"); - // } + if(!globalsMap.empty()) { + stencilClassCtr.addArg("const globals &globals_"); + } stencilClassCtr.addInit("m_mesh(mesh)"); stencilClassCtr.addInit("m_k_size(k_size)"); @@ -406,6 +409,10 @@ void CXXNaiveIcoCodeGen::generateStencilClasses( stencilClassCtr.addInit("m_" + fieldIt.second.Name + "(" + fieldIt.second.Name + ")"); } + if(!globalsMap.empty()) { + stencilClassCtr.addInit("m_globals(globals_)"); + } + // addTmpStorageInit(stencilClassCtr, *stencil, tempFields); stencilClassCtr.commit(); @@ -635,7 +642,7 @@ void CXXNaiveIcoCodeGen::generateStencilFunctions( // add global parameter if(stencilFun->hasGlobalVariables()) { - stencilFunMethod.addArg("const globals& m_globals"); + stencilFunMethod.addArg("globals m_globals"); } ASTStencilBody stencilBodyCXXVisitor(stencilInstantiation->getMetaData(), StencilContext::SC_StencilFunction); @@ -680,7 +687,7 @@ std::unique_ptr CXXNaiveIcoCodeGen::generateCode() { stencils.emplace(nameStencilCtxPair.first, std::move(code)); } - std::string globals = generateGlobals(context_, "::dawn_generated", "cxxnaiveico"); + std::string globals = generateGlobals(context_, "dawn_generated", "cxxnaiveico"); std::vector ppDefines; ppDefines.push_back("#define DAWN_GENERATED 1"); diff --git a/dawn/src/dawn/CodeGen/CXXNaive/CXXNaiveCodeGen.cpp b/dawn/src/dawn/CodeGen/CXXNaive/CXXNaiveCodeGen.cpp index 78bed4d85..17ca73701 100644 --- a/dawn/src/dawn/CodeGen/CXXNaive/CXXNaiveCodeGen.cpp +++ b/dawn/src/dawn/CodeGen/CXXNaive/CXXNaiveCodeGen.cpp @@ -123,7 +123,7 @@ std::string CXXNaiveCodeGen::generateStencilInstantiation( generateStencilWrapperCtr(stencilWrapperClass, stencilInstantiation, codeGenProperties); - generateGlobalsAPI(*stencilInstantiation, stencilWrapperClass, globalsMap, codeGenProperties); + generateGlobalsAPI(stencilWrapperClass, globalsMap, codeGenProperties); generateStencilWrapperRun(stencilWrapperClass, stencilInstantiation, codeGenProperties); @@ -321,7 +321,7 @@ void CXXNaiveCodeGen::generateStencilClasses( stencilClassCtr.addArg("const " + c_dgt + "domain& dom_"); if(!globalsMap.empty()) { - stencilClassCtr.addArg("const globals& globals_"); + stencilClassCtr.addArg("globals& globals_"); } stencilClassCtr.addArg("int rank"); stencilClassCtr.addArg("int xcols"); @@ -399,14 +399,14 @@ void CXXNaiveCodeGen::generateStencilClasses( for(auto it = nonTempFields.begin(); it != nonTempFields.end(); ++it) { const auto fieldName = (*it).second.Name; std::string type = stencilProperties->paramNameToType_.at(fieldName); - stencilRunMethod.addStatement(c_gt + "data_view<" + type + "> " + fieldName + "= " + - c_gt + "make_host_view(" + fieldName + "_)"); + stencilRunMethod.addStatement(c_gt + "data_view<" + type + "> " + fieldName + "= " + c_gt + + "make_host_view(" + fieldName + "_)"); stencilRunMethod.addStatement("std::array " + fieldName + "_offsets{0,0,0}"); } for(const auto& fieldPair : tempFields) { const auto fieldName = fieldPair.second.Name; - stencilRunMethod.addStatement(c_gt + "data_view " + fieldName + "= " + - c_gt + "make_host_view(m_" + fieldName + ")"); + stencilRunMethod.addStatement(c_gt + "data_view " + fieldName + "= " + c_gt + + "make_host_view(m_" + fieldName + ")"); stencilRunMethod.addStatement("std::array " + fieldName + "_offsets{0,0,0}"); } @@ -569,7 +569,7 @@ void CXXNaiveCodeGen::generateStencilFunctions( // add global parameter if(stencilFun->hasGlobalVariables()) { - stencilFunMethod.addArg("const globals& m_globals"); + stencilFunMethod.addArg("globals m_globals"); } ASTStencilBody stencilBodyCXXVisitor(stencilInstantiation->getMetaData(), StencilContext::SC_StencilFunction); @@ -581,8 +581,8 @@ void CXXNaiveCodeGen::generateStencilFunctions( std::string paramName = stencilFun->getOriginalNameFromCallerAccessID(fields[m].getAccessID()); - stencilFunMethod << c_gt << "data_view " - << paramName << " = pw_" << paramName << ".dview_;"; + stencilFunMethod << c_gt << "data_view " << paramName + << " = pw_" << paramName << ".dview_;"; stencilFunMethod << "auto " << paramName << "_offsets = pw_" << paramName << ".offsets_;"; } stencilBodyCXXVisitor.setCurrentStencilFunction(stencilFun); diff --git a/dawn/src/dawn/CodeGen/CXXOpt/CXXOptCodeGen.cpp b/dawn/src/dawn/CodeGen/CXXOpt/CXXOptCodeGen.cpp index 83dffde16..031ecf5b8 100644 --- a/dawn/src/dawn/CodeGen/CXXOpt/CXXOptCodeGen.cpp +++ b/dawn/src/dawn/CodeGen/CXXOpt/CXXOptCodeGen.cpp @@ -131,7 +131,7 @@ std::string CXXOptCodeGen::generateStencilInstantiation( generateStencilWrapperCtr(stencilWrapperClass, stencilInstantiation, codeGenProperties); - generateGlobalsAPI(*stencilInstantiation, stencilWrapperClass, globalsMap, codeGenProperties); + generateGlobalsAPI(stencilWrapperClass, globalsMap, codeGenProperties); generateStencilWrapperRun(stencilWrapperClass, stencilInstantiation, codeGenProperties); diff --git a/dawn/src/dawn/CodeGen/CodeGen.cpp b/dawn/src/dawn/CodeGen/CodeGen.cpp index d439418a3..667ea030d 100644 --- a/dawn/src/dawn/CodeGen/CodeGen.cpp +++ b/dawn/src/dawn/CodeGen/CodeGen.cpp @@ -107,8 +107,7 @@ std::string CodeGen::generateGlobals(const sir::GlobalVariableMap& globalsMap, return ss.str(); } -void CodeGen::generateGlobalsAPI(const iir::StencilInstantiation& stencilInstantiation, - Class& stencilWrapperClass, +void CodeGen::generateGlobalsAPI(Structure& stencilWrapperClass, const sir::GlobalVariableMap& globalsMap, const CodeGenProperties& codeGenProperties) const { @@ -131,7 +130,7 @@ void CodeGen::generateGlobalsAPI(const iir::StencilInstantiation& stencilInstant setter.addArg(std::string(sir::Value::typeToString(globalValue.getType())) + " " + globalProp.first); setter.finishArgs(); - setter.addStatement("m_globals." + globalProp.first + "=" + globalProp.first); + setter.addStatement("m_globals." + globalProp.first + "=" + globalProp.first); setter.commit(); } } diff --git a/dawn/src/dawn/CodeGen/CodeGen.h b/dawn/src/dawn/CodeGen/CodeGen.h index 03755eceb..d31bdcca6 100644 --- a/dawn/src/dawn/CodeGen/CodeGen.h +++ b/dawn/src/dawn/CodeGen/CodeGen.h @@ -92,8 +92,7 @@ class CodeGen { CodeGenProperties computeCodeGenProperties(const iir::StencilInstantiation* stencilInstantiation) const; - virtual void generateGlobalsAPI(const iir::StencilInstantiation& stencilInstantiation, - Class& stencilWrapperClass, + virtual void generateGlobalsAPI(Structure& stencilWrapperClass, const sir::GlobalVariableMap& globalsMap, const CodeGenProperties& codeGenProperties) const; virtual std::string generateGlobals(const StencilInstantiationContext& context, diff --git a/dawn/src/dawn/CodeGen/Cuda-ico/ASTStencilBody.cpp b/dawn/src/dawn/CodeGen/Cuda-ico/ASTStencilBody.cpp index 2c69fb2bb..cb0887521 100644 --- a/dawn/src/dawn/CodeGen/Cuda-ico/ASTStencilBody.cpp +++ b/dawn/src/dawn/CodeGen/Cuda-ico/ASTStencilBody.cpp @@ -74,7 +74,20 @@ void ASTStencilBody::visit(const std::shared_ptr& stmt) { } void ASTStencilBody::visit(const std::shared_ptr& expr) { - DAWN_ASSERT_MSG(0, "Var Access not allowed in this context"); + std::string name = getName(expr); + int AccessID = iir::getAccessID(expr); + + if(metadata_.isAccessType(iir::FieldAccessType::GlobalVariable, AccessID)) { + ss_ << "globals." << name; + } else { + ss_ << name; + + if(expr->isArrayAccess()) { + ss_ << "["; + expr->getIndex()->accept(*this); + ss_ << "]"; + } + } } void ASTStencilBody::visit(const std::shared_ptr& expr) { diff --git a/dawn/src/dawn/CodeGen/Cuda-ico/CudaIcoCodeGen.cpp b/dawn/src/dawn/CodeGen/Cuda-ico/CudaIcoCodeGen.cpp index 69b5ab718..49963421e 100644 --- a/dawn/src/dawn/CodeGen/Cuda-ico/CudaIcoCodeGen.cpp +++ b/dawn/src/dawn/CodeGen/Cuda-ico/CudaIcoCodeGen.cpp @@ -161,6 +161,8 @@ void CudaIcoCodeGen::generateRunFun( const std::shared_ptr& stencilInstantiation, MemberFunction& runFun, CodeGenProperties& codeGenProperties) { + const auto& globalsMap = stencilInstantiation->getIIR()->getGlobalVariableMap(); + // find block sizes to generate std::set stageLocType; for(const auto& ms : iterateIIROver(*(stencilInstantiation->getIIR()))) { @@ -308,7 +310,9 @@ void CudaIcoCodeGen::generateRunFun( kernelCall << "<<<" << "dG" + std::to_string(stage->getStageID()) + ",dB" << ">>>("; - + if(!globalsMap.empty()) { + kernelCall << "m_globals, "; + } kernelCall << numElString << ", "; // which loc size args (int CellIdx, int EdgeIdx, int CellIdx) need to be passed additionally? @@ -364,9 +368,13 @@ void CudaIcoCodeGen::generateRunFun( } void CudaIcoCodeGen::generateStencilClassCtr(MemberFunction& ctor, const iir::Stencil& stencil, + const sir::GlobalVariableMap& globalsMap, CodeGenProperties& codeGenProperties) const { // arguments: mesh, kSize, fields + if(!globalsMap.empty()) { + ctor.addArg("globals globals"); + } ctor.addArg("const dawn::mesh_t& mesh"); ctor.addArg("int kSize"); for(auto field : support::orderMap(stencil.getFields())) { @@ -389,6 +397,9 @@ void CudaIcoCodeGen::generateStencilClassCtr(MemberFunction& ctor, const iir::St ctor.addInit("sbase(\"" + stencilName + "\")"); ctor.addInit("mesh_(mesh)"); ctor.addInit("kSize_(kSize)"); + if(!globalsMap.empty()) { + ctor.addInit("m_globals(globals)"); + } std::stringstream fieldsStr; { @@ -406,8 +417,12 @@ void CudaIcoCodeGen::generateStencilClassCtr(MemberFunction& ctor, const iir::St void CudaIcoCodeGen::generateStencilClassCtrMinimal(MemberFunction& ctor, const iir::Stencil& stencil, + const sir::GlobalVariableMap& globalsMap, CodeGenProperties& codeGenProperties) const { + if(!globalsMap.empty()) { + ctor.addArg("globals globals"); + } // arguments: mesh, kSize, fields ctor.addArg("const dawn::GlobalGpuTriMesh *mesh"); ctor.addArg("int kSize"); @@ -417,6 +432,9 @@ void CudaIcoCodeGen::generateStencilClassCtrMinimal(MemberFunction& ctor, codeGenProperties.getStencilName(StencilContext::SC_Stencil, stencil.getStencilID()); ctor.addInit("sbase(\"" + stencilName + "\")"); ctor.addInit("mesh_(mesh)"); + if(!globalsMap.empty()) { + ctor.addInit("m_globals(globals)"); + } ctor.addInit("kSize_(kSize)"); } @@ -576,6 +594,7 @@ void CudaIcoCodeGen::generateStencilClasses( Class& stencilWrapperClass, CodeGenProperties& codeGenProperties) { const auto& stencils = stencilInstantiation->getStencils(); + const auto& globalsMap = stencilInstantiation->getIIR()->getGlobalVariableMap(); // Stencil members: // generate the code for each of the stencils @@ -587,6 +606,8 @@ void CudaIcoCodeGen::generateStencilClasses( Structure stencilClass = stencilWrapperClass.addStruct(stencilName, "", "sbase"); + generateGlobalsAPI(stencilClass, globalsMap, codeGenProperties); + // generate members (fields + kSize + gpuMesh) stencilClass.changeAccessibility("private"); for(auto field : support::orderMap(stencil.getFields())) { @@ -597,9 +618,13 @@ void CudaIcoCodeGen::generateStencilClasses( stencilClass.changeAccessibility("public"); + if(!globalsMap.empty()) { + stencilClass.addMember("globals", "m_globals"); + } + // constructor from library auto stencilClassConstructor = stencilClass.addConstructor(); - generateStencilClassCtr(stencilClassConstructor, stencil, codeGenProperties); + generateStencilClassCtr(stencilClassConstructor, stencil, globalsMap, codeGenProperties); stencilClassConstructor.commit(); // grid helper fun @@ -613,7 +638,8 @@ void CudaIcoCodeGen::generateStencilClasses( // minmal ctor auto stencilClassMinimalConstructor = stencilClass.addConstructor(); - generateStencilClassCtrMinimal(stencilClassMinimalConstructor, stencil, codeGenProperties); + generateStencilClassCtrMinimal(stencilClassMinimalConstructor, stencil, globalsMap, + codeGenProperties); stencilClassMinimalConstructor.commit(); // run method @@ -647,6 +673,8 @@ void CudaIcoCodeGen::generateAllAPIRunFunctions( CodeGenProperties& codeGenProperties, bool fromHost) { const auto& stencils = stencilInstantiation->getStencils(); + const auto& globalsMap = stencilInstantiation->getIIR()->getGlobalVariableMap(); + CollectChainStrings chainCollector; std::set> chains; for(const auto& doMethod : iterateIIROver(*(stencilInstantiation->getIIR()))) { @@ -732,8 +760,10 @@ void CudaIcoCodeGen::generateAllAPIRunFunctions( } for(auto& apiRunFun : apiRunFuns) { - apiRunFun->addStatement(wrapperName + "::" + stencilName + " s(mesh, k_size)"); + apiRunFun->addStatement(wrapperName + "::" + stencilName + " s(" + + (globalsMap.empty() ? "" : "globals(), ") + "mesh, k_size)"); } if(fromHost) { // depending if we are calling from c or from fortran, we need to transpose the data or not @@ -792,6 +822,7 @@ void CudaIcoCodeGen::generateAllCudaKernels( const std::shared_ptr& stencilInstantiation) { ASTStencilBody stencilBodyCXXVisitor(stencilInstantiation->getMetaData()); + const auto& globalsMap = stencilInstantiation->getIIR()->getGlobalVariableMap(); for(const auto& ms : iterateIIROver(*(stencilInstantiation->getIIR()))) { for(const auto& stage : ms->getChildren()) { @@ -829,6 +860,9 @@ void CudaIcoCodeGen::generateAllCudaKernels( retString, cuda::CodeGeneratorHelper::buildCudaKernelName(stencilInstantiation, ms, stage), ssSW); + if(!globalsMap.empty()) { + cudaKernel.addArg("globals globals"); + } auto loc = *stage->getLocationType(); cudaKernel.addArg("int " + locToDenseSizeStringGpuMesh(loc)); @@ -1029,8 +1063,7 @@ std::unique_ptr CudaIcoCodeGen::generateCode() { "using namespace gridtools::dawn;", }; - // globals not yet supported - std::string globals = ""; + std::string globals = generateGlobals(context_, "dawn_generated", "cuda_ico"); DAWN_LOG(INFO) << "Done generating code"; diff --git a/dawn/src/dawn/CodeGen/Cuda-ico/CudaIcoCodeGen.h b/dawn/src/dawn/CodeGen/Cuda-ico/CudaIcoCodeGen.h index bad5c0f95..e93cf6a61 100644 --- a/dawn/src/dawn/CodeGen/Cuda-ico/CudaIcoCodeGen.h +++ b/dawn/src/dawn/CodeGen/Cuda-ico/CudaIcoCodeGen.h @@ -86,10 +86,12 @@ class CudaIcoCodeGen : public CodeGen { void generateGridFun(MemberFunction& runFun); void generateStencilClassCtr(MemberFunction& stencilClassCtor, const iir::Stencil& stencil, + const sir::GlobalVariableMap& globalsMap, CodeGenProperties& codeGenProperties) const; void generateStencilClassCtrMinimal(MemberFunction& stencilClassCtor, const iir::Stencil& stencil, - CodeGenProperties& codeGenProperties) const; + const sir::GlobalVariableMap& globalsMap, + CodeGenProperties& codeGenProperties) const; void generateStencilClassRawPtrCtr(MemberFunction& stencilClassCtor, const iir::Stencil& stencil, CodeGenProperties& codeGenProperties) const; diff --git a/dawn/src/dawn/CodeGen/Cuda/CudaCodeGen.cpp b/dawn/src/dawn/CodeGen/Cuda/CudaCodeGen.cpp index b14087671..5ee0b0832 100644 --- a/dawn/src/dawn/CodeGen/Cuda/CudaCodeGen.cpp +++ b/dawn/src/dawn/CodeGen/Cuda/CudaCodeGen.cpp @@ -125,7 +125,7 @@ std::string CudaCodeGen::generateStencilInstantiation( generateStencilWrapperCtr(stencilWrapperClass, stencilInstantiation, codeGenProperties); if(!globalsMap.empty()) { - generateGlobalsAPI(*stencilInstantiation, stencilWrapperClass, globalsMap, codeGenProperties); + generateGlobalsAPI(stencilWrapperClass, globalsMap, codeGenProperties); } generateStencilWrapperSyncMethod(stencilWrapperClass); diff --git a/dawn/src/dawn/CodeGen/GridTools/GTCodeGen.cpp b/dawn/src/dawn/CodeGen/GridTools/GTCodeGen.cpp index 5045f54ff..9997fdc8e 100644 --- a/dawn/src/dawn/CodeGen/GridTools/GTCodeGen.cpp +++ b/dawn/src/dawn/CodeGen/GridTools/GTCodeGen.cpp @@ -226,8 +226,7 @@ void GTCodeGen::generatePlaceholderDefinitions( } } -void GTCodeGen::generateGlobalsAPI(const iir::StencilInstantiation& stencilInstantiation, - Class& stencilWrapperClass, +void GTCodeGen::generateGlobalsAPI(Structure& stencilWrapperClass, const sir::GlobalVariableMap& globalsMap, const CodeGenProperties& codeGenProperties) const { @@ -287,7 +286,7 @@ std::string GTCodeGen::generateStencilInstantiation( generateStencilWrapperRun(stencilWrapperClass, stencilInstantiation, codeGenProperties); if(!globalsMap.empty()) { - generateGlobalsAPI(*stencilInstantiation, stencilWrapperClass, globalsMap, codeGenProperties); + generateGlobalsAPI(stencilWrapperClass, globalsMap, codeGenProperties); } generateStencilWrapperPublicMemberFunctions(stencilWrapperClass, codeGenProperties); diff --git a/dawn/src/dawn/CodeGen/GridTools/GTCodeGen.h b/dawn/src/dawn/CodeGen/GridTools/GTCodeGen.h index 33748f532..b984b21f7 100644 --- a/dawn/src/dawn/CodeGen/GridTools/GTCodeGen.h +++ b/dawn/src/dawn/CodeGen/GridTools/GTCodeGen.h @@ -93,8 +93,7 @@ class GTCodeGen : public CodeGen { bool isTemporary(iir::Stencil::FieldInfo const& f) const { return f.IsTemporary; } - void generateGlobalsAPI(const iir::StencilInstantiation& stencilInstantiation, - Class& stencilWrapperClass, const sir::GlobalVariableMap& globalsMap, + void generateGlobalsAPI(Structure& stencilWrapperClass, const sir::GlobalVariableMap& globalsMap, const CodeGenProperties& codeGenProperties) const override; void generateStencilWrapperMembers( diff --git a/dawn/test/integration-test/dawn4py-tests/CMakeLists.txt b/dawn/test/integration-test/dawn4py-tests/CMakeLists.txt index 1867664df..6af989480 100644 --- a/dawn/test/integration-test/dawn4py-tests/CMakeLists.txt +++ b/dawn/test/integration-test/dawn4py-tests/CMakeLists.txt @@ -53,3 +53,5 @@ add_python_example(NAME generate_empty_stage) add_python_example(NAME generate_versioned_field) add_python_example(NAME unstructured_masked_fields) add_python_example(NAME vertical_indirection) +add_python_example(NAME global_var) +add_python_example(NAME global_var_unstructured) diff --git a/dawn/test/integration-test/dawn4py-tests/global_var.py b/dawn/test/integration-test/dawn4py-tests/global_var.py new file mode 100644 index 000000000..d83666b84 --- /dev/null +++ b/dawn/test/integration-test/dawn4py-tests/global_var.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python + +##===-----------------------------------------------------------------------------*- Python -*-===## +## _ +## | | +## __| | __ ___ ___ ___ +## / _` |/ _` \ \ /\ / / '_ | +## | (_| | (_| |\ V V /| | | | +## \__,_|\__,_| \_/\_/ |_| |_| - Compiler Toolchain +## +## +## This file is distributed under the MIT License (MIT). +## See LICENSE.txt for details. +## +##===------------------------------------------------------------------------------------------===## + +import argparse +import os + +import dawn4py +from dawn4py.serialization import SIR +from dawn4py.serialization import utils as sir_utils + +OUTPUT_NAME = "global_var_stencil" +OUTPUT_FILE = f"{OUTPUT_NAME}.cpp" +OUTPUT_PATH = f"{OUTPUT_NAME}.cpp" + + +def main(args: argparse.Namespace): + interval = sir_utils.make_interval( + SIR.Interval.Start, SIR.Interval.End, 0, 0) + + body_ast = sir_utils.make_ast( + [ + sir_utils.make_assignment_stmt( + sir_utils.make_field_access_expr("out", [0, 0, 0]), + sir_utils.make_binary_operator(sir_utils.make_var_access_expr( + "dt", is_external=True), "*", sir_utils.make_field_access_expr("in", [1, 0, 0])), + "="), + ] + ) + + vertical_region_stmt = sir_utils.make_vertical_region_decl_stmt( + body_ast, interval, SIR.VerticalRegion.Forward + ) + + globals = SIR.GlobalVariableMap() + globals.map["dt"].double_value = 0.5 + + sir = sir_utils.make_sir( + OUTPUT_FILE, + SIR.GridType.Value("Cartesian"), + [ + sir_utils.make_stencil( + OUTPUT_NAME, + sir_utils.make_ast([vertical_region_stmt]), + [ + sir_utils.make_field( + "in", sir_utils.make_field_dimensions_cartesian()), + sir_utils.make_field( + "out", sir_utils.make_field_dimensions_cartesian()), + ], + ) + ], + global_variables=globals + ) + + # print the SIR + if args.verbose: + sir_utils.pprint(sir) + + # compile + code = dawn4py.compile(sir, backend=dawn4py.CodeGenBackend.CXXNaive) + + # write to file + print(f"Writing generated code to '{OUTPUT_PATH}'") + with open(OUTPUT_PATH, "w") as f: + f.write(code) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Generate a simple stencil with globals using Dawn compiler" + ) + parser.add_argument( + "-v", + "--verbose", + dest="verbose", + action="store_true", + default=False, + help="Print the generated SIR", + ) + main(parser.parse_args()) diff --git a/dawn/test/integration-test/dawn4py-tests/global_var_unstructured.py b/dawn/test/integration-test/dawn4py-tests/global_var_unstructured.py new file mode 100644 index 000000000..e60432c13 --- /dev/null +++ b/dawn/test/integration-test/dawn4py-tests/global_var_unstructured.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python + +##===-----------------------------------------------------------------------------*- Python -*-===## +## _ +## | | +## __| | __ ___ ___ ___ +## / _` |/ _` \ \ /\ / / '_ | +## | (_| | (_| |\ V V /| | | | +## \__,_|\__,_| \_/\_/ |_| |_| - Compiler Toolchain +## +## +## This file is distributed under the MIT License (MIT). +## See LICENSE.txt for details. +## +##===------------------------------------------------------------------------------------------===## + +import argparse +import os + +import dawn4py +from dawn4py.serialization import SIR +from dawn4py.serialization import utils as sir_utils + +OUTPUT_NAME = "global_var_stencil_unstructured" +OUTPUT_FILE = f"{OUTPUT_NAME}.cpp" +OUTPUT_PATH = f"{OUTPUT_NAME}.cpp" + + +def main(args: argparse.Namespace): + interval = sir_utils.make_interval( + SIR.Interval.Start, SIR.Interval.End, 0, 0) + + body_ast = sir_utils.make_ast( + [ + sir_utils.make_assignment_stmt( + sir_utils.make_field_access_expr("out"), + sir_utils.make_binary_operator(sir_utils.make_var_access_expr( + "dt", is_external=True), "*", sir_utils.make_field_access_expr("in")), + "="), + ] + ) + + vertical_region_stmt = sir_utils.make_vertical_region_decl_stmt( + body_ast, interval, SIR.VerticalRegion.Forward + ) + + globals = SIR.GlobalVariableMap() + globals.map["dt"].double_value = 0.5 + + sir = sir_utils.make_sir( + OUTPUT_FILE, + SIR.GridType.Value("Unstructured"), + [ + sir_utils.make_stencil( + OUTPUT_NAME, + sir_utils.make_ast([vertical_region_stmt]), + [ + sir_utils.make_field( + "in", + sir_utils.make_field_dimensions_unstructured( + [SIR.LocationType.Value("Edge")], 1 + ), + ), + sir_utils.make_field( + "out", + sir_utils.make_field_dimensions_unstructured( + [SIR.LocationType.Value("Edge")], 1 + ), + ), + ], + ) + ], + global_variables=globals + ) + + # print the SIR + if args.verbose: + sir_utils.pprint(sir) + + # compile + code = dawn4py.compile(sir, backend=dawn4py.CodeGenBackend.CUDAIco) + + # write to file + print(f"Writing generated code to '{OUTPUT_PATH}'") + with open(OUTPUT_PATH, "w") as f: + f.write(code) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Generate a simple stencil with globals using Dawn compiler" + ) + parser.add_argument( + "-v", + "--verbose", + dest="verbose", + action="store_true", + default=False, + help="Print the generated SIR", + ) + main(parser.parse_args()) diff --git a/dawn/test/integration-test/unstructured/AtlasIntegrationTestCompareOutput.cpp b/dawn/test/integration-test/unstructured/AtlasIntegrationTestCompareOutput.cpp index a1064ed97..c4d777fcb 100644 --- a/dawn/test/integration-test/unstructured/AtlasIntegrationTestCompareOutput.cpp +++ b/dawn/test/integration-test/unstructured/AtlasIntegrationTestCompareOutput.cpp @@ -1078,4 +1078,33 @@ TEST(AtlasIntegrationTestCompareOutput, iterationSpaceUnstructured) { } } // namespace +namespace { +#include +TEST(AtlasIntegrationTestCompareOutput, globalVar) { + auto mesh = generateQuadMesh(10, 10); + size_t nb_levels = 10; + const double dt = 2.0; + + auto [in_F, in_v] = makeAtlasField("in", mesh.cells().size(), nb_levels); + auto [out_F, out_v] = makeAtlasField("out", mesh.cells().size(), nb_levels); + + // Initialize fields with data + initField(in_v, mesh.cells().size(), nb_levels, 1.0); + initField(out_v, mesh.cells().size(), nb_levels, -1.0); + + // Run the stencil + auto stencil = dawn_generated::cxxnaiveico::globalVar( + mesh, static_cast(nb_levels), in_v, out_v); + stencil.set_dt(dt); + stencil.run(); + + // Check correctness of the output + for(int k = 0; k < nb_levels; k++) { + for(int cell_idx = 0; cell_idx < mesh.cells().size(); ++cell_idx) { + ASSERT_EQ(out_v(cell_idx, k), dt); + } + } +} +} // namespace + } // namespace \ No newline at end of file diff --git a/dawn/test/integration-test/unstructured/CMakeLists.txt b/dawn/test/integration-test/unstructured/CMakeLists.txt index da30ac520..960dcbade 100644 --- a/dawn/test/integration-test/unstructured/CMakeLists.txt +++ b/dawn/test/integration-test/unstructured/CMakeLists.txt @@ -39,6 +39,7 @@ set(generated_stencil_codes generated_accumulateEdgeToCell.hpp generated_tridiagonalSolve.hpp generated_verticalIndirecion.hpp generated_verticalSum.hpp + generated_globalVar.hpp ) set(reference_stencil_codes reference_diffusion.hpp diff --git a/dawn/test/integration-test/unstructured/GenerateUnstructuredStencils.cpp b/dawn/test/integration-test/unstructured/GenerateUnstructuredStencils.cpp index 25390bff0..5f53cc219 100644 --- a/dawn/test/integration-test/unstructured/GenerateUnstructuredStencils.cpp +++ b/dawn/test/integration-test/unstructured/GenerateUnstructuredStencils.cpp @@ -239,7 +239,7 @@ int main() { b.reduceOverNeighborExpr( Op::plus, b.at(in_f, HOffsetType::withOffset, 0), b.binaryExpr(b.unaryExpr(b.at(cnt), Op::minus), - b.at(in_f, HOffsetType::withOffset, 0), Op::multiply), + b.at(in_f, HOffsetType::noOffset, 0), Op::multiply), {LocType::Cells, LocType::Edges, LocType::Cells}))), b.stmt(b.assignExpr( b.at(out_f), @@ -879,5 +879,31 @@ int main() { of << dawn::codegen::generate(tu) << std::endl; } + { + using namespace dawn::iir; + using LocType = dawn::ast::LocationType; + + UnstructuredIIRBuilder b; + auto in_f = b.field("in_field", LocType::Cells); + auto out_f = b.field("out_field", LocType::Cells); + auto global = b.globalvar("dt", 0.5); + std::string stencilName = "globalVar"; + + auto stencilInstantiation = b.build( + stencilName, + b.stencil(b.multistage( + LoopOrderKind::Parallel, + b.stage( + LocType::Cells, + b.doMethod(dawn::sir::Interval::Start, dawn::sir::Interval::End, + b.stmt(b.assignExpr(b.at(out_f), b.binaryExpr(b.at(global), b.at(in_f), + Op::multiply)))))))); + + std::ofstream of("generated/generated_" + stencilName + ".hpp"); + DAWN_ASSERT_MSG(of, "couldn't open output file!\n"); + auto tu = dawn::codegen::run(stencilInstantiation, dawn::codegen::Backend::CXXNaiveIco); + of << dawn::codegen::generate(tu) << std::endl; + } + return 0; } \ No newline at end of file diff --git a/dawn/test/unit-test/dawn/CodeGen/reference/conditional_stencil.cpp b/dawn/test/unit-test/dawn/CodeGen/reference/conditional_stencil.cpp index ca3dbf029..aee86942b 100644 --- a/dawn/test/unit-test/dawn/CodeGen/reference/conditional_stencil.cpp +++ b/dawn/test/unit-test/dawn/CodeGen/reference/conditional_stencil.cpp @@ -70,7 +70,7 @@ class conditional_stencil { // Input/Output storages public: - stencil_21(const gridtools::dawn::domain& dom_, const globals& globals_, int rank, int xcols, int ycols) : m_dom(dom_), m_globals(globals_){} + stencil_21(const gridtools::dawn::domain& dom_, globals& globals_, int rank, int xcols, int ycols) : m_dom(dom_), m_globals(globals_){} static constexpr ::dawn::driver::cartesian_extent in_extent = {-1,1, -1,1, 0,0}; static constexpr ::dawn::driver::cartesian_extent out_extent = {0,0, 0,0, 0,0}; diff --git a/dawn/test/unit-test/dawn/CodeGen/reference/update_dz_c.cpp b/dawn/test/unit-test/dawn/CodeGen/reference/update_dz_c.cpp index f769c862f..1a763f143 100644 --- a/dawn/test/unit-test/dawn/CodeGen/reference/update_dz_c.cpp +++ b/dawn/test/unit-test/dawn/CodeGen/reference/update_dz_c.cpp @@ -74,7 +74,7 @@ class update_dz_c { tmp_storage_t m_fy; public: - stencil_443(const gridtools::dawn::domain& dom_, const globals& globals_, int rank, int xcols, int ycols) : m_dom(dom_), m_globals(globals_), m_tmp_meta_data(dom_.isize() + 1, dom_.jsize() + 1, dom_.ksize() + 2*0), m_xfx(m_tmp_meta_data), m_yfx(m_tmp_meta_data), m_fx(m_tmp_meta_data), m_fy(m_tmp_meta_data){} + stencil_443(const gridtools::dawn::domain& dom_, globals& globals_, int rank, int xcols, int ycols) : m_dom(dom_), m_globals(globals_), m_tmp_meta_data(dom_.isize() + 1, dom_.jsize() + 1, dom_.ksize() + 2*0), m_xfx(m_tmp_meta_data), m_yfx(m_tmp_meta_data), m_fx(m_tmp_meta_data), m_fy(m_tmp_meta_data){} static constexpr ::dawn::driver::cartesian_extent dp_ref_extent = {0,1, 0,1, -2,1}; static constexpr ::dawn::driver::cartesian_extent zs_extent = {0,0, 0,0, 0,0}; static constexpr ::dawn::driver::cartesian_extent area_extent = {0,0, 0,0, 0,0};