Skip to content

Commit

Permalink
Cache batched functions and recursively batch. (#2222)
Browse files Browse the repository at this point in the history
* don't batch the same function twice.

* batch functions that are called within batched function

* initial tests

* support recursive functions

* recursion test

* formatting
  • Loading branch information
jumerckx authored Jan 17, 2025
1 parent ec3a788 commit 5c632cc
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 12 deletions.
118 changes: 106 additions & 12 deletions enzyme/Enzyme/MLIR/Passes/EnzymeBatchPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,24 @@ using namespace enzyme;

namespace {

struct BatchCacheKey {
FunctionOpInterface function;
SmallVector<int64_t> batchSizes;

// for use in std::map:
bool operator<(const BatchCacheKey &other) const {
if (const_cast<FunctionOpInterface &>(function).getName() !=
const_cast<FunctionOpInterface &>(other.function).getName())
return const_cast<FunctionOpInterface &>(function).getName() <
const_cast<FunctionOpInterface &>(other.function).getName();
return batchSizes < other.batchSizes;
}
};

static FunctionOpInterface batchCloneFunction(
FunctionOpInterface F, Twine name, llvm::ArrayRef<int64_t> batchSizes,
std::map<BatchCacheKey, FunctionOpInterface> &batchedFunctionCache);

static mlir::TensorType applyBatchSizes(mlir::Type Ty,
llvm::ArrayRef<int64_t> batchSizes) {
auto T = dyn_cast<TensorType>(Ty);
Expand All @@ -41,8 +59,56 @@ static mlir::TensorType applyBatchSizes(mlir::Type Ty,
return T2;
}

static void batchCloneRegion(Region *src, Region *dest, IRMapping &mapper,
llvm::ArrayRef<int64_t> batchSizes) {
static LogicalResult handleCallOp(
func::CallOp callOp, OpBuilder &builder, IRMapping &mapper,
llvm::ArrayRef<int64_t> batchSizes,
std::map<BatchCacheKey, FunctionOpInterface> &batchedFunctionCache) {
// Get the called function
auto moduleOp = callOp->getParentOfType<ModuleOp>();
auto calledFunc =
dyn_cast<FunctionOpInterface>(moduleOp.lookupSymbol(callOp.getCallee()));
if (!calledFunc)
return failure();

// Create cache key for this function and batch size combination
BatchCacheKey key{calledFunc,
SmallVector<int64_t>(batchSizes.begin(), batchSizes.end())};

// Look up or create batched version of the called function
FunctionOpInterface batchedFunc;
auto it = batchedFunctionCache.find(key);
if (it != batchedFunctionCache.end()) {
batchedFunc = it->second;
} else {
batchedFunc =
batchCloneFunction(calledFunc, "batched_" + calledFunc.getName(),
batchSizes, batchedFunctionCache);
if (!batchedFunc)
return failure();
batchedFunctionCache[key] = batchedFunc;
}

// Create new call operation to the batched function
SmallVector<Value> newOperands;
for (auto operand : callOp->getOperands())
newOperands.push_back(mapper.lookup(operand));

auto newCall =
builder.create<func::CallOp>(callOp.getLoc(), batchedFunc.getName(),
batchedFunc.getResultTypes(), newOperands);

// Map the results
for (auto [oldResult, newResult] :
llvm::zip(callOp.getResults(), newCall.getResults()))
mapper.map(oldResult, newResult);

return success();
}

static void batchCloneRegion(
Region *src, Region *dest, IRMapping &mapper,
llvm::ArrayRef<int64_t> batchSizes,
std::map<BatchCacheKey, FunctionOpInterface> &batchedFunctionCache) {
// For each block in src, generate a corresponding block in the dest region.
for (auto &blk : *src) {
auto newBlk = new Block();
Expand All @@ -61,6 +127,12 @@ static void batchCloneRegion(Region *src, Region *dest, IRMapping &mapper,
OpBuilder builder(&newBlk, newBlk.end());
for (auto &src : blk) {

if (auto callOp = dyn_cast<func::CallOp>(&src)) {
if (succeeded(handleCallOp(callOp, builder, mapper, batchSizes,
batchedFunctionCache)))
continue;
}

if (auto ifaceOp = dyn_cast<BatchOpInterface>(&src)) {
auto res = ifaceOp.createBatch(builder, mapper, batchSizes);
if (res.succeeded())
Expand Down Expand Up @@ -93,7 +165,8 @@ static void batchCloneRegion(Region *src, Region *dest, IRMapping &mapper,
// Clone the regions.
for (auto &&[oldReg, newReg] :
llvm::zip(src.getRegions(), newOp->getRegions())) {
batchCloneRegion(&oldReg, &newReg, mapper, batchSizes);
batchCloneRegion(&oldReg, &newReg, mapper, batchSizes,
batchedFunctionCache);
}

// Remember the mapping of any results.
Expand All @@ -105,9 +178,9 @@ static void batchCloneRegion(Region *src, Region *dest, IRMapping &mapper,
}
}

static FunctionOpInterface
batchCloneFunction(FunctionOpInterface F, Twine name,
llvm::ArrayRef<int64_t> batchSizes) {
static FunctionOpInterface batchCloneFunction(
FunctionOpInterface F, Twine name, llvm::ArrayRef<int64_t> batchSizes,
std::map<BatchCacheKey, FunctionOpInterface> &batchedFunctionCache) {
assert(!F.getFunctionBody().empty());

auto FTy = F.getFunctionType().cast<FunctionType>();
Expand Down Expand Up @@ -138,30 +211,51 @@ batchCloneFunction(FunctionOpInterface F, Twine name,
table.insert(NewF);
SymbolTable::setSymbolVisibility(NewF, SymbolTable::Visibility::Private);

// Add the function to the cache BEFORE processing its body to support
// recursion.
BatchCacheKey key{F,
SmallVector<int64_t>(batchSizes.begin(), batchSizes.end())};
batchedFunctionCache[key] = NewF;

auto &origReg = F.getFunctionBody();
auto &newReg = NewF.getFunctionBody();

IRMapping mapper;
batchCloneRegion(&origReg, &newReg, mapper, batchSizes);
batchCloneRegion(&origReg, &newReg, mapper, batchSizes, batchedFunctionCache);

return NewF;
}

struct BatchPass : public BatchPassBase<BatchPass> {
void runOnOperation() override;

// Cache mapping original function and batch sizes to batched function
std::map<BatchCacheKey, FunctionOpInterface> batchedFunctionCache;

template <typename T>
LogicalResult HandleBatch(SymbolTableCollection &symbolTable, T CI) {
SmallVector<mlir::Value, 2> args;

auto *symbolOp = symbolTable.lookupNearestSymbolFrom(CI, CI.getFnAttr());
auto fn = cast<FunctionOpInterface>(symbolOp);

FunctionOpInterface newFunc =
batchCloneFunction(fn, "batched_" + fn.getName(), CI.getBatchShape());

if (!newFunc)
return failure();
BatchCacheKey key{fn, SmallVector<int64_t>(CI.getBatchShape().begin(),
CI.getBatchShape().end())};

// Check if we already have a batched version
auto it = batchedFunctionCache.find(key);
FunctionOpInterface newFunc;

if (it != batchedFunctionCache.end()) {
newFunc = it->second;
} else {
// Create new batched function and store in cache
newFunc = batchCloneFunction(fn, "batched_" + fn.getName(),
CI.getBatchShape(), batchedFunctionCache);
if (!newFunc) {
return failure();
}
}

OpBuilder builder(CI);
auto dCI =
Expand Down
19 changes: 19 additions & 0 deletions enzyme/test/MLIR/Batch/cachefunction.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
module {
func.func private @f(%arg0: tensor<16xf32>, %arg1: tensor<16xf32>) -> tensor<16xf32> {
return %arg0 : tensor<16xf32>
}
func.func @main(%arg0: tensor<4x16xf32>, %arg1: tensor<4x16xf32>) {
%2 = enzyme.batch @f(%arg0, %arg1) {batch_shape = array<i64: 4>} : (tensor<4x16xf32>, tensor<4x16xf32>) -> tensor<4x16xf32>
%3 = enzyme.batch @f(%arg0, %arg1) {batch_shape = array<i64: 4>} : (tensor<4x16xf32>, tensor<4x16xf32>) -> tensor<4x16xf32>
return
}
}

// CHECK: func.func @main(%[[arg0:.+]]: tensor<4x16xf32>, %[[arg1:.+]]: tensor<4x16xf32>) {
// CHECK-NEXT: %[[v0:.+]] = call @batched_f(%[[arg0]], %[[arg1]]) : (tensor<4x16xf32>, tensor<4x16xf32>) -> tensor<4x16xf32>
// CHECK-NEXT: %[[v1:.+]] = call @batched_f(%[[arg0]], %[[arg1]]) : (tensor<4x16xf32>, tensor<4x16xf32>) -> tensor<4x16xf32>
// CHECK-NEXT: return
// CHECK-NEXT: }
// CHECK: func.func private @batched_f(%[[arg0:.+]]: tensor<4x16xf32>, %[[arg1:.+]]: tensor<4x16xf32>) -> tensor<4x16xf32> {
// CHECK-NEXT: return %[[arg0]] : tensor<4x16xf32>
// CHECK-NEXT: }
28 changes: 28 additions & 0 deletions enzyme/test/MLIR/Batch/call.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// RUN: %eopt -enzyme-batch %s | FileCheck %s

module {
func.func private @g(%arg0: tensor<16xf32>) -> tensor<16xf32> {
return %arg0 : tensor<16xf32>
}
func.func private @f(%arg0: tensor<16xf32>, %arg1: tensor<16xf32>) -> tensor<16xf32> {
%1 = func.call @g(%arg0) : (tensor<16xf32>) -> tensor<16xf32>
return %1 : tensor<16xf32>
}
func.func @main(%arg0: tensor<4x16xf32>, %arg1: tensor<4x16xf32>) {
%2 = enzyme.batch @f(%arg0, %arg1) {batch_shape = array<i64: 4>} : (tensor<4x16xf32>, tensor<4x16xf32>) -> tensor<4x16xf32>
return
}
}

// CHECK: func.func @main(%[[arg0:.+]]: tensor<4x16xf32>, %[[arg1:.+]]: tensor<4x16xf32>) {
// CHECK-NEXT: %[[v0:.+]] = call @batched_f(%[[arg0]], %[[arg1]]) : (tensor<4x16xf32>, tensor<4x16xf32>) -> tensor<4x16xf32>
// CHECK-NEXT: return
// CHECK-NEXT: }
// CHECK: func.func private @batched_f(%[[arg0:.+]]: tensor<4x16xf32>, %[[arg1:.+]]: tensor<4x16xf32>) -> tensor<4x16xf32> {
// CHECK-NEXT: %[[v0:.+]] = call @batched_g(%[[arg0]]) : (tensor<4x16xf32>) -> tensor<4x16xf32>
// CHECK-NEXT: return %[[v0]] : tensor<4x16xf32>
// CHECK-NEXT: }
// CHECK: func.func private @batched_g(%[[arg0:.+]]: tensor<4x16xf32>) -> tensor<4x16xf32> {
// CHECK-NEXT: return %[[arg0]] : tensor<4x16xf32>
// CHECK-NEXT: }

21 changes: 21 additions & 0 deletions enzyme/test/MLIR/Batch/recursion.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// RUN: %eopt -enzyme-batch %s | FileCheck %s

module {
func.func private @f(%arg0: tensor<16xf32>, %arg1: tensor<16xf32>) -> tensor<16xf32> {
%0 = func.call @f(%arg0, %arg1) : (tensor<16xf32>, tensor<16xf32>) -> tensor<16xf32>
return %0 : tensor<16xf32>
}
func.func @main(%arg0: tensor<4x16xf32>, %arg1: tensor<4x16xf32>) {
%0 = enzyme.batch @f(%arg0, %arg1) {batch_shape = array<i64: 4>} : (tensor<4x16xf32>, tensor<4x16xf32>) -> tensor<4x16xf32>
return
}
}

// CHECK: func.func @main(%[[arg0:.+]]: tensor<4x16xf32>, %[[arg1:.+]]: tensor<4x16xf32>) {
// CHECK-NEXT: %[[v0:.+]] = call @batched_f(%[[arg0]], %[[arg1]]) : (tensor<4x16xf32>, tensor<4x16xf32>) -> tensor<4x16xf32>
// CHECK-NEXT: return
// CHECK-NEXT: }
// CHECK: func.func private @batched_f(%[[arg0:.+]]: tensor<4x16xf32>, %[[arg1:.+]]: tensor<4x16xf32>) -> tensor<4x16xf32> {
// CHECK-NEXT: %[[v0:.+]] = call @batched_f(%[[arg0]], %[[arg1]]) : (tensor<4x16xf32>, tensor<4x16xf32>) -> tensor<4x16xf32>
// CHECK-NEXT: return %[[v0]] : tensor<4x16xf32>
// CHECK-NEXT: }

0 comments on commit 5c632cc

Please sign in to comment.