Skip to content

Commit

Permalink
Add RemoveIV pattern
Browse files Browse the repository at this point in the history
  • Loading branch information
BuildKite committed Jan 31, 2025
1 parent d89468e commit d614590
Show file tree
Hide file tree
Showing 8 changed files with 331 additions and 12 deletions.
64 changes: 62 additions & 2 deletions src/enzyme_ad/jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,63 @@ gentbl_cc_library(
tblgen = "//:enzymexlamlir-tblgen",
)

td_library(
name = "RaisingTransformOpsTdFiles",
srcs = [
"TransformOps/RaisingTransformOps.td",
],
deps = [
"@llvm-project//mlir:TransformDialectTdFiles",
]
)

gentbl_cc_library(
name = "RaisingTransformOpsIncGen",
tbl_outs = [(
["-gen-op-decls"],
"TransformOps/RaisingTransformOps.h.inc",
), (
["-gen-op-defs"],
"TransformOps/RaisingTransformOps.cpp.inc",
),
],
td_file = "TransformOps/RaisingTransformOps.td",
deps = [
":RaisingTransformOpsTdFiles",
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
)

gentbl_cc_library(
name = "RaisingTransformOpsImplIncGen",
tbl_outs = [(
["-gen-populate-raising-patterns-interface-impl"],
"TransformOps/RaisingTransformOpsImpl.cpp.inc"
)],
td_file = "TransformOps/RaisingTransformOps.td",
deps = [
":RaisingTransformOpsTdFiles",
],
tblgen = "//:enzymexlamlir-tblgen",
)

gentbl_cc_library(
name = "RaisingTransformPatternsIncGen",
tbl_outs = [
(
["-gen-populate-raising-patterns-func-decls"],
"TransformOps/RaisingTransformPatterns.h.inc",
), (
["-gen-populate-raising-patterns-func-defs"],
"TransformOps/RaisingTransformPatterns.cpp.inc",
)],
td_file = "TransformOps/RaisingTransformOps.td",
deps = [
":RaisingTransformOpsTdFiles",
],
tblgen = "//:enzymexlamlir-tblgen",
)

cc_library(
name = "TransformOps",
srcs = glob(["TransformOps/*.cpp"]),
Expand All @@ -132,6 +189,9 @@ cc_library(
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:TransformDialect",
"@llvm-project//mlir:TransformDialectInterfaces",
":RaisingTransformOpsIncGen",
":RaisingTransformOpsImplIncGen",
":RaisingTransformPatternsIncGen",
":TransformOpsIncGen",
":TransformOpsImplIncGen",
":XLADerivatives",
Expand Down Expand Up @@ -273,7 +333,7 @@ gentbl_cc_library(
)

gentbl_cc_library(
name = "EnzyeHLOPatternsIncGen",
name = "EnzymeHLOPatternsIncGen",
tbl_outs = [
(
["-gen-populate-patterns-func-decls"],
Expand Down Expand Up @@ -312,7 +372,7 @@ cc_library(
deps = [
":EnzymeXLAOpsIncGen",
":EnzymeXLAPassesIncGen",
":EnzyeHLOPatternsIncGen",
":EnzymeHLOPatternsIncGen",
"@llvm-project//mlir:DLTIDialect",
"@llvm-project//mlir:GPUPipelines",
"@llvm-project//llvm:Core",
Expand Down
2 changes: 2 additions & 0 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7557,6 +7557,8 @@ void mlir::transform::addConcatenateOpCanon(RewritePatternSet &patterns,
patterns.insert<ConcatenateOpCanon>(maxConstantExpansion, &context, benefit);
}

/////////////// End stablehlo patterns

namespace {
struct EnzymeHLOOptPass
: public enzyme::impl::EnzymeHLOOptPassBase<EnzymeHLOOptPass> {
Expand Down
2 changes: 2 additions & 0 deletions src/enzyme_ad/jax/RegistryUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
namespace mlir {
namespace enzyme {
void registerEnzymeJaxTransformExtension(mlir::DialectRegistry &registry);
void registerRaisingTransformExtension(mlir::DialectRegistry &registry);
} // namespace enzyme
} // namespace mlir

Expand Down Expand Up @@ -124,6 +125,7 @@ void prepareRegistry(mlir::DialectRegistry &registry) {
mlir::linalg::registerTransformDialectExtension(registry);

mlir::enzyme::registerEnzymeJaxTransformExtension(registry);
mlir::enzyme::registerRaisingTransformExtension(registry);

mlir::registerLLVMDialectImport(registry);
mlir::registerNVVMDialectImport(registry);
Expand Down
165 changes: 165 additions & 0 deletions src/enzyme_ad/jax/TransformOps/RaisingTransformOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
//===- RaisingTransformOps.cpp - Definition of raising transform extension ===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "src/enzyme_ad/jax/TransformOps/RaisingTransformOps.h"

#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h"

#define GET_OP_CLASSES
#include "src/enzyme_ad/jax/TransformOps/RaisingTransformOps.cpp.inc"
#include "src/enzyme_ad/jax/TransformOps/RaisingTransformOpsImpl.cpp.inc"

using namespace mlir;

namespace mlir {
namespace transform {

struct RemoveIVs : public OpRewritePattern<scf::ForOp> {
using OpRewritePattern<scf::ForOp>::OpRewritePattern;
LogicalResult matchAndRewrite(scf::ForOp forOp,
PatternRewriter &rewriter) const override {
if (!forOp.getRegion().hasOneBlock())
return failure();
unsigned numIterArgs = forOp.getNumRegionIterArgs();
auto loc = forOp->getLoc();
bool changed = false;
llvm::SetVector<unsigned> removed;
llvm::MapVector<unsigned, Value> steps;
auto yield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
for (unsigned i = 0; i < numIterArgs; i++) {
auto ba = forOp.getRegionIterArgs()[i];
auto init = forOp.getInits()[i];
auto next = yield->getOperand(i);

auto increment = next.getDefiningOp<arith::AddIOp>();
if (!increment)
continue;

Value step = nullptr;
if (increment.getLhs() == ba) {
step = increment.getRhs();
} else {
step = increment.getLhs();
}
if (!step)
continue;

// If it dominates the loop entry
if (!step.getParentRegion()->isProperAncestor(&forOp.getRegion()))
continue;

rewriter.setInsertionPointToStart(forOp.getBody());
Value iterNum = rewriter.create<arith::SubIOp>(
loc, forOp.getInductionVar(), forOp.getLowerBound());
iterNum = rewriter.create<arith::DivSIOp>(loc, iterNum, forOp.getStep());

Value replacementIV = rewriter.create<arith::MulIOp>(loc, iterNum, step);
replacementIV = rewriter.create<arith::AddIOp>(loc, replacementIV, init);

rewriter.replaceAllUsesWith(ba, replacementIV);

removed.insert(i);
steps.insert({i, step});
changed = true;
}

if (!changed)
return failure();

SmallVector<Value> newInits;
for (unsigned i = 0; i < numIterArgs; i++)
if (!removed.contains(i))
newInits.push_back(forOp.getInits()[i]);

rewriter.setInsertionPoint(forOp);
auto newForOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
forOp.getUpperBound(),
forOp.getStep(), newInits);
if (!newForOp.getRegion().empty())
newForOp.getRegion().front().erase();
assert(newForOp.getRegion().empty());
rewriter.inlineRegionBefore(forOp.getRegion(), newForOp.getRegion(),
newForOp.getRegion().begin());

SmallVector<Value> newYields;
for (unsigned i = 0; i < numIterArgs; i++)
if (!removed.contains(i))
newYields.push_back(yield->getOperand(i));

rewriter.setInsertionPoint(yield);
rewriter.replaceOpWithNewOp<scf::YieldOp>(yield, newYields);

llvm::BitVector toDelete(numIterArgs + 1);
for (unsigned i = 0; i < numIterArgs; i++)
if (removed.contains(i))
toDelete[i + 1] = true;
newForOp.getBody()->eraseArguments(toDelete);

rewriter.setInsertionPoint(newForOp);
unsigned curNewRes = 0;
for (unsigned i = 0; i < numIterArgs; i++) {
auto result = forOp->getResult(i);
if (removed.contains(i)) {
if (result.use_empty())
continue;

rewriter.setInsertionPointAfter(forOp.getOperation());
Value iterNum = rewriter.create<arith::SubIOp>(
loc, forOp.getUpperBound(), forOp.getLowerBound());
iterNum =
rewriter.create<arith::DivSIOp>(loc, iterNum, forOp.getStep());

Value afterLoop =
rewriter.create<arith::MulIOp>(loc, iterNum, steps[i]);
afterLoop =
rewriter.create<arith::AddIOp>(loc, afterLoop, forOp.getInits()[i]);

rewriter.replaceAllUsesWith(result, afterLoop);
} else {
rewriter.replaceAllUsesWith(result, newForOp->getResult(curNewRes++));
}
}

forOp->getParentOp()->dump();
rewriter.eraseOp(forOp);

return success();
}
};

} // namespace transform
} // namespace mlir

#include "src/enzyme_ad/jax/TransformOps/RaisingTransformPatterns.cpp.inc"

namespace {
class RaisingTransformExtension
: public transform::TransformDialectExtension<RaisingTransformExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RaisingTransformExtension)
using Base::Base;

void init() {
registerTransformOps<
#define GET_OP_LIST
#include "src/enzyme_ad/jax/TransformOps/RaisingTransformOps.cpp.inc"
>();
}
};
} // namespace

void mlir::enzyme::registerRaisingTransformExtension(DialectRegistry &registry) {
registry.addExtensions<RaisingTransformExtension>();
}
23 changes: 23 additions & 0 deletions src/enzyme_ad/jax/TransformOps/RaisingTransformOps.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
//===---- RaisingTransformOps.h - Declarations of Transform extension ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "src/enzyme_ad/jax/TransformOps/OpInterfaces.h.inc"

#define GET_OP_CLASSES
#include "src/enzyme_ad/jax/TransformOps/RaisingTransformOps.h.inc"
#include "src/enzyme_ad/jax/TransformOps/RaisingTransformPatterns.h.inc"

namespace mlir {
namespace enzyme {
void registerRaisingTransformExtension(mlir::DialectRegistry &registry);

} // namespace enzyme
} // namespace mlir
19 changes: 19 additions & 0 deletions src/enzyme_ad/jax/TransformOps/RaisingTransformOps.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
include "mlir/Dialect/Transform/IR/TransformDialect.td"

class RaisingPatternOp<string mnemonic, list<Trait> traits = []>
: Op<Transform_Dialect,
"apply_patterns.raising." # mnemonic,
// For some reason, inherited methods are not getting declared...
!listconcat(
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>],
traits)> {
let arguments = (ins OptionalAttr<I64Attr>:$benefit);
list<string> patterns = [];
let assemblyFormat = "attr-dict";
}

def ApplyRemoveIVsPatterns : RaisingPatternOp<
"remove_ivs"> {
let patterns = ["RemoveIVs"];
}
Loading

0 comments on commit d614590

Please sign in to comment.