diff --git a/src/enzyme_ad/jax/TransformOps/RaisingTransformOps.cpp b/src/enzyme_ad/jax/TransformOps/RaisingTransformOps.cpp index cd629d840..b7d4393f3 100644 --- a/src/enzyme_ad/jax/TransformOps/RaisingTransformOps.cpp +++ b/src/enzyme_ad/jax/TransformOps/RaisingTransformOps.cpp @@ -8,6 +8,8 @@ #include "src/enzyme_ad/jax/TransformOps/RaisingTransformOps.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" @@ -139,6 +141,97 @@ struct RemoveIVs : public OpRewritePattern { } }; +static inline void clearBlock(mlir::Block *block, + mlir::RewriterBase &rewriter) { + for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) { + assert(op.use_empty() && "expected 'op' to have no uses"); + rewriter.eraseOp(&op); + } +} + +static mlir::Value createConstantInt(RewriterBase &rewriter, Location loc, + Type ty, int64_t v) { + if (ty.isIndex()) + return rewriter.create(loc, v); + else + return rewriter.create(loc, v, ty); +} + +static std::optional getConstant(Operation *op) { + if (auto cst = dyn_cast_or_null(op)) { + return cst.value(); + } else if (auto cst = dyn_cast_or_null(op)) { + return cst.value(); + } else if (auto cst = dyn_cast_or_null(op)) { + if (auto intAttr = dyn_cast(cst.getValue())) + return intAttr.getValue().getSExtValue(); + } + return {}; +} + +static std::optional getConstant(Value v) { + Operation *op = v.getDefiningOp(); + if (op) + return getConstant(op); + return {}; +} + +/// Returns `true` if the loop has a form expected by interchange patterns. +static bool isNormalized(scf::ForOp op) { + auto lb = getConstant(op.getLowerBound()); + auto step = getConstant(op.getStep()); + if (!lb || !step) + return false; + return *lb == 0 && *step == 1; +} + +#define DEBUG_TYPE "normalize-loop" +#define DBGS llvm::dbgs + +struct NormalizeLoop : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::ForOp op, + PatternRewriter &rewriter) const override { + using namespace arith; + if (isNormalized(op) || + !isa(op->getParentOp())) { + LLVM_DEBUG(DBGS() << "[normalize-loop] loop already normalized\n"); + return failure(); + } + + rewriter.setInsertionPoint(op); + Value zero = createConstantInt(rewriter, op.getLoc(), + op.getInductionVar().getType(), 0); + Value one = createConstantInt(rewriter, op.getLoc(), + op.getInductionVar().getType(), 1); + + Value difference = rewriter.create(op.getLoc(), op.getUpperBound(), + op.getLowerBound()); + Value tripCount = rewriter.create( + op.getLoc(), + rewriter.create( + op.getLoc(), rewriter.create(op.getLoc(), difference, one), + op.getStep()), + one); + // rewriter.create(op.getLoc(), difference, op.getStep()); + auto newForOp = rewriter.create(op.getLoc(), zero, tripCount, + one, op.getInits()); + clearBlock(newForOp.getBody(), rewriter); + rewriter.setInsertionPointToStart(newForOp.getBody()); + Value scaled = rewriter.create( + op.getLoc(), newForOp.getInductionVar(), op.getStep()); + Value iv = rewriter.create(op.getLoc(), op.getLowerBound(), scaled); + SmallVector newArgs(newForOp.getRegion().args_begin(), + newForOp.getRegion().args_end()); + newArgs[0] = iv; + rewriter.inlineBlockBefore(op.getBody(), newForOp.getBody(), + newForOp.getBody()->end(), newArgs); + rewriter.replaceOp(op, newForOp->getResults()); + return success(); + } +}; + } // namespace transform } // namespace mlir diff --git a/src/enzyme_ad/jax/TransformOps/RaisingTransformOps.td b/src/enzyme_ad/jax/TransformOps/RaisingTransformOps.td index 0bbc3af93..09fa35716 100644 --- a/src/enzyme_ad/jax/TransformOps/RaisingTransformOps.td +++ b/src/enzyme_ad/jax/TransformOps/RaisingTransformOps.td @@ -17,3 +17,8 @@ def ApplyRemoveIVsPatterns : RaisingPatternOp< "remove_ivs"> { let patterns = ["RemoveIVs"]; } + +def ApplyNormalizeLoopPatterns : RaisingPatternOp< + "normalize_loop"> { + let patterns = ["NormalizeLoop"]; +} diff --git a/test/lit_tests/patterns/normalize_loop.mlir b/test/lit_tests/patterns/normalize_loop.mlir new file mode 100644 index 000000000..741aaa28b --- /dev/null +++ b/test/lit_tests/patterns/normalize_loop.mlir @@ -0,0 +1,38 @@ +// RUN: enzymexlamlir-opt %s -split-input-file -allow-unregistered-dialect --transform-interpreter | FileCheck %s + +module { + func.func @test_normalize_loop() { + %c5 = arith.constant 5 : index + %c20 = arith.constant 20 : index + %c3 = arith.constant 3 : index + scf.parallel (%arg0) = (%c5) to (%c20) step (%c3) { + scf.for %i = %c5 to %c20 step %c3 { + "test.test"() : () -> () + "test.test1"(%i) : (index) -> () + } + } + return + } + + builtin.module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg2: !transform.any_op) { + %4 = transform.structured.match ops{["func.func"]} in %arg2 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %4 { + transform.apply_patterns.raising.normalize_loop + } : !transform.any_op + transform.yield + } + } +} + +// CHECK-LABEL: func @test_normalize_loop( +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 +// CHECK-DAG: %[[C5:.*]] = arith.constant 5 +// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C5]] step %[[C1]] +// CHECK: %[[MUL:.*]] = arith.muli %[[I]], %[[C3]] +// CHECK: %[[ADD:.*]] = arith.addi %[[MUL]], %[[C5]] +// CHECK: "test.test" +// CHECK: "test.test1"(%[[ADD]]) +// CHECK: scf.reduce