From eb3bfae7b4e0da4a23a6f2b8e824888570788164 Mon Sep 17 00:00:00 2001 From: jumerckx Date: Thu, 16 Jan 2025 22:25:44 +0100 Subject: [PATCH 1/6] don't batch the same function twice. --- enzyme/Enzyme/MLIR/Passes/EnzymeBatchPass.cpp | 32 +++++++++++++++++-- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeBatchPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeBatchPass.cpp index 11a0fb3180f..f4a86a2075d 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeBatchPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeBatchPass.cpp @@ -28,6 +28,18 @@ using namespace enzyme; namespace { +struct BatchCacheKey { + FunctionOpInterface function; + SmallVector batchSizes; + + // for use in std::map: + bool operator<(const BatchCacheKey &other) const { + if (const_cast(function).getName() != const_cast(other.function).getName()) + return const_cast(function).getName() < const_cast(other.function).getName(); + return batchSizes < other.batchSizes; + } +}; + static mlir::TensorType applyBatchSizes(mlir::Type Ty, llvm::ArrayRef batchSizes) { auto T = cast(Ty); @@ -146,6 +158,9 @@ batchCloneFunction(FunctionOpInterface F, Twine name, struct BatchPass : public BatchPassBase { void runOnOperation() override; + // Cache mapping original function and batch sizes to batched function + std::map batchedFunctionCache; + template LogicalResult HandleBatch(SymbolTableCollection &symbolTable, T CI) { SmallVector args; @@ -153,11 +168,22 @@ struct BatchPass : public BatchPassBase { auto *symbolOp = symbolTable.lookupNearestSymbolFrom(CI, CI.getFnAttr()); auto fn = cast(symbolOp); - FunctionOpInterface newFunc = - batchCloneFunction(fn, "batched_" + fn.getName(), CI.getBatchShape()); - + BatchCacheKey key{fn, SmallVector(CI.getBatchShape().begin(), + CI.getBatchShape().end())}; + + // Check if we already have a batched version + auto it = batchedFunctionCache.find(key); + FunctionOpInterface newFunc; + + if (it != batchedFunctionCache.end()) { + newFunc = it->second; + } else { + // Create new batched function and store in cache + newFunc = batchCloneFunction(fn, "batched_" + fn.getName(), CI.getBatchShape()); if (!newFunc) return failure(); + batchedFunctionCache[key] = newFunc; + } OpBuilder builder(CI); auto dCI = From be50d150467af012491ebb631ec59d3e9023b42d Mon Sep 17 00:00:00 2001 From: jumerckx Date: Thu, 16 Jan 2025 22:42:05 +0100 Subject: [PATCH 2/6] batch functions that are called within batched function --- enzyme/Enzyme/MLIR/Passes/EnzymeBatchPass.cpp | 76 +++++++++++++++++-- 1 file changed, 71 insertions(+), 5 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeBatchPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeBatchPass.cpp index f4a86a2075d..5e131f49108 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeBatchPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeBatchPass.cpp @@ -40,6 +40,12 @@ struct BatchCacheKey { } }; +static FunctionOpInterface +batchCloneFunction(FunctionOpInterface F, Twine name, + llvm::ArrayRef batchSizes, + std::map &batchedFunctionCache); + + static mlir::TensorType applyBatchSizes(mlir::Type Ty, llvm::ArrayRef batchSizes) { auto T = cast(Ty); @@ -49,8 +55,58 @@ static mlir::TensorType applyBatchSizes(mlir::Type Ty, return T2; } +static LogicalResult handleCallOp(func::CallOp callOp, OpBuilder &builder, + IRMapping &mapper, + llvm::ArrayRef batchSizes, + std::map &batchedFunctionCache) { + // Get the called function + auto moduleOp = callOp->getParentOfType(); + auto calledFunc = dyn_cast( + moduleOp.lookupSymbol(callOp.getCallee())); + if (!calledFunc) + return failure(); + + // Create cache key for this function and batch size combination + BatchCacheKey key{calledFunc, SmallVector(batchSizes.begin(), + batchSizes.end())}; + + // Look up or create batched version of the called function + FunctionOpInterface batchedFunc; + auto it = batchedFunctionCache.find(key); + if (it != batchedFunctionCache.end()) { + batchedFunc = it->second; + } else { + batchedFunc = batchCloneFunction(calledFunc, + "batched_" + calledFunc.getName(), + batchSizes, + batchedFunctionCache); + if (!batchedFunc) + return failure(); + batchedFunctionCache[key] = batchedFunc; + } + + // Create new call operation to the batched function + SmallVector newOperands; + for (auto operand : callOp->getOperands()) + newOperands.push_back(mapper.lookup(operand)); + + auto newCall = builder.create( + callOp.getLoc(), + batchedFunc.getName(), + batchedFunc.getResultTypes(), + newOperands); + + // Map the results + for (auto [oldResult, newResult] : + llvm::zip(callOp.getResults(), newCall.getResults())) + mapper.map(oldResult, newResult); + + return success(); +} + static void batchCloneRegion(Region *src, Region *dest, IRMapping &mapper, - llvm::ArrayRef batchSizes) { + llvm::ArrayRef batchSizes, + std::map &batchedFunctionCache) { // For each block in src, generate a corresponding block in the dest region. for (auto &blk : *src) { auto newBlk = new Block(); @@ -69,6 +125,13 @@ static void batchCloneRegion(Region *src, Region *dest, IRMapping &mapper, OpBuilder builder(&newBlk, newBlk.end()); for (auto &src : blk) { + if (auto callOp = dyn_cast(&src)) { + if (succeeded(handleCallOp(callOp, builder, mapper, batchSizes, + batchedFunctionCache))) + continue; + } + + if (auto ifaceOp = dyn_cast(&src)) { auto res = ifaceOp.createBatch(builder, mapper, batchSizes); if (res.succeeded()) @@ -101,7 +164,7 @@ static void batchCloneRegion(Region *src, Region *dest, IRMapping &mapper, // Clone the regions. for (auto &&[oldReg, newReg] : llvm::zip(src.getRegions(), newOp->getRegions())) { - batchCloneRegion(&oldReg, &newReg, mapper, batchSizes); + batchCloneRegion(&oldReg, &newReg, mapper, batchSizes, batchedFunctionCache); } // Remember the mapping of any results. @@ -115,7 +178,8 @@ static void batchCloneRegion(Region *src, Region *dest, IRMapping &mapper, static FunctionOpInterface batchCloneFunction(FunctionOpInterface F, Twine name, - llvm::ArrayRef batchSizes) { + llvm::ArrayRef batchSizes, + std::map &batchedFunctionCache) { assert(!F.getFunctionBody().empty()); auto FTy = F.getFunctionType().cast(); @@ -150,7 +214,7 @@ batchCloneFunction(FunctionOpInterface F, Twine name, auto &newReg = NewF.getFunctionBody(); IRMapping mapper; - batchCloneRegion(&origReg, &newReg, mapper, batchSizes); + batchCloneRegion(&origReg, &newReg, mapper, batchSizes, batchedFunctionCache); return NewF; } @@ -179,7 +243,9 @@ struct BatchPass : public BatchPassBase { newFunc = it->second; } else { // Create new batched function and store in cache - newFunc = batchCloneFunction(fn, "batched_" + fn.getName(), CI.getBatchShape()); + newFunc = batchCloneFunction(fn, "batched_" + fn.getName(), + CI.getBatchShape(), + batchedFunctionCache); if (!newFunc) return failure(); batchedFunctionCache[key] = newFunc; From 5e403e0233e76682489289c58ff1a3d5556f8926 Mon Sep 17 00:00:00 2001 From: jumerckx Date: Fri, 17 Jan 2025 09:59:08 +0100 Subject: [PATCH 3/6] initial tests --- enzyme/test/MLIR/Batch/cachefunction.mlir | 19 +++++++++++++++ enzyme/test/MLIR/Batch/call.mlir | 28 +++++++++++++++++++++++ 2 files changed, 47 insertions(+) create mode 100644 enzyme/test/MLIR/Batch/cachefunction.mlir create mode 100644 enzyme/test/MLIR/Batch/call.mlir diff --git a/enzyme/test/MLIR/Batch/cachefunction.mlir b/enzyme/test/MLIR/Batch/cachefunction.mlir new file mode 100644 index 00000000000..bc7bae1b0ba --- /dev/null +++ b/enzyme/test/MLIR/Batch/cachefunction.mlir @@ -0,0 +1,19 @@ +module { + func.func private @f(%arg0: tensor<16xf32>, %arg1: tensor<16xf32>) -> tensor<16xf32> { + return %arg0 : tensor<16xf32> + } + func.func @main(%arg0: tensor<4x16xf32>, %arg1: tensor<4x16xf32>) { + %2 = enzyme.batch @f(%arg0, %arg1) {batch_shape = array} : (tensor<4x16xf32>, tensor<4x16xf32>) -> tensor<4x16xf32> + %3 = enzyme.batch @f(%arg0, %arg1) {batch_shape = array} : (tensor<4x16xf32>, tensor<4x16xf32>) -> tensor<4x16xf32> + return + } +} + +// CHECK: func.func @main(%[[arg0:.+]]: tensor<4x16xf32>, %[[arg1:.+]]: tensor<4x16xf32>) { +// CHECK-NEXT: %[[v0:.+]] = call @batched_f(%[[arg0]], %[[arg1]]) : (tensor<4x16xf32>, tensor<4x16xf32>) -> tensor<4x16xf32> +// CHECK-NEXT: %[[v1:.+]] = call @batched_f(%[[arg0]], %[[arg1]]) : (tensor<4x16xf32>, tensor<4x16xf32>) -> tensor<4x16xf32> +// CHECK-NEXT: return +// CHECK-NEXT: } +// CHECK: func.func private @batched_f(%[[arg0:.+]]: tensor<4x16xf32>, %[[arg1:.+]]: tensor<4x16xf32>) -> tensor<4x16xf32> { +// CHECK-NEXT: return %[[arg0]] : tensor<4x16xf32> +// CHECK-NEXT: } \ No newline at end of file diff --git a/enzyme/test/MLIR/Batch/call.mlir b/enzyme/test/MLIR/Batch/call.mlir new file mode 100644 index 00000000000..faac106f8f2 --- /dev/null +++ b/enzyme/test/MLIR/Batch/call.mlir @@ -0,0 +1,28 @@ +// RUN: %eopt -enzyme-batch %s | FileCheck %s + +module { + func.func private @g(%arg0: tensor<16xf32>) -> tensor<16xf32> { + return %arg0 : tensor<16xf32> + } + func.func private @f(%arg0: tensor<16xf32>, %arg1: tensor<16xf32>) -> tensor<16xf32> { + %1 = func.call @g(%arg0) : (tensor<16xf32>) -> tensor<16xf32> + return %1 : tensor<16xf32> + } + func.func @main(%arg0: tensor<4x16xf32>, %arg1: tensor<4x16xf32>) { + %2 = enzyme.batch @f(%arg0, %arg1) {batch_shape = array} : (tensor<4x16xf32>, tensor<4x16xf32>) -> tensor<4x16xf32> + return + } +} + +// CHECK: func.func @main(%[[arg0:.+]]: tensor<4x16xf32>, %[[arg1:.+]]: tensor<4x16xf32>) { +// CHECK-NEXT: %[[v0:.+]] = call @batched_f(%[[arg0]], %[[arg1]]) : (tensor<4x16xf32>, tensor<4x16xf32>) -> tensor<4x16xf32> +// CHECK-NEXT: return +// CHECK-NEXT: } +// CHECK: func.func private @batched_f(%[[arg0:.+]]: tensor<4x16xf32>, %[[arg1:.+]]: tensor<4x16xf32>) -> tensor<4x16xf32> { +// CHECK-NEXT: %[[v0:.+]] = call @batched_g(%[[arg0]]) : (tensor<4x16xf32>) -> tensor<4x16xf32> +// CHECK-NEXT: return %[[v0]] : tensor<4x16xf32> +// CHECK-NEXT: } +// CHECK: func.func private @batched_g(%[[arg0:.+]]: tensor<4x16xf32>) -> tensor<4x16xf32> { +// CHECK-NEXT: return %[[arg0]] : tensor<4x16xf32> +// CHECK-NEXT: } + From d3aab16b85d7e1d245a8f5152242bbc7635ef0ef Mon Sep 17 00:00:00 2001 From: jumerckx Date: Fri, 17 Jan 2025 10:46:27 +0100 Subject: [PATCH 4/6] support recursive functions --- enzyme/Enzyme/MLIR/Passes/EnzymeBatchPass.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeBatchPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeBatchPass.cpp index 5e131f49108..dd2e8159417 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeBatchPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeBatchPass.cpp @@ -209,6 +209,10 @@ batchCloneFunction(FunctionOpInterface F, Twine name, SymbolTable table(parent); table.insert(NewF); SymbolTable::setSymbolVisibility(NewF, SymbolTable::Visibility::Private); + + // Add the function to the cache BEFORE processing its body to support recursion. + BatchCacheKey key{F, SmallVector(batchSizes.begin(), batchSizes.end())}; + batchedFunctionCache[key] = NewF; auto &origReg = F.getFunctionBody(); auto &newReg = NewF.getFunctionBody(); @@ -246,9 +250,9 @@ struct BatchPass : public BatchPassBase { newFunc = batchCloneFunction(fn, "batched_" + fn.getName(), CI.getBatchShape(), batchedFunctionCache); - if (!newFunc) + if (!newFunc) { return failure(); - batchedFunctionCache[key] = newFunc; + } } OpBuilder builder(CI); From 01548746eb476b69e7d4444bf20b74f23dd121f8 Mon Sep 17 00:00:00 2001 From: jumerckx Date: Fri, 17 Jan 2025 10:48:03 +0100 Subject: [PATCH 5/6] recursion test --- enzyme/test/MLIR/Batch/recursion.mlir | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 enzyme/test/MLIR/Batch/recursion.mlir diff --git a/enzyme/test/MLIR/Batch/recursion.mlir b/enzyme/test/MLIR/Batch/recursion.mlir new file mode 100644 index 00000000000..4f9bfb31eef --- /dev/null +++ b/enzyme/test/MLIR/Batch/recursion.mlir @@ -0,0 +1,21 @@ +// RUN: %eopt -enzyme-batch %s | FileCheck %s + +module { + func.func private @f(%arg0: tensor<16xf32>, %arg1: tensor<16xf32>) -> tensor<16xf32> { + %0 = func.call @f(%arg0, %arg1) : (tensor<16xf32>, tensor<16xf32>) -> tensor<16xf32> + return %0 : tensor<16xf32> + } + func.func @main(%arg0: tensor<4x16xf32>, %arg1: tensor<4x16xf32>) { + %0 = enzyme.batch @f(%arg0, %arg1) {batch_shape = array} : (tensor<4x16xf32>, tensor<4x16xf32>) -> tensor<4x16xf32> + return + } +} + +// CHECK: func.func @main(%[[arg0:.+]]: tensor<4x16xf32>, %[[arg1:.+]]: tensor<4x16xf32>) { +// CHECK-NEXT: %[[v0:.+]] = call @batched_f(%[[arg0]], %[[arg1]]) : (tensor<4x16xf32>, tensor<4x16xf32>) -> tensor<4x16xf32> +// CHECK-NEXT: return +// CHECK-NEXT: } +// CHECK: func.func private @batched_f(%[[arg0:.+]]: tensor<4x16xf32>, %[[arg1:.+]]: tensor<4x16xf32>) -> tensor<4x16xf32> { +// CHECK-NEXT: %[[v0:.+]] = call @batched_f(%[[arg0]], %[[arg1]]) : (tensor<4x16xf32>, tensor<4x16xf32>) -> tensor<4x16xf32> +// CHECK-NEXT: return %[[v0]] : tensor<4x16xf32> +// CHECK-NEXT: } \ No newline at end of file From 0e7e722019b71f4c9258bbef9cb2a76d0ea2fe3b Mon Sep 17 00:00:00 2001 From: jumerckx Date: Fri, 17 Jan 2025 10:49:49 +0100 Subject: [PATCH 6/6] formatting --- enzyme/Enzyme/MLIR/Passes/EnzymeBatchPass.cpp | 92 +++++++++---------- 1 file changed, 45 insertions(+), 47 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeBatchPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeBatchPass.cpp index dd2e8159417..471cd9f4615 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeBatchPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeBatchPass.cpp @@ -31,20 +31,20 @@ namespace { struct BatchCacheKey { FunctionOpInterface function; SmallVector batchSizes; - + // for use in std::map: bool operator<(const BatchCacheKey &other) const { - if (const_cast(function).getName() != const_cast(other.function).getName()) - return const_cast(function).getName() < const_cast(other.function).getName(); + if (const_cast(function).getName() != + const_cast(other.function).getName()) + return const_cast(function).getName() < + const_cast(other.function).getName(); return batchSizes < other.batchSizes; } }; -static FunctionOpInterface -batchCloneFunction(FunctionOpInterface F, Twine name, - llvm::ArrayRef batchSizes, - std::map &batchedFunctionCache); - +static FunctionOpInterface batchCloneFunction( + FunctionOpInterface F, Twine name, llvm::ArrayRef batchSizes, + std::map &batchedFunctionCache); static mlir::TensorType applyBatchSizes(mlir::Type Ty, llvm::ArrayRef batchSizes) { @@ -55,31 +55,30 @@ static mlir::TensorType applyBatchSizes(mlir::Type Ty, return T2; } -static LogicalResult handleCallOp(func::CallOp callOp, OpBuilder &builder, - IRMapping &mapper, - llvm::ArrayRef batchSizes, - std::map &batchedFunctionCache) { +static LogicalResult handleCallOp( + func::CallOp callOp, OpBuilder &builder, IRMapping &mapper, + llvm::ArrayRef batchSizes, + std::map &batchedFunctionCache) { // Get the called function auto moduleOp = callOp->getParentOfType(); - auto calledFunc = dyn_cast( - moduleOp.lookupSymbol(callOp.getCallee())); + auto calledFunc = + dyn_cast(moduleOp.lookupSymbol(callOp.getCallee())); if (!calledFunc) return failure(); // Create cache key for this function and batch size combination - BatchCacheKey key{calledFunc, SmallVector(batchSizes.begin(), - batchSizes.end())}; - + BatchCacheKey key{calledFunc, + SmallVector(batchSizes.begin(), batchSizes.end())}; + // Look up or create batched version of the called function FunctionOpInterface batchedFunc; auto it = batchedFunctionCache.find(key); if (it != batchedFunctionCache.end()) { batchedFunc = it->second; } else { - batchedFunc = batchCloneFunction(calledFunc, - "batched_" + calledFunc.getName(), - batchSizes, - batchedFunctionCache); + batchedFunc = + batchCloneFunction(calledFunc, "batched_" + calledFunc.getName(), + batchSizes, batchedFunctionCache); if (!batchedFunc) return failure(); batchedFunctionCache[key] = batchedFunc; @@ -90,23 +89,22 @@ static LogicalResult handleCallOp(func::CallOp callOp, OpBuilder &builder, for (auto operand : callOp->getOperands()) newOperands.push_back(mapper.lookup(operand)); - auto newCall = builder.create( - callOp.getLoc(), - batchedFunc.getName(), - batchedFunc.getResultTypes(), - newOperands); + auto newCall = + builder.create(callOp.getLoc(), batchedFunc.getName(), + batchedFunc.getResultTypes(), newOperands); // Map the results - for (auto [oldResult, newResult] : + for (auto [oldResult, newResult] : llvm::zip(callOp.getResults(), newCall.getResults())) mapper.map(oldResult, newResult); return success(); } -static void batchCloneRegion(Region *src, Region *dest, IRMapping &mapper, - llvm::ArrayRef batchSizes, - std::map &batchedFunctionCache) { +static void batchCloneRegion( + Region *src, Region *dest, IRMapping &mapper, + llvm::ArrayRef batchSizes, + std::map &batchedFunctionCache) { // For each block in src, generate a corresponding block in the dest region. for (auto &blk : *src) { auto newBlk = new Block(); @@ -126,12 +124,11 @@ static void batchCloneRegion(Region *src, Region *dest, IRMapping &mapper, for (auto &src : blk) { if (auto callOp = dyn_cast(&src)) { - if (succeeded(handleCallOp(callOp, builder, mapper, batchSizes, - batchedFunctionCache))) + if (succeeded(handleCallOp(callOp, builder, mapper, batchSizes, + batchedFunctionCache))) continue; } - if (auto ifaceOp = dyn_cast(&src)) { auto res = ifaceOp.createBatch(builder, mapper, batchSizes); if (res.succeeded()) @@ -164,7 +161,8 @@ static void batchCloneRegion(Region *src, Region *dest, IRMapping &mapper, // Clone the regions. for (auto &&[oldReg, newReg] : llvm::zip(src.getRegions(), newOp->getRegions())) { - batchCloneRegion(&oldReg, &newReg, mapper, batchSizes, batchedFunctionCache); + batchCloneRegion(&oldReg, &newReg, mapper, batchSizes, + batchedFunctionCache); } // Remember the mapping of any results. @@ -176,10 +174,9 @@ static void batchCloneRegion(Region *src, Region *dest, IRMapping &mapper, } } -static FunctionOpInterface -batchCloneFunction(FunctionOpInterface F, Twine name, - llvm::ArrayRef batchSizes, - std::map &batchedFunctionCache) { +static FunctionOpInterface batchCloneFunction( + FunctionOpInterface F, Twine name, llvm::ArrayRef batchSizes, + std::map &batchedFunctionCache) { assert(!F.getFunctionBody().empty()); auto FTy = F.getFunctionType().cast(); @@ -209,9 +206,11 @@ batchCloneFunction(FunctionOpInterface F, Twine name, SymbolTable table(parent); table.insert(NewF); SymbolTable::setSymbolVisibility(NewF, SymbolTable::Visibility::Private); - - // Add the function to the cache BEFORE processing its body to support recursion. - BatchCacheKey key{F, SmallVector(batchSizes.begin(), batchSizes.end())}; + + // Add the function to the cache BEFORE processing its body to support + // recursion. + BatchCacheKey key{F, + SmallVector(batchSizes.begin(), batchSizes.end())}; batchedFunctionCache[key] = NewF; auto &origReg = F.getFunctionBody(); @@ -236,22 +235,21 @@ struct BatchPass : public BatchPassBase { auto *symbolOp = symbolTable.lookupNearestSymbolFrom(CI, CI.getFnAttr()); auto fn = cast(symbolOp); - BatchCacheKey key{fn, SmallVector(CI.getBatchShape().begin(), + BatchCacheKey key{fn, SmallVector(CI.getBatchShape().begin(), CI.getBatchShape().end())}; - + // Check if we already have a batched version auto it = batchedFunctionCache.find(key); FunctionOpInterface newFunc; - + if (it != batchedFunctionCache.end()) { newFunc = it->second; } else { // Create new batched function and store in cache newFunc = batchCloneFunction(fn, "batched_" + fn.getName(), - CI.getBatchShape(), - batchedFunctionCache); + CI.getBatchShape(), batchedFunctionCache); if (!newFunc) { - return failure(); + return failure(); } }