Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache batched functions and recursively batch. #2222

Merged
merged 6 commits into from
Jan 17, 2025
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 99 additions & 7 deletions enzyme/Enzyme/MLIR/Passes/EnzymeBatchPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,24 @@ using namespace enzyme;

namespace {

struct BatchCacheKey {
FunctionOpInterface function;
SmallVector<int64_t> batchSizes;

// for use in std::map:
bool operator<(const BatchCacheKey &other) const {
if (const_cast<FunctionOpInterface &>(function).getName() != const_cast<FunctionOpInterface &>(other.function).getName())
return const_cast<FunctionOpInterface &>(function).getName() < const_cast<FunctionOpInterface &>(other.function).getName();
return batchSizes < other.batchSizes;
}
};

static FunctionOpInterface
batchCloneFunction(FunctionOpInterface F, Twine name,
llvm::ArrayRef<int64_t> batchSizes,
std::map<BatchCacheKey, FunctionOpInterface> &batchedFunctionCache);


static mlir::TensorType applyBatchSizes(mlir::Type Ty,
llvm::ArrayRef<int64_t> batchSizes) {
auto T = cast<TensorType>(Ty);
Expand All @@ -37,8 +55,58 @@ static mlir::TensorType applyBatchSizes(mlir::Type Ty,
return T2;
}

static LogicalResult handleCallOp(func::CallOp callOp, OpBuilder &builder,
IRMapping &mapper,
llvm::ArrayRef<int64_t> batchSizes,
std::map<BatchCacheKey, FunctionOpInterface> &batchedFunctionCache) {
// Get the called function
auto moduleOp = callOp->getParentOfType<ModuleOp>();
auto calledFunc = dyn_cast<FunctionOpInterface>(
moduleOp.lookupSymbol(callOp.getCallee()));
if (!calledFunc)
return failure();

// Create cache key for this function and batch size combination
BatchCacheKey key{calledFunc, SmallVector<int64_t>(batchSizes.begin(),
batchSizes.end())};

// Look up or create batched version of the called function
FunctionOpInterface batchedFunc;
auto it = batchedFunctionCache.find(key);
if (it != batchedFunctionCache.end()) {
batchedFunc = it->second;
} else {
batchedFunc = batchCloneFunction(calledFunc,
"batched_" + calledFunc.getName(),
batchSizes,
batchedFunctionCache);
if (!batchedFunc)
return failure();
batchedFunctionCache[key] = batchedFunc;
}

// Create new call operation to the batched function
SmallVector<Value> newOperands;
for (auto operand : callOp->getOperands())
newOperands.push_back(mapper.lookup(operand));

auto newCall = builder.create<func::CallOp>(
callOp.getLoc(),
batchedFunc.getName(),
batchedFunc.getResultTypes(),
newOperands);

// Map the results
for (auto [oldResult, newResult] :
llvm::zip(callOp.getResults(), newCall.getResults()))
mapper.map(oldResult, newResult);

return success();
}

static void batchCloneRegion(Region *src, Region *dest, IRMapping &mapper,
llvm::ArrayRef<int64_t> batchSizes) {
llvm::ArrayRef<int64_t> batchSizes,
std::map<BatchCacheKey, FunctionOpInterface> &batchedFunctionCache) {
// For each block in src, generate a corresponding block in the dest region.
for (auto &blk : *src) {
auto newBlk = new Block();
Expand All @@ -57,6 +125,13 @@ static void batchCloneRegion(Region *src, Region *dest, IRMapping &mapper,
OpBuilder builder(&newBlk, newBlk.end());
for (auto &src : blk) {

if (auto callOp = dyn_cast<func::CallOp>(&src)) {
if (succeeded(handleCallOp(callOp, builder, mapper, batchSizes,
batchedFunctionCache)))
continue;
}


if (auto ifaceOp = dyn_cast<BatchOpInterface>(&src)) {
auto res = ifaceOp.createBatch(builder, mapper, batchSizes);
if (res.succeeded())
Expand Down Expand Up @@ -89,7 +164,7 @@ static void batchCloneRegion(Region *src, Region *dest, IRMapping &mapper,
// Clone the regions.
for (auto &&[oldReg, newReg] :
llvm::zip(src.getRegions(), newOp->getRegions())) {
batchCloneRegion(&oldReg, &newReg, mapper, batchSizes);
batchCloneRegion(&oldReg, &newReg, mapper, batchSizes, batchedFunctionCache);
}

// Remember the mapping of any results.
Expand All @@ -103,7 +178,8 @@ static void batchCloneRegion(Region *src, Region *dest, IRMapping &mapper,

static FunctionOpInterface
batchCloneFunction(FunctionOpInterface F, Twine name,
llvm::ArrayRef<int64_t> batchSizes) {
llvm::ArrayRef<int64_t> batchSizes,
std::map<BatchCacheKey, FunctionOpInterface> &batchedFunctionCache) {
assert(!F.getFunctionBody().empty());

auto FTy = F.getFunctionType().cast<FunctionType>();
Expand Down Expand Up @@ -138,26 +214,42 @@ batchCloneFunction(FunctionOpInterface F, Twine name,
auto &newReg = NewF.getFunctionBody();

IRMapping mapper;
batchCloneRegion(&origReg, &newReg, mapper, batchSizes);
batchCloneRegion(&origReg, &newReg, mapper, batchSizes, batchedFunctionCache);

return NewF;
}

struct BatchPass : public BatchPassBase<BatchPass> {
void runOnOperation() override;

// Cache mapping original function and batch sizes to batched function
std::map<BatchCacheKey, FunctionOpInterface> batchedFunctionCache;

template <typename T>
LogicalResult HandleBatch(SymbolTableCollection &symbolTable, T CI) {
SmallVector<mlir::Value, 2> args;

auto *symbolOp = symbolTable.lookupNearestSymbolFrom(CI, CI.getFnAttr());
auto fn = cast<FunctionOpInterface>(symbolOp);

FunctionOpInterface newFunc =
batchCloneFunction(fn, "batched_" + fn.getName(), CI.getBatchShape());

BatchCacheKey key{fn, SmallVector<int64_t>(CI.getBatchShape().begin(),
CI.getBatchShape().end())};

// Check if we already have a batched version
auto it = batchedFunctionCache.find(key);
FunctionOpInterface newFunc;

if (it != batchedFunctionCache.end()) {
newFunc = it->second;
} else {
// Create new batched function and store in cache
newFunc = batchCloneFunction(fn, "batched_" + fn.getName(),
CI.getBatchShape(),
batchedFunctionCache);
if (!newFunc)
return failure();
batchedFunctionCache[key] = newFunc;
}

OpBuilder builder(CI);
auto dCI =
Expand Down
Loading