From a03c1d4008369c27536ab8450e808e34d4514a5f Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 21 Dec 2024 19:11:58 -0500 Subject: [PATCH 01/13] Fix forward rewrite (#2201) --- enzyme/Enzyme/CallDerivatives.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/enzyme/Enzyme/CallDerivatives.cpp b/enzyme/Enzyme/CallDerivatives.cpp index 22df3dab9a7..db7ab6ac445 100644 --- a/enzyme/Enzyme/CallDerivatives.cpp +++ b/enzyme/Enzyme/CallDerivatives.cpp @@ -3236,6 +3236,13 @@ bool AdjointGenerator::handleKnownCallDerivatives( CI->setCallingConv(call.getCallingConv()); CI->setTailCallKind(call.getTailCallKind()); CI->setDebugLoc(dbgLoc); + + if (funcName == "julia.gc_alloc_obj" || + funcName == "jl_gc_alloc_typed" || + funcName == "ijl_gc_alloc_typed") { + if (EnzymeShadowAllocRewrite) + EnzymeShadowAllocRewrite(wrap(CI), gutils); + } return CI; }; From 534a28595a9bf1f41d4a8222615f00e36611dc0f Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 22 Dec 2024 14:56:08 -0500 Subject: [PATCH 02/13] Expand shadow_alloc_rewrite capabilities (#2202) --- enzyme/Enzyme/CallDerivatives.cpp | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/enzyme/Enzyme/CallDerivatives.cpp b/enzyme/Enzyme/CallDerivatives.cpp index db7ab6ac445..60503c2b0fe 100644 --- a/enzyme/Enzyme/CallDerivatives.cpp +++ b/enzyme/Enzyme/CallDerivatives.cpp @@ -29,7 +29,8 @@ using namespace llvm; extern "C" { -void (*EnzymeShadowAllocRewrite)(LLVMValueRef, void *) = nullptr; +void (*EnzymeShadowAllocRewrite)(LLVMValueRef, void *, LLVMValueRef, uint64_t, + LLVMValueRef) = nullptr; } void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called, @@ -3014,6 +3015,9 @@ bool AdjointGenerator::handleKnownCallDerivatives( bb, anti, getIndex(&call, CacheType::Shadow, BuilderZ)); } else { bool zeroed = false; + uint64_t idx = 0; + Value *prev = nullptr; + ; auto rule = [&]() { Value *anti = bb.CreateCall(call.getFunctionType(), call.getCalledOperand(), @@ -3059,7 +3063,8 @@ bool AdjointGenerator::handleKnownCallDerivatives( funcName == "jl_gc_alloc_typed" || funcName == "ijl_gc_alloc_typed") { if (EnzymeShadowAllocRewrite) - EnzymeShadowAllocRewrite(wrap(anti), gutils); + EnzymeShadowAllocRewrite(wrap(anti), gutils, wrap(&call), + idx, wrap(prev)); } } if (Mode == DerivativeMode::ReverseModeCombined || @@ -3075,6 +3080,8 @@ bool AdjointGenerator::handleKnownCallDerivatives( zeroed = true; } } + idx++; + prev = anti; return anti; }; @@ -3224,6 +3231,8 @@ bool AdjointGenerator::handleKnownCallDerivatives( args.push_back(gutils->getNewFromOriginal(arg)); } + uint64_t idx = 0; + Value *prev = gutils->getNewFromOriginal(&call); auto rule = [&]() { SmallVector BundleTypes(args.size(), ValueType::Primal); @@ -3241,8 +3250,11 @@ bool AdjointGenerator::handleKnownCallDerivatives( funcName == "jl_gc_alloc_typed" || funcName == "ijl_gc_alloc_typed") { if (EnzymeShadowAllocRewrite) - EnzymeShadowAllocRewrite(wrap(CI), gutils); + EnzymeShadowAllocRewrite(wrap(CI), gutils, wrap(&call), idx, + wrap(prev)); } + idx++; + prev = CI; return CI; }; From f5767ef45ac7071ac44dd75d77771475377328e5 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 22 Dec 2024 18:03:34 -0500 Subject: [PATCH 03/13] Fix activity analysis store (#2203) --- enzyme/Enzyme/ActivityAnalysis.cpp | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index 2a22512732b..88ceb7e9f47 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -2797,8 +2797,16 @@ bool ActivityAnalyzer::isValueInactiveFromUsers(TypeResults const &TR, if (isa(TmpOrig) || isAllocationCall(TmpOrig, TLI)) { done.insert( std::make_tuple((User *)SI, SI->getPointerOperand(), UA)); + // If we are capturing a variable v, we need to check any loads or + // stores into that variable, even if we are checking only for + // stores. + auto UA2 = UA; + if (UA == UseActivity::OnlyStores || + UA == UseActivity::OnlyNonPointerStores || + UA == UseActivity::AllStores) + UA2 = UseActivity::None; for (const auto a : TmpOrig->users()) { - todo.push_back(std::make_tuple(a, TmpOrig, UA)); + todo.push_back(std::make_tuple(a, TmpOrig, UA2)); } AllocaSet.insert(TmpOrig); if (EnzymePrintActivity) From 79ac40699969e416e6088bb99dfdd523f0d7f40e Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 22 Dec 2024 20:14:21 -0500 Subject: [PATCH 04/13] Non-power of two cache switch (#2204) --- enzyme/Enzyme/Utils.cpp | 4 ++++ enzyme/Enzyme/Utils.h | 5 +++++ 2 files changed, 9 insertions(+) diff --git a/enzyme/Enzyme/Utils.cpp b/enzyme/Enzyme/Utils.cpp index 933a22304e2..f4655bac845 100644 --- a/enzyme/Enzyme/Utils.cpp +++ b/enzyme/Enzyme/Utils.cpp @@ -102,6 +102,10 @@ llvm::cl::opt EnzymeMemmoveWarning( llvm::cl::opt EnzymeRuntimeError( "enzyme-runtime-error", cl::init(false), cl::Hidden, cl::desc("Emit Runtime errors instead of compile time ones")); + +llvm::cl::opt EnzymeNonPower2Cache( + "enzyme-non-power2-cache", cl::init(false), cl::Hidden, + cl::desc("Disable caching of integers which are not a power of 2")); } void ZeroMemory(llvm::IRBuilder<> &Builder, llvm::Type *T, llvm::Value *obj, diff --git a/enzyme/Enzyme/Utils.h b/enzyme/Enzyme/Utils.h index 02ce4b8b47e..089a99c8691 100644 --- a/enzyme/Enzyme/Utils.h +++ b/enzyme/Enzyme/Utils.h @@ -91,6 +91,7 @@ enum class ErrorType { extern "C" { /// Print additional debug info relevant to performance extern llvm::cl::opt EnzymePrintPerf; +extern llvm::cl::opt EnzymeNonPower2Cache; extern llvm::cl::opt EnzymeStrongZero; extern llvm::cl::opt EnzymeBlasCopy; extern llvm::cl::opt EnzymeLapackCopy; @@ -1194,6 +1195,10 @@ static inline bool hasNoCache(llvm::Value *op) { } } } + if (auto IT = dyn_cast(op->getType())) + if (!isPowerOf2_64(IT->getBitWidth()) && !EnzymeNonPower2Cache) + return true; + return false; } From 1ca3ceb12723b7246fe19016bca6e91fba37f1c7 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 22 Dec 2024 23:11:13 -0500 Subject: [PATCH 05/13] ActivityAnalysis: consider atomicrmw (#2205) * ActivityAnalysis: consider atomicrmw * fix * fix --- enzyme/Enzyme/ActivityAnalysis.cpp | 22 +++++++++++++++++++++- enzyme/Enzyme/GradientUtils.cpp | 10 ++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index 88ceb7e9f47..67cbb5e9ddf 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -1554,6 +1554,26 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { ReEvaluateValueIfInactiveValue[II->getOperand(0)].insert(TmpOrig); } } + } else if (auto RMW = dyn_cast(TmpOrig)) { + if (directions == UP) { + if (isConstantValue(TR, RMW->getPointerOperand())) { + InsertConstantValue(TR, Val); + return true; + } + } else { + if (UpHypothesis->isConstantValue(TR, RMW->getPointerOperand())) { + InsertConstantValue(TR, Val); + insertConstantsFrom(TR, *UpHypothesis); + return true; + } + } + if (EnzymeEnableRecursiveHypotheses) { + ReEvaluateValueIfInactiveValue[RMW->getPointerOperand()].insert(Val); + if (TmpOrig != Val) { + ReEvaluateValueIfInactiveValue[RMW->getPointerOperand()].insert( + TmpOrig); + } + } } else if (auto op = dyn_cast(TmpOrig)) { if (isInactiveCall(*op) || op->hasFnAttr("enzyme_inactive_val") || op->getAttributes().hasAttribute(llvm::AttributeList::ReturnIndex, @@ -1940,7 +1960,7 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { isRefSet(AARes)) { if (EnzymePrintActivity) llvm::errs() << "potential active load: " << *I << "\n"; - if (isa(I) || isNVLoad(I)) { + if (isa(I) || isNVLoad(I) || isa(I)) { // If the ref'ing value is a load check if the loaded value is // active if (!Hypothesis->isConstantValue(TR, I)) { diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index c169365371a..414374dd779 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -3822,6 +3822,9 @@ bool GradientUtils::legalRecompute(const Value *val, } } + if (isa(val)) + return false; + if (auto phi = dyn_cast(val)) { if (auto uiv = hasUninverted(val)) { if (auto dli = dyn_cast_or_null(uiv)) { @@ -3835,6 +3838,13 @@ bool GradientUtils::legalRecompute(const Value *val, } } + auto found = fictiousPHIs.find(const_cast(phi)); + if (found != fictiousPHIs.end()) { + auto orig = found->second; + if (isa(orig)) + return false; + } + if (phi->getNumIncomingValues() == 0) { llvm::errs() << *oldFunc << "\n"; llvm::errs() << *newFunc << "\n"; From 59540625e212697ea4ab1d7beeecaa0eeec14426 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 25 Dec 2024 20:05:23 -0500 Subject: [PATCH 06/13] Pass if is unnecessary (#2207) --- enzyme/Enzyme/CallDerivatives.cpp | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/enzyme/Enzyme/CallDerivatives.cpp b/enzyme/Enzyme/CallDerivatives.cpp index 60503c2b0fe..bc5a095cfae 100644 --- a/enzyme/Enzyme/CallDerivatives.cpp +++ b/enzyme/Enzyme/CallDerivatives.cpp @@ -30,7 +30,7 @@ using namespace llvm; extern "C" { void (*EnzymeShadowAllocRewrite)(LLVMValueRef, void *, LLVMValueRef, uint64_t, - LLVMValueRef) = nullptr; + LLVMValueRef, uint8_t) = nullptr; } void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called, @@ -3062,9 +3062,12 @@ bool AdjointGenerator::handleKnownCallDerivatives( if (funcName == "julia.gc_alloc_obj" || funcName == "jl_gc_alloc_typed" || funcName == "ijl_gc_alloc_typed") { - if (EnzymeShadowAllocRewrite) + if (EnzymeShadowAllocRewrite) { + bool used = unnecessaryInstructions.find(&call) == + unnecessaryInstructions.end(); EnzymeShadowAllocRewrite(wrap(anti), gutils, wrap(&call), - idx, wrap(prev)); + idx, wrap(prev), used); + } } } if (Mode == DerivativeMode::ReverseModeCombined || @@ -3249,9 +3252,12 @@ bool AdjointGenerator::handleKnownCallDerivatives( if (funcName == "julia.gc_alloc_obj" || funcName == "jl_gc_alloc_typed" || funcName == "ijl_gc_alloc_typed") { - if (EnzymeShadowAllocRewrite) + if (EnzymeShadowAllocRewrite) { + bool used = unnecessaryInstructions.find(&call) == + unnecessaryInstructions.end(); EnzymeShadowAllocRewrite(wrap(CI), gutils, wrap(&call), idx, - wrap(prev)); + wrap(prev), used); + } } idx++; prev = CI; From 8e79483d4c2d4cb2a6ccf1354be0595c6658a73d Mon Sep 17 00:00:00 2001 From: Kiran Shila Date: Fri, 27 Dec 2024 10:48:13 -0800 Subject: [PATCH 07/13] Add nixpkgs to README (#2208) --- Readme.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/Readme.md b/Readme.md index 7dd210b8b7e..544e12dfc5d 100644 --- a/Readme.md +++ b/Readme.md @@ -39,6 +39,10 @@ brew install enzyme ``` spack install enzyme ``` +[Nix](https://nixos.org/) +``` +nix-shell -p enzyme +``` To get involved or if you have questions, please join our [mailing list](https://groups.google.com/d/forum/enzyme-dev). From eeb6200dafad352aa44ba163e3e9cd4f4eae5a8f Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Fri, 27 Dec 2024 21:14:15 +0100 Subject: [PATCH 08/13] Batched autodiff (#2181) * add type conversions for width != 1. This still requires changes in the tblgenerated derivative files. For example, createForwardModeTangent in MulFOpFwdDerivative could be altered like this: ``` LogicalResult createForwardModeTangent(Operation *op0, OpBuilder &builder, MGradientUtils *gutils) const { auto op = cast(op0); if (gutils->width != 1) { auto newop = gutils->getNewFromOriginal(op0); for (auto res : newop->getResults()) { res.setType(mlir::RankedTensorType::get({gutils->width}, res.getType())); } } gutils->eraseIfUnused(op); if (gutils->isConstantInstruction(op)) return success(); mlir::Value res = nullptr; if (!gutils->isConstantValue(op->getOperand(0))) { auto dif = gutils->invertPointerM(op->getOperand(0), builder); { mlir::Value itmp = ({ // Computing MulFOp auto fwdarg_0 = dif; dif.dump(); // TODO: gutils->makeBatched(...) auto fwdarg_1 = gutils->getNewFromOriginal(op->getOperand(1)); builder.create(op.getLoc(), fwdarg_0, fwdarg_1); }); itmp.dump(); if (!res) res = itmp; else { auto operandType = cast(res.getType()); res = operandType.createAddOp(builder, op.getLoc(), res, itmp); } } } if (!gutils->isConstantValue(op->getOperand(1))) { auto dif = gutils->invertPointerM(op->getOperand(1), builder); { mlir::Value itmp = ({ // Computing MulFOp auto fwdarg_0 = dif; dif.dump(); auto fwdarg_1 = gutils->getNewFromOriginal(op->getOperand(0)); builder.create(op.getLoc(), fwdarg_0, fwdarg_1); }); if (!res) res = itmp; else { auto operandType = cast(res.getType()); res = operandType.createAddOp(builder, op.getLoc(), res, itmp); } } } assert(res); gutils->setDiffe(op->getResult(0), res, builder); return success(); } ``` * add code to tblgen generator, this eventually needs to be a single function call. * a test and formatting * use tensor splatop * remove stale enzyme-tblgen changes * do the simple batching in enzyme-tblgen * include tensor in all AutoDiffOpInterfaceImpls * add enzyme broadcastop * getShadowType for TensorTypeInterface * create broadcastop in enzyme-tblgen * Revert "include tensor in all AutoDiffOpInterfaceImpls" This reverts commit c06ed01709b51bff5b794a7e4dc83b63510b9a84. * test * DenseI64ArrayAttr for shape instead of scalar width * `llvm::SmallVector` --> `ArrayRef` * formatting * use getShadowType in BroadcastOp builder Co-authored-by: Billy Moses * unstructured control flow test * scf.for * formatting * support `scf.if` test * formatting * forgotten includes --------- Co-authored-by: Jules Merckx Co-authored-by: Billy Moses --- enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td | 16 +++++++ enzyme/Enzyme/MLIR/Dialect/Ops.cpp | 15 +++++++ .../ArithAutoDiffOpInterfaceImpl.cpp | 8 ++++ .../BuiltinAutoDiffTypeInterfaceImpl.cpp | 16 +++++-- .../CoreDialectsAutoDiffImplementations.cpp | 8 ++-- .../CoreDialectsAutoDiffImplementations.h | 1 + .../Enzyme/MLIR/Interfaces/CloneFunction.cpp | 6 ++- .../Enzyme/MLIR/Interfaces/GradientUtils.cpp | 3 +- enzyme/Enzyme/MLIR/Passes/CMakeLists.txt | 1 + enzyme/Enzyme/MLIR/Passes/Passes.h | 5 +++ enzyme/Enzyme/MLIR/Passes/Passes.td | 3 +- enzyme/Enzyme/MLIR/enzymemlir-opt.cpp | 1 + .../test/MLIR/ForwardMode/batched_branch.mlir | 26 +++++++++++ enzyme/test/MLIR/ForwardMode/batched_for.mlir | 33 ++++++++++++++ enzyme/test/MLIR/ForwardMode/batched_if.mlir | 43 +++++++++++++++++++ .../test/MLIR/ForwardMode/batched_scalar.mlir | 26 +++++++++++ .../test/MLIR/ForwardMode/batched_tensor.mlir | 26 +++++++++++ enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 13 +++++- 18 files changed, 239 insertions(+), 11 deletions(-) create mode 100644 enzyme/test/MLIR/ForwardMode/batched_branch.mlir create mode 100644 enzyme/test/MLIR/ForwardMode/batched_for.mlir create mode 100644 enzyme/test/MLIR/ForwardMode/batched_if.mlir create mode 100644 enzyme/test/MLIR/ForwardMode/batched_scalar.mlir create mode 100644 enzyme/test/MLIR/ForwardMode/batched_tensor.mlir diff --git a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td index be139fb3d8b..72672a95940 100644 --- a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td +++ b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td @@ -192,4 +192,20 @@ def GenericAdjointOp : Enzyme_Op<"genericAdjoint", [AttrSizedOperandSegments]> { } +def BroadcastOp : Enzyme_Op<"broadcast"> { + let description = [{ + Broadcast the operand by adding extra dimensions with sizes provided by the `shape` attribute to the front. + For scalar operands, ranked tensor is created. + + NOTE: Only works for scalar and *ranked* tensor operands for now. + }]; + + let arguments = (ins AnyType:$input, DenseI64ArrayAttr:$shape); + let results = (outs AnyRankedTensor:$output); + + let builders = [ + OpBuilder<(ins "Value":$input, "ArrayRef":$shape)> + ]; +} + #endif // ENZYME_OPS diff --git a/enzyme/Enzyme/MLIR/Dialect/Ops.cpp b/enzyme/Enzyme/MLIR/Dialect/Ops.cpp index 3e318542730..7e48db2d583 100644 --- a/enzyme/Enzyme/MLIR/Dialect/Ops.cpp +++ b/enzyme/Enzyme/MLIR/Dialect/Ops.cpp @@ -27,6 +27,7 @@ #include "mlir/IR/IRMapping.h" #include "mlir/IR/IntegerSet.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/Debug.h" @@ -191,3 +192,17 @@ LogicalResult BatchOp::verifySymbolUses(SymbolTableCollection &symbolTable) { return success(); } + +//===----------------------------------------------------------------------===// +// BroadcastOp +//===----------------------------------------------------------------------===// + +void BroadcastOp::build(OpBuilder &builder, OperationState &result, Value input, + ArrayRef shape) { + auto shapeAttr = builder.getDenseI64ArrayAttr(shape); + auto resultTy = input.getType(); + for (auto s : llvm::reverse(shape)) { + resultTy = resultTy.cast().getShadowType(s); + } + build(builder, result, resultTy, input, shapeAttr); +} diff --git a/enzyme/Enzyme/MLIR/Implementations/ArithAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/ArithAutoDiffOpInterfaceImpl.cpp index 9b27503d79d..8d3650969d0 100644 --- a/enzyme/Enzyme/MLIR/Implementations/ArithAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/ArithAutoDiffOpInterfaceImpl.cpp @@ -17,6 +17,7 @@ #include "Interfaces/GradientUtilsReverse.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/Support/LogicalResult.h" @@ -69,3 +70,10 @@ void mlir::enzyme::registerArithDialectAutoDiffInterface( arith::ConstantOp::attachInterface(*context); }); } + +void mlir::enzyme::registerTensorDialectAutoDiffInterface( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *context, tensor::TensorDialect *) { + registerInterfaces(context); + }); +} diff --git a/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp index d2d6ddfe19b..7c72b97d093 100644 --- a/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp @@ -45,8 +45,11 @@ class FloatTypeInterface } Type getShadowType(Type self, unsigned width) const { - assert(width == 1 && "unsupported width != 1"); - return self; + if (width > 1) { + return RankedTensorType::get({width}, self); + } else { + return self; + } } bool isMutable(Type self) const { return false; } @@ -106,7 +109,14 @@ class TensorTypeInterface } Type getShadowType(Type self, unsigned width) const { - assert(width == 1 && "unsupported width != 1"); + if (width != 1) { + auto tenType = self.cast(); + auto shape = tenType.getShape(); + SmallVector newShape; + newShape.push_back(width); + newShape.append(shape.begin(), shape.end()); + return RankedTensorType::get(newShape, tenType.getElementType()); + } return self; } diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index 355808cdbcc..f727dca2f87 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -74,7 +74,8 @@ void mlir::enzyme::detail::branchingForwardHandler(Operation *inst, newVals.push_back(gutils->invertPointerM(op, builder)); } else { Type retTy = - arg.getType().cast().getShadowType(); + arg.getType().cast().getShadowType( + gutils->width); auto toret = retTy.cast().createNullValue( builder, op.getLoc()); newVals.push_back(toret); @@ -146,7 +147,7 @@ LogicalResult mlir::enzyme::detail::memoryIdentityForwardHandler( if (auto iface = dyn_cast(operand.get().getType())) { if (!iface.isMutable()) { - Type retTy = iface.getShadowType(); + Type retTy = iface.getShadowType(gutils->width); auto toret = retTy.cast().createNullValue( builder, operand.get().getLoc()); newOperands.push_back(toret); @@ -346,7 +347,7 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler( << result.getType() << "\n"; return failure(); } - newOpResultTypes.push_back(typeIface.getShadowType()); + newOpResultTypes.push_back(typeIface.getShadowType(gutils->width)); } SmallVector newOperands; @@ -432,4 +433,5 @@ void mlir::enzyme::registerCoreDialectAutodiffInterfaces( enzyme::registerCFDialectAutoDiffInterface(registry); enzyme::registerLinalgDialectAutoDiffInterface(registry); enzyme::registerFuncDialectAutoDiffInterface(registry); + enzyme::registerTensorDialectAutoDiffInterface(registry); } diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h index d6f28ccfc73..650f6c6326b 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h @@ -260,6 +260,7 @@ void registerCFDialectAutoDiffInterface(DialectRegistry ®istry); void registerLinalgDialectAutoDiffInterface(DialectRegistry ®istry); void registerMathDialectAutoDiffInterface(DialectRegistry ®istry); void registerFuncDialectAutoDiffInterface(DialectRegistry ®istry); +void registerTensorDialectAutoDiffInterface(DialectRegistry ®istry); void registerCoreDialectAutodiffInterfaces(DialectRegistry ®istry); diff --git a/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp b/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp index 69cfad436cf..5ec908f1268 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp @@ -245,9 +245,11 @@ FunctionOpInterface CloneFunctionWithReturns( mlir::Value val = blk.getArgument(i); mlir::Value dval; if (i == ArgActivity.size() - 1) - dval = blk.addArgument(val.getType(), val.getLoc()); + dval = blk.addArgument(getShadowType(val.getType(), width), + val.getLoc()); else - dval = blk.insertArgument(blk.args_begin() + i + 1, val.getType(), + dval = blk.insertArgument(blk.args_begin() + i + 1, + getShadowType(val.getType(), width), val.getLoc()); ptrInputs.map(oval, dval); } diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp index 1ec4212dc5a..32cb5b79614 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp @@ -108,7 +108,8 @@ mlir::Value mlir::enzyme::MGradientUtils::invertPointerM(mlir::Value v, return invertedPointers.lookupOrNull(v); if (isConstantValue(v)) { - if (auto iface = v.getType().dyn_cast()) { + if (auto iface = + getShadowType(v.getType()).dyn_cast()) { OpBuilder::InsertionGuard guard(Builder2); if (auto op = v.getDefiningOp()) Builder2.setInsertionPoint(getNewFromOriginal(op)); diff --git a/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt b/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt index 0445fc43064..99db4d80034 100644 --- a/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt +++ b/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt @@ -31,6 +31,7 @@ add_mlir_dialect_library(MLIREnzymeTransforms MLIRFuncDialect MLIRFuncTransforms MLIRGPUDialect + MLIRTensorDialect MLIRIR MLIRLLVMDialect MLIRMathDialect diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.h b/enzyme/Enzyme/MLIR/Passes/Passes.h index 58c43be236d..fb6df3e2208 100644 --- a/enzyme/Enzyme/MLIR/Passes/Passes.h +++ b/enzyme/Enzyme/MLIR/Passes/Passes.h @@ -15,6 +15,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "Dialect/Dialect.h" @@ -80,6 +81,10 @@ namespace affine { class AffineDialect; } // end namespace affine +namespace tensor { +class TensorDialect; +} // end namespace tensor + namespace LLVM { class LLVMDialect; } // end namespace LLVM diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.td b/enzyme/Enzyme/MLIR/Passes/Passes.td index 6458e63b273..c5b4df76917 100644 --- a/enzyme/Enzyme/MLIR/Passes/Passes.td +++ b/enzyme/Enzyme/MLIR/Passes/Passes.td @@ -16,7 +16,8 @@ def DifferentiatePass : Pass<"enzyme"> { let dependentDialects = [ "arith::ArithDialect", "complex::ComplexDialect", - "cf::ControlFlowDialect" + "cf::ControlFlowDialect", + "tensor::TensorDialect", ]; let constructor = "mlir::enzyme::createDifferentiatePass()"; } diff --git a/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp b/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp index 0e6bdf7b101..99e7243129b 100644 --- a/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp +++ b/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp @@ -67,6 +67,7 @@ int main(int argc, char **argv) { registry.insert(); registry.insert(); registry.insert(); + registry.insert(); registry.insert(); registry.insert(); diff --git a/enzyme/test/MLIR/ForwardMode/batched_branch.mlir b/enzyme/test/MLIR/ForwardMode/batched_branch.mlir new file mode 100644 index 00000000000..f20989aa424 --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/batched_branch.mlir @@ -0,0 +1,26 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @square(%x : f64, %y : f64) -> f64 { + %c = arith.cmpf ult, %x, %y : f64 + cf.cond_br %c, ^blk2(%x : f64), ^blk2(%y : f64) + + ^blk2(%r : f64): + return %r : f64 + } + func.func @dsq(%x : f64, %dx : tensor<2xf64>, %y : f64, %dy : tensor<2xf64>) -> tensor<2xf64> { + %r = enzyme.fwddiff @square(%x, %dx, %y, %dy) { activity=[#enzyme, #enzyme], ret_activity=[#enzyme], width=2 } : (f64, tensor<2xf64>, f64, tensor<2xf64>) -> (tensor<2xf64>) + return %r : tensor<2xf64> + } +} + +// CHECK: func.func @dsq(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>, %[[arg2:.+]]: f64, %[[arg3:.+]]: tensor<2xf64>) -> tensor<2xf64> { +// CHECK-NEXT: %[[i0:.+]] = call @fwddiffesquare(%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]]) : (f64, tensor<2xf64>, f64, tensor<2xf64>) -> tensor<2xf64> +// CHECK-NEXT: return %[[i0]] : tensor<2xf64> +// CHECK-NEXT: } +// CHECK: func.func private @fwddiffesquare(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>, %[[arg2:.+]]: f64, %[[arg3]]: tensor<2xf64>) -> tensor<2xf64> { +// CHECK-NEXT: %[[i0:.+]] = arith.cmpf ult, %[[arg0]], %[[arg2]] : f64 +// CHECK-NEXT: cf.cond_br %[[i0]], ^bb1(%[[arg0]], %[[arg1]] : f64, tensor<2xf64>), ^bb1(%[[arg2]], %[[arg3]] : f64, tensor<2xf64>) +// CHECK-NEXT: ^bb1(%[[i1:.+]]: f64, %[[i2:.+]]: tensor<2xf64>): // 2 preds: ^bb0, ^bb0 +// CHECK-NEXT: return %[[i2]] : tensor<2xf64> +// CHECK-NEXT: } diff --git a/enzyme/test/MLIR/ForwardMode/batched_for.mlir b/enzyme/test/MLIR/ForwardMode/batched_for.mlir new file mode 100644 index 00000000000..95557cb0b6f --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/batched_for.mlir @@ -0,0 +1,33 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @square(%x : f64) -> f64 { + %cst = arith.constant 10.000000e+00 : f64 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %r = scf.for %arg1 = %c0 to %c10 step %c1 iter_args(%arg2 = %cst) -> (f64) { + %n = arith.addf %arg2, %x : f64 + scf.yield %n : f64 + } + return %r : f64 + } + func.func @dsq(%x : f64, %dx : tensor<2xf64>) -> tensor<2xf64> { + %r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme], ret_activity=[#enzyme], width=2 } : (f64, tensor<2xf64>) -> (tensor<2xf64>) + return %r : tensor<2xf64> + } +} + +// CHECK: func.func private @fwddiffesquare(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> { +// CHECK-DAG: %[[cst:.+]] = arith.constant dense<0.000000e+00> : tensor<2xf64> +// CHECK-DAG: %[[cst_0:.+]] = arith.constant 1.000000e+01 : f64 +// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[c10:.+]] = arith.constant 10 : index +// CHECK-NEXT: %[[i0:.+]]:2 = scf.for %[[arg2:.+]] = %[[c0]] to %[[c10]] step %[[c1]] iter_args(%[[arg3:.+]] = %[[cst_0]], %[[arg4:.+]] = %[[cst]]) -> (f64, tensor<2xf64>) { +// CHECK-NEXT: %[[i1:.+]] = arith.addf %[[arg4]], %[[arg1]] : tensor<2xf64> +// CHECK-NEXT: %[[i2:.+]] = arith.addf %[[arg3]], %[[arg0]] : f64 +// CHECK-NEXT: scf.yield %[[i2]], %[[i1]] : f64, tensor<2xf64> +// CHECK-NEXT: } +// CHECK-NEXT: return %[[i0]]#1 : tensor<2xf64> +// CHECK-NEXT: } diff --git a/enzyme/test/MLIR/ForwardMode/batched_if.mlir b/enzyme/test/MLIR/ForwardMode/batched_if.mlir new file mode 100644 index 00000000000..33c9e1b9fe8 --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/batched_if.mlir @@ -0,0 +1,43 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @square(%x : f64, %c : i1) -> f64 { + %c2 = arith.constant 2.000000e+00 : f64 + %c10 = arith.constant 10.000000e+00 : f64 + %r:2 = scf.if %c -> (f64, f64) { + %mul = arith.mulf %x, %x : f64 + scf.yield %mul, %c2 : f64, f64 + } else { + %add = arith.addf %x, %x : f64 + scf.yield %add, %c10 : f64, f64 + } + %res = arith.mulf %r#0, %r#1 : f64 + return %res : f64 + } + func.func @dsq(%x : f64, %dx : tensor<2xf64>, %c : i1) -> tensor<2xf64> { + %r = enzyme.fwddiff @square(%x, %dx, %c) { activity=[#enzyme, #enzyme], ret_activity=[#enzyme], width=2 } : (f64, tensor<2xf64>, i1) -> (tensor<2xf64>) + return %r : tensor<2xf64> + } +} + +// CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>, %[[arg2:.+]]: i1) -> tensor<2xf64> { +// CHECK-DAG: %[[cst2:.+]] = arith.constant 2.000000e+00 : f64 +// CHECK-DAG: %[[cst10:.+]] = arith.constant 1.000000e+01 : f64 +// CHECK-NEXT: %[[r0:.+]]:3 = scf.if %[[arg2]] -> (f64, tensor<2xf64>, f64) { +// CHECK-NEXT: %[[t4:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array}> : (f64) -> tensor<2xf64> +// CHECK-NEXT: %[[t5:.+]] = arith.mulf %[[arg1]], %[[t4]] : tensor<2xf64> +// CHECK-NEXT: %[[t6:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array}> : (f64) -> tensor<2xf64> +// CHECK-NEXT: %[[t7:.+]] = arith.mulf %[[arg1]], %[[t6]] : tensor<2xf64> +// CHECK-NEXT: %[[t8:.+]] = arith.addf %[[t5]], %[[t7]] : tensor<2xf64> +// CHECK-NEXT: %[[t9:.+]] = arith.mulf %[[arg0]], %[[arg0]] : f64 +// CHECK-NEXT: scf.yield %[[t9]], %[[t8]], %[[cst2]] : f64, tensor<2xf64>, f64 +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[e4:.+]] = arith.addf %[[arg1]], %[[arg1]] : tensor<2xf64> +// CHECK-NEXT: %[[e5:.+]] = arith.addf %[[arg0]], %[[arg0]] : f64 +// CHECK-NEXT: scf.yield %[[e5]], %[[e4]], %[[cst10]] : f64, tensor<2xf64>, f64 +// CHECK-NEXT: } +// CHECK-NEXT: %[[r1:.+]] = "enzyme.broadcast"(%[[r0]]#2) <{shape = array}> : (f64) -> tensor<2xf64> +// CHECK-NEXT: %[[r2:.+]] = arith.mulf %[[r0]]#1, %[[r1]] : tensor<2xf64> +// CHECK-NEXT: %[[r3:.+]] = arith.mulf %[[r0]]#0, %[[r0]]#2 : f64 +// CHECK-NEXT: return %[[r2]] : tensor<2xf64> +// CHECK-NEXT: } diff --git a/enzyme/test/MLIR/ForwardMode/batched_scalar.mlir b/enzyme/test/MLIR/ForwardMode/batched_scalar.mlir new file mode 100644 index 00000000000..f06f86d2a04 --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/batched_scalar.mlir @@ -0,0 +1,26 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @square(%x : f64) -> f64{ + %y = arith.mulf %x, %x : f64 + return %y : f64 + } + func.func @dsq(%x : f64, %dx : tensor<2xf64>) -> tensor<2xf64> { + %r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme], ret_activity=[#enzyme], width=2 } : (f64, tensor<2xf64>) -> (tensor<2xf64>) + return %r : tensor<2xf64> + } +} + +// CHECK: func.func @dsq(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> { +// CHECK-NEXT: %[[i0:.+]] = call @fwddiffe2square(%[[arg0]], %[[arg1]]) : (f64, tensor<2xf64>) -> tensor<2xf64> +// CHECK-NEXT: return %[[i0]] : tensor<2xf64> +// CHECK-NEXT: } +// CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> { +// CHECK-NEXT: %[[s0:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array}> : f64 -> tensor<2xf64> +// CHECK-NEXT: %[[i0:.+]] = arith.mulf %[[arg1]], %[[s0]] : tensor<2xf64> +// CHECK-NEXT: %[[s1:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array}> : f64 -> tensor<2xf64> +// CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[s1]] : tensor<2xf64> +// CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : tensor<2xf64> +// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : tensor<2xf64> +// CHECK-NEXT: return %[[i2]] : tensor<2xf64> +// CHECK-NEXT: } diff --git a/enzyme/test/MLIR/ForwardMode/batched_tensor.mlir b/enzyme/test/MLIR/ForwardMode/batched_tensor.mlir new file mode 100644 index 00000000000..11b75f634a6 --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/batched_tensor.mlir @@ -0,0 +1,26 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @square(%x : tensor<10xf64>) -> tensor<10xf64>{ + %y = arith.mulf %x, %x : tensor<10xf64> + return %y : tensor<10xf64> + } + func.func @dsq(%x : tensor<10xf64>, %dx : tensor<2x10xf64>) -> tensor<2x10xf64> { + %r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme], ret_activity=[#enzyme], width=2 } : (tensor<10xf64>, tensor<2x10xf64>) -> (tensor<2x10xf64>) + return %r : tensor<2x10xf64> + } +} + +// CHECK: func.func @dsq(%[[arg0:.+]]: tensor<10xf64>, %[[arg1:.+]]: tensor<2x10xf64>) -> tensor<2x10xf64> { +// CHECK-NEXT: %[[i0:.+]] = call @fwddiffe2square(%[[arg0]], %[[arg1]]) : (tensor<10xf64>, tensor<2x10xf64>) -> tensor<2x10xf64> +// CHECK-NEXT: return %[[i0]] : tensor<2x10xf64> +// CHECK-NEXT: } +// CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: tensor<10xf64>, %[[arg1:.+]]: tensor<2x10xf64>) -> tensor<2x10xf64> { +// CHECK-NEXT: %[[s0:.+]] = "enzyme.broadcast"(%%[[arg0]]) <{shape = array}> : (tensor<10xf64>) -> tensor<2x10xf64> +// CHECK-NEXT: %[[i0:.+]] = arith.mulf %[[arg1]], %[[s0]] : tensor<2x10xf64> +// CHECK-NEXT: %[[s1:.+]] = "enzyme.broadcast"(%%[[arg0]]) <{shape = array}> : (tensor<10xf64>) -> tensor<2x10xf64> +// CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[s1]] : tensor<2x10xf64> +// CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : tensor<2x10xf64> +// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : tensor<10xf64> +// CHECK-NEXT: return %[[i2]] : tensor<2x10xf64> +// CHECK-NEXT: } diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 900c5c813cd..dccbc7b7923 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -275,8 +275,19 @@ SmallVector prepareArgs(const Twine &curIndent, raw_ostream &os, os << ord; } if (!vecValue && !startsWith(ord, "local")) { - if (newFromOriginal && (!lookup || intrinsic != MLIRDerivatives)) + if (newFromOriginal && (!lookup || intrinsic != MLIRDerivatives)) { os << ")"; + if (intrinsic == MLIRDerivatives) { + os << ";\n"; + os << "if (gutils->width != 1) {\n" + << " " << argName << "_" << (idx - 1) + << " = builder.create(\n" + << " op.getLoc(),\n" + << " " << argName << "_" << (idx - 1) << ",\n" + << " llvm::SmallVector({gutils->width}));\n" + << "}"; + } + } if (lookup && intrinsic != MLIRDerivatives) os << ", " << builder << ")"; From 29fe86bde582cc4d8a433e5f0064138cc76b1715 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 28 Dec 2024 20:09:40 -0500 Subject: [PATCH 09/13] slice activity analysis (#2209) --- enzyme/Enzyme/ActivityAnalysis.cpp | 35 ++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index 67cbb5e9ddf..67972fdcc8c 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -579,6 +579,11 @@ bool ActivityAnalyzer::isFunctionArgumentConstant(CallInst *CI, Value *val) { if (Name == "jl_reshape_array" || Name == "ijl_reshape_array") return val != CI->getArgOperand(1); + // Only the 0-th arg impacts activity + if (Name == "jl_genericmemory_copy_slice" || + Name == "ijl_genericmemory_copy_slice") + return val != CI->getArgOperand(0); + // Allocations, deallocations, and c++ guards don't impact the activity // of arguments if (isAllocationFunction(Name, TLI) || isDeallocationFunction(Name, TLI)) @@ -660,6 +665,13 @@ static inline void propagateArgumentInformation( return; } + // Only the 0-th arg impacts activity + if (Name == "jl_genericmemory_copy_slice" || + Name == "ijl_genericmemory_copy_slice") { + propagateFromOperand(CI.getArgOperand(1)); + return; + } + // Only the 1-th arg impacts activity if (Name == "jl_reshape_array" || Name == "ijl_reshape_array") { propagateFromOperand(CI.getArgOperand(1)); @@ -1601,6 +1613,29 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { return true; } + if (funcName == "jl_genericmemory_copy_slice" || + funcName == "ijl_genericmemory_copy_slice") { + if (directions == UP) { + if (isConstantValue(TR, op->getArgOperand(0))) { + InsertConstantValue(TR, Val); + return true; + } + } else { + if (UpHypothesis->isConstantValue(TR, op->getArgOperand(0))) { + InsertConstantValue(TR, Val); + insertConstantsFrom(TR, *UpHypothesis); + return true; + } + } + if (EnzymeEnableRecursiveHypotheses) { + ReEvaluateValueIfInactiveValue[op->getArgOperand(0)].insert(Val); + if (TmpOrig != Val) { + ReEvaluateValueIfInactiveValue[op->getArgOperand(0)].insert( + TmpOrig); + } + } + } + // If requesting empty unknown functions to be considered inactive, // abide by those rules if (called && EnzymeEmptyFnInactive && called->empty() && From bdda0ce16544f503f06a627c3db0579d876a7e04 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 28 Dec 2024 21:12:20 -0500 Subject: [PATCH 10/13] Fix batched shadow reverse runtime AA (#2210) --- enzyme/Enzyme/DiffeGradientUtils.cpp | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/enzyme/Enzyme/DiffeGradientUtils.cpp b/enzyme/Enzyme/DiffeGradientUtils.cpp index eba0de11f54..e88e762e7f0 100644 --- a/enzyme/Enzyme/DiffeGradientUtils.cpp +++ b/enzyme/Enzyme/DiffeGradientUtils.cpp @@ -1179,9 +1179,13 @@ void DiffeGradientUtils::addToInvertedPtrDiffe( // the pointers and conditionally execute. if ((!isa(basePtr) && !isAllocationCall(basePtr, TLI)) && runtimeActivity && !merge) { - Value *shadow = Builder2.CreateICmpNE( - lookupM(getNewFromOriginal(origptr), Builder2), - lookupM(invertPointerM(origptr, Builder2), Builder2)); + Value *primal_val = lookupM(getNewFromOriginal(origptr), Builder2); + Value *shadow_val = + lookupM(invertPointerM(origptr, Builder2), Builder2); + if (getWidth() != 1) { + shadow_val = extractMeta(Builder2, shadow_val, 0); + } + Value *shadow = Builder2.CreateICmpNE(primal_val, shadow_val); BasicBlock *current = Builder2.GetInsertBlock(); BasicBlock *conditional = From 7cf9e90de211ed525aca58e0930b14a827e20bba Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 28 Dec 2024 22:38:20 -0500 Subject: [PATCH 11/13] Fix copyslice v2 (#2211) * Fix copyslice v2 * fix --- enzyme/Enzyme/ActivityAnalysis.cpp | 30 ++++++------------------------ 1 file changed, 6 insertions(+), 24 deletions(-) diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index 67972fdcc8c..44e88abe351 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -668,7 +668,7 @@ static inline void propagateArgumentInformation( // Only the 0-th arg impacts activity if (Name == "jl_genericmemory_copy_slice" || Name == "ijl_genericmemory_copy_slice") { - propagateFromOperand(CI.getArgOperand(1)); + propagateFromOperand(CI.getArgOperand(0)); return; } @@ -1613,29 +1613,6 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { return true; } - if (funcName == "jl_genericmemory_copy_slice" || - funcName == "ijl_genericmemory_copy_slice") { - if (directions == UP) { - if (isConstantValue(TR, op->getArgOperand(0))) { - InsertConstantValue(TR, Val); - return true; - } - } else { - if (UpHypothesis->isConstantValue(TR, op->getArgOperand(0))) { - InsertConstantValue(TR, Val); - insertConstantsFrom(TR, *UpHypothesis); - return true; - } - } - if (EnzymeEnableRecursiveHypotheses) { - ReEvaluateValueIfInactiveValue[op->getArgOperand(0)].insert(Val); - if (TmpOrig != Val) { - ReEvaluateValueIfInactiveValue[op->getArgOperand(0)].insert( - TmpOrig); - } - } - } - // If requesting empty unknown functions to be considered inactive, // abide by those rules if (called && EnzymeEmptyFnInactive && called->empty() && @@ -2751,6 +2728,11 @@ bool ActivityAnalyzer::isValueInactiveFromUsers(TypeResults const &TR, if (AllocaSet.count(TmpOrig)) { continue; } + // We are literally storing our value into ourselves [or relevant + // derived pointer] + if (TmpOrig == val) { + continue; + } if (isa(TmpOrig)) { newAllocaSet.insert(TmpOrig); continue; From 8e36e65e6570827796acf3ae7df421fc244ca8aa Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 1 Jan 2025 14:18:42 -0500 Subject: [PATCH 12/13] Improve cast ft error (#2213) --- enzyme/Enzyme/AdjointGenerator.h | 46 +++++++++++++++++--------------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index 96b8494302e..655bdca6943 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -1380,31 +1380,33 @@ class AdjointGenerator : public llvm::InstVisitor { ss << "Cannot deduce adding type (cast) of " << I; EmitNoTypeError(str, I, gutils, Builder2); } - assert(FT); - auto rule = [&](Value *dif) { - if (I.getOpcode() == CastInst::CastOps::FPTrunc || - I.getOpcode() == CastInst::CastOps::FPExt) { - return Builder2.CreateFPCast(dif, op0->getType()); - } else if (I.getOpcode() == CastInst::CastOps::BitCast) { - return Builder2.CreateBitCast(dif, op0->getType()); - } else if (I.getOpcode() == CastInst::CastOps::Trunc) { - // TODO CHECK THIS - return Builder2.CreateZExt(dif, op0->getType()); - } else { - std::string s; - llvm::raw_string_ostream ss(s); - ss << *I.getParent()->getParent() << "\n"; - ss << "cannot handle above cast " << I << "\n"; - EmitNoDerivativeError(ss.str(), I, gutils, Builder2); - return (llvm::Value *)UndefValue::get(op0->getType()); - } - }; + if (FT) { + + auto rule = [&](Value *dif) { + if (I.getOpcode() == CastInst::CastOps::FPTrunc || + I.getOpcode() == CastInst::CastOps::FPExt) { + return Builder2.CreateFPCast(dif, op0->getType()); + } else if (I.getOpcode() == CastInst::CastOps::BitCast) { + return Builder2.CreateBitCast(dif, op0->getType()); + } else if (I.getOpcode() == CastInst::CastOps::Trunc) { + // TODO CHECK THIS + return Builder2.CreateZExt(dif, op0->getType()); + } else { + std::string s; + llvm::raw_string_ostream ss(s); + ss << *I.getParent()->getParent() << "\n"; + ss << "cannot handle above cast " << I << "\n"; + EmitNoDerivativeError(ss.str(), I, gutils, Builder2); + return (llvm::Value *)UndefValue::get(op0->getType()); + } + }; - Value *dif = diffe(&I, Builder2); - Value *diff = applyChainRule(op0->getType(), Builder2, rule, dif); + Value *dif = diffe(&I, Builder2); + Value *diff = applyChainRule(op0->getType(), Builder2, rule, dif); - addToDiffe(orig_op0, diff, Builder2, FT); + addToDiffe(orig_op0, diff, Builder2, FT); + } } Type *diffTy = gutils->getShadowType(I.getType()); From 7bc73fa8291cfed08456ae93e6f460060a6c8344 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 1 Jan 2025 16:42:51 -0500 Subject: [PATCH 13/13] MLIR: post optimization pipeline (#2214) * MLIR: post optimization pipeline * build start * fix * fix * fix build * format * fixup --- enzyme/.bazelversion | 1 + .../FuncAutoDiffOpInterfaceImpl.cpp | 4 +-- enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp | 27 ++++++++++++++----- enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h | 19 +++++++------ .../MLIR/Interfaces/EnzymeLogicReverse.cpp | 21 +++++++++++++-- .../Enzyme/MLIR/Interfaces/GradientUtils.cpp | 8 +++--- enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h | 14 ++++++---- .../MLIR/Interfaces/GradientUtilsReverse.cpp | 8 +++--- .../MLIR/Interfaces/GradientUtilsReverse.h | 17 +++++++----- enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp | 17 ++++++++++-- enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp | 4 +-- enzyme/Enzyme/MLIR/Passes/Passes.td | 9 +++++++ .../test/MLIR/ForwardMode/batched_branch.mlir | 4 +-- enzyme/test/MLIR/ForwardMode/batched_for.mlir | 2 +- .../test/MLIR/ForwardMode/batched_scalar.mlir | 4 +-- .../test/MLIR/ForwardMode/batched_tensor.mlir | 4 +-- 16 files changed, 113 insertions(+), 50 deletions(-) create mode 100644 enzyme/.bazelversion diff --git a/enzyme/.bazelversion b/enzyme/.bazelversion new file mode 100644 index 00000000000..f22d756da39 --- /dev/null +++ b/enzyme/.bazelversion @@ -0,0 +1 @@ +6.5.0 diff --git a/enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp index 5308304f5b7..54845c740d3 100644 --- a/enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp @@ -73,7 +73,7 @@ class AutoDiffCallFwd fn, RetActivity, ArgActivity, gutils->TA, returnPrimal, mode, freeMemory, width, /* addedType */ nullptr, type_args, volatile_args, - /* augmented */ nullptr); + /* augmented */ nullptr, gutils->postpasses); SmallVector fwdArguments; @@ -173,7 +173,7 @@ class AutoDiffCallRev auto revFn = gutils->Logic.CreateReverseDiff( fn, RetActivity, ArgActivity, gutils->TA, returnPrimal, returnShadow, mode, freeMemory, width, /*addedType*/ nullptr, type_args, - volatile_args, /*augmented*/ nullptr); + volatile_args, /*augmented*/ nullptr, gutils->postpasses); SmallVector revArguments; diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp index fbd337813bc..7a5770ccdaa 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp @@ -13,6 +13,9 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Dominance.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Pass/PassRegistry.h" + #include "llvm/ADT/BreadthFirstIterator.h" #include "EnzymeLogic.h" @@ -78,7 +81,8 @@ FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateForwardDiff( std::vector ArgActivity, MTypeAnalysis &TA, std::vector returnPrimals, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, - std::vector volatile_args, void *augmented) { + std::vector volatile_args, void *augmented, + llvm::StringRef postpasses) { if (fn.getFunctionBody().empty()) { llvm::errs() << fn << "\n"; llvm_unreachable("Differentiating empty function"); @@ -105,7 +109,7 @@ FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateForwardDiff( auto gutils = MDiffeGradientUtils::CreateFromClone( *this, mode, width, fn, TA, type_args, returnPrimalsP, returnShadowsP, RetActivity, ArgActivity, addedType, - /*omp*/ false); + /*omp*/ false, postpasses); ForwardCachedFunctions[tup] = gutils->newFunc; insert_or_assign2( @@ -195,10 +199,19 @@ FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateForwardDiff( if (!valid) return nullptr; - // if (PostOpt) - // PPC.optimizeIntermediate(nf); - // if (EnzymePrint) { - // llvm::errs() << nf << "\n"; - //} + if (postpasses != "") { + mlir::PassManager pm(nf->getContext()); + std::string error_message; + // llvm::raw_string_ostream error_stream(error_message); + mlir::LogicalResult result = mlir::parsePassPipeline(postpasses, pm); + if (mlir::failed(result)) { + return nullptr; + } + + if (!mlir::succeeded(pm.run(nf))) { + return nullptr; + } + } + return nf; } diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h index c8cad6eee27..aef498d5227 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h @@ -196,14 +196,17 @@ class MEnzymeLogic { std::vector returnPrimals, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, std::vector volatile_args, - void *augmented); - - FunctionOpInterface CreateReverseDiff( - FunctionOpInterface fn, std::vector retType, - std::vector constants, MTypeAnalysis &TA, - std::vector returnPrimals, std::vector returnShadows, - DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, - MFnTypeInfo type_args, std::vector volatile_args, void *augmented); + void *augmented, llvm::StringRef postpasses); + + FunctionOpInterface + CreateReverseDiff(FunctionOpInterface fn, std::vector retType, + std::vector constants, MTypeAnalysis &TA, + std::vector returnPrimals, + std::vector returnShadows, DerivativeMode mode, + bool freeMemory, size_t width, mlir::Type addedType, + MFnTypeInfo type_args, std::vector volatile_args, + void *augmented, llvm::StringRef postpasses); + void initializeShadowValues(SmallVector &dominatorToposortBlocks, MGradientUtilsReverse *gutils); diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp index 0812a7ccde5..7ca0e9ea72f 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp @@ -11,6 +11,8 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Pass/PassRegistry.h" #include "EnzymeLogic.h" #include "Interfaces/GradientUtils.h" @@ -182,7 +184,8 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff( std::vector constants, MTypeAnalysis &TA, std::vector returnPrimals, std::vector returnShadows, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, - MFnTypeInfo type_args, std::vector volatile_args, void *augmented) { + MFnTypeInfo type_args, std::vector volatile_args, void *augmented, + llvm::StringRef postpasses) { if (fn.getFunctionBody().empty()) { llvm::errs() << fn << "\n"; @@ -214,7 +217,7 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff( MGradientUtilsReverse *gutils = MGradientUtilsReverse::CreateFromClone( *this, mode, width, fn, TA, type_args, returnPrimalsP, returnShadowsP, - retType, constants, addedType); + retType, constants, addedType, postpasses); ReverseCachedFunctions[tup] = gutils->newFunc; @@ -254,5 +257,19 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff( if (!res.succeeded()) return nullptr; + if (postpasses != "") { + mlir::PassManager pm(nf->getContext()); + std::string error_message; + // llvm::raw_string_ostream error_stream(error_message); + mlir::LogicalResult result = mlir::parsePassPipeline(postpasses, pm); + if (mlir::failed(result)) { + return nullptr; + } + + if (!mlir::succeeded(pm.run(nf))) { + return nullptr; + } + } + return nf; } diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp index 32cb5b79614..0dab1032af9 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp @@ -37,15 +37,15 @@ mlir::enzyme::MGradientUtils::MGradientUtils( ArrayRef ReturnActivity, ArrayRef ArgDiffeTypes_, IRMapping &originalToNewFn_, std::map &originalToNewFnOps_, - DerivativeMode mode, unsigned width, bool omp) + DerivativeMode mode, unsigned width, bool omp, llvm::StringRef postpasses) : newFunc(newFunc_), Logic(Logic), mode(mode), oldFunc(oldFunc_), invertedPointers(invertedPointers_), originalToNewFn(originalToNewFn_), originalToNewFnOps(originalToNewFnOps_), blocksNotForAnalysis(), activityAnalyzer(std::make_unique( blocksNotForAnalysis, constantvalues_, activevals_, ReturnActivity)), - TA(TA_), TR(TR_), omp(omp), returnPrimals(returnPrimals), - returnShadows(returnShadows), width(width), ArgDiffeTypes(ArgDiffeTypes_), - RetDiffeTypes(ReturnActivity) {} + TA(TA_), TR(TR_), omp(omp), postpasses(postpasses), + returnPrimals(returnPrimals), returnShadows(returnShadows), width(width), + ArgDiffeTypes(ArgDiffeTypes_), RetDiffeTypes(ReturnActivity) {} mlir::Value mlir::enzyme::MGradientUtils::getNewFromOriginal( const mlir::Value originst) const { diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h index 1fac52caab3..085bd678f83 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h @@ -36,6 +36,7 @@ class MGradientUtils { MTypeAnalysis &TA; MTypeResults TR; bool omp; + llvm::StringRef postpasses; const llvm::ArrayRef returnPrimals; const llvm::ArrayRef returnShadows; @@ -58,7 +59,8 @@ class MGradientUtils { ArrayRef ArgDiffeTypes_, IRMapping &originalToNewFn_, std::map &originalToNewFnOps_, - DerivativeMode mode, unsigned width, bool omp); + DerivativeMode mode, unsigned width, bool omp, + llvm::StringRef postpasses); void erase(Operation *op) { op->erase(); } void replaceOrigOpWith(Operation *op, ValueRange vals) { for (auto &&[res, rep] : llvm::zip(op->getResults(), vals)) { @@ -113,11 +115,12 @@ class MDiffeGradientUtils : public MGradientUtils { ArrayRef RetActivity, ArrayRef ArgActivity, IRMapping &origToNew_, std::map &origToNewOps_, - DerivativeMode mode, unsigned width, bool omp) + DerivativeMode mode, unsigned width, bool omp, + llvm::StringRef postpasses) : MGradientUtils(Logic, newFunc_, oldFunc_, TA, TR, invertedPointers_, returnPrimals, returnShadows, constantvalues_, activevals_, RetActivity, ArgActivity, origToNew_, - origToNewOps_, mode, width, omp), + origToNewOps_, mode, width, omp, postpasses), initializationBlock(&*(newFunc.getFunctionBody().begin())) {} // Technically diffe constructor @@ -127,7 +130,7 @@ class MDiffeGradientUtils : public MGradientUtils { const llvm::ArrayRef returnPrimals, const llvm::ArrayRef returnShadows, ArrayRef RetActivity, ArrayRef ArgActivity, - mlir::Type additionalArg, bool omp) { + mlir::Type additionalArg, bool omp, llvm::StringRef postpasses) { std::string prefix; switch (mode) { @@ -163,7 +166,8 @@ class MDiffeGradientUtils : public MGradientUtils { return new MDiffeGradientUtils( Logic, newFunc, todiff, TA, TR, invertedPointers, returnPrimals, returnShadows, constant_values, nonconstant_values, RetActivity, - ArgActivity, originalToNew, originalToNewOps, mode, width, omp); + ArgActivity, originalToNew, originalToNewOps, mode, width, omp, + postpasses); } }; diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp index 793b073de0f..c9fe98bc5a5 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp @@ -37,12 +37,12 @@ mlir::enzyme::MGradientUtilsReverse::MGradientUtilsReverse( ArrayRef ReturnActivity, ArrayRef ArgDiffeTypes_, IRMapping &originalToNewFn_, std::map &originalToNewFnOps_, - DerivativeMode mode_, unsigned width) + DerivativeMode mode_, unsigned width, StringRef postpasses) : MDiffeGradientUtils(Logic, newFunc_, oldFunc_, TA_, /*MTypeResults*/ {}, invertedPointers_, returnPrimals, returnShadows, constantvalues_, activevals_, ReturnActivity, ArgDiffeTypes_, originalToNewFn_, originalToNewFnOps_, - mode_, width, /*omp*/ false) {} + mode_, width, /*omp*/ false, postpasses) {} Type mlir::enzyme::MGradientUtilsReverse::getIndexCacheType() { Type indexType = getIndexType(); @@ -138,7 +138,7 @@ MGradientUtilsReverse *MGradientUtilsReverse::CreateFromClone( FunctionOpInterface todiff, MTypeAnalysis &TA, MFnTypeInfo &oldTypeInfo, const ArrayRef returnPrimals, const ArrayRef returnShadows, ArrayRef retType, ArrayRef constant_args, - mlir::Type additionalArg) { + mlir::Type additionalArg, llvm::StringRef postpasses) { std::string prefix; switch (mode_) { @@ -174,5 +174,5 @@ MGradientUtilsReverse *MGradientUtilsReverse::CreateFromClone( return new MGradientUtilsReverse( Logic, newFunc, todiff, TA, invertedPointers, returnPrimals, returnShadows, constant_values, nonconstant_values, retType, - constant_args, originalToNew, originalToNewOps, mode_, width); + constant_args, originalToNew, originalToNewOps, mode_, width, postpasses); } diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h index b6b63c6d13d..7f2d26cba2e 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h @@ -36,7 +36,8 @@ class MGradientUtilsReverse : public MDiffeGradientUtils { ArrayRef ArgDiffeTypes_, IRMapping &originalToNewFn_, std::map &originalToNewFnOps_, - DerivativeMode mode_, unsigned width); + DerivativeMode mode_, unsigned width, + llvm::StringRef postpasses); IRMapping mapReverseModeBlocks; @@ -64,12 +65,14 @@ class MGradientUtilsReverse : public MDiffeGradientUtils { void createReverseModeBlocks(Region &oldFunc, Region &newFunc); - static MGradientUtilsReverse *CreateFromClone( - MEnzymeLogic &Logic, DerivativeMode mode_, unsigned width, - FunctionOpInterface todiff, MTypeAnalysis &TA, MFnTypeInfo &oldTypeInfo, - const ArrayRef returnPrimals, const ArrayRef returnShadows, - llvm::ArrayRef retType, - llvm::ArrayRef constant_args, mlir::Type additionalArg); + static MGradientUtilsReverse * + CreateFromClone(MEnzymeLogic &Logic, DerivativeMode mode_, unsigned width, + FunctionOpInterface todiff, MTypeAnalysis &TA, + MFnTypeInfo &oldTypeInfo, const ArrayRef returnPrimals, + const ArrayRef returnShadows, + llvm::ArrayRef retType, + llvm::ArrayRef constant_args, + mlir::Type additionalArg, llvm::StringRef postpasses); }; } // namespace enzyme diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp index d83532db35a..c91f5400fef 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Builders.h" #include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Pass/PassManager.h" #define DEBUG_TYPE "enzyme" @@ -31,6 +32,18 @@ struct DifferentiatePass : public DifferentiatePassBase { void runOnOperation() override; + void getDependentDialects(DialectRegistry ®istry) const override { + mlir::OpPassManager pm; + mlir::LogicalResult result = mlir::parsePassPipeline(postpasses, pm); + if (!mlir::failed(result)) { + pm.getDependentDialects(registry); + } + + registry + .insert(); + } + static std::vector mode_from_fn(FunctionOpInterface fn, DerivativeMode mode) { std::vector retTypes; @@ -150,7 +163,7 @@ struct DifferentiatePass : public DifferentiatePassBase { FunctionOpInterface newFunc = Logic.CreateForwardDiff( fn, retType, constants, TA, returnPrimals, mode, freeMemory, width, /*addedType*/ nullptr, type_args, volatile_args, - /*augmented*/ nullptr); + /*augmented*/ nullptr, postpasses); if (!newFunc) return failure(); @@ -276,7 +289,7 @@ struct DifferentiatePass : public DifferentiatePassBase { Logic.CreateReverseDiff(fn, retType, arg_activities, TA, returnPrimals, returnShadows, mode, freeMemory, width, /*addedType*/ nullptr, type_args, volatile_args, - /*augmented*/ nullptr); + /*augmented*/ nullptr, postpasses); if (!newFunc) return failure(); diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp index 629a567815e..1e01c8f87bc 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp @@ -121,13 +121,13 @@ struct DifferentiateWrapperPass returnPrimal, mode, freeMemory, width, /*addedType*/ nullptr, type_args, volatile_args, - /*augmented*/ nullptr); + /*augmented*/ nullptr, ""); } else { newFunc = Logic.CreateReverseDiff( fn, RetActivity, ArgActivity, TA, returnPrimal, returnShadow, mode, freeMemory, width, /*addedType*/ nullptr, type_args, volatile_args, - /*augmented*/ nullptr); + /*augmented*/ nullptr, ""); } if (!newFunc) { signalPassFailure(); diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.td b/enzyme/Enzyme/MLIR/Passes/Passes.td index c5b4df76917..758f27946a7 100644 --- a/enzyme/Enzyme/MLIR/Passes/Passes.td +++ b/enzyme/Enzyme/MLIR/Passes/Passes.td @@ -19,6 +19,15 @@ def DifferentiatePass : Pass<"enzyme"> { "cf::ControlFlowDialect", "tensor::TensorDialect", ]; + let options = [ + Option< + /*C++ variable name=*/"postpasses", + /*CLI argument=*/"postpasses", + /*type=*/"std::string", + /*default=*/"", + /*description=*/"Optimization passes to apply to generated derivative functions" + >, + ]; let constructor = "mlir::enzyme::createDifferentiatePass()"; } diff --git a/enzyme/test/MLIR/ForwardMode/batched_branch.mlir b/enzyme/test/MLIR/ForwardMode/batched_branch.mlir index f20989aa424..d663eea5afe 100644 --- a/enzyme/test/MLIR/ForwardMode/batched_branch.mlir +++ b/enzyme/test/MLIR/ForwardMode/batched_branch.mlir @@ -15,10 +15,10 @@ module { } // CHECK: func.func @dsq(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>, %[[arg2:.+]]: f64, %[[arg3:.+]]: tensor<2xf64>) -> tensor<2xf64> { -// CHECK-NEXT: %[[i0:.+]] = call @fwddiffesquare(%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]]) : (f64, tensor<2xf64>, f64, tensor<2xf64>) -> tensor<2xf64> +// CHECK-NEXT: %[[i0:.+]] = call @fwddiffe2square(%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]]) : (f64, tensor<2xf64>, f64, tensor<2xf64>) -> tensor<2xf64> // CHECK-NEXT: return %[[i0]] : tensor<2xf64> // CHECK-NEXT: } -// CHECK: func.func private @fwddiffesquare(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>, %[[arg2:.+]]: f64, %[[arg3]]: tensor<2xf64>) -> tensor<2xf64> { +// CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>, %[[arg2:.+]]: f64, %[[arg3]]: tensor<2xf64>) -> tensor<2xf64> { // CHECK-NEXT: %[[i0:.+]] = arith.cmpf ult, %[[arg0]], %[[arg2]] : f64 // CHECK-NEXT: cf.cond_br %[[i0]], ^bb1(%[[arg0]], %[[arg1]] : f64, tensor<2xf64>), ^bb1(%[[arg2]], %[[arg3]] : f64, tensor<2xf64>) // CHECK-NEXT: ^bb1(%[[i1:.+]]: f64, %[[i2:.+]]: tensor<2xf64>): // 2 preds: ^bb0, ^bb0 diff --git a/enzyme/test/MLIR/ForwardMode/batched_for.mlir b/enzyme/test/MLIR/ForwardMode/batched_for.mlir index 95557cb0b6f..3ec17ec50f5 100644 --- a/enzyme/test/MLIR/ForwardMode/batched_for.mlir +++ b/enzyme/test/MLIR/ForwardMode/batched_for.mlir @@ -18,7 +18,7 @@ module { } } -// CHECK: func.func private @fwddiffesquare(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> { +// CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> { // CHECK-DAG: %[[cst:.+]] = arith.constant dense<0.000000e+00> : tensor<2xf64> // CHECK-DAG: %[[cst_0:.+]] = arith.constant 1.000000e+01 : f64 // CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index diff --git a/enzyme/test/MLIR/ForwardMode/batched_scalar.mlir b/enzyme/test/MLIR/ForwardMode/batched_scalar.mlir index f06f86d2a04..d384bdd0933 100644 --- a/enzyme/test/MLIR/ForwardMode/batched_scalar.mlir +++ b/enzyme/test/MLIR/ForwardMode/batched_scalar.mlir @@ -16,9 +16,9 @@ module { // CHECK-NEXT: return %[[i0]] : tensor<2xf64> // CHECK-NEXT: } // CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> { -// CHECK-NEXT: %[[s0:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array}> : f64 -> tensor<2xf64> +// CHECK-NEXT: %[[s0:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array}> : (f64) -> tensor<2xf64> // CHECK-NEXT: %[[i0:.+]] = arith.mulf %[[arg1]], %[[s0]] : tensor<2xf64> -// CHECK-NEXT: %[[s1:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array}> : f64 -> tensor<2xf64> +// CHECK-NEXT: %[[s1:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array}> : (f64) -> tensor<2xf64> // CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[s1]] : tensor<2xf64> // CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : tensor<2xf64> // CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : tensor<2xf64> diff --git a/enzyme/test/MLIR/ForwardMode/batched_tensor.mlir b/enzyme/test/MLIR/ForwardMode/batched_tensor.mlir index 11b75f634a6..2a565f9ff41 100644 --- a/enzyme/test/MLIR/ForwardMode/batched_tensor.mlir +++ b/enzyme/test/MLIR/ForwardMode/batched_tensor.mlir @@ -16,9 +16,9 @@ module { // CHECK-NEXT: return %[[i0]] : tensor<2x10xf64> // CHECK-NEXT: } // CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: tensor<10xf64>, %[[arg1:.+]]: tensor<2x10xf64>) -> tensor<2x10xf64> { -// CHECK-NEXT: %[[s0:.+]] = "enzyme.broadcast"(%%[[arg0]]) <{shape = array}> : (tensor<10xf64>) -> tensor<2x10xf64> +// CHECK-NEXT: %[[s0:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array}> : (tensor<10xf64>) -> tensor<2x10xf64> // CHECK-NEXT: %[[i0:.+]] = arith.mulf %[[arg1]], %[[s0]] : tensor<2x10xf64> -// CHECK-NEXT: %[[s1:.+]] = "enzyme.broadcast"(%%[[arg0]]) <{shape = array}> : (tensor<10xf64>) -> tensor<2x10xf64> +// CHECK-NEXT: %[[s1:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array}> : (tensor<10xf64>) -> tensor<2x10xf64> // CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[s1]] : tensor<2x10xf64> // CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : tensor<2x10xf64> // CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : tensor<10xf64>