Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Drive enzyme ops removal using pattern rewriter #2229

Merged
merged 14 commits into from
Jan 28, 2025
Prev Previous commit
Next Next commit
workaround driver problems
Pangoraw committed Jan 22, 2025

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit 37277e06ae3c4059b766701a6dc6526cc4679805
Original file line number Diff line number Diff line change
@@ -71,8 +71,7 @@ struct ForOpEnzymeOpsRemover
auto forOp = cast<scf::ForOp>(op);
scf::ForOp otherForOp; // where caches pops are

if (removeOpsWithinBlock(forOp.getBody(), rewriter).failed())
return failure();
(void)removeOpsWithinBlock(forOp.getBody(), rewriter);

// Gradients whose values need to be passed as iteration variables.
llvm::SetVector<Value> updatedGradients;
6 changes: 3 additions & 3 deletions enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp
Original file line number Diff line number Diff line change
@@ -49,14 +49,14 @@ mlir::enzyme::CacheInfo::merge(mlir::enzyme::CacheInfo other,
mlir::LogicalResult
mlir::enzyme::removeOpsWithinBlock(mlir::Block *block,
mlir::PatternRewriter &rewriter) {
bool valid = true;
bool matched = false;

for (auto &it : *block) {
mlir::Operation *op = &it;
if (auto iface = dyn_cast<mlir::enzyme::EnzymeOpsRemoverOpInterface>(op)) {
valid &= iface.removeEnzymeOps(rewriter).succeeded();
matched |= iface.removeEnzymeOps(rewriter).succeeded();
}
}

return success(valid);
return success(matched);
}
11 changes: 8 additions & 3 deletions enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp
Original file line number Diff line number Diff line change
@@ -29,6 +29,7 @@

#include "mlir/IR/Dominance.h"
#include "llvm/Support/raw_ostream.h"
#include <cmath>

using namespace mlir;
using namespace enzyme;
@@ -307,7 +308,8 @@ static void applyPatterns(Operation *op) {
InitSimplify>(op->getContext());

GreedyRewriteConfig config;
(void)applyPatternsAndFoldGreedily(op, std::move(patterns), config);
config.fold = true;
(void)applyPatternsGreedily(op, std::move(patterns), config);
}

struct RemoveUnusedEnzymeOpsPass
@@ -324,8 +326,11 @@ struct RemoveUnusedEnzymeOpsPass
// of the interface will erase operations which are not only nested in the
// currently matched operation (for example, the for op where the pops are
// located). As such, we use the greedy driver with the option to run only
// on the pre-existing operations. This prevents the driver from running
// indefinitely.
// on the pre-existing and new operations. This prevents the driver from
// running indefinitely.
//
// TODO: improve this driver for the case where the newly created ops need
// .... to be applied the pattern as well.
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
(void)applyPatternsGreedily(op, std::move(patterns), config);