Skip to content

Commit

Permalink
Add NormalizeLoop pattern
Browse files Browse the repository at this point in the history
  • Loading branch information
tyb0807 authored and wsmoses committed Jan 31, 2025
1 parent 8b9b4ec commit 2e7f8b4
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 0 deletions.
93 changes: 93 additions & 0 deletions src/enzyme_ad/jax/TransformOps/RaisingTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

#include "src/enzyme_ad/jax/TransformOps/RaisingTransformOps.h"

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
Expand Down Expand Up @@ -139,6 +141,97 @@ struct RemoveIVs : public OpRewritePattern<scf::ForOp> {
}
};

static inline void clearBlock(mlir::Block *block,
mlir::RewriterBase &rewriter) {
for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) {
assert(op.use_empty() && "expected 'op' to have no uses");
rewriter.eraseOp(&op);
}
}

static mlir::Value createConstantInt(RewriterBase &rewriter, Location loc,
Type ty, int64_t v) {
if (ty.isIndex())
return rewriter.create<arith::ConstantIndexOp>(loc, v);
else
return rewriter.create<arith::ConstantIntOp>(loc, v, ty);
}

static std::optional<int64_t> getConstant(Operation *op) {
if (auto cst = dyn_cast_or_null<arith::ConstantIntOp>(op)) {
return cst.value();
} else if (auto cst = dyn_cast_or_null<arith::ConstantIndexOp>(op)) {
return cst.value();
} else if (auto cst = dyn_cast_or_null<LLVM::ConstantOp>(op)) {
if (auto intAttr = dyn_cast<IntegerAttr>(cst.getValue()))
return intAttr.getValue().getSExtValue();
}
return {};
}

static std::optional<int64_t> getConstant(Value v) {
Operation *op = v.getDefiningOp();
if (op)
return getConstant(op);
return {};
}

/// Returns `true` if the loop has a form expected by interchange patterns.
static bool isNormalized(scf::ForOp op) {
auto lb = getConstant(op.getLowerBound());
auto step = getConstant(op.getStep());
if (!lb || !step)
return false;
return *lb == 0 && *step == 1;
}

#define DEBUG_TYPE "normalize-loop"
#define DBGS llvm::dbgs

struct NormalizeLoop : public OpRewritePattern<scf::ForOp> {
using OpRewritePattern<scf::ForOp>::OpRewritePattern;

LogicalResult matchAndRewrite(scf::ForOp op,
PatternRewriter &rewriter) const override {
using namespace arith;
if (isNormalized(op) ||
!isa<scf::ParallelOp, affine::AffineParallelOp>(op->getParentOp())) {
LLVM_DEBUG(DBGS() << "[normalize-loop] loop already normalized\n");
return failure();
}

rewriter.setInsertionPoint(op);
Value zero = createConstantInt(rewriter, op.getLoc(),
op.getInductionVar().getType(), 0);
Value one = createConstantInt(rewriter, op.getLoc(),
op.getInductionVar().getType(), 1);

Value difference = rewriter.create<SubIOp>(op.getLoc(), op.getUpperBound(),
op.getLowerBound());
Value tripCount = rewriter.create<AddIOp>(
op.getLoc(),
rewriter.create<DivUIOp>(
op.getLoc(), rewriter.create<SubIOp>(op.getLoc(), difference, one),
op.getStep()),
one);
// rewriter.create<CeilDivSIOp>(op.getLoc(), difference, op.getStep());
auto newForOp = rewriter.create<scf::ForOp>(op.getLoc(), zero, tripCount,
one, op.getInits());
clearBlock(newForOp.getBody(), rewriter);
rewriter.setInsertionPointToStart(newForOp.getBody());
Value scaled = rewriter.create<MulIOp>(
op.getLoc(), newForOp.getInductionVar(), op.getStep());
Value iv = rewriter.create<AddIOp>(op.getLoc(), op.getLowerBound(), scaled);
SmallVector<Value> newArgs(newForOp.getRegion().args_begin(),
newForOp.getRegion().args_end());
newArgs[0] = iv;
rewriter.inlineBlockBefore(op.getBody(), newForOp.getBody(),
newForOp.getBody()->end(), newArgs);
rewriter.replaceOp(op, newForOp->getResults());
return success();
}
};

} // namespace transform
} // namespace mlir

Expand Down
5 changes: 5 additions & 0 deletions src/enzyme_ad/jax/TransformOps/RaisingTransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,8 @@ def ApplyRemoveIVsPatterns : RaisingPatternOp<
"remove_ivs"> {
let patterns = ["RemoveIVs"];
}

def ApplyNormalizeLoopPatterns : RaisingPatternOp<
"normalize_loop"> {
let patterns = ["NormalizeLoop"];
}
38 changes: 38 additions & 0 deletions test/lit_tests/patterns/normalize_loop.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// RUN: enzymexlamlir-opt %s -split-input-file -allow-unregistered-dialect --transform-interpreter | FileCheck %s

module {
func.func @test_normalize_loop() {
%c5 = arith.constant 5 : index
%c20 = arith.constant 20 : index
%c3 = arith.constant 3 : index
scf.parallel (%arg0) = (%c5) to (%c20) step (%c3) {
scf.for %i = %c5 to %c20 step %c3 {
"test.test"() : () -> ()
"test.test1"(%i) : (index) -> ()
}
}
return
}

builtin.module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg2: !transform.any_op) {
%4 = transform.structured.match ops{["func.func"]} in %arg2 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %4 {
transform.apply_patterns.raising.normalize_loop
} : !transform.any_op
transform.yield
}
}
}

// CHECK-LABEL: func @test_normalize_loop(
// CHECK-DAG: %[[C0:.*]] = arith.constant 0
// CHECK-DAG: %[[C1:.*]] = arith.constant 1
// CHECK-DAG: %[[C3:.*]] = arith.constant 3
// CHECK-DAG: %[[C5:.*]] = arith.constant 5
// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C5]] step %[[C1]]
// CHECK: %[[MUL:.*]] = arith.muli %[[I]], %[[C3]]
// CHECK: %[[ADD:.*]] = arith.addi %[[MUL]], %[[C5]]
// CHECK: "test.test"
// CHECK: "test.test1"(%[[ADD]])
// CHECK: scf.reduce

0 comments on commit 2e7f8b4

Please sign in to comment.