From 3e9921a2998f99bf03cd92e27a07728ca8d3e1c4 Mon Sep 17 00:00:00 2001 From: BuildKite Date: Thu, 30 Jan 2025 23:27:06 +0100 Subject: [PATCH 01/12] Add LLVM to Affine access pass --- src/enzyme_ad/jax/BUILD | 3 + src/enzyme_ad/jax/Dialect/EnzymeXLAOps.td | 11 + .../jax/Passes/LLVMToAffineAccess.cpp | 1192 +++++++++++++++++ src/enzyme_ad/jax/Passes/Passes.td | 10 + .../jax/TransformOps/RaisingTransformOps.cpp | 279 ++-- .../jax/TransformOps/RaisingTransformOps.h | 20 + .../raising/llvm_to_affine_access.mlir | 66 + 7 files changed, 1436 insertions(+), 145 deletions(-) create mode 100644 src/enzyme_ad/jax/Passes/LLVMToAffineAccess.cpp create mode 100644 test/lit_tests/raising/llvm_to_affine_access.mlir diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index 61edb95dc..c40006320 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -395,6 +395,9 @@ cc_library( ":EnzymeXLAOpsIncGen", ":EnzymeXLAPassesIncGen", ":EnzymeHLOPatternsIncGen", + ":RaisingTransformOpsImplIncGen", + ":RaisingTransformOpsIncGen", + ":RaisingTransformPatternsIncGen", ":RaisingTransformOps", "@llvm-project//mlir:DLTIDialect", "@llvm-project//mlir:GPUPipelines", diff --git a/src/enzyme_ad/jax/Dialect/EnzymeXLAOps.td b/src/enzyme_ad/jax/Dialect/EnzymeXLAOps.td index e474c823e..24d1b6a18 100644 --- a/src/enzyme_ad/jax/Dialect/EnzymeXLAOps.td +++ b/src/enzyme_ad/jax/Dialect/EnzymeXLAOps.td @@ -128,4 +128,15 @@ def Pointer2MemrefOp : EnzymeXLA_Op<"pointer2memref", [ }]; } +def AffineScopeOp : EnzymeXLA_Op<"scope", [ + AffineScope, + AutomaticAllocationScope, + RecursiveMemoryEffects, + ]>, + Arguments<(ins Variadic:$operands)>, + Results<(outs Variadic:$results)> { + let summary = "Inline affine scope"; + let regions = (region SizedRegion<1>:$region); +} + #endif // ENZYMEXLA_OPS diff --git a/src/enzyme_ad/jax/Passes/LLVMToAffineAccess.cpp b/src/enzyme_ad/jax/Passes/LLVMToAffineAccess.cpp new file mode 100644 index 000000000..962710f39 --- /dev/null +++ b/src/enzyme_ad/jax/Passes/LLVMToAffineAccess.cpp @@ -0,0 +1,1192 @@ +#include "Passes.h" + +#include "mlir/Analysis/CallGraph.h" +#include "mlir/Analysis/DataLayoutAnalysis.h" +#include "mlir/Dialect/Affine/Analysis/AffineStructures.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/IntegerSet.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/IR/Verifier.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "src/enzyme_ad/jax/Dialect/Dialect.h" +#include "src/enzyme_ad/jax/Dialect/Ops.h" +#include "src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h" +#include "src/enzyme_ad/jax/TransformOps/RaisingTransformOps.h" + +#include "llvm/ADT/BitVector.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVectorExtras.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/IR/Instructions.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/LogicalResult.h" +#include "llvm/Support/raw_ostream.h" + +#include "Utils.h" + +#include +#include +#include + +#define DEBUG_TYPE "llvm-to-affine-access" + +namespace mlir { +namespace enzyme { +#define GEN_PASS_DEF_LLVMTOAFFINEACCESSPASS +#include "src/enzyme_ad/jax/Passes/Passes.h.inc" +} // namespace enzyme +} // namespace mlir + +using namespace mlir; + +using PtrVal = TypedValue; +using MemRefVal = MemrefValue; + +static mlir::Value createConstantInt(RewriterBase &rewriter, Location loc, + Type ty, int64_t v) { + if (ty.isIndex()) + return rewriter.create(loc, v); + else + return rewriter.create(loc, v, ty); +} + +static std::optional getConstant(Operation *op) { + if (auto cst = dyn_cast_or_null(op)) { + return cst.value(); + } else if (auto cst = dyn_cast_or_null(op)) { + return cst.value(); + } else if (auto cst = dyn_cast_or_null(op)) { + if (auto intAttr = dyn_cast(cst.getValue())) + return intAttr.getValue().getSExtValue(); + } + return {}; +} + +static std::optional getConstant(Value v) { + Operation *op = v.getDefiningOp(); + if (op) + return getConstant(op); + return {}; +} + +static LogicalResult +convertLLVMAllocaToMemrefAlloca(LLVM::AllocaOp alloc, RewriterBase &rewriter, + const DataLayout &dataLayout) { + if (!alloc.getRes().hasOneUse()) + return failure(); + + auto sizeVal = getConstant(alloc.getArraySize()); + if (!sizeVal) + return failure(); + + Type elType = rewriter.getI8Type(); + int64_t elNum = dataLayout.getTypeSize(alloc.getElemType()) * (*sizeVal); + + auto ptr2memref = + dyn_cast(alloc.getRes().use_begin()->getOwner()); + if (!ptr2memref) + return failure(); + + assert(elType == ptr2memref.getResult().getType().getElementType()); + + SmallVector sizes = {elNum}; + auto memrefType = + MemRefType::get(sizes, elType, MemRefLayoutAttrInterface{}, + ptr2memref.getResult().getType().getMemorySpace()); + auto newAlloca = + rewriter.create(alloc->getLoc(), memrefType); + rewriter.replaceAllUsesWith(ptr2memref.getResult(), newAlloca.getResult()); + rewriter.eraseOp(ptr2memref); + rewriter.eraseOp(alloc); + return success(); +} + +namespace { +struct ConvertLLVMAllocaToMemrefAlloca + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + const DataLayoutAnalysis &dl; + ConvertLLVMAllocaToMemrefAlloca(MLIRContext *context, + const DataLayoutAnalysis &dl) + : OpRewritePattern(context), dl(dl) {} + + LogicalResult matchAndRewrite(LLVM::AllocaOp alloc, + PatternRewriter &rewriter) const override { + auto dataLayout = dl.getAtOrAbove(alloc); + return convertLLVMAllocaToMemrefAlloca(alloc, rewriter, dataLayout); + } +}; +} // namespace + +static Value convertToIndex(Value v) { + OpBuilder builder(v.getContext()); + if (v.getType() == builder.getIndexType()) + return v; + if (auto ba = dyn_cast(v)) + builder.setInsertionPointToStart(ba.getOwner()); + else + builder.setInsertionPointAfter(v.getDefiningOp()); + return builder + .create(v.getLoc(), builder.getIndexType(), v) + .getResult(); +} + +static MemRefVal convertToMemref(PtrVal addr) { + OpBuilder builder(addr.getContext()); + if (auto ba = dyn_cast(addr)) + builder.setInsertionPointToStart(ba.getOwner()); + else + builder.setInsertionPointAfter(addr.getDefiningOp()); + Attribute addrSpace; + if (addr.getType().getAddressSpace() == 0) + addrSpace = nullptr; + else + addrSpace = IntegerAttr::get(IntegerType::get(addr.getContext(), 64), + addr.getType().getAddressSpace()); + // TODO we can actually plug in the size of the memref here if `addr` is + // defined by an llvm.alloca + auto ptr2memref = builder.create( + addr.getLoc(), + MemRefType::get({ShapedType::kDynamic}, builder.getI8Type(), + MemRefLayoutAttrInterface{}, Attribute(addrSpace)), addr); + return cast(ptr2memref.getResult()); +} + +template struct ConverterBase { + DenseMap map; + To operator()(From p) { + auto it = map.find(p); + if (it != map.end()) + return it->getSecond(); + auto converted = F(p); + map.insert({p, converted}); + return converted; + } + SmallVector operator()(ValueRange range) { + return llvm::map_to_vector(range, [&](From v) { return (*this)(v); }); + } +}; + +using MemrefConverter = ConverterBase; +using IndexConverter = ConverterBase; + +static BlockArgument getScopeRemap(enzymexla::AffineScopeOp scope, Value v) { + for (unsigned i = 0; i < scope->getNumOperands(); i++) + if (scope->getOperand(i) == v) + return scope.getRegion().begin()->getArgument(i); + return nullptr; +} + +/// See llvm/Support/Alignment.h +static AffineExpr alignTo(AffineExpr expr, uint64_t a) { + return (expr + a - 1).floorDiv(a) * a; +} + +// TODO To preserve correctness, we need to keep track of values for which +// converting indexing to the index type preserves the semantics, i.e. no +// overflows or underflows or trucation etc and insert a runtime guard against +// that +struct AffineExprBuilder { + AffineExprBuilder(Operation *user, bool legalizeSymbols) + : user(user), legalizeSymbols(legalizeSymbols) {} + Operation *user; + + SmallPtrSet illegalSymbols; + + DenseMap symToPos; + DenseMap dimToPos; + SmallVector symbolOperands; + SmallVector dimOperands; + + // Options + bool legalizeSymbols; + + SmallVector symbolsForScope; + unsigned scopedIllegalSymbols = 0; + bool scoped = false; + + bool isLegal() { + return illegalSymbols.size() == 0 || + (illegalSymbols.size() == scopedIllegalSymbols && scoped); + } + + void collectSymbolsForScope(Region *region, SmallPtrSetImpl &symbols) { + assert(region->getBlocks().size() == 1); + SmallVector newExprs; + if (!region->isAncestor(user->getParentRegion())) + return; + // An illegal symbol will be legalized either by defining in at the top + // level in a region, or by remapping it in the scope + for (auto sym : illegalSymbols) { + assert(sym.getParentRegion()->isAncestor(region)); + bool isOutsideRegion = sym.getParentRegion()->isProperAncestor(region); + auto ba = dyn_cast(sym); + bool isTopLevelBlockArg = ba && ba.getOwner()->getParent() == region; + [[maybe_unused]] bool isTopLevelOp = + !ba && sym.getParentRegion() == region; + assert((unsigned)isOutsideRegion + (unsigned)isTopLevelBlockArg + + (unsigned)isTopLevelOp == + 1); + scopedIllegalSymbols++; + if (isOutsideRegion || isTopLevelBlockArg) + symbols.insert(sym); + } + if (!region->isProperAncestor(user->getParentRegion())) + return; + // We redefine dims to be symbols in this scope + for (auto dim : dimOperands) { + if (dim.getParentRegion()->isProperAncestor(region)) { + symbols.insert(dim); + symbolsForScope.push_back(dim); + } + } + // TODO we may have a state like this: + // + // func.func () { + // %sym = ... + // region: { + // ... + // } + // } + // + // and `sym` was mot marked illegal because func.func is an affine scope. + // Should we rescope it to the new scope? + } + + AffineExpr rescopeExprImpl(AffineExpr expr, enzymexla::AffineScopeOp scope) { + auto newExpr = expr; + for (auto sym : symbolsForScope) { + unsigned dimPos = getDimPosition(sym); + assert(dimOperands[dimPos] == sym); + BlockArgument newSym = getScopeRemap(scope, sym); + assert(newSym); + unsigned newSymPos = getSymbolPosition(newSym); + AffineExpr dimExpr = getAffineDimExpr(dimPos, user->getContext()); + AffineExpr newSymExpr = getAffineDimExpr(newSymPos, user->getContext()); + newExpr = newExpr.replace(dimExpr, newSymExpr); + } + for (auto sym : illegalSymbols) { + if (sym.getParentRegion() == &scope.getRegion()) + continue; + BlockArgument newSym = getScopeRemap(scope, sym); + assert(newSym); + auto it = llvm::find(symbolOperands, sym); + assert(it != symbolOperands.end()); + *it = newSym; + } + return newExpr; + } + + void rescopeExpr(enzymexla::AffineScopeOp scope) { + expr = rescopeExprImpl(expr, scope); + assert(!scoped); + scoped = true; + } + + unsigned getPosition(Value v, SmallVectorImpl &operands, + DenseMap toPos) { + auto it = toPos.find(v); + if (it != toPos.end()) + return it->getSecond(); + unsigned newPos = operands.size(); + toPos.insert({v, newPos}); + operands.push_back(v); + return newPos; + } + + unsigned getSymbolPosition(Value v) { + return getPosition(v, symbolOperands, symToPos); + } + unsigned getDimPosition(Value v) { + return getPosition(v, dimOperands, dimToPos); + } + + template + inline FailureOr buildPassthrough(Operation *op) { + if (isa(op)) { + assert(op->getNumOperands() == 1); + return buildExpr(op->getOperand(0)); + } + return failure(); + } + + template + inline FailureOr + buildBinOpExpr(Operation *op, + AffineExpr (AffineExpr::*handler)(AffineExpr) const) { + if (isa(op)) { + assert(op->getNumOperands() == 2); + auto lhs = buildExpr(op->getOperand(0)); + auto rhs = buildExpr(op->getOperand(1)); + if (failed(lhs) || failed(rhs)) + return failure(); + return ((*lhs).*handler)(*rhs); + } + return failure(); + } + + // TODO test this + FailureOr buildShift(Operation *op) { + if (op->getNumOperands() != 2) + return failure(); + auto rhs = getConstant(op->getOperand(1)); + if (!rhs) + return failure(); + auto lhs = buildExpr(op->getOperand(0)); + if (failed(lhs)) + return failure(); + if (isa(op)) { + return (*lhs) * getAffineConstantExpr(1 << (*rhs), op->getContext()); + } else if (isa( + op)) { + return (*lhs).floorDiv( + getAffineConstantExpr(1 << (*rhs), op->getContext())); + } + return failure(); + } + + FailureOr buildExpr(Value v) { + auto context = v.getContext(); + Operation *op = v.getDefiningOp(); + auto cst = getConstant(op); + if (cst) + return getAffineConstantExpr(*cst, context); + bool isIndexTy = v.getType().isIndex(); + Value oldV = v; + if (!isIndexTy) + v = convertToIndex(v); + if (affine::isValidSymbol(v)) { + return getAffineSymbolExpr(getSymbolPosition(v), v.getContext()); + } else if (affine::isValidDim(v)) { + return getAffineDimExpr(getDimPosition(v), v.getContext()); + } + if (!isIndexTy) { + v.getDefiningOp()->erase(); + v = oldV; + } + + if (op) { + // clang-format off +#define RIS(X) do { auto res = X; if (succeeded(res)) return *res; } while (0) + RIS((buildBinOpExpr( + op, &AffineExpr::operator+))); + RIS((buildBinOpExpr( + op, &AffineExpr::operator-))); + RIS((buildBinOpExpr( + op, &AffineExpr::operator%))); + // TODO need to check that we dont end up with dim * dim or other invalid + // expression + RIS((buildBinOpExpr( + op, &AffineExpr::operator*))); + RIS((buildBinOpExpr( + op, &AffineExpr::floorDiv))); + RIS((buildPassthrough< + LLVM::ZExtOp, LLVM::SExtOp, LLVM::TruncOp, + arith::ExtSIOp, arith::ExtUIOp, arith::TruncIOp, + arith::IndexCastOp, arith::IndexCastUIOp>(op))); + RIS((buildShift(op))); +#undef RIS + // clang-format on + } + + // TODO We may find an affine op reduction block arg - we may be able to + // handle them + + for (auto &use : v.getUses()) { + if (auto affineScope = + dyn_cast(use.getOwner())) { + if (affineScope->isAncestor(user)) + // TODO should we try to find the inner-most one? + return getAffineSymbolExpr( + getSymbolPosition(affineScope.getRegion().front().getArgument( + use.getOperandNumber())), + v.getContext()); + } + } + + if (legalizeSymbols) { + illegalSymbols.insert(v); + return getAffineSymbolExpr(getSymbolPosition(v), context); + } + + return failure(); + } + + FailureOr getExpr(llvm::PointerUnion index) { + auto constIndex = dyn_cast(index); + if (constIndex) { + return getAffineConstantExpr(constIndex.getInt(), user->getContext()); + } else { + auto expr = buildExpr(cast(index)); + if (succeeded(expr)) + expr->dump(); + return expr; + } + } + + AffineExpr expr; + LogicalResult build(llvm::PointerUnion index) { + auto mexpr = getExpr(index); + if (failed(mexpr)) + return failure(); + expr = *mexpr; + return success(); + } + + struct MapAndOperands { + AffineMap map; + SmallVector operands; + }; + AffineExpr getExpr() { + assert(isLegal()); + return expr; + } + MapAndOperands getMap() { + assert(isLegal()); + AffineMap map = AffineMap::get(dimOperands.size(), symbolOperands.size(), + expr, user->getContext()); + auto concat = llvm::concat(dimOperands, symbolOperands); + SmallVector operands = + SmallVector(concat.begin(), concat.end()); + affine::canonicalizeMapAndOperands(&map, &operands); + map = simplifyAffineMap(map); + return {map, operands}; + } +}; + +struct AffineAccessBuilder : AffineExprBuilder { +private: + struct AffineAccess { + PtrVal base; + AffineExpr expr; + }; + +public: + AffineAccessBuilder(Operation *accessOp, bool legalizeSymbols) + : AffineExprBuilder(accessOp, legalizeSymbols) {} + + PtrVal base = nullptr; + + LogicalResult build(const DataLayout &dataLayout, PtrVal addr) { + auto aa = buildAffineAccess(dataLayout, addr); + if (failed(aa)) + return failure(); + expr = aa->expr; + base = aa->base; + + LLVM_DEBUG(llvm::dbgs() << "Built expr: " << expr << "\n"); + return success(); + } + + AffineExprBuilder::MapAndOperands getMap() { + return AffineExprBuilder::getMap(); + } + + PtrVal getBase() { + assert(base); + return base; + } + + void rescope(enzymexla::AffineScopeOp scope) { + if (!scope->isAncestor(user)) + return; + rescopeExpr(scope); + } + +private: + std::optional getGepAffineExpr(const DataLayout &dataLayout, + LLVM::GEPOp gep) { + // TODO what happens if we get a negative index + auto indicesRange = gep.getIndices(); + auto indices = SmallVector::value_type>( + indicesRange.begin(), indicesRange.end()); + assert(indices.size() > 0); + Type currentType = gep.getElemType(); + auto expr = getExpr(indices[0]); + if (failed(expr)) + return std::nullopt; + AffineExpr offset = (*expr) * dataLayout.getTypeSize(currentType); + + for (auto index : llvm::drop_begin(indices)) { + bool shouldCancel = + TypeSwitch(currentType) + .Case([&](LLVM::LLVMArrayType arrayType) { + auto expr = getExpr(index); + if (failed(expr)) + return true; + offset = offset + (*expr) * dataLayout.getTypeSize( + arrayType.getElementType()); + currentType = arrayType.getElementType(); + return false; + }) + .Case([&](LLVM::LLVMStructType structType) { + ArrayRef body = structType.getBody(); + int64_t indexInt; + auto constIndex = dyn_cast(index); + if (constIndex) + indexInt = constIndex.getInt(); + else + return true; + + for (uint32_t i : llvm::seq(indexInt)) { + if (!structType.isPacked()) + offset = alignTo(offset, + dataLayout.getTypeABIAlignment(body[i])); + offset = offset + dataLayout.getTypeSize(body[i]); + } + + // Align for the current type as well. + if (!structType.isPacked()) + offset = alignTo( + offset, dataLayout.getTypeABIAlignment(body[indexInt])); + currentType = body[indexInt]; + return false; + }) + .Default([&](Type type) { + LLVM_DEBUG(llvm::dbgs() + << "Unsupported type for offset computations" << type + << "\n"); + return true; + }); + + if (shouldCancel) + return std::nullopt; + } + + LLVM_DEBUG(llvm::dbgs() << "offset " << offset << "\n"); + + return offset; + } + + FailureOr buildAffineAccess(const DataLayout &dataLayout, + PtrVal addr) { + if (auto gep = dyn_cast_or_null(addr.getDefiningOp())) { + LLVM_DEBUG(llvm::dbgs() << "gep " << gep << "\n"); + auto base = cast(gep.getBase()); + + auto gepExpr = getGepAffineExpr(dataLayout, gep); + if (!gepExpr) + return failure(); + + auto aa = buildAffineAccess(dataLayout, base); + if (failed(aa)) + return failure(); + + AffineAccess newAA; + newAA.base = aa->base; + newAA.expr = aa->expr + *gepExpr; + LLVM_DEBUG(llvm::dbgs() << "added " << newAA.expr << "\n"); + return newAA; + } else if (auto addrSpaceCast = dyn_cast_or_null( + addr.getDefiningOp())) { + return buildAffineAccess(dataLayout, + cast(addrSpaceCast.getArg())); + } + + AffineAccess aa; + aa.base = addr; + aa.expr = getAffineConstantExpr(0, addr.getContext()); + LLVM_DEBUG(llvm::dbgs() << "base " << aa.expr << "\n"); + return aa; + } +}; + +struct AffineForBuilder { +public: + AffineForBuilder(scf::ForOp forOp, bool legalizeSymbols) + : lbBuilder(forOp, legalizeSymbols), ubBuilder(forOp, legalizeSymbols), + forOp(forOp) {} + + AffineExprBuilder lbBuilder; + AffineExprBuilder ubBuilder; + + scf::ForOp forOp; + int64_t step; + + void collectSymbolsForScope(Region *region, SmallPtrSetImpl &symbols) { + lbBuilder.collectSymbolsForScope(region, symbols); + ubBuilder.collectSymbolsForScope(region, symbols); + } + + SmallPtrSet getIllegalSymbols() { + auto set = lbBuilder.illegalSymbols; + set.insert(ubBuilder.illegalSymbols.begin(), + ubBuilder.illegalSymbols.end()); + return set; + } + + LogicalResult build() { + auto cstStep = getConstant(forOp.getStep()); + if (!cstStep) + return failure(); + step = *cstStep; + + if (failed(ubBuilder.build(forOp.getUpperBound())) || + failed(lbBuilder.build(forOp.getLowerBound()))) + return failure(); + + return success(); + } + + AffineExprBuilder::MapAndOperands getUbMap() { return ubBuilder.getMap(); } + + AffineExprBuilder::MapAndOperands getLbMap() { return lbBuilder.getMap(); } + + int64_t getStep() { return step; } + + void rescope(enzymexla::AffineScopeOp scope) { + if (!scope->isAncestor(forOp)) + return; + SmallVector newExprs; + + lbBuilder.rescopeExpr(scope); + ubBuilder.rescopeExpr(scope); + } +}; + +struct AffineIfBuilder { +public: + scf::IfOp ifOp; + bool legalizeSymbols; + AffineIfBuilder(scf::IfOp ifOp, bool legalizeSymbols) + : ifOp(ifOp), legalizeSymbols(legalizeSymbols) {} + + struct Constraint { + arith::CmpIPredicate pred; + struct Side { + Value val; + AffineExprBuilder builder; + }; + Side rhs, lhs; + }; + + struct SetAndOperands { + IntegerSet set; + SmallVector operands; + } sao; + + SmallVector constraints; + + LogicalResult build() { + Value cond = ifOp.getCondition(); + + if (failed(getConstraints(cond, constraints))) + return failure(); + + for (auto &c : constraints) { + for (auto side : {&c.lhs, &c.rhs}) { + auto &builder = side->builder; + if (failed(builder.build(side->val))) + return failure(); + } + } + + return success(); + } + + void collectSymbolsForScope(Region *region, SmallPtrSetImpl &symbols) { + for (auto &c : constraints) { + c.lhs.builder.collectSymbolsForScope(region, symbols); + c.rhs.builder.collectSymbolsForScope(region, symbols); + } + } + + SmallPtrSet getIllegalSymbols() { + SmallPtrSet set; + for (auto &c : constraints) { + set.insert(c.lhs.builder.illegalSymbols.begin(), + c.lhs.builder.illegalSymbols.end()); + set.insert(c.rhs.builder.illegalSymbols.begin(), + c.rhs.builder.illegalSymbols.end()); + } + return set; + } + + void rescope(enzymexla::AffineScopeOp scope) { + if (!scope->isAncestor(ifOp)) + return; + SmallVector newExprs; + + for (auto &c : constraints) { + c.lhs.builder.rescopeExpr(scope); + c.rhs.builder.rescopeExpr(scope); + } + } + + SetAndOperands getSet() { + SmallVector eqs; + SmallVector exprs; + unsigned numDims = 0; + unsigned numSymbols = 0; + SmallVector dimOperands; + SmallVector symbolOperands; + + auto getExpr = [&](AffineExprBuilder &builder) { + auto lhs = builder.getExpr(); + lhs = lhs.shiftDims(builder.dimOperands.size(), numDims); + lhs = lhs.shiftSymbols(builder.symbolOperands.size(), numSymbols); + numDims += builder.dimOperands.size(); + numSymbols += builder.symbolOperands.size(); + dimOperands.append(builder.dimOperands); + symbolOperands.append(builder.symbolOperands); + return lhs; + }; + + for (auto &c : constraints) { + + auto lhs = getExpr(c.lhs.builder); + auto rhs = getExpr(c.rhs.builder); + + AffineExpr expr = getAffineConstantExpr(0, ifOp->getContext()); + switch (c.pred) { + case arith::CmpIPredicate::eq: + exprs.push_back(rhs - lhs); + eqs.push_back(true); + break; + case arith::CmpIPredicate::ne: + llvm_unreachable("no ne"); + break; + case arith::CmpIPredicate::slt: + case arith::CmpIPredicate::ult: + expr = expr - 1; + [[fallthrough]]; + case arith::CmpIPredicate::sle: + case arith::CmpIPredicate::ule: + expr = expr + lhs - rhs; + exprs.push_back(expr); + eqs.push_back(false); + break; + case arith::CmpIPredicate::sgt: + case arith::CmpIPredicate::ugt: + expr = expr - 1; + [[fallthrough]]; + case arith::CmpIPredicate::sge: + case arith::CmpIPredicate::uge: + expr = expr + rhs - lhs; + exprs.push_back(expr); + eqs.push_back(false); + break; + } + } + sao.set = IntegerSet::get(numDims, numSymbols, exprs, eqs); + sao.operands = dimOperands; + sao.operands.append(symbolOperands); + affine::canonicalizeSetAndOperands(&sao.set, &sao.operands); + return sao; + } + + LogicalResult getConstraints(Value conjunction, + SmallVectorImpl &constraints) { + Operation *op = conjunction.getDefiningOp(); + if (!op) + return failure(); + if (isa(op)) { + auto lhs = op->getOperand(0); + auto rhs = op->getOperand(1); + if (succeeded(getConstraints(lhs, constraints)) && + succeeded(getConstraints(rhs, constraints))) + return success(); + else + return failure(); + } + if (auto cmp = dyn_cast(op)) { + // TODO there is a way to make this work with ne, but it is annoying to + // think through, ingore for now. + if (cmp.getPredicate() == arith::CmpIPredicate::ne) + return failure(); + constraints.emplace_back( + Constraint{cmp.getPredicate(), + {cmp.getLhs(), AffineExprBuilder(ifOp, legalizeSymbols)}, + {cmp.getRhs(), AffineExprBuilder(ifOp, legalizeSymbols)}}); + return success(); + } + return failure(); + } +}; + +// TODO this works for single-block regions where SSA values are not used across +// blocks but will fail when a value defined in `block` is used in another +// block. +static enzymexla::AffineScopeOp appendToScope(enzymexla::AffineScopeOp oldScope, + ValueRange operands) { + IRRewriter rewriter(oldScope); + assert(llvm::all_of(operands, [&](Value a) { + return llvm::all_of(oldScope->getOperands(), + [&](Value b) { return a != b; }); + })); + SmallVector newOperands(oldScope->getOperands()); + Block *b = &oldScope.getRegion().front(); + for (Value v : operands) { + if (llvm::find(newOperands, v) == newOperands.end()) { + b->addArgument(v.getType(), v.getLoc()); + newOperands.push_back(v); + } + } + auto scope = rewriter.create( + oldScope.getLoc(), oldScope->getResultTypes(), newOperands); + rewriter.inlineRegionBefore(oldScope.getRegion(), scope.getRegion(), + scope.getRegion().begin()); + rewriter.replaceOp(oldScope, scope); + return scope; +} + +template SmallVector getLocs(T values) { + return llvm::map_to_vector(values, [](Value v) { return v.getLoc(); }); +} + +static enzymexla::AffineScopeOp insertAffineScope(Block *block, + ValueRange operands) { + assert(block->getParent()->getBlocks().size() == 1); + + assert(!isa(block->getParentOp())); + if (auto scope = dyn_cast(block->front())) { + assert(scope->getNextNode() == scope->getBlock()->getTerminator()); + return appendToScope(scope, operands); + } + + IRRewriter rewriter(block->getParentOp()->getContext()); + rewriter.setInsertionPointToStart(block); + auto scope = rewriter.create( + block->getParentOp()->getLoc(), block->getTerminator()->getOperandTypes(), + operands); + Block *innerBlock = rewriter.createBlock( + &scope.getRegion(), {}, operands.getTypes(), getLocs(operands)); + while (scope->getNextNode() != &block->back()) + rewriter.moveOpBefore(scope->getNextNode(), innerBlock, innerBlock->end()); + rewriter.setInsertionPointToEnd(innerBlock); + Operation *terminator = block->getTerminator(); + rewriter.create(terminator->getLoc(), + terminator->getOperands()); + terminator->setOperands(scope->getResults()); + return scope; +} + +static constexpr bool useVectorLoadStore = true; + +static Operation *createVectorStore(OpBuilder &b, Location loc, Type ty, + TypedValue v, MemRefVal m, + AffineMap map, ValueRange mapOperands) { + if (useVectorLoadStore) { + auto vs = + b.create(loc, v, m, map, mapOperands); + vs->setAttr("polymer.access.type", TypeAttr::get(ty)); + return vs; + } + llvm_unreachable(""); +} + +static Value createVectorLoad(OpBuilder &b, Location loc, Type ty, + VectorType vty, MemRefVal m, AffineMap map, + ValueRange mapOperands) { + if (useVectorLoadStore) { + auto vl = + b.create(loc, vty, m, map, mapOperands); + vl->setAttr("polymer.access.type", TypeAttr::get(ty)); + return vl; + } + llvm_unreachable(""); +} + +namespace mlir { +LogicalResult +convertLLVMToAffineAccess(Operation *op, + const DataLayoutAnalysis &dataLayoutAnalysis, + bool legalizeSymbols) { + if (!legalizeSymbols && !op->hasTrait()) { + LLVM_DEBUG(llvm::errs() << "Must be called with an affine scope root when " + "not legelizing symbols\n"); + return failure(); + } + + MLIRContext *context = op->getContext(); + + MemrefConverter mc; + IndexConverter ic; + + // TODO Pretty slow but annoying to implement as we wrap the operation in + // the callback + while (true) { + auto res = op->walk([&](scf::ForOp forOp) { + AffineForBuilder forBuilder(forOp, legalizeSymbols); + if (failed(forBuilder.build())) + return WalkResult::advance(); + LLVM_DEBUG(llvm::dbgs() << "Converting\n" << forOp << "\n"); + if (legalizeSymbols) { + SmallPtrSet blocksToScope; + for (auto illegalSym : forBuilder.getIllegalSymbols()) + blocksToScope.insert(illegalSym.getParentBlock()); + for (Block *b : blocksToScope) { + SmallPtrSet symbols; + forBuilder.collectSymbolsForScope(b->getParent(), symbols); + SmallVector symbolsVec(symbols.begin(), symbols.end()); + auto scope = insertAffineScope(b, symbolsVec); + forBuilder.rescope(scope); + } + } + IRRewriter rewriter(forOp); + auto lb = forBuilder.getLbMap(); + auto ub = forBuilder.getUbMap(); + auto affineForOp = rewriter.create( + forOp.getLoc(), ic(lb.operands), lb.map, ic(ub.operands), ub.map, + forBuilder.getStep(), forOp.getInitArgs()); + if (!affineForOp.getRegion().empty()) + affineForOp.getRegion().front().erase(); + Block *block = forOp.getBody(); + SmallVector blockArgTypes = {rewriter.getIndexType()}; + auto iterArgTypes = forOp.getInitArgs().getTypes(); + blockArgTypes.insert(blockArgTypes.end(), iterArgTypes.begin(), + iterArgTypes.end()); + SmallVector blockArgLocs = + getLocs(forOp.getBody()->getArguments()); + auto newBlock = rewriter.createBlock(&affineForOp.getRegion(), {}, + blockArgTypes, blockArgLocs); + SmallVector newBlockArgs(newBlock->getArguments()); + auto origIVType = forOp.getInductionVar().getType(); + if (origIVType != rewriter.getIndexType()) { + rewriter.setInsertionPointToStart(newBlock); + newBlockArgs[0] = rewriter.create( + newBlockArgs[0].getLoc(), origIVType, newBlockArgs[0]); + } + rewriter.inlineBlockBefore(block, newBlock, newBlock->end(), + newBlockArgs); + rewriter.replaceOp(forOp, affineForOp); + auto yield = cast(newBlock->getTerminator()); + rewriter.setInsertionPoint(yield); + rewriter.replaceOpWithNewOp(yield, + yield.getOperands()); + return WalkResult::interrupt(); + }); + if (!res.wasInterrupted()) + break; + } + + while (true) { + auto res = op->walk([&](scf::IfOp ifOp) { + AffineIfBuilder ifBuilder(ifOp, legalizeSymbols); + if (failed(ifBuilder.build())) + return WalkResult::advance(); + LLVM_DEBUG(llvm::dbgs() << "Converting\n" << ifOp << "\n"); + if (legalizeSymbols) { + SmallPtrSet blocksToScope; + for (auto illegalSym : ifBuilder.getIllegalSymbols()) + blocksToScope.insert(illegalSym.getParentBlock()); + for (Block *b : blocksToScope) { + SmallPtrSet symbols; + ifBuilder.collectSymbolsForScope(b->getParent(), symbols); + SmallVector symbolsVec(symbols.begin(), symbols.end()); + auto scope = insertAffineScope(b, symbolsVec); + ifBuilder.rescope(scope); + } + } + IRRewriter rewriter(ifOp); + auto sao = ifBuilder.getSet(); + auto affineIfOp = rewriter.create( + ifOp.getLoc(), ifOp.getResultTypes(), sao.set, ic(sao.operands), + ifOp.elseBlock()); + for (auto [newRegion, oldRegion] : + llvm::zip(affineIfOp.getRegions(), ifOp.getRegions())) { + if (!newRegion->empty()) + newRegion->front().erase(); + if (oldRegion->empty()) + continue; + Block *block = &oldRegion->front(); + auto newBlock = rewriter.createBlock(newRegion); + rewriter.inlineBlockBefore(block, newBlock, newBlock->end(), {}); + auto yield = cast(newBlock->getTerminator()); + rewriter.setInsertionPoint(yield); + rewriter.replaceOpWithNewOp(yield, + yield.getOperands()); + } + rewriter.replaceOp(ifOp, affineIfOp); + return WalkResult::interrupt(); + }); + if (!res.wasInterrupted()) + break; + } + + SmallVector> accessBuilders; + auto handleOp = [&](Operation *op, PtrVal addr) { + LLVM_DEBUG(llvm::dbgs() << "Building affine access for " << op + << " for address " << addr << "\n"); + accessBuilders.push_back( + std::make_unique(op, legalizeSymbols)); + AffineAccessBuilder &aab = *accessBuilders.back(); + auto dl = dataLayoutAnalysis.getAtOrAbove(op); + auto res = aab.build(dl, addr); + if (failed(res)) + accessBuilders.pop_back(); + }; + op->walk([&](LLVM::StoreOp store) { + PtrVal addr = store.getAddr(); + handleOp(store, addr); + }); + op->walk([&](LLVM::LoadOp load) { + PtrVal addr = load.getAddr(); + handleOp(load, addr); + }); + + // TODO should also gather other mem operations such as memory intrinsics + // TODO should we shrink the scope to where no other memory operations + // exist? + + if (legalizeSymbols) { + SmallPtrSet blocksToScope; + for (auto &aabp : accessBuilders) + for (auto illegalSym : aabp->illegalSymbols) + blocksToScope.insert(illegalSym.getParentBlock()); + SmallPtrSet innermostBlocks; + for (Block *b : blocksToScope) { + SmallVector toRemove; + bool isInnermost = true; + for (Block *existing : innermostBlocks) { + if (existing->getParent()->isProperAncestor(b->getParent())) + toRemove.push_back(existing); + if (b->getParent()->isAncestor(existing->getParent())) + isInnermost = false; + } + for (Block *r : toRemove) + innermostBlocks.erase(r); + if (isInnermost) + innermostBlocks.insert(b); + } + + // TODO this looks terribly slow + for (Block *b : innermostBlocks) { + SmallPtrSet symbols; + for (auto &aabp : accessBuilders) + aabp->collectSymbolsForScope(b->getParent(), symbols); + SmallVector symbolsVec(symbols.begin(), symbols.end()); + auto scope = insertAffineScope(b, symbolsVec); + for (auto &aabp : accessBuilders) { + aabp->rescope(scope); + } + } + } + + IRMapping mapping; + for (auto &aabp : accessBuilders) { + AffineAccessBuilder &aab = *aabp; + // TODO add a test where some operations are left illegal + if (!aab.isLegal()) + continue; + + auto mao = aab.getMap(); + + auto dl = dataLayoutAnalysis.getAtOrAbove(aab.user); + if (auto load = dyn_cast(aab.user)) { + IRRewriter rewriter(load); + auto vty = VectorType::get({(int64_t)dl.getTypeSize(load.getType())}, + rewriter.getI8Type()); + auto vecLoad = + createVectorLoad(rewriter, load.getLoc(), load.getType(), vty, + mc(aab.getBase()), mao.map, ic(mao.operands)); + Operation *newLoad; + if (isa(load.getType())) { + Type intTy = rewriter.getIntegerType( + (int64_t)dl.getTypeSize(load.getType()) * 8); + auto cast = + rewriter.create(load.getLoc(), intTy, vecLoad); + newLoad = rewriter.create(load.getLoc(), + load.getType(), cast); + } else { + newLoad = rewriter.create(load.getLoc(), + load.getType(), vecLoad); + } + mapping.map(load, newLoad); + } else if (auto store = dyn_cast(aab.user)) { + Type ty = store.getValue().getType(); + IRRewriter rewriter(store); + auto vty = + VectorType::get({(int64_t)dl.getTypeSize(ty)}, rewriter.getI8Type()); + Value v; + if (isa(ty)) { + Type intTy = rewriter.getIntegerType((int64_t)dl.getTypeSize(ty) * 8); + v = rewriter.create(store.getLoc(), intTy, + store.getValue()); + v = rewriter.create(store.getLoc(), vty, v); + } else { + v = rewriter.create(store.getLoc(), vty, + store.getValue()); + } + Operation *newStore = createVectorStore( + rewriter, store.getLoc(), ty, cast>(v), + mc(aab.base), mao.map, ic(mao.operands)); + mapping.map(store.getOperation(), newStore); + } else { + llvm_unreachable(""); + } + } + + IRRewriter rewriter(context); + for (auto &&[oldOp, newOp] : mapping.getOperationMap()) { + rewriter.replaceOp(oldOp, newOp); + } + + RewritePatternSet patterns(context); + patterns.insert(context, dataLayoutAnalysis); + GreedyRewriteConfig config; + return applyPatternsAndFoldGreedily(op, std::move(patterns), config); +} +} // namespace mlir + +namespace mlir { +void populateRemoveIVPatterns(RewritePatternSet &patterns) { + patterns.insert( + patterns.getContext()); +} +} // namespace mlir + +// This should be scheduled on individual functions +struct LLVMToAffineAccessPass + : public enzyme::impl::LLVMToAffineAccessPassBase { + using LLVMToAffineAccessPassBase::LLVMToAffineAccessPassBase; + + void runOnOperation() override { + auto context = &getContext(); + RewritePatternSet patterns(context); + populateRemoveIVPatterns(patterns); + GreedyRewriteConfig config; + if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config))) { + signalPassFailure(); + return; + } + Operation *op = getOperation(); + const auto &dataLayoutAnalysis = getAnalysis(); + if (failed(convertLLVMToAffineAccess(op, dataLayoutAnalysis, true))) { + signalPassFailure(); + return; + } + } +}; diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td index d2ed6eb11..8cd7f2c7e 100644 --- a/src/enzyme_ad/jax/Passes/Passes.td +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -387,4 +387,14 @@ def EnzymeLiftControlFlowToSCFPass : Pass<"enzyme-lift-cf-to-scf"> { ]; } +def LLVMToAffineAccessPass : Pass<"llvm-to-affine-access"> { + let summary = ""; + let dependentDialects = [ + "memref::MemRefDialect", + "affine::AffineDialect", + "vector::VectorDialect", + "enzymexla::EnzymeXLADialect", + ]; +} + #endif diff --git a/src/enzyme_ad/jax/TransformOps/RaisingTransformOps.cpp b/src/enzyme_ad/jax/TransformOps/RaisingTransformOps.cpp index 1505aa04d..c07b3eb86 100644 --- a/src/enzyme_ad/jax/TransformOps/RaisingTransformOps.cpp +++ b/src/enzyme_ad/jax/TransformOps/RaisingTransformOps.cpp @@ -27,118 +27,113 @@ using namespace mlir; namespace mlir { namespace transform { -struct RemoveIVs : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(scf::ForOp forOp, - PatternRewriter &rewriter) const override { - if (!forOp.getRegion().hasOneBlock()) - return failure(); - unsigned numIterArgs = forOp.getNumRegionIterArgs(); - auto loc = forOp->getLoc(); - bool changed = false; - llvm::SetVector removed; - llvm::MapVector steps; - auto yield = cast(forOp.getBody()->getTerminator()); - for (unsigned i = 0; i < numIterArgs; i++) { - auto ba = forOp.getRegionIterArgs()[i]; - auto init = forOp.getInits()[i]; - auto next = yield->getOperand(i); - - auto increment = next.getDefiningOp(); - if (!increment) - continue; +LogicalResult RemoveIVs::matchAndRewrite(scf::ForOp forOp, + PatternRewriter &rewriter) const { + if (!forOp.getRegion().hasOneBlock()) + return failure(); + unsigned numIterArgs = forOp.getNumRegionIterArgs(); + auto loc = forOp->getLoc(); + bool changed = false; + llvm::SetVector removed; + llvm::MapVector steps; + auto yield = cast(forOp.getBody()->getTerminator()); + for (unsigned i = 0; i < numIterArgs; i++) { + auto ba = forOp.getRegionIterArgs()[i]; + auto init = forOp.getInits()[i]; + auto next = yield->getOperand(i); + + auto increment = next.getDefiningOp(); + if (!increment) + continue; + + Value step = nullptr; + if (increment.getLhs() == ba) { + step = increment.getRhs(); + } else { + step = increment.getLhs(); + } + if (!step) + continue; - Value step = nullptr; - if (increment.getLhs() == ba) { - step = increment.getRhs(); - } else { - step = increment.getLhs(); - } - if (!step) - continue; + // If it dominates the loop entry + if (!step.getParentRegion()->isProperAncestor(&forOp.getRegion())) + continue; - // If it dominates the loop entry - if (!step.getParentRegion()->isProperAncestor(&forOp.getRegion())) - continue; + rewriter.setInsertionPointToStart(forOp.getBody()); + Value iterNum = rewriter.create(loc, forOp.getInductionVar(), + forOp.getLowerBound()); + iterNum = rewriter.create(loc, iterNum, forOp.getStep()); - rewriter.setInsertionPointToStart(forOp.getBody()); - Value iterNum = rewriter.create( - loc, forOp.getInductionVar(), forOp.getLowerBound()); - iterNum = rewriter.create(loc, iterNum, forOp.getStep()); + Value replacementIV = rewriter.create(loc, iterNum, step); + replacementIV = rewriter.create(loc, replacementIV, init); - Value replacementIV = rewriter.create(loc, iterNum, step); - replacementIV = rewriter.create(loc, replacementIV, init); + rewriter.replaceAllUsesWith(ba, replacementIV); - rewriter.replaceAllUsesWith(ba, replacementIV); + removed.insert(i); + steps.insert({i, step}); + changed = true; + } - removed.insert(i); - steps.insert({i, step}); - changed = true; - } + if (!changed) + return failure(); + + SmallVector newInits; + for (unsigned i = 0; i < numIterArgs; i++) + if (!removed.contains(i)) + newInits.push_back(forOp.getInits()[i]); + + rewriter.setInsertionPoint(forOp); + auto newForOp = rewriter.create(loc, forOp.getLowerBound(), + forOp.getUpperBound(), + forOp.getStep(), newInits); + if (!newForOp.getRegion().empty()) + newForOp.getRegion().front().erase(); + assert(newForOp.getRegion().empty()); + rewriter.inlineRegionBefore(forOp.getRegion(), newForOp.getRegion(), + newForOp.getRegion().begin()); + + SmallVector newYields; + for (unsigned i = 0; i < numIterArgs; i++) + if (!removed.contains(i)) + newYields.push_back(yield->getOperand(i)); + + rewriter.setInsertionPoint(yield); + rewriter.replaceOpWithNewOp(yield, newYields); + + llvm::BitVector toDelete(numIterArgs + 1); + for (unsigned i = 0; i < numIterArgs; i++) + if (removed.contains(i)) + toDelete[i + 1] = true; + newForOp.getBody()->eraseArguments(toDelete); + + rewriter.setInsertionPoint(newForOp); + unsigned curNewRes = 0; + for (unsigned i = 0; i < numIterArgs; i++) { + auto result = forOp->getResult(i); + if (removed.contains(i)) { + if (result.use_empty()) + continue; - if (!changed) - return failure(); - - SmallVector newInits; - for (unsigned i = 0; i < numIterArgs; i++) - if (!removed.contains(i)) - newInits.push_back(forOp.getInits()[i]); - - rewriter.setInsertionPoint(forOp); - auto newForOp = rewriter.create(loc, forOp.getLowerBound(), - forOp.getUpperBound(), - forOp.getStep(), newInits); - if (!newForOp.getRegion().empty()) - newForOp.getRegion().front().erase(); - assert(newForOp.getRegion().empty()); - rewriter.inlineRegionBefore(forOp.getRegion(), newForOp.getRegion(), - newForOp.getRegion().begin()); - - SmallVector newYields; - for (unsigned i = 0; i < numIterArgs; i++) - if (!removed.contains(i)) - newYields.push_back(yield->getOperand(i)); - - rewriter.setInsertionPoint(yield); - rewriter.replaceOpWithNewOp(yield, newYields); - - llvm::BitVector toDelete(numIterArgs + 1); - for (unsigned i = 0; i < numIterArgs; i++) - if (removed.contains(i)) - toDelete[i + 1] = true; - newForOp.getBody()->eraseArguments(toDelete); - - rewriter.setInsertionPoint(newForOp); - unsigned curNewRes = 0; - for (unsigned i = 0; i < numIterArgs; i++) { - auto result = forOp->getResult(i); - if (removed.contains(i)) { - if (result.use_empty()) - continue; - - rewriter.setInsertionPointAfter(forOp.getOperation()); - Value iterNum = rewriter.create( - loc, forOp.getUpperBound(), forOp.getLowerBound()); - iterNum = - rewriter.create(loc, iterNum, forOp.getStep()); - - Value afterLoop = - rewriter.create(loc, iterNum, steps[i]); - afterLoop = - rewriter.create(loc, afterLoop, forOp.getInits()[i]); - - rewriter.replaceAllUsesWith(result, afterLoop); - } else { - rewriter.replaceAllUsesWith(result, newForOp->getResult(curNewRes++)); - } - } + rewriter.setInsertionPointAfter(forOp.getOperation()); + Value iterNum = rewriter.create(loc, forOp.getUpperBound(), + forOp.getLowerBound()); + iterNum = rewriter.create(loc, iterNum, forOp.getStep()); - forOp->getParentOp()->dump(); - rewriter.eraseOp(forOp); + Value afterLoop = rewriter.create(loc, iterNum, steps[i]); + afterLoop = + rewriter.create(loc, afterLoop, forOp.getInits()[i]); - return success(); + rewriter.replaceAllUsesWith(result, afterLoop); + } else { + rewriter.replaceAllUsesWith(result, newForOp->getResult(curNewRes++)); + } } -}; + + forOp->getParentOp()->dump(); + rewriter.eraseOp(forOp); + + return success(); +} static inline void clearBlock(mlir::Block *block, mlir::RewriterBase &rewriter) { @@ -187,53 +182,47 @@ static bool isNormalized(scf::ForOp op) { #define DEBUG_TYPE "normalize-loop" #define DBGS llvm::dbgs -struct NormalizeLoop : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(scf::ForOp op, - PatternRewriter &rewriter) const override { - using namespace arith; - if (isNormalized(op) || - !isa(op->getParentOp())) { - LLVM_DEBUG(DBGS() << "[normalize-loop] loop already normalized\n"); - return failure(); - } - - rewriter.setInsertionPoint(op); - Value zero = createConstantInt(rewriter, op.getLoc(), - op.getInductionVar().getType(), 0); - Value one = createConstantInt(rewriter, op.getLoc(), - op.getInductionVar().getType(), 1); - - Value difference = rewriter.create(op.getLoc(), op.getUpperBound(), - op.getLowerBound()); - Value tripCount = rewriter.create( - op.getLoc(), - rewriter.create( - op.getLoc(), rewriter.create(op.getLoc(), difference, one), - op.getStep()), - one); - // rewriter.create(op.getLoc(), difference, op.getStep()); - auto newForOp = rewriter.create(op.getLoc(), zero, tripCount, - one, op.getInits()); - clearBlock(newForOp.getBody(), rewriter); - rewriter.setInsertionPointToStart(newForOp.getBody()); - Value scaled = rewriter.create( - op.getLoc(), newForOp.getInductionVar(), op.getStep()); - Value iv = rewriter.create(op.getLoc(), op.getLowerBound(), scaled); - SmallVector newArgs(newForOp.getRegion().args_begin(), - newForOp.getRegion().args_end()); - newArgs[0] = iv; - rewriter.inlineBlockBefore(op.getBody(), newForOp.getBody(), - newForOp.getBody()->end(), newArgs); - rewriter.replaceOp(op, newForOp->getResults()); - return success(); +LogicalResult NormalizeLoop::matchAndRewrite(scf::ForOp op, + PatternRewriter &rewriter) const { + using namespace arith; + if (isNormalized(op) || + !isa(op->getParentOp())) { + LLVM_DEBUG(DBGS() << "[normalize-loop] loop already normalized\n"); + return failure(); } -}; + + rewriter.setInsertionPoint(op); + Value zero = createConstantInt(rewriter, op.getLoc(), + op.getInductionVar().getType(), 0); + Value one = createConstantInt(rewriter, op.getLoc(), + op.getInductionVar().getType(), 1); + + Value difference = rewriter.create(op.getLoc(), op.getUpperBound(), + op.getLowerBound()); + Value tripCount = rewriter.create( + op.getLoc(), + rewriter.create( + op.getLoc(), rewriter.create(op.getLoc(), difference, one), + op.getStep()), + one); + auto newForOp = rewriter.create(op.getLoc(), zero, tripCount, one, + op.getInits()); + clearBlock(newForOp.getBody(), rewriter); + rewriter.setInsertionPointToStart(newForOp.getBody()); + Value scaled = rewriter.create( + op.getLoc(), newForOp.getInductionVar(), op.getStep()); + Value iv = rewriter.create(op.getLoc(), op.getLowerBound(), scaled); + SmallVector newArgs(newForOp.getRegion().args_begin(), + newForOp.getRegion().args_end()); + newArgs[0] = iv; + rewriter.inlineBlockBefore(op.getBody(), newForOp.getBody(), + newForOp.getBody()->end(), newArgs); + rewriter.replaceOp(op, newForOp->getResults()); + return success(); +} } // namespace transform } // namespace mlir - #include "src/enzyme_ad/jax/TransformOps/RaisingTransformPatterns.cpp.inc" namespace { diff --git a/src/enzyme_ad/jax/TransformOps/RaisingTransformOps.h b/src/enzyme_ad/jax/TransformOps/RaisingTransformOps.h index 2027640c7..c3b1ec505 100644 --- a/src/enzyme_ad/jax/TransformOps/RaisingTransformOps.h +++ b/src/enzyme_ad/jax/TransformOps/RaisingTransformOps.h @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" @@ -20,3 +21,22 @@ void registerRaisingTransformExtension(mlir::DialectRegistry ®istry); } // namespace enzyme } // namespace mlir + +namespace mlir { +namespace transform { +struct RemoveIVs : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::scf::ForOp forOp, + mlir::PatternRewriter &rewriter) const override; +}; + +struct NormalizeLoop : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::ForOp op, + PatternRewriter &rewriter) const override; +}; + +} // namespace transform +} // namespace mlir diff --git a/test/lit_tests/raising/llvm_to_affine_access.mlir b/test/lit_tests/raising/llvm_to_affine_access.mlir new file mode 100644 index 000000000..5ca1040f6 --- /dev/null +++ b/test/lit_tests/raising/llvm_to_affine_access.mlir @@ -0,0 +1,66 @@ +// RUN: enzymexlamlir-opt %s --pass-pipeline="builtin.module(llvm-to-affine-access)" | FileCheck %s + +func.func @test_load_store_conversion(%arg0: !llvm.ptr<1>, %idx: i64) { + %0 = llvm.getelementptr inbounds %arg0[%idx] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, i64 + %1 = llvm.load %0 {alignment = 1 : i64} : !llvm.ptr<1> -> i64 + %2 = llvm.mul %1, %1 : i64 + + llvm.store %2, %0 {alignment = 1 : i64} : i64, !llvm.ptr<1> + + return +} + +// CHECK-LABEL: func @test_load_store_conversion +// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr<1> +// CHECK-SAME: %[[ARG1:.*]]: i64 +// CHECK: %[[MEMREF:.*]] = "enzymexla.pointer2memref"(%[[ARG0]]) {{.*}} memref +// CHECK: %[[IDX:.*]] = arith.index_cast %[[ARG1]] +// CHECK: affine.vector_load %[[MEMREF]][symbol(%[[IDX]]) * 8] {{.*}} vector<8xi8> +// CHECK: affine.vector_store + +// ----- + +func.func @test_multidim_load_store(%arg0: !llvm.ptr<1>, %idx1: i64, %idx2: i64) { + %c1 = llvm.mlir.constant(1 : index) : i64 + %ptr = llvm.getelementptr %arg0[%idx1, %idx2] : (!llvm.ptr<1>, i64, i64) -> !llvm.ptr<1>, !llvm.array<8 x i64> + %val = llvm.load %ptr : !llvm.ptr<1> -> i64 + + %idx1p1 = llvm.add %idx1, %c1 : i64 + %idx2p1 = llvm.add %idx2, %c1 : i64 + + %ptr_str = llvm.getelementptr %arg0[%idx1p1, %idx2p1] : (!llvm.ptr<1>, i64, i64) -> !llvm.ptr<1>, !llvm.array<8 x i64> + llvm.store %val, %ptr_str : i64, !llvm.ptr<1> + + return +} + +// CHECK-LABEL: func @test_multidim_load_store +// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr<1>, +// CHECK-SAME: %[[ARG1:.*]]: i64, +// CHECK-SAME: %[[ARG2:.*]]: i64 +// CHECK: %[[MEMREF:.*]] = "enzymexla.pointer2memref"(%[[ARG0]]) {{.*}} memref +// CHECK-DAG: %[[IDX1:.*]] = arith.index_cast %[[ARG1]] +// CHECK-DAG: %[[IDX2:.*]] = arith.index_cast %[[ARG2]] +// CHECK: affine.vector_load %[[MEMREF]][symbol(%[[IDX1]]) * 64 + symbol(%[[IDX2]]) * 8] {{.*}} vector<8xi8> +// CHECK: affine.vector_store + +// ----- + +func.func @test_struct_access(%arg0: !llvm.ptr) { + %ptr = llvm.getelementptr %arg0[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i64)> + %val = llvm.load %ptr : !llvm.ptr -> i64 + + llvm.store %val, %ptr : i64, !llvm.ptr + + return +} + +// CHECK-LABEL: func @test_struct_access +// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr +// CHECK: %[[MEMREF:.*]] = "enzymexla.pointer2memref"(%[[ARG0]]) {{.*}} memref +// CHECK: affine.vector_load %[[MEMREF]][0] {{.*}} vector<8xi8> +// CHECK: affine.vector_store + +// CHEC-K: %[[MEMREF:.*]] = enzymexla.at_addr %arg0 : !llvm.ptr to memref +// CHEC-K: %[[LOAD:.*]] = memref.load %[[MEMREF]][%c0] : memref +// CHEC-K: memref.store %[[VAL]], %[[MEMREF]][%c0] : memref From 9b7367f21c21fac82af5b474fad8cdbd38e78739 Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Thu, 6 Feb 2025 10:40:43 +0900 Subject: [PATCH 02/12] Fix build --- src/enzyme_ad/jax/BUILD | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index c40006320..080549ab6 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -193,6 +193,10 @@ cc_library( "@llvm-project//mlir:Pass", "@llvm-project//mlir:TransformDialect", "@llvm-project//mlir:TransformDialectInterfaces", + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SCFUtils", ":RaisingTransformOpsIncGen", ":RaisingTransformOpsImplIncGen", ":RaisingTransformPatternsIncGen", @@ -417,6 +421,7 @@ cc_library( "@llvm-project//mlir:MemorySlotInterfaces", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:AffineAnalysis", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:ArithUtils", "@llvm-project//mlir:CommonFolders", @@ -433,6 +438,7 @@ cc_library( "@llvm-project//mlir:AffineToStandard", "@llvm-project//mlir:ArithToLLVM", "@llvm-project//mlir:MemRefTransforms", + "@llvm-project//mlir:MemRefUtils", "@llvm-project//mlir:TransformDialect", "@llvm-project//mlir:TransformDialectInterfaces", "@llvm-project//mlir:TransformDialectTransforms", @@ -445,6 +451,7 @@ cc_library( "@llvm-project//mlir:ToLLVMIRTranslation", "@llvm-project//mlir:FuncToLLVM", "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SCFUtils", "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:FuncTransforms", "@llvm-project//mlir:Transforms", From 62d90f459df165cb180f08fd7b5d5d9b9b024336 Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Thu, 6 Feb 2025 10:41:05 +0900 Subject: [PATCH 03/12] Put stray print in LLVM_DEBUG --- src/enzyme_ad/jax/Passes/LLVMToAffineAccess.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/LLVMToAffineAccess.cpp b/src/enzyme_ad/jax/Passes/LLVMToAffineAccess.cpp index 962710f39..e5388accc 100644 --- a/src/enzyme_ad/jax/Passes/LLVMToAffineAccess.cpp +++ b/src/enzyme_ad/jax/Passes/LLVMToAffineAccess.cpp @@ -449,8 +449,7 @@ struct AffineExprBuilder { return getAffineConstantExpr(constIndex.getInt(), user->getContext()); } else { auto expr = buildExpr(cast(index)); - if (succeeded(expr)) - expr->dump(); + LLVM_DEBUG(if (succeeded(expr)) expr->dump()); return expr; } } From 400ca9688bf001eec726b182867ae220846924cd Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Thu, 6 Feb 2025 10:49:07 +0900 Subject: [PATCH 04/12] Disable legalization --- src/enzyme_ad/jax/Passes/LLVMToAffineAccess.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/enzyme_ad/jax/Passes/LLVMToAffineAccess.cpp b/src/enzyme_ad/jax/Passes/LLVMToAffineAccess.cpp index e5388accc..2ab7adec2 100644 --- a/src/enzyme_ad/jax/Passes/LLVMToAffineAccess.cpp +++ b/src/enzyme_ad/jax/Passes/LLVMToAffineAccess.cpp @@ -1183,7 +1183,9 @@ struct LLVMToAffineAccessPass } Operation *op = getOperation(); const auto &dataLayoutAnalysis = getAnalysis(); - if (failed(convertLLVMToAffineAccess(op, dataLayoutAnalysis, true))) { + // TODO in order to enable legalization we need to add an enzymexla.yield op + // to terminate it with + if (failed(convertLLVMToAffineAccess(op, dataLayoutAnalysis, false))) { signalPassFailure(); return; } From 1f51d823fe782d79af3f59c6ae45c6dc7b96df1a Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Thu, 6 Feb 2025 11:54:12 +0900 Subject: [PATCH 05/12] Fix build --- src/enzyme_ad/jax/BUILD | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index 080549ab6..30fe3a862 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -186,16 +186,12 @@ cc_library( "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:LLVMDialect", "@llvm-project//mlir:SCFDialect", - "@llvm-project//mlir:SCFUtils", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgTransformOps", "@llvm-project//mlir:Pass", "@llvm-project//mlir:TransformDialect", "@llvm-project//mlir:TransformDialectInterfaces", - "@llvm-project//mlir:AffineDialect", - "@llvm-project//mlir:LLVMDialect", - "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:SCFUtils", ":RaisingTransformOpsIncGen", ":RaisingTransformOpsImplIncGen", From f1562a387a7d2bea330700ac1862eaea64272666 Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Thu, 6 Feb 2025 11:58:14 +0900 Subject: [PATCH 06/12] Use affine parallel in cpu lowering --- src/enzyme_ad/jax/Passes/LowerKernel.cpp | 26 ++++++++++---- test/lit_tests/lowering/cpu.mlir | 44 +++++++++++------------- 2 files changed, 40 insertions(+), 30 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/LowerKernel.cpp b/src/enzyme_ad/jax/Passes/LowerKernel.cpp index 6a17eb526..10838d163 100644 --- a/src/enzyme_ad/jax/Passes/LowerKernel.cpp +++ b/src/enzyme_ad/jax/Passes/LowerKernel.cpp @@ -290,7 +290,19 @@ bool CompileCPUKernel(SymbolTableCollection &symbolTable, mlir::Location loc, IRMapping map; map.map(op.getArguments(), entryBlock.getArguments()); - auto par = builder.create(loc, inits, finals, incs); + auto context = loc.getContext(); + SmallVector idMaps, zeroMaps; + auto zeroMap = AffineMap::getConstantMap(0, context); + zeroMaps.insert(zeroMaps.begin(), 6, zeroMap); + for (unsigned i = 0; i < 6; i++) { + auto idMap = AffineMap::get(0, 6, getAffineSymbolExpr(i, context)); + idMaps.push_back(idMap); + } + + SmallVector steps(6, 1); + auto par = builder.create( + loc, TypeRange(), ArrayRef(), zeroMaps, + ValueRange(), idMaps, finals, steps); builder.create(loc); @@ -322,21 +334,21 @@ bool CompileCPUKernel(SymbolTableCollection &symbolTable, mlir::Location loc, executeRegion->walk([&](NVVM::BlockIdXOp idxOp) { OpBuilder rewriter(idxOp); auto rep = rewriter.create(op.getLoc(), idxOp.getType(), - par.getInductionVars()[0]); + par.getIVs()[0]); idxOp.replaceAllUsesWith(rep.getResult()); idxOp.erase(); }); executeRegion->walk([&](NVVM::BlockIdYOp idxOp) { OpBuilder rewriter(idxOp); auto rep = rewriter.create(op.getLoc(), idxOp.getType(), - par.getInductionVars()[1]); + par.getIVs()[1]); idxOp.replaceAllUsesWith(rep.getResult()); idxOp.erase(); }); executeRegion->walk([&](NVVM::BlockIdZOp idxOp) { OpBuilder rewriter(idxOp); auto rep = rewriter.create(op.getLoc(), idxOp.getType(), - par.getInductionVars()[2]); + par.getIVs()[2]); idxOp.replaceAllUsesWith(rep.getResult()); idxOp.erase(); }); @@ -345,21 +357,21 @@ bool CompileCPUKernel(SymbolTableCollection &symbolTable, mlir::Location loc, executeRegion->walk([&](NVVM::ThreadIdXOp idxOp) { OpBuilder rewriter(idxOp); auto rep = rewriter.create(op.getLoc(), idxOp.getType(), - par.getInductionVars()[3]); + par.getIVs()[3]); idxOp.replaceAllUsesWith(rep.getResult()); idxOp.erase(); }); executeRegion->walk([&](NVVM::ThreadIdYOp idxOp) { OpBuilder rewriter(idxOp); auto rep = rewriter.create(op.getLoc(), idxOp.getType(), - par.getInductionVars()[4]); + par.getIVs()[4]); idxOp.replaceAllUsesWith(rep.getResult()); idxOp.erase(); }); executeRegion->walk([&](NVVM::ThreadIdZOp idxOp) { OpBuilder rewriter(idxOp); auto rep = rewriter.create(op.getLoc(), idxOp.getType(), - par.getInductionVars()[5]); + par.getIVs()[5]); idxOp.replaceAllUsesWith(rep.getResult()); idxOp.erase(); }); diff --git a/test/lit_tests/lowering/cpu.mlir b/test/lit_tests/lowering/cpu.mlir index 6218a68ef..fbea489a5 100644 --- a/test/lit_tests/lowering/cpu.mlir +++ b/test/lit_tests/lowering/cpu.mlir @@ -29,29 +29,27 @@ module { } } -// CHECK: func.func private @kern$par0(%arg0: !llvm.ptr<1>) { -// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32 -// CHECK-NEXT: %0 = llvm.mlir.constant(63 : i32) : i32 -// CHECK-NEXT: %c0 = arith.constant 0 : index -// CHECK-NEXT: %c1 = arith.constant 1 : index -// CHECK-NEXT: %c40 = arith.constant 40 : index -// CHECK-NEXT: scf.parallel (%arg1) = (%c0) to (%c40) step (%c1) { -// CHECK-NEXT: scf.execute_region { -// CHECK-NEXT: %1 = llvm.icmp "ugt" %c0_i32, %0 : i32 -// CHECK-NEXT: llvm.cond_br %1, ^bb2, ^bb1 -// CHECK-NEXT: ^bb1: // pred: ^bb0 -// CHECK-NEXT: %2 = llvm.load %arg0 {alignment = 1 : i64} : !llvm.ptr<1> -> i64 -// CHECK-NEXT: %3 = llvm.mul %2, %2 : i64 -// CHECK-NEXT: llvm.store %3, %arg0 {alignment = 1 : i64} : i64, !llvm.ptr<1> -// CHECK-NEXT: scf.yield -// CHECK-NEXT: ^bb2: // pred: ^bb0 -// CHECK-NEXT: llvm.call fastcc @throw_boundserror_2676() : () -> () -// CHECK-NEXT: scf.yield -// CHECK-NEXT: } -// CHECK-NEXT: scf.reduce -// CHECK-NEXT: } -// CHECK-NEXT: return -// CHECK-NEXT: } +// CHECK: func.func private @kern$par0(%arg0: !llvm.ptr<1>) { +// CHECK-NEXT: %0 = llvm.mlir.constant(63 : i32) : i32 +// CHECK-NEXT: affine.parallel (%arg1, %arg2, %arg3, %arg4, %arg5, %arg6) = (0, 0, 0, 0, 0, 0) to (1, 1, 1, 1, 1, 40) { +// CHECK-NEXT: scf.execute_region { +// CHECK-NEXT: %1 = arith.index_cast %arg4 : index to i32 +// CHECK-NEXT: %2 = llvm.icmp "ugt" %1, %0 : i32 +// CHECK-NEXT: llvm.cond_br %2, ^bb2, ^bb1 +// CHECK-NEXT: ^bb1: // pred: ^bb0 +// CHECK-NEXT: %3 = llvm.zext %1 : i32 to i64 +// CHECK-NEXT: %4 = llvm.getelementptr inbounds %arg0[%3] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, i64 +// CHECK-NEXT: %5 = llvm.load %4 {alignment = 1 : i64} : !llvm.ptr<1> -> i64 +// CHECK-NEXT: %6 = llvm.mul %5, %5 : i64 +// CHECK-NEXT: llvm.store %6, %4 {alignment = 1 : i64} : i64, !llvm.ptr<1> +// CHECK-NEXT: scf.yield +// CHECK-NEXT: ^bb2: // pred: ^bb0 +// CHECK-NEXT: llvm.call fastcc @throw_boundserror_2676() : () -> () +// CHECK-NEXT: scf.yield +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return +// CHECK-NEXT: } // CHECK: func.func @main(%arg0: tensor<64xi64>) -> tensor<64xi64> { // CHECK-NEXT: %0 = enzymexla.jit_call @kern$par0 (%arg0) {output_operand_aliases = [#stablehlo.output_operand_alias]} : (tensor<64xi64>) -> tensor<64xi64> From 65912df74bf650385b196a14c346cf0359f606d2 Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Thu, 6 Feb 2025 13:50:41 +0900 Subject: [PATCH 07/12] Add converting from memref to memref --- .../jax/Passes/LLVMToAffineAccess.cpp | 143 +++++++++++++++++- test/lit_tests/raising/cpu.mlir | 45 ++++++ 2 files changed, 181 insertions(+), 7 deletions(-) create mode 100644 test/lit_tests/raising/cpu.mlir diff --git a/src/enzyme_ad/jax/Passes/LLVMToAffineAccess.cpp b/src/enzyme_ad/jax/Passes/LLVMToAffineAccess.cpp index 2ab7adec2..954056ed1 100644 --- a/src/enzyme_ad/jax/Passes/LLVMToAffineAccess.cpp +++ b/src/enzyme_ad/jax/Passes/LLVMToAffineAccess.cpp @@ -56,6 +56,7 @@ #include "Utils.h" #include +#include #include #include @@ -113,8 +114,8 @@ convertLLVMAllocaToMemrefAlloca(LLVM::AllocaOp alloc, RewriterBase &rewriter, Type elType = rewriter.getI8Type(); int64_t elNum = dataLayout.getTypeSize(alloc.getElemType()) * (*sizeVal); - auto ptr2memref = - dyn_cast(alloc.getRes().use_begin()->getOwner()); + auto ptr2memref = dyn_cast( + alloc.getRes().use_begin()->getOwner()); if (!ptr2memref) return failure(); @@ -133,6 +134,121 @@ convertLLVMAllocaToMemrefAlloca(LLVM::AllocaOp alloc, RewriterBase &rewriter, } namespace { + +struct ConvertToTypedMemref + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + const DataLayoutAnalysis &dl; + ConvertToTypedMemref(MLIRContext *context, const DataLayoutAnalysis &dl) + : OpRewritePattern(context), dl(dl) {} + + LogicalResult matchAndRewrite(enzymexla::Pointer2MemrefOp p2m, + PatternRewriter &rewriter) const override { + LLVM_DEBUG(llvm::dbgs() << "Checking " << p2m << "\n"); + TypedValue memref = p2m.getResult(); + bool allGood = true; + Type type = nullptr; + int64_t allSize = 0; + TypedValue newMemref = nullptr; + auto getNewMemref = [&]() { + OpBuilder::InsertionGuard g(rewriter); + if (!newMemref) { + rewriter.setInsertionPoint(p2m); + auto newp2m = rewriter.create( + p2m.getLoc(), + MemRefType::get({ShapedType::kDynamic}, type, + MemRefLayoutAttrInterface{}, + memref.getType().getMemorySpace()), + p2m.getSource()); + newMemref = newp2m.getResult(); + } + return newMemref; + }; + + SmallVector toErase; + + IRMapping mapping; + for (auto &use : memref.getUses()) { + auto checkTypeAndAlignment = [&](int64_t size, Type t, AffineExpr expr) { + allSize = size; + if (!expr.isMultipleOf(size)) + return failure(); + if (!type) { + type = t; + return success(); + } + if (type == t) { + return success(); + } + return failure(); + }; + if (auto load = dyn_cast(use.getOwner())) { + assert(load.getValue().hasOneUse()); + Operation *user = *load.getValue().user_begin(); + assert(user->getNumResults() == 1); + assert(load.getType().getRank() == 1); + assert(load.getMemRefType().getRank() == 1); + auto size = load.getType().getShape()[0]; + assert(size != ShapedType::kDynamic); + auto map = load.getMap(); + auto expr = map.getResults()[0]; + auto value = user->getResult(0); + if (checkTypeAndAlignment(size, value.getType(), expr).failed()) { + allGood = false; + break; + } + rewriter.setInsertionPoint(load); + auto newMap = AffineMap::get(map.getNumDims(), map.getNumSymbols(), + {expr.floorDiv(size)}, load.getContext()); + auto newLoad = rewriter.create( + load.getLoc(), getNewMemref(), newMap, load.getMapOperands()); + mapping.map(value, newLoad.getValue()); + toErase.push_back(user); + toErase.push_back(load); + } else if (auto store = + dyn_cast(use.getOwner())) { + Operation *user = store.getValue().getDefiningOp(); + auto size = store.getValue().getType().getShape()[0]; + assert(size != ShapedType::kDynamic); + auto map = store.getMap(); + auto expr = map.getResults()[0]; + auto value = user->getOperand(0); + if (checkTypeAndAlignment(size, value.getType(), expr).failed()) { + allGood = false; + break; + } + rewriter.setInsertionPoint(store); + auto newMap = AffineMap::get(map.getNumDims(), map.getNumSymbols(), + {expr.floorDiv(size)}, store.getContext()); + auto newStore = rewriter.create( + store.getLoc(), value, getNewMemref(), newMap, + store.getMapOperands()); + toErase.push_back(store); + toErase.push_back(user); + } else { + allGood = false; + break; + } + } + + if (!allGood) + return failure(); + + if (type == rewriter.getI8Type()) + return failure(); + + LLVM_DEBUG(llvm::dbgs() << "all good " << allGood << "\n"); + + for (auto &m : mapping.getValueMap()) + rewriter.replaceAllUsesWith(m.getFirst(), m.getSecond()); + + for (Operation *op : toErase) + rewriter.eraseOp(op); + + return failure(); + } +}; + struct ConvertLLVMAllocaToMemrefAlloca : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -179,7 +295,8 @@ static MemRefVal convertToMemref(PtrVal addr) { auto ptr2memref = builder.create( addr.getLoc(), MemRefType::get({ShapedType::kDynamic}, builder.getI8Type(), - MemRefLayoutAttrInterface{}, Attribute(addrSpace)), addr); + MemRefLayoutAttrInterface{}, Attribute(addrSpace)), + addr); return cast(ptr2memref.getResult()); } @@ -1152,10 +1269,22 @@ convertLLVMToAffineAccess(Operation *op, rewriter.replaceOp(oldOp, newOp); } - RewritePatternSet patterns(context); - patterns.insert(context, dataLayoutAnalysis); - GreedyRewriteConfig config; - return applyPatternsAndFoldGreedily(op, std::move(patterns), config); + { + RewritePatternSet patterns(context); + patterns.insert(context, dataLayoutAnalysis); + GreedyRewriteConfig config; + if (applyPatternsAndFoldGreedily(op, std::move(patterns), config).failed()) + return failure(); + } + { + RewritePatternSet patterns(context); + patterns.insert(context, + dataLayoutAnalysis); + GreedyRewriteConfig config; + if (applyPatternsAndFoldGreedily(op, std::move(patterns), config).failed()) + return failure(); + } + return success(); } } // namespace mlir diff --git a/test/lit_tests/raising/cpu.mlir b/test/lit_tests/raising/cpu.mlir new file mode 100644 index 000000000..7972d06eb --- /dev/null +++ b/test/lit_tests/raising/cpu.mlir @@ -0,0 +1,45 @@ +// RUN: enzymexlamlir-opt %s --pass-pipeline="builtin.module(convert-llvm-to-cf,enzyme-lift-cf-to-scf,libdevice-funcs-raise,canonicalize,llvm-to-affine-access)" | FileCheck %s +module { + llvm.func internal unnamed_addr fastcc @throw_boundserror_2676() attributes {dso_local, no_inline, sym_visibility = "private"} { + llvm.unreachable + } + func.func private @kern$par0(%arg0: !llvm.ptr<1>) { + %0 = llvm.mlir.constant(63 : i32) : i32 + affine.parallel (%arg1, %arg2, %arg3, %arg4, %arg5, %arg6) = (0, 0, 0, 0, 0, 0) to (1, 1, 1, 1, 1, 40) { + scf.execute_region { + %1 = arith.index_cast %arg4 : index to i32 + %2 = llvm.icmp "ugt" %1, %0 : i32 + llvm.cond_br %2, ^bb2, ^bb1 + ^bb1: // pred: ^bb0 + %3 = llvm.zext %1 : i32 to i64 + %4 = llvm.getelementptr inbounds %arg0[%3] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, i64 + %5 = llvm.load %4 {alignment = 1 : i64} : !llvm.ptr<1> -> i64 + %6 = llvm.mul %5, %5 : i64 + llvm.store %6, %4 {alignment = 1 : i64} : i64, !llvm.ptr<1> + scf.yield + ^bb2: // pred: ^bb0 + llvm.call fastcc @throw_boundserror_2676() : () -> () + scf.yield + } + } + return + } + func.func @main(%arg0: tensor<64xi64>) -> tensor<64xi64> { + %0 = enzymexla.jit_call @kern$par0 (%arg0) {output_operand_aliases = [#stablehlo.output_operand_alias]} : (tensor<64xi64>) -> tensor<64xi64> + return %0 : tensor<64xi64> + } +} + +// CHECK: func.func private @kern$par0(%arg0: !llvm.ptr<1>) { +// CHECK-NEXT: %0 = "enzymexla.pointer2memref"(%arg0) : (!llvm.ptr<1>) -> memref +// CHECK-NEXT: affine.parallel (%arg1, %arg2, %arg3, %arg4, %arg5, %arg6) = (0, 0, 0, 0, 0, 0) to (1, 1, 1, 1, 1, 40) { +// CHECK-NEXT: affine.if #set(%arg4) { +// CHECK-NEXT: llvm.call fastcc @throw_boundserror_2676() : () -> () +// CHECK-NEXT: } else { +// CHECK-NEXT: %1 = affine.load %0[%arg4] : memref +// CHECK-NEXT: %2 = arith.muli %1, %1 : i64 +// CHECK-NEXT: affine.store %2, %0[%arg4] : memref +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return +// CHECK-NEXT: } From b76a15a44923419747bcf2a265fcbd48d0d73f21 Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Thu, 6 Feb 2025 15:28:36 +0900 Subject: [PATCH 08/12] Propagate caller tensor types to memref types in callees --- src/enzyme_ad/jax/Passes/Passes.td | 6 ++ .../jax/Passes/PropagateConstantBound.cpp | 102 +++++++++++++++++- test/lit_tests/annotate_func_args.mlir | 5 +- test/lit_tests/propagate_tensor_types.mlir | 23 ++++ 4 files changed, 133 insertions(+), 3 deletions(-) create mode 100644 test/lit_tests/propagate_tensor_types.mlir diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td index 8cd7f2c7e..33de9c794 100644 --- a/src/enzyme_ad/jax/Passes/Passes.td +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -34,6 +34,12 @@ def PropagateConstantBoundsPass "mlir::LLVM::LLVMDialect", "mlir::NVVM::NVVMDialect" ]; + let options = [Option< + /*C++ variable name=*/"tensor_types", + /*CLI argument=*/"tensor_types", + /*type=*/"bool", + /*default=*/"true", + /*description=*/"Whether to propagate tensor types">]; } def ArithRaisingPass : Pass<"arith-raise"> { diff --git a/src/enzyme_ad/jax/Passes/PropagateConstantBound.cpp b/src/enzyme_ad/jax/Passes/PropagateConstantBound.cpp index c1feb37e0..6a1902f75 100644 --- a/src/enzyme_ad/jax/Passes/PropagateConstantBound.cpp +++ b/src/enzyme_ad/jax/Passes/PropagateConstantBound.cpp @@ -11,6 +11,7 @@ #include "src/enzyme_ad/jax/Dialect/Ops.h" #include "stablehlo/dialect/StablehloOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/IR/Matchers.h" @@ -33,6 +34,7 @@ namespace { struct PropagateConstantBoundsPass : public enzyme::impl::PropagateConstantBoundsPassBase< PropagateConstantBoundsPass> { + using PropagateConstantBoundsPassBase::PropagateConstantBoundsPassBase; static int32_t getSizeInBytes(Type ty) { int32_t bitWidth = 0; @@ -263,7 +265,105 @@ struct PropagateConstantBoundsPass } } } + + if (!tensor_types) + continue; + + bool changed = false; + SmallVector newTypes; + SmallVector indices; + for (auto [index, argTy, arg] : + llvm::enumerate(callee.getArgumentTypes(), callee.getArguments())) { + assert(!callers.empty()); + if (auto ptrTy = dyn_cast(argTy)) { + bool allMatch = true; + ArrayRef shape; + Type elTy; + for (auto caller : callers) { + Value param = caller.getOperands()[index]; + Type type = param.getType(); + + if (auto rtt = dyn_cast(type)) { + auto thisShape = rtt.getShape(); + auto thisElTy = rtt.getElementType(); + if (!elTy) { + elTy = thisElTy; + shape = thisShape; + } else if (shape != thisShape || elTy != thisElTy) { + allMatch = false; + break; + } + } else { + allMatch = false; + break; + } + } + if (allMatch) { + + Attribute addrSpace; + if (ptrTy.getAddressSpace() == 0) + addrSpace = nullptr; + else + addrSpace = + IntegerAttr::get(IntegerType::get(arg.getContext(), 64), + ptrTy.getAddressSpace()); + + MemRefType memrefTy = MemRefType::get({ShapedType::kDynamic}, elTy, + // TODO do we need a layout? + MemRefLayoutAttrInterface{}, + Attribute(addrSpace)); + + changed = true; + newTypes.push_back(memrefTy); + indices.push_back(index); + } else { + newTypes.push_back(argTy); + } + } + } + + if (changed) { + SmallVector newFtyArgs; + if (auto fty = + dyn_cast(callee.getFunctionType())) { + if (fty.getReturnType() == LLVM::LLVMVoidType::get(ctx) && + !fty.isVarArg()) { + auto newLLVMFty = LLVM::LLVMFunctionType::get( + fty.getReturnType(), newTypes, fty.isVarArg()); + + Block *entry = &callee.getFunctionBody().front(); + builder.setInsertionPointToStart(entry); + for (auto index : indices) { + auto newType = newTypes[index]; + auto oldArg = callee.getArgument(index); + auto newArg = + entry->insertArgument(index, newType, oldArg.getLoc()); + auto newPtr = builder.create( + newArg.getLoc(), oldArg.getType(), newArg); + oldArg.replaceAllUsesWith(newPtr); + entry->eraseArgument(index + 1); + } + + builder.setInsertionPoint(callee); + auto newFty = FunctionType::get(ctx, newTypes, TypeRange{}); + auto newF = builder.create( + callee.getLoc(), callee.getNameAttr(), newFty); + newF.getBlocks().splice(newF.getBlocks().begin(), + callee.getFunctionBody().getBlocks()); + + newF.setAllArgAttrs(callee.getArgAttrsAttr()); + newF.setResAttrsAttr(callee.getResAttrsAttr()); + + // TODO collect useful function attributes from here e.g. calling + // convention, visibility etc. we shuold probably just filter out + // the problematic ones such as `function_type`. + auto oldAttrs = cast(callee)->getAttrs(); + + callee->erase(); + } + } + } } } }; -} // end namespace \ No newline at end of file +} // end namespace diff --git a/test/lit_tests/annotate_func_args.mlir b/test/lit_tests/annotate_func_args.mlir index d72bad403..09c8d5b57 100644 --- a/test/lit_tests/annotate_func_args.mlir +++ b/test/lit_tests/annotate_func_args.mlir @@ -1,4 +1,4 @@ -// RUN: enzymexlamlir-opt %s --propagate-constant-bounds --split-input-file | FileCheck %s +// RUN: enzymexlamlir-opt %s --pass-pipeline="builtin.module(propagate-constant-bounds{tensor_types=false})" --split-input-file | FileCheck %s // CHECK-LABEL: ptx_kernelcc @foo // CHECK-SAME: llvm.align = 128 : i32, llvm.dereferenceable = 16 : i64, llvm.noalias @@ -141,4 +141,5 @@ func.func @main(%arg0: tensor<5xcomplex>) { %c_8 = stablehlo.constant dense<4> : tensor enzymexla.kernel_call @foo blocks in(%c_5, %c_8, %c_8) threads in(%c_4, %c_8, %c_8) shmem = %c_6 (%arg0) {} : (tensor<5xcomplex>) -> () return -} \ No newline at end of file +} + diff --git a/test/lit_tests/propagate_tensor_types.mlir b/test/lit_tests/propagate_tensor_types.mlir new file mode 100644 index 000000000..2ae198b94 --- /dev/null +++ b/test/lit_tests/propagate_tensor_types.mlir @@ -0,0 +1,23 @@ +// RUN: enzymexlamlir-opt %s --propagate-constant-bounds --split-input-file | FileCheck %s + +llvm.func @use(%arg0: !llvm.ptr<1>) + +// CHECK: func.func @foo(%arg0: memref {llvm.align = 128 : i32, llvm.dereferenceable = 40 : i64, llvm.noalias, llvm.nocapture, llvm.nofree}) { +// CHECK-NEXT: %0 = "enzymexla.memref2pointer"(%arg0) : (memref) -> !llvm.ptr<1> +// CHECK-NEXT: llvm.call @use(%0) : (!llvm.ptr<1>) -> () +// CHECK-NEXT: llvm.return +// CHECK-NEXT: } + +llvm.func ptx_kernelcc @foo(%arg0: !llvm.ptr<1> {llvm.align = 32, llvm.nocapture, llvm.nofree}) { + llvm.call @use(%arg0) : (!llvm.ptr<1>) -> () + llvm.return +} + +func.func @main(%arg0: tensor<5xcomplex>) { + %c_4 = stablehlo.constant dense<1> : tensor + %c_5 = stablehlo.constant dense<2> : tensor + %c_6 = stablehlo.constant dense<3> : tensor + %c_8 = stablehlo.constant dense<4> : tensor + enzymexla.kernel_call @foo blocks in(%c_5, %c_8, %c_8) threads in(%c_4, %c_8, %c_8) shmem = %c_6 (%arg0) {} : (tensor<5xcomplex>) -> () + return +} From b6c55ee461e7468723c2294df8692e4735e413be Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Thu, 6 Feb 2025 15:33:06 +0900 Subject: [PATCH 09/12] Must not fold before trying to type the memrefs --- src/enzyme_ad/jax/Passes/LLVMToAffineAccess.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/enzyme_ad/jax/Passes/LLVMToAffineAccess.cpp b/src/enzyme_ad/jax/Passes/LLVMToAffineAccess.cpp index 954056ed1..dd16a0e00 100644 --- a/src/enzyme_ad/jax/Passes/LLVMToAffineAccess.cpp +++ b/src/enzyme_ad/jax/Passes/LLVMToAffineAccess.cpp @@ -1273,7 +1273,8 @@ convertLLVMToAffineAccess(Operation *op, RewritePatternSet patterns(context); patterns.insert(context, dataLayoutAnalysis); GreedyRewriteConfig config; - if (applyPatternsAndFoldGreedily(op, std::move(patterns), config).failed()) + config.fold = false; + if (applyPatternsGreedily(op, std::move(patterns), config).failed()) return failure(); } { From b093273dc92eadfd7d577ee12fe0697147e274c4 Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Thu, 6 Feb 2025 15:36:54 +0900 Subject: [PATCH 10/12] Fix test --- .../raising/llvm_to_affine_access.mlir | 55 ++++++++++++------- 1 file changed, 34 insertions(+), 21 deletions(-) diff --git a/test/lit_tests/raising/llvm_to_affine_access.mlir b/test/lit_tests/raising/llvm_to_affine_access.mlir index 5ca1040f6..7fccfd8ee 100644 --- a/test/lit_tests/raising/llvm_to_affine_access.mlir +++ b/test/lit_tests/raising/llvm_to_affine_access.mlir @@ -1,5 +1,16 @@ // RUN: enzymexlamlir-opt %s --pass-pipeline="builtin.module(llvm-to-affine-access)" | FileCheck %s +// CHECK-LABEL: func.func @test_load_store_conversion( +// CHECK-SAME: %[[VAL_0:[^:]*]]: !llvm.ptr<1>, +// CHECK-SAME: %[[VAL_1:[^:]*]]: i64) { +// CHECK: %[[VAL_2:.*]] = "enzymexla.pointer2memref"(%[[VAL_0]]) : (!llvm.ptr<1>) -> memref +// CHECK: %[[VAL_3:.*]] = arith.index_cast %[[VAL_1]] : i64 to index +// CHECK: %[[VAL_4:.*]] = arith.index_cast %[[VAL_1]] : i64 to index +// CHECK: %[[VAL_5:.*]] = affine.load %[[VAL_2]][symbol(%[[VAL_3]])] : memref +// CHECK: %[[VAL_6:.*]] = llvm.mul %[[VAL_5]], %[[VAL_5]] : i64 +// CHECK: affine.store %[[VAL_6]], %[[VAL_2]][symbol(%[[VAL_4]])] : memref +// CHECK: return +// CHECK: } func.func @test_load_store_conversion(%arg0: !llvm.ptr<1>, %idx: i64) { %0 = llvm.getelementptr inbounds %arg0[%idx] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, i64 %1 = llvm.load %0 {alignment = 1 : i64} : !llvm.ptr<1> -> i64 @@ -10,16 +21,25 @@ func.func @test_load_store_conversion(%arg0: !llvm.ptr<1>, %idx: i64) { return } -// CHECK-LABEL: func @test_load_store_conversion -// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr<1> -// CHECK-SAME: %[[ARG1:.*]]: i64 -// CHECK: %[[MEMREF:.*]] = "enzymexla.pointer2memref"(%[[ARG0]]) {{.*}} memref -// CHECK: %[[IDX:.*]] = arith.index_cast %[[ARG1]] -// CHECK: affine.vector_load %[[MEMREF]][symbol(%[[IDX]]) * 8] {{.*}} vector<8xi8> -// CHECK: affine.vector_store // ----- +// CHECK-LABEL: func.func @test_multidim_load_store( +// CHECK-SAME: %[[VAL_0:[^:]*]]: !llvm.ptr<1>, +// CHECK-SAME: %[[VAL_1:[^:]*]]: i64, +// CHECK-SAME: %[[VAL_2:[^:]*]]: i64) { +// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(1 : index) : i64 +// CHECK: %[[VAL_4:.*]] = "enzymexla.pointer2memref"(%[[VAL_0]]) : (!llvm.ptr<1>) -> memref +// CHECK: %[[VAL_5:.*]] = arith.index_cast %[[VAL_2]] : i64 to index +// CHECK: %[[VAL_6:.*]] = arith.index_cast %[[VAL_1]] : i64 to index +// CHECK: %[[VAL_7:.*]] = affine.load %[[VAL_4]][symbol(%[[VAL_6]]) * 8 + symbol(%[[VAL_5]])] : memref +// CHECK: %[[VAL_8:.*]] = llvm.add %[[VAL_1]], %[[VAL_3]] : i64 +// CHECK: %[[VAL_9:.*]] = arith.index_cast %[[VAL_8]] : i64 to index +// CHECK: %[[VAL_10:.*]] = llvm.add %[[VAL_2]], %[[VAL_3]] : i64 +// CHECK: %[[VAL_11:.*]] = arith.index_cast %[[VAL_10]] : i64 to index +// CHECK: affine.store %[[VAL_7]], %[[VAL_4]][symbol(%[[VAL_9]]) * 8 + symbol(%[[VAL_11]])] : memref +// CHECK: return +// CHECK: } func.func @test_multidim_load_store(%arg0: !llvm.ptr<1>, %idx1: i64, %idx2: i64) { %c1 = llvm.mlir.constant(1 : index) : i64 %ptr = llvm.getelementptr %arg0[%idx1, %idx2] : (!llvm.ptr<1>, i64, i64) -> !llvm.ptr<1>, !llvm.array<8 x i64> @@ -34,18 +54,16 @@ func.func @test_multidim_load_store(%arg0: !llvm.ptr<1>, %idx1: i64, %idx2: i64) return } -// CHECK-LABEL: func @test_multidim_load_store -// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr<1>, -// CHECK-SAME: %[[ARG1:.*]]: i64, -// CHECK-SAME: %[[ARG2:.*]]: i64 -// CHECK: %[[MEMREF:.*]] = "enzymexla.pointer2memref"(%[[ARG0]]) {{.*}} memref -// CHECK-DAG: %[[IDX1:.*]] = arith.index_cast %[[ARG1]] -// CHECK-DAG: %[[IDX2:.*]] = arith.index_cast %[[ARG2]] -// CHECK: affine.vector_load %[[MEMREF]][symbol(%[[IDX1]]) * 64 + symbol(%[[IDX2]]) * 8] {{.*}} vector<8xi8> -// CHECK: affine.vector_store // ----- +// CHECK-LABEL: func.func @test_struct_access( +// CHECK-SAME: %[[VAL_0:[^:]*]]: !llvm.ptr) { +// CHECK: %[[VAL_1:.*]] = "enzymexla.pointer2memref"(%[[VAL_0]]) : (!llvm.ptr) -> memref +// CHECK: %[[VAL_2:.*]] = affine.load %[[VAL_1]][0] : memref +// CHECK: affine.store %[[VAL_2]], %[[VAL_1]][0] : memref +// CHECK: return +// CHECK: } func.func @test_struct_access(%arg0: !llvm.ptr) { %ptr = llvm.getelementptr %arg0[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i64)> %val = llvm.load %ptr : !llvm.ptr -> i64 @@ -55,11 +73,6 @@ func.func @test_struct_access(%arg0: !llvm.ptr) { return } -// CHECK-LABEL: func @test_struct_access -// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr -// CHECK: %[[MEMREF:.*]] = "enzymexla.pointer2memref"(%[[ARG0]]) {{.*}} memref -// CHECK: affine.vector_load %[[MEMREF]][0] {{.*}} vector<8xi8> -// CHECK: affine.vector_store // CHEC-K: %[[MEMREF:.*]] = enzymexla.at_addr %arg0 : !llvm.ptr to memref // CHEC-K: %[[LOAD:.*]] = memref.load %[[MEMREF]][%c0] : memref From df99e9d900de7257a8eb8e0bf4f77d3c4c869da4 Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Thu, 6 Feb 2025 16:39:50 +0900 Subject: [PATCH 11/12] clang-format --- src/enzyme_ad/jax/Passes/LowerKernel.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/LowerKernel.cpp b/src/enzyme_ad/jax/Passes/LowerKernel.cpp index 10838d163..a67f4fe2b 100644 --- a/src/enzyme_ad/jax/Passes/LowerKernel.cpp +++ b/src/enzyme_ad/jax/Passes/LowerKernel.cpp @@ -301,8 +301,8 @@ bool CompileCPUKernel(SymbolTableCollection &symbolTable, mlir::Location loc, SmallVector steps(6, 1); auto par = builder.create( - loc, TypeRange(), ArrayRef(), zeroMaps, - ValueRange(), idMaps, finals, steps); + loc, TypeRange(), ArrayRef(), zeroMaps, + ValueRange(), idMaps, finals, steps); builder.create(loc); From 43555dda70be458dd728427c622c1cd6cc9f822b Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Fri, 7 Feb 2025 18:13:48 +0900 Subject: [PATCH 12/12] Revert "Propagate caller tensor types to memref types in callees" This reverts commit b76a15a44923419747bcf2a265fcbd48d0d73f21. --- src/enzyme_ad/jax/Passes/Passes.td | 6 -- .../jax/Passes/PropagateConstantBound.cpp | 102 +----------------- test/lit_tests/annotate_func_args.mlir | 5 +- test/lit_tests/propagate_tensor_types.mlir | 23 ---- 4 files changed, 3 insertions(+), 133 deletions(-) delete mode 100644 test/lit_tests/propagate_tensor_types.mlir diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td index 33de9c794..8cd7f2c7e 100644 --- a/src/enzyme_ad/jax/Passes/Passes.td +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -34,12 +34,6 @@ def PropagateConstantBoundsPass "mlir::LLVM::LLVMDialect", "mlir::NVVM::NVVMDialect" ]; - let options = [Option< - /*C++ variable name=*/"tensor_types", - /*CLI argument=*/"tensor_types", - /*type=*/"bool", - /*default=*/"true", - /*description=*/"Whether to propagate tensor types">]; } def ArithRaisingPass : Pass<"arith-raise"> { diff --git a/src/enzyme_ad/jax/Passes/PropagateConstantBound.cpp b/src/enzyme_ad/jax/Passes/PropagateConstantBound.cpp index 6a1902f75..c1feb37e0 100644 --- a/src/enzyme_ad/jax/Passes/PropagateConstantBound.cpp +++ b/src/enzyme_ad/jax/Passes/PropagateConstantBound.cpp @@ -11,7 +11,6 @@ #include "src/enzyme_ad/jax/Dialect/Ops.h" #include "stablehlo/dialect/StablehloOps.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/IR/Matchers.h" @@ -34,7 +33,6 @@ namespace { struct PropagateConstantBoundsPass : public enzyme::impl::PropagateConstantBoundsPassBase< PropagateConstantBoundsPass> { - using PropagateConstantBoundsPassBase::PropagateConstantBoundsPassBase; static int32_t getSizeInBytes(Type ty) { int32_t bitWidth = 0; @@ -265,105 +263,7 @@ struct PropagateConstantBoundsPass } } } - - if (!tensor_types) - continue; - - bool changed = false; - SmallVector newTypes; - SmallVector indices; - for (auto [index, argTy, arg] : - llvm::enumerate(callee.getArgumentTypes(), callee.getArguments())) { - assert(!callers.empty()); - if (auto ptrTy = dyn_cast(argTy)) { - bool allMatch = true; - ArrayRef shape; - Type elTy; - for (auto caller : callers) { - Value param = caller.getOperands()[index]; - Type type = param.getType(); - - if (auto rtt = dyn_cast(type)) { - auto thisShape = rtt.getShape(); - auto thisElTy = rtt.getElementType(); - if (!elTy) { - elTy = thisElTy; - shape = thisShape; - } else if (shape != thisShape || elTy != thisElTy) { - allMatch = false; - break; - } - } else { - allMatch = false; - break; - } - } - if (allMatch) { - - Attribute addrSpace; - if (ptrTy.getAddressSpace() == 0) - addrSpace = nullptr; - else - addrSpace = - IntegerAttr::get(IntegerType::get(arg.getContext(), 64), - ptrTy.getAddressSpace()); - - MemRefType memrefTy = MemRefType::get({ShapedType::kDynamic}, elTy, - // TODO do we need a layout? - MemRefLayoutAttrInterface{}, - Attribute(addrSpace)); - - changed = true; - newTypes.push_back(memrefTy); - indices.push_back(index); - } else { - newTypes.push_back(argTy); - } - } - } - - if (changed) { - SmallVector newFtyArgs; - if (auto fty = - dyn_cast(callee.getFunctionType())) { - if (fty.getReturnType() == LLVM::LLVMVoidType::get(ctx) && - !fty.isVarArg()) { - auto newLLVMFty = LLVM::LLVMFunctionType::get( - fty.getReturnType(), newTypes, fty.isVarArg()); - - Block *entry = &callee.getFunctionBody().front(); - builder.setInsertionPointToStart(entry); - for (auto index : indices) { - auto newType = newTypes[index]; - auto oldArg = callee.getArgument(index); - auto newArg = - entry->insertArgument(index, newType, oldArg.getLoc()); - auto newPtr = builder.create( - newArg.getLoc(), oldArg.getType(), newArg); - oldArg.replaceAllUsesWith(newPtr); - entry->eraseArgument(index + 1); - } - - builder.setInsertionPoint(callee); - auto newFty = FunctionType::get(ctx, newTypes, TypeRange{}); - auto newF = builder.create( - callee.getLoc(), callee.getNameAttr(), newFty); - newF.getBlocks().splice(newF.getBlocks().begin(), - callee.getFunctionBody().getBlocks()); - - newF.setAllArgAttrs(callee.getArgAttrsAttr()); - newF.setResAttrsAttr(callee.getResAttrsAttr()); - - // TODO collect useful function attributes from here e.g. calling - // convention, visibility etc. we shuold probably just filter out - // the problematic ones such as `function_type`. - auto oldAttrs = cast(callee)->getAttrs(); - - callee->erase(); - } - } - } } } }; -} // end namespace +} // end namespace \ No newline at end of file diff --git a/test/lit_tests/annotate_func_args.mlir b/test/lit_tests/annotate_func_args.mlir index 09c8d5b57..d72bad403 100644 --- a/test/lit_tests/annotate_func_args.mlir +++ b/test/lit_tests/annotate_func_args.mlir @@ -1,4 +1,4 @@ -// RUN: enzymexlamlir-opt %s --pass-pipeline="builtin.module(propagate-constant-bounds{tensor_types=false})" --split-input-file | FileCheck %s +// RUN: enzymexlamlir-opt %s --propagate-constant-bounds --split-input-file | FileCheck %s // CHECK-LABEL: ptx_kernelcc @foo // CHECK-SAME: llvm.align = 128 : i32, llvm.dereferenceable = 16 : i64, llvm.noalias @@ -141,5 +141,4 @@ func.func @main(%arg0: tensor<5xcomplex>) { %c_8 = stablehlo.constant dense<4> : tensor enzymexla.kernel_call @foo blocks in(%c_5, %c_8, %c_8) threads in(%c_4, %c_8, %c_8) shmem = %c_6 (%arg0) {} : (tensor<5xcomplex>) -> () return -} - +} \ No newline at end of file diff --git a/test/lit_tests/propagate_tensor_types.mlir b/test/lit_tests/propagate_tensor_types.mlir deleted file mode 100644 index 2ae198b94..000000000 --- a/test/lit_tests/propagate_tensor_types.mlir +++ /dev/null @@ -1,23 +0,0 @@ -// RUN: enzymexlamlir-opt %s --propagate-constant-bounds --split-input-file | FileCheck %s - -llvm.func @use(%arg0: !llvm.ptr<1>) - -// CHECK: func.func @foo(%arg0: memref {llvm.align = 128 : i32, llvm.dereferenceable = 40 : i64, llvm.noalias, llvm.nocapture, llvm.nofree}) { -// CHECK-NEXT: %0 = "enzymexla.memref2pointer"(%arg0) : (memref) -> !llvm.ptr<1> -// CHECK-NEXT: llvm.call @use(%0) : (!llvm.ptr<1>) -> () -// CHECK-NEXT: llvm.return -// CHECK-NEXT: } - -llvm.func ptx_kernelcc @foo(%arg0: !llvm.ptr<1> {llvm.align = 32, llvm.nocapture, llvm.nofree}) { - llvm.call @use(%arg0) : (!llvm.ptr<1>) -> () - llvm.return -} - -func.func @main(%arg0: tensor<5xcomplex>) { - %c_4 = stablehlo.constant dense<1> : tensor - %c_5 = stablehlo.constant dense<2> : tensor - %c_6 = stablehlo.constant dense<3> : tensor - %c_8 = stablehlo.constant dense<4> : tensor - enzymexla.kernel_call @foo blocks in(%c_5, %c_8, %c_8) threads in(%c_4, %c_8, %c_8) shmem = %c_6 (%arg0) {} : (tensor<5xcomplex>) -> () - return -}