From 232d3b60805bfc2e104e2f6ad1721f72ff62f932 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 12 Nov 2024 12:13:42 -0600 Subject: [PATCH 01/45] Cleanup noalias analysis (#2164) * Cleanup noalias analysis * cleanup * fix * fix --- enzyme/Enzyme/EnzymeLogic.cpp | 13 +- enzyme/Enzyme/EnzymeLogic.h | 6 + enzyme/Enzyme/JLInstSimplify.cpp | 181 ++-------------------- enzyme/Enzyme/LibraryFuncs.h | 17 ++- enzyme/Enzyme/Utils.cpp | 201 +++++++++++++++++++++++++ enzyme/Enzyme/Utils.h | 18 +++ enzyme/test/Enzyme/JLSimplify/lerr.ll | 2 +- enzyme/test/Enzyme/JLSimplify/loads.ll | 2 +- 8 files changed, 256 insertions(+), 184 deletions(-) diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 4d4b2b5cbc05..02595a3cbb84 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -3551,15 +3551,18 @@ void createInvertedTerminator(DiffeGradientUtils *gutils, } Function *EnzymeLogic::CreatePrimalAndGradient( - RequestContext context, const ReverseCacheKey &&key, TypeAnalysis &TA, + RequestContext context, const ReverseCacheKey &&prevkey, TypeAnalysis &TA, const AugmentedReturn *augmenteddata, bool omp) { - TimeTraceScope timeScope("CreatePrimalAndGradient", key.todiff->getName()); + TimeTraceScope timeScope("CreatePrimalAndGradient", + prevkey.todiff->getName()); - assert(key.mode == DerivativeMode::ReverseModeCombined || - key.mode == DerivativeMode::ReverseModeGradient); + assert(prevkey.mode == DerivativeMode::ReverseModeCombined || + prevkey.mode == DerivativeMode::ReverseModeGradient); - FnTypeInfo oldTypeInfo = preventTypeAnalysisLoops(key.typeInfo, key.todiff); + FnTypeInfo oldTypeInfo = + preventTypeAnalysisLoops(prevkey.typeInfo, prevkey.todiff); + auto key = prevkey.replaceTypeInfo(oldTypeInfo); if (key.retType != DIFFE_TYPE::CONSTANT) assert(!key.todiff->getReturnType()->isVoidTy()); diff --git a/enzyme/Enzyme/EnzymeLogic.h b/enzyme/Enzyme/EnzymeLogic.h index f3155d10894a..e01534f3edda 100644 --- a/enzyme/Enzyme/EnzymeLogic.h +++ b/enzyme/Enzyme/EnzymeLogic.h @@ -168,6 +168,12 @@ struct ReverseCacheKey { const FnTypeInfo typeInfo; bool runtimeActivity; + ReverseCacheKey replaceTypeInfo(const FnTypeInfo &newTypeInfo) const { + return {todiff, retType, constant_args, overwritten_args, + returnUsed, shadowReturnUsed, mode, width, + freeMemory, AtomicAdd, additionalType, forceAnonymousTape, + newTypeInfo, runtimeActivity}; + } /* inline bool operator==(const ReverseCacheKey& rhs) const { return todiff == rhs.todiff && diff --git a/enzyme/Enzyme/JLInstSimplify.cpp b/enzyme/Enzyme/JLInstSimplify.cpp index 9390649099de..f0d29f0f856d 100644 --- a/enzyme/Enzyme/JLInstSimplify.cpp +++ b/enzyme/Enzyme/JLInstSimplify.cpp @@ -49,6 +49,7 @@ #include "llvm-c/Types.h" #include "JLInstSimplify.h" +#include "LibraryFuncs.h" #include "Utils.h" using namespace llvm; @@ -58,80 +59,6 @@ using namespace llvm; #define DEBUG_TYPE "jl-inst-simplify" namespace { -bool notCapturedBefore(llvm::Value *V, Instruction *inst) { - Instruction *VI = dyn_cast(V); - if (!VI) - VI = &*inst->getParent()->getParent()->getEntryBlock().begin(); - else - VI = VI->getNextNode(); - SmallPtrSet regionBetween; - { - SmallVector todo; - todo.push_back(VI->getParent()); - while (todo.size()) { - auto cur = todo.pop_back_val(); - if (regionBetween.count(cur)) - continue; - regionBetween.insert(cur); - if (cur == inst->getParent()) - continue; - for (auto BB : successors(cur)) - todo.push_back(BB); - } - } - SmallVector todo = {V}; - SmallPtrSet seen; - while (todo.size()) { - auto cur = todo.pop_back_val(); - if (seen.count(cur)) - continue; - for (auto U : cur->users()) { - auto UI = dyn_cast(U); - if (!regionBetween.count(UI->getParent())) - continue; - if (UI->getParent() == VI->getParent()) { - if (UI->comesBefore(VI)) - continue; - } - if (UI->getParent() == inst->getParent()) - if (inst->comesBefore(UI)) - continue; - - if (isPointerArithmeticInst(UI, /*includephi*/ true, - /*includebin*/ true)) { - todo.push_back(UI); - continue; - } - - if (auto CI = dyn_cast(UI)) { -#if LLVM_VERSION_MAJOR >= 14 - for (size_t i = 0, size = CI->arg_size(); i < size; i++) -#else - for (size_t i = 0, size = CI->getNumArgOperands(); i < size; i++) -#endif - { - if (cur == CI->getArgOperand(i)) { - if (isNoCapture(CI, i)) - continue; - return false; - } - } - return true; - } - - if (isa(UI)) { - continue; - } - if (isa(UI)) { - todo.push_back(UI); - continue; - } - return false; - } - } - return true; -} - bool jlInstSimplify(llvm::Function &F, TargetLibraryInfo &TLI, llvm::AAResults &AA, llvm::LoopInfo &LI) { bool changed = false; @@ -178,108 +105,24 @@ bool jlInstSimplify(llvm::Function &F, TargetLibraryInfo &TLI, } if (legal) { - auto lhs = getBaseObject(I.getOperand(0), /*offsetAllowed*/ false); - auto rhs = getBaseObject(I.getOperand(1), /*offsetAllowed*/ false); - if (lhs == rhs) { - auto repval = ICmpInst::isTrueWhenEqual(pred) - ? ConstantInt::get(I.getType(), 1) - : ConstantInt::get(I.getType(), 0); - I.replaceAllUsesWith(repval); - changed = true; - continue; - } - if ((isNoAlias(lhs) && (isNoAlias(rhs) || isa(rhs))) || - (isNoAlias(rhs) && isa(lhs))) { + if (auto alias = arePointersGuaranteedNoAlias( + TLI, AA, LI, I.getOperand(0), I.getOperand(1), false)) { + +#if LLVM_VERSION_MAJOR >= 16 + bool val = alias.value(); +#else + bool val = alias.getValue(); +#endif auto repval = ICmpInst::isTrueWhenEqual(pred) - ? ConstantInt::get(I.getType(), 0) - : ConstantInt::get(I.getType(), 1); + ? ConstantInt::get(I.getType(), 1 - val) + : ConstantInt::get(I.getType(), val); I.replaceAllUsesWith(repval); changed = true; continue; } - - { - bool noalias_from_capture = false; - for (int i = 0; i < 2; i++) { - Value *start = (i == 0) ? lhs : rhs; - Value *end = (i == 0) ? rhs : lhs; - if (isNoAlias(start)) { - if (auto endi = dyn_cast(end)) { - if (notCapturedBefore(start, endi)) { - noalias_from_capture = true; - break; - } - } - } - } - if (noalias_from_capture) { - auto repval = ICmpInst::isTrueWhenEqual(pred) - ? ConstantInt::get(I.getType(), 0) - : ConstantInt::get(I.getType(), 1); - I.replaceAllUsesWith(repval); - changed = true; - continue; - } - } - - auto llhs = dyn_cast(lhs); - auto lrhs = dyn_cast(rhs); - if (llhs && lrhs && isa(llhs->getType()) && - isa(lrhs->getType())) { - auto lhsv = - getBaseObject(llhs->getOperand(0), /*offsetAllowed*/ false); - auto rhsv = - getBaseObject(lrhs->getOperand(0), /*offsetAllowed*/ false); - if ((isNoAlias(lhsv) && (isNoAlias(rhsv) || isa(rhsv) || - notCapturedBefore(lhsv, &I))) || - (isNoAlias(rhsv) && - (isa(lhsv) || notCapturedBefore(rhsv, &I)))) { - bool legal = false; - for (int i = 0; i < 2; i++) { - Value *start = (i == 0) ? lhsv : rhsv; - Instruction *starti = dyn_cast(start); - if (!starti) { - if (!isa(start)) - continue; - starti = &cast(start) - ->getParent() - ->getEntryBlock() - .front(); - } - - bool overwritten = false; - allInstructionsBetween( - LI, starti, &I, [&](Instruction *I) -> bool { - if (!I->mayWriteToMemory()) - return /*earlyBreak*/ false; - - for (auto LI : {llhs, lrhs}) - if (writesToMemoryReadBy(nullptr, AA, TLI, - /*maybeReader*/ LI, - /*maybeWriter*/ I)) { - overwritten = true; - return /*earlyBreak*/ true; - } - return /*earlyBreak*/ false; - }); - if (!overwritten) { - legal = true; - break; - } - } - - if (legal && lhsv != rhsv) { - auto repval = ICmpInst::isTrueWhenEqual(pred) - ? ConstantInt::get(I.getType(), 0) - : ConstantInt::get(I.getType(), 1); - I.replaceAllUsesWith(repval); - changed = true; - continue; - } - } - } } } + return changed; } diff --git a/enzyme/Enzyme/LibraryFuncs.h b/enzyme/Enzyme/LibraryFuncs.h index 7e4feb661c96..27697bc41a53 100644 --- a/enzyme/Enzyme/LibraryFuncs.h +++ b/enzyme/Enzyme/LibraryFuncs.h @@ -266,10 +266,14 @@ llvm::CallInst *freeKnownAllocation(llvm::IRBuilder<> &builder, static inline bool isAllocationCall(const llvm::Value *TmpOrig, llvm::TargetLibraryInfo &TLI) { - if (auto *CI = llvm::dyn_cast(TmpOrig)) { - return isAllocationFunction(getFuncNameFromCall(CI), TLI); - } - if (auto *CI = llvm::dyn_cast(TmpOrig)) { + if (auto *CI = llvm::dyn_cast(TmpOrig)) { + auto AttrList = + CI->getAttributes().getAttributes(llvm::AttributeList::FunctionIndex); + if (AttrList.hasAttribute("enzyme_allocation")) + return true; + if (auto Fn = getFunctionFromCall(CI)) + if (Fn->hasFnAttribute("enzyme_allocation")) + return true; return isAllocationFunction(getFuncNameFromCall(CI), TLI); } return false; @@ -277,10 +281,7 @@ static inline bool isAllocationCall(const llvm::Value *TmpOrig, static inline bool isDeallocationCall(const llvm::Value *TmpOrig, llvm::TargetLibraryInfo &TLI) { - if (auto *CI = llvm::dyn_cast(TmpOrig)) { - return isDeallocationFunction(getFuncNameFromCall(CI), TLI); - } - if (auto *CI = llvm::dyn_cast(TmpOrig)) { + if (auto *CI = llvm::dyn_cast(TmpOrig)) { return isDeallocationFunction(getFuncNameFromCall(CI), TLI); } return false; diff --git a/enzyme/Enzyme/Utils.cpp b/enzyme/Enzyme/Utils.cpp index 2e78f7c083cd..90907d96dec3 100644 --- a/enzyme/Enzyme/Utils.cpp +++ b/enzyme/Enzyme/Utils.cpp @@ -2218,7 +2218,10 @@ bool overwritesToMemoryReadBy(const TypeResults *TR, llvm::AAResults &AA, const SCEV *StoreBegin = SE.getCouldNotCompute(); const SCEV *StoreEnd = SE.getCouldNotCompute(); + Value *loadPtr = nullptr; + Value *storePtr = nullptr; if (auto LI = dyn_cast(maybeReader)) { + loadPtr = LI->getPointerOperand(); LoadBegin = SE.getSCEV(LI->getPointerOperand()); if (LoadBegin != SE.getCouldNotCompute() && !LoadBegin->getType()->isIntegerTy()) { @@ -2236,6 +2239,7 @@ bool overwritesToMemoryReadBy(const TypeResults *TR, llvm::AAResults &AA, } } if (auto SI = dyn_cast(maybeWriter)) { + storePtr = SI->getPointerOperand(); StoreBegin = SE.getSCEV(SI->getPointerOperand()); if (StoreBegin != SE.getCouldNotCompute() && !StoreBegin->getType()->isIntegerTy()) { @@ -2255,6 +2259,7 @@ bool overwritesToMemoryReadBy(const TypeResults *TR, llvm::AAResults &AA, } } if (auto MS = dyn_cast(maybeWriter)) { + storePtr = MS->getArgOperand(0); StoreBegin = SE.getSCEV(MS->getArgOperand(0)); if (StoreBegin != SE.getCouldNotCompute() && !StoreBegin->getType()->isIntegerTy()) { @@ -2269,6 +2274,7 @@ bool overwritesToMemoryReadBy(const TypeResults *TR, llvm::AAResults &AA, } } if (auto MS = dyn_cast(maybeWriter)) { + storePtr = MS->getArgOperand(0); StoreBegin = SE.getSCEV(MS->getArgOperand(0)); if (StoreBegin != SE.getCouldNotCompute() && !StoreBegin->getType()->isIntegerTy()) { @@ -2283,6 +2289,7 @@ bool overwritesToMemoryReadBy(const TypeResults *TR, llvm::AAResults &AA, } } if (auto MS = dyn_cast(maybeReader)) { + loadPtr = MS->getArgOperand(1); LoadBegin = SE.getSCEV(MS->getArgOperand(1)); if (LoadBegin != SE.getCouldNotCompute() && !LoadBegin->getType()->isIntegerTy()) { @@ -2297,6 +2304,16 @@ bool overwritesToMemoryReadBy(const TypeResults *TR, llvm::AAResults &AA, } } + if (loadPtr && storePtr) + if (auto alias = + arePointersGuaranteedNoAlias(TLI, AA, LI, loadPtr, storePtr, true)) +#if LLVM_VERSION_MAJOR >= 16 + if (alias.value()) +#else + if (alias.getValue()) +#endif + return false; + if (!overwritesToMemoryReadByLoop(SE, LI, DT, maybeReader, LoadBegin, LoadEnd, maybeWriter, StoreBegin, StoreEnd, scope)) return false; @@ -3810,3 +3827,187 @@ bool isNVLoad(const llvm::Value *V) { } return false; } + +bool notCapturedBefore(llvm::Value *V, Instruction *inst, + size_t checkLoadCaptures) { + Instruction *VI = dyn_cast(V); + if (!VI) + VI = &*inst->getParent()->getParent()->getEntryBlock().begin(); + else + VI = VI->getNextNode(); + SmallPtrSet regionBetween; + { + SmallVector todo; + todo.push_back(VI->getParent()); + while (todo.size()) { + auto cur = todo.pop_back_val(); + if (regionBetween.count(cur)) + continue; + regionBetween.insert(cur); + if (cur == inst->getParent()) + continue; + for (auto BB : successors(cur)) + todo.push_back(BB); + } + } + SmallVector, 1> todo; + for (auto U : V->users()) { + todo.emplace_back(cast(U), checkLoadCaptures, V); + } + std::set> seen; + while (todo.size()) { + auto pair = todo.pop_back_val(); + if (seen.count(pair)) + continue; + auto UI = std::get<0>(pair); + auto level = std::get<1>(pair); + auto prev = std::get<2>(pair); + if (!regionBetween.count(UI->getParent())) + continue; + if (UI->getParent() == VI->getParent()) { + if (UI->comesBefore(VI)) + continue; + } + if (UI->getParent() == inst->getParent()) + if (inst->comesBefore(UI)) + continue; + + if (isPointerArithmeticInst(UI, /*includephi*/ true, + /*includebin*/ true)) { + for (auto U2 : UI->users()) { + auto UI2 = cast(U2); + todo.emplace_back(UI2, level, UI); + } + continue; + } + + if (isa(UI)) + continue; + + if (isa(UI)) { + if (level == 0) + continue; + if (UI->getOperand(1) != prev) + continue; + } + + if (auto CI = dyn_cast(UI)) { +#if LLVM_VERSION_MAJOR >= 14 + for (size_t i = 0, size = CI->arg_size(); i < size; i++) +#else + for (size_t i = 0, size = CI->getNumArgOperands(); i < size; i++) +#endif + { + if (prev == CI->getArgOperand(i)) { + if (isNoCapture(CI, i) && level == 0) + continue; + return false; + } + } + return true; + } + + if (isa(UI)) { + continue; + } + if (isa(UI)) { + if (level) { + for (auto U2 : UI->users()) { + auto UI2 = cast(U2); + todo.emplace_back(UI2, level - 1, UI); + } + } + continue; + } + // storing into it. + if (auto SI = dyn_cast(UI)) { + if (SI->getValueOperand() != prev) { + continue; + } + } + return false; + } + return true; +} + +// Return true if guaranteed not to alias +// Return false if guaranteed to alias [with possible offset depending on flag]. +// Return {} if no information is given. +#if LLVM_VERSION_MAJOR >= 16 +std::optional +#else +llvm::Optional +#endif +arePointersGuaranteedNoAlias(TargetLibraryInfo &TLI, llvm::AAResults &AA, + llvm::LoopInfo &LI, llvm::Value *op0, + llvm::Value *op1, bool offsetAllowed) { + auto lhs = getBaseObject(op0, offsetAllowed); + auto rhs = getBaseObject(op1, offsetAllowed); + + if (lhs == rhs) { + return false; + } + if (!lhs->getType()->isPointerTy() && !rhs->getType()->isPointerTy()) + return {}; + + bool noalias_lhs = isNoAlias(lhs); + bool noalias_rhs = isNoAlias(rhs); + + bool noalias[2] = {noalias_lhs, noalias_rhs}; + + for (int i = 0; i < 2; i++) { + Value *start = (i == 0) ? lhs : rhs; + Value *end = (i == 0) ? rhs : lhs; + if (noalias[i]) { + if (noalias[1 - i]) { + return true; + } + if (isa(end)) { + return true; + } + if (auto endi = dyn_cast(end)) { + if (notCapturedBefore(start, endi, 0)) { + return true; + } + } + } + if (auto ld = dyn_cast(start)) { + auto base = getBaseObject(ld->getOperand(0), /*offsetAllowed*/ false); + if (isAllocationCall(base, TLI)) { + if (isa(end)) + return true; + if (auto endi = dyn_cast(end)) + if (isNoAlias(end) || (notCapturedBefore(start, endi, 1))) { + Instruction *starti = dyn_cast(start); + if (!starti) { + if (!isa(start)) + continue; + starti = + &cast(start)->getParent()->getEntryBlock().front(); + } + + bool overwritten = false; + allInstructionsBetween( + LI, starti, endi, [&](Instruction *I) -> bool { + if (!I->mayWriteToMemory()) + return /*earlyBreak*/ false; + + if (writesToMemoryReadBy(nullptr, AA, TLI, + /*maybeReader*/ ld, + /*maybeWriter*/ I)) { + overwritten = true; + return /*earlyBreak*/ true; + } + return /*earlyBreak*/ false; + }); + + if (!overwritten) { + return true; + } + } + } + } + } + + return {}; +} diff --git a/enzyme/Enzyme/Utils.h b/enzyme/Enzyme/Utils.h index f825af47607c..62493de9046f 100644 --- a/enzyme/Enzyme/Utils.h +++ b/enzyme/Enzyme/Utils.h @@ -2098,4 +2098,22 @@ llvm::CallInst *createIntrinsicCall(llvm::IRBuilderBase &B, bool isNVLoad(const llvm::Value *V); +//! Check if value if b captured after definition before executing inst. +//! If checkLoadCaptured != 0, also consider catpures of any loads of the value +//! as a capture (for the number of loads set). +bool notCapturedBefore(llvm::Value *V, llvm::Instruction *inst, + size_t checkLoadCaptured); + +// Return true if guaranteed not to alias +// Return false if guaranteed to alias [with possible offset depending on flag]. +// Return {} if no information is given. +#if LLVM_VERSION_MAJOR >= 16 +std::optional +#else +llvm::Optional +#endif +arePointersGuaranteedNoAlias(llvm::TargetLibraryInfo &TLI, llvm::AAResults &AA, + llvm::LoopInfo &LI, llvm::Value *op0, + llvm::Value *op1, bool offsetAllowed = false); + #endif // ENZYME_UTILS_H diff --git a/enzyme/test/Enzyme/JLSimplify/lerr.ll b/enzyme/test/Enzyme/JLSimplify/lerr.ll index a0c71cb0c425..c49b32e551b8 100644 --- a/enzyme/test/Enzyme/JLSimplify/lerr.ll +++ b/enzyme/test/Enzyme/JLSimplify/lerr.ll @@ -26,6 +26,6 @@ top: ret i1 %.not } -declare noalias {} addrspace(10)* @ijl_alloc_array_2d({} addrspace(10)*, i64, i64) +declare noalias {} addrspace(10)* @ijl_alloc_array_2d({} addrspace(10)*, i64, i64) "enzyme_allocation" ; CHECK: ret i1 false diff --git a/enzyme/test/Enzyme/JLSimplify/loads.ll b/enzyme/test/Enzyme/JLSimplify/loads.ll index b4ed3b274a93..83043d822787 100644 --- a/enzyme/test/Enzyme/JLSimplify/loads.ll +++ b/enzyme/test/Enzyme/JLSimplify/loads.ll @@ -3,7 +3,7 @@ declare void @julia.safepoint() -declare noalias nonnull {} addrspace(10)* @ijl_new_array({} addrspace(10)*, {} addrspace(10)*) +declare noalias nonnull {} addrspace(10)* @ijl_new_array({} addrspace(10)*, {} addrspace(10)*) "enzyme_allocation" declare {}* @julia.pointer_from_objref({} addrspace(11)*) readnone From 896da02c7eb1b2db4d3aedebea6d81b14cd1c5a9 Mon Sep 17 00:00:00 2001 From: Matt Bolitho Date: Tue, 12 Nov 2024 19:44:51 +0000 Subject: [PATCH 02/45] Marks functions inline in BLAS tblgen (#2163) * Adds blas-tblgen header guards * Marks header function implementations as inline --- enzyme/tools/enzyme-tblgen/blas-tblgen.h | 5 +++++ enzyme/tools/enzyme-tblgen/blasDeclUpdater.h | 13 +++++++++---- enzyme/tools/enzyme-tblgen/blasDiffUseUpdater.h | 9 +++++++-- enzyme/tools/enzyme-tblgen/blasTAUpdater.h | 11 ++++++++--- 4 files changed, 29 insertions(+), 9 deletions(-) diff --git a/enzyme/tools/enzyme-tblgen/blas-tblgen.h b/enzyme/tools/enzyme-tblgen/blas-tblgen.h index da2d5b7ae901..247eb986de21 100644 --- a/enzyme/tools/enzyme-tblgen/blas-tblgen.h +++ b/enzyme/tools/enzyme-tblgen/blas-tblgen.h @@ -1,3 +1,6 @@ +#ifndef ENZYME_TBLGEN_BLAS_TBLGEN_H +#define ENZYME_TBLGEN_BLAS_TBLGEN_H + #include "llvm/ADT/SmallString.h" #include "llvm/TableGen/Record.h" @@ -9,3 +12,5 @@ bool hasAdjoint(const TGPattern &pattern, const llvm::Init *resultTree, llvm::StringRef argName); llvm::SmallString<80> ValueType_helper(const TGPattern &pattern, ssize_t actPos, const llvm::DagInit *ruleDag); + +#endif // ENZYME_TBLGEN_BLAS_TBLGEN_H diff --git a/enzyme/tools/enzyme-tblgen/blasDeclUpdater.h b/enzyme/tools/enzyme-tblgen/blasDeclUpdater.h index c0df30b73d9f..35c87c8e93b7 100644 --- a/enzyme/tools/enzyme-tblgen/blasDeclUpdater.h +++ b/enzyme/tools/enzyme-tblgen/blasDeclUpdater.h @@ -1,7 +1,10 @@ +#ifndef ENZYME_TBLGEN_BLAS_DECL_UPDATER_H +#define ENZYME_TBLGEN_BLAS_DECL_UPDATER_H + #include "datastructures.h" -void emit_attributeBLASCaller(ArrayRef blasPatterns, - raw_ostream &os) { +inline void emit_attributeBLASCaller(ArrayRef blasPatterns, + raw_ostream &os) { os << "void attributeBLAS(BlasInfo blas, llvm::Function *F) { \n"; os << " if (!F->empty())\n"; os << " return;\n"; @@ -15,7 +18,7 @@ void emit_attributeBLASCaller(ArrayRef blasPatterns, os << "} \n"; } -void emit_attributeBLAS(const TGPattern &pattern, raw_ostream &os) { +inline void emit_attributeBLAS(const TGPattern &pattern, raw_ostream &os) { auto name = pattern.getName(); bool lv23 = pattern.isBLASLevel2or3(); os << "llvm::Constant* attribute_" << name @@ -196,7 +199,7 @@ void emit_attributeBLAS(const TGPattern &pattern, raw_ostream &os) { os << "}\n"; } -void emitBlasDeclUpdater(const RecordKeeper &RK, raw_ostream &os) { +inline void emitBlasDeclUpdater(const RecordKeeper &RK, raw_ostream &os) { emitSourceFileHeader("Rewriters", os); const auto &blasPatterns = RK.getAllDerivedDefinitions("CallBlasPattern"); @@ -283,3 +286,5 @@ void emitBlasDeclUpdater(const RecordKeeper &RK, raw_ostream &os) { os << " return changed;\n"; os << "}\n"; } + +#endif // ENZYME_TBLGEN_BLAS_DECL_UPDATER_H diff --git a/enzyme/tools/enzyme-tblgen/blasDiffUseUpdater.h b/enzyme/tools/enzyme-tblgen/blasDiffUseUpdater.h index 3da25ae96dfd..db78a62717d5 100644 --- a/enzyme/tools/enzyme-tblgen/blasDiffUseUpdater.h +++ b/enzyme/tools/enzyme-tblgen/blasDiffUseUpdater.h @@ -1,10 +1,13 @@ +#ifndef ENZYME_TBLGEN_BLAS_DIFF_USE_UPDATER_H +#define ENZYME_TBLGEN_BLAS_DIFF_USE_UPDATER_H + #include "blas-tblgen.h" #include "caching.h" #include "datastructures.h" #include "enzyme-tblgen.h" #include "llvm/Support/raw_ostream.h" -void emit_BLASDiffUse(TGPattern &pattern, llvm::raw_ostream &os) { +inline void emit_BLASDiffUse(TGPattern &pattern, llvm::raw_ostream &os) { auto typeMap = pattern.getArgTypeMap(); auto argUsers = pattern.getArgUsers(); bool lv23 = pattern.isBLASLevel2or3(); @@ -139,7 +142,7 @@ void emit_BLASDiffUse(TGPattern &pattern, llvm::raw_ostream &os) { os << "}\n"; } -void emitBlasDiffUse(const RecordKeeper &RK, llvm::raw_ostream &os) { +inline void emitBlasDiffUse(const RecordKeeper &RK, llvm::raw_ostream &os) { emitSourceFileHeader("Rewriters", os); const auto &blasPatterns = RK.getAllDerivedDefinitions("CallBlasPattern"); @@ -193,3 +196,5 @@ void emitBlasDiffUse(const RecordKeeper &RK, llvm::raw_ostream &os) { os << " }\n"; os << "}\n"; } + +#endif // ENZYME_TBLGEN_BLAS_DIFF_USE_UPDATER_H diff --git a/enzyme/tools/enzyme-tblgen/blasTAUpdater.h b/enzyme/tools/enzyme-tblgen/blasTAUpdater.h index d87d2eb702d4..33f63d163382 100644 --- a/enzyme/tools/enzyme-tblgen/blasTAUpdater.h +++ b/enzyme/tools/enzyme-tblgen/blasTAUpdater.h @@ -1,6 +1,9 @@ +#ifndef ENZYME_TBLGEN_BLAS_TA_UPDATER_H +#define ENZYME_TBLGEN_BLAS_TA_UPDATER_H + #include "datastructures.h" -void emit_BLASTypes(raw_ostream &os) { +inline void emit_BLASTypes(raw_ostream &os) { os << "const bool byRef = blas.prefix == \"\" || blas.prefix == " "\"cublas_\";\n"; os << "const bool byRefFloat = byRef || blas.prefix == " @@ -65,7 +68,7 @@ void emit_BLASTypes(raw_ostream &os) { // cblas lv23 => layout // cublas => always handle -void emit_BLASTA(TGPattern &pattern, raw_ostream &os) { +inline void emit_BLASTA(TGPattern &pattern, raw_ostream &os) { auto name = pattern.getName(); bool lv23 = pattern.isBLASLevel2or3(); @@ -180,7 +183,7 @@ void emit_BLASTA(TGPattern &pattern, raw_ostream &os) { os << "}\n"; } -void emitBlasTAUpdater(const RecordKeeper &RK, raw_ostream &os) { +inline void emitBlasTAUpdater(const RecordKeeper &RK, raw_ostream &os) { emitSourceFileHeader("Rewriters", os); const auto &blasPatterns = RK.getAllDerivedDefinitions("CallBlasPattern"); @@ -199,3 +202,5 @@ void emitBlasTAUpdater(const RecordKeeper &RK, raw_ostream &os) { emit_BLASTA(newPattern, os); } } + +#endif // ENZYME_TBLGEN_BLAS_TA_UPDATER_H From 624621aac978b1726cb6330ec442f4b63955a4a2 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 12 Nov 2024 20:23:35 -0600 Subject: [PATCH 03/45] Mac optional (#2165) --- enzyme/Enzyme/JLInstSimplify.cpp | 6 +----- enzyme/Enzyme/Utils.cpp | 6 +----- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/enzyme/Enzyme/JLInstSimplify.cpp b/enzyme/Enzyme/JLInstSimplify.cpp index f0d29f0f856d..36f7d21095eb 100644 --- a/enzyme/Enzyme/JLInstSimplify.cpp +++ b/enzyme/Enzyme/JLInstSimplify.cpp @@ -108,11 +108,7 @@ bool jlInstSimplify(llvm::Function &F, TargetLibraryInfo &TLI, if (auto alias = arePointersGuaranteedNoAlias( TLI, AA, LI, I.getOperand(0), I.getOperand(1), false)) { -#if LLVM_VERSION_MAJOR >= 16 - bool val = alias.value(); -#else - bool val = alias.getValue(); -#endif + bool val = *alias; auto repval = ICmpInst::isTrueWhenEqual(pred) ? ConstantInt::get(I.getType(), 1 - val) : ConstantInt::get(I.getType(), val); diff --git a/enzyme/Enzyme/Utils.cpp b/enzyme/Enzyme/Utils.cpp index 90907d96dec3..bfdee67dd7f9 100644 --- a/enzyme/Enzyme/Utils.cpp +++ b/enzyme/Enzyme/Utils.cpp @@ -2307,11 +2307,7 @@ bool overwritesToMemoryReadBy(const TypeResults *TR, llvm::AAResults &AA, if (loadPtr && storePtr) if (auto alias = arePointersGuaranteedNoAlias(TLI, AA, LI, loadPtr, storePtr, true)) -#if LLVM_VERSION_MAJOR >= 16 - if (alias.value()) -#else - if (alias.getValue()) -#endif + if (*alias) return false; if (!overwritesToMemoryReadByLoop(SE, LI, DT, maybeReader, LoadBegin, LoadEnd, From 0a81fa1c4835699549dbd06000cee81ea3b6b1ea Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 16 Nov 2024 01:01:33 -0500 Subject: [PATCH 04/45] Clear attribute of generated (#2166) * Clear attribute of generated * fix * Update inactbuffree.ll --- enzyme/Enzyme/EnzymeLogic.cpp | 18 ++++++++++++++++-- .../ForwardMode/intelSubscriptIntrinsic.ll | 8 ++++---- enzyme/test/Enzyme/ReverseMode/hascast.ll | 1 + enzyme/test/Enzyme/ReverseMode/inactbuffree.ll | 2 +- enzyme/test/Enzyme/ReverseMode/tgamma.ll | 4 ++-- 5 files changed, 24 insertions(+), 9 deletions(-) diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 02595a3cbb84..06f11c933bb9 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -1696,8 +1696,22 @@ void clearFunctionAttributes(Function *f) { if (Arg.hasAttribute(Attribute::StructRet)) Arg.removeAttr(Attribute::StructRet); } - if (f->hasFnAttribute(Attribute::OptimizeNone)) - f->removeFnAttr(Attribute::OptimizeNone); + + Attribute::AttrKind fnattrs[] = { +#if LLVM_VERSION_MAJOR >= 16 + Attribute::Memory, +#endif + Attribute::ReadOnly, + Attribute::ReadNone, + Attribute::WriteOnly, + Attribute::WillReturn, + Attribute::OptimizeNone + }; + for (auto attr : fnattrs) { + if (f->hasFnAttribute(attr)) { + f->removeFnAttr(attr); + } + } if (f->getAttributes().getRetDereferenceableBytes()) { f->removeRetAttr(Attribute::Dereferenceable); diff --git a/enzyme/test/Enzyme/ForwardMode/intelSubscriptIntrinsic.ll b/enzyme/test/Enzyme/ForwardMode/intelSubscriptIntrinsic.ll index e2c325370740..38cb109e362d 100644 --- a/enzyme/test/Enzyme/ForwardMode/intelSubscriptIntrinsic.ll +++ b/enzyme/test/Enzyme/ForwardMode/intelSubscriptIntrinsic.ll @@ -38,17 +38,17 @@ declare void @__enzyme_fwddiff(float (...)* noalias, %"QNCA_a0$float*$rank1$"* n !8 = !{!"ifx$unique_sym$10", !9, i64 0} !9 = !{!"Fortran Data Symbol", !4, i64 0} -; CHECK: define internal float @fwddiffeselectfirst(%"QNCA_a0$float*$rank1$"* noalias nocapture readonly dereferenceable(72) "ptrnoalias" %X, %"QNCA_a0$float*$rank1$"* nocapture %"X'") local_unnamed_addr #0 { +; CHECK: define internal float @fwddiffeselectfirst(%"QNCA_a0$float*$rank1$"* noalias nocapture readonly dereferenceable(72) "ptrnoalias" %X, %"QNCA_a0$float*$rank1$"* nocapture %"X'") ; CHECK-NEXT: entry: ; CHECK-NEXT: %"X.addr_a0$'ipg" = getelementptr inbounds %"QNCA_a0$float*$rank1$", %"QNCA_a0$float*$rank1$"* %"X'", i64 0, i32 0 ; CHECK-NEXT: %"X.addr_a0$" = getelementptr inbounds %"QNCA_a0$float*$rank1$", %"QNCA_a0$float*$rank1$"* %X, i64 0, i32 0 ; CHECK-NEXT: %"X.addr_a0$_fetch.29'ipl" = load float*, float** %"X.addr_a0$'ipg", align 1, !tbaa !0, !alias.scope !10, !noalias !13 ; CHECK-NEXT: %"X.addr_a0$_fetch.29" = load float*, float** %"X.addr_a0$", align 1, !tbaa !0, !alias.scope !13, !noalias !10 ; CHECK-NEXT: %"X.dim_info$.lower_bound$" = getelementptr inbounds %"QNCA_a0$float*$rank1$", %"QNCA_a0$float*$rank1$"* %X, i64 0, i32 6, i64 0, i32 2 -; CHECK-NEXT: %"X.dim_info$.lower_bound$[]" = tail call i64* @llvm.intel.subscript.p0i64.i64.i32.p0i64.i32(i8 0, i64 0, i32 24, i64* nonnull elementtype(i64) %"X.dim_info$.lower_bound$", i32 0) #0 +; CHECK-NEXT: %"X.dim_info$.lower_bound$[]" = tail call i64* @llvm.intel.subscript.p0i64.i64.i32.p0i64.i32(i8 0, i64 0, i32 24, i64* nonnull elementtype(i64) %"X.dim_info$.lower_bound$", i32 0) ; CHECK-NEXT: %"X.dim_info$.lower_bound$[]_fetch.30" = load i64, i64* %"X.dim_info$.lower_bound$[]", align 1, !tbaa !6, !alias.scope !13, !noalias !10 ; CHECK-NEXT: %0 = call float* @llvm.intel.subscript.p0f32.i64.i64.p0f32.i64(i8 0, i64 %"X.dim_info$.lower_bound$[]_fetch.30", i64 4, float* elementtype(float) %"X.addr_a0$_fetch.29'ipl", i64 1) -; CHECK-NEXT: %"X.addr_a0$_fetch.29[]" = tail call float* @llvm.intel.subscript.p0f32.i64.i64.p0f32.i64(i8 0, i64 %"X.dim_info$.lower_bound$[]_fetch.30", i64 4, float* elementtype(float) %"X.addr_a0$_fetch.29", i64 1) #0 +; CHECK-NEXT: %"X.addr_a0$_fetch.29[]" = tail call float* @llvm.intel.subscript.p0f32.i64.i64.p0f32.i64(i8 0, i64 %"X.dim_info$.lower_bound$[]_fetch.30", i64 4, float* elementtype(float) %"X.addr_a0$_fetch.29", i64 1) ; CHECK-NEXT: %"X.addr_a0$_fetch.29[]_fetch.32'ipl" = load float, float* %0, align 1, !tbaa !7, !alias.scope !15, !noalias !18 ; CHECK-NEXT: ret float %"X.addr_a0$_fetch.29[]_fetch.32'ipl" -; CHECK-NEXT: } \ No newline at end of file +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ReverseMode/hascast.ll b/enzyme/test/Enzyme/ReverseMode/hascast.ll index cbfb21960ad8..ace44e2edf5c 100644 --- a/enzyme/test/Enzyme/ReverseMode/hascast.ll +++ b/enzyme/test/Enzyme/ReverseMode/hascast.ll @@ -132,6 +132,7 @@ attributes #3 = { nounwind } ; CHECK-NEXT: entry: ; CHECK-NEXT: %[[loadcallp:.+]] = load double, double* %[[callp]] ; CHECK-NEXT: store double 0.000000e+00, double* %[[callp]] +; CHECK-NEXT: call void @diffecast(double* %x, double* %"x'") ; CHECK-NEXT: %[[m0diffez:.+]] = fmul fast double %[[loadcallp]], %y ; CHECK-NEXT: %[[m1diffey:.+]] = fmul fast double %[[loadcallp]], %z ; CHECK-NEXT: %[[toret0:.+]] = insertvalue { double, double } undef, double %[[m1diffey]], 0 diff --git a/enzyme/test/Enzyme/ReverseMode/inactbuffree.ll b/enzyme/test/Enzyme/ReverseMode/inactbuffree.ll index 77e1b75901d1..8a42a83eb96a 100644 --- a/enzyme/test/Enzyme/ReverseMode/inactbuffree.ll +++ b/enzyme/test/Enzyme/ReverseMode/inactbuffree.ll @@ -42,7 +42,7 @@ declare double @__enzyme_autodiff(i8*, double, i64) ; CHECK: define internal { double } @diffesquare(double %x, i64 %i, double %differeturn) ; CHECK-NEXT: entry: -; CHECK-NEXT: %call = tail call double* @alloc() #3 +; CHECK-NEXT: %call = tail call double* @alloc() ; CHECK-NEXT: store double 3.000000e+00, double* %call ; CHECK-NEXT: %arrayidx2 = getelementptr inbounds double, double* %call, i64 %i ; CHECK-NEXT: %ld = load double, double* %arrayidx2 diff --git a/enzyme/test/Enzyme/ReverseMode/tgamma.ll b/enzyme/test/Enzyme/ReverseMode/tgamma.ll index 5bd34a8757dc..bad1d7fbf0f3 100644 --- a/enzyme/test/Enzyme/ReverseMode/tgamma.ll +++ b/enzyme/test/Enzyme/ReverseMode/tgamma.ll @@ -34,10 +34,10 @@ declare double @tgamma(double) ; CHECK: define internal fastcc { double } @diffea5(double %arg, double %differeturn) ; CHECK-NEXT: bb: -; CHECK-NEXT: %i8 = call double @tgamma(double %arg) #1 +; CHECK-NEXT: %i8 = call double @tgamma(double %arg) ; CHECK-NEXT: %0 = fmul fast double %differeturn, %arg ; CHECK-NEXT: %1 = fmul fast double %i8, %differeturn -; CHECK-NEXT: %2 = call fast double @digamma(double %arg) #0 +; CHECK-NEXT: %2 = call fast double @digamma(double %arg) ; CHECK-NEXT: %3 = fmul fast double %2, %0 ; CHECK-NEXT: %4 = fadd fast double %1, %3 ; CHECK-NEXT: %5 = insertvalue { double } undef, double %4, 0 From c87b3ad80ebdc6503966d7187a008ed7c84a7e4c Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 16 Nov 2024 03:15:50 -0500 Subject: [PATCH 05/45] BLAS: fix blas fptype for complex (#2167) --- enzyme/Enzyme/Utils.cpp | 6 +++++- enzyme/Enzyme/Utils.h | 2 +- enzyme/tools/enzyme-tblgen/blasTAUpdater.h | 2 +- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/enzyme/Enzyme/Utils.cpp b/enzyme/Enzyme/Utils.cpp index bfdee67dd7f9..e1f9d21dc376 100644 --- a/enzyme/Enzyme/Utils.cpp +++ b/enzyme/Enzyme/Utils.cpp @@ -579,14 +579,18 @@ void ErrorIfRuntimeInactive(llvm::IRBuilder<> &B, llvm::Value *primal, call->setDebugLoc(loc); } -Type *BlasInfo::fpType(LLVMContext &ctx) const { +Type *BlasInfo::fpType(LLVMContext &ctx, bool to_scalar) const { if (floatType == "d" || floatType == "D") { return Type::getDoubleTy(ctx); } else if (floatType == "s" || floatType == "S") { return Type::getFloatTy(ctx); } else if (floatType == "c" || floatType == "C") { + if (to_scalar) + return Type::getFloatTy(ctx); return VectorType::get(Type::getFloatTy(ctx), 2, false); } else if (floatType == "z" || floatType == "Z") { + if (to_scalar) + return Type::getDoubleTy(ctx); return VectorType::get(Type::getDoubleTy(ctx), 2, false); } else { assert(false && "Unreachable"); diff --git a/enzyme/Enzyme/Utils.h b/enzyme/Enzyme/Utils.h index 62493de9046f..02ce4b8b47e2 100644 --- a/enzyme/Enzyme/Utils.h +++ b/enzyme/Enzyme/Utils.h @@ -678,7 +678,7 @@ struct BlasInfo { std::string function; bool is64; - llvm::Type *fpType(llvm::LLVMContext &ctx) const; + llvm::Type *fpType(llvm::LLVMContext &ctx, bool to_scalar = false) const; llvm::IntegerType *intType(llvm::LLVMContext &ctx) const; }; diff --git a/enzyme/tools/enzyme-tblgen/blasTAUpdater.h b/enzyme/tools/enzyme-tblgen/blasTAUpdater.h index 33f63d163382..2ff92d047b06 100644 --- a/enzyme/tools/enzyme-tblgen/blasTAUpdater.h +++ b/enzyme/tools/enzyme-tblgen/blasTAUpdater.h @@ -16,7 +16,7 @@ inline void emit_BLASTypes(raw_ostream &os) { "\"cublas\" && StringRef(blas.suffix).contains(\"v2\");\n"; os << "TypeTree ttFloat;\n" - << "llvm::Type *floatType = blas.fpType(call.getContext()); \n" + << "llvm::Type *floatType = blas.fpType(call.getContext(), true); \n" << "if (byRefFloat) {\n" << " ttFloat.insert({-1},BaseType::Pointer);\n" << " ttFloat.insert({-1,0},floatType);\n" From d9d6338c8177e9fd676d0c0afb7e5689119f9405 Mon Sep 17 00:00:00 2001 From: Matt Bolitho Date: Sun, 17 Nov 2024 19:43:34 +0000 Subject: [PATCH 06/45] Adds initial CMake presets configuration (#2169) * Adds CMakeUserPresets.json to gitignore * Adds presets based on root CMakeLists * Adds build presets * Adds naive preset configuration detection * Adds self review changes * Uses NOT DEFINED for preset check --- .gitignore | 2 + enzyme/CMakeLists.txt | 10 ++-- enzyme/CMakePresets.json | 103 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 111 insertions(+), 4 deletions(-) create mode 100644 enzyme/CMakePresets.json diff --git a/.gitignore b/.gitignore index a3c2751a0f5c..5e7285d5f0af 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,5 @@ enzyme/benchmarks/ReverseMode/*/*.exe enzyme/benchmarks/ReverseMode/*/results.txt enzyme/benchmarks/ReverseMode/*/results.json .cache +CMakeUserPresets.json +/out diff --git a/enzyme/CMakeLists.txt b/enzyme/CMakeLists.txt index 1f51afe2789d..2ba42e969ce8 100644 --- a/enzyme/CMakeLists.txt +++ b/enzyme/CMakeLists.txt @@ -17,11 +17,13 @@ add_definitions(-DENZYME_VERSION_PATCH=${ENZYME_PATCH_VERSION}) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) set(CMAKE_POSITION_INDEPENDENT_CODE ON) -SET(CMAKE_CXX_FLAGS "-Wall -fno-rtti ${CMAKE_CXX_FLAGS} -Werror=unused-variable -Werror=dangling-else -Werror=unused-but-set-variable -Werror=return-type -Werror=nonnull -Werror=unused-result -Werror=reorder -Werror=switch") -SET(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O2 -g -ggdb") -SET(CMAKE_CXX_FLAGS_RELEASE "-O2") -SET(CMAKE_CXX_FLAGS_DEBUG "-O0 -g -ggdb -fno-omit-frame-pointer") +if (NOT DEFINED ENZYME_CONFIGURED_WITH_PRESETS) + set(CMAKE_CXX_FLAGS "-Wall -fno-rtti ${CMAKE_CXX_FLAGS} -Werror=unused-variable -Werror=dangling-else -Werror=unused-but-set-variable -Werror=return-type -Werror=nonnull -Werror=unused-result -Werror=reorder -Werror=switch") + set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O2 -g -ggdb") + set(CMAKE_CXX_FLAGS_RELEASE "-O2") + set(CMAKE_CXX_FLAGS_DEBUG "-O0 -g -ggdb -fno-omit-frame-pointer") +endif() #SET(CMAKE_CXX_FLAGS_DEBUG "-O0 -g -fno-omit-frame-pointer -fsanitize=address") #SET(CMAKE_LINKER_FLAGS_DEBUG "${CMAKE_LINKER_FLAGS_DEBUG} -fno-omit-frame-pointer -fsanitize=address") diff --git a/enzyme/CMakePresets.json b/enzyme/CMakePresets.json new file mode 100644 index 000000000000..1a5ffa80c5d9 --- /dev/null +++ b/enzyme/CMakePresets.json @@ -0,0 +1,103 @@ +{ + "version": 3, + "configurePresets": [ + { + "name": "config-base", + "description": "Base configure preset.", + "hidden": true, + "generator": "Ninja", + "binaryDir": "${sourceDir}/out/build/${presetName}", + "installDir": "${sourceDir}/out/install/${presetName}", + "cacheVariables": { + "CMAKE_CXX_STANDARD": "17", + "CMAKE_CXX_STANDARD_REQUIRED": "ON", + "CMAKE_EXPORT_COMPILE_COMMANDS": "ON", + "ENZYME_CONFIGURED_WITH_PRESETS": "ON" + } + }, + { + "name": "config-base-linux", + "description": "Base configure preset for Linux.", + "inherits": "config-base", + "hidden": true, + "cacheVariables": { + "CMAKE_POSITION_INDEPENDENT_CODE": "ON" + }, + "condition": { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Linux" + } + }, + { + "name": "config-base-x64", + "description": "Base preset for x64 platforms.", + "hidden": true, + "architecture": { + "value": "x64", + "strategy": "external" + } + }, + { + "name": "x64-linux-clang", + "description": "Base preset for Linux development using Clang compilers.", + "hidden": true, + "inherits": [ + "config-base-x64", + "config-base-linux" + ], + "cacheVariables": { + "CMAKE_C_COMPILER": "clang", + "CMAKE_CXX_COMPILER": "clang++", + "CMAKE_CXX_FLAGS": "-Wall -fno-rtti -Werror=unused-variable -Werror=dangling-else -Werror=unused-but-set-variable -Werror=return-type -Werror=nonnull -Werror=unused-result -Werror=reorder -Werror=switch", + "CMAKE_CXX_FLAGS_DEBUG": "-O0 -g -ggdb -fno-omit-frame-pointer", + "CMAKE_CXX_FLAGS_RELEASE": "-O2", + "CMAKE_CXX_FLAGS_RELWITHDEBINFO": "-O2 -g -ggdb" + } + }, + { + "name": "x64-linux-clang-debug", + "displayName": "Clang x64 Linux Debug", + "inherits": "x64-linux-clang", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Debug" + } + }, + { + "name": "x64-linux-clang-release", + "displayName": "Clang x64 Linux Release", + "inherits": "x64-linux-clang", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release" + } + }, + { + "name": "x64-linux-clang-release-with-debug-info", + "displayName": "Clang x64 Linux Release with Debug Info", + "inherits": "x64-linux-clang", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "RelWithDebInfo" + } + } + ], + "buildPresets": [ + { + "name": "x64-linux-clang-debug", + "displayName": "Clang x64 Linux Debug", + "description": "Builds the project using Clang on Linux in Debug configuration.", + "configurePreset": "x64-linux-clang-debug" + }, + { + "name": "x64-linux-clang-release", + "displayName": "Clang x64 Linux Release", + "description": "Builds the project using Clang on Linux in Release configuration.", + "configurePreset": "x64-linux-clang-release" + }, + { + "name": "x64-linux-clang-release-with-debug-info", + "displayName": "Clang x64 Linux Release with Debug Info", + "description": "Builds the project using Clang on Linux in Release configuration with debug info.", + "configurePreset": "x64-linux-clang-release-with-debug-info" + } + ] +} From e42f5aabd1b5583e0e300550dd2d401535a9a87f Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 19 Nov 2024 13:01:03 -0500 Subject: [PATCH 07/45] Generalize blas attributor (#2171) --- enzyme/Enzyme/Utils.cpp | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/enzyme/Enzyme/Utils.cpp b/enzyme/Enzyme/Utils.cpp index e1f9d21dc376..933a22304e29 100644 --- a/enzyme/Enzyme/Utils.cpp +++ b/enzyme/Enzyme/Utils.cpp @@ -929,8 +929,25 @@ void callMemcpyStridedBlas(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas, FunctionType *FT = FunctionType::get(copy_retty, tys, false); auto fn = M.getOrInsertFunction(copy_name, FT); - Function *F = cast(fn.getCallee()); - attributeKnownFunctions(*F); + Value *callVal = fn.getCallee(); + Function *called = nullptr; + while (!called) { + if (auto castinst = dyn_cast(callVal)) + if (castinst->isCast()) { + callVal = castinst->getOperand(0); + continue; + } + if (auto fn = dyn_cast(callVal)) { + called = fn; + break; + } + if (auto alias = dyn_cast(callVal)) { + callVal = alias->getAliasee(); + continue; + } + break; + } + attributeKnownFunctions(*called); B.CreateCall(fn, args, bundles); } From 3a64b164db6556dda5daecf7e3746076d82bf720 Mon Sep 17 00:00:00 2001 From: Paul Berg Date: Wed, 20 Nov 2024 05:05:09 +0000 Subject: [PATCH 08/45] mlir: implement MemorySlot Interfaces for Gradient ops (#2168) --- enzyme/BUILD | 2 + enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td | 14 ++- enzyme/Enzyme/MLIR/Dialect/Ops.cpp | 115 ++++++++++++++++++++++++ enzyme/Enzyme/MLIR/Dialect/Ops.h | 1 + enzyme/Enzyme/MLIR/enzymemlir-opt.cpp | 2 + enzyme/test/MLIR/Passes/mem2reg.mlir | 14 +++ 6 files changed, 144 insertions(+), 4 deletions(-) create mode 100644 enzyme/test/MLIR/Passes/mem2reg.mlir diff --git a/enzyme/BUILD b/enzyme/BUILD index 2517c3a9ce43..bfb0f2f5c446 100644 --- a/enzyme/BUILD +++ b/enzyme/BUILD @@ -239,6 +239,7 @@ td_library( "@llvm-project//mlir:ControlFlowInterfacesTdFiles", "@llvm-project//mlir:FunctionInterfacesTdFiles", "@llvm-project//mlir:LoopLikeInterfaceTdFiles", + "@llvm-project//mlir:MemorySlotInterfacesTdFiles", "@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:SideEffectInterfacesTdFiles", "@llvm-project//mlir:ViewLikeInterfaceTdFiles", @@ -604,6 +605,7 @@ cc_library( "@llvm-project//mlir:LinalgStructuredOpsIncGen", "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:MemorySlotInterfaces", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:NVVMDialect", "@llvm-project//mlir:OpenMPDialect", diff --git a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td index 5d61cec8287b..ca7659143370 100644 --- a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td +++ b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td @@ -24,6 +24,7 @@ include "mlir/IR/AttrTypeBase.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/FunctionInterfaces.td" include "mlir/Interfaces/LoopLikeInterface.td" +include "mlir/Interfaces/MemorySlotInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -121,8 +122,9 @@ def PopOp : Enzyme_Op<"pop"> { let results = (outs AnyType:$output); } -def InitOp : Enzyme_Op<"init"> { - let summary = "Creat enzyme.gradient and enzyme.cache"; +def InitOp : Enzyme_Op<"init", + [DeclareOpInterfaceMethods]> { + let summary = "Create enzyme.gradient and enzyme.cache"; let arguments = (ins ); let results = (outs AnyType); } @@ -147,14 +149,18 @@ def Gradient : Enzyme_Type<"Gradient"> { let assemblyFormat = "`<` $basetype `>`"; } -def SetOp : Enzyme_Op<"set"> { +def SetOp : Enzyme_Op<"set", + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { let summary = "Store the current value of the gradient"; let arguments = (ins Arg:$gradient, AnyType : $value); let results = (outs ); } -def GetOp : Enzyme_Op<"get"> { +def GetOp : Enzyme_Op<"get", + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { let summary = "Load current value of gradient"; let arguments = (ins Arg:$gradient); diff --git a/enzyme/Enzyme/MLIR/Dialect/Ops.cpp b/enzyme/Enzyme/MLIR/Dialect/Ops.cpp index 5c2e4283d300..3e3185427306 100644 --- a/enzyme/Enzyme/MLIR/Dialect/Ops.cpp +++ b/enzyme/Enzyme/MLIR/Dialect/Ops.cpp @@ -8,12 +8,14 @@ #include "Ops.h" #include "Dialect.h" +#include "Interfaces/AutoDiffTypeInterface.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Builders.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/MemorySlotInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" @@ -36,6 +38,119 @@ using namespace mlir; using namespace enzyme; using namespace mlir::arith; +//===----------------------------------------------------------------------===// +// InitOp +//===----------------------------------------------------------------------===// + +llvm::SmallVector InitOp::getPromotableSlots() { + auto Ty = this->getType(); + if (isa(Ty)) + return {}; + + if (!getOperation()->getBlock()->isEntryBlock()) + return {}; + + auto gTy = cast(Ty); + MemorySlot slot = {this->getResult(), gTy.getBasetype()}; + + return {slot}; +} + +Value InitOp::getDefaultValue(const MemorySlot &slot, OpBuilder &builder) { + auto gTy = cast(this->getType()); + return cast(gTy.getBasetype()) + .createNullValue(builder, this->getLoc()); +} + +void InitOp::handleBlockArgument(const MemorySlot &slot, BlockArgument argument, + OpBuilder &builder) {} + +std::optional +InitOp::handlePromotionComplete(const MemorySlot &slot, Value defaultValue, + OpBuilder &builder) { + if (defaultValue && defaultValue.use_empty()) + defaultValue.getDefiningOp()->erase(); + this->erase(); + return std::nullopt; +} + +//===----------------------------------------------------------------------===// +// GetOp +//===----------------------------------------------------------------------===// + +bool GetOp::loadsFrom(const MemorySlot &slot) { + return this->getGradient() == slot.ptr; +} + +bool GetOp::storesTo(const MemorySlot &slot) { return false; } + +Value GetOp::getStored(const MemorySlot &slot, OpBuilder &builder, + Value reachingDef, const DataLayout &dataLayout) { + return {}; +} + +bool GetOp::canUsesBeRemoved( + const MemorySlot &slot, + const llvm::SmallPtrSetImpl &blockingUses, + llvm::SmallVectorImpl &newBlockingUses, + const mlir::DataLayout &dataLayout) { + if (blockingUses.size() != 1) + return false; + + Value blockingUse = (*blockingUses.begin())->get(); + return blockingUse == slot.ptr && getGradient() == slot.ptr; +} + +DeletionKind GetOp::removeBlockingUses( + const MemorySlot &slot, + const llvm::SmallPtrSetImpl &blockingUses, OpBuilder &builder, + Value reachingDefinition, const DataLayout &dataLayout) { + this->getResult().replaceAllUsesWith(reachingDefinition); + return DeletionKind::Delete; +} + +llvm::LogicalResult GetOp::ensureOnlySafeAccesses( + const MemorySlot &slot, llvm::SmallVectorImpl &mustBeSafelyUsed, + const DataLayout &dataLayout) { + return success(slot.ptr == getGradient()); +} + +//===----------------------------------------------------------------------===// +// SetOp +//===----------------------------------------------------------------------===// + +bool SetOp::loadsFrom(const MemorySlot &slot) { return false; } + +bool SetOp::storesTo(const MemorySlot &slot) { + return this->getGradient() == slot.ptr; +} + +Value SetOp::getStored(const MemorySlot &slot, OpBuilder &builder, + Value reachingDef, const DataLayout &dataLayout) { + return this->getValue(); +} + +bool SetOp::canUsesBeRemoved( + const MemorySlot &slot, + const llvm::SmallPtrSetImpl &blockingUses, + llvm::SmallVectorImpl &newBlockingUses, + const mlir::DataLayout &dataLayout) { + return true; +} + +DeletionKind SetOp::removeBlockingUses( + const MemorySlot &slot, + const llvm::SmallPtrSetImpl &blockingUses, OpBuilder &builder, + Value reachingDefinition, const DataLayout &dataLayout) { + return DeletionKind::Delete; +} + +llvm::LogicalResult SetOp::ensureOnlySafeAccesses( + const MemorySlot &slot, llvm::SmallVectorImpl &mustBeSafelyUsed, + const DataLayout &dataLayout) { + return success(slot.ptr == getGradient()); +} + //===----------------------------------------------------------------------===// // GetFuncOp //===----------------------------------------------------------------------===// diff --git a/enzyme/Enzyme/MLIR/Dialect/Ops.h b/enzyme/Enzyme/MLIR/Dialect/Ops.h index 11aaa5f4291f..69aa6496b84f 100644 --- a/enzyme/Enzyme/MLIR/Dialect/Ops.h +++ b/enzyme/Enzyme/MLIR/Dialect/Ops.h @@ -14,6 +14,7 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/MemorySlotInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Bytecode/BytecodeOpInterface.h" diff --git a/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp b/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp index f608cecd3a2c..0e6bdf7b101e 100644 --- a/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp +++ b/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp @@ -77,6 +77,7 @@ int main(int argc, char **argv) { // Register the standard passes we want. mlir::registerCSEPass(); + mlir::registerMem2RegPass(); mlir::registerConvertAffineToStandardPass(); mlir::registerSCCPPass(); mlir::registerInlinerPass(); @@ -84,6 +85,7 @@ int main(int argc, char **argv) { mlir::registerSymbolDCEPass(); mlir::registerLoopInvariantCodeMotionPass(); mlir::registerConvertSCFToOpenMPPass(); + mlir::registerSCFToControlFlowPass(); mlir::affine::registerAffinePasses(); mlir::registerReconcileUnrealizedCasts(); diff --git a/enzyme/test/MLIR/Passes/mem2reg.mlir b/enzyme/test/MLIR/Passes/mem2reg.mlir new file mode 100644 index 000000000000..e850f1d8fccf --- /dev/null +++ b/enzyme/test/MLIR/Passes/mem2reg.mlir @@ -0,0 +1,14 @@ +// RUN: %eopt %s -mem2reg | FileCheck %s + +module { + func.func @main(%arg0: f32) -> f32 { + %0 = "enzyme.init"() : () -> !enzyme.Gradient + "enzyme.set"(%0, %arg0) : (!enzyme.Gradient, f32) -> () + %2 = "enzyme.get"(%0) : (!enzyme.Gradient) -> f32 + return %2 : f32 + } +} + +// CHECK: func.func @main(%arg0: f32) -> f32 { +// CHECK-NEXT: return %arg0 : f32 +// CHECK-NEXT: } From caa9a3a7818b09b240a41ba31117670c505708dc Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Mon, 25 Nov 2024 23:33:08 +0100 Subject: [PATCH 09/45] Support batching in MLIR autodiff operations (2nd try) (#2173) * Reorder generated code for Attributes, Enums, Types and Operations * add width attribute to Forwarddiffop and use it in `HandleAutoDiff` * add width attribute to autodiff as well * formatting --- enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td | 4 ++-- enzyme/Enzyme/MLIR/Dialect/Ops.h | 14 ++++++++------ enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp | 4 ++-- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td index ca7659143370..be139fb3d8ba 100644 --- a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td +++ b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td @@ -80,7 +80,7 @@ def PlaceholderOp : Enzyme_Op<"placeholder", def ForwardDiffOp : Enzyme_Op<"fwddiff", [DeclareOpInterfaceMethods]> { let summary = "Perform forward mode AD on a funcop"; - let arguments = (ins FlatSymbolRefAttr:$fn, Variadic:$inputs, ActivityArrayAttr:$activity, ActivityArrayAttr:$ret_activity); + let arguments = (ins FlatSymbolRefAttr:$fn, Variadic:$inputs, ActivityArrayAttr:$activity, ActivityArrayAttr:$ret_activity, DefaultValuedAttr:$width); let results = (outs Variadic:$outputs); let assemblyFormat = [{ @@ -91,7 +91,7 @@ def ForwardDiffOp : Enzyme_Op<"fwddiff", def AutoDiffOp : Enzyme_Op<"autodiff", [DeclareOpInterfaceMethods]> { let summary = "Perform reverse mode AD on a funcop"; - let arguments = (ins FlatSymbolRefAttr:$fn, Variadic:$inputs, ActivityArrayAttr:$activity, ActivityArrayAttr:$ret_activity); + let arguments = (ins FlatSymbolRefAttr:$fn, Variadic:$inputs, ActivityArrayAttr:$activity, ActivityArrayAttr:$ret_activity, DefaultValuedAttr:$width); let results = (outs Variadic:$outputs); let assemblyFormat = [{ diff --git a/enzyme/Enzyme/MLIR/Dialect/Ops.h b/enzyme/Enzyme/MLIR/Dialect/Ops.h index 69aa6496b84f..cd2eb1f70d42 100644 --- a/enzyme/Enzyme/MLIR/Dialect/Ops.h +++ b/enzyme/Enzyme/MLIR/Dialect/Ops.h @@ -19,15 +19,17 @@ #include "mlir/Bytecode/BytecodeOpInterface.h" -#define GET_OP_CLASSES -#include "Dialect/EnzymeOps.h.inc" -#define GET_TYPEDEF_CLASSES -#include "Dialect/EnzymeOpsTypes.h.inc" -// #include "Dialect/EnzymeTypes.h.inc" - #include "Dialect/EnzymeEnums.h.inc" #define GET_ATTRDEF_CLASSES #include "Dialect/EnzymeAttributes.h.inc" +#define GET_TYPEDEF_CLASSES +#include "Dialect/EnzymeOpsTypes.h.inc" + +#define GET_OP_CLASSES +#include "Dialect/EnzymeOps.h.inc" + +// #include "Dialect/EnzymeTypes.h.inc" + #endif // ENZYMEOPS_H diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp index bf1ec877f9d3..c3fe53a7c4ea 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp @@ -139,7 +139,7 @@ struct DifferentiatePass : public DifferentiatePassBase { MTypeAnalysis TA; auto type_args = TA.getAnalyzedTypeInfo(fn); bool freeMemory = true; - size_t width = 1; + size_t width = CI.getWidth(); std::vector volatile_args; for (auto &a : fn.getFunctionBody().getArguments()) { @@ -259,7 +259,7 @@ struct DifferentiatePass : public DifferentiatePassBase { MTypeAnalysis TA; auto type_args = TA.getAnalyzedTypeInfo(fn); bool freeMemory = true; - size_t width = 1; + size_t width = CI.getWidth(); std::vector volatile_args; for (auto &a : fn.getFunctionBody().getArguments()) { From 0a8fdc762af91ab1bc30bcf171ca4ff4c0c33f60 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 25 Nov 2024 18:35:22 -0500 Subject: [PATCH 10/45] Update benchmark.yml --- .github/workflows/benchmark.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index b85eb1aaf083..5aa355f4bfd2 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -23,7 +23,7 @@ jobs: matrix: llvm: ["16", "17", "18"] build: ["Release"] #, "Debug" "RelWithDebInfo" - os: [openstack22] + os: [] # [openstack22] timeout-minutes: 120 steps: - name: add llvm From 32a21180fd5c65c40cabea29d311b6a3ddcccfee Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Wed, 27 Nov 2024 03:59:52 +0100 Subject: [PATCH 11/45] Support batching scalar types (#2175) * support batching scalar types * formatting --- enzyme/Enzyme/MLIR/Passes/EnzymeBatchPass.cpp | 6 +++++- enzyme/test/MLIR/Batch/batched_scalar.mlir | 21 +++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) create mode 100644 enzyme/test/MLIR/Batch/batched_scalar.mlir diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeBatchPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeBatchPass.cpp index 11a0fb3180f6..1e505fe7b907 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeBatchPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeBatchPass.cpp @@ -30,7 +30,11 @@ namespace { static mlir::TensorType applyBatchSizes(mlir::Type Ty, llvm::ArrayRef batchSizes) { - auto T = cast(Ty); + auto T = dyn_cast(Ty); + if (!T) { + return RankedTensorType::get(batchSizes, Ty); + } + SmallVector shape(batchSizes.begin(), batchSizes.end()); shape.append(T.getShape().begin(), T.getShape().end()); auto T2 = T.clone(shape); diff --git a/enzyme/test/MLIR/Batch/batched_scalar.mlir b/enzyme/test/MLIR/Batch/batched_scalar.mlir new file mode 100644 index 000000000000..16cd543873ec --- /dev/null +++ b/enzyme/test/MLIR/Batch/batched_scalar.mlir @@ -0,0 +1,21 @@ +// RUN: %eopt --enzyme-batch %s | FileCheck %s + +module { + func.func @square(%x : f64) -> f64 { + %y = math.sin %x : f64 + return %y : f64 + } + func.func @dsq(%x : tensor<10x2xf64>) -> tensor<10x2xf64> { + %r = enzyme.batch @square(%x) { batch_shape=array } : (tensor<10x2xf64>) -> (tensor<10x2xf64>) + return %r : tensor<10x2xf64> + } +} + +// CHECK: func.func @dsq(%arg0: tensor<10x2xf64>) -> tensor<10x2xf64> { +// CHECK-NEXT: %0 = call @batched_square(%arg0) : (tensor<10x2xf64>) -> tensor<10x2xf64> +// CHECK-NEXT: return %0 : tensor<10x2xf64> +// CHECK-NEXT: } +// CHECK: func.func private @batched_square(%arg0: tensor<10x2xf64>) -> tensor<10x2xf64> { +// CHECK-NEXT: %0 = math.sin %arg0 : tensor<10x2xf64> +// CHECK-NEXT: return %0 : tensor<10x2xf64> +// CHECK-NEXT: } \ No newline at end of file From b59ab66f1792bf1a801f1803b63c9066adca4925 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 28 Nov 2024 00:22:02 -0500 Subject: [PATCH 12/45] Cleanup julia api usage (#2179) --- enzyme/Enzyme/CApi.cpp | 8 +++++--- enzyme/Enzyme/GradientUtils.cpp | 11 ++++++++++- enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp | 16 ++++++++++++++++ 3 files changed, 31 insertions(+), 4 deletions(-) diff --git a/enzyme/Enzyme/CApi.cpp b/enzyme/Enzyme/CApi.cpp index 4e583273698a..ca71867462f3 100644 --- a/enzyme/Enzyme/CApi.cpp +++ b/enzyme/Enzyme/CApi.cpp @@ -333,9 +333,11 @@ void EnzymeRegisterAllocationHandler(char *Name, CustomShadowAlloc AHandle, return unwrap( AHandle(wrap(&B), wrap(CI), Args.size(), refs.data(), gutils)); }; - shadowErasers[Name] = [=](IRBuilder<> &B, Value *ToFree) -> llvm::CallInst * { - return cast_or_null(unwrap(FHandle(wrap(&B), wrap(ToFree)))); - }; + if (FHandle) + shadowErasers[Name] = [=](IRBuilder<> &B, + Value *ToFree) -> llvm::CallInst * { + return cast_or_null(unwrap(FHandle(wrap(&B), wrap(ToFree)))); + }; } void EnzymeRegisterCallHandler(char *Name, diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 7c012506e374..8dfd7ae81041 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -9325,7 +9325,16 @@ llvm::CallInst *freeKnownAllocation(llvm::IRBuilder<> &builder, } if (allocationfn == "julia.gc_alloc_obj" || allocationfn == "jl_gc_alloc_typed" || - allocationfn == "ijl_gc_alloc_typed") + allocationfn == "ijl_gc_alloc_typed" || + allocationfn == "jl_alloc_array_1d" || + allocationfn == "ijl_alloc_array_1d" || + allocationfn == "jl_alloc_array_2d" || + allocationfn == "ijl_alloc_array_2d" || + allocationfn == "jl_alloc_array_3d" || + allocationfn == "ijl_alloc_array_3d" || allocationfn == "jl_new_array" || + allocationfn == "ijl_new_array" || + allocationfn == "jl_alloc_genericmemory" || + allocationfn == "ijl_alloc_genericmemory") return nullptr; if (allocationfn == "enzyme_allocator") { diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index c77d86efc554..fcecc69c3ec6 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -5243,6 +5243,22 @@ void TypeAnalyzer::visitCallBase(CallBase &call) { TypeTree(BaseType::Integer).Only(-1, &call), &call); return; } + if (funcName == "julia.except_enter" || funcName == "ijl_excstack_state" || + funcName == "jl_excstack_state") { + updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); + return; + } + if (funcName == "jl_array_copy" || funcName == "ijl_array_copy" || + funcName == "jl_inactive_inout" || + funcName == "jl_genericmemory_copy_slice" || + funcName == "ijl_genericmemory_copy_slice") { + if (direction & DOWN) + updateAnalysis(&call, getAnalysis(call.getOperand(0)), &call); + if (direction & UP) + updateAnalysis(call.getOperand(0), getAnalysis(&call), &call); + return; + } + if (isAllocationFunction(funcName, TLI)) { size_t Idx = 0; for (auto &Arg : ci->args()) { From 095ee7e3f42931360c3771a6767eea24b2e5c4e2 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 28 Nov 2024 02:15:45 -0500 Subject: [PATCH 13/45] Support batching scalar types (#2175) (#2180) * support batching scalar types * formatting Co-authored-by: jumerckx <31353884+jumerckx@users.noreply.github.com> --- enzyme/Enzyme/CallDerivatives.cpp | 4 ++++ enzyme/Enzyme/FunctionUtils.cpp | 3 ++- enzyme/Enzyme/GradientUtils.cpp | 8 ++++++++ 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/enzyme/Enzyme/CallDerivatives.cpp b/enzyme/Enzyme/CallDerivatives.cpp index 243077c13769..22df3dab9a78 100644 --- a/enzyme/Enzyme/CallDerivatives.cpp +++ b/enzyme/Enzyme/CallDerivatives.cpp @@ -3312,6 +3312,10 @@ bool AdjointGenerator::handleKnownCallDerivatives( } #endif Value *replacement = B.CreateAlloca(elTy, Size); + for (auto MD : {"enzyme_active", "enzyme_inactive", "enzyme_type", + "enzymejl_allocart"}) + if (auto M = call.getMetadata(MD)) + cast(replacement)->setMetadata(MD, M); if (I) replacement->takeName(I); else diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index 3cd0203c08a9..1eac88e5b54c 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -508,7 +508,8 @@ UpgradeAllocasToMallocs(Function *NewF, DerivativeMode mode, {ConstantAsMetadata::get(ConstantInt::get( IntegerType::get(AI->getContext(), 64), align))})); - for (auto MD : {"enzyme_active", "enzyme_inactive", "enzyme_type"}) + for (auto MD : {"enzyme_active", "enzyme_inactive", "enzyme_type", + "enzymejl_allocart"}) if (auto M = AI->getMetadata(MD)) CI->setMetadata(MD, M); diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 8dfd7ae81041..f5d5ffa6854d 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -3280,6 +3280,10 @@ BasicBlock *GradientUtils::prepRematerializedLoopEntry(LoopContext &lc) { auto replacement = NB.CreateAlloca( Type::getInt8Ty(I.getContext()), lookupM(getNewFromOriginal(I.getOperand(0)), NB, available)); + for (auto MD : {"enzyme_active", "enzyme_inactive", "enzyme_type", + "enzymejl_allocart"}) + if (auto M = I.getMetadata(MD)) + replacement->setMetadata(MD, M); auto Alignment = cast( cast(MD->getOperand(0))->getValue()) @@ -3524,6 +3528,10 @@ BasicBlock *GradientUtils::prepRematerializedLoopEntry(LoopContext &lc) { auto rule = [&](Value *anti) { AllocaInst *replacement = NB.CreateAlloca( Type::getInt8Ty(orig->getContext()), args[0]); + for (auto MD : {"enzyme_active", "enzyme_inactive", + "enzyme_type", "enzymejl_allocart"}) + if (auto M = I.getMetadata(MD)) + replacement->setMetadata(MD, M); replacement->takeName(anti); auto Alignment = cast(cast( MD->getOperand(0)) From 7b97a9b6b6f92516a733d12c78a33f118e63091d Mon Sep 17 00:00:00 2001 From: Paul Berg Date: Fri, 29 Nov 2024 20:00:10 +0000 Subject: [PATCH 14/45] tblgen: Implement SelectIfComplex (#2183) * selectifcomplex * remove ConjIfComplex this is unused and can now be implemented using SelectIfComplex * used primal --- enzyme/Enzyme/MLIR/Implementations/Common.td | 9 ++- enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 65 ++++++++++++++------ 2 files changed, 49 insertions(+), 25 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Implementations/Common.td b/enzyme/Enzyme/MLIR/Implementations/Common.td index c40be825a827..9846f9ae6183 100644 --- a/enzyme/Enzyme/MLIR/Implementations/Common.td +++ b/enzyme/Enzyme/MLIR/Implementations/Common.td @@ -103,6 +103,10 @@ def SelectIfActive : Operation { } +def SelectIfComplex : Operation { + +} + class ConstantFP : Operation { string value = val; string dialect = dialect_; @@ -110,11 +114,6 @@ class ConstantFP : Ope string type = type_; } -class ConjIfComplex : Operation { - string dialect = dialect_; - string opName = op_; -} - def ResultTypes : GlobalExprgetResultTypes()">; def TypeOf : Operation { diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index b6848cedcebe..f789a125b222 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -488,29 +488,24 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, os << curIndent << INDENT << "imVal;\n"; os << curIndent << "})"; return true; - } else if (opName == "ConjIfComplex" || - Def->isSubClassOf("ConjIfComplex")) { - if (resultRoot->getNumArgs() != 1) + } else if (opName == "SelectIfComplex" || + Def->isSubClassOf("SelectIfComplex")) { + if (resultRoot->getNumArgs() != 3) PrintFatalError(pattern->getLoc(), - "only three op ConjIfComplex supported"); + "only three op SelectIfComplex supported"); os << "({\n"; - os << curIndent << INDENT << "// Computing ConjIfComplex\n"; + os << curIndent << INDENT << "// Computing SelectIfComplex\n"; if (intrinsic == MLIRDerivatives) - os << curIndent << INDENT << "mlir::Value imVal"; + os << curIndent << INDENT << "mlir::Value imVal = "; else - os << curIndent << INDENT << "llvm::Value *imVal"; - - os << curIndent << INDENT << "if (!gutils->isConstantValue("; + os << curIndent << INDENT << "llvm::Value *imVal = "; if (isa(resultRoot->getArg(0)) && resultRoot->getArgName(0)) { auto name = resultRoot->getArgName(0)->getAsUnquotedString(); auto [ord, isVec, ext] = nameToOrdinal.lookup(name, pattern, resultRoot); - os << ord; - assert(!ext.size()); - os << ord; - os << ";\n"; + os << ord << ";\n"; } else { handle(curIndent + INDENT + INDENT, argPattern + "_cic", os, pattern, resultRoot->getArg(0), builder, nameToOrdinal, lookup, retidx, @@ -518,15 +513,45 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, os << ";\n"; } - os << " (isa(imVal.getType()) || " + os << curIndent << INDENT + << "if (isa(imVal.getType()) || " "(isa(imVal.getType()) && " "isa(cast(imVal.getType()).getElementType(" - ")))) ? "; - os << builder << ".create<" - << cast(Def->getValueInit("dialect"))->getValue() - << "::" << cast(Def->getValueInit("opName"))->getValue() - << ">(op.getLoc(), imVal.getType(), imVal) : imVal;\n"; - os << curIndent << "})"; + ")))) {\n"; + + os << curIndent << INDENT << INDENT << "imVal = "; + if (isa(resultRoot->getArg(1)) && resultRoot->getArgName(1)) { + auto name = resultRoot->getArgName(1)->getAsUnquotedString(); + auto [ord, isVec, ext] = + nameToOrdinal.lookup(name, pattern, resultRoot); + assert(!ext.size()); + os << ord << ";\n"; + } else { + handle(curIndent + INDENT + INDENT, argPattern + "_cic", os, pattern, + resultRoot->getArg(1), builder, nameToOrdinal, lookup, retidx, + origName, newFromOriginal, intrinsic); + os << ";\n"; + } + + os << curIndent << INDENT << "} else {\n"; + + os << curIndent << INDENT << INDENT << "imVal = "; + if (isa(resultRoot->getArg(2)) && resultRoot->getArgName(2)) { + auto name = resultRoot->getArgName(2)->getAsUnquotedString(); + auto [ord, isVec, ext] = + nameToOrdinal.lookup(name, pattern, resultRoot); + assert(!ext.size()); + os << ord << ";\n"; + } else { + handle(curIndent + INDENT + INDENT, argPattern + "_cic", os, pattern, + resultRoot->getArg(2), builder, nameToOrdinal, lookup, retidx, + origName, newFromOriginal, intrinsic); + os << ";\n"; + } + + os << curIndent << INDENT << "}\n"; + os << curIndent << INDENT << "imVal;"; + os << curIndent << INDENT << "})\n"; return true; } else if (opName == "ConstantFP" || Def->isSubClassOf("ConstantFP")) { auto value = dyn_cast(Def->getValueInit("value")); From 06367e4128076271a55719c0c19aec3a58970578 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 3 Dec 2024 21:40:50 -0600 Subject: [PATCH 15/45] add erfinv (#2182) --- enzyme/Enzyme/InstructionDerivatives.td | 9 +++++++++ enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp | 1 + 2 files changed, 10 insertions(+) diff --git a/enzyme/Enzyme/InstructionDerivatives.td b/enzyme/Enzyme/InstructionDerivatives.td index fa415cd46bcc..b3fc4d4448c5 100644 --- a/enzyme/Enzyme/InstructionDerivatives.td +++ b/enzyme/Enzyme/InstructionDerivatives.td @@ -691,6 +691,15 @@ def : CallPattern<(Op $x), [ReadNone, NoUnwind] >; +def : CallPattern<(Op $x), + ["erfinv","erfinvf","erfinfvl", "__nv_erfinv","__nv_erfinvf","__nv_erfinfvl", ], + [ + (FMul (FMul (ConstantFP<"0.8862269254527580136490837416705725913987747280611935641069038949264556422955160906874753283692723327"> $x), (FMul (Call<(SameFunc), [ReadNone, NoUnwind]> $x):$ei, $ei)), (DiffeRet)) + ], + (ForwardFromSummedReverse), + [ReadNone, NoUnwind] + >; + def ToStruct2 : SubRoutine<(Op (Op $re, $im):$z), (RetMultiReturnRet $re, $im) >; diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index fcecc69c3ec6..a4384cbf4dd2 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -162,6 +162,7 @@ const llvm::StringMap LIBM_FUNCTIONS = { {"erf", Intrinsic::not_intrinsic}, {"erfi", Intrinsic::not_intrinsic}, {"erfc", Intrinsic::not_intrinsic}, + {"erfinv", Intrinsic::not_intrinsic}, {"__fd_sincos_1", Intrinsic::not_intrinsic}, {"sincospi", Intrinsic::not_intrinsic}, From 57b718b196ee16d07406470472113390b9043076 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 4 Dec 2024 00:18:43 -0600 Subject: [PATCH 16/45] Improve error on inserted phi scev (#2185) * Improve error on inserted phi scev * fix * more * fix * fix * fix * fix * fix * fix --- enzyme/Enzyme/GradientUtils.cpp | 49 ++++++++++++++++++++++++--------- 1 file changed, 36 insertions(+), 13 deletions(-) diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index f5d5ffa6854d..c169365371ac 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -7137,7 +7137,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, { SCEVExpander OrigExp( *OrigSE, ctx->getParent()->getParent()->getDataLayout(), - "enzyme"); + "enzyme", /*PreserveLCSSA = */ false); OrigExp.setInsertPoint( isOriginal(l1.header)->getTerminator()); @@ -7160,22 +7160,45 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, return OrigDT->dominates(A, B); }); for (auto a : InsertedInstructions) { - assert(!isa(a)); - auto uw = cast( + if (isa(a)) { + std::string str; + raw_string_ostream ss(str); + ss << "oldFunc: " << *oldFunc << "\n"; + ss << "newFunc: " << *newFunc << "\n"; + ss << "li: " << *li << "\n"; + ss << "start0: " << *start0 << "\n"; + ss << "Inserted a phi node (" << *a + << ") during unwrap of SCEV: " << *ar1->getStart() + << "\n"; + if (CustomErrorHandler) { + CustomErrorHandler(str.c_str(), wrap(li), + ErrorType::InternalError, nullptr, + nullptr, nullptr); + } else { + EmitFailure("InsertedPHISCEV", li->getDebugLoc(), li, + ss.str()); + } + } + auto uwV = unwrapM(a, v, available, UnwrapMode::AttemptSingleUnwrap, - /*scope*/ nullptr, /*cache*/ false)); - assert(uw->getType() == a->getType()); + /*scope*/ nullptr, /*cache*/ false); + auto uw = dyn_cast(uwV); + assert(uwV->getType() == a->getType()); #ifndef NDEBUG - for (size_t i = 0; i < uw->getNumOperands(); i++) { - auto op = uw->getOperand(i); - if (auto arg = dyn_cast(op)) - assert(arg->getParent() == newFunc); - else if (auto inst = dyn_cast(op)) - assert(inst->getParent()->getParent() == newFunc); + if (uw) { + for (size_t i = 0; i < uw->getNumOperands(); i++) { + auto op = uw->getOperand(i); + if (auto arg = dyn_cast(op)) + assert(arg->getParent() == newFunc); + else if (auto inst = dyn_cast(op)) + assert(inst->getParent()->getParent() == newFunc); + } + assert(uw->getParent()->getParent() == newFunc); } #endif - available[a] = uw; - unwrappedLoads.erase(cast(uw)); + available[a] = uwV; + if (uw) + unwrappedLoads.erase(uw); } start = From a9cc3c919c263bcfbeff1791ee5b3455bf62f4e7 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 5 Dec 2024 21:32:50 -0600 Subject: [PATCH 17/45] Simplify enzymejl_needs_restoration (#2187) --- enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index f789a125b222..2da86c18e5e5 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -92,7 +92,9 @@ void getFunction(const Twine &curIndent, raw_ostream &os, StringRef callval, os << curIndent << "auto " << callval << " = gutils->oldFunc->getParent()->getOrInsertFunction("; os << Def->getValueInit("name")->getAsString(); - os << ", " << FT << ", called->getAttributes()).getCallee();\n"; + os << ", " << FT + << ", called->getAttributes().removeFnAttribute(called->getContext(), " + "\"enzymejl_needs_restoration\")).getCallee();\n"; os << curIndent << "auto " << cconv << " = cast(&" << origName << ")->getCallingConv();\n"; return; @@ -118,7 +120,9 @@ void getFunction(const Twine &curIndent, raw_ostream &os, StringRef callval, os << curIndent << "auto " << callval << " = gutils->oldFunc->getParent()->getOrInsertFunction("; os << Def->getValueInit("name")->getAsString(); - os << ", " << FT << ", called->getAttributes()).getCallee();\n"; + os << ", " << FT + << ", called->getAttributes().removeFnAttribute(called->getContext(), " + "\"enzymejl_needs_restoration\")).getCallee();\n"; os << curIndent << "auto " << cconv << " = cast(&" << origName << ")->getCallingConv();\n"; return; @@ -133,7 +137,9 @@ void getFunction(const Twine &curIndent, raw_ostream &os, StringRef callval, os << curIndent << "auto " << callval << " = gutils->oldFunc->getParent()->getOrInsertFunction("; os << Def->getValueInit("name")->getAsString(); - os << ", " << FT << ", called->getAttributes()).getCallee();\n"; + os << ", " << FT + << ", called->getAttributes().removeFnAttribute(called->getContext(), " + "\"enzymejl_needs_restoration\")).getCallee();\n"; os << curIndent << "auto " << cconv << " = cast(&" << origName << ")->getCallingConv();\n"; return; From 31efe088a2c5d82f9f48323fee6e01fb3ecd5ef0 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 5 Dec 2024 21:32:59 -0600 Subject: [PATCH 18/45] Add complex bessel support (#2188) * Add complex bessel support * add cfsub * add commag * add commag --- enzyme/Enzyme/InstructionDerivatives.td | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/enzyme/Enzyme/InstructionDerivatives.td b/enzyme/Enzyme/InstructionDerivatives.td index b3fc4d4448c5..2973c4b84cd8 100644 --- a/enzyme/Enzyme/InstructionDerivatives.td +++ b/enzyme/Enzyme/InstructionDerivatives.td @@ -144,6 +144,11 @@ def CFAdd : SubRoutine<(Op (Op $re1, $im1):$z1, (Op $re2, $im2):$z2), (FAdd $re1, $re2), (FAdd $im1, $im2) )>; +def CFSub : SubRoutine<(Op (Op $re1, $im1):$z1, (Op $re2, $im2):$z2), + (ArrayRet + (FSub $re1, $re2), + (FSub $im1, $im2) + )>; def CFMul_splat : SubRoutine<(Op $re1, $im1, $re2, $im2), (ArrayRet @@ -666,6 +671,17 @@ def : CallPattern<(Op $n, $x), [ReadNone, NoUnwind] >; +def : CallPattern<(Op $n, $x), + ["cmplx_jn","cmplx_yn"], + [ + (InactiveArg), + // Reverse mode needs to return the conjugate + (CFMul (DiffeRet), (Conj (CFMul (ConstantCFP<"0.5", "0"> $x), (CFSub (Call<(SameFunc), [ReadNone,NoUnwind]> (FSub $n, (ConstantFP<"1"> $n)), $x), (Call<(SameFunc), [ReadNone,NoUnwind]> (FAdd $n, (ConstantFP<"1"> $n)), $x))))) + ], + (CFMul (Shadow $x), (CFMul (ConstantCFP<"0.5", "0"> $x), (CFSub (Call<(SameFunc), [ReadNone,NoUnwind]> (FSub $n, (ConstantFP<"1"> $n)), $x), (Call<(SameFunc), [ReadNone,NoUnwind]> (FAdd $n, (ConstantFP<"1"> $n)), $x)))), + [ReadNone, NoUnwind] + >; + def : CallPattern<(Op $x), ["erf","erff","erfl"], [ From 07e1826f2bc1090c3d8bc5bd30b9e2797e65a6dc Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 5 Dec 2024 23:22:48 -0600 Subject: [PATCH 19/45] More complex specialfunctions (#2190) --- enzyme/Enzyme/InstructionDerivatives.td | 31 ++++++++++++-------- enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 5 +++- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/enzyme/Enzyme/InstructionDerivatives.td b/enzyme/Enzyme/InstructionDerivatives.td index 2973c4b84cd8..c778e63aa258 100644 --- a/enzyme/Enzyme/InstructionDerivatives.td +++ b/enzyme/Enzyme/InstructionDerivatives.td @@ -4,6 +4,15 @@ class ForwardFromSummedReverseInternal { } def ForwardFromSummedReverse : ForwardFromSummedReverseInternal<0>; +class Operation { + bit usesPrimal = usesPrimal_; + bit usesShadow = usesShadow_; + bit usesCustom = usesCustom_; +} +class ConstantFP : Operation { + string value = val; +} + class Attribute { string name = name_; @@ -46,11 +55,6 @@ class InstPattern { - bit usesPrimal = usesPrimal_; - bit usesShadow = usesShadow_; - bit usesCustom = usesCustom_; -} class Inst : Operation { string name = mnemonic; } @@ -175,6 +179,12 @@ def CFNeg : SubRoutine<(Op (Op $re, $im):$z), (FNeg $im) )>; +def Complex : SubRoutine<(Op $x), + (ArrayRet + $x, + (ConstantFP<"0"> $x) + )>; + def Conj : SubRoutine<(Op (Op $re, $im):$z), (ArrayRet $re, @@ -272,9 +282,6 @@ def MantissaMaskOfReturnForFrexp : GlobalExpr; -class ConstantFP : Operation { - string value = val; -} def Zero : Operation { } class ConstantCFP : Operation { @@ -671,14 +678,14 @@ def : CallPattern<(Op $n, $x), [ReadNone, NoUnwind] >; -def : CallPattern<(Op $n, $x), +def : CallPattern<(Op $n, $z), ["cmplx_jn","cmplx_yn"], [ - (InactiveArg), + (AssertingInactiveArg), // Reverse mode needs to return the conjugate - (CFMul (DiffeRet), (Conj (CFMul (ConstantCFP<"0.5", "0"> $x), (CFSub (Call<(SameFunc), [ReadNone,NoUnwind]> (FSub $n, (ConstantFP<"1"> $n)), $x), (Call<(SameFunc), [ReadNone,NoUnwind]> (FAdd $n, (ConstantFP<"1"> $n)), $x))))) + (CFMul (DiffeRet), (Conj (CFMul (ConstantCFP<"0.5", "0"> $z), (CFSub (Call<(SameFunc), [ReadNone,NoUnwind]> (FSub $n, (ConstantFP<"1"> $n)), $z), (Call<(SameFunc), [ReadNone,NoUnwind]> (FAdd $n, (ConstantFP<"1"> $n)), $z))))) ], - (CFMul (Shadow $x), (CFMul (ConstantCFP<"0.5", "0"> $x), (CFSub (Call<(SameFunc), [ReadNone,NoUnwind]> (FSub $n, (ConstantFP<"1"> $n)), $x), (Call<(SameFunc), [ReadNone,NoUnwind]> (FAdd $n, (ConstantFP<"1"> $n)), $x)))), + (CFMul (Shadow $z), (CFMul (ConstantCFP<"0.5", "0"> $z), (CFSub (Call<(SameFunc), [ReadNone,NoUnwind]> (FSub $n, (ConstantFP<"1"> $n)), $z), (Call<(SameFunc), [ReadNone,NoUnwind]> (FAdd $n, (ConstantFP<"1"> $n)), $z)))), [ReadNone, NoUnwind] >; diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 2da86c18e5e5..900c5c813cd7 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -695,7 +695,10 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, << rvalue->getValue() << "\"), (llvm::Constant*)ConstantFP::get(AT->getElementType(), \"" << ivalue->getValue() << "\")});\n"; - os << curIndent << INDENT << "} else assert(0 && \"unhandled cfp\");\n"; + os << curIndent << INDENT << "} else {\n"; + os << curIndent << INDENT << " llvm::errs() << *ty << \"\\n\";\n"; + os << curIndent << INDENT << " assert(0 && \"unhandled cfp\");\n"; + os << curIndent << INDENT << "}\n"; os << curIndent << INDENT << "ret;\n"; os << curIndent << "})\n"; return false; From 9fe27b8df837e1a67e81f48cf2ff8be44dc40987 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 9 Dec 2024 12:26:32 -0600 Subject: [PATCH 20/45] Add loose types for extract (#2193) --- enzyme/Enzyme/AdjointGenerator.h | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index 2eae03312b73..96b8494302ec 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -1840,14 +1840,25 @@ class AdjointGenerator : public llvm::InstVisitor { } unsigned size = nextStart - start; if (!dt.isKnown()) { - - std::string str; - raw_string_ostream ss(str); - ss << "Cannot deduce type of extract " << EVI << vd.str() - << " start: " << start << " size: " << size - << " extractSize: " << storeSize; - EmitNoTypeError(str, EVI, gutils, Builder2); - break; + bool found = false; + if (looseTypeAnalysis) { + if (EVI.getType()->isFPOrFPVectorTy()) { + dt = ConcreteType(EVI.getType()->getScalarType()); + found = true; + } else if (EVI.getType()->isIntOrIntVectorTy() || + EVI.getType()->isPointerTy()) { + dt = BaseType::Integer; + found = true; + } + } + if (!found) { + std::string str; + raw_string_ostream ss(str); + ss << "Cannot deduce type of extract " << EVI << vd.str() + << " start: " << start << " size: " << size + << " extractSize: " << storeSize; + EmitNoTypeError(str, EVI, gutils, Builder2); + } } if (auto FT = dt.isFloat()) ((DiffeGradientUtils *)gutils) From 068ad9c6f8bc7d8c7ad3806fd148492e323bc4b1 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 13 Dec 2024 13:08:37 -0600 Subject: [PATCH 21/45] MLIR: improve num returns error (#2196) --- enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp index c3fe53a7c4ea..d83532db35a6 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp @@ -157,6 +157,11 @@ struct DifferentiatePass : public DifferentiatePassBase { OpBuilder builder(CI); auto dCI = builder.create(CI.getLoc(), newFunc.getName(), newFunc.getResultTypes(), args); + if (dCI.getNumResults() != CI.getNumResults()) { + CI.emitError() << "Incorrect number of results for enzyme operation: " + << *CI << " expected " << *dCI; + return failure(); + } CI.replaceAllUsesWith(dCI); CI->erase(); return success(); From 5f1d3325bf6f25ddbdf24759822f122ab2b72088 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 13 Dec 2024 18:19:32 -0600 Subject: [PATCH 22/45] Activity Analysis: strengthen recursive hyp (#2197) * Activity Analysis: strengthen recursive hyp * fix test * fix --- enzyme/Enzyme/ActivityAnalysis.cpp | 25 ++++++-- enzyme/test/ActivityAnalysis/mallocuse.ll | 43 ++++++++++++++ enzyme/test/Enzyme/ReverseMode/mallocuse.ll | 66 +++++++++++++++++++++ 3 files changed, 128 insertions(+), 6 deletions(-) create mode 100644 enzyme/test/ActivityAnalysis/mallocuse.ll create mode 100644 enzyme/test/Enzyme/ReverseMode/mallocuse.ll diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index cedaa1d19c3f..2a22512732b4 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -980,12 +980,10 @@ bool ActivityAnalyzer::isConstantInstruction(TypeResults const &TR, insertConstantsFrom(TR, *DownHypothesis); return true; } else if (directions == 3) { - if (isa(I) || isa(I) || isa(I)) { - for (auto &op : I->operands()) { - if (!UpHypothesis->isConstantValue(TR, op) && - EnzymeEnableRecursiveHypotheses) { - ReEvaluateInstIfInactiveValue[op].insert(I); - } + for (auto &op : I->operands()) { + if (!UpHypothesis->isConstantValue(TR, op) && + EnzymeEnableRecursiveHypotheses) { + ReEvaluateInstIfInactiveValue[op].insert(I); } } } @@ -1785,6 +1783,13 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { InsertConstantValue(TR, Val); insertConstantsFrom(TR, *UpHypothesis); return true; + } else if (directions == 3) { + for (auto &op : inst->operands()) { + if (!UpHypothesis->isConstantValue(TR, op) && + EnzymeEnableRecursiveHypotheses) { + ReEvaluateValueIfInactiveValue[op].insert(Val); + } + } } } } @@ -1826,6 +1831,14 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { if (EnzymePrintActivity) llvm::errs() << " cannot show constant instruction hypothesis: " << *VI << "\n"; + if (directions == 3) { + for (auto &op : VI->operands()) { + if (!UpHypothesis->isConstantValue(TR, op) && + EnzymeEnableRecursiveHypotheses) { + ReEvaluateValueIfInactiveValue[op].insert(Val); + } + } + } } } diff --git a/enzyme/test/ActivityAnalysis/mallocuse.ll b/enzyme/test/ActivityAnalysis/mallocuse.ll new file mode 100644 index 000000000000..2b66d2b61b8e --- /dev/null +++ b/enzyme/test/ActivityAnalysis/mallocuse.ll @@ -0,0 +1,43 @@ +; RUN: %opt < %s %newLoadEnzyme -passes="print-activity-analysis" -activity-analysis-func=_take -opaque-pointers -S -o /dev/null | FileCheck %s + +declare ptr @malloc(i64) + +define double @_take(ptr %a0, i1 %a1) { +entry: + %a3 = tail call ptr @malloc(i64 10) + %a4 = tail call ptr @malloc(i64 10) + %a5 = ptrtoint ptr %a4 to i64 + %a6 = or i64 %a5, 1 + %a7 = inttoptr i64 %a6 to ptr + %a8 = load double, ptr %a7, align 8 + store double %a8, ptr %a0, align 8 + br i1 %a1, label %.lr.ph, label %.lr.ph1.peel.next + +.lr.ph1.peel.next: ; preds = %2 + %.pre = load double, ptr %a4, align 8 + ret double %.pre + +.lr.ph: ; preds = %.lr.ph, %2 + %a9 = load double, ptr %a3, align 4 + store double %a9, ptr %a4, align 8 + br label %.lr.ph +} + +; CHECK: ptr %a0: icv:0 +; CHECK-NEXT: i1 %a1: icv:1 +; CHECK-NEXT: entry +; CHECK-NEXT: %a3 = tail call ptr @malloc(i64 10): icv:1 ici:1 +; CHECK-NEXT: %a4 = tail call ptr @malloc(i64 10): icv:1 ici:1 +; CHECK-NEXT: %a5 = ptrtoint ptr %a4 to i64: icv:1 ici:1 +; CHECK-NEXT: %a6 = or i64 %a5, 1: icv:1 ici:1 +; CHECK-NEXT: %a7 = inttoptr i64 %a6 to ptr: icv:1 ici:1 +; CHECK-NEXT: %a8 = load double, ptr %a7, align 8: icv:1 ici:1 +; CHECK-NEXT: store double %a8, ptr %a0, align 8: icv:1 ici:1 +; CHECK-NEXT: br i1 %a1, label %.lr.ph, label %.lr.ph1.peel.next: icv:1 ici:1 +; CHECK-NEXT: .lr.ph1.peel.next +; CHECK-NEXT: %.pre = load double, ptr %a4, align 8: icv:1 ici:1 +; CHECK-NEXT: ret double %.pre: icv:1 ici:1 +; CHECK-NEXT: .lr.ph +; CHECK-NEXT: %a9 = load double, ptr %a3, align 4: icv:1 ici:1 +; CHECK-NEXT: store double %a9, ptr %a4, align 8: icv:1 ici:1 +; CHECK-NEXT: br label %.lr.ph: icv:1 ici:1 diff --git a/enzyme/test/Enzyme/ReverseMode/mallocuse.ll b/enzyme/test/Enzyme/ReverseMode/mallocuse.ll new file mode 100644 index 000000000000..eb1fb7c478cd --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/mallocuse.ll @@ -0,0 +1,66 @@ +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme,function(mem2reg,early-cse,sroa,instsimplify,%simplifycfg,adce)" -enzyme-preopt=false -opaque-pointers -S | FileCheck %s + +declare ptr @__enzyme_virtualreverse(...) + +declare ptr @malloc(i64) + +define void @my_model.fullgrad1() { + %z = call ptr (...) @__enzyme_virtualreverse(ptr nonnull @_take) + ret void +} + +define double @_take(ptr %a0, i1 %a1) { + %a3 = tail call ptr @malloc(i64 10) + %a4 = tail call ptr @malloc(i64 10) + %a5 = ptrtoint ptr %a4 to i64 + %a6 = or i64 %a5, 1 + %a7 = inttoptr i64 %a6 to ptr + %a8 = load double, ptr %a7, align 8 + store double %a8, ptr %a0, align 8 + br i1 %a1, label %.lr.ph, label %.lr.ph1.peel.next + +.lr.ph1.peel.next: ; preds = %2 + %.pre = load double, ptr %a4, align 8 + ret double %.pre + +.lr.ph: ; preds = %.lr.ph, %2 + %a9 = load double, ptr %a3, align 4 + store double %a9, ptr %a4, align 8 + br label %.lr.ph +} + +; CHECK: define internal { ptr, double } @augmented__take(ptr %a0, ptr %"a0'", i1 %a1) +; CHECK-NEXT: %malloccall = tail call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) ptr @malloc(i64 8) +; CHECK-NEXT: %a3 = tail call ptr @malloc(i64 10) +; CHECK-NEXT: %a4 = tail call ptr @malloc(i64 10) +; CHECK-NEXT: store ptr %a4, ptr %malloccall, align 8 +; CHECK-NEXT: %a5 = ptrtoint ptr %a4 to i64 +; CHECK-NEXT: %a6 = or i64 %a5, 1 +; CHECK-NEXT: %a7 = inttoptr i64 %a6 to ptr +; CHECK-NEXT: %a8 = load double, ptr %a7, align 8 +; CHECK-NEXT: store double %a8, ptr %a0, align 8 +; CHECK-NEXT: br i1 %a1, label %.lr.ph, label %.lr.ph1.peel.next + +; CHECK: .lr.ph1.peel.next: ; preds = %0 +; CHECK-NEXT: %.pre = load double, ptr %a4, align 8, !alias.scope !10, !noalias !13 +; CHECK-NEXT: %.fca.0.insert = insertvalue { ptr, double } poison, ptr %malloccall, 0 +; CHECK-NEXT: %.fca.1.insert = insertvalue { ptr, double } %.fca.0.insert, double %.pre, 1 +; CHECK-NEXT: ret { ptr, double } %.fca.1.insert + +; CHECK: .lr.ph: ; preds = %0, %.lr.ph +; CHECK-NEXT: %a9 = load double, ptr %a3, align 4 +; CHECK-NEXT: store double %a9, ptr %a4, align 8 +; CHECK-NEXT: br label %.lr.ph +; CHECK-NEXT: } + +; CHECK: define internal void @diffe_take(ptr %a0, ptr %"a0'", i1 %a1, double %differeturn, ptr %tapeArg) +; CHECK-NEXT: tail call void @free(ptr nonnull %tapeArg) +; CHECK-NEXT: br i1 %a1, label %.lr.ph, label %invert.lr.ph1.peel.next + +; CHECK: .lr.ph: ; preds = %0, %.lr.ph +; CHECK-NEXT: br label %.lr.ph + +; CHECK: invert.lr.ph1.peel.next: ; preds = %0 +; CHECK-NEXT: store double 0.000000e+00, ptr %"a0'", align 8 +; CHECK-NEXT: ret void +; CHECK-NEXT: } From b387a389b11040b2d8de7849aa063e0087f0ae05 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 15 Dec 2024 13:01:19 -0600 Subject: [PATCH 23/45] Update BUILD --- enzyme/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/enzyme/BUILD b/enzyme/BUILD index bfb0f2f5c446..d15ae7a1badc 100644 --- a/enzyme/BUILD +++ b/enzyme/BUILD @@ -235,6 +235,7 @@ td_library( srcs = [ "Enzyme/MLIR/Dialect/Dialect.td", ], + includes = ["."], deps = [ "@llvm-project//mlir:ControlFlowInterfacesTdFiles", "@llvm-project//mlir:FunctionInterfacesTdFiles", From a03c1d4008369c27536ab8450e808e34d4514a5f Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 21 Dec 2024 19:11:58 -0500 Subject: [PATCH 24/45] Fix forward rewrite (#2201) --- enzyme/Enzyme/CallDerivatives.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/enzyme/Enzyme/CallDerivatives.cpp b/enzyme/Enzyme/CallDerivatives.cpp index 22df3dab9a78..db7ab6ac4450 100644 --- a/enzyme/Enzyme/CallDerivatives.cpp +++ b/enzyme/Enzyme/CallDerivatives.cpp @@ -3236,6 +3236,13 @@ bool AdjointGenerator::handleKnownCallDerivatives( CI->setCallingConv(call.getCallingConv()); CI->setTailCallKind(call.getTailCallKind()); CI->setDebugLoc(dbgLoc); + + if (funcName == "julia.gc_alloc_obj" || + funcName == "jl_gc_alloc_typed" || + funcName == "ijl_gc_alloc_typed") { + if (EnzymeShadowAllocRewrite) + EnzymeShadowAllocRewrite(wrap(CI), gutils); + } return CI; }; From 534a28595a9bf1f41d4a8222615f00e36611dc0f Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 22 Dec 2024 14:56:08 -0500 Subject: [PATCH 25/45] Expand shadow_alloc_rewrite capabilities (#2202) --- enzyme/Enzyme/CallDerivatives.cpp | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/enzyme/Enzyme/CallDerivatives.cpp b/enzyme/Enzyme/CallDerivatives.cpp index db7ab6ac4450..60503c2b0fe4 100644 --- a/enzyme/Enzyme/CallDerivatives.cpp +++ b/enzyme/Enzyme/CallDerivatives.cpp @@ -29,7 +29,8 @@ using namespace llvm; extern "C" { -void (*EnzymeShadowAllocRewrite)(LLVMValueRef, void *) = nullptr; +void (*EnzymeShadowAllocRewrite)(LLVMValueRef, void *, LLVMValueRef, uint64_t, + LLVMValueRef) = nullptr; } void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called, @@ -3014,6 +3015,9 @@ bool AdjointGenerator::handleKnownCallDerivatives( bb, anti, getIndex(&call, CacheType::Shadow, BuilderZ)); } else { bool zeroed = false; + uint64_t idx = 0; + Value *prev = nullptr; + ; auto rule = [&]() { Value *anti = bb.CreateCall(call.getFunctionType(), call.getCalledOperand(), @@ -3059,7 +3063,8 @@ bool AdjointGenerator::handleKnownCallDerivatives( funcName == "jl_gc_alloc_typed" || funcName == "ijl_gc_alloc_typed") { if (EnzymeShadowAllocRewrite) - EnzymeShadowAllocRewrite(wrap(anti), gutils); + EnzymeShadowAllocRewrite(wrap(anti), gutils, wrap(&call), + idx, wrap(prev)); } } if (Mode == DerivativeMode::ReverseModeCombined || @@ -3075,6 +3080,8 @@ bool AdjointGenerator::handleKnownCallDerivatives( zeroed = true; } } + idx++; + prev = anti; return anti; }; @@ -3224,6 +3231,8 @@ bool AdjointGenerator::handleKnownCallDerivatives( args.push_back(gutils->getNewFromOriginal(arg)); } + uint64_t idx = 0; + Value *prev = gutils->getNewFromOriginal(&call); auto rule = [&]() { SmallVector BundleTypes(args.size(), ValueType::Primal); @@ -3241,8 +3250,11 @@ bool AdjointGenerator::handleKnownCallDerivatives( funcName == "jl_gc_alloc_typed" || funcName == "ijl_gc_alloc_typed") { if (EnzymeShadowAllocRewrite) - EnzymeShadowAllocRewrite(wrap(CI), gutils); + EnzymeShadowAllocRewrite(wrap(CI), gutils, wrap(&call), idx, + wrap(prev)); } + idx++; + prev = CI; return CI; }; From f5767ef45ac7071ac44dd75d77771475377328e5 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 22 Dec 2024 18:03:34 -0500 Subject: [PATCH 26/45] Fix activity analysis store (#2203) --- enzyme/Enzyme/ActivityAnalysis.cpp | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index 2a22512732b4..88ceb7e9f472 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -2797,8 +2797,16 @@ bool ActivityAnalyzer::isValueInactiveFromUsers(TypeResults const &TR, if (isa(TmpOrig) || isAllocationCall(TmpOrig, TLI)) { done.insert( std::make_tuple((User *)SI, SI->getPointerOperand(), UA)); + // If we are capturing a variable v, we need to check any loads or + // stores into that variable, even if we are checking only for + // stores. + auto UA2 = UA; + if (UA == UseActivity::OnlyStores || + UA == UseActivity::OnlyNonPointerStores || + UA == UseActivity::AllStores) + UA2 = UseActivity::None; for (const auto a : TmpOrig->users()) { - todo.push_back(std::make_tuple(a, TmpOrig, UA)); + todo.push_back(std::make_tuple(a, TmpOrig, UA2)); } AllocaSet.insert(TmpOrig); if (EnzymePrintActivity) From 79ac40699969e416e6088bb99dfdd523f0d7f40e Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 22 Dec 2024 20:14:21 -0500 Subject: [PATCH 27/45] Non-power of two cache switch (#2204) --- enzyme/Enzyme/Utils.cpp | 4 ++++ enzyme/Enzyme/Utils.h | 5 +++++ 2 files changed, 9 insertions(+) diff --git a/enzyme/Enzyme/Utils.cpp b/enzyme/Enzyme/Utils.cpp index 933a22304e29..f4655bac8455 100644 --- a/enzyme/Enzyme/Utils.cpp +++ b/enzyme/Enzyme/Utils.cpp @@ -102,6 +102,10 @@ llvm::cl::opt EnzymeMemmoveWarning( llvm::cl::opt EnzymeRuntimeError( "enzyme-runtime-error", cl::init(false), cl::Hidden, cl::desc("Emit Runtime errors instead of compile time ones")); + +llvm::cl::opt EnzymeNonPower2Cache( + "enzyme-non-power2-cache", cl::init(false), cl::Hidden, + cl::desc("Disable caching of integers which are not a power of 2")); } void ZeroMemory(llvm::IRBuilder<> &Builder, llvm::Type *T, llvm::Value *obj, diff --git a/enzyme/Enzyme/Utils.h b/enzyme/Enzyme/Utils.h index 02ce4b8b47e2..089a99c86912 100644 --- a/enzyme/Enzyme/Utils.h +++ b/enzyme/Enzyme/Utils.h @@ -91,6 +91,7 @@ enum class ErrorType { extern "C" { /// Print additional debug info relevant to performance extern llvm::cl::opt EnzymePrintPerf; +extern llvm::cl::opt EnzymeNonPower2Cache; extern llvm::cl::opt EnzymeStrongZero; extern llvm::cl::opt EnzymeBlasCopy; extern llvm::cl::opt EnzymeLapackCopy; @@ -1194,6 +1195,10 @@ static inline bool hasNoCache(llvm::Value *op) { } } } + if (auto IT = dyn_cast(op->getType())) + if (!isPowerOf2_64(IT->getBitWidth()) && !EnzymeNonPower2Cache) + return true; + return false; } From 1ca3ceb12723b7246fe19016bca6e91fba37f1c7 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 22 Dec 2024 23:11:13 -0500 Subject: [PATCH 28/45] ActivityAnalysis: consider atomicrmw (#2205) * ActivityAnalysis: consider atomicrmw * fix * fix --- enzyme/Enzyme/ActivityAnalysis.cpp | 22 +++++++++++++++++++++- enzyme/Enzyme/GradientUtils.cpp | 10 ++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index 88ceb7e9f472..67cbb5e9ddfe 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -1554,6 +1554,26 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { ReEvaluateValueIfInactiveValue[II->getOperand(0)].insert(TmpOrig); } } + } else if (auto RMW = dyn_cast(TmpOrig)) { + if (directions == UP) { + if (isConstantValue(TR, RMW->getPointerOperand())) { + InsertConstantValue(TR, Val); + return true; + } + } else { + if (UpHypothesis->isConstantValue(TR, RMW->getPointerOperand())) { + InsertConstantValue(TR, Val); + insertConstantsFrom(TR, *UpHypothesis); + return true; + } + } + if (EnzymeEnableRecursiveHypotheses) { + ReEvaluateValueIfInactiveValue[RMW->getPointerOperand()].insert(Val); + if (TmpOrig != Val) { + ReEvaluateValueIfInactiveValue[RMW->getPointerOperand()].insert( + TmpOrig); + } + } } else if (auto op = dyn_cast(TmpOrig)) { if (isInactiveCall(*op) || op->hasFnAttr("enzyme_inactive_val") || op->getAttributes().hasAttribute(llvm::AttributeList::ReturnIndex, @@ -1940,7 +1960,7 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { isRefSet(AARes)) { if (EnzymePrintActivity) llvm::errs() << "potential active load: " << *I << "\n"; - if (isa(I) || isNVLoad(I)) { + if (isa(I) || isNVLoad(I) || isa(I)) { // If the ref'ing value is a load check if the loaded value is // active if (!Hypothesis->isConstantValue(TR, I)) { diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index c169365371ac..414374dd7799 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -3822,6 +3822,9 @@ bool GradientUtils::legalRecompute(const Value *val, } } + if (isa(val)) + return false; + if (auto phi = dyn_cast(val)) { if (auto uiv = hasUninverted(val)) { if (auto dli = dyn_cast_or_null(uiv)) { @@ -3835,6 +3838,13 @@ bool GradientUtils::legalRecompute(const Value *val, } } + auto found = fictiousPHIs.find(const_cast(phi)); + if (found != fictiousPHIs.end()) { + auto orig = found->second; + if (isa(orig)) + return false; + } + if (phi->getNumIncomingValues() == 0) { llvm::errs() << *oldFunc << "\n"; llvm::errs() << *newFunc << "\n"; From 59540625e212697ea4ab1d7beeecaa0eeec14426 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 25 Dec 2024 20:05:23 -0500 Subject: [PATCH 29/45] Pass if is unnecessary (#2207) --- enzyme/Enzyme/CallDerivatives.cpp | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/enzyme/Enzyme/CallDerivatives.cpp b/enzyme/Enzyme/CallDerivatives.cpp index 60503c2b0fe4..bc5a095cfae5 100644 --- a/enzyme/Enzyme/CallDerivatives.cpp +++ b/enzyme/Enzyme/CallDerivatives.cpp @@ -30,7 +30,7 @@ using namespace llvm; extern "C" { void (*EnzymeShadowAllocRewrite)(LLVMValueRef, void *, LLVMValueRef, uint64_t, - LLVMValueRef) = nullptr; + LLVMValueRef, uint8_t) = nullptr; } void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called, @@ -3062,9 +3062,12 @@ bool AdjointGenerator::handleKnownCallDerivatives( if (funcName == "julia.gc_alloc_obj" || funcName == "jl_gc_alloc_typed" || funcName == "ijl_gc_alloc_typed") { - if (EnzymeShadowAllocRewrite) + if (EnzymeShadowAllocRewrite) { + bool used = unnecessaryInstructions.find(&call) == + unnecessaryInstructions.end(); EnzymeShadowAllocRewrite(wrap(anti), gutils, wrap(&call), - idx, wrap(prev)); + idx, wrap(prev), used); + } } } if (Mode == DerivativeMode::ReverseModeCombined || @@ -3249,9 +3252,12 @@ bool AdjointGenerator::handleKnownCallDerivatives( if (funcName == "julia.gc_alloc_obj" || funcName == "jl_gc_alloc_typed" || funcName == "ijl_gc_alloc_typed") { - if (EnzymeShadowAllocRewrite) + if (EnzymeShadowAllocRewrite) { + bool used = unnecessaryInstructions.find(&call) == + unnecessaryInstructions.end(); EnzymeShadowAllocRewrite(wrap(CI), gutils, wrap(&call), idx, - wrap(prev)); + wrap(prev), used); + } } idx++; prev = CI; From 8e79483d4c2d4cb2a6ccf1354be0595c6658a73d Mon Sep 17 00:00:00 2001 From: Kiran Shila Date: Fri, 27 Dec 2024 10:48:13 -0800 Subject: [PATCH 30/45] Add nixpkgs to README (#2208) --- Readme.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/Readme.md b/Readme.md index 7dd210b8b7e8..544e12dfc5df 100644 --- a/Readme.md +++ b/Readme.md @@ -39,6 +39,10 @@ brew install enzyme ``` spack install enzyme ``` +[Nix](https://nixos.org/) +``` +nix-shell -p enzyme +``` To get involved or if you have questions, please join our [mailing list](https://groups.google.com/d/forum/enzyme-dev). From eeb6200dafad352aa44ba163e3e9cd4f4eae5a8f Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Fri, 27 Dec 2024 21:14:15 +0100 Subject: [PATCH 31/45] Batched autodiff (#2181) * add type conversions for width != 1. This still requires changes in the tblgenerated derivative files. For example, createForwardModeTangent in MulFOpFwdDerivative could be altered like this: ``` LogicalResult createForwardModeTangent(Operation *op0, OpBuilder &builder, MGradientUtils *gutils) const { auto op = cast(op0); if (gutils->width != 1) { auto newop = gutils->getNewFromOriginal(op0); for (auto res : newop->getResults()) { res.setType(mlir::RankedTensorType::get({gutils->width}, res.getType())); } } gutils->eraseIfUnused(op); if (gutils->isConstantInstruction(op)) return success(); mlir::Value res = nullptr; if (!gutils->isConstantValue(op->getOperand(0))) { auto dif = gutils->invertPointerM(op->getOperand(0), builder); { mlir::Value itmp = ({ // Computing MulFOp auto fwdarg_0 = dif; dif.dump(); // TODO: gutils->makeBatched(...) auto fwdarg_1 = gutils->getNewFromOriginal(op->getOperand(1)); builder.create(op.getLoc(), fwdarg_0, fwdarg_1); }); itmp.dump(); if (!res) res = itmp; else { auto operandType = cast(res.getType()); res = operandType.createAddOp(builder, op.getLoc(), res, itmp); } } } if (!gutils->isConstantValue(op->getOperand(1))) { auto dif = gutils->invertPointerM(op->getOperand(1), builder); { mlir::Value itmp = ({ // Computing MulFOp auto fwdarg_0 = dif; dif.dump(); auto fwdarg_1 = gutils->getNewFromOriginal(op->getOperand(0)); builder.create(op.getLoc(), fwdarg_0, fwdarg_1); }); if (!res) res = itmp; else { auto operandType = cast(res.getType()); res = operandType.createAddOp(builder, op.getLoc(), res, itmp); } } } assert(res); gutils->setDiffe(op->getResult(0), res, builder); return success(); } ``` * add code to tblgen generator, this eventually needs to be a single function call. * a test and formatting * use tensor splatop * remove stale enzyme-tblgen changes * do the simple batching in enzyme-tblgen * include tensor in all AutoDiffOpInterfaceImpls * add enzyme broadcastop * getShadowType for TensorTypeInterface * create broadcastop in enzyme-tblgen * Revert "include tensor in all AutoDiffOpInterfaceImpls" This reverts commit c06ed01709b51bff5b794a7e4dc83b63510b9a84. * test * DenseI64ArrayAttr for shape instead of scalar width * `llvm::SmallVector` --> `ArrayRef` * formatting * use getShadowType in BroadcastOp builder Co-authored-by: Billy Moses * unstructured control flow test * scf.for * formatting * support `scf.if` test * formatting * forgotten includes --------- Co-authored-by: Jules Merckx Co-authored-by: Billy Moses --- enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td | 16 +++++++ enzyme/Enzyme/MLIR/Dialect/Ops.cpp | 15 +++++++ .../ArithAutoDiffOpInterfaceImpl.cpp | 8 ++++ .../BuiltinAutoDiffTypeInterfaceImpl.cpp | 16 +++++-- .../CoreDialectsAutoDiffImplementations.cpp | 8 ++-- .../CoreDialectsAutoDiffImplementations.h | 1 + .../Enzyme/MLIR/Interfaces/CloneFunction.cpp | 6 ++- .../Enzyme/MLIR/Interfaces/GradientUtils.cpp | 3 +- enzyme/Enzyme/MLIR/Passes/CMakeLists.txt | 1 + enzyme/Enzyme/MLIR/Passes/Passes.h | 5 +++ enzyme/Enzyme/MLIR/Passes/Passes.td | 3 +- enzyme/Enzyme/MLIR/enzymemlir-opt.cpp | 1 + .../test/MLIR/ForwardMode/batched_branch.mlir | 26 +++++++++++ enzyme/test/MLIR/ForwardMode/batched_for.mlir | 33 ++++++++++++++ enzyme/test/MLIR/ForwardMode/batched_if.mlir | 43 +++++++++++++++++++ .../test/MLIR/ForwardMode/batched_scalar.mlir | 26 +++++++++++ .../test/MLIR/ForwardMode/batched_tensor.mlir | 26 +++++++++++ enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 13 +++++- 18 files changed, 239 insertions(+), 11 deletions(-) create mode 100644 enzyme/test/MLIR/ForwardMode/batched_branch.mlir create mode 100644 enzyme/test/MLIR/ForwardMode/batched_for.mlir create mode 100644 enzyme/test/MLIR/ForwardMode/batched_if.mlir create mode 100644 enzyme/test/MLIR/ForwardMode/batched_scalar.mlir create mode 100644 enzyme/test/MLIR/ForwardMode/batched_tensor.mlir diff --git a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td index be139fb3d8ba..72672a959403 100644 --- a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td +++ b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td @@ -192,4 +192,20 @@ def GenericAdjointOp : Enzyme_Op<"genericAdjoint", [AttrSizedOperandSegments]> { } +def BroadcastOp : Enzyme_Op<"broadcast"> { + let description = [{ + Broadcast the operand by adding extra dimensions with sizes provided by the `shape` attribute to the front. + For scalar operands, ranked tensor is created. + + NOTE: Only works for scalar and *ranked* tensor operands for now. + }]; + + let arguments = (ins AnyType:$input, DenseI64ArrayAttr:$shape); + let results = (outs AnyRankedTensor:$output); + + let builders = [ + OpBuilder<(ins "Value":$input, "ArrayRef":$shape)> + ]; +} + #endif // ENZYME_OPS diff --git a/enzyme/Enzyme/MLIR/Dialect/Ops.cpp b/enzyme/Enzyme/MLIR/Dialect/Ops.cpp index 3e3185427306..7e48db2d583b 100644 --- a/enzyme/Enzyme/MLIR/Dialect/Ops.cpp +++ b/enzyme/Enzyme/MLIR/Dialect/Ops.cpp @@ -27,6 +27,7 @@ #include "mlir/IR/IRMapping.h" #include "mlir/IR/IntegerSet.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/Debug.h" @@ -191,3 +192,17 @@ LogicalResult BatchOp::verifySymbolUses(SymbolTableCollection &symbolTable) { return success(); } + +//===----------------------------------------------------------------------===// +// BroadcastOp +//===----------------------------------------------------------------------===// + +void BroadcastOp::build(OpBuilder &builder, OperationState &result, Value input, + ArrayRef shape) { + auto shapeAttr = builder.getDenseI64ArrayAttr(shape); + auto resultTy = input.getType(); + for (auto s : llvm::reverse(shape)) { + resultTy = resultTy.cast().getShadowType(s); + } + build(builder, result, resultTy, input, shapeAttr); +} diff --git a/enzyme/Enzyme/MLIR/Implementations/ArithAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/ArithAutoDiffOpInterfaceImpl.cpp index 9b27503d79dc..8d3650969d09 100644 --- a/enzyme/Enzyme/MLIR/Implementations/ArithAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/ArithAutoDiffOpInterfaceImpl.cpp @@ -17,6 +17,7 @@ #include "Interfaces/GradientUtilsReverse.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/Support/LogicalResult.h" @@ -69,3 +70,10 @@ void mlir::enzyme::registerArithDialectAutoDiffInterface( arith::ConstantOp::attachInterface(*context); }); } + +void mlir::enzyme::registerTensorDialectAutoDiffInterface( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *context, tensor::TensorDialect *) { + registerInterfaces(context); + }); +} diff --git a/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp index d2d6ddfe19be..7c72b97d0934 100644 --- a/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp @@ -45,8 +45,11 @@ class FloatTypeInterface } Type getShadowType(Type self, unsigned width) const { - assert(width == 1 && "unsupported width != 1"); - return self; + if (width > 1) { + return RankedTensorType::get({width}, self); + } else { + return self; + } } bool isMutable(Type self) const { return false; } @@ -106,7 +109,14 @@ class TensorTypeInterface } Type getShadowType(Type self, unsigned width) const { - assert(width == 1 && "unsupported width != 1"); + if (width != 1) { + auto tenType = self.cast(); + auto shape = tenType.getShape(); + SmallVector newShape; + newShape.push_back(width); + newShape.append(shape.begin(), shape.end()); + return RankedTensorType::get(newShape, tenType.getElementType()); + } return self; } diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index 355808cdbcc1..f727dca2f877 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -74,7 +74,8 @@ void mlir::enzyme::detail::branchingForwardHandler(Operation *inst, newVals.push_back(gutils->invertPointerM(op, builder)); } else { Type retTy = - arg.getType().cast().getShadowType(); + arg.getType().cast().getShadowType( + gutils->width); auto toret = retTy.cast().createNullValue( builder, op.getLoc()); newVals.push_back(toret); @@ -146,7 +147,7 @@ LogicalResult mlir::enzyme::detail::memoryIdentityForwardHandler( if (auto iface = dyn_cast(operand.get().getType())) { if (!iface.isMutable()) { - Type retTy = iface.getShadowType(); + Type retTy = iface.getShadowType(gutils->width); auto toret = retTy.cast().createNullValue( builder, operand.get().getLoc()); newOperands.push_back(toret); @@ -346,7 +347,7 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler( << result.getType() << "\n"; return failure(); } - newOpResultTypes.push_back(typeIface.getShadowType()); + newOpResultTypes.push_back(typeIface.getShadowType(gutils->width)); } SmallVector newOperands; @@ -432,4 +433,5 @@ void mlir::enzyme::registerCoreDialectAutodiffInterfaces( enzyme::registerCFDialectAutoDiffInterface(registry); enzyme::registerLinalgDialectAutoDiffInterface(registry); enzyme::registerFuncDialectAutoDiffInterface(registry); + enzyme::registerTensorDialectAutoDiffInterface(registry); } diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h index d6f28ccfc736..650f6c6326bb 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h @@ -260,6 +260,7 @@ void registerCFDialectAutoDiffInterface(DialectRegistry ®istry); void registerLinalgDialectAutoDiffInterface(DialectRegistry ®istry); void registerMathDialectAutoDiffInterface(DialectRegistry ®istry); void registerFuncDialectAutoDiffInterface(DialectRegistry ®istry); +void registerTensorDialectAutoDiffInterface(DialectRegistry ®istry); void registerCoreDialectAutodiffInterfaces(DialectRegistry ®istry); diff --git a/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp b/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp index 69cfad436cfd..5ec908f1268a 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp @@ -245,9 +245,11 @@ FunctionOpInterface CloneFunctionWithReturns( mlir::Value val = blk.getArgument(i); mlir::Value dval; if (i == ArgActivity.size() - 1) - dval = blk.addArgument(val.getType(), val.getLoc()); + dval = blk.addArgument(getShadowType(val.getType(), width), + val.getLoc()); else - dval = blk.insertArgument(blk.args_begin() + i + 1, val.getType(), + dval = blk.insertArgument(blk.args_begin() + i + 1, + getShadowType(val.getType(), width), val.getLoc()); ptrInputs.map(oval, dval); } diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp index 1ec4212dc5a5..32cb5b796144 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp @@ -108,7 +108,8 @@ mlir::Value mlir::enzyme::MGradientUtils::invertPointerM(mlir::Value v, return invertedPointers.lookupOrNull(v); if (isConstantValue(v)) { - if (auto iface = v.getType().dyn_cast()) { + if (auto iface = + getShadowType(v.getType()).dyn_cast()) { OpBuilder::InsertionGuard guard(Builder2); if (auto op = v.getDefiningOp()) Builder2.setInsertionPoint(getNewFromOriginal(op)); diff --git a/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt b/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt index 0445fc430649..99db4d80034c 100644 --- a/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt +++ b/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt @@ -31,6 +31,7 @@ add_mlir_dialect_library(MLIREnzymeTransforms MLIRFuncDialect MLIRFuncTransforms MLIRGPUDialect + MLIRTensorDialect MLIRIR MLIRLLVMDialect MLIRMathDialect diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.h b/enzyme/Enzyme/MLIR/Passes/Passes.h index 58c43be236de..fb6df3e2208c 100644 --- a/enzyme/Enzyme/MLIR/Passes/Passes.h +++ b/enzyme/Enzyme/MLIR/Passes/Passes.h @@ -15,6 +15,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "Dialect/Dialect.h" @@ -80,6 +81,10 @@ namespace affine { class AffineDialect; } // end namespace affine +namespace tensor { +class TensorDialect; +} // end namespace tensor + namespace LLVM { class LLVMDialect; } // end namespace LLVM diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.td b/enzyme/Enzyme/MLIR/Passes/Passes.td index 6458e63b2735..c5b4df769172 100644 --- a/enzyme/Enzyme/MLIR/Passes/Passes.td +++ b/enzyme/Enzyme/MLIR/Passes/Passes.td @@ -16,7 +16,8 @@ def DifferentiatePass : Pass<"enzyme"> { let dependentDialects = [ "arith::ArithDialect", "complex::ComplexDialect", - "cf::ControlFlowDialect" + "cf::ControlFlowDialect", + "tensor::TensorDialect", ]; let constructor = "mlir::enzyme::createDifferentiatePass()"; } diff --git a/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp b/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp index 0e6bdf7b101e..99e7243129be 100644 --- a/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp +++ b/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp @@ -67,6 +67,7 @@ int main(int argc, char **argv) { registry.insert(); registry.insert(); registry.insert(); + registry.insert(); registry.insert(); registry.insert(); diff --git a/enzyme/test/MLIR/ForwardMode/batched_branch.mlir b/enzyme/test/MLIR/ForwardMode/batched_branch.mlir new file mode 100644 index 000000000000..f20989aa4245 --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/batched_branch.mlir @@ -0,0 +1,26 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @square(%x : f64, %y : f64) -> f64 { + %c = arith.cmpf ult, %x, %y : f64 + cf.cond_br %c, ^blk2(%x : f64), ^blk2(%y : f64) + + ^blk2(%r : f64): + return %r : f64 + } + func.func @dsq(%x : f64, %dx : tensor<2xf64>, %y : f64, %dy : tensor<2xf64>) -> tensor<2xf64> { + %r = enzyme.fwddiff @square(%x, %dx, %y, %dy) { activity=[#enzyme, #enzyme], ret_activity=[#enzyme], width=2 } : (f64, tensor<2xf64>, f64, tensor<2xf64>) -> (tensor<2xf64>) + return %r : tensor<2xf64> + } +} + +// CHECK: func.func @dsq(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>, %[[arg2:.+]]: f64, %[[arg3:.+]]: tensor<2xf64>) -> tensor<2xf64> { +// CHECK-NEXT: %[[i0:.+]] = call @fwddiffesquare(%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]]) : (f64, tensor<2xf64>, f64, tensor<2xf64>) -> tensor<2xf64> +// CHECK-NEXT: return %[[i0]] : tensor<2xf64> +// CHECK-NEXT: } +// CHECK: func.func private @fwddiffesquare(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>, %[[arg2:.+]]: f64, %[[arg3]]: tensor<2xf64>) -> tensor<2xf64> { +// CHECK-NEXT: %[[i0:.+]] = arith.cmpf ult, %[[arg0]], %[[arg2]] : f64 +// CHECK-NEXT: cf.cond_br %[[i0]], ^bb1(%[[arg0]], %[[arg1]] : f64, tensor<2xf64>), ^bb1(%[[arg2]], %[[arg3]] : f64, tensor<2xf64>) +// CHECK-NEXT: ^bb1(%[[i1:.+]]: f64, %[[i2:.+]]: tensor<2xf64>): // 2 preds: ^bb0, ^bb0 +// CHECK-NEXT: return %[[i2]] : tensor<2xf64> +// CHECK-NEXT: } diff --git a/enzyme/test/MLIR/ForwardMode/batched_for.mlir b/enzyme/test/MLIR/ForwardMode/batched_for.mlir new file mode 100644 index 000000000000..95557cb0b6fc --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/batched_for.mlir @@ -0,0 +1,33 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @square(%x : f64) -> f64 { + %cst = arith.constant 10.000000e+00 : f64 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %r = scf.for %arg1 = %c0 to %c10 step %c1 iter_args(%arg2 = %cst) -> (f64) { + %n = arith.addf %arg2, %x : f64 + scf.yield %n : f64 + } + return %r : f64 + } + func.func @dsq(%x : f64, %dx : tensor<2xf64>) -> tensor<2xf64> { + %r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme], ret_activity=[#enzyme], width=2 } : (f64, tensor<2xf64>) -> (tensor<2xf64>) + return %r : tensor<2xf64> + } +} + +// CHECK: func.func private @fwddiffesquare(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> { +// CHECK-DAG: %[[cst:.+]] = arith.constant dense<0.000000e+00> : tensor<2xf64> +// CHECK-DAG: %[[cst_0:.+]] = arith.constant 1.000000e+01 : f64 +// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[c10:.+]] = arith.constant 10 : index +// CHECK-NEXT: %[[i0:.+]]:2 = scf.for %[[arg2:.+]] = %[[c0]] to %[[c10]] step %[[c1]] iter_args(%[[arg3:.+]] = %[[cst_0]], %[[arg4:.+]] = %[[cst]]) -> (f64, tensor<2xf64>) { +// CHECK-NEXT: %[[i1:.+]] = arith.addf %[[arg4]], %[[arg1]] : tensor<2xf64> +// CHECK-NEXT: %[[i2:.+]] = arith.addf %[[arg3]], %[[arg0]] : f64 +// CHECK-NEXT: scf.yield %[[i2]], %[[i1]] : f64, tensor<2xf64> +// CHECK-NEXT: } +// CHECK-NEXT: return %[[i0]]#1 : tensor<2xf64> +// CHECK-NEXT: } diff --git a/enzyme/test/MLIR/ForwardMode/batched_if.mlir b/enzyme/test/MLIR/ForwardMode/batched_if.mlir new file mode 100644 index 000000000000..33c9e1b9fe8b --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/batched_if.mlir @@ -0,0 +1,43 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @square(%x : f64, %c : i1) -> f64 { + %c2 = arith.constant 2.000000e+00 : f64 + %c10 = arith.constant 10.000000e+00 : f64 + %r:2 = scf.if %c -> (f64, f64) { + %mul = arith.mulf %x, %x : f64 + scf.yield %mul, %c2 : f64, f64 + } else { + %add = arith.addf %x, %x : f64 + scf.yield %add, %c10 : f64, f64 + } + %res = arith.mulf %r#0, %r#1 : f64 + return %res : f64 + } + func.func @dsq(%x : f64, %dx : tensor<2xf64>, %c : i1) -> tensor<2xf64> { + %r = enzyme.fwddiff @square(%x, %dx, %c) { activity=[#enzyme, #enzyme], ret_activity=[#enzyme], width=2 } : (f64, tensor<2xf64>, i1) -> (tensor<2xf64>) + return %r : tensor<2xf64> + } +} + +// CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>, %[[arg2:.+]]: i1) -> tensor<2xf64> { +// CHECK-DAG: %[[cst2:.+]] = arith.constant 2.000000e+00 : f64 +// CHECK-DAG: %[[cst10:.+]] = arith.constant 1.000000e+01 : f64 +// CHECK-NEXT: %[[r0:.+]]:3 = scf.if %[[arg2]] -> (f64, tensor<2xf64>, f64) { +// CHECK-NEXT: %[[t4:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array}> : (f64) -> tensor<2xf64> +// CHECK-NEXT: %[[t5:.+]] = arith.mulf %[[arg1]], %[[t4]] : tensor<2xf64> +// CHECK-NEXT: %[[t6:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array}> : (f64) -> tensor<2xf64> +// CHECK-NEXT: %[[t7:.+]] = arith.mulf %[[arg1]], %[[t6]] : tensor<2xf64> +// CHECK-NEXT: %[[t8:.+]] = arith.addf %[[t5]], %[[t7]] : tensor<2xf64> +// CHECK-NEXT: %[[t9:.+]] = arith.mulf %[[arg0]], %[[arg0]] : f64 +// CHECK-NEXT: scf.yield %[[t9]], %[[t8]], %[[cst2]] : f64, tensor<2xf64>, f64 +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[e4:.+]] = arith.addf %[[arg1]], %[[arg1]] : tensor<2xf64> +// CHECK-NEXT: %[[e5:.+]] = arith.addf %[[arg0]], %[[arg0]] : f64 +// CHECK-NEXT: scf.yield %[[e5]], %[[e4]], %[[cst10]] : f64, tensor<2xf64>, f64 +// CHECK-NEXT: } +// CHECK-NEXT: %[[r1:.+]] = "enzyme.broadcast"(%[[r0]]#2) <{shape = array}> : (f64) -> tensor<2xf64> +// CHECK-NEXT: %[[r2:.+]] = arith.mulf %[[r0]]#1, %[[r1]] : tensor<2xf64> +// CHECK-NEXT: %[[r3:.+]] = arith.mulf %[[r0]]#0, %[[r0]]#2 : f64 +// CHECK-NEXT: return %[[r2]] : tensor<2xf64> +// CHECK-NEXT: } diff --git a/enzyme/test/MLIR/ForwardMode/batched_scalar.mlir b/enzyme/test/MLIR/ForwardMode/batched_scalar.mlir new file mode 100644 index 000000000000..f06f86d2a043 --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/batched_scalar.mlir @@ -0,0 +1,26 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @square(%x : f64) -> f64{ + %y = arith.mulf %x, %x : f64 + return %y : f64 + } + func.func @dsq(%x : f64, %dx : tensor<2xf64>) -> tensor<2xf64> { + %r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme], ret_activity=[#enzyme], width=2 } : (f64, tensor<2xf64>) -> (tensor<2xf64>) + return %r : tensor<2xf64> + } +} + +// CHECK: func.func @dsq(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> { +// CHECK-NEXT: %[[i0:.+]] = call @fwddiffe2square(%[[arg0]], %[[arg1]]) : (f64, tensor<2xf64>) -> tensor<2xf64> +// CHECK-NEXT: return %[[i0]] : tensor<2xf64> +// CHECK-NEXT: } +// CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> { +// CHECK-NEXT: %[[s0:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array}> : f64 -> tensor<2xf64> +// CHECK-NEXT: %[[i0:.+]] = arith.mulf %[[arg1]], %[[s0]] : tensor<2xf64> +// CHECK-NEXT: %[[s1:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array}> : f64 -> tensor<2xf64> +// CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[s1]] : tensor<2xf64> +// CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : tensor<2xf64> +// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : tensor<2xf64> +// CHECK-NEXT: return %[[i2]] : tensor<2xf64> +// CHECK-NEXT: } diff --git a/enzyme/test/MLIR/ForwardMode/batched_tensor.mlir b/enzyme/test/MLIR/ForwardMode/batched_tensor.mlir new file mode 100644 index 000000000000..11b75f634a67 --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/batched_tensor.mlir @@ -0,0 +1,26 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @square(%x : tensor<10xf64>) -> tensor<10xf64>{ + %y = arith.mulf %x, %x : tensor<10xf64> + return %y : tensor<10xf64> + } + func.func @dsq(%x : tensor<10xf64>, %dx : tensor<2x10xf64>) -> tensor<2x10xf64> { + %r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme], ret_activity=[#enzyme], width=2 } : (tensor<10xf64>, tensor<2x10xf64>) -> (tensor<2x10xf64>) + return %r : tensor<2x10xf64> + } +} + +// CHECK: func.func @dsq(%[[arg0:.+]]: tensor<10xf64>, %[[arg1:.+]]: tensor<2x10xf64>) -> tensor<2x10xf64> { +// CHECK-NEXT: %[[i0:.+]] = call @fwddiffe2square(%[[arg0]], %[[arg1]]) : (tensor<10xf64>, tensor<2x10xf64>) -> tensor<2x10xf64> +// CHECK-NEXT: return %[[i0]] : tensor<2x10xf64> +// CHECK-NEXT: } +// CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: tensor<10xf64>, %[[arg1:.+]]: tensor<2x10xf64>) -> tensor<2x10xf64> { +// CHECK-NEXT: %[[s0:.+]] = "enzyme.broadcast"(%%[[arg0]]) <{shape = array}> : (tensor<10xf64>) -> tensor<2x10xf64> +// CHECK-NEXT: %[[i0:.+]] = arith.mulf %[[arg1]], %[[s0]] : tensor<2x10xf64> +// CHECK-NEXT: %[[s1:.+]] = "enzyme.broadcast"(%%[[arg0]]) <{shape = array}> : (tensor<10xf64>) -> tensor<2x10xf64> +// CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[s1]] : tensor<2x10xf64> +// CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : tensor<2x10xf64> +// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : tensor<10xf64> +// CHECK-NEXT: return %[[i2]] : tensor<2x10xf64> +// CHECK-NEXT: } diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 900c5c813cd7..dccbc7b7923c 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -275,8 +275,19 @@ SmallVector prepareArgs(const Twine &curIndent, raw_ostream &os, os << ord; } if (!vecValue && !startsWith(ord, "local")) { - if (newFromOriginal && (!lookup || intrinsic != MLIRDerivatives)) + if (newFromOriginal && (!lookup || intrinsic != MLIRDerivatives)) { os << ")"; + if (intrinsic == MLIRDerivatives) { + os << ";\n"; + os << "if (gutils->width != 1) {\n" + << " " << argName << "_" << (idx - 1) + << " = builder.create(\n" + << " op.getLoc(),\n" + << " " << argName << "_" << (idx - 1) << ",\n" + << " llvm::SmallVector({gutils->width}));\n" + << "}"; + } + } if (lookup && intrinsic != MLIRDerivatives) os << ", " << builder << ")"; From 29fe86bde582cc4d8a433e5f0064138cc76b1715 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 28 Dec 2024 20:09:40 -0500 Subject: [PATCH 32/45] slice activity analysis (#2209) --- enzyme/Enzyme/ActivityAnalysis.cpp | 35 ++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index 67cbb5e9ddfe..67972fdcc8c1 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -579,6 +579,11 @@ bool ActivityAnalyzer::isFunctionArgumentConstant(CallInst *CI, Value *val) { if (Name == "jl_reshape_array" || Name == "ijl_reshape_array") return val != CI->getArgOperand(1); + // Only the 0-th arg impacts activity + if (Name == "jl_genericmemory_copy_slice" || + Name == "ijl_genericmemory_copy_slice") + return val != CI->getArgOperand(0); + // Allocations, deallocations, and c++ guards don't impact the activity // of arguments if (isAllocationFunction(Name, TLI) || isDeallocationFunction(Name, TLI)) @@ -660,6 +665,13 @@ static inline void propagateArgumentInformation( return; } + // Only the 0-th arg impacts activity + if (Name == "jl_genericmemory_copy_slice" || + Name == "ijl_genericmemory_copy_slice") { + propagateFromOperand(CI.getArgOperand(1)); + return; + } + // Only the 1-th arg impacts activity if (Name == "jl_reshape_array" || Name == "ijl_reshape_array") { propagateFromOperand(CI.getArgOperand(1)); @@ -1601,6 +1613,29 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { return true; } + if (funcName == "jl_genericmemory_copy_slice" || + funcName == "ijl_genericmemory_copy_slice") { + if (directions == UP) { + if (isConstantValue(TR, op->getArgOperand(0))) { + InsertConstantValue(TR, Val); + return true; + } + } else { + if (UpHypothesis->isConstantValue(TR, op->getArgOperand(0))) { + InsertConstantValue(TR, Val); + insertConstantsFrom(TR, *UpHypothesis); + return true; + } + } + if (EnzymeEnableRecursiveHypotheses) { + ReEvaluateValueIfInactiveValue[op->getArgOperand(0)].insert(Val); + if (TmpOrig != Val) { + ReEvaluateValueIfInactiveValue[op->getArgOperand(0)].insert( + TmpOrig); + } + } + } + // If requesting empty unknown functions to be considered inactive, // abide by those rules if (called && EnzymeEmptyFnInactive && called->empty() && From bdda0ce16544f503f06a627c3db0579d876a7e04 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 28 Dec 2024 21:12:20 -0500 Subject: [PATCH 33/45] Fix batched shadow reverse runtime AA (#2210) --- enzyme/Enzyme/DiffeGradientUtils.cpp | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/enzyme/Enzyme/DiffeGradientUtils.cpp b/enzyme/Enzyme/DiffeGradientUtils.cpp index eba0de11f54d..e88e762e7f03 100644 --- a/enzyme/Enzyme/DiffeGradientUtils.cpp +++ b/enzyme/Enzyme/DiffeGradientUtils.cpp @@ -1179,9 +1179,13 @@ void DiffeGradientUtils::addToInvertedPtrDiffe( // the pointers and conditionally execute. if ((!isa(basePtr) && !isAllocationCall(basePtr, TLI)) && runtimeActivity && !merge) { - Value *shadow = Builder2.CreateICmpNE( - lookupM(getNewFromOriginal(origptr), Builder2), - lookupM(invertPointerM(origptr, Builder2), Builder2)); + Value *primal_val = lookupM(getNewFromOriginal(origptr), Builder2); + Value *shadow_val = + lookupM(invertPointerM(origptr, Builder2), Builder2); + if (getWidth() != 1) { + shadow_val = extractMeta(Builder2, shadow_val, 0); + } + Value *shadow = Builder2.CreateICmpNE(primal_val, shadow_val); BasicBlock *current = Builder2.GetInsertBlock(); BasicBlock *conditional = From 7cf9e90de211ed525aca58e0930b14a827e20bba Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 28 Dec 2024 22:38:20 -0500 Subject: [PATCH 34/45] Fix copyslice v2 (#2211) * Fix copyslice v2 * fix --- enzyme/Enzyme/ActivityAnalysis.cpp | 30 ++++++------------------------ 1 file changed, 6 insertions(+), 24 deletions(-) diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index 67972fdcc8c1..44e88abe3519 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -668,7 +668,7 @@ static inline void propagateArgumentInformation( // Only the 0-th arg impacts activity if (Name == "jl_genericmemory_copy_slice" || Name == "ijl_genericmemory_copy_slice") { - propagateFromOperand(CI.getArgOperand(1)); + propagateFromOperand(CI.getArgOperand(0)); return; } @@ -1613,29 +1613,6 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { return true; } - if (funcName == "jl_genericmemory_copy_slice" || - funcName == "ijl_genericmemory_copy_slice") { - if (directions == UP) { - if (isConstantValue(TR, op->getArgOperand(0))) { - InsertConstantValue(TR, Val); - return true; - } - } else { - if (UpHypothesis->isConstantValue(TR, op->getArgOperand(0))) { - InsertConstantValue(TR, Val); - insertConstantsFrom(TR, *UpHypothesis); - return true; - } - } - if (EnzymeEnableRecursiveHypotheses) { - ReEvaluateValueIfInactiveValue[op->getArgOperand(0)].insert(Val); - if (TmpOrig != Val) { - ReEvaluateValueIfInactiveValue[op->getArgOperand(0)].insert( - TmpOrig); - } - } - } - // If requesting empty unknown functions to be considered inactive, // abide by those rules if (called && EnzymeEmptyFnInactive && called->empty() && @@ -2751,6 +2728,11 @@ bool ActivityAnalyzer::isValueInactiveFromUsers(TypeResults const &TR, if (AllocaSet.count(TmpOrig)) { continue; } + // We are literally storing our value into ourselves [or relevant + // derived pointer] + if (TmpOrig == val) { + continue; + } if (isa(TmpOrig)) { newAllocaSet.insert(TmpOrig); continue; From 8e36e65e6570827796acf3ae7df421fc244ca8aa Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 1 Jan 2025 14:18:42 -0500 Subject: [PATCH 35/45] Improve cast ft error (#2213) --- enzyme/Enzyme/AdjointGenerator.h | 46 +++++++++++++++++--------------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index 96b8494302ec..655bdca6943f 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -1380,31 +1380,33 @@ class AdjointGenerator : public llvm::InstVisitor { ss << "Cannot deduce adding type (cast) of " << I; EmitNoTypeError(str, I, gutils, Builder2); } - assert(FT); - auto rule = [&](Value *dif) { - if (I.getOpcode() == CastInst::CastOps::FPTrunc || - I.getOpcode() == CastInst::CastOps::FPExt) { - return Builder2.CreateFPCast(dif, op0->getType()); - } else if (I.getOpcode() == CastInst::CastOps::BitCast) { - return Builder2.CreateBitCast(dif, op0->getType()); - } else if (I.getOpcode() == CastInst::CastOps::Trunc) { - // TODO CHECK THIS - return Builder2.CreateZExt(dif, op0->getType()); - } else { - std::string s; - llvm::raw_string_ostream ss(s); - ss << *I.getParent()->getParent() << "\n"; - ss << "cannot handle above cast " << I << "\n"; - EmitNoDerivativeError(ss.str(), I, gutils, Builder2); - return (llvm::Value *)UndefValue::get(op0->getType()); - } - }; + if (FT) { + + auto rule = [&](Value *dif) { + if (I.getOpcode() == CastInst::CastOps::FPTrunc || + I.getOpcode() == CastInst::CastOps::FPExt) { + return Builder2.CreateFPCast(dif, op0->getType()); + } else if (I.getOpcode() == CastInst::CastOps::BitCast) { + return Builder2.CreateBitCast(dif, op0->getType()); + } else if (I.getOpcode() == CastInst::CastOps::Trunc) { + // TODO CHECK THIS + return Builder2.CreateZExt(dif, op0->getType()); + } else { + std::string s; + llvm::raw_string_ostream ss(s); + ss << *I.getParent()->getParent() << "\n"; + ss << "cannot handle above cast " << I << "\n"; + EmitNoDerivativeError(ss.str(), I, gutils, Builder2); + return (llvm::Value *)UndefValue::get(op0->getType()); + } + }; - Value *dif = diffe(&I, Builder2); - Value *diff = applyChainRule(op0->getType(), Builder2, rule, dif); + Value *dif = diffe(&I, Builder2); + Value *diff = applyChainRule(op0->getType(), Builder2, rule, dif); - addToDiffe(orig_op0, diff, Builder2, FT); + addToDiffe(orig_op0, diff, Builder2, FT); + } } Type *diffTy = gutils->getShadowType(I.getType()); From 7bc73fa8291cfed08456ae93e6f460060a6c8344 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 1 Jan 2025 16:42:51 -0500 Subject: [PATCH 36/45] MLIR: post optimization pipeline (#2214) * MLIR: post optimization pipeline * build start * fix * fix * fix build * format * fixup --- enzyme/.bazelversion | 1 + .../FuncAutoDiffOpInterfaceImpl.cpp | 4 +-- enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp | 27 ++++++++++++++----- enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h | 19 +++++++------ .../MLIR/Interfaces/EnzymeLogicReverse.cpp | 21 +++++++++++++-- .../Enzyme/MLIR/Interfaces/GradientUtils.cpp | 8 +++--- enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h | 14 ++++++---- .../MLIR/Interfaces/GradientUtilsReverse.cpp | 8 +++--- .../MLIR/Interfaces/GradientUtilsReverse.h | 17 +++++++----- enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp | 17 ++++++++++-- enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp | 4 +-- enzyme/Enzyme/MLIR/Passes/Passes.td | 9 +++++++ .../test/MLIR/ForwardMode/batched_branch.mlir | 4 +-- enzyme/test/MLIR/ForwardMode/batched_for.mlir | 2 +- .../test/MLIR/ForwardMode/batched_scalar.mlir | 4 +-- .../test/MLIR/ForwardMode/batched_tensor.mlir | 4 +-- 16 files changed, 113 insertions(+), 50 deletions(-) create mode 100644 enzyme/.bazelversion diff --git a/enzyme/.bazelversion b/enzyme/.bazelversion new file mode 100644 index 000000000000..f22d756da39d --- /dev/null +++ b/enzyme/.bazelversion @@ -0,0 +1 @@ +6.5.0 diff --git a/enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp index 5308304f5b77..54845c740d30 100644 --- a/enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp @@ -73,7 +73,7 @@ class AutoDiffCallFwd fn, RetActivity, ArgActivity, gutils->TA, returnPrimal, mode, freeMemory, width, /* addedType */ nullptr, type_args, volatile_args, - /* augmented */ nullptr); + /* augmented */ nullptr, gutils->postpasses); SmallVector fwdArguments; @@ -173,7 +173,7 @@ class AutoDiffCallRev auto revFn = gutils->Logic.CreateReverseDiff( fn, RetActivity, ArgActivity, gutils->TA, returnPrimal, returnShadow, mode, freeMemory, width, /*addedType*/ nullptr, type_args, - volatile_args, /*augmented*/ nullptr); + volatile_args, /*augmented*/ nullptr, gutils->postpasses); SmallVector revArguments; diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp index fbd337813bcb..7a5770ccdaa8 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp @@ -13,6 +13,9 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Dominance.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Pass/PassRegistry.h" + #include "llvm/ADT/BreadthFirstIterator.h" #include "EnzymeLogic.h" @@ -78,7 +81,8 @@ FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateForwardDiff( std::vector ArgActivity, MTypeAnalysis &TA, std::vector returnPrimals, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, - std::vector volatile_args, void *augmented) { + std::vector volatile_args, void *augmented, + llvm::StringRef postpasses) { if (fn.getFunctionBody().empty()) { llvm::errs() << fn << "\n"; llvm_unreachable("Differentiating empty function"); @@ -105,7 +109,7 @@ FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateForwardDiff( auto gutils = MDiffeGradientUtils::CreateFromClone( *this, mode, width, fn, TA, type_args, returnPrimalsP, returnShadowsP, RetActivity, ArgActivity, addedType, - /*omp*/ false); + /*omp*/ false, postpasses); ForwardCachedFunctions[tup] = gutils->newFunc; insert_or_assign2( @@ -195,10 +199,19 @@ FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateForwardDiff( if (!valid) return nullptr; - // if (PostOpt) - // PPC.optimizeIntermediate(nf); - // if (EnzymePrint) { - // llvm::errs() << nf << "\n"; - //} + if (postpasses != "") { + mlir::PassManager pm(nf->getContext()); + std::string error_message; + // llvm::raw_string_ostream error_stream(error_message); + mlir::LogicalResult result = mlir::parsePassPipeline(postpasses, pm); + if (mlir::failed(result)) { + return nullptr; + } + + if (!mlir::succeeded(pm.run(nf))) { + return nullptr; + } + } + return nf; } diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h index c8cad6eee272..aef498d52274 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h @@ -196,14 +196,17 @@ class MEnzymeLogic { std::vector returnPrimals, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, std::vector volatile_args, - void *augmented); - - FunctionOpInterface CreateReverseDiff( - FunctionOpInterface fn, std::vector retType, - std::vector constants, MTypeAnalysis &TA, - std::vector returnPrimals, std::vector returnShadows, - DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, - MFnTypeInfo type_args, std::vector volatile_args, void *augmented); + void *augmented, llvm::StringRef postpasses); + + FunctionOpInterface + CreateReverseDiff(FunctionOpInterface fn, std::vector retType, + std::vector constants, MTypeAnalysis &TA, + std::vector returnPrimals, + std::vector returnShadows, DerivativeMode mode, + bool freeMemory, size_t width, mlir::Type addedType, + MFnTypeInfo type_args, std::vector volatile_args, + void *augmented, llvm::StringRef postpasses); + void initializeShadowValues(SmallVector &dominatorToposortBlocks, MGradientUtilsReverse *gutils); diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp index 0812a7ccde5d..7ca0e9ea72f1 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp @@ -11,6 +11,8 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Pass/PassRegistry.h" #include "EnzymeLogic.h" #include "Interfaces/GradientUtils.h" @@ -182,7 +184,8 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff( std::vector constants, MTypeAnalysis &TA, std::vector returnPrimals, std::vector returnShadows, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, - MFnTypeInfo type_args, std::vector volatile_args, void *augmented) { + MFnTypeInfo type_args, std::vector volatile_args, void *augmented, + llvm::StringRef postpasses) { if (fn.getFunctionBody().empty()) { llvm::errs() << fn << "\n"; @@ -214,7 +217,7 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff( MGradientUtilsReverse *gutils = MGradientUtilsReverse::CreateFromClone( *this, mode, width, fn, TA, type_args, returnPrimalsP, returnShadowsP, - retType, constants, addedType); + retType, constants, addedType, postpasses); ReverseCachedFunctions[tup] = gutils->newFunc; @@ -254,5 +257,19 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff( if (!res.succeeded()) return nullptr; + if (postpasses != "") { + mlir::PassManager pm(nf->getContext()); + std::string error_message; + // llvm::raw_string_ostream error_stream(error_message); + mlir::LogicalResult result = mlir::parsePassPipeline(postpasses, pm); + if (mlir::failed(result)) { + return nullptr; + } + + if (!mlir::succeeded(pm.run(nf))) { + return nullptr; + } + } + return nf; } diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp index 32cb5b796144..0dab1032af91 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp @@ -37,15 +37,15 @@ mlir::enzyme::MGradientUtils::MGradientUtils( ArrayRef ReturnActivity, ArrayRef ArgDiffeTypes_, IRMapping &originalToNewFn_, std::map &originalToNewFnOps_, - DerivativeMode mode, unsigned width, bool omp) + DerivativeMode mode, unsigned width, bool omp, llvm::StringRef postpasses) : newFunc(newFunc_), Logic(Logic), mode(mode), oldFunc(oldFunc_), invertedPointers(invertedPointers_), originalToNewFn(originalToNewFn_), originalToNewFnOps(originalToNewFnOps_), blocksNotForAnalysis(), activityAnalyzer(std::make_unique( blocksNotForAnalysis, constantvalues_, activevals_, ReturnActivity)), - TA(TA_), TR(TR_), omp(omp), returnPrimals(returnPrimals), - returnShadows(returnShadows), width(width), ArgDiffeTypes(ArgDiffeTypes_), - RetDiffeTypes(ReturnActivity) {} + TA(TA_), TR(TR_), omp(omp), postpasses(postpasses), + returnPrimals(returnPrimals), returnShadows(returnShadows), width(width), + ArgDiffeTypes(ArgDiffeTypes_), RetDiffeTypes(ReturnActivity) {} mlir::Value mlir::enzyme::MGradientUtils::getNewFromOriginal( const mlir::Value originst) const { diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h index 1fac52caab39..085bd678f834 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h @@ -36,6 +36,7 @@ class MGradientUtils { MTypeAnalysis &TA; MTypeResults TR; bool omp; + llvm::StringRef postpasses; const llvm::ArrayRef returnPrimals; const llvm::ArrayRef returnShadows; @@ -58,7 +59,8 @@ class MGradientUtils { ArrayRef ArgDiffeTypes_, IRMapping &originalToNewFn_, std::map &originalToNewFnOps_, - DerivativeMode mode, unsigned width, bool omp); + DerivativeMode mode, unsigned width, bool omp, + llvm::StringRef postpasses); void erase(Operation *op) { op->erase(); } void replaceOrigOpWith(Operation *op, ValueRange vals) { for (auto &&[res, rep] : llvm::zip(op->getResults(), vals)) { @@ -113,11 +115,12 @@ class MDiffeGradientUtils : public MGradientUtils { ArrayRef RetActivity, ArrayRef ArgActivity, IRMapping &origToNew_, std::map &origToNewOps_, - DerivativeMode mode, unsigned width, bool omp) + DerivativeMode mode, unsigned width, bool omp, + llvm::StringRef postpasses) : MGradientUtils(Logic, newFunc_, oldFunc_, TA, TR, invertedPointers_, returnPrimals, returnShadows, constantvalues_, activevals_, RetActivity, ArgActivity, origToNew_, - origToNewOps_, mode, width, omp), + origToNewOps_, mode, width, omp, postpasses), initializationBlock(&*(newFunc.getFunctionBody().begin())) {} // Technically diffe constructor @@ -127,7 +130,7 @@ class MDiffeGradientUtils : public MGradientUtils { const llvm::ArrayRef returnPrimals, const llvm::ArrayRef returnShadows, ArrayRef RetActivity, ArrayRef ArgActivity, - mlir::Type additionalArg, bool omp) { + mlir::Type additionalArg, bool omp, llvm::StringRef postpasses) { std::string prefix; switch (mode) { @@ -163,7 +166,8 @@ class MDiffeGradientUtils : public MGradientUtils { return new MDiffeGradientUtils( Logic, newFunc, todiff, TA, TR, invertedPointers, returnPrimals, returnShadows, constant_values, nonconstant_values, RetActivity, - ArgActivity, originalToNew, originalToNewOps, mode, width, omp); + ArgActivity, originalToNew, originalToNewOps, mode, width, omp, + postpasses); } }; diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp index 793b073de0f9..c9fe98bc5a5e 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp @@ -37,12 +37,12 @@ mlir::enzyme::MGradientUtilsReverse::MGradientUtilsReverse( ArrayRef ReturnActivity, ArrayRef ArgDiffeTypes_, IRMapping &originalToNewFn_, std::map &originalToNewFnOps_, - DerivativeMode mode_, unsigned width) + DerivativeMode mode_, unsigned width, StringRef postpasses) : MDiffeGradientUtils(Logic, newFunc_, oldFunc_, TA_, /*MTypeResults*/ {}, invertedPointers_, returnPrimals, returnShadows, constantvalues_, activevals_, ReturnActivity, ArgDiffeTypes_, originalToNewFn_, originalToNewFnOps_, - mode_, width, /*omp*/ false) {} + mode_, width, /*omp*/ false, postpasses) {} Type mlir::enzyme::MGradientUtilsReverse::getIndexCacheType() { Type indexType = getIndexType(); @@ -138,7 +138,7 @@ MGradientUtilsReverse *MGradientUtilsReverse::CreateFromClone( FunctionOpInterface todiff, MTypeAnalysis &TA, MFnTypeInfo &oldTypeInfo, const ArrayRef returnPrimals, const ArrayRef returnShadows, ArrayRef retType, ArrayRef constant_args, - mlir::Type additionalArg) { + mlir::Type additionalArg, llvm::StringRef postpasses) { std::string prefix; switch (mode_) { @@ -174,5 +174,5 @@ MGradientUtilsReverse *MGradientUtilsReverse::CreateFromClone( return new MGradientUtilsReverse( Logic, newFunc, todiff, TA, invertedPointers, returnPrimals, returnShadows, constant_values, nonconstant_values, retType, - constant_args, originalToNew, originalToNewOps, mode_, width); + constant_args, originalToNew, originalToNewOps, mode_, width, postpasses); } diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h index b6b63c6d13d5..7f2d26cba2ee 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h @@ -36,7 +36,8 @@ class MGradientUtilsReverse : public MDiffeGradientUtils { ArrayRef ArgDiffeTypes_, IRMapping &originalToNewFn_, std::map &originalToNewFnOps_, - DerivativeMode mode_, unsigned width); + DerivativeMode mode_, unsigned width, + llvm::StringRef postpasses); IRMapping mapReverseModeBlocks; @@ -64,12 +65,14 @@ class MGradientUtilsReverse : public MDiffeGradientUtils { void createReverseModeBlocks(Region &oldFunc, Region &newFunc); - static MGradientUtilsReverse *CreateFromClone( - MEnzymeLogic &Logic, DerivativeMode mode_, unsigned width, - FunctionOpInterface todiff, MTypeAnalysis &TA, MFnTypeInfo &oldTypeInfo, - const ArrayRef returnPrimals, const ArrayRef returnShadows, - llvm::ArrayRef retType, - llvm::ArrayRef constant_args, mlir::Type additionalArg); + static MGradientUtilsReverse * + CreateFromClone(MEnzymeLogic &Logic, DerivativeMode mode_, unsigned width, + FunctionOpInterface todiff, MTypeAnalysis &TA, + MFnTypeInfo &oldTypeInfo, const ArrayRef returnPrimals, + const ArrayRef returnShadows, + llvm::ArrayRef retType, + llvm::ArrayRef constant_args, + mlir::Type additionalArg, llvm::StringRef postpasses); }; } // namespace enzyme diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp index d83532db35a6..c91f5400fefa 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Builders.h" #include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Pass/PassManager.h" #define DEBUG_TYPE "enzyme" @@ -31,6 +32,18 @@ struct DifferentiatePass : public DifferentiatePassBase { void runOnOperation() override; + void getDependentDialects(DialectRegistry ®istry) const override { + mlir::OpPassManager pm; + mlir::LogicalResult result = mlir::parsePassPipeline(postpasses, pm); + if (!mlir::failed(result)) { + pm.getDependentDialects(registry); + } + + registry + .insert(); + } + static std::vector mode_from_fn(FunctionOpInterface fn, DerivativeMode mode) { std::vector retTypes; @@ -150,7 +163,7 @@ struct DifferentiatePass : public DifferentiatePassBase { FunctionOpInterface newFunc = Logic.CreateForwardDiff( fn, retType, constants, TA, returnPrimals, mode, freeMemory, width, /*addedType*/ nullptr, type_args, volatile_args, - /*augmented*/ nullptr); + /*augmented*/ nullptr, postpasses); if (!newFunc) return failure(); @@ -276,7 +289,7 @@ struct DifferentiatePass : public DifferentiatePassBase { Logic.CreateReverseDiff(fn, retType, arg_activities, TA, returnPrimals, returnShadows, mode, freeMemory, width, /*addedType*/ nullptr, type_args, volatile_args, - /*augmented*/ nullptr); + /*augmented*/ nullptr, postpasses); if (!newFunc) return failure(); diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp index 629a567815e7..1e01c8f87bc2 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp @@ -121,13 +121,13 @@ struct DifferentiateWrapperPass returnPrimal, mode, freeMemory, width, /*addedType*/ nullptr, type_args, volatile_args, - /*augmented*/ nullptr); + /*augmented*/ nullptr, ""); } else { newFunc = Logic.CreateReverseDiff( fn, RetActivity, ArgActivity, TA, returnPrimal, returnShadow, mode, freeMemory, width, /*addedType*/ nullptr, type_args, volatile_args, - /*augmented*/ nullptr); + /*augmented*/ nullptr, ""); } if (!newFunc) { signalPassFailure(); diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.td b/enzyme/Enzyme/MLIR/Passes/Passes.td index c5b4df769172..758f27946a79 100644 --- a/enzyme/Enzyme/MLIR/Passes/Passes.td +++ b/enzyme/Enzyme/MLIR/Passes/Passes.td @@ -19,6 +19,15 @@ def DifferentiatePass : Pass<"enzyme"> { "cf::ControlFlowDialect", "tensor::TensorDialect", ]; + let options = [ + Option< + /*C++ variable name=*/"postpasses", + /*CLI argument=*/"postpasses", + /*type=*/"std::string", + /*default=*/"", + /*description=*/"Optimization passes to apply to generated derivative functions" + >, + ]; let constructor = "mlir::enzyme::createDifferentiatePass()"; } diff --git a/enzyme/test/MLIR/ForwardMode/batched_branch.mlir b/enzyme/test/MLIR/ForwardMode/batched_branch.mlir index f20989aa4245..d663eea5afe1 100644 --- a/enzyme/test/MLIR/ForwardMode/batched_branch.mlir +++ b/enzyme/test/MLIR/ForwardMode/batched_branch.mlir @@ -15,10 +15,10 @@ module { } // CHECK: func.func @dsq(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>, %[[arg2:.+]]: f64, %[[arg3:.+]]: tensor<2xf64>) -> tensor<2xf64> { -// CHECK-NEXT: %[[i0:.+]] = call @fwddiffesquare(%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]]) : (f64, tensor<2xf64>, f64, tensor<2xf64>) -> tensor<2xf64> +// CHECK-NEXT: %[[i0:.+]] = call @fwddiffe2square(%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]]) : (f64, tensor<2xf64>, f64, tensor<2xf64>) -> tensor<2xf64> // CHECK-NEXT: return %[[i0]] : tensor<2xf64> // CHECK-NEXT: } -// CHECK: func.func private @fwddiffesquare(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>, %[[arg2:.+]]: f64, %[[arg3]]: tensor<2xf64>) -> tensor<2xf64> { +// CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>, %[[arg2:.+]]: f64, %[[arg3]]: tensor<2xf64>) -> tensor<2xf64> { // CHECK-NEXT: %[[i0:.+]] = arith.cmpf ult, %[[arg0]], %[[arg2]] : f64 // CHECK-NEXT: cf.cond_br %[[i0]], ^bb1(%[[arg0]], %[[arg1]] : f64, tensor<2xf64>), ^bb1(%[[arg2]], %[[arg3]] : f64, tensor<2xf64>) // CHECK-NEXT: ^bb1(%[[i1:.+]]: f64, %[[i2:.+]]: tensor<2xf64>): // 2 preds: ^bb0, ^bb0 diff --git a/enzyme/test/MLIR/ForwardMode/batched_for.mlir b/enzyme/test/MLIR/ForwardMode/batched_for.mlir index 95557cb0b6fc..3ec17ec50f5b 100644 --- a/enzyme/test/MLIR/ForwardMode/batched_for.mlir +++ b/enzyme/test/MLIR/ForwardMode/batched_for.mlir @@ -18,7 +18,7 @@ module { } } -// CHECK: func.func private @fwddiffesquare(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> { +// CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> { // CHECK-DAG: %[[cst:.+]] = arith.constant dense<0.000000e+00> : tensor<2xf64> // CHECK-DAG: %[[cst_0:.+]] = arith.constant 1.000000e+01 : f64 // CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index diff --git a/enzyme/test/MLIR/ForwardMode/batched_scalar.mlir b/enzyme/test/MLIR/ForwardMode/batched_scalar.mlir index f06f86d2a043..d384bdd09337 100644 --- a/enzyme/test/MLIR/ForwardMode/batched_scalar.mlir +++ b/enzyme/test/MLIR/ForwardMode/batched_scalar.mlir @@ -16,9 +16,9 @@ module { // CHECK-NEXT: return %[[i0]] : tensor<2xf64> // CHECK-NEXT: } // CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> { -// CHECK-NEXT: %[[s0:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array}> : f64 -> tensor<2xf64> +// CHECK-NEXT: %[[s0:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array}> : (f64) -> tensor<2xf64> // CHECK-NEXT: %[[i0:.+]] = arith.mulf %[[arg1]], %[[s0]] : tensor<2xf64> -// CHECK-NEXT: %[[s1:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array}> : f64 -> tensor<2xf64> +// CHECK-NEXT: %[[s1:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array}> : (f64) -> tensor<2xf64> // CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[s1]] : tensor<2xf64> // CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : tensor<2xf64> // CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : tensor<2xf64> diff --git a/enzyme/test/MLIR/ForwardMode/batched_tensor.mlir b/enzyme/test/MLIR/ForwardMode/batched_tensor.mlir index 11b75f634a67..2a565f9ff410 100644 --- a/enzyme/test/MLIR/ForwardMode/batched_tensor.mlir +++ b/enzyme/test/MLIR/ForwardMode/batched_tensor.mlir @@ -16,9 +16,9 @@ module { // CHECK-NEXT: return %[[i0]] : tensor<2x10xf64> // CHECK-NEXT: } // CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: tensor<10xf64>, %[[arg1:.+]]: tensor<2x10xf64>) -> tensor<2x10xf64> { -// CHECK-NEXT: %[[s0:.+]] = "enzyme.broadcast"(%%[[arg0]]) <{shape = array}> : (tensor<10xf64>) -> tensor<2x10xf64> +// CHECK-NEXT: %[[s0:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array}> : (tensor<10xf64>) -> tensor<2x10xf64> // CHECK-NEXT: %[[i0:.+]] = arith.mulf %[[arg1]], %[[s0]] : tensor<2x10xf64> -// CHECK-NEXT: %[[s1:.+]] = "enzyme.broadcast"(%%[[arg0]]) <{shape = array}> : (tensor<10xf64>) -> tensor<2x10xf64> +// CHECK-NEXT: %[[s1:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array}> : (tensor<10xf64>) -> tensor<2x10xf64> // CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[s1]] : tensor<2x10xf64> // CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : tensor<2x10xf64> // CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : tensor<10xf64> From 9ebb61d61c25132c12da68d6eeebf1fc58389cbb Mon Sep 17 00:00:00 2001 From: Martin Robinson Date: Fri, 3 Jan 2025 18:06:50 +0000 Subject: [PATCH 37/45] add EnzymeCreatePrimalAndGradient to CApi.h (#2215) * add EnzymeCreatePrimalAndGradient to CApi.h * update EnzymeCreatePrimalAndGradient signature --- enzyme/Enzyme/CApi.h | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/enzyme/Enzyme/CApi.h b/enzyme/Enzyme/CApi.h index be3536c6b7f6..3a38a68c4c74 100644 --- a/enzyme/Enzyme/CApi.h +++ b/enzyme/Enzyme/CApi.h @@ -212,6 +212,16 @@ LLVMValueRef EnzymeCreateForwardDiff( uint8_t *_overwritten_args, size_t overwritten_args_size, EnzymeAugmentedReturnPtr augmented); +LLVMValueRef EnzymeCreatePrimalAndGradient( + EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip, + LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args, + size_t constant_args_size, EnzymeTypeAnalysisRef TA, uint8_t returnValue, + uint8_t dretUsed, CDerivativeMode mode, uint8_t runtimeActivity, + unsigned width, uint8_t freeMemory, LLVMTypeRef additionalArg, + uint8_t forceAnonymousTape, CFnTypeInfo typeInfo, + uint8_t *_overwritten_args, size_t overwritten_args_size, + EnzymeAugmentedReturnPtr augmented, uint8_t AtomicAdd); + #ifdef __cplusplus } #endif From c759460200c9f60ca2c6bc94e70634afd52665a5 Mon Sep 17 00:00:00 2001 From: Matt Bolitho Date: Sun, 5 Jan 2025 16:05:29 +0000 Subject: [PATCH 38/45] Adds Intel oneAPI and GCC Linux CMake presets (#2218) * Adds GCC presets (cloned from clang ones) * Adds Intel presets * Moves compiler flags to relevant preset to avoid CMake warnings * Adds CMAKE_BUILD_WITH_INSTALL_RPATH to Intel presets --- enzyme/CMakePresets.json | 133 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 126 insertions(+), 7 deletions(-) diff --git a/enzyme/CMakePresets.json b/enzyme/CMakePresets.json index 1a5ffa80c5d9..14927307a5af 100644 --- a/enzyme/CMakePresets.json +++ b/enzyme/CMakePresets.json @@ -49,10 +49,7 @@ "cacheVariables": { "CMAKE_C_COMPILER": "clang", "CMAKE_CXX_COMPILER": "clang++", - "CMAKE_CXX_FLAGS": "-Wall -fno-rtti -Werror=unused-variable -Werror=dangling-else -Werror=unused-but-set-variable -Werror=return-type -Werror=nonnull -Werror=unused-result -Werror=reorder -Werror=switch", - "CMAKE_CXX_FLAGS_DEBUG": "-O0 -g -ggdb -fno-omit-frame-pointer", - "CMAKE_CXX_FLAGS_RELEASE": "-O2", - "CMAKE_CXX_FLAGS_RELWITHDEBINFO": "-O2 -g -ggdb" + "CMAKE_CXX_FLAGS": "-Wall -fno-rtti -Werror=unused-variable -Werror=dangling-else -Werror=unused-but-set-variable -Werror=return-type -Werror=nonnull -Werror=unused-result -Werror=reorder -Werror=switch" } }, { @@ -60,7 +57,8 @@ "displayName": "Clang x64 Linux Debug", "inherits": "x64-linux-clang", "cacheVariables": { - "CMAKE_BUILD_TYPE": "Debug" + "CMAKE_BUILD_TYPE": "Debug", + "CMAKE_CXX_FLAGS_DEBUG": "-O0 -g -ggdb -fno-omit-frame-pointer" } }, { @@ -68,7 +66,8 @@ "displayName": "Clang x64 Linux Release", "inherits": "x64-linux-clang", "cacheVariables": { - "CMAKE_BUILD_TYPE": "Release" + "CMAKE_BUILD_TYPE": "Release", + "CMAKE_CXX_FLAGS_RELEASE": "-O2" } }, { @@ -76,7 +75,91 @@ "displayName": "Clang x64 Linux Release with Debug Info", "inherits": "x64-linux-clang", "cacheVariables": { - "CMAKE_BUILD_TYPE": "RelWithDebInfo" + "CMAKE_BUILD_TYPE": "RelWithDebInfo", + "CMAKE_CXX_FLAGS_RELWITHDEBINFO": "-O2 -g -ggdb" + } + }, + { + "name": "x64-linux-gcc", + "description": "Base preset for Linux development using GNU compilers.", + "hidden": true, + "inherits": [ + "config-base-x64", + "config-base-linux" + ], + "cacheVariables": { + "CMAKE_C_COMPILER": "gcc", + "CMAKE_CXX_COMPILER": "g++", + "CMAKE_CXX_FLAGS": "-Wall -fno-rtti -Werror=unused-variable -Werror=dangling-else -Werror=unused-but-set-variable -Werror=return-type -Werror=nonnull -Werror=unused-result -Werror=reorder -Werror=switch -Wno-comment" + } + }, + { + "name": "x64-linux-gcc-debug", + "displayName": "GCC x64 Linux Debug", + "inherits": "x64-linux-gcc", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Debug", + "CMAKE_CXX_FLAGS_DEBUG": "-O0 -g -ggdb -fno-omit-frame-pointer" + } + }, + { + "name": "x64-linux-gcc-release", + "displayName": "GCC x64 Linux Release", + "inherits": "x64-linux-gcc", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release", + "CMAKE_CXX_FLAGS_RELEASE": "-O2" + } + }, + { + "name": "x64-linux-gcc-release-with-debug-info", + "displayName": "GCC x64 Linux Release with Debug Info", + "inherits": "x64-linux-gcc", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "RelWithDebInfo", + "CMAKE_CXX_FLAGS_RELWITHDEBINFO": "-O2 -g -ggdb" + } + }, + { + "name": "x64-linux-intel", + "description": "Base preset for Linux development using Intel oneAPI compilers.", + "hidden": true, + "inherits": [ + "config-base-x64", + "config-base-linux" + ], + "cacheVariables": { + "CMAKE_C_COMPILER": "icx", + "CMAKE_CXX_COMPILER": "icpx", + "CMAKE_CXX_FLAGS": "-Wall -fno-rtti -Werror=unused-variable -Werror=dangling-else -Werror=unused-but-set-variable -Werror=return-type -Werror=nonnull -Werror=unused-result -Werror=reorder -Werror=switch", + "CMAKE_BUILD_WITH_INSTALL_RPATH": "ON" + } + }, + { + "name": "x64-linux-intel-debug", + "displayName": "Intel x64 Linux Debug", + "inherits": "x64-linux-intel", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Debug", + "CMAKE_CXX_FLAGS_DEBUG": "-O0 -g -ggdb -fno-omit-frame-pointer" + } + }, + { + "name": "x64-linux-intel-release", + "displayName": "Intel x64 Linux Release", + "inherits": "x64-linux-intel", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release", + "CMAKE_CXX_FLAGS_RELEASE": "-O2" + } + }, + { + "name": "x64-linux-intel-release-with-debug-info", + "displayName": "Intel x64 Linux Release with Debug Info", + "inherits": "x64-linux-intel", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "RelWithDebInfo", + "CMAKE_CXX_FLAGS_RELWITHDEBINFO": "-O2 -g -ggdb" } } ], @@ -98,6 +181,42 @@ "displayName": "Clang x64 Linux Release with Debug Info", "description": "Builds the project using Clang on Linux in Release configuration with debug info.", "configurePreset": "x64-linux-clang-release-with-debug-info" + }, + { + "name": "x64-linux-gcc-debug", + "displayName": "GCC x64 Linux Debug", + "description": "Builds the project using GCC on Linux in Debug configuration.", + "configurePreset": "x64-linux-gcc-debug" + }, + { + "name": "x64-linux-gcc-release", + "displayName": "GCC x64 Linux Release", + "description": "Builds the project using GCC on Linux in Release configuration.", + "configurePreset": "x64-linux-gcc-release" + }, + { + "name": "x64-linux-gcc-release-with-debug-info", + "displayName": "GCC x64 Linux Release with Debug Info", + "description": "Builds the project using GCC on Linux in Release configuration with debug info.", + "configurePreset": "x64-linux-gcc-release-with-debug-info" + }, + { + "name": "x64-linux-intel-debug", + "displayName": "Intel x64 Linux Debug", + "description": "Builds the project using Intel oneAPI compilers on Linux in Debug configuration.", + "configurePreset": "x64-linux-intel-debug" + }, + { + "name": "x64-linux-intel-release", + "displayName": "Intel x64 Linux Release", + "description": "Builds the project using Intel oneAPI compilers on Linux in Release configuration.", + "configurePreset": "x64-linux-intel-release" + }, + { + "name": "x64-linux-intel-release-with-debug-info", + "displayName": "Intel x64 Linux Release with Debug Info", + "description": "Builds the project using Intel oneAPI compilers on Linux in Release configuration with debug info.", + "configurePreset": "x64-linux-intel-release-with-debug-info" } ] } From 5b330a905044a32afae8f0e4b69fc558c08c8cc1 Mon Sep 17 00:00:00 2001 From: Paul Berg Date: Sun, 5 Jan 2025 21:08:32 +0100 Subject: [PATCH 39/45] mlir: Add Enzyme ops removal on structured control flow (#2200) * mlir: Add Enzyme ops removal on structured control flow * format * use AutoDiffTypeInterface for batching * remove * add test with unknown number of iterations * don't push same value twice * tensor extract/insert * reserve the right size * better batchType * better comment --- .../BuiltinAutoDiffTypeInterfaceImpl.cpp | 46 +- .../SCFAutoDiffOpInterfaceImpl.cpp | 551 +++++++++++++++--- .../MLIR/Interfaces/AutoDiffOpInterface.td | 18 + .../MLIR/Interfaces/AutoDiffTypeInterface.td | 2 +- enzyme/Enzyme/MLIR/Passes/Passes.h | 4 + enzyme/Enzyme/MLIR/Passes/Passes.td | 3 + enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp | 58 ++ enzyme/Enzyme/MLIR/Passes/RemovalUtils.h | 54 ++ .../MLIR/Passes/RemoveUnusedEnzymeOps.cpp | 27 +- enzyme/Enzyme/MLIR/enzymemlir-opt.cpp | 1 + enzyme/test/MLIR/ReverseMode/pow.mlir | 43 +- enzyme/test/MLIR/ReverseMode/scf_for.mlir | 42 ++ 12 files changed, 701 insertions(+), 148 deletions(-) create mode 100644 enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp create mode 100644 enzyme/Enzyme/MLIR/Passes/RemovalUtils.h create mode 100644 enzyme/test/MLIR/ReverseMode/scf_for.mlir diff --git a/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp index 7c72b97d0934..c38b990ceb6b 100644 --- a/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp @@ -25,6 +25,22 @@ using namespace mlir; using namespace mlir::enzyme; namespace { + +static mlir::Type batchType(mlir::Type type, int64_t width) { + if (width == 1) + return type; + + if (auto TT = dyn_cast(type)) { + SmallVector shape; + shape.reserve(TT.getShape().size() + 1); + shape.push_back(width); + shape.append(TT.getShape().begin(), TT.getShape().end()); + return TT.clone(shape); + } + + return RankedTensorType::get({width}, type); +} + class FloatTypeInterface : public AutoDiffTypeInterface::ExternalModel { @@ -44,12 +60,8 @@ class FloatTypeInterface return a; } - Type getShadowType(Type self, unsigned width) const { - if (width > 1) { - return RankedTensorType::get({width}, self); - } else { - return self; - } + Type getShadowType(Type self, int64_t width) const { + return batchType(self, width); } bool isMutable(Type self) const { return false; } @@ -108,16 +120,8 @@ class TensorTypeInterface return added; } - Type getShadowType(Type self, unsigned width) const { - if (width != 1) { - auto tenType = self.cast(); - auto shape = tenType.getShape(); - SmallVector newShape; - newShape.push_back(width); - newShape.append(shape.begin(), shape.end()); - return RankedTensorType::get(newShape, tenType.getElementType()); - } - return self; + Type getShadowType(Type self, int64_t width) const { + return batchType(self, width); } bool isMutable(Type self) const { return false; } @@ -148,9 +152,8 @@ class IntegerTypeInterface return a; } - Type getShadowType(Type self, unsigned width) const { - assert(width == 1 && "unsupported width != 1"); - return self; + Type getShadowType(Type self, int64_t width) const { + return batchType(self, width); } bool isMutable(Type self) const { return false; } @@ -182,9 +185,8 @@ class ComplexTypeInterface return builder.create(loc, a)->getResult(0); } - Type getShadowType(Type self, unsigned width) const { - assert(width == 1 && "unsupported width != 1"); - return self; + Type getShadowType(Type self, int64_t width) const { + return batchType(self, width); } bool isMutable(Type self) const { return false; } diff --git a/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp index cbecedc21827..4af2c2d97e63 100644 --- a/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp @@ -17,7 +17,10 @@ #include "Interfaces/EnzymeLogic.h" #include "Interfaces/GradientUtils.h" #include "Interfaces/GradientUtilsReverse.h" +#include "Passes/RemovalUtils.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/IR/Types.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" @@ -32,126 +35,499 @@ using namespace mlir::enzyme; namespace { #include "Implementations/SCFDerivatives.inc" -struct ForOpInterfaceReverse - : public ReverseAutoDiffOpInterface::ExternalModel { - LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder, - MGradientUtilsReverse *gutils, - SmallVector caches) const { +// TODO: support non constant number of iteration by using unknown dimensions +static std::optional getConstantNumberOfIterations(scf::ForOp forOp) { + auto lb = forOp.getLowerBound(); + auto ub = forOp.getUpperBound(); + auto step = forOp.getStep(); + + IntegerAttr lbAttr, ubAttr, stepAttr; + if (!matchPattern(lb, m_Constant(&lbAttr))) + return std::nullopt; + if (!matchPattern(ub, m_Constant(&ubAttr))) + return std::nullopt; + if (!matchPattern(step, m_Constant(&stepAttr))) + return std::nullopt; + + int64_t lbI = lbAttr.getInt(), ubI = ubAttr.getInt(), + stepI = stepAttr.getInt(); + + return (ubI - lbI) / stepI; +} + +static Value getNumberOfIterations(OpBuilder &builder, scf::ForOp forOp) { + Value lb = forOp.getLowerBound(), ub = forOp.getUpperBound(), + step = forOp.getStep(); + Value diff = builder.create(forOp->getLoc(), ub, lb); + Value nSteps = builder.create(forOp->getLoc(), diff, step); + return nSteps; +} + +struct ForOpEnzymeOpsRemover + : public EnzymeOpsRemoverOpInterface::ExternalModel { + + LogicalResult removeEnzymeOps(Operation *op) const { auto forOp = cast(op); + scf::ForOp otherForOp; // where caches pops are + + if (removeOpsWithinBlock(forOp.getBody()).failed()) + return failure(); + + // Gradients whose values need to be passed as iteration variables. + llvm::SmallDenseSet updatedGradients; + + llvm::MapVector cachesMap; + + Block *body = forOp.getBody(); - // Begin Perform d(yielded value[i]) += d(result[i]); d(result[i]) = 0 - SmallVector resDiffes; - for (OpResult v : forOp.getResults()) { - if (!gutils->isConstantValue(v)) { - auto autoDiffType = cast(v.getType()); - if (!autoDiffType.isMutable()) { - auto prev = gutils->diffe(v, builder); - gutils->zeroDiffe(v, builder); - resDiffes.push_back(prev); - continue; + for (auto &it : *body) { + Operation *op = ⁢ + + if (auto setOp = dyn_cast(op)) + updatedGradients.insert(setOp.getGradient()); + + if (auto pushOp = dyn_cast(op)) { + CacheInfo info(pushOp.getCache()); + + Value pushedValue = info.pushedValue(); + if (cachesMap.contains(pushedValue)) { + info = info.merge(cachesMap.lookup(pushedValue)); } + cachesMap[pushedValue] = info; + + otherForOp = cast(info.popOp->getParentOp()); + } + } + + SmallVector caches; + caches.reserve(cachesMap.size()); + for (auto &&[_, info] : cachesMap) { + caches.push_back(info); + } + + // nothing to do + if (updatedGradients.empty() && caches.empty()) + return success(); + + OpBuilder builder(forOp); + for (auto &it : *body) { + Operation *op = ⁢ + + auto getOp = dyn_cast(op); + if (!getOp || updatedGradients.contains(getOp.getGradient())) + continue; + + auto outerGet = builder.create( + getOp->getLoc(), + cast(getOp.getResult().getType()).getBasetype(), + getOp.getGradient()); + + getOp.getResult().replaceAllUsesWith(outerGet.getResult()); + getOp->erase(); + } + + auto term = body->getTerminator(); + + SmallVector newOperands(forOp.getInitArgs()); + for (auto grad : updatedGradients) { + auto Ty = cast(grad.getType()).getBasetype(); + auto outerGet = builder.create(grad.getLoc(), Ty, grad); + + newOperands.push_back(outerGet.getResult()); + auto newArg = body->addArgument(Ty, grad.getLoc()); + + { + OpBuilder::InsertionGuard guard(builder); + + builder.setInsertionPointToStart(body); + builder.create(grad.getLoc(), grad, newArg); + + builder.setInsertionPoint(term); + + auto outputVal = + builder.create(grad.getLoc(), Ty, grad).getResult(); + term->insertOperands(term->getNumOperands(), ValueRange(outputVal)); } - resDiffes.push_back(nullptr); } - for (auto ® : op->getRegions()) { - auto termIface = - cast(reg.begin()->getTerminator()); - - SmallVector successors; - termIface.getSuccessorRegions( - SmallVector(termIface->getNumOperands(), Attribute()), - successors); - - for (auto &successor : successors) { - if (!successor.isParent()) - continue; - OperandRange operandRange = termIface.getSuccessorOperands(successor); - assert(operandRange.size() == resDiffes.size()); - - // There is an assumption here that there is only regions that branch to - // the successor. Specifically, otherwise we would need to - // gutils->addToDiffe select (if came from that result) - for (auto &&[prev, post] : llvm::zip(operandRange, resDiffes)) { - if (!post) - continue; - if (!gutils->isConstantValue(prev)) - gutils->addToDiffe(prev, post, builder); + auto numIters = getConstantNumberOfIterations(forOp); + Value inductionVariable; // [0,..., N - 1] counter + + if (matchPattern(forOp.getLowerBound(), m_Zero()) && + matchPattern(forOp.getStep(), m_One())) { + inductionVariable = body->getArgument(0); + } + + for (auto info : caches) { + Value cache = info.initOp.getResult(); + + // push does not depend on a value inside the loop, we can hoist the + // push/pop before the for loops. + if (info.pushedValue().getParentRegion() != forOp->getRegion(0)) { + auto newPush = builder.create(cache.getLoc(), cache, + info.pushedValue()); + info.pushOp->erase(); + info.pushOp = newPush; + + { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(info.popOp->getParentOp()); + + auto popVal = info.popOp.getResult(); + auto newPop = builder.create(cache.getLoc(), + popVal.getType(), cache); + popVal.replaceAllUsesWith(newPop.getResult()); + info.popOp->erase(); + info.popOp = newPop; } + + continue; + } + + if (!inductionVariable) { + Value zero = builder.create(forOp->getLoc(), + builder.getIndexAttr(0)); + newOperands.push_back(zero); + + inductionVariable = body->addArgument(zero.getType(), forOp->getLoc()); + { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(term); + + auto one = builder.create(forOp->getLoc(), + builder.getIndexAttr(1)); + auto newInductionVar = builder.create( + forOp->getLoc(), inductionVariable, one); + term->insertOperands(term->getNumOperands(), + ValueRange(newInductionVar)); + } + } + + auto newType = + info.cachedType() + .cast() + .getShadowType(numIters.value_or(mlir::ShapedType::kDynamic)) + .cast(); + + SmallVector dynamicDims; + + for (auto it : llvm::enumerate(newType.getShape())) { + if (ShapedType::isDynamic(it.value())) { + if (it.index() == 0) + dynamicDims.push_back(getNumberOfIterations(builder, forOp)); + else + return failure(); // TODO: find dynamic dims within the body. + } + } + + Value initValue = builder.create(info.initOp->getLoc(), + newType, dynamicDims); + + // cast(newType).createNullValue( + // builder, info.initOp->getLoc()); + + newOperands.push_back(initValue); + + auto cacheValue = body->addArgument(newType, info.pushOp->getLoc()); + + { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(info.pushOp); + + // TODO: if type is tensor, use insert_slice instead + Value newCacheValue; + if (auto TT = dyn_cast(info.cachedType())) { + auto shape = TT.getShape(); + + SmallVector offsets(shape.size() + 1, 0); + offsets[0] = ShapedType::kDynamic; + + SmallVector sizes; + sizes.reserve(shape.size() + 1); + sizes.push_back(1); + sizes.append(shape.begin(), shape.end()); + + SmallVector strides(shape.size() + 1, 1); + + newCacheValue = builder.create( + info.pushOp->getLoc(), info.pushOp.getValue(), cacheValue, + ValueRange(inductionVariable), ValueRange(), ValueRange(), + builder.getDenseI64ArrayAttr(offsets), + builder.getDenseI64ArrayAttr(sizes), + builder.getDenseI64ArrayAttr(strides)); + } else { + newCacheValue = builder.create( + info.pushOp->getLoc(), info.pushOp.getValue(), cacheValue, + inductionVariable); + } + + term->insertOperands(term->getNumOperands(), ValueRange(newCacheValue)); } } - // End Perform d(yielded value[i]) += d(result[i]); d(result[i]) = 0 + + auto numInitArgs = forOp.getInitArgs().size(); + auto newFor = builder.create( + op->getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), + forOp.getStep(), newOperands); + + newFor.getRegion().takeBody(forOp.getRegion()); + + unsigned resultIdx = numInitArgs; + for (auto grad : updatedGradients) { + // set the updated gradient after the new for op. + OpBuilder::InsertionGuard guard(builder); + builder.create(grad.getLoc(), grad, + newFor->getResult(resultIdx)); + ++resultIdx; + } + + if (inductionVariable && caches.size()) { + if (isa(inductionVariable) && + cast(inductionVariable).getArgNumber() != 0) + resultIdx++; + + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(otherForOp); + SmallVector operands(otherForOp.getInitArgs().begin(), + otherForOp.getInitArgs().end()); + operands.push_back(numIters.has_value() + ? builder.create( + otherForOp->getLoc(), + builder.getIndexAttr(numIters.value() - 1)) + : getNumberOfIterations(builder, forOp)); + + Block *otherBody = otherForOp.getBody(); + Value otherInductionVariable = + otherBody->addArgument(builder.getIndexType(), otherForOp->getLoc()); + auto otherTerm = otherBody->getTerminator(); + + builder.setInsertionPoint(otherTerm); + + otherInductionVariable = + builder + .create( + otherForOp->getLoc(), otherInductionVariable, + builder + .create(otherForOp->getLoc(), + builder.getIndexAttr(1)) + .getResult()) + .getResult(); + otherTerm->insertOperands(otherTerm->getNumOperands(), + ValueRange(otherInductionVariable)); + + builder.setInsertionPoint(otherForOp); + auto newOtherForOp = builder.create( + otherForOp->getLoc(), otherForOp.getLowerBound(), + otherForOp.getUpperBound(), otherForOp.getStep(), operands); + + for (auto &&[res, newRes] : + llvm::zip(otherForOp->getResults(), newOtherForOp->getResults())) { + res.replaceAllUsesWith(newRes); + } + newOtherForOp.getRegion().takeBody(otherForOp.getRegion()); + + otherForOp->erase(); + otherForOp = newOtherForOp; + } + + for (auto info : caches) { + if (info.pushedValue().getParentRegion() != newFor->getRegion(0)) + continue; + + Value cache = info.initOp.getResult(); + + auto newType = + info.cachedType().cast().getShadowType( + numIters.value_or(ShapedType::kDynamic)); + enzyme::InitOp newInit = ({ + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(info.initOp); + + builder.create( + info.initOp->getLoc(), + enzyme::CacheType::get(cache.getContext(), newType)); + }); + info.pushOp = ({ + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointAfter(newFor); + auto newPush = builder.create( + cache.getLoc(), newInit.getResult(), newFor->getResult(resultIdx)); + info.pushOp->erase(); + newPush; + }); + + resultIdx++; + + { + OpBuilder::InsertionGuard guard(builder); + + builder.setInsertionPoint(otherForOp); + + auto popNewValue = builder.create( + info.popOp->getLoc(), newType, newInit.getResult()); + + Block *popBody = otherForOp.getBody(); + builder.setInsertionPoint(info.popOp); + + Value newInductionVariable = + popBody->getArgument(popBody->getNumArguments() - 1); + + Value popValue; + if (auto TT = dyn_cast(info.cachedType())) { + auto shape = TT.getShape(); + SmallVector offsets(shape.size() + 1, 0); + offsets[0] = ShapedType::kDynamic; + + SmallVector sizes; + sizes.reserve(shape.size() + 1); + sizes.push_back(1); + sizes.append(shape.begin(), shape.end()); + + SmallVector strides(shape.size() + 1, 1); + + popValue = + builder + .create( + info.popOp->getLoc(), TT, popNewValue, + ValueRange(newInductionVariable), ValueRange(), + ValueRange(), builder.getDenseI64ArrayAttr(offsets), + builder.getDenseI64ArrayAttr(sizes), + builder.getDenseI64ArrayAttr(strides)) + .getResult(); + } else { + popValue = + builder + .create(info.popOp->getLoc(), popNewValue, + newInductionVariable) + .getResult(); + } + + info.popOp.getResult().replaceAllUsesWith(popValue); + info.popOp->erase(); + } + } + + forOp->erase(); + + return success(); + } +}; + +struct ForOpInterfaceReverse + : public ReverseAutoDiffOpInterface::ExternalModel { + LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils, + SmallVector caches) const { + // SCF ForOp has 3 more operands than results (lb, ub, step). + // Its body has 1 more argument than yielded values (the induction + // variable). + + auto forOp = cast(op); + + SmallVector operandsActive(forOp.getNumOperands() - 3, false); + for (int i = 0, e = operandsActive.size(); i < e; ++i) { + operandsActive[i] = !gutils->isConstantValue(op->getOperand(i + 3)) || + !gutils->isConstantValue(op->getResult(i)); + } auto start = gutils->popCache(caches[0], builder); auto end = gutils->popCache(caches[1], builder); auto step = gutils->popCache(caches[2], builder); + SmallVector incomingGradients; + for (auto &&[active, res] : + llvm::zip_equal(operandsActive, op->getResults())) { + if (active) { + incomingGradients.push_back(gutils->diffe(res, builder)); + if (!gutils->isConstantValue(res)) + gutils->zeroDiffe(res, builder); + } + } + auto repFor = builder.create(forOp.getLoc(), start, end, step, - ArrayRef()); - // erase scf yield - repFor.getBody()->begin()->erase(); + incomingGradients); + bool valid = true; for (auto &&[oldReg, newReg] : llvm::zip(op->getRegions(), repFor->getRegions())) { + for (auto &&[oBB, revBB] : llvm::zip(oldReg, newReg)) { + OpBuilder bodyBuilder(&revBB, revBB.end()); + + // Create implicit terminator if not present (when num results > 0) + if (revBB.empty()) { + bodyBuilder.create(repFor->getLoc()); + } + bodyBuilder.setInsertionPoint(revBB.getTerminator()); - // This code assumes at most one terminating block for each region (lest - // the append happen multiple times) - auto buildFuncReturnOp = [&](OpBuilder &builder, Block *oBB) { - auto loc = oBB->rbegin()->getLoc(); + // All values defined in the body should have no use outside this block + // therefore we can set their diffe to zero upon entering the reverse + // block to simplify the work of the remove-unnecessary-enzyme-ops pass. + for (auto operand : oBB.getArguments().slice(1)) { + if (!gutils->isConstantValue(operand)) { + gutils->zeroDiffe(operand, bodyBuilder); + } + } - auto idx = repFor.getInductionVar(); + for (auto &it : oBB.getOperations()) { + for (auto res : it.getResults()) { + if (!gutils->isConstantValue(res)) { + gutils->zeroDiffe(res, bodyBuilder); + } + } + } - auto lhs = builder.create(loc, idx, step); + auto term = oBB.getTerminator(); - // This needs to know a condition describing which predecessor this will - // return to, to select the right value Here we use the condition i + - // step >= end to determine the last iteration + for (auto &&[active, arg, operand] : + llvm::zip_equal(operandsActive, revBB.getArguments().slice(1), + term->getOperands())) { + if (active) { + // Set diffe here, not add because it should not accumulate across + // iterations. Instead the new gradient for this operand is passed + // in the return of the reverse for body. + gutils->setDiffe(operand, arg, bodyBuilder); + } + } - auto condition = builder.create( - loc, arith::CmpIPredicate::sge, lhs, end); + auto first = oBB.rbegin(); + first++; // skip terminator - for (auto [arg, init_arg] : - llvm::zip(oBB->getArguments().slice(1), forOp.getInitArgs())) { - if (!gutils->isConstantValue(arg) && - !cast(arg.getType()).isMutable()) { - auto diffe = gutils->diffe(arg, builder); - gutils->zeroDiffe(arg, builder); + auto last = oBB.rend(); - auto zero = cast(diffe.getType()) - .createNullValue(builder, loc); - auto outside = - builder.create(loc, condition, diffe, zero); - auto inside = - builder.create(loc, condition, zero, diffe); + for (auto it = first; it != last; ++it) { + Operation *op = &*it; + valid &= + gutils->Logic.visitChild(op, bodyBuilder, gutils).succeeded(); + } - // For each predecessor, if we came from that predecessor += the - // shadow of the arg [after zero'ing] - if (!gutils->isConstantValue(init_arg)) { - gutils->addToDiffe(init_arg, outside, builder); - } + SmallVector newResults; + newResults.reserve(incomingGradients.size()); - if (!gutils->isConstantValue(arg)) { - gutils->addToDiffe(arg, inside, builder); - } + for (auto &&[active, arg] : + llvm::zip_equal(operandsActive, oBB.getArguments().slice(1))) { + if (active) { + newResults.push_back(gutils->diffe(arg, bodyBuilder)); + if (!gutils->isConstantValue(arg)) + gutils->zeroDiffe(arg, bodyBuilder); } } - builder.create(loc); - }; - for (auto &&[oBB, revBB] : llvm::zip(oldReg, newReg)) { - gutils->mapReverseModeBlocks.map(&oBB, &revBB); + // yield new gradient values + revBB.getTerminator()->setOperands(newResults); } - for (auto &&[oBB, revBB] : llvm::zip(oldReg, newReg)) { - auto sub = gutils->Logic.visitChildren(&oBB, &revBB, gutils); - if (!sub.succeeded()) - return sub; - Block *newBB = gutils->getNewFromOriginal(&oBB); - gutils->Logic.handlePredecessors(&oBB, newBB, &revBB, gutils, - buildFuncReturnOp); + } + + for (auto &&[active, res, arg] : llvm::zip_equal( + operandsActive, repFor->getResults(), forOp.getInitArgs())) { + if (active) { + if (!gutils->isConstantValue(arg)) + gutils->addToDiffe(arg, res, builder); } } - return success(); + + return success(valid); } SmallVector cacheValues(Operation *op, @@ -190,5 +566,6 @@ void mlir::enzyme::registerSCFDialectAutoDiffInterface( registry.addExtension(+[](MLIRContext *context, scf::SCFDialect *) { registerInterfaces(context); scf::ForOp::attachInterface(*context); + scf::ForOp::attachInterface(*context); }); } diff --git a/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td b/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td index 771dc22001a1..7510d3772984 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td +++ b/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td @@ -180,4 +180,22 @@ def BatchOpInterface : OpInterface<"BatchOpInterface"> { ]; } +def EnzymeOpsRemoverOpInterface : OpInterface<"EnzymeOpsRemoverOpInterface"> { + let description = [{ + An operation with nested operations which can move inner enzyme operations outside of itself. + }]; + let cppNamespace = "::mlir::enzyme"; + + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Pushes the inner enzyme operations outside of self. + }], + /*retTy=*/"::mlir::LogicalResult", + /*methodName=*/"removeEnzymeOps", + /*args=*/(ins) + > + ]; +} + #endif // ENZYME_MLIR_INTERFACES_AUTODIFFOPINTERFACES diff --git a/enzyme/Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.td b/enzyme/Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.td index 2e9f1697af43..ced1e68700b5 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.td +++ b/enzyme/Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.td @@ -57,7 +57,7 @@ def AutoDiffTypeInterface : TypeInterface<"AutoDiffTypeInterface"> { }], /*retTy=*/"::mlir::Type", /*methodName=*/"getShadowType", - /*args=*/(ins "unsigned":$width) + /*args=*/(ins "int64_t":$width) >, InterfaceMethod< /*desc=*/[{ diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.h b/enzyme/Enzyme/MLIR/Passes/Passes.h index fb6df3e2208c..fff304a7e492 100644 --- a/enzyme/Enzyme/MLIR/Passes/Passes.h +++ b/enzyme/Enzyme/MLIR/Passes/Passes.h @@ -89,6 +89,10 @@ namespace LLVM { class LLVMDialect; } // end namespace LLVM +namespace tensor { +class TensorDialect; +} // end namespace tensor + #define GEN_PASS_REGISTRATION #include "Passes/Passes.h.inc" diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.td b/enzyme/Enzyme/MLIR/Passes/Passes.td index 758f27946a79..d3494956a12f 100644 --- a/enzyme/Enzyme/MLIR/Passes/Passes.td +++ b/enzyme/Enzyme/MLIR/Passes/Passes.td @@ -183,6 +183,9 @@ def AddToOpToSplitPass : Pass<"add-to-op-to-split"> { def RemoveUnusedEnzymeOpsPass : Pass<"remove-unnecessary-enzyme-ops"> { let summary = "Remove Unnecessary Enzyme Ops"; + let dependentDialects = [ + "tensor::TensorDialect" + ]; let constructor = "mlir::enzyme::createRemoveUnusedEnzymeOpsPass()"; } diff --git a/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp b/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp new file mode 100644 index 000000000000..002b11d6bc92 --- /dev/null +++ b/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp @@ -0,0 +1,58 @@ +//===- RemovalUtils.cpp - Utilities to remove Enzyme ops -------* C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "RemovalUtils.h" +#include "Interfaces/AutoDiffOpInterface.h" +#include + +mlir::enzyme::CacheInfo +mlir::enzyme::CacheInfo::merge(mlir::enzyme::CacheInfo other) { + assert(other.pushOp->getBlock() == pushOp->getBlock()); + assert(other.popOp->getBlock() == popOp->getBlock()); + + enzyme::InitOp newInitOp; + if (other.initOp->isBeforeInBlock(initOp)) { + newInitOp = other.initOp; + initOp.getResult().replaceAllUsesWith(newInitOp.getResult()); + initOp->erase(); + } else { + newInitOp = initOp; + other.initOp.getResult().replaceAllUsesWith(newInitOp.getResult()); + other.initOp->erase(); + } + + enzyme::PushOp newPushOp = pushOp; + other.pushOp->erase(); + + enzyme::PopOp newPopOp; + if (other.popOp->isBeforeInBlock(popOp)) { + newPopOp = other.popOp; + popOp.getResult().replaceAllUsesWith(newPopOp.getResult()); + popOp->erase(); + } else { + newPopOp = popOp; + other.popOp.getResult().replaceAllUsesWith(newPopOp.getResult()); + other.popOp->erase(); + } + + CacheInfo newInfo{newInitOp}; + return newInfo; +} + +mlir::LogicalResult mlir::enzyme::removeOpsWithinBlock(mlir::Block *block) { + bool valid = true; + + for (auto &it : *block) { + mlir::Operation *op = ⁢ + if (auto iface = dyn_cast(op)) { + valid &= iface.removeEnzymeOps().succeeded(); + } + } + + return success(valid); +} diff --git a/enzyme/Enzyme/MLIR/Passes/RemovalUtils.h b/enzyme/Enzyme/MLIR/Passes/RemovalUtils.h new file mode 100644 index 000000000000..32308ed1d6b0 --- /dev/null +++ b/enzyme/Enzyme/MLIR/Passes/RemovalUtils.h @@ -0,0 +1,54 @@ +//===- RemovalUtils.h - Utilities to remove Enzyme ops -------* C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#pragma once + +#include "Dialect/Ops.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/IRMapping.h" + +namespace mlir { +namespace enzyme { + +/// Information about a cache, each cache init should have one corresponding +/// push and pop. +struct CacheInfo { + enzyme::InitOp initOp; + enzyme::PushOp pushOp; + enzyme::PopOp popOp; + + CacheInfo() { + initOp = nullptr; + pushOp = nullptr; + popOp = nullptr; + } + CacheInfo(Value cache) { + initOp = cache.getDefiningOp(); + unsigned nusers = 0; + for (auto user : cache.getUsers()) { + nusers++; + if (!popOp) + popOp = dyn_cast(user); + if (!pushOp) + pushOp = dyn_cast(user); + } + assert(nusers == 2); // TODO: support more uses + } + + Value pushedValue() { return pushOp.getValue(); } + Type cachedType() { + return initOp.getResult().getType().cast().getType(); + } + + // Pushed values must be the same + CacheInfo merge(CacheInfo other); +}; + +LogicalResult removeOpsWithinBlock(Block *block); + +} // namespace enzyme +} // namespace mlir diff --git a/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp b/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp index d478380f1df0..8ee77113e9ae 100644 --- a/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp +++ b/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp @@ -12,13 +12,16 @@ #include "Dialect/Dialect.h" #include "Dialect/Ops.h" +#include "Interfaces/AutoDiffOpInterface.h" #include "PassDetails.h" #include "Passes/Passes.h" +#include "Passes/RemovalUtils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Rewrite/PatternApplicator.h" @@ -287,17 +290,29 @@ struct InitSimplify : public OpRewritePattern { } }; +static void applyPatterns(Operation *op) { + RewritePatternSet patterns(op->getContext()); + patterns.insert(op->getContext()); + + GreedyRewriteConfig config; + (void)applyPatternsAndFoldGreedily(op, std::move(patterns), config); +} + struct RemoveUnusedEnzymeOpsPass : public enzyme::RemoveUnusedEnzymeOpsPassBase { void runOnOperation() override { + auto op = getOperation(); + + applyPatterns(op); - RewritePatternSet patterns(&getContext()); - patterns.insert(&getContext()); + op->walk([&](FunctionOpInterface func) { + func->walk([&](enzyme::EnzymeOpsRemoverOpInterface iface) { + iface.removeEnzymeOps(); + }); + }); - GreedyRewriteConfig config; - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), - config); + applyPatterns(op); } }; diff --git a/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp b/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp index 99e7243129be..c3764baf7d55 100644 --- a/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp +++ b/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp @@ -28,6 +28,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/InitAllPasses.h" #include "mlir/Pass/PassRegistry.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" diff --git a/enzyme/test/MLIR/ReverseMode/pow.mlir b/enzyme/test/MLIR/ReverseMode/pow.mlir index ef3e9d599529..b8a07fb3928c 100644 --- a/enzyme/test/MLIR/ReverseMode/pow.mlir +++ b/enzyme/test/MLIR/ReverseMode/pow.mlir @@ -20,45 +20,24 @@ module { } // CHECK: func.func private @diffeppow(%[[x:.+]]: f64, %[[dr:.+]]: f64) -> f64 { +// CHECK-NEXT: %c9 = arith.constant 9 : index // CHECK-NEXT: %c10 = arith.constant 10 : index // CHECK-NEXT: %c1 = arith.constant 1 : index // CHECK-NEXT: %c0 = arith.constant 0 : index // CHECK-NEXT: %[[one:.+]] = arith.constant 1.0 // CHECK-NEXT: %[[zero:.+]] = arith.constant 0.000000e+00 : f64 -// CHECK-NEXT: %[[xshadow:.+]] = "enzyme.init"() : () -> !enzyme.Gradient -// CHECK-NEXT: "enzyme.set"(%[[xshadow]], %[[zero]]) : (!enzyme.Gradient, f64) -> () -// CHECK-NEXT: %[[itshadow:.+]] = "enzyme.init"() : () -> !enzyme.Gradient -// CHECK-NEXT: "enzyme.set"(%[[itshadow]], %[[zero]]) : (!enzyme.Gradient, f64) -> () -// CHECK-NEXT: %[[xcache:.+]] = "enzyme.init"() : () -> !enzyme.Cache -// CHECK-NEXT: %[[rcache:.+]] = "enzyme.init"() : () -> !enzyme.Cache -// CHECK-NEXT: %[[rshadow:.+]] = "enzyme.init"() : () -> !enzyme.Gradient -// CHECK-NEXT: "enzyme.set"(%[[rshadow]], %[[zero]]) : (!enzyme.Gradient, f64) -> () - -// CHECK-NEXT: %{{.+}} = scf.for %[[iv:.+]] = %c0 to %c10 step %c1 iter_args(%[[r_it:.+]] = %[[one]]) -> (f64) { -// CHECK-NEXT: "enzyme.push"(%[[rcache]], %[[r_it]]) : (!enzyme.Cache, f64) -> () -// CHECK-NEXT: "enzyme.push"(%[[xcache]], %[[x]]) : (!enzyme.Cache, f64) -> () +// CHECK-NEXT: %[[cache:.+]] = tensor.empty() : tensor<10xf64> +// CHECK-NEXT: %{{.+}} = scf.for %[[iv:.+]] = %c0 to %c10 step %c1 iter_args(%[[r_it:.+]] = %[[one]], %[[cache_iter:.+]] = %[[cache]]) -> (f64, tensor<10xf64>) { +// CHECK-NEXT: %[[cache_new:.+]] = tensor.insert %[[r_it]] into %[[cache_iter]][%[[iv]]] : tensor<10xf64> // CHECK-NEXT: %[[fwd:.+]] = arith.mulf %[[r_it]], %[[x]] : f64 -// CHECK-NEXT: scf.yield %[[fwd]] : f64 +// CHECK-NEXT: scf.yield %[[fwd]], %[[cache_new]] : f64, tensor<10xf64> // CHECK-NEXT: } -// CHECK-NEXT: "enzyme.set"(%[[rshadow]], %[[dr]]) : (!enzyme.Gradient, f64) -> () -// CHECK-NEXT: scf.for %[[div:.+]] = %c0 to %c10 step %c1 { -// CHECK-NEXT: %[[dr_it:.+]] = "enzyme.get"(%[[rshadow]]) : (!enzyme.Gradient) -> f64 -// CHECK-NEXT: "enzyme.set"(%[[rshadow]], %[[zero]]) : (!enzyme.Gradient, f64) -> () -// CHECK-NEXT: %[[r_cached:.+]] = "enzyme.pop"(%[[rcache]]) : (!enzyme.Cache) -> f64 -// CHECK-NEXT: %[[x_cached:.+]] = "enzyme.pop"(%[[xcache]]) : (!enzyme.Cache) -> f64 -// CHECK-NEXT: %[[dr_next:.+]] = arith.mulf %[[dr_it]], %[[x_cached]] -// CHECK-NEXT: %[[previts:.+]] = "enzyme.get"(%[[itshadow]]) : (!enzyme.Gradient) -> f64 -// CHECK-NEXT: %[[postits:.+]] = arith.addf %[[previts]], %[[dr_next]] : f64 -// CHECK-NEXT: "enzyme.set"(%[[itshadow]], %[[postits]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %{{.+}} = scf.for %[[div:.+]] = %c0 to %c10 step %c1 iter_args(%[[dr_it:.+]] = %[[dr]], %[[rev_idx:.+]] = %c9, %[[dx0:.+]] = %[[zero]]) -> (f64, index, f64) { +// CHECK-NEXT: %[[r_cached:.+]] = tensor.extract %1#1[%[[rev_idx]]] : tensor<10xf64> +// CHECK-NEXT: %[[dr_next:.+]] = arith.mulf %[[dr_it]], %[[x]] : f64 // CHECK-NEXT: %[[dx_next:.+]] = arith.mulf %[[dr_it]], %[[r_cached]] : f64 -// CHECK-NEXT: %[[dx0:.+]] = "enzyme.get"(%[[xshadow]]) : // CHECK-NEXT: %[[dx1:.+]] = arith.addf %[[dx0]], %[[dx_next]] -// CHECK-NEXT: "enzyme.set"(%[[xshadow]], %[[dx1]]) : (!enzyme.Gradient, f64) -> () -// CHECK-NEXT: %[[divp1:.+]] = arith.addi %[[div]], %c1 : index -// CHECK-NEXT: %[[last:.+]] = arith.cmpi sge, %[[divp1]], %c10 : index -// CHECK-NEXT: "enzyme.set"(%[[itshadow]], %[[zero]]) : (!enzyme.Gradient, f64) -> () -// CHECK-NEXT: %[[sel:.+]] = arith.select %[[last]], %[[zero]], %12 : f64 -// CHECK-NEXT: "enzyme.set"(%[[itshadow]], %[[sel]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[new_rev_idx:.+]] = arith.subi %[[rev_idx]], %c1 : index +// CHECK-NEXT: scf.yield %[[dr_next]], %[[new_rev_idx]], %[[dx1]] : f64, index, f64 // CHECK-NEXT: } -// CHECK-NEXT: %[[final:.+]] = "enzyme.get"(%[[xshadow]]) -// CHECK-NEXT: return %[[final]] \ No newline at end of file +// CHECK-NEXT: return %2#2 : f64 diff --git a/enzyme/test/MLIR/ReverseMode/scf_for.mlir b/enzyme/test/MLIR/ReverseMode/scf_for.mlir new file mode 100644 index 000000000000..22f7a93762a1 --- /dev/null +++ b/enzyme/test/MLIR/ReverseMode/scf_for.mlir @@ -0,0 +1,42 @@ +// RUN: %eopt %s --enzyme-wrap="infn=reduce outfn= argTys=enzyme_active,enzyme_const retTys=enzyme_active mode=ReverseModeCombined" --canonicalize --remove-unnecessary-enzyme-ops | FileCheck %s + +func.func @reduce(%x: f32, %ub: index) -> (f32) { + %lb = arith.constant 0 : index + %step = arith.constant 1 : index + + // Initial sum set to 0. + %sum_0 = arith.constant 1.0 : f32 + // iter_args binds initial values to the loop's region arguments. + %sum = scf.for %iv = %lb to %ub step %step + iter_args(%sum_iter = %sum_0) -> (f32) { + %sum_next = arith.mulf %sum_iter, %x : f32 + // Yield current iteration sum to next iteration %sum_iter or to %sum + // if final iteration. + scf.yield %sum_next : f32 + } + return %sum : f32 +} + +// CHECK: func.func @reduce(%arg0: f32, %arg1: index, %arg2: f32) -> f32 { +// CHECK-NEXT: %cst = arith.constant 1.000000e+00 : f32 +// CHECK-NEXT: %c1 = arith.constant 1 : index +// CHECK-NEXT: %c0 = arith.constant 0 : index +// CHECK-NEXT: %cst_0 = arith.constant 0.000000e+00 : f32 +// CHECK-NEXT: %0 = tensor.empty(%arg1) : tensor +// CHECK-NEXT: %1:2 = scf.for %arg3 = %c0 to %arg1 step %c1 iter_args(%arg4 = %cst, %arg5 = %0) -> (f32, tensor) { +// CHECK-NEXT: %inserted = tensor.insert %arg4 into %arg5[%arg3] : tensor +// CHECK-NEXT: %4 = arith.mulf %arg4, %arg0 : f32 +// CHECK-NEXT: scf.yield %4, %inserted : f32, tensor +// CHECK-NEXT: } +// CHECK-NEXT: %2 = arith.addf %arg2, %cst_0 : f32 +// CHECK-NEXT: %3:3 = scf.for %arg3 = %c0 to %arg1 step %c1 iter_args(%arg4 = %2, %arg5 = %arg1, %arg6 = %cst_0) -> (f32, index, f32) { +// CHECK-NEXT: %extracted = tensor.extract %1#1[%arg5] : tensor +// CHECK-NEXT: %4 = arith.mulf %arg4, %arg0 : f32 +// CHECK-NEXT: %5 = arith.addf %4, %cst_0 : f32 +// CHECK-NEXT: %6 = arith.mulf %arg4, %extracted : f32 +// CHECK-NEXT: %7 = arith.addf %arg6, %6 : f32 +// CHECK-NEXT: %8 = arith.subi %arg5, %c1 : index +// CHECK-NEXT: scf.yield %5, %8, %7 : f32, index, f32 +// CHECK-NEXT: } +// CHECK-NEXT: return %3#2 : f32 +// CHECK-NEXT: } From 96b8efc8125ffc83244233dd54b7582a9b1b5d85 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Tue, 7 Jan 2025 04:51:15 +0100 Subject: [PATCH 40/45] Batched reverse mode (#2216) * fix for reversemode * fix test * fixup * fixup --------- Co-authored-by: William S. Moses --- .github/workflows/enzyme-mlir.yml | 2 +- enzyme/Enzyme/MLIR/Passes/CMakeLists.txt | 1 + enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp | 6 ++--- enzyme/Enzyme/MLIR/Passes/Passes.td | 1 + enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp | 1 - enzyme/Enzyme/MLIR/Passes/RemovalUtils.h | 2 +- .../MLIR/Passes/RemoveUnusedEnzymeOps.cpp | 8 +++++- .../test/MLIR/ForwardMode/batched_scalar.mlir | 2 +- .../test/MLIR/ReverseMode/batched_square.mlir | 27 +++++++++++++++++++ enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 21 ++++++++------- 10 files changed, 53 insertions(+), 18 deletions(-) create mode 100644 enzyme/test/MLIR/ReverseMode/batched_square.mlir diff --git a/.github/workflows/enzyme-mlir.yml b/.github/workflows/enzyme-mlir.yml index 89ae72957c79..16b3fe6e11ee 100644 --- a/.github/workflows/enzyme-mlir.yml +++ b/.github/workflows/enzyme-mlir.yml @@ -36,7 +36,7 @@ jobs: - uses: actions/checkout@v4 with: repository: 'llvm/llvm-project' - ref: 'eaa7b385368fa7e3dad9b95411d04be55e71494e' + ref: 'ff24e9a19e3db330dd6412aac9d1d6c0b416697f' path: 'llvm-project' - name: Get MLIR commit hash diff --git a/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt b/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt index 99db4d80034c..00b2cae1e381 100644 --- a/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt +++ b/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt @@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIREnzymeTransforms SimplifyMath.cpp AddToOpToIndexAndLoad.cpp AddToOpToSplit.cpp + RemovalUtils.cpp RemoveUnusedEnzymeOps.cpp SimplifyMemrefCache.cpp Utils.cpp diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp index c91f5400fefa..972222f87bae 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp @@ -39,9 +39,9 @@ struct DifferentiatePass : public DifferentiatePassBase { pm.getDependentDialects(registry); } - registry - .insert(); + registry.insert(); } static std::vector mode_from_fn(FunctionOpInterface fn, diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.td b/enzyme/Enzyme/MLIR/Passes/Passes.td index d3494956a12f..ebe00135b9fe 100644 --- a/enzyme/Enzyme/MLIR/Passes/Passes.td +++ b/enzyme/Enzyme/MLIR/Passes/Passes.td @@ -18,6 +18,7 @@ def DifferentiatePass : Pass<"enzyme"> { "complex::ComplexDialect", "cf::ControlFlowDialect", "tensor::TensorDialect", + "enzyme::EnzymeDialect", ]; let options = [ Option< diff --git a/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp b/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp index 002b11d6bc92..572fddd1cae5 100644 --- a/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp +++ b/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp @@ -26,7 +26,6 @@ mlir::enzyme::CacheInfo::merge(mlir::enzyme::CacheInfo other) { other.initOp->erase(); } - enzyme::PushOp newPushOp = pushOp; other.pushOp->erase(); enzyme::PopOp newPopOp; diff --git a/enzyme/Enzyme/MLIR/Passes/RemovalUtils.h b/enzyme/Enzyme/MLIR/Passes/RemovalUtils.h index 32308ed1d6b0..d56ce6018daf 100644 --- a/enzyme/Enzyme/MLIR/Passes/RemovalUtils.h +++ b/enzyme/Enzyme/MLIR/Passes/RemovalUtils.h @@ -41,7 +41,7 @@ struct CacheInfo { Value pushedValue() { return pushOp.getValue(); } Type cachedType() { - return initOp.getResult().getType().cast().getType(); + return cast(initOp.getResult().getType()).getType(); } // Pushed values must be the same diff --git a/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp b/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp index 8ee77113e9ae..cb25fa6fa8bb 100644 --- a/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp +++ b/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp @@ -306,12 +306,18 @@ struct RemoveUnusedEnzymeOpsPass applyPatterns(op); + bool failed = false; op->walk([&](FunctionOpInterface func) { func->walk([&](enzyme::EnzymeOpsRemoverOpInterface iface) { - iface.removeEnzymeOps(); + auto result = iface.removeEnzymeOps(); + if (!result.succeeded()) + failed = true; }); }); + if (failed) + return signalPassFailure(); + applyPatterns(op); } }; diff --git a/enzyme/test/MLIR/ForwardMode/batched_scalar.mlir b/enzyme/test/MLIR/ForwardMode/batched_scalar.mlir index d384bdd09337..8acd131c169b 100644 --- a/enzyme/test/MLIR/ForwardMode/batched_scalar.mlir +++ b/enzyme/test/MLIR/ForwardMode/batched_scalar.mlir @@ -21,6 +21,6 @@ module { // CHECK-NEXT: %[[s1:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array}> : (f64) -> tensor<2xf64> // CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[s1]] : tensor<2xf64> // CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : tensor<2xf64> -// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : tensor<2xf64> +// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] // CHECK-NEXT: return %[[i2]] : tensor<2xf64> // CHECK-NEXT: } diff --git a/enzyme/test/MLIR/ReverseMode/batched_square.mlir b/enzyme/test/MLIR/ReverseMode/batched_square.mlir new file mode 100644 index 000000000000..86c286703129 --- /dev/null +++ b/enzyme/test/MLIR/ReverseMode/batched_square.mlir @@ -0,0 +1,27 @@ +// RUN: %eopt --enzyme --canonicalize --remove-unnecessary-enzyme-ops --canonicalize --enzyme-simplify-math %s | FileCheck %s + +module { + func.func @square(%x: f64) -> f64 { + %next = arith.mulf %x, %x : f64 + return %next : f64 + } + + func.func @dsquare(%x: f64, %dr: tensor<2xf64>) -> tensor<2xf64> { + %r = enzyme.autodiff @square(%x, %dr) { activity=[#enzyme], ret_activity=[#enzyme], width=2 } : (f64, tensor<2xf64>) -> tensor<2xf64> + return %r : tensor<2xf64> + } +} + +// CHECK: func.func @dsquare(%arg0: f64, %arg1: tensor<2xf64>) -> tensor<2xf64> { +// CHECK-NEXT: %0 = call @diffe2square(%arg0, %arg1) : (f64, tensor<2xf64>) -> tensor<2xf64> +// CHECK-NEXT: return %0 : tensor<2xf64> +// CHECK-NEXT: } + +// CHECK: func.func private @diffe2square(%arg0: f64, %arg1: tensor<2xf64>) -> tensor<2xf64> { +// CHECK-NEXT: %0 = "enzyme.broadcast"(%arg0) <{shape = array}> : (f64) -> tensor<2xf64> +// CHECK-NEXT: %1 = arith.mulf %arg1, %0 : tensor<2xf64> +// CHECK-NEXT: %2 = "enzyme.broadcast"(%arg0) <{shape = array}> : (f64) -> tensor<2xf64> +// CHECK-NEXT: %3 = arith.mulf %arg1, %2 : tensor<2xf64> +// CHECK-NEXT: %4 = arith.addf %1, %3 : tensor<2xf64> +// CHECK-NEXT: return %4 : tensor<2xf64> +// CHECK-NEXT: } diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index dccbc7b7923c..3f85a07548e1 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -277,16 +277,17 @@ SmallVector prepareArgs(const Twine &curIndent, raw_ostream &os, if (!vecValue && !startsWith(ord, "local")) { if (newFromOriginal && (!lookup || intrinsic != MLIRDerivatives)) { os << ")"; - if (intrinsic == MLIRDerivatives) { - os << ";\n"; - os << "if (gutils->width != 1) {\n" - << " " << argName << "_" << (idx - 1) - << " = builder.create(\n" - << " op.getLoc(),\n" - << " " << argName << "_" << (idx - 1) << ",\n" - << " llvm::SmallVector({gutils->width}));\n" - << "}"; - } + } + if (intrinsic == MLIRDerivatives) { + os << ";\n"; + os << curIndent << "if (gutils->width != 1) {\n" + << curIndent << " " << argName << "_" << (idx - 1) + << " = builder.create(\n" + << curIndent << " op.getLoc(),\n" + << curIndent << " " << argName << "_" << (idx - 1) << ",\n" + << curIndent + << " llvm::SmallVector({gutils->width}));\n" + << curIndent << "}"; } if (lookup && intrinsic != MLIRDerivatives) From ac3be7e59da58e8e57d51d712b52270929986518 Mon Sep 17 00:00:00 2001 From: tyb0807 Date: Wed, 8 Jan 2025 22:52:25 +0100 Subject: [PATCH 41/45] Add derivative for LLVM:ExpOp (#2220) Co-authored-by: BuildKite --- enzyme/Enzyme/MLIR/Implementations/Common.td | 4 ++++ .../LLVMAutoDiffOpInterfaceImpl.cpp | 1 + .../MLIR/Implementations/LLVMDerivatives.td | 6 ++++++ enzyme/test/MLIR/ForwardMode/llvm.mlir | 17 +++++++++++++++++ 4 files changed, 28 insertions(+) diff --git a/enzyme/Enzyme/MLIR/Implementations/Common.td b/enzyme/Enzyme/MLIR/Implementations/Common.td index 9846f9ae6183..29f977c95a98 100644 --- a/enzyme/Enzyme/MLIR/Implementations/Common.td +++ b/enzyme/Enzyme/MLIR/Implementations/Common.td @@ -121,6 +121,7 @@ def TypeOf : Operation { class ComplexInst : Inst; class ArithInst : Inst; +class LlvmInst : Inst; class MathInst : Inst; def AddF : ArithInst<"AddFOp">; @@ -133,6 +134,9 @@ def RemF : ArithInst<"RemFOp">; def CheckedMulF : ArithInst<"MulFOp">; def CheckedDivF : ArithInst<"DivFOp">; +def LlvmCheckedMulF : LlvmInst<"FMulOp">; +def LlvmExpF : LlvmInst<"ExpOp">; + def CosF : MathInst<"CosOp">; def SinF : MathInst<"SinOp">; def ExpF : MathInst<"ExpOp">; diff --git a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp index b9e9ade7421f..11b191d7fdd6 100644 --- a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp @@ -15,6 +15,7 @@ #include "Interfaces/AutoDiffOpInterface.h" #include "Interfaces/AutoDiffTypeInterface.h" #include "Interfaces/GradientUtils.h" +#include "Interfaces/GradientUtilsReverse.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/Support/LogicalResult.h" diff --git a/enzyme/Enzyme/MLIR/Implementations/LLVMDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/LLVMDerivatives.td index e77e88aea47f..949eeb22e09b 100644 --- a/enzyme/Enzyme/MLIR/Implementations/LLVMDerivatives.td +++ b/enzyme/Enzyme/MLIR/Implementations/LLVMDerivatives.td @@ -26,3 +26,9 @@ def : ReadOnlyIdentityOp<"LLVM", "PtrToIntOp", [0]>; def : ReadOnlyIdentityOp<"LLVM", "IntToPtrOp", [0]>; def : AllocationOp<"LLVM", "AllocaOp">; + +def : MLIRDerivative<"LLVM", "ExpOp", (Op $x), + [ + (LlvmCheckedMulF (DiffeRet), (LlvmExpF $x)) + ] + >; diff --git a/enzyme/test/MLIR/ForwardMode/llvm.mlir b/enzyme/test/MLIR/ForwardMode/llvm.mlir index df7e572ce4ce..f6cd8c3a9c2c 100644 --- a/enzyme/test/MLIR/ForwardMode/llvm.mlir +++ b/enzyme/test/MLIR/ForwardMode/llvm.mlir @@ -13,6 +13,16 @@ module { %r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme], ret_activity=[#enzyme] } : (f64, f64) -> (f64) return %r : f64 } + + func.func @exp(%x: f32) -> f32 { + %0 = llvm.intr.exp(%x) : (f32) -> f32 + return %0 : f32 + } + + func.func @dexp(%x: f32, %dx: f32) -> f32 { + %r = enzyme.fwddiff @exp(%x, %dx) { activity=[#enzyme], ret_activity=[#enzyme] } : (f32, f32) -> f32 + return %r : f32 + } } // CHECK: func.func private @fwddiffesquare(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> f64 { @@ -29,3 +39,10 @@ module { // CHECK-NEXT: %[[i7:.+]] = llvm.load %[[i1]] : !llvm.ptr -> f64 // CHECK-NEXT: return %[[i6]] : f64 // CHECK-NEXT: } + +// CHECK: func.func private @fwddiffeexp(%[[arg0:.+]]: f32, %[[arg1:.+]]: f32) -> f32 { +// CHECK-NEXT: %[[der:.+]] = llvm.intr.exp(%[[arg0]]) : (f32) -> f32 +// CHECK-NEXT: %[[res:.+]] = llvm.fmul %[[arg1]], %[[der]] : f32 +// CHECK-NEXT: %[[exp:.+]] = llvm.intr.exp(%[[arg0]]) : (f32) -> f32 +// CHECK-NEXT: return %[[res]] : f32 +// CHECK-NEXT: } From 495fde356c97a6095c34d42a503d82de70cfe63f Mon Sep 17 00:00:00 2001 From: Paul Berg Date: Thu, 9 Jan 2025 23:25:54 +0100 Subject: [PATCH 42/45] tablegen: Add StaticSelect to select based on static condition (#2206) * tablegen: Add StaticIf to select based on static condition * rename to StaticSelect and implement SelectIfActive and SelectIfComplex with it * define for llvm * put vector mode for LLVM back * basic use analysis * StaticSelect use analysis * fixup --------- Co-authored-by: William S. Moses --- enzyme/Enzyme/InstructionDerivatives.td | 8 +- enzyme/Enzyme/MLIR/Implementations/Common.td | 25 +- enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 446 +++++++++++-------- 3 files changed, 273 insertions(+), 206 deletions(-) diff --git a/enzyme/Enzyme/InstructionDerivatives.td b/enzyme/Enzyme/InstructionDerivatives.td index c778e63aa258..1c38bd9dfcaf 100644 --- a/enzyme/Enzyme/InstructionDerivatives.td +++ b/enzyme/Enzyme/InstructionDerivatives.td @@ -13,6 +13,11 @@ class ConstantFP : Operation { string value = val; } +class StaticSelect : Operation { + string condition = condition_; +} + +def SelectIfActive : StaticSelect<"!gutils->isConstantValue(imVal)">; class Attribute { string name = name_; @@ -62,9 +67,6 @@ class Inst : Operation { def TypeOf : Operation { } def VectorSize : Operation { -} -def SelectIfActive : Operation { - } // Define ops to rewrite. diff --git a/enzyme/Enzyme/MLIR/Implementations/Common.td b/enzyme/Enzyme/MLIR/Implementations/Common.td index 29f977c95a98..170f4e26bf12 100644 --- a/enzyme/Enzyme/MLIR/Implementations/Common.td +++ b/enzyme/Enzyme/MLIR/Implementations/Common.td @@ -86,10 +86,20 @@ def DiffeRet : DiffeRetIndex<[-1]>; def Shadow : Operation { } -class GlobalExpr : Operation{ +class GlobalExpr : Operation { string value = val; } +// Class for a dag operator that generates either a or b +// It can then be used with a two or three arguments. +// The two arguments version is (StaticSelect a, b) +// The three arguments version accepts a name as a first argument +// which is then available in the condition as a `Value` under the +// variable `imVal`. +class StaticSelect : Operation { + string condition = condition_; +} + class Inst : Operation { string name = mnemonic; string dialect = dialect_; @@ -99,13 +109,14 @@ class Inst : Operation { - -} - -def SelectIfComplex : Operation { +def SelectIfActive : StaticSelect<"!gutils->isConstantValue(imVal)">; -} +def SelectIfComplex : StaticSelect<[{ + auto ty = imVal.getType(); + ty.isa() || + ty.isa() && + ty.cast().getElementType().isa(); +}]>; class ConstantFP : Operation { string value = val; diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 3f85a07548e1..50efdaeae61c 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -436,53 +436,100 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, os << ".Min"; #endif return false; - } else if (opName == "SelectIfActive" || - Def->isSubClassOf("SelectIfActive")) { - if (resultRoot->getNumArgs() != 3) + } else if (Def->isSubClassOf("StaticSelect")) { + auto numArgs = resultRoot->getNumArgs(); + + if (numArgs != 2 && numArgs != 3) PrintFatalError(pattern->getLoc(), - "only three op SelectIfActive supported"); + "only two/three op StaticSelect supported"); os << "({\n"; - os << curIndent << INDENT << "// Computing SelectIfActive\n"; + os << curIndent << INDENT << "// Computing " << opName << "\n"; if (intrinsic == MLIRDerivatives) - os << curIndent << INDENT << "mlir::Value imVal = nullptr;\n"; + os << curIndent << INDENT << "mlir::Value imVal = "; else - os << curIndent << INDENT << "llvm::Value *imVal = nullptr;\n"; + os << curIndent << INDENT << "llvm::Value *imVal = "; - os << curIndent << INDENT << "if (!gutils->isConstantValue("; + int index = numArgs == 3; - if (isa(resultRoot->getArg(0)) && resultRoot->getArgName(0)) { - auto name = resultRoot->getArgName(0)->getAsUnquotedString(); - auto [ord, isVec, ext] = - nameToOrdinal.lookup(name, pattern, resultRoot); - assert(!isVec); - // This assumes that activity of inner extractions are the same as - // outer. assert(!ext.size()); - os << ord; - } else - assert("Requires name for arg"); + // First one is a name, set imVal to it + if (numArgs == 3) { + if (isa(resultRoot->getArg(0)) && + resultRoot->getArgName(0)) { + auto name = resultRoot->getArgName(0)->getAsUnquotedString(); + auto [ord, isVec, ext] = + nameToOrdinal.lookup(name, pattern, resultRoot); + assert(!isVec); + os << ord << ";\n"; + } else + assert("Requires name for arg"); + } else { + os << "nullptr;\n"; + } + + os << curIndent << INDENT << "bool condition = "; + + auto condition = dyn_cast(Def->getValueInit("condition")); + if (!condition) + PrintFatalError(pattern->getLoc(), + Twine("string 'condition' not defined in ") + + resultTree->getAsString()); + auto conditionStr = condition->getValue(); + + if (conditionStr.contains("imVal") && numArgs == 2) + PrintFatalError(pattern->getLoc(), "need a name as first argument"); - os << ")) {\n"; + bool complexExpr = conditionStr.contains(';'); + if (complexExpr) + os << "({\n"; + os << conditionStr; + if (complexExpr) + os << "\n" << curIndent << INDENT << "})"; + + os << ";\n"; - for (size_t i = 1; i < 3; i++) { + os << curIndent << INDENT << "bool vectorized = false;\n"; + + os << curIndent << INDENT << "if (condition) {\n"; + + bool any_vector = false; + bool all_vector = true; + for (size_t i = index; i < numArgs; ++i) { os << curIndent << INDENT << INDENT << "imVal = "; + bool vector; if (isa(resultRoot->getArg(i)) && resultRoot->getArgName(i)) { auto name = resultRoot->getArgName(i)->getAsUnquotedString(); auto [ord, isVec, ext] = nameToOrdinal.lookup(name, pattern, resultRoot); - vector = isVec; assert(!ext.size()); + vector = isVec; os << ord; - } else - vector = handle(curIndent + INDENT + INDENT, - argPattern + "_sia_" + Twine(i), os, pattern, - resultRoot->getArg(i), builder, nameToOrdinal, lookup, - retidx, origName, newFromOriginal, intrinsic); + } else { + vector = + handle(curIndent + INDENT + INDENT, argPattern + "_cic", os, + pattern, resultRoot->getArg(i), builder, nameToOrdinal, + lookup, retidx, origName, newFromOriginal, intrinsic); + } os << ";\n"; + if (vector) { + any_vector = true; + os << curIndent << INDENT << INDENT << "vectorized = true;\n"; + } else { + all_vector = false; + } + + if (i == numArgs - 1) { + os << curIndent << INDENT << "}\n"; + } else { + os << curIndent << INDENT << "} else {\n"; + } + } - if (!vector && intrinsic != MLIRDerivatives) { + if (any_vector && !all_vector) { + os << curIndent << INDENT << "if (!vectorized) {\n"; + if (intrinsic != MLIRDerivatives) { os << curIndent << INDENT << INDENT << "llvm::Value* vec_imVal = gutils->getWidth() == 1 ? imVal : " "UndefValue::get(gutils->getShadowType(imVal" @@ -496,81 +543,19 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, << ".CreateInsertValue(vec_imVal, imVal, " "std::vector({(unsigned)i}));\n"; os << curIndent << INDENT << INDENT << "imVal = vec_imVal;\n"; + } else { + os << curIndent << INDENT << "if (gutils->width != 1)\n" + << curIndent << INDENT << INDENT + << "imVal = builder.create(imVal.getLoc(), " + "imVal, SmallVector({gutils->width}));\n"; } - if (i == 1) - os << curIndent << INDENT << "} else {\n"; - else - os << curIndent << INDENT << "}\n"; + os << curIndent << INDENT << "}\n"; } os << curIndent << INDENT << "imVal;\n"; - os << curIndent << "})"; - return true; - } else if (opName == "SelectIfComplex" || - Def->isSubClassOf("SelectIfComplex")) { - if (resultRoot->getNumArgs() != 3) - PrintFatalError(pattern->getLoc(), - "only three op SelectIfComplex supported"); - - os << "({\n"; - os << curIndent << INDENT << "// Computing SelectIfComplex\n"; - if (intrinsic == MLIRDerivatives) - os << curIndent << INDENT << "mlir::Value imVal = "; - else - os << curIndent << INDENT << "llvm::Value *imVal = "; - - if (isa(resultRoot->getArg(0)) && resultRoot->getArgName(0)) { - auto name = resultRoot->getArgName(0)->getAsUnquotedString(); - auto [ord, isVec, ext] = - nameToOrdinal.lookup(name, pattern, resultRoot); - os << ord << ";\n"; - } else { - handle(curIndent + INDENT + INDENT, argPattern + "_cic", os, pattern, - resultRoot->getArg(0), builder, nameToOrdinal, lookup, retidx, - origName, newFromOriginal, intrinsic); - os << ";\n"; - } + os << curIndent << INDENT << "})"; - os << curIndent << INDENT - << "if (isa(imVal.getType()) || " - "(isa(imVal.getType()) && " - "isa(cast(imVal.getType()).getElementType(" - ")))) {\n"; - - os << curIndent << INDENT << INDENT << "imVal = "; - if (isa(resultRoot->getArg(1)) && resultRoot->getArgName(1)) { - auto name = resultRoot->getArgName(1)->getAsUnquotedString(); - auto [ord, isVec, ext] = - nameToOrdinal.lookup(name, pattern, resultRoot); - assert(!ext.size()); - os << ord << ";\n"; - } else { - handle(curIndent + INDENT + INDENT, argPattern + "_cic", os, pattern, - resultRoot->getArg(1), builder, nameToOrdinal, lookup, retidx, - origName, newFromOriginal, intrinsic); - os << ";\n"; - } - - os << curIndent << INDENT << "} else {\n"; - - os << curIndent << INDENT << INDENT << "imVal = "; - if (isa(resultRoot->getArg(2)) && resultRoot->getArgName(2)) { - auto name = resultRoot->getArgName(2)->getAsUnquotedString(); - auto [ord, isVec, ext] = - nameToOrdinal.lookup(name, pattern, resultRoot); - assert(!ext.size()); - os << ord << ";\n"; - } else { - handle(curIndent + INDENT + INDENT, argPattern + "_cic", os, pattern, - resultRoot->getArg(2), builder, nameToOrdinal, lookup, retidx, - origName, newFromOriginal, intrinsic); - os << ";\n"; - } - - os << curIndent << INDENT << "}\n"; - os << curIndent << INDENT << "imVal;"; - os << curIndent << INDENT << "})\n"; - return true; + return any_vector; } else if (opName == "ConstantFP" || Def->isSubClassOf("ConstantFP")) { auto value = dyn_cast(Def->getValueInit("value")); if (!value) @@ -1192,6 +1177,131 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, PrintFatalError(pattern->getLoc(), Twine("unknown operation")); } +std::string ReplaceAll(std::string str, const std::string &from, + const std::string &to) { + size_t start_pos = 0; + while ((start_pos = str.find(from, start_pos)) != std::string::npos) { + str.replace(start_pos, from.length(), to); + start_pos += + to.length(); // Handles case where 'to' is a substring of 'from' + } + return str; +} + +void handleUse( + const DagInit *root, const DagInit *resultTree, std::string &foundPrimalUse, + std::string &foundShadowUse, bool &foundDiffRet, std::string precondition, + const DagInit *tree, + StringMap> &varNameToCondition); + +void handleUseArgument( + StringRef name, const Init *arg, bool usesPrimal, bool usesShadow, + const DagInit *root, const DagInit *resultTree, std::string &foundPrimalUse, + std::string &foundShadowUse, bool &foundDiffRet, std::string precondition, + const DagInit *tree, + StringMap> &varNameToCondition) { + + auto arg2 = dyn_cast(arg); + + if (arg2) { + // Recursive use of shadow is unhandled + assert(!usesShadow); + + std::string foundPrimalUse2 = ""; + std::string foundShadowUse2 = ""; + + bool foundDiffRet2 = false; + // We set precondition to be false (aka "") if we do not need the + // primal, since we are now only recurring to set variables + // correctly. + if (name.size() || usesPrimal) + handleUse(root, arg2, name.size() ? foundPrimalUse2 : foundPrimalUse, + name.size() ? foundShadowUse2 : foundShadowUse, + name.size() ? foundDiffRet2 : foundDiffRet, + usesPrimal ? precondition : "", tree, varNameToCondition); + + if (name.size()) { + if (foundPrimalUse2.size() && + !(startsWith(foundPrimalUse, foundPrimalUse2) || + endsWith(foundPrimalUse, foundPrimalUse2))) { + if (foundPrimalUse.size() == 0) + foundPrimalUse = foundPrimalUse2; + else + foundPrimalUse += " || " + foundPrimalUse2; + } + if (foundShadowUse2.size() && + !(startsWith(foundShadowUse, foundShadowUse2) || + endsWith(foundShadowUse, foundShadowUse2))) { + if (foundShadowUse.size() == 0) + foundShadowUse = foundShadowUse2; + else + foundShadowUse += " || " + foundShadowUse2; + } + foundDiffRet |= foundDiffRet2; + + varNameToCondition[name] = + std::make_tuple(foundPrimalUse2, foundShadowUse2, foundDiffRet2); + } + } else { + assert(name.size()); + + if (name.size()) { + auto found = varNameToCondition.find(name); + if (found == varNameToCondition.end()) { + llvm::errs() << "tree scope: " << *tree << "\n"; + llvm::errs() << "root scope: " << *root << "\n"; + llvm::errs() << "could not find var name: " << name << "\n"; + } + assert(found != varNameToCondition.end()); + } + + if (precondition.size()) { + auto [foundPrimalUse2, foundShadowUse2, foundDiffRet2] = + varNameToCondition[name]; + if (precondition != "true") { + if (foundPrimalUse2.size()) { + foundPrimalUse2 = + "((" + foundPrimalUse2 + ")&&(" + precondition + ")"; + } + if (foundShadowUse2.size()) { + foundShadowUse2 = + "((" + foundShadowUse2 + ")&&(" + precondition + ")"; + } + } + if (usesPrimal) { + if (foundPrimalUse2.size() && + !(startsWith(foundPrimalUse, foundPrimalUse2) || + endsWith(foundPrimalUse, foundPrimalUse2))) { + if (foundPrimalUse.size() == 0) + foundPrimalUse = foundPrimalUse2; + else + foundPrimalUse += " || " + foundPrimalUse2; + } + if (foundShadowUse2.size() && + !(startsWith(foundShadowUse, foundShadowUse2) || + endsWith(foundShadowUse, foundShadowUse2))) { + if (foundShadowUse.size() == 0) + foundShadowUse = foundShadowUse2; + else + foundShadowUse += " || " + foundShadowUse2; + } + foundDiffRet |= foundDiffRet2; + } + if (usesShadow) { + if (foundPrimalUse2.size() && + !(startsWith(foundShadowUse, foundPrimalUse2) || + endsWith(foundShadowUse, foundPrimalUse2))) { + if (foundShadowUse.size() == 0) + foundShadowUse = foundPrimalUse2; + else + foundShadowUse += " || " + foundPrimalUse2; + } + assert(!foundDiffRet2); + assert(foundShadowUse2 == ""); + } + } + } +} void handleUse( const DagInit *root, const DagInit *resultTree, std::string &foundPrimalUse, std::string &foundShadowUse, bool &foundDiffRet, std::string precondition, @@ -1215,113 +1325,57 @@ void handleUse( bool usesShadow = Def->getValueAsBit("usesShadow"); bool usesCustom = Def->getValueAsBit("usesCustom"); - // We don't handle any custom primal/shadow - (void)usesCustom; - assert(!usesCustom); + if (Def->isSubClassOf("StaticSelect")) { + auto numArgs = resultTree->getNumArgs(); - for (auto argEn : llvm::enumerate(resultTree->getArgs())) { - auto name = resultTree->getArgNameStr(argEn.index()); + assert(numArgs == 2 || numArgs == 3); + auto condition = dyn_cast(Def->getValueInit("condition")); + assert(condition); + std::string conditionStr = condition->getValue().str(); - auto arg2 = dyn_cast(argEn.value()); + assert(!(StringRef(conditionStr).contains("imVal") && numArgs == 2)); - if (arg2) { - // Recursive use of shadow is unhandled - assert(!usesShadow); + // First one is a name, set imVal to it + if (numArgs == 3) { + if (isa(resultTree->getArg(0)) && resultTree->getArgName(0)) { + auto name = resultTree->getArgName(0)->getAsUnquotedString(); + conditionStr = ReplaceAll(conditionStr, "imVal", name); + } else + assert("Requires name for arg"); + } - std::string foundPrimalUse2 = ""; - std::string foundShadowUse2 = ""; + bool complexExpr = StringRef(conditionStr).contains(';'); + if (complexExpr) { + conditionStr = "({ " + conditionStr + " })"; + } - bool foundDiffRet2 = false; - // We set precondition to be false (aka "") if we do not need the - // primal, since we are now only recurring to set variables - // correctly. - if (name.size() || usesPrimal) - handleUse(root, arg2, name.size() ? foundPrimalUse2 : foundPrimalUse, - name.size() ? foundShadowUse2 : foundShadowUse, - name.size() ? foundDiffRet2 : foundDiffRet, - usesPrimal ? precondition : "", tree, varNameToCondition); + for (size_t i = numArgs == 3; i < numArgs; ++i) { + std::string conditionStr2 = + (i == numArgs - 1) ? ("!(" + conditionStr + ")") : conditionStr; + std::string precondition2; + if (precondition == "true") + precondition2 = conditionStr2; + else + precondition2 = "((" + precondition + ")&&(" + conditionStr2 + ")"; - if (name.size()) { - if (foundPrimalUse2.size() && - !(startsWith(foundPrimalUse, foundPrimalUse2) || - endsWith(foundPrimalUse, foundPrimalUse2))) { - if (foundPrimalUse.size() == 0) - foundPrimalUse = foundPrimalUse2; - else - foundPrimalUse += " || " + foundPrimalUse2; - } - if (foundShadowUse2.size() && - !(startsWith(foundShadowUse, foundShadowUse2) || - endsWith(foundShadowUse, foundShadowUse2))) { - if (foundShadowUse.size() == 0) - foundShadowUse = foundShadowUse2; - else - foundShadowUse += " || " + foundShadowUse2; - } - foundDiffRet |= foundDiffRet2; + auto name = resultTree->getArgNameStr(i); + auto arg = resultTree->getArg(i); + handleUseArgument(name, arg, true, false, root, resultTree, + foundPrimalUse, foundShadowUse, foundDiffRet, + precondition2, tree, varNameToCondition); + } - varNameToCondition[name] = - std::make_tuple(foundPrimalUse2, foundShadowUse2, foundDiffRet2); - } - } else { - assert(name.size()); - - if (name.size()) { - auto found = varNameToCondition.find(name); - if (found == varNameToCondition.end()) { - llvm::errs() << "tree scope: " << *tree << "\n"; - llvm::errs() << "root scope: " << *root << "\n"; - llvm::errs() << "could not find var name: " << name << "\n"; - } - assert(found != varNameToCondition.end()); - } + return; + } - if (precondition.size()) { - auto [foundPrimalUse2, foundShadowUse2, foundDiffRet2] = - varNameToCondition[name]; - if (precondition != "true") { - if (foundPrimalUse2.size()) { - foundPrimalUse2 = - "((" + foundPrimalUse2 + ")&&(" + precondition + ")"; - } - if (foundShadowUse2.size()) { - foundShadowUse2 = - "((" + foundShadowUse2 + ")&&(" + precondition + ")"; - } - } - if (usesPrimal) { - if (foundPrimalUse2.size() && - !(startsWith(foundPrimalUse, foundPrimalUse2) || - endsWith(foundPrimalUse, foundPrimalUse2))) { - if (foundPrimalUse.size() == 0) - foundPrimalUse = foundPrimalUse2; - else - foundPrimalUse += " || " + foundPrimalUse2; - } - if (foundShadowUse2.size() && - !(startsWith(foundShadowUse, foundShadowUse2) || - endsWith(foundShadowUse, foundShadowUse2))) { - if (foundShadowUse.size() == 0) - foundShadowUse = foundShadowUse2; - else - foundShadowUse += " || " + foundShadowUse2; - } - foundDiffRet |= foundDiffRet2; - } - if (usesShadow) { - if (foundPrimalUse2.size() && - !(startsWith(foundShadowUse, foundPrimalUse2) || - endsWith(foundShadowUse, foundPrimalUse2))) { - if (foundShadowUse.size() == 0) - foundShadowUse = foundPrimalUse2; - else - foundShadowUse += " || " + foundPrimalUse2; - } - assert(!foundDiffRet2); - assert(foundShadowUse2 == ""); - } - } - } + (void)usesCustom; + assert(!usesCustom); + + for (auto argEn : llvm::enumerate(resultTree->getArgs())) { + auto name = resultTree->getArgNameStr(argEn.index()); + handleUseArgument(name, argEn.value(), usesPrimal, usesShadow, root, + resultTree, foundPrimalUse, foundShadowUse, foundDiffRet, + precondition, tree, varNameToCondition); } } From e8bc18716cdaf190a259f2dcb2b91a2cf7489516 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 9 Jan 2025 17:55:42 -0500 Subject: [PATCH 43/45] Fix nametoordinal (#2221) --- enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 105 ++++++++++--------- 1 file changed, 57 insertions(+), 48 deletions(-) diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 50efdaeae61c..0c0a45c6eade 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -207,7 +207,7 @@ struct VariableSetting { StringMap> extractions; std::tuple> - lookup(StringRef name, const Record *pattern, const Init *resultRoot) { + lookup(StringRef name, const Record *pattern, const Init *resultRoot) const { auto ord = nameToOrdinal.find(name); if (ord == nameToOrdinal.end()) PrintFatalError(pattern->getLoc(), Twine("unknown named operand '") + @@ -1192,14 +1192,16 @@ void handleUse( const DagInit *root, const DagInit *resultTree, std::string &foundPrimalUse, std::string &foundShadowUse, bool &foundDiffRet, std::string precondition, const DagInit *tree, - StringMap> &varNameToCondition); + StringMap> &varNameToCondition, + const VariableSetting &nameToOrdinal); void handleUseArgument( StringRef name, const Init *arg, bool usesPrimal, bool usesShadow, const DagInit *root, const DagInit *resultTree, std::string &foundPrimalUse, std::string &foundShadowUse, bool &foundDiffRet, std::string precondition, const DagInit *tree, - StringMap> &varNameToCondition) { + StringMap> &varNameToCondition, + const VariableSetting &nameToOrdinal) { auto arg2 = dyn_cast(arg); @@ -1218,7 +1220,8 @@ void handleUseArgument( handleUse(root, arg2, name.size() ? foundPrimalUse2 : foundPrimalUse, name.size() ? foundShadowUse2 : foundShadowUse, name.size() ? foundDiffRet2 : foundDiffRet, - usesPrimal ? precondition : "", tree, varNameToCondition); + usesPrimal ? precondition : "", tree, varNameToCondition, + nameToOrdinal); if (name.size()) { if (foundPrimalUse2.size() && @@ -1306,7 +1309,8 @@ void handleUse( const DagInit *root, const DagInit *resultTree, std::string &foundPrimalUse, std::string &foundShadowUse, bool &foundDiffRet, std::string precondition, const DagInit *tree, - StringMap> &varNameToCondition) { + StringMap> &varNameToCondition, + const VariableSetting &nameToOrdinal) { auto opName = resultTree->getOperator()->getAsString(); auto Def = cast(resultTree->getOperator())->getDef(); if (opName == "DiffeRetIndex" || Def->isSubClassOf("DiffeRetIndex")) { @@ -1339,7 +1343,9 @@ void handleUse( if (numArgs == 3) { if (isa(resultTree->getArg(0)) && resultTree->getArgName(0)) { auto name = resultTree->getArgName(0)->getAsUnquotedString(); - conditionStr = ReplaceAll(conditionStr, "imVal", name); + auto [ord, isVec, ext] = nameToOrdinal.lookup(name, nullptr, nullptr); + assert(!isVec); + conditionStr = ReplaceAll(conditionStr, "imVal", ord); } else assert("Requires name for arg"); } @@ -1362,7 +1368,7 @@ void handleUse( auto arg = resultTree->getArg(i); handleUseArgument(name, arg, true, false, root, resultTree, foundPrimalUse, foundShadowUse, foundDiffRet, - precondition2, tree, varNameToCondition); + precondition2, tree, varNameToCondition, nameToOrdinal); } return; @@ -1375,16 +1381,57 @@ void handleUse( auto name = resultTree->getArgNameStr(argEn.index()); handleUseArgument(name, argEn.value(), usesPrimal, usesShadow, root, resultTree, foundPrimalUse, foundShadowUse, foundDiffRet, - precondition, tree, varNameToCondition); + precondition, tree, varNameToCondition, nameToOrdinal); } } +static VariableSetting parseVariables(const DagInit *tree, ActionType intrinsic, + StringRef origName) { + VariableSetting nameToOrdinal; + std::function)> insert = + [&](const DagInit *ptree, ArrayRef prev) { + unsigned i = 0; + for (auto tree : ptree->getArgs()) { + SmallVector next(prev.begin(), prev.end()); + next.push_back(i); + if (auto dg = dyn_cast(tree)) + insert(dg, next); + + if (ptree->getArgNameStr(i).size()) { + std::string op; + if (intrinsic != MLIRDerivatives) + op = (origName + ".getOperand(" + Twine(next[0]) + ")").str(); + else + op = (origName + "->getOperand(" + Twine(next[0]) + ")").str(); + std::vector extractions; + if (prev.size() > 0) { + for (unsigned i = 1; i < next.size(); i++) { + extractions.push_back(next[i]); + } + } + nameToOrdinal.insert(ptree->getArgNameStr(i), op, false, + extractions); + } + i++; + } + }; + + insert(tree, {}); + + if (tree->getNameStr().size()) + nameToOrdinal.insert(tree->getNameStr(), + (Twine("(&") + origName + ")").str(), false, {}); + return nameToOrdinal; +} + void printDiffUse( raw_ostream &os, Twine prefix, const ListInit *argOps, StringRef origName, ActionType intrinsic, const DagInit *tree, StringMap> &varNameToCondition) { os << prefix << " // Rule " << *tree << "\n"; + VariableSetting nameToOrdinal = parseVariables(tree, intrinsic, origName); + for (auto argOpEn : enumerate(*argOps)) { size_t argIdx = argOpEn.index(); if (auto resultRoot = dyn_cast(argOpEn.value())) { @@ -1417,7 +1464,8 @@ void printDiffUse( // hasDiffeRet(resultTree) handleUse(resultTree, resultTree, foundPrimalUse, foundShadowUse, - foundDiffRet, /*precondition*/ "true", tree, varNameToCondition); + foundDiffRet, /*precondition*/ "true", tree, varNameToCondition, + nameToOrdinal); os << prefix << " // Arg " << argIdx << " : " << *resultTree << "\n"; @@ -1587,45 +1635,6 @@ static void emitMLIRReverse(raw_ostream &os, const Record *pattern, os << " mlir::Value dif = nullptr;\n"; } -static VariableSetting parseVariables(const DagInit *tree, ActionType intrinsic, - StringRef origName) { - VariableSetting nameToOrdinal; - std::function)> insert = - [&](const DagInit *ptree, ArrayRef prev) { - unsigned i = 0; - for (auto tree : ptree->getArgs()) { - SmallVector next(prev.begin(), prev.end()); - next.push_back(i); - if (auto dg = dyn_cast(tree)) - insert(dg, next); - - if (ptree->getArgNameStr(i).size()) { - std::string op; - if (intrinsic != MLIRDerivatives) - op = (origName + ".getOperand(" + Twine(next[0]) + ")").str(); - else - op = (origName + "->getOperand(" + Twine(next[0]) + ")").str(); - std::vector extractions; - if (prev.size() > 0) { - for (unsigned i = 1; i < next.size(); i++) { - extractions.push_back(next[i]); - } - } - nameToOrdinal.insert(ptree->getArgNameStr(i), op, false, - extractions); - } - i++; - } - }; - - insert(tree, {}); - - if (tree->getNameStr().size()) - nameToOrdinal.insert(tree->getNameStr(), - (Twine("(&") + origName + ")").str(), false, {}); - return nameToOrdinal; -} - static void emitReverseCommon(raw_ostream &os, const Record *pattern, const DagInit *tree, ActionType intrinsic, StringRef origName, const ListInit *argOps) { From 02c5f89b251318da0baace3f442f955af1129111 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 9 Jan 2025 18:01:53 -0500 Subject: [PATCH 44/45] Fix paren balancing --- enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 0c0a45c6eade..672cb8954e5b 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -1352,17 +1352,17 @@ void handleUse( bool complexExpr = StringRef(conditionStr).contains(';'); if (complexExpr) { - conditionStr = "({ " + conditionStr + " })"; + conditionStr = " ({ " + conditionStr + " }) "; } for (size_t i = numArgs == 3; i < numArgs; ++i) { std::string conditionStr2 = - (i == numArgs - 1) ? ("!(" + conditionStr + ")") : conditionStr; + (i == numArgs - 1) ? (" !( " + conditionStr + " ) ") : conditionStr; std::string precondition2; if (precondition == "true") precondition2 = conditionStr2; else - precondition2 = "((" + precondition + ")&&(" + conditionStr2 + ")"; + precondition2 = "((" + precondition + ")&&(" + conditionStr2 + "))"; auto name = resultTree->getArgNameStr(i); auto arg = resultTree->getArg(i); From 565163635fe98c7e0da004b4ac6dd2cd45bc880c Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 9 Jan 2025 18:03:51 -0500 Subject: [PATCH 45/45] More paren balancing --- enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 672cb8954e5b..fade744a7494 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -1264,11 +1264,11 @@ void handleUseArgument( if (precondition != "true") { if (foundPrimalUse2.size()) { foundPrimalUse2 = - "((" + foundPrimalUse2 + ")&&(" + precondition + ")"; + "((" + foundPrimalUse2 + ")&&(" + precondition + "))"; } if (foundShadowUse2.size()) { foundShadowUse2 = - "((" + foundShadowUse2 + ")&&(" + precondition + ")"; + "((" + foundShadowUse2 + ")&&(" + precondition + "))"; } } if (usesPrimal) {