diff --git a/enzyme/.bazelversion b/enzyme/.bazelversion new file mode 100644 index 00000000000..f22d756da39 --- /dev/null +++ b/enzyme/.bazelversion @@ -0,0 +1 @@ +6.5.0 diff --git a/enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp index 5308304f5b7..54845c740d3 100644 --- a/enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp @@ -73,7 +73,7 @@ class AutoDiffCallFwd fn, RetActivity, ArgActivity, gutils->TA, returnPrimal, mode, freeMemory, width, /* addedType */ nullptr, type_args, volatile_args, - /* augmented */ nullptr); + /* augmented */ nullptr, gutils->postpasses); SmallVector fwdArguments; @@ -173,7 +173,7 @@ class AutoDiffCallRev auto revFn = gutils->Logic.CreateReverseDiff( fn, RetActivity, ArgActivity, gutils->TA, returnPrimal, returnShadow, mode, freeMemory, width, /*addedType*/ nullptr, type_args, - volatile_args, /*augmented*/ nullptr); + volatile_args, /*augmented*/ nullptr, gutils->postpasses); SmallVector revArguments; diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp index fbd337813bc..7a5770ccdaa 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp @@ -13,6 +13,9 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Dominance.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Pass/PassRegistry.h" + #include "llvm/ADT/BreadthFirstIterator.h" #include "EnzymeLogic.h" @@ -78,7 +81,8 @@ FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateForwardDiff( std::vector ArgActivity, MTypeAnalysis &TA, std::vector returnPrimals, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, - std::vector volatile_args, void *augmented) { + std::vector volatile_args, void *augmented, + llvm::StringRef postpasses) { if (fn.getFunctionBody().empty()) { llvm::errs() << fn << "\n"; llvm_unreachable("Differentiating empty function"); @@ -105,7 +109,7 @@ FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateForwardDiff( auto gutils = MDiffeGradientUtils::CreateFromClone( *this, mode, width, fn, TA, type_args, returnPrimalsP, returnShadowsP, RetActivity, ArgActivity, addedType, - /*omp*/ false); + /*omp*/ false, postpasses); ForwardCachedFunctions[tup] = gutils->newFunc; insert_or_assign2( @@ -195,10 +199,19 @@ FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateForwardDiff( if (!valid) return nullptr; - // if (PostOpt) - // PPC.optimizeIntermediate(nf); - // if (EnzymePrint) { - // llvm::errs() << nf << "\n"; - //} + if (postpasses != "") { + mlir::PassManager pm(nf->getContext()); + std::string error_message; + // llvm::raw_string_ostream error_stream(error_message); + mlir::LogicalResult result = mlir::parsePassPipeline(postpasses, pm); + if (mlir::failed(result)) { + return nullptr; + } + + if (!mlir::succeeded(pm.run(nf))) { + return nullptr; + } + } + return nf; } diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h index c8cad6eee27..aef498d5227 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h @@ -196,14 +196,17 @@ class MEnzymeLogic { std::vector returnPrimals, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, std::vector volatile_args, - void *augmented); - - FunctionOpInterface CreateReverseDiff( - FunctionOpInterface fn, std::vector retType, - std::vector constants, MTypeAnalysis &TA, - std::vector returnPrimals, std::vector returnShadows, - DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, - MFnTypeInfo type_args, std::vector volatile_args, void *augmented); + void *augmented, llvm::StringRef postpasses); + + FunctionOpInterface + CreateReverseDiff(FunctionOpInterface fn, std::vector retType, + std::vector constants, MTypeAnalysis &TA, + std::vector returnPrimals, + std::vector returnShadows, DerivativeMode mode, + bool freeMemory, size_t width, mlir::Type addedType, + MFnTypeInfo type_args, std::vector volatile_args, + void *augmented, llvm::StringRef postpasses); + void initializeShadowValues(SmallVector &dominatorToposortBlocks, MGradientUtilsReverse *gutils); diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp index 0812a7ccde5..7ca0e9ea72f 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp @@ -11,6 +11,8 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Pass/PassRegistry.h" #include "EnzymeLogic.h" #include "Interfaces/GradientUtils.h" @@ -182,7 +184,8 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff( std::vector constants, MTypeAnalysis &TA, std::vector returnPrimals, std::vector returnShadows, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, - MFnTypeInfo type_args, std::vector volatile_args, void *augmented) { + MFnTypeInfo type_args, std::vector volatile_args, void *augmented, + llvm::StringRef postpasses) { if (fn.getFunctionBody().empty()) { llvm::errs() << fn << "\n"; @@ -214,7 +217,7 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff( MGradientUtilsReverse *gutils = MGradientUtilsReverse::CreateFromClone( *this, mode, width, fn, TA, type_args, returnPrimalsP, returnShadowsP, - retType, constants, addedType); + retType, constants, addedType, postpasses); ReverseCachedFunctions[tup] = gutils->newFunc; @@ -254,5 +257,19 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff( if (!res.succeeded()) return nullptr; + if (postpasses != "") { + mlir::PassManager pm(nf->getContext()); + std::string error_message; + // llvm::raw_string_ostream error_stream(error_message); + mlir::LogicalResult result = mlir::parsePassPipeline(postpasses, pm); + if (mlir::failed(result)) { + return nullptr; + } + + if (!mlir::succeeded(pm.run(nf))) { + return nullptr; + } + } + return nf; } diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp index 32cb5b79614..0dab1032af9 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp @@ -37,15 +37,15 @@ mlir::enzyme::MGradientUtils::MGradientUtils( ArrayRef ReturnActivity, ArrayRef ArgDiffeTypes_, IRMapping &originalToNewFn_, std::map &originalToNewFnOps_, - DerivativeMode mode, unsigned width, bool omp) + DerivativeMode mode, unsigned width, bool omp, llvm::StringRef postpasses) : newFunc(newFunc_), Logic(Logic), mode(mode), oldFunc(oldFunc_), invertedPointers(invertedPointers_), originalToNewFn(originalToNewFn_), originalToNewFnOps(originalToNewFnOps_), blocksNotForAnalysis(), activityAnalyzer(std::make_unique( blocksNotForAnalysis, constantvalues_, activevals_, ReturnActivity)), - TA(TA_), TR(TR_), omp(omp), returnPrimals(returnPrimals), - returnShadows(returnShadows), width(width), ArgDiffeTypes(ArgDiffeTypes_), - RetDiffeTypes(ReturnActivity) {} + TA(TA_), TR(TR_), omp(omp), postpasses(postpasses), + returnPrimals(returnPrimals), returnShadows(returnShadows), width(width), + ArgDiffeTypes(ArgDiffeTypes_), RetDiffeTypes(ReturnActivity) {} mlir::Value mlir::enzyme::MGradientUtils::getNewFromOriginal( const mlir::Value originst) const { diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h index 1fac52caab3..085bd678f83 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h @@ -36,6 +36,7 @@ class MGradientUtils { MTypeAnalysis &TA; MTypeResults TR; bool omp; + llvm::StringRef postpasses; const llvm::ArrayRef returnPrimals; const llvm::ArrayRef returnShadows; @@ -58,7 +59,8 @@ class MGradientUtils { ArrayRef ArgDiffeTypes_, IRMapping &originalToNewFn_, std::map &originalToNewFnOps_, - DerivativeMode mode, unsigned width, bool omp); + DerivativeMode mode, unsigned width, bool omp, + llvm::StringRef postpasses); void erase(Operation *op) { op->erase(); } void replaceOrigOpWith(Operation *op, ValueRange vals) { for (auto &&[res, rep] : llvm::zip(op->getResults(), vals)) { @@ -113,11 +115,12 @@ class MDiffeGradientUtils : public MGradientUtils { ArrayRef RetActivity, ArrayRef ArgActivity, IRMapping &origToNew_, std::map &origToNewOps_, - DerivativeMode mode, unsigned width, bool omp) + DerivativeMode mode, unsigned width, bool omp, + llvm::StringRef postpasses) : MGradientUtils(Logic, newFunc_, oldFunc_, TA, TR, invertedPointers_, returnPrimals, returnShadows, constantvalues_, activevals_, RetActivity, ArgActivity, origToNew_, - origToNewOps_, mode, width, omp), + origToNewOps_, mode, width, omp, postpasses), initializationBlock(&*(newFunc.getFunctionBody().begin())) {} // Technically diffe constructor @@ -127,7 +130,7 @@ class MDiffeGradientUtils : public MGradientUtils { const llvm::ArrayRef returnPrimals, const llvm::ArrayRef returnShadows, ArrayRef RetActivity, ArrayRef ArgActivity, - mlir::Type additionalArg, bool omp) { + mlir::Type additionalArg, bool omp, llvm::StringRef postpasses) { std::string prefix; switch (mode) { @@ -163,7 +166,8 @@ class MDiffeGradientUtils : public MGradientUtils { return new MDiffeGradientUtils( Logic, newFunc, todiff, TA, TR, invertedPointers, returnPrimals, returnShadows, constant_values, nonconstant_values, RetActivity, - ArgActivity, originalToNew, originalToNewOps, mode, width, omp); + ArgActivity, originalToNew, originalToNewOps, mode, width, omp, + postpasses); } }; diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp index 793b073de0f..c9fe98bc5a5 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp @@ -37,12 +37,12 @@ mlir::enzyme::MGradientUtilsReverse::MGradientUtilsReverse( ArrayRef ReturnActivity, ArrayRef ArgDiffeTypes_, IRMapping &originalToNewFn_, std::map &originalToNewFnOps_, - DerivativeMode mode_, unsigned width) + DerivativeMode mode_, unsigned width, StringRef postpasses) : MDiffeGradientUtils(Logic, newFunc_, oldFunc_, TA_, /*MTypeResults*/ {}, invertedPointers_, returnPrimals, returnShadows, constantvalues_, activevals_, ReturnActivity, ArgDiffeTypes_, originalToNewFn_, originalToNewFnOps_, - mode_, width, /*omp*/ false) {} + mode_, width, /*omp*/ false, postpasses) {} Type mlir::enzyme::MGradientUtilsReverse::getIndexCacheType() { Type indexType = getIndexType(); @@ -138,7 +138,7 @@ MGradientUtilsReverse *MGradientUtilsReverse::CreateFromClone( FunctionOpInterface todiff, MTypeAnalysis &TA, MFnTypeInfo &oldTypeInfo, const ArrayRef returnPrimals, const ArrayRef returnShadows, ArrayRef retType, ArrayRef constant_args, - mlir::Type additionalArg) { + mlir::Type additionalArg, llvm::StringRef postpasses) { std::string prefix; switch (mode_) { @@ -174,5 +174,5 @@ MGradientUtilsReverse *MGradientUtilsReverse::CreateFromClone( return new MGradientUtilsReverse( Logic, newFunc, todiff, TA, invertedPointers, returnPrimals, returnShadows, constant_values, nonconstant_values, retType, - constant_args, originalToNew, originalToNewOps, mode_, width); + constant_args, originalToNew, originalToNewOps, mode_, width, postpasses); } diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h index b6b63c6d13d..7f2d26cba2e 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h @@ -36,7 +36,8 @@ class MGradientUtilsReverse : public MDiffeGradientUtils { ArrayRef ArgDiffeTypes_, IRMapping &originalToNewFn_, std::map &originalToNewFnOps_, - DerivativeMode mode_, unsigned width); + DerivativeMode mode_, unsigned width, + llvm::StringRef postpasses); IRMapping mapReverseModeBlocks; @@ -64,12 +65,14 @@ class MGradientUtilsReverse : public MDiffeGradientUtils { void createReverseModeBlocks(Region &oldFunc, Region &newFunc); - static MGradientUtilsReverse *CreateFromClone( - MEnzymeLogic &Logic, DerivativeMode mode_, unsigned width, - FunctionOpInterface todiff, MTypeAnalysis &TA, MFnTypeInfo &oldTypeInfo, - const ArrayRef returnPrimals, const ArrayRef returnShadows, - llvm::ArrayRef retType, - llvm::ArrayRef constant_args, mlir::Type additionalArg); + static MGradientUtilsReverse * + CreateFromClone(MEnzymeLogic &Logic, DerivativeMode mode_, unsigned width, + FunctionOpInterface todiff, MTypeAnalysis &TA, + MFnTypeInfo &oldTypeInfo, const ArrayRef returnPrimals, + const ArrayRef returnShadows, + llvm::ArrayRef retType, + llvm::ArrayRef constant_args, + mlir::Type additionalArg, llvm::StringRef postpasses); }; } // namespace enzyme diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp index d83532db35a..c91f5400fef 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Builders.h" #include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Pass/PassManager.h" #define DEBUG_TYPE "enzyme" @@ -31,6 +32,18 @@ struct DifferentiatePass : public DifferentiatePassBase { void runOnOperation() override; + void getDependentDialects(DialectRegistry ®istry) const override { + mlir::OpPassManager pm; + mlir::LogicalResult result = mlir::parsePassPipeline(postpasses, pm); + if (!mlir::failed(result)) { + pm.getDependentDialects(registry); + } + + registry + .insert(); + } + static std::vector mode_from_fn(FunctionOpInterface fn, DerivativeMode mode) { std::vector retTypes; @@ -150,7 +163,7 @@ struct DifferentiatePass : public DifferentiatePassBase { FunctionOpInterface newFunc = Logic.CreateForwardDiff( fn, retType, constants, TA, returnPrimals, mode, freeMemory, width, /*addedType*/ nullptr, type_args, volatile_args, - /*augmented*/ nullptr); + /*augmented*/ nullptr, postpasses); if (!newFunc) return failure(); @@ -276,7 +289,7 @@ struct DifferentiatePass : public DifferentiatePassBase { Logic.CreateReverseDiff(fn, retType, arg_activities, TA, returnPrimals, returnShadows, mode, freeMemory, width, /*addedType*/ nullptr, type_args, volatile_args, - /*augmented*/ nullptr); + /*augmented*/ nullptr, postpasses); if (!newFunc) return failure(); diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp index 629a567815e..1e01c8f87bc 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp @@ -121,13 +121,13 @@ struct DifferentiateWrapperPass returnPrimal, mode, freeMemory, width, /*addedType*/ nullptr, type_args, volatile_args, - /*augmented*/ nullptr); + /*augmented*/ nullptr, ""); } else { newFunc = Logic.CreateReverseDiff( fn, RetActivity, ArgActivity, TA, returnPrimal, returnShadow, mode, freeMemory, width, /*addedType*/ nullptr, type_args, volatile_args, - /*augmented*/ nullptr); + /*augmented*/ nullptr, ""); } if (!newFunc) { signalPassFailure(); diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.td b/enzyme/Enzyme/MLIR/Passes/Passes.td index c5b4df76917..758f27946a7 100644 --- a/enzyme/Enzyme/MLIR/Passes/Passes.td +++ b/enzyme/Enzyme/MLIR/Passes/Passes.td @@ -19,6 +19,15 @@ def DifferentiatePass : Pass<"enzyme"> { "cf::ControlFlowDialect", "tensor::TensorDialect", ]; + let options = [ + Option< + /*C++ variable name=*/"postpasses", + /*CLI argument=*/"postpasses", + /*type=*/"std::string", + /*default=*/"", + /*description=*/"Optimization passes to apply to generated derivative functions" + >, + ]; let constructor = "mlir::enzyme::createDifferentiatePass()"; } diff --git a/enzyme/test/MLIR/ForwardMode/batched_branch.mlir b/enzyme/test/MLIR/ForwardMode/batched_branch.mlir index f20989aa424..d663eea5afe 100644 --- a/enzyme/test/MLIR/ForwardMode/batched_branch.mlir +++ b/enzyme/test/MLIR/ForwardMode/batched_branch.mlir @@ -15,10 +15,10 @@ module { } // CHECK: func.func @dsq(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>, %[[arg2:.+]]: f64, %[[arg3:.+]]: tensor<2xf64>) -> tensor<2xf64> { -// CHECK-NEXT: %[[i0:.+]] = call @fwddiffesquare(%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]]) : (f64, tensor<2xf64>, f64, tensor<2xf64>) -> tensor<2xf64> +// CHECK-NEXT: %[[i0:.+]] = call @fwddiffe2square(%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]]) : (f64, tensor<2xf64>, f64, tensor<2xf64>) -> tensor<2xf64> // CHECK-NEXT: return %[[i0]] : tensor<2xf64> // CHECK-NEXT: } -// CHECK: func.func private @fwddiffesquare(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>, %[[arg2:.+]]: f64, %[[arg3]]: tensor<2xf64>) -> tensor<2xf64> { +// CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>, %[[arg2:.+]]: f64, %[[arg3]]: tensor<2xf64>) -> tensor<2xf64> { // CHECK-NEXT: %[[i0:.+]] = arith.cmpf ult, %[[arg0]], %[[arg2]] : f64 // CHECK-NEXT: cf.cond_br %[[i0]], ^bb1(%[[arg0]], %[[arg1]] : f64, tensor<2xf64>), ^bb1(%[[arg2]], %[[arg3]] : f64, tensor<2xf64>) // CHECK-NEXT: ^bb1(%[[i1:.+]]: f64, %[[i2:.+]]: tensor<2xf64>): // 2 preds: ^bb0, ^bb0 diff --git a/enzyme/test/MLIR/ForwardMode/batched_for.mlir b/enzyme/test/MLIR/ForwardMode/batched_for.mlir index 95557cb0b6f..3ec17ec50f5 100644 --- a/enzyme/test/MLIR/ForwardMode/batched_for.mlir +++ b/enzyme/test/MLIR/ForwardMode/batched_for.mlir @@ -18,7 +18,7 @@ module { } } -// CHECK: func.func private @fwddiffesquare(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> { +// CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> { // CHECK-DAG: %[[cst:.+]] = arith.constant dense<0.000000e+00> : tensor<2xf64> // CHECK-DAG: %[[cst_0:.+]] = arith.constant 1.000000e+01 : f64 // CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index diff --git a/enzyme/test/MLIR/ForwardMode/batched_scalar.mlir b/enzyme/test/MLIR/ForwardMode/batched_scalar.mlir index f06f86d2a04..d384bdd0933 100644 --- a/enzyme/test/MLIR/ForwardMode/batched_scalar.mlir +++ b/enzyme/test/MLIR/ForwardMode/batched_scalar.mlir @@ -16,9 +16,9 @@ module { // CHECK-NEXT: return %[[i0]] : tensor<2xf64> // CHECK-NEXT: } // CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> { -// CHECK-NEXT: %[[s0:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array}> : f64 -> tensor<2xf64> +// CHECK-NEXT: %[[s0:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array}> : (f64) -> tensor<2xf64> // CHECK-NEXT: %[[i0:.+]] = arith.mulf %[[arg1]], %[[s0]] : tensor<2xf64> -// CHECK-NEXT: %[[s1:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array}> : f64 -> tensor<2xf64> +// CHECK-NEXT: %[[s1:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array}> : (f64) -> tensor<2xf64> // CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[s1]] : tensor<2xf64> // CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : tensor<2xf64> // CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : tensor<2xf64> diff --git a/enzyme/test/MLIR/ForwardMode/batched_tensor.mlir b/enzyme/test/MLIR/ForwardMode/batched_tensor.mlir index 11b75f634a6..2a565f9ff41 100644 --- a/enzyme/test/MLIR/ForwardMode/batched_tensor.mlir +++ b/enzyme/test/MLIR/ForwardMode/batched_tensor.mlir @@ -16,9 +16,9 @@ module { // CHECK-NEXT: return %[[i0]] : tensor<2x10xf64> // CHECK-NEXT: } // CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: tensor<10xf64>, %[[arg1:.+]]: tensor<2x10xf64>) -> tensor<2x10xf64> { -// CHECK-NEXT: %[[s0:.+]] = "enzyme.broadcast"(%%[[arg0]]) <{shape = array}> : (tensor<10xf64>) -> tensor<2x10xf64> +// CHECK-NEXT: %[[s0:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array}> : (tensor<10xf64>) -> tensor<2x10xf64> // CHECK-NEXT: %[[i0:.+]] = arith.mulf %[[arg1]], %[[s0]] : tensor<2x10xf64> -// CHECK-NEXT: %[[s1:.+]] = "enzyme.broadcast"(%%[[arg0]]) <{shape = array}> : (tensor<10xf64>) -> tensor<2x10xf64> +// CHECK-NEXT: %[[s1:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array}> : (tensor<10xf64>) -> tensor<2x10xf64> // CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[s1]] : tensor<2x10xf64> // CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : tensor<2x10xf64> // CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : tensor<10xf64>