From 4a3235b9a21f5e7f8248eeac7debc461a60f9e87 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Mon, 24 Feb 2025 20:35:32 +0000 Subject: [PATCH 1/6] Introduce loop versioning to remove masked tt.load operations. Signed-off-by: Tiotto, Ettore --- bin/RegisterTritonDialects.h | 2 + scripts/test-triton.sh | 6 +- .../kernel-03-matrix-multiplication.mlir | 2 +- third_party/intel/backend/compiler.py | 1 + third_party/intel/include/CMakeLists.txt | 2 +- .../intel/include/Dialect/CMakeLists.txt | 1 + .../include/Dialect/Triton/CMakeLists.txt | 1 + .../Dialect/Triton/Transforms/CMakeLists.txt | 3 + .../Dialect/Triton/Transforms/Passes.h | 25 ++ .../Dialect/Triton/Transforms/Passes.td | 32 ++ .../TritonIntelGPU/Transforms/Utility.h | 8 - third_party/intel/include/Utils/Utility.h | 20 + third_party/intel/lib/Dialect/CMakeLists.txt | 1 + .../intel/lib/Dialect/Triton/CMakeLists.txt | 1 + .../Dialect/Triton/Transforms/CMakeLists.txt | 11 + .../Dialect/Triton/Transforms/RemoveMasks.cpp | 307 +++++++++++++++ .../TritonIntelGPUTransforms/CMakeLists.txt | 1 + .../MaterializeBlockPointer.cpp | 7 +- .../lib/TritonIntelGPUTransforms/Utility.cpp | 45 --- .../TritonRaiseBlockPointer.cpp | 356 ++---------------- third_party/intel/lib/Utils/CMakeLists.txt | 1 + third_party/intel/lib/Utils/Utility.cpp | 94 +++++ third_party/intel/triton_xpu.cc | 3 +- 23 files changed, 541 insertions(+), 389 deletions(-) create mode 100644 third_party/intel/include/Dialect/Triton/CMakeLists.txt create mode 100644 third_party/intel/include/Dialect/Triton/Transforms/CMakeLists.txt create mode 100644 third_party/intel/include/Dialect/Triton/Transforms/Passes.h create mode 100644 third_party/intel/include/Dialect/Triton/Transforms/Passes.td create mode 100644 third_party/intel/include/Utils/Utility.h create mode 100644 third_party/intel/lib/Dialect/Triton/CMakeLists.txt create mode 100644 third_party/intel/lib/Dialect/Triton/Transforms/CMakeLists.txt create mode 100644 third_party/intel/lib/Dialect/Triton/Transforms/RemoveMasks.cpp create mode 100644 third_party/intel/lib/Utils/Utility.cpp diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index 3f13ebc256..5ca8fc5d91 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -1,4 +1,5 @@ #pragma once +#include "intel/include/Dialect/Triton/Transforms/Passes.h" #include "intel/include/Dialect/TritonGEN/IR/TritonGENDialect.h" #include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" #include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h" @@ -63,6 +64,7 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::test::registerTestMembarPass(); mlir::triton::registerConvertTritonToTritonGPUPass(); mlir::triton::intel::registerConvertTritonToTritonGPUWarpPass(); + mlir::triton::intel::registerTritonIntelRemoveMasks(); mlir::triton::intel::registerTritonRaiseBlockPointer(); mlir::triton::registerAllocateSharedMemoryPass(); mlir::triton::registerTritonGPUGlobalScratchAllocationPass(); diff --git a/scripts/test-triton.sh b/scripts/test-triton.sh index 69fffa2070..a52a573037 100755 --- a/scripts/test-triton.sh +++ b/scripts/test-triton.sh @@ -252,11 +252,11 @@ run_tutorial_tests() { run_tutorial_test "10-experimental-block-pointer" run_tutorial_test "10i-experimental-block-pointer" - echo "\n***************************************************" - echo "Running with TRITON_INTEL_RAISE_BLOCK_POINTER=ignore-masks" + echo "***************************************************" + echo "Running with TRITON_INTEL_RAISE_BLOCK_POINTER " echo "***************************************************" - TRITON_TEST_REPORTS=false TRITON_INTEL_RAISE_BLOCK_POINTER=ignore-masks \ + TRITON_TEST_REPORTS=false TRITON_INTEL_RAISE_BLOCK_POINTER=1 \ run_tutorial_test "03-matrix-multiplication" } diff --git a/test/Triton/Intel/RaiseToBlockPointers/kernel-03-matrix-multiplication.mlir b/test/Triton/Intel/RaiseToBlockPointers/kernel-03-matrix-multiplication.mlir index 211c8b4f3a..eefab3a6ff 100644 --- a/test/Triton/Intel/RaiseToBlockPointers/kernel-03-matrix-multiplication.mlir +++ b/test/Triton/Intel/RaiseToBlockPointers/kernel-03-matrix-multiplication.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -triton-raise-block-pointer -canonicalize | FileCheck %s +// RUN: triton-opt %s -triton-intel-remove-masks -triton-raise-block-pointer -canonicalize | FileCheck %s module { tt.func public @matmul_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) { diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index 24b854a782..000bbde993 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -224,6 +224,7 @@ def make_ttir(mod, metadata, opt): pm.enable_debug() passes.common.add_inliner(pm) passes.ttir.add_combine(pm) + intel.passes.ttir.add_remove_masks(pm) if raise_block_ptr_flags['enabled']: ignore_masks = True if raise_block_ptr_flags['ignore-masks'] else False intel.passes.ttir.add_raise_block_pointer(pm, ignore_masks) diff --git a/third_party/intel/include/CMakeLists.txt b/third_party/intel/include/CMakeLists.txt index 193ba48916..acb84f7ad2 100644 --- a/third_party/intel/include/CMakeLists.txt +++ b/third_party/intel/include/CMakeLists.txt @@ -1,6 +1,6 @@ -add_subdirectory(TritonAnnotateModule) add_subdirectory(Dialect) add_subdirectory(GPUToTritonGEN) +add_subdirectory(TritonAnnotateModule) add_subdirectory(TritonGENToLLVM) add_subdirectory(TritonIntelGPUToLLVM) add_subdirectory(TritonRaiseBlockPointer) diff --git a/third_party/intel/include/Dialect/CMakeLists.txt b/third_party/intel/include/Dialect/CMakeLists.txt index fe07104135..a2727a5182 100644 --- a/third_party/intel/include/Dialect/CMakeLists.txt +++ b/third_party/intel/include/Dialect/CMakeLists.txt @@ -1,2 +1,3 @@ +add_subdirectory(Triton) add_subdirectory(TritonGEN) add_subdirectory(TritonIntelGPU) diff --git a/third_party/intel/include/Dialect/Triton/CMakeLists.txt b/third_party/intel/include/Dialect/Triton/CMakeLists.txt new file mode 100644 index 0000000000..e31af32661 --- /dev/null +++ b/third_party/intel/include/Dialect/Triton/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Transforms) diff --git a/third_party/intel/include/Dialect/Triton/Transforms/CMakeLists.txt b/third_party/intel/include/Dialect/Triton/Transforms/CMakeLists.txt new file mode 100644 index 0000000000..2e9e7415e9 --- /dev/null +++ b/third_party/intel/include/Dialect/Triton/Transforms/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name TritonIntel) +add_public_tablegen_target(TritonIntelTransformsIncGen) diff --git a/third_party/intel/include/Dialect/Triton/Transforms/Passes.h b/third_party/intel/include/Dialect/Triton/Transforms/Passes.h new file mode 100644 index 0000000000..650cc87df9 --- /dev/null +++ b/third_party/intel/include/Dialect/Triton/Transforms/Passes.h @@ -0,0 +1,25 @@ +//===- Passes.h - Intel Pass Construction and Registration ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_DIALECT_TRITON_INTEL_TRANSFORMS_PASSES_H +#define TRITON_DIALECT_TRITON_INTEL_TRANSFORMS_PASSES_H + +#include "mlir/Pass/Pass.h" + +namespace mlir::triton::intel { + +#define GEN_PASS_DECL +#include "intel/include/Dialect/Triton/Transforms/Passes.h.inc" + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "intel/include/Dialect/Triton/Transforms/Passes.h.inc" + +} // namespace mlir::triton::intel + +#endif // TRITON_DIALECT_TRITON_INTEL_TRANSFORMS_PASSES_H \ No newline at end of file diff --git a/third_party/intel/include/Dialect/Triton/Transforms/Passes.td b/third_party/intel/include/Dialect/Triton/Transforms/Passes.td new file mode 100644 index 0000000000..466932cf48 --- /dev/null +++ b/third_party/intel/include/Dialect/Triton/Transforms/Passes.td @@ -0,0 +1,32 @@ +//===-- Passes.td - Intel TritonDialect passes definition --*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_DIALECT_TRITON_INTEL_TRANSFORMS_PASSES +#define TRITON_DIALECT_TRITON_INTEL_TRANSFORMS_PASSES + +include "mlir/Pass/PassBase.td" + +def TritonIntelRemoveMasks + : Pass<"triton-intel-remove-masks", "mlir::ModuleOp"> { + let summary = "Remove masks from tt.load and tt.store operations"; + + let description = [{ + This pass attempts to remove the mask for tt.load and tt.store operations. + If the masked operation is in a loop, the pass attempts to find a loop + invariant condition equivalent to the mask condition, and then use it to + version the loop. + }]; + + let dependentDialects = [ + "mlir::arith::ArithDialect", + "mlir::scf::SCFDialect", + "mlir::triton::TritonDialect" + ]; +} + +#endif // TRITON_DIALECT_TRITON_INTEL_TRANSFORMS_PASSES diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h index 3028cae8d6..2eed05a89e 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h +++ b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h @@ -59,14 +59,6 @@ LLVM::CallOp createSPIRVBuiltinCall(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMFuncOp func, ValueRange args); -// This function folds the `op` operation and returns the constant value if it -// has successfully folded to a constant. Otherwise, it returns `std::nullopt`. -std::optional getFoldedConstantValue(Operation *op); - -// Return true if the `val` value is a constant containing a value equal to -// expected. -bool isConstant(Value val, const unsigned expected); - } // namespace mlir::triton::gpu::intel #endif // TRITON_DIALECT_TRITONINTELGPU_TRANSFORMS_UTILITY_H diff --git a/third_party/intel/include/Utils/Utility.h b/third_party/intel/include/Utils/Utility.h new file mode 100644 index 0000000000..b99343726e --- /dev/null +++ b/third_party/intel/include/Utils/Utility.h @@ -0,0 +1,20 @@ +#ifndef TRITON_INTEL_UTILS_UTILITY_H +#define TRITON_INTEL_UTILS_UTILITY_H + +#include + +namespace mlir::triton::intel { + +// This function folds the `op` operation and returns the constant value if it +// has successfully folded to a constant. Otherwise, it returns `std::nullopt`. +std::optional getFoldedConstantValue(Operation *op); + +// Return true if the `val` value is a constant containing a value equal to +// expected. +bool isConstant(Value val, const unsigned expected); + +mlir::Value getFinalValue(Value value); + +} // namespace mlir::triton::intel + +#endif // TRITON_INTEL_UTILS_UTILITY_H diff --git a/third_party/intel/lib/Dialect/CMakeLists.txt b/third_party/intel/lib/Dialect/CMakeLists.txt index fe07104135..a2727a5182 100644 --- a/third_party/intel/lib/Dialect/CMakeLists.txt +++ b/third_party/intel/lib/Dialect/CMakeLists.txt @@ -1,2 +1,3 @@ +add_subdirectory(Triton) add_subdirectory(TritonGEN) add_subdirectory(TritonIntelGPU) diff --git a/third_party/intel/lib/Dialect/Triton/CMakeLists.txt b/third_party/intel/lib/Dialect/Triton/CMakeLists.txt new file mode 100644 index 0000000000..e31af32661 --- /dev/null +++ b/third_party/intel/lib/Dialect/Triton/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Transforms) diff --git a/third_party/intel/lib/Dialect/Triton/Transforms/CMakeLists.txt b/third_party/intel/lib/Dialect/Triton/Transforms/CMakeLists.txt new file mode 100644 index 0000000000..3219f3ab2a --- /dev/null +++ b/third_party/intel/lib/Dialect/Triton/Transforms/CMakeLists.txt @@ -0,0 +1,11 @@ +add_triton_library(TritonIntelTransforms + RemoveMasks.cpp + + DEPENDS + TritonIntelTransformsIncGen + + LINK_LIBS PUBLIC + MLIRPass + MLIRTransformUtils + TritonIR +) diff --git a/third_party/intel/lib/Dialect/Triton/Transforms/RemoveMasks.cpp b/third_party/intel/lib/Dialect/Triton/Transforms/RemoveMasks.cpp new file mode 100644 index 0000000000..f171b3019c --- /dev/null +++ b/third_party/intel/lib/Dialect/Triton/Transforms/RemoveMasks.cpp @@ -0,0 +1,307 @@ +#include "intel/include/Dialect/Triton/Transforms/Passes.h" +#include "intel/include/Utils/Utility.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Verifier.h" +// #include "mlir/Pass/Pass.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "triton-intel-remove-masks" + +using namespace mlir; +namespace tt = mlir::triton; + +namespace mlir::triton::intel { +#define GEN_PASS_DEF_TRITONINTELREMOVEMASKS +#include "intel/include/Dialect/Triton/Transforms/Passes.h.inc" +} // namespace mlir::triton::intel + +namespace { + +// Represent a versioning condition for a loop. +class VersioningCondition { +public: + VersioningCondition(Value S, Value BS) : S(S), BS(BS) { + assert(isValid() && "Invalid values supplied"); + } + + // Create the condition: (S % BS == 0 && S > BS) + Value materialize(OpBuilder &builder, Location loc) const { + assert(S && BS && "Expecting valid values"); + Value zero = + builder.createOrFold(loc, 0, S.getType()); + Value cmp1 = builder.create( + loc, arith::CmpIPredicate::eq, + builder.create(loc, S, BS), zero); + Value cmp2 = + builder.create(loc, arith::CmpIPredicate::sgt, S, BS); + return builder.create(loc, cmp1, cmp2); + } + +private: + bool isValid() const { + Type SType = S.getType(), BSType = BS.getType(); + if (!isa(SType) || !isa(BSType)) + return false; + + return cast(SType).getWidth() == + cast(BSType).getWidth(); + } + + Value S; // The length of a row/column. + Value BS; // The block size. +}; + +// Collects masked operations conditions in a loop. +class MaskedOpsCollector { +public: + using MaskedOperations = SmallPtrSet; + + MaskedOpsCollector(scf::ForOp &forOp) : forOp(forOp) { + assert(!forOp->template getParentOfType() && + "Nested loop not handled yet"); + createVersioningCondition(forOp); + } + + // Collect mask condition that can be made loop invariant for the `tt.load` + // operation in the given loop. + bool collectMaskedOps() { + assert(versioningCond && "Versioning condition should be valid"); + + // Collect masked loads in the loop if they have canonical mask. + for (auto op : forOp.getOps()) { + Value mask = op.getMask(); + if (mask && isValidMask(tt::intel::getFinalValue(mask))) + maskedOps.insert(op); + } + + // TODO: collect masked stores in the loop if they have canonical mask. + return maskedOps.size(); + } + + VersioningCondition *getVersioningCond() const { + return versioningCond.get(); + }; + + const MaskedOperations &getMaskedOps() const { return maskedOps; }; + +private: + // Note: this assumes the loop UB is in canonical form `N+END-1)/END`. + void createVersioningCondition(scf::ForOp &forOp) { + Value ub = tt::intel::getFinalValue(forOp.getUpperBound()); + Operation *defOp = ub.getDefiningOp(); + auto divOp = cast(defOp); + Operation *divLhsOp = divOp.getLhs().getDefiningOp(); + auto divNumOp = cast(divLhsOp); + versioningCond = std::make_unique(divNumOp.getLhs(), + divOp.getRhs()); + } + + // Check whether a mask is in canonical form: (0..END) < N - i*END + bool isValidMask(Value mask) const { + assert(mask.getDefiningOp() && "Expected a valid mask operation"); + auto cmpOp = cast(mask.getDefiningOp()); + arith::CmpIPredicate pred = cmpOp.getPredicate(); + if (pred != arith::CmpIPredicate::slt) + return false; + + Operation *lhs = tt::intel::getFinalValue(cmpOp.getLhs()).getDefiningOp(); + Operation *rhs = tt::intel::getFinalValue(cmpOp.getRhs()).getDefiningOp(); + if (!isa(lhs) || !isa(rhs)) + return false; + + auto rangeOp = cast(lhs); + unsigned end = rangeOp.getEnd(); + assert(end > rangeOp.getStart() && "Invalid range"); + + auto subOp = cast(rhs); + Operation *subLhs = subOp.getLhs().getDefiningOp(); + Operation *subRhs = subOp.getRhs().getDefiningOp(); + if (subLhs || !isa(subRhs)) + return false; + + auto mulOp = cast(subRhs); + Operation *mulLhs = mulOp.getLhs().getDefiningOp(); + Operation *mulRhs = mulOp.getRhs().getDefiningOp(); + if (mulLhs && mulRhs) + return false; + + if (!mulLhs && isa(mulRhs)) + return cast(mulRhs).value() == end; + if (!mulRhs && isa(mulLhs)) + return cast(mulLhs).value() == end; + + return false; + } + +private: + // Masked operations in the loop that can be have their mask dropped when the + // loop is versioned using the versioning condition associated with this + // class. + scf::ForOp &forOp; + MaskedOperations maskedOps; + std::unique_ptr versioningCond = nullptr; +}; + +class LoopVersioner { +public: + // Version the \p forOp loop with a condition that makes the masks collected + // by \p collector unnecessary. + // TODO: Extend the versioning region to encompass the downward exposed uses + // of the return values. + static bool version(scf::ForOp &forOp, MaskedOpsCollector &collector) { + assert(collector.getVersioningCond() && + "Versioning condition should be present"); + + // Limitation: give up if the loop returns tensor of ptrs. + if (!canVersion(forOp)) + return false; + + // Collect loop results that are downward exposed. + auto getUsedResults = [](const scf::ForOp &forOp) { + SmallVector resTypes; + for (Value res : forOp->getResults()) { + if (!res.getUsers().empty()) + resTypes.push_back(res.getType()); + } + return resTypes; + }; + + // Create the versioning branch. + OpBuilder builder(forOp); + Location loc = forOp.getLoc(); + Value versioningCond = + collector.getVersioningCond()->materialize(builder, loc); + auto ifOp = + builder.create(loc, getUsedResults(forOp), versioningCond, + /*withThenRegion=*/true, + /*withElseRegion=*/true); + + // Clone the original loop into the 2 if branches. + OpBuilder thenB = ifOp.getThenBodyBuilder(); + OpBuilder elseB = ifOp.getElseBodyBuilder(); + + IRMapping map; + Operation *thenForLoop = thenB.clone(*forOp.getOperation(), map); + Operation *elseForLoop = elseB.clone(*forOp.getOperation()); + + // Collect results in 'clonedLoop' corresponding to downward exposed results + // 'forOp'. + auto pruneUnusedResults = [&](const scf::ForOp &forOp, + Operation *clonedLoop) { + SmallVector prunedResults; + for (auto [idx, val] : llvm::enumerate(forOp->getResults())) { + if (!val.getUsers().empty()) + prunedResults.push_back(clonedLoop->getResult(idx)); + } + return prunedResults; + }; + + // Create the yield operations for the two if branches. + thenB.create(loc, pruneUnusedResults(forOp, thenForLoop)); + elseB.create(loc, pruneUnusedResults(forOp, elseForLoop)); + + // Drop the mask from candidate masked operations in the "then" region's + // cloned loop. + for (Operation *maskedOp : collector.getMaskedOps()) { + Operation *mappedOp = map.lookup(maskedOp); + if (auto loadOp = dyn_cast(mappedOp)) { + OpBuilder builder(mappedOp); + auto newLoad = builder.create( + loadOp.getLoc(), loadOp.getPtr(), loadOp.getCache(), + loadOp.getEvict(), loadOp.getIsVolatile()); + mappedOp->replaceAllUsesWith(newLoad); + mappedOp->erase(); + } + // TODO: stores + } + + // Replace the uses of the original loop results. + unsigned idx = 0; + for (Value res : forOp.getResults()) { + if (!res.getUsers().empty()) + res.replaceAllUsesWith(ifOp->getResult(idx++)); + } + + forOp.erase(); + + return true; + } + + // Ensure the loop upper bound is in canonical form (N+END-1)/END. + static bool hasValidUpperBound(scf::ForOp &forOp) { + Value ub = tt::intel::getFinalValue(forOp.getUpperBound()); + Operation *defOp = ub.getDefiningOp(); + if (!defOp || !isa(defOp)) + return false; + + auto divOp = cast(defOp); + Operation *divLhsOp = divOp.getLhs().getDefiningOp(); + Operation *divRhsOp = divOp.getRhs().getDefiningOp(); + if (!divLhsOp || !divRhsOp || !isa(divLhsOp) || + !isa(divRhsOp)) + return false; + + auto divNumOp = cast(divLhsOp); + auto divDenOp = cast(divRhsOp); + Operation *addLhsOp = divNumOp.getLhs().getDefiningOp(); + Operation *addRhsOp = divNumOp.getRhs().getDefiningOp(); + if (addLhsOp || !isa(addRhsOp) || + (divDenOp.value() != cast(addRhsOp).value() + 1)) + return false; + + return true; + } + +private: + // Currently we can version the loop only is it doesn't have downward + // exposed uses of return values that are a tensor of pointers. + // Note: this is due to the fact the results yielded by the 2 versioning + // branches have different types for ptr (only in one versioned loop tensor of + // ptrs are changed to block ptrs) 'then' part of the versioning branch and + // leave them as is in the 'else' branch). + static bool canVersion(scf::ForOp &forOp) { + return llvm::any_of(forOp.getResults(), [](Value res) { + return !tt::isTensorPointerType(res.getType()) || res.getUsers().empty(); + }); + } +}; + +struct TritonIntelRemoveMasksBase + : tt::intel::impl::TritonIntelRemoveMasksBase { +public: + using Base::Base; + using IndexMapSet = std::map>; + + void runOnOperation() final { + ModuleOp moduleOp = getOperation(); + + // Attempt to version loops so that masked operations in the loop become + // superfluous. + moduleOp->walk([&](Operation *op) { + if (scf::ForOp forOp = dyn_cast(op)) { + // Nested loop aren't currently handled. + if (forOp->template getParentOfType()) + return WalkResult::advance(); + + // Ensure loop UB is in 'canonical' form. + if (!LoopVersioner::hasValidUpperBound(forOp)) + return WalkResult::advance(); + + MaskedOpsCollector collector(forOp); + if (collector.collectMaskedOps()) { + [[maybe_unused]] bool loopVersioned = + LoopVersioner::version(forOp, collector); + LLVM_DEBUG(if (loopVersioned) llvm::dbgs() << "Loop versioned\n"); + } + } + return WalkResult::advance(); + }); + + LLVM_DEBUG(llvm::dbgs() << "After versioning:\n" << moduleOp << "\n"); + assert(succeeded(verify(moduleOp)) && "Module verification failed"); + } +}; + +} // namespace \ No newline at end of file diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt b/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt index 36a9701195..6d57997695 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt +++ b/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt @@ -27,4 +27,5 @@ add_triton_library(TritonIntelGPUTransforms TritonGENIR TritonGPUIR TritonIntelGPUIR + TritonIntelUtils ) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp index c8faf07969..e0bb58df76 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp @@ -1,10 +1,9 @@ #include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" #include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h" #include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h" -#include "mlir/Dialect/Arith/IR/Arith.h" +#include "intel/include/Utils/Utility.h" #include "mlir/IR/Visitors.h" #include "triton/Analysis/Utility.h" -#include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include @@ -64,7 +63,7 @@ struct TritonIntelGPUMaterializeBlockPointerPass Operation::operand_range strides = makeTensorPtrOp.getStrides(); int fastChangeDim = -1; for (size_t i = 0; i < strides.size(); ++i) { - if (mlir::triton::gpu::intel::isConstant(strides[i], 1)) { + if (tt::intel::isConstant(strides[i], 1)) { fastChangeDim = i; break; } @@ -89,7 +88,7 @@ struct TritonIntelGPUMaterializeBlockPointerPass fastChangeStride.print(llvm::dbgs()); llvm::dbgs() << "\n"; }); - if (!mlir::triton::gpu::intel::isConstant(fastChangeStride, 1)) + if (!tt::intel::isConstant(fastChangeStride, 1)) return; // Across Intel platforms, the strictest pitch restriction is to be a diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp index 909589c84f..5b59616f1b 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp @@ -286,49 +286,4 @@ LLVM::CallOp createSPIRVBuiltinCall(Location loc, return call; } -static std::optional getIntAttr(const OpFoldResult ofr) { - if (auto attr = dyn_cast(ofr)) - if (auto intAttr = dyn_cast(attr)) - return intAttr.getInt(); - return std::nullopt; -} - -std::optional getFoldedConstantValue(Operation *op) { - SmallVector results; - if (failed(op->fold(results))) { - return std::nullopt; - } - - // If fold succeeded but `results` is empty, we give a second try, after the - // operands have been switched during the first call to `fold()`. - if (results.empty()) { - if (failed(op->fold(results))) { - return std::nullopt; - } - } - - if (results.size() != 1) { - return std::nullopt; - } - - auto intAttr = getIntAttr(results[0]); - if (intAttr.has_value()) { - return intAttr.value(); - } - - auto val = cast(results[0]); - auto constOp = val.getDefiningOp(); - if (!constOp) - return std::nullopt; - - return getIntAttr(constOp.getValue()); -} - -bool isConstant(Value val, const unsigned expected) { - auto defOp = val.getDefiningOp(); - if (!defOp) - return false; - return (getFoldedConstantValue(defOp) == expected); -} - } // namespace mlir::triton::gpu::intel diff --git a/third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp b/third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp index 949464cfd0..ba826582c8 100644 --- a/third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp +++ b/third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp @@ -5,8 +5,8 @@ // //===----------------------------------------------------------------------===// -#include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h" #include "intel/include/TritonRaiseBlockPointer/Passes.h" +#include "intel/include/Utils/Utility.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Verifier.h" @@ -36,7 +36,6 @@ using namespace mlir; namespace tt = mlir::triton; -namespace ttgi = mlir::triton::gpu::intel; namespace mlir::triton::intel { #define GEN_PASS_DEF_TRITONRAISEBLOCKPOINTER @@ -115,50 +114,6 @@ Value findOrCreateMakeTensorPtr(Location loc, Value source, ValueRange shape, loc, source, zeros, strides, offsets, sizes, order); } -Value getFinalValue(Value value) { - Operation *defOp = value.getDefiningOp(); - if (!defOp) { - // look init values outside the loop - BlockArgument blockArg = dyn_cast(value); - Operation *parentOp = blockArg.getOwner()->getParentOp(); - if (scf::ForOp forOp = dyn_cast(parentOp)) - return getFinalValue(forOp.getInitArgs()[blockArg.getArgNumber() - 1]); - - return value; - } - - if (isa( - defOp)) - return getFinalValue(defOp->getOperand(0)); - - if (auto addOp = dyn_cast(defOp)) { - if (ttgi::isConstant(addOp.getLhs(), 0)) - return getFinalValue(addOp.getRhs()); - if (ttgi::isConstant(addOp.getRhs(), 0)) - return getFinalValue(addOp.getLhs()); - return addOp.getResult(); - } - - if (auto mulOp = dyn_cast(defOp)) { - if (ttgi::isConstant(mulOp.getLhs(), 1) || - ttgi::isConstant(mulOp.getRhs(), 0)) - return getFinalValue(mulOp.getRhs()); - if (ttgi::isConstant(mulOp.getRhs(), 1) || - ttgi::isConstant(mulOp.getLhs(), 0)) - return getFinalValue(mulOp.getLhs()); - return mulOp.getResult(); - } - - if (auto divOp = dyn_cast(defOp)) { - if (ttgi::isConstant(divOp.getRhs(), 1) || - ttgi::isConstant(divOp.getLhs(), 0)) - return getFinalValue(divOp.getLhs()); - return divOp.getResult(); - } - - return value; -} - // Data structure used to decode pointer arithmetics. Offsets, sizes, and // strides are in unit of elements in a linearly laid-out memory, which is the // same as pointer arithmetic operations in Triton language. Scalar is a @@ -200,7 +155,7 @@ struct PtrState { // When PtrState describes a non-block pointer, shape field indicates how // address wraps around. As a result, a constant 0 indicates no wrap // around (i.e. modulo) for the dimension. - return !ttgi::isConstant(shape[dim], 0); + return !tt::intel::isConstant(shape[dim], 0); } // @return true if addresses wrap around in any of the pointer dimension. @@ -231,7 +186,7 @@ struct PtrState { if (lhsState.scalar && rhsState.scalar) { scalar = builder.create(loc, lhsState.scalar, rhsState.scalar); - scalar = findOrCreateCast(loc, getFinalValue(scalar), + scalar = findOrCreateCast(loc, tt::intel::getFinalValue(scalar), lhsState.scalar.getType(), builder); } else if (lhsState.getRank() == 0) @@ -240,15 +195,15 @@ struct PtrState { for (unsigned i = 0; i < lhsState.getRank(); ++i) { Value newOffset = builder.create(loc, lhsState.offsets[i], rhsState.offsets[i]); - offsets.push_back(findOrCreateCast(loc, getFinalValue(newOffset), - lhsState.offsets[i].getType(), - builder)); + offsets.push_back( + findOrCreateCast(loc, tt::intel::getFinalValue(newOffset), + lhsState.offsets[i].getType(), builder)); Value newStride = builder.create(loc, lhsState.strides[i], rhsState.strides[i]); - strides.push_back(findOrCreateCast(loc, getFinalValue(newStride), - lhsState.strides[i].getType(), - builder)); + strides.push_back( + findOrCreateCast(loc, tt::intel::getFinalValue(newStride), + lhsState.strides[i].getType(), builder)); sizes.push_back(lhsState.sizes[i]); } @@ -285,7 +240,7 @@ struct PtrState { std::swap(lhs, rhs); for (unsigned i = 0; i < lhs->getRank(); ++i) { - if (!lhs->dimHasModulo(i) || ttgi::isConstant(rhs->offsets[i], 0)) { + if (!lhs->dimHasModulo(i) || tt::intel::isConstant(rhs->offsets[i], 0)) { shape.push_back(lhs->shape[i]); } else { op->emitRemark("TritonRaiseBlockPointer: do not support adding to " @@ -328,7 +283,7 @@ struct PtrState { findOrCreateCast(loc, rhs->scalar, builder.getIntegerType(offsetBitwidth), builder)); newOffset = - findOrCreateCast(loc, getFinalValue(newOffset), + findOrCreateCast(loc, tt::intel::getFinalValue(newOffset), builder.getIntegerType(offsetBitwidth), builder); Value newStride = builder.create( @@ -340,7 +295,7 @@ struct PtrState { builder.getIntegerType(shapeAndStridesBitwidth), builder)); newStride = findOrCreateCast( - loc, getFinalValue(newStride), + loc, tt::intel::getFinalValue(newStride), builder.getIntegerType(shapeAndStridesBitwidth), builder); Value newDim = builder.create( @@ -351,7 +306,7 @@ struct PtrState { findOrCreateCast(loc, rhs->scalar, builder.getIntegerType(shapeAndStridesBitwidth), builder)); - newDim = findOrCreateCast(loc, getFinalValue(newDim), + newDim = findOrCreateCast(loc, tt::intel::getFinalValue(newDim), builder.getIntegerType(shapeAndStridesBitwidth), builder); @@ -398,7 +353,7 @@ struct PtrState { // therefore we give up if none of the strides is one. bool noStrideIsOne = llvm::all_of(makeTPtrOp.getStrides(), [&](Value str) { - return !ttgi::isConstant(getFinalValue(str), 1); + return !tt::intel::isConstant(tt::intel::getFinalValue(str), 1); }); if (noStrideIsOne) return std::nullopt; @@ -410,7 +365,7 @@ struct PtrState { // 2b) offsets: (off0, 0) strides: (*, 1) ==> tt.advance ptr, (0, off0) bool allOffsetsNotZero = llvm::all_of(offsets, [&](Value offset) { - return !ttgi::isConstant(getFinalValue(offset), 0); + return !tt::intel::isConstant(tt::intel::getFinalValue(offset), 0); }); // Case 1: all offsets are non-zero. @@ -419,7 +374,7 @@ struct PtrState { "TODO: can we generate tt.advance ptr, (0, off0*str0 + off1) ?"); if (llvm::any_of(makeTPtrOp.getStrides(), [&](Value stride) { - return !ttgi::isConstant(getFinalValue(stride), 1); + return !tt::intel::isConstant(tt::intel::getFinalValue(stride), 1); })) return std::nullopt; @@ -432,11 +387,13 @@ struct PtrState { // Case 2: at least one offset is zero. assert(offsets.size() == 2 && "Expecting two offsets"); - bool zeroIdx = !ttgi::isConstant(getFinalValue(offsets[0]), 0); + bool zeroIdx = + !tt::intel::isConstant(tt::intel::getFinalValue(offsets[0]), 0); Value nonZeroOffset = offsets[!zeroIdx]; Value zeroOffset = offsets[zeroIdx]; - if (ttgi::isConstant(getFinalValue(makeTPtrOp.getStrides()[0]), 1)) + if (tt::intel::isConstant( + tt::intel::getFinalValue(makeTPtrOp.getStrides()[0]), 1)) newOffsets = {nonZeroOffset, zeroOffset}; else newOffsets = {zeroOffset, nonZeroOffset}; @@ -448,7 +405,7 @@ struct PtrState { private: Value computeOffset(Value offset, Value stride, OpBuilder &builder, Location loc) const { - if (ttgi::isConstant(stride, 0)) + if (tt::intel::isConstant(stride, 0)) return findOrCreateCast(loc, offset, builder.getIntegerType(offsetBitwidth), builder); @@ -458,7 +415,7 @@ struct PtrState { builder), findOrCreateCast(loc, stride, builder.getIntegerType(offsetBitwidth), builder)); - return findOrCreateCast(loc, getFinalValue(divOffset), + return findOrCreateCast(loc, tt::intel::getFinalValue(divOffset), builder.getIntegerType(offsetBitwidth), builder); } }; @@ -489,236 +446,6 @@ static llvm::raw_ostream &operator<<(llvm::raw_ostream &os, } #endif -// Utility class aggregating information required to create a versioning -// condition. -class VersioningCondition { -public: - VersioningCondition(Value S, Value BS) : S(S), BS(BS) { - assert(isValid() && "Invalid values supplied"); - } - - // Create the condition: (S % BS == 0 && S > BS) - Value materialize(OpBuilder &builder, Location loc) const { - assert(S && BS && "Expecting valid values"); - Value zero = - builder.createOrFold(loc, 0, S.getType()); - Value cmp1 = builder.create( - loc, arith::CmpIPredicate::eq, - builder.create(loc, S, BS), zero); - Value cmp2 = - builder.create(loc, arith::CmpIPredicate::sgt, S, BS); - return builder.create(loc, cmp1, cmp2); - } - -private: - bool isValid() const { - Type SType = S.getType(), BSType = BS.getType(); - if (!isa(SType) || !isa(BSType)) - return false; - - return cast(SType).getWidth() == - cast(BSType).getWidth(); - } - - Value S; // The length of a row/column. - Value BS; // The block size. -}; - -// Utility class responsible for collecting masked operation in a loop that are -// amenable to having their mask dropped when the loop is versioned. -class MaskedOpsCollector { - friend class LoopVersioner; - -public: - bool collectMaskedOps(scf::ForOp &forOp) { - // Nested loop aren't currently handled. - if (forOp->template getParentOfType()) - return false; - - // Ensure the loop upper bound is in canonical form (N+END-1)/END. - if (!hasValidUpperBound(forOp)) - return false; - - assert(versioningCond && "Expecting a valid versioning condition"); - - // Collect masked loads in the loop if they have canonical mask. - for (auto op : forOp.getOps()) { - Value mask = op.getMask(); - if (mask && isValidMask(getFinalValue(mask))) - maskedOps.insert(op); - } - - // TODO: collect masked stores in the loop if they have canonical mask. - - return maskedOps.size(); - } - -private: - // Check whether the loop UB is in canonical form: (N+END-1)/END and create - // the versioning condition to use for the loop if so. - bool hasValidUpperBound(scf::ForOp &forOp) { - Value ub = getFinalValue(forOp.getUpperBound()); - Operation *defOp = ub.getDefiningOp(); - if (!defOp || !isa(defOp)) - return false; - - auto divOp = cast(defOp); - Operation *divLhsOp = divOp.getLhs().getDefiningOp(); - Operation *divRhsOp = divOp.getRhs().getDefiningOp(); - if (!divLhsOp || !divRhsOp || !isa(divLhsOp) || - !isa(divRhsOp)) - return false; - - auto divNumOp = cast(divLhsOp); - auto divDenOp = cast(divRhsOp); - Operation *addLhsOp = divNumOp.getLhs().getDefiningOp(); - Operation *addRhsOp = divNumOp.getRhs().getDefiningOp(); - if (addLhsOp || !isa(addRhsOp) || - (divDenOp.value() != cast(addRhsOp).value() + 1)) - return false; - - versioningCond = std::make_unique(divNumOp.getLhs(), - divOp.getRhs()); - return true; - } - - // Check whether a mask is in canonical form: (0..END) < N - i*END - bool isValidMask(Value mask) const { - assert(mask.getDefiningOp() && "Expected a valid mask operation"); - auto cmpOp = cast(mask.getDefiningOp()); - arith::CmpIPredicate pred = cmpOp.getPredicate(); - if (pred != arith::CmpIPredicate::slt) - return false; - - Operation *lhs = getFinalValue(cmpOp.getLhs()).getDefiningOp(); - Operation *rhs = getFinalValue(cmpOp.getRhs()).getDefiningOp(); - if (!isa(lhs) || !isa(rhs)) - return false; - - auto rangeOp = cast(lhs); - unsigned end = rangeOp.getEnd(); - assert(end > rangeOp.getStart() && "Invalid range"); - - auto subOp = cast(rhs); - Operation *subLhs = subOp.getLhs().getDefiningOp(); - Operation *subRhs = subOp.getRhs().getDefiningOp(); - if (subLhs || !isa(subRhs)) - return false; - - auto mulOp = cast(subRhs); - Operation *mulLhs = mulOp.getLhs().getDefiningOp(); - Operation *mulRhs = mulOp.getRhs().getDefiningOp(); - if (mulLhs && mulRhs) - return false; - - if (!mulLhs && isa(mulRhs)) - return cast(mulRhs).value() == end; - if (!mulRhs && isa(mulLhs)) - return cast(mulLhs).value() == end; - - return false; - } - -private: - using MaskedOperations = SmallPtrSet; - // Masked operations in the loop that can be have their mask dropped when the - // loop is versioned using the condition builder associated with this class. - MaskedOperations maskedOps; - std::unique_ptr versioningCond = nullptr; -}; - -class LoopVersioner { -public: - // TODO: Extend the versioning region to encompass the downward exposed uses - // of the return values. - bool version(scf::ForOp &forOp, MaskedOpsCollector &collector) const { - if (!canVersion(forOp)) - return false; - - // Collect loop results that are downward exposed. - auto getUsedResults = [](const scf::ForOp &forOp) { - SmallVector resTypes; - for (Value res : forOp->getResults()) { - if (!res.getUsers().empty()) - resTypes.push_back(res.getType()); - } - return resTypes; - }; - - // Create the versioning condition. - OpBuilder builder(forOp); - Location loc = forOp.getLoc(); - Value versioningCond = collector.versioningCond->materialize(builder, loc); - auto ifOp = - builder.create(loc, getUsedResults(forOp), versioningCond, - /*withThenRegion=*/true, - /*withElseRegion=*/true); - - // Clone the original loop into the 2 if branches. - OpBuilder thenB = ifOp.getThenBodyBuilder(); - OpBuilder elseB = ifOp.getElseBodyBuilder(); - - IRMapping map; - Operation *thenForLoop = thenB.clone(*forOp.getOperation(), map); - Operation *elseForLoop = elseB.clone(*forOp.getOperation()); - - // Collect results in 'clonedLoop' corresponding to downward exposed results - // 'forOp'. - auto pruneUnusedResults = [&](const scf::ForOp &forOp, - Operation *clonedLoop) { - SmallVector prunedResults; - for (auto [idx, val] : llvm::enumerate(forOp->getResults())) { - if (!val.getUsers().empty()) - prunedResults.push_back(clonedLoop->getResult(idx)); - } - return prunedResults; - }; - - // Create the yield operations for the two if branches. - thenB.create(loc, pruneUnusedResults(forOp, thenForLoop)); - elseB.create(loc, pruneUnusedResults(forOp, elseForLoop)); - - // Drop the mask from candidate masked operations in the "then" region's - // cloned loop. - for (Operation *maskedOp : collector.maskedOps) { - Operation *mappedOp = map.lookup(maskedOp); - if (auto loadOp = dyn_cast(mappedOp)) { - OpBuilder builder(mappedOp); - auto newLoad = builder.create( - loadOp.getLoc(), loadOp.getPtr(), loadOp.getCache(), - loadOp.getEvict(), loadOp.getIsVolatile()); - mappedOp->replaceAllUsesWith(newLoad); - mappedOp->erase(); - } - // TODO: stores - } - - // Replace the uses of the original loop results. - unsigned idx = 0; - for (Value res : forOp.getResults()) { - if (!res.getUsers().empty()) - res.replaceAllUsesWith(ifOp->getResult(idx++)); - } - - forOp.erase(); - - return true; - } - -private: - // Currently we can version the loop only is it doesn't have downward - // exposed uses of return values that are a tensor of pointers. - // Note: this is due to the fact the results yielded by the 2 versioning - // branches have different types for ptr (only in one versioned loop tensor of - // ptrs are changed to block ptrs) 'then' part of the versioning branch and - // leave them as is in the 'else' branch). - bool canVersion(scf::ForOp &forOp) const { - return llvm::any_of(forOp.getResults(), [](Value res) { - return !tt::isTensorPointerType(res.getType()) || res.getUsers().empty(); - }); - } -}; - struct TritonRaiseBlockPointer : tt::intel::impl::TritonRaiseBlockPointerBase { public: @@ -731,25 +458,6 @@ struct TritonRaiseBlockPointer // Drop the mask or version loops containing masked operations. if (IgnoreMasks) dropMasks(moduleOp); - else { - // Collect masked operations amenable to versioning in each loop. - moduleOp->walk([&](Operation *op) { - MaskedOpsCollector collector; - LoopVersioner loopVersioner; - if (scf::ForOp forOp = dyn_cast(op)) { - if (collector.collectMaskedOps(forOp)) { - [[maybe_unused]] bool loopVersioned = - loopVersioner.version(forOp, collector); - if (loopVersioned) - LLVM_DEBUG(llvm::dbgs() << "Loop versioned\n"); - } - } - return WalkResult::advance(); - }); - - LLVM_DEBUG(llvm::dbgs() << "After versioning:\n" << moduleOp << "\n"); - assert(succeeded(verify(moduleOp)) && "Module verification failed"); - } // Perform the transformation. if (failed(rewriteOp(moduleOp))) @@ -930,7 +638,7 @@ struct TritonRaiseBlockPointer } bool lookForMultiplyingValueInDefiningPath(Value &val, Value &ref) const { - if (Operation *defOp = getFinalValue(val).getDefiningOp()) { + if (Operation *defOp = tt::intel::getFinalValue(val).getDefiningOp()) { if (auto mulOp = dyn_cast(defOp)) { if ((mulOp.getLhs() == ref) || (mulOp.getRhs() == ref)) return true; @@ -946,8 +654,8 @@ struct TritonRaiseBlockPointer Operation *op1 = val1.getDefiningOp(); Operation *op2 = val2.getDefiningOp(); if (op1 && op2) { - std::optional intVal1 = ttgi::getFoldedConstantValue(op1); - std::optional intVal2 = ttgi::getFoldedConstantValue(op2); + std::optional intVal1 = tt::intel::getFoldedConstantValue(op1); + std::optional intVal2 = tt::intel::getFoldedConstantValue(op2); if (intVal1.has_value() && intVal2.has_value()) return intVal1.value() == intVal2.value(); } @@ -962,7 +670,7 @@ struct TritonRaiseBlockPointer SmallVector finalStrides; // check whether all strides are different, if not => skip for (auto stride : strides) { - Value currentVal = getFinalValue(stride); + Value currentVal = tt::intel::getFinalValue(stride); if (llvm::any_of(finalStrides, [&](Value val) { return areValuesEqual(val, currentVal); })) @@ -975,7 +683,7 @@ struct TritonRaiseBlockPointer // search for a mul to finalStride in the predecessors if (lookForMultiplyingValueInDefiningPath(operand, finalStride)) return axis; - if (ttgi::isConstant(finalStride, 1)) + if (tt::intel::isConstant(finalStride, 1)) return axis; ++axis; } @@ -1150,7 +858,7 @@ struct TritonRaiseBlockPointer auto scaledOffset = builder.createOrFold(loc, offsetCst, strideCst); state.offsets.push_back( - findOrCreateCast(loc, getFinalValue(scaledOffset), + findOrCreateCast(loc, tt::intel::getFinalValue(scaledOffset), builder.getIntegerType(offsetBitwidth), builder)); } state.strides = makeTPtrOp.getStrides(); @@ -1261,12 +969,8 @@ struct TritonRaiseBlockPointer Operation *definingOp = operand.getDefiningOp(); if (!definingOp) { - if (!knownPtrs.contains(operand)) { - llvm::errs() << "TritonRaiseBlockPointer: encountered addptr block " - "argument operand\n" - << operand << "\n"; + if (!knownPtrs.contains(operand)) return failure(); - } // This operand must be an iter-arg of an inner-loop in a multiple-level // nested loop, which means its PtrState must have already been @@ -1347,7 +1051,7 @@ struct TritonRaiseBlockPointer if (auto iter = knownPtrs.find(ptr); iter != knownPtrs.end()) { PtrState state = iter->second; for (int axis = 0; axis < state.shape.size(); ++axis) { - if (!ttgi::isConstant(state.shape[axis], 0)) + if (!tt::intel::isConstant(state.shape[axis], 0)) boundary.push_back(axis); } } diff --git a/third_party/intel/lib/Utils/CMakeLists.txt b/third_party/intel/lib/Utils/CMakeLists.txt index 491bb04b8a..2bc9e35db7 100644 --- a/third_party/intel/lib/Utils/CMakeLists.txt +++ b/third_party/intel/lib/Utils/CMakeLists.txt @@ -1,6 +1,7 @@ add_triton_library(TritonIntelUtils LLVMIntr.cpp Mangling.cpp + Utility.cpp LINK_LIBS PUBLIC MLIRIR diff --git a/third_party/intel/lib/Utils/Utility.cpp b/third_party/intel/lib/Utils/Utility.cpp new file mode 100644 index 0000000000..62f55a74ec --- /dev/null +++ b/third_party/intel/lib/Utils/Utility.cpp @@ -0,0 +1,94 @@ + +#include "intel/include/Utils/Utility.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +using namespace mlir; +namespace tt = mlir::triton; + +static std::optional getIntAttr(const OpFoldResult ofr) { + if (auto attr = dyn_cast(ofr)) + if (auto intAttr = dyn_cast(attr)) + return intAttr.getInt(); + return std::nullopt; +} + +namespace mlir::triton::intel { + +std::optional getFoldedConstantValue(Operation *op) { + SmallVector results; + if (failed(op->fold(results))) + return std::nullopt; + + // If fold succeeded but `results` is empty, we give a second try, after the + // operands have been switched during the first call to `fold()`. + if (results.empty()) { + if (failed(op->fold(results))) + return std::nullopt; + } + + if (results.size() != 1) + return std::nullopt; + + auto intAttr = getIntAttr(results[0]); + if (intAttr.has_value()) + return intAttr.value(); + + auto val = cast(results[0]); + auto constOp = val.getDefiningOp(); + if (!constOp) + return std::nullopt; + + return getIntAttr(constOp.getValue()); +} + +bool isConstant(Value val, const unsigned expected) { + if (auto defOp = val.getDefiningOp()) + return (getFoldedConstantValue(defOp) == expected); + return false; +} + +Value getFinalValue(Value value) { + Operation *defOp = value.getDefiningOp(); + if (!defOp) { + // look init values outside the loop + BlockArgument blockArg = dyn_cast(value); + Operation *parentOp = blockArg.getOwner()->getParentOp(); + if (scf::ForOp forOp = dyn_cast(parentOp)) + return getFinalValue(forOp.getInitArgs()[blockArg.getArgNumber() - 1]); + + return value; + } + + if (isa( + defOp)) + return getFinalValue(defOp->getOperand(0)); + + if (auto addOp = dyn_cast(defOp)) { + if (isConstant(addOp.getLhs(), 0)) + return getFinalValue(addOp.getRhs()); + if (isConstant(addOp.getRhs(), 0)) + return getFinalValue(addOp.getLhs()); + return addOp.getResult(); + } + + if (auto mulOp = dyn_cast(defOp)) { + if (isConstant(mulOp.getLhs(), 1) || isConstant(mulOp.getRhs(), 0)) + return getFinalValue(mulOp.getRhs()); + if (isConstant(mulOp.getRhs(), 1) || isConstant(mulOp.getLhs(), 0)) + return getFinalValue(mulOp.getLhs()); + return mulOp.getResult(); + } + + if (auto divOp = dyn_cast(defOp)) { + if (isConstant(divOp.getRhs(), 1) || isConstant(divOp.getLhs(), 0)) + return getFinalValue(divOp.getLhs()); + return divOp.getResult(); + } + + return value; +} + +} // namespace mlir::triton::intel diff --git a/third_party/intel/triton_xpu.cc b/third_party/intel/triton_xpu.cc index bc90dfd725..85662320ab 100644 --- a/third_party/intel/triton_xpu.cc +++ b/third_party/intel/triton_xpu.cc @@ -8,10 +8,10 @@ #include "llvm/Passes/StandardInstrumentations.h" #include "llvm/Transforms/InstCombine/InstCombine.h" +#include "intel/include/Dialect/Triton/Transforms/Passes.h" #include "intel/include/Dialect/TritonGEN/IR/TritonGENDialect.h" #include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" #include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h" -#include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h" #include "intel/include/Target/LLVMIR/Dialect/TritonGEN/TritonGENToLLVMIRTranslation.h" #include "intel/include/Target/LLVMIR/PostProcess.h" #include "intel/include/TritonAnnotateModule/Passes.h" @@ -66,6 +66,7 @@ static uint32_t findKernels(llvm::Module &M, } void init_triton_intel_passes_ttir(py::module &&m) { + ADD_PASS_WRAPPER_0("add_remove_masks", intel::createTritonIntelRemoveMasks); ADD_PASS_WRAPPER_OPT_1("add_raise_block_pointer", intel::createTritonRaiseBlockPointer, bool); ADD_PASS_WRAPPER_OPT_1("add_convert_to_ttgpuir_warp", From 06629ffcb1b73130395a55ef47878eea3da6353f Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Mon, 24 Feb 2025 20:36:19 +0000 Subject: [PATCH 2/6] Fix precommit Signed-off-by: Tiotto, Ettore --- .../intel/include/Dialect/Triton/Transforms/Passes.h | 2 +- .../intel/include/Dialect/Triton/Transforms/Passes.td | 6 +++--- .../intel/lib/Dialect/Triton/Transforms/RemoveMasks.cpp | 2 +- .../intel/lib/TritonIntelGPUTransforms/CMakeLists.txt | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/third_party/intel/include/Dialect/Triton/Transforms/Passes.h b/third_party/intel/include/Dialect/Triton/Transforms/Passes.h index 650cc87df9..96bb65f612 100644 --- a/third_party/intel/include/Dialect/Triton/Transforms/Passes.h +++ b/third_party/intel/include/Dialect/Triton/Transforms/Passes.h @@ -22,4 +22,4 @@ namespace mlir::triton::intel { } // namespace mlir::triton::intel -#endif // TRITON_DIALECT_TRITON_INTEL_TRANSFORMS_PASSES_H \ No newline at end of file +#endif // TRITON_DIALECT_TRITON_INTEL_TRANSFORMS_PASSES_H diff --git a/third_party/intel/include/Dialect/Triton/Transforms/Passes.td b/third_party/intel/include/Dialect/Triton/Transforms/Passes.td index 466932cf48..e8be26c654 100644 --- a/third_party/intel/include/Dialect/Triton/Transforms/Passes.td +++ b/third_party/intel/include/Dialect/Triton/Transforms/Passes.td @@ -16,10 +16,10 @@ def TritonIntelRemoveMasks let summary = "Remove masks from tt.load and tt.store operations"; let description = [{ - This pass attempts to remove the mask for tt.load and tt.store operations. - If the masked operation is in a loop, the pass attempts to find a loop + This pass attempts to remove the mask for tt.load and tt.store operations. + If the masked operation is in a loop, the pass attempts to find a loop invariant condition equivalent to the mask condition, and then use it to - version the loop. + version the loop. }]; let dependentDialects = [ diff --git a/third_party/intel/lib/Dialect/Triton/Transforms/RemoveMasks.cpp b/third_party/intel/lib/Dialect/Triton/Transforms/RemoveMasks.cpp index f171b3019c..0c02c70356 100644 --- a/third_party/intel/lib/Dialect/Triton/Transforms/RemoveMasks.cpp +++ b/third_party/intel/lib/Dialect/Triton/Transforms/RemoveMasks.cpp @@ -304,4 +304,4 @@ struct TritonIntelRemoveMasksBase } }; -} // namespace \ No newline at end of file +} // namespace diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt b/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt index 6d57997695..f398f09a85 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt +++ b/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt @@ -27,5 +27,5 @@ add_triton_library(TritonIntelGPUTransforms TritonGENIR TritonGPUIR TritonIntelGPUIR - TritonIntelUtils + TritonIntelUtils ) From bde5a8ca9102afbde793906a9d59874ef05ca56d Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Mon, 24 Feb 2025 21:31:48 +0000 Subject: [PATCH 3/6] Fix assertion Signed-off-by: Tiotto, Ettore --- .../Dialect/Triton/Transforms/RemoveMasks.cpp | 35 ++++++++++++------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/third_party/intel/lib/Dialect/Triton/Transforms/RemoveMasks.cpp b/third_party/intel/lib/Dialect/Triton/Transforms/RemoveMasks.cpp index 0c02c70356..d513053fb7 100644 --- a/third_party/intel/lib/Dialect/Triton/Transforms/RemoveMasks.cpp +++ b/third_party/intel/lib/Dialect/Triton/Transforms/RemoveMasks.cpp @@ -100,7 +100,9 @@ class MaskedOpsCollector { // Check whether a mask is in canonical form: (0..END) < N - i*END bool isValidMask(Value mask) const { - assert(mask.getDefiningOp() && "Expected a valid mask operation"); + if (!mask.getDefiningOp() || !isa(mask.getDefiningOp())) + return false; + auto cmpOp = cast(mask.getDefiningOp()); arith::CmpIPredicate pred = cmpOp.getPredicate(); if (pred != arith::CmpIPredicate::slt) @@ -122,25 +124,32 @@ class MaskedOpsCollector { return false; auto mulOp = cast(subRhs); - Operation *mulLhs = mulOp.getLhs().getDefiningOp(); - Operation *mulRhs = mulOp.getRhs().getDefiningOp(); - if (mulLhs && mulRhs) + Operation *defMulLhs = mulOp.getLhs().getDefiningOp(); + Operation *defMulRhs = mulOp.getRhs().getDefiningOp(); + if (defMulLhs && defMulRhs) return false; - if (!mulLhs && isa(mulRhs)) - return cast(mulRhs).value() == end; - if (!mulRhs && isa(mulLhs)) - return cast(mulLhs).value() == end; + std::optional loopIV = forOp.getSingleInductionVar(); + assert(loopIV.has_value() && "Failed to find loop induction variable"); + + if (!defMulLhs && mulOp.getLhs() == *loopIV && + isa(defMulRhs)) + return cast(defMulRhs).value() == end; + + if (!defMulRhs && mulOp.getRhs() == *loopIV && + isa(defMulLhs)) + return cast(defMulLhs).value() == end; return false; } private: - // Masked operations in the loop that can be have their mask dropped when the - // loop is versioned using the versioning condition associated with this - // class. scf::ForOp &forOp; + + // Masked operations that can be have their mask dropped when the loop is + // versioned using the versioning condition associated with this class. MaskedOperations maskedOps; + std::unique_ptr versioningCond = nullptr; }; @@ -285,8 +294,8 @@ struct TritonIntelRemoveMasksBase if (forOp->template getParentOfType()) return WalkResult::advance(); - // Ensure loop UB is in 'canonical' form. - if (!LoopVersioner::hasValidUpperBound(forOp)) + if (!forOp.getSingleInductionVar() || + !LoopVersioner::hasValidUpperBound(forOp)) return WalkResult::advance(); MaskedOpsCollector collector(forOp); From 22b34e14f64abc3bdc45be2b439eb22c49585413 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Mon, 24 Feb 2025 23:07:47 +0000 Subject: [PATCH 4/6] Fix assertion Signed-off-by: Tiotto, Ettore --- third_party/intel/lib/Dialect/Triton/Transforms/RemoveMasks.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/intel/lib/Dialect/Triton/Transforms/RemoveMasks.cpp b/third_party/intel/lib/Dialect/Triton/Transforms/RemoveMasks.cpp index d513053fb7..b910da9ee7 100644 --- a/third_party/intel/lib/Dialect/Triton/Transforms/RemoveMasks.cpp +++ b/third_party/intel/lib/Dialect/Triton/Transforms/RemoveMasks.cpp @@ -110,7 +110,7 @@ class MaskedOpsCollector { Operation *lhs = tt::intel::getFinalValue(cmpOp.getLhs()).getDefiningOp(); Operation *rhs = tt::intel::getFinalValue(cmpOp.getRhs()).getDefiningOp(); - if (!isa(lhs) || !isa(rhs)) + if (!lhs || !rhs || !isa(lhs) || !isa(rhs)) return false; auto rangeOp = cast(lhs); From 272d1f414205513c5454e27c147ed26eb038541b Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Sat, 1 Mar 2025 01:06:42 +0000 Subject: [PATCH 5/6] [Loop Specialization]: Specalize loops containing masked operations with loop invariant masks Signed-off-by: Tiotto, Ettore --- .../kernel-03-matrix-multiplication.mlir | 1 + third_party/intel/backend/compiler.py | 2 + .../Dialect/Triton/Transforms/RemoveMasks.cpp | 434 ++++++++++++------ .../TritonRaiseBlockPointer.cpp | 9 +- third_party/intel/lib/Utils/Utility.cpp | 1 + 5 files changed, 306 insertions(+), 141 deletions(-) diff --git a/test/Triton/Intel/RaiseToBlockPointers/kernel-03-matrix-multiplication.mlir b/test/Triton/Intel/RaiseToBlockPointers/kernel-03-matrix-multiplication.mlir index eefab3a6ff..e2c040a89b 100644 --- a/test/Triton/Intel/RaiseToBlockPointers/kernel-03-matrix-multiplication.mlir +++ b/test/Triton/Intel/RaiseToBlockPointers/kernel-03-matrix-multiplication.mlir @@ -1,6 +1,7 @@ // RUN: triton-opt %s -triton-intel-remove-masks -triton-raise-block-pointer -canonicalize | FileCheck %s module { + // COM: Derived from tutorial 03-matrix-multiplication. tt.func public @matmul_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) { %c31_i32 = arith.constant 31 : i32 %cst = arith.constant dense<0.000000e+00> : tensor<64x128xf32> diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index 54ed5ecb4d..efc89391e5 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -224,6 +224,8 @@ def make_ttir(mod, metadata, opt): pm.enable_debug() passes.common.add_inliner(pm) passes.ttir.add_combine(pm) + passes.common.add_cse(pm) + passes.common.add_licm(pm) intel.passes.ttir.add_remove_masks(pm) if raise_block_ptr_flags['enabled']: ignore_masks = True if raise_block_ptr_flags['ignore-masks'] else False diff --git a/third_party/intel/lib/Dialect/Triton/Transforms/RemoveMasks.cpp b/third_party/intel/lib/Dialect/Triton/Transforms/RemoveMasks.cpp index b910da9ee7..4671e10ffa 100644 --- a/third_party/intel/lib/Dialect/Triton/Transforms/RemoveMasks.cpp +++ b/third_party/intel/lib/Dialect/Triton/Transforms/RemoveMasks.cpp @@ -3,9 +3,9 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Verifier.h" -// #include "mlir/Pass/Pass.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" #define DEBUG_TYPE "triton-intel-remove-masks" @@ -19,87 +19,24 @@ namespace mlir::triton::intel { namespace { -// Represent a versioning condition for a loop. -class VersioningCondition { +// Abstract base class for mask validators. +// Mask validators are used to check whether a given mask has an expected form. +// Concreate subclasses define the expected form. +class MaskValidatorBase { public: - VersioningCondition(Value S, Value BS) : S(S), BS(BS) { - assert(isValid() && "Invalid values supplied"); - } - - // Create the condition: (S % BS == 0 && S > BS) - Value materialize(OpBuilder &builder, Location loc) const { - assert(S && BS && "Expecting valid values"); - Value zero = - builder.createOrFold(loc, 0, S.getType()); - Value cmp1 = builder.create( - loc, arith::CmpIPredicate::eq, - builder.create(loc, S, BS), zero); - Value cmp2 = - builder.create(loc, arith::CmpIPredicate::sgt, S, BS); - return builder.create(loc, cmp1, cmp2); - } - -private: - bool isValid() const { - Type SType = S.getType(), BSType = BS.getType(); - if (!isa(SType) || !isa(BSType)) - return false; + virtual ~MaskValidatorBase() = default; - return cast(SType).getWidth() == - cast(BSType).getWidth(); - } - - Value S; // The length of a row/column. - Value BS; // The block size. + // Check whether the given mask is valid. + virtual bool isValidMask(Value mask, scf::ForOp &forOp) const = 0; }; -// Collects masked operations conditions in a loop. -class MaskedOpsCollector { +// A mask validator which ensures that the mask can be reduced to the form: +// `END < N-i*END`. +class CanonicalMaskValidator final : public MaskValidatorBase { public: - using MaskedOperations = SmallPtrSet; + virtual bool isValidMask(Value mask, scf::ForOp &forOp) const { + assert(mask && "Expecting a valid mask"); - MaskedOpsCollector(scf::ForOp &forOp) : forOp(forOp) { - assert(!forOp->template getParentOfType() && - "Nested loop not handled yet"); - createVersioningCondition(forOp); - } - - // Collect mask condition that can be made loop invariant for the `tt.load` - // operation in the given loop. - bool collectMaskedOps() { - assert(versioningCond && "Versioning condition should be valid"); - - // Collect masked loads in the loop if they have canonical mask. - for (auto op : forOp.getOps()) { - Value mask = op.getMask(); - if (mask && isValidMask(tt::intel::getFinalValue(mask))) - maskedOps.insert(op); - } - - // TODO: collect masked stores in the loop if they have canonical mask. - return maskedOps.size(); - } - - VersioningCondition *getVersioningCond() const { - return versioningCond.get(); - }; - - const MaskedOperations &getMaskedOps() const { return maskedOps; }; - -private: - // Note: this assumes the loop UB is in canonical form `N+END-1)/END`. - void createVersioningCondition(scf::ForOp &forOp) { - Value ub = tt::intel::getFinalValue(forOp.getUpperBound()); - Operation *defOp = ub.getDefiningOp(); - auto divOp = cast(defOp); - Operation *divLhsOp = divOp.getLhs().getDefiningOp(); - auto divNumOp = cast(divLhsOp); - versioningCond = std::make_unique(divNumOp.getLhs(), - divOp.getRhs()); - } - - // Check whether a mask is in canonical form: (0..END) < N - i*END - bool isValidMask(Value mask) const { if (!mask.getDefiningOp() || !isa(mask.getDefiningOp())) return false; @@ -143,14 +80,183 @@ class MaskedOpsCollector { return false; } + // Create the loop versioning condition, assumes the loop upper bound is the + // form `(N+END-1)/END`. + Value getVersioningCond(scf::ForOp forOp) const { + assert(hasValidUpperBound(forOp) && "Invalid upper bound"); + + Value ub = tt::intel::getFinalValue(forOp.getUpperBound()); + Operation *defOp = ub.getDefiningOp(); + auto divOp = cast(defOp); + Operation *divLhsOp = divOp.getLhs().getDefiningOp(); + auto divNumOp = cast(divLhsOp); + Value lhs = divNumOp.getLhs(); + Value rhs = divOp.getRhs(); + + OpBuilder builder(forOp); + Location loc = forOp.getLoc(); + Value zero = + builder.createOrFold(loc, 0, lhs.getType()); + Value cmp1 = builder.create( + loc, arith::CmpIPredicate::eq, + builder.create(loc, lhs, rhs), zero); + Value cmp2 = + builder.create(loc, arith::CmpIPredicate::sgt, lhs, rhs); + return builder.create(loc, cmp1, cmp2); + } + + // Ensure the loop upper bound is in canonical form (N+END-1)/END. + static bool hasValidUpperBound(scf::ForOp &forOp) { + Value ub = tt::intel::getFinalValue(forOp.getUpperBound()); + Operation *defOp = ub.getDefiningOp(); + if (!defOp || !isa(defOp)) + return false; + + auto divOp = cast(defOp); + Operation *divLhsOp = divOp.getLhs().getDefiningOp(); + Operation *divRhsOp = divOp.getRhs().getDefiningOp(); + if (!divLhsOp || !divRhsOp || !isa(divLhsOp) || + !isa(divRhsOp)) + return false; + + auto divNumOp = cast(divLhsOp); + auto divDenOp = cast(divRhsOp); + Operation *addLhsOp = divNumOp.getLhs().getDefiningOp(); + Operation *addRhsOp = divNumOp.getRhs().getDefiningOp(); + if (addLhsOp || !isa(addRhsOp) || + (divDenOp.value() != cast(addRhsOp).value() + 1)) + return false; + + return true; + } +}; + +// This mask validator ensures the mask is loop invariant. +class InvariantMaskValidator final : public MaskValidatorBase { +public: + // The mask must have one of the forms: + // - N < M (with i1 data type) + // - [0..END] < splat(N) + // - splat(N) < [0..END] + virtual bool isValidMask(Value mask, scf::ForOp &forOp) const { + assert(mask && "Expecting a valid mask"); + + if (!mask.getDefiningOp() || !isa(mask.getDefiningOp())) + return false; + + auto cmpOp = cast(mask.getDefiningOp()); + arith::CmpIPredicate pred = cmpOp.getPredicate(); + if (pred != arith::CmpIPredicate::slt) + return false; + + bool isInLoop = (cmpOp->getParentOfType() == forOp); + if (isInLoop) + return false; + + Value lhsVal = tt::intel::getFinalValue(cmpOp.getLhs()); + Value rhsVal = tt::intel::getFinalValue(cmpOp.getRhs()); + Operation *lhs = tt::intel::getFinalValue(lhsVal).getDefiningOp(); + Operation *rhs = tt::intel::getFinalValue(rhsVal).getDefiningOp(); + + if (!lhs && !rhs) { + assert(lhsVal.getType() == rhsVal.getType() && "Invalid types"); + assert(isa(lhsVal.getType()) && + cast(lhsVal.getType()).getWidth() == 1 && + "Invalid type"); + return true; + } + + if (!rhs && isa(lhs)) { + [[maybe_unused]] auto rangeOp = cast(lhs); + assert(rangeOp.getStart() < rangeOp.getEnd() && "Invalid range"); + return true; + } + + if (!lhs && isa(rhs)) { + [[maybe_unused]] auto rangeOp = cast(rhs); + assert(rangeOp.getStart() < rangeOp.getEnd() && "Invalid range"); + return true; + } + + return false; + } + + Value getVersioningCond(Value mask, scf::ForOp &forOp) const { + assert(isValidMask(mask, forOp) && "Invalid mask"); + + OpBuilder builder(forOp); + Location loc = forOp.getLoc(); + auto cmpOp = cast(mask.getDefiningOp()); + Value lhsVal = tt::intel::getFinalValue(cmpOp.getLhs()); + Value rhsVal = tt::intel::getFinalValue(cmpOp.getRhs()); + Operation *lhs = tt::intel::getFinalValue(lhsVal).getDefiningOp(); + Operation *rhs = tt::intel::getFinalValue(rhsVal).getDefiningOp(); + + // N < M (with i1 data type) + if (!lhs && !rhs) + return builder.createOrFold(loc, arith::CmpIPredicate::slt, + lhsVal, rhsVal); + + // [0..END] < splat(N) + if (!rhs && isa(lhs)) { + [[maybe_unused]] auto rangeOp = cast(lhs); + assert(rangeOp.getStart() < rangeOp.getEnd() && "Invalid range"); + unsigned end = rangeOp.getEnd(); + auto cstOp = builder.createOrFold(loc, end, + rhsVal.getType()); + return builder.createOrFold(loc, arith::CmpIPredicate::slt, + cstOp, rhsVal); + } + + // splat(N) < [0..END] + if (!lhs && isa(rhs)) { + [[maybe_unused]] auto rangeOp = cast(rhs); + assert(rangeOp.getStart() < rangeOp.getEnd() && "Invalid range"); + unsigned start = rangeOp.getStart(); + auto cstOp = builder.createOrFold(loc, start, + lhsVal.getType()); + return builder.createOrFold(loc, arith::CmpIPredicate::slt, + lhsVal, cstOp); + } + + llvm_unreachable("Unexpected mask"); + return {}; + } +}; + +// Collects masked operations in a loop that satisfy the condition imposed by +// the mask validator associated with this class. +template class MaskedOpsCollector { +public: + using MaskedOperations = SmallPtrSet; + + MaskedOpsCollector(scf::ForOp &forOp, MaskValidator &maskValidator) + : forOp(forOp), maskValidator(maskValidator) {} + + bool collectMaskedOps() { + auto collectMaskedOps = [&](auto ops, MaskedOperations &maskedOps) { + for (Operation *op : ops) { + Value mask = isa(op) ? cast(op).getMask() + : isa(op) ? cast(op).getMask() + : nullptr; + if (mask && + maskValidator.isValidMask(tt::intel::getFinalValue(mask), forOp)) + maskedOps.insert(op); + } + }; + + collectMaskedOps(forOp.getOps(), maskedOps); + collectMaskedOps(forOp.getOps(), maskedOps); + return maskedOps.size(); + } + + const MaskedOperations &getMaskedOps() const { return maskedOps; }; + const MaskValidator &getMaskValidator() const { return maskValidator; } + private: scf::ForOp &forOp; - - // Masked operations that can be have their mask dropped when the loop is - // versioned using the versioning condition associated with this class. + MaskValidator &maskValidator; MaskedOperations maskedOps; - - std::unique_ptr versioningCond = nullptr; }; class LoopVersioner { @@ -159,15 +265,32 @@ class LoopVersioner { // by \p collector unnecessary. // TODO: Extend the versioning region to encompass the downward exposed uses // of the return values. - static bool version(scf::ForOp &forOp, MaskedOpsCollector &collector) { - assert(collector.getVersioningCond() && - "Versioning condition should be present"); - - // Limitation: give up if the loop returns tensor of ptrs. + static bool version(scf::ForOp &forOp, + MaskedOpsCollector &collector) { + // Limitation + // Currently we can version the loop only is it doesn't have downward + // exposed uses of return values that are a tensor of pointers. + // Note: this is due to the fact the results yielded by the 2 versioning + // branches have different types for ptr (only in one versioned loop tensor + // of ptrs are changed to block ptrs) 'then' part of the versioning branch + // and leave them as is in the 'else' branch). + auto canVersion = [](scf::ForOp &forOp) { + return llvm::any_of(forOp.getResults(), [](Value res) { + return !tt::isTensorPointerType(res.getType()) || + res.getUsers().empty(); + }); + }; if (!canVersion(forOp)) return false; - // Collect loop results that are downward exposed. + // Retrieve the versioning condition, bail out if it doesn't exist (in which + // case the loop upper bound is not in canonical form). + Value verCond = collector.getMaskValidator().getVersioningCond(forOp); + if (!verCond) + return false; + + // This lambda is used to collect the types for the loop results that are + // downward exposed (i.e. used by other operations). auto getUsedResults = [](const scf::ForOp &forOp) { SmallVector resTypes; for (Value res : forOp->getResults()) { @@ -180,23 +303,18 @@ class LoopVersioner { // Create the versioning branch. OpBuilder builder(forOp); Location loc = forOp.getLoc(); - Value versioningCond = - collector.getVersioningCond()->materialize(builder, loc); - auto ifOp = - builder.create(loc, getUsedResults(forOp), versioningCond, - /*withThenRegion=*/true, - /*withElseRegion=*/true); + auto ifOp = builder.create(loc, getUsedResults(forOp), verCond, + /*withThenRegion=*/true); // Clone the original loop into the 2 if branches. - OpBuilder thenB = ifOp.getThenBodyBuilder(); - OpBuilder elseB = ifOp.getElseBodyBuilder(); - IRMapping map; + OpBuilder thenB = ifOp.getThenBodyBuilder(); Operation *thenForLoop = thenB.clone(*forOp.getOperation(), map); + OpBuilder elseB = ifOp.getElseBodyBuilder(); Operation *elseForLoop = elseB.clone(*forOp.getOperation()); // Collect results in 'clonedLoop' corresponding to downward exposed results - // 'forOp'. + // of the given loop. auto pruneUnusedResults = [&](const scf::ForOp &forOp, Operation *clonedLoop) { SmallVector prunedResults; @@ -211,11 +329,14 @@ class LoopVersioner { thenB.create(loc, pruneUnusedResults(forOp, thenForLoop)); elseB.create(loc, pruneUnusedResults(forOp, elseForLoop)); - // Drop the mask from candidate masked operations in the "then" region's - // cloned loop. + // Drop the mask from candidate masked operations in the "then" region. for (Operation *maskedOp : collector.getMaskedOps()) { + llvm::errs() << "maskedOp: " << *maskedOp << "\n"; Operation *mappedOp = map.lookup(maskedOp); + llvm::errs() << "mappedOp: " << *mappedOp << "\n"; + if (auto loadOp = dyn_cast(mappedOp)) { + llvm::errs() << "BINGO, load: " << *loadOp << "\n"; OpBuilder builder(mappedOp); auto newLoad = builder.create( loadOp.getLoc(), loadOp.getPtr(), loadOp.getCache(), @@ -234,46 +355,67 @@ class LoopVersioner { } forOp.erase(); - return true; } - // Ensure the loop upper bound is in canonical form (N+END-1)/END. - static bool hasValidUpperBound(scf::ForOp &forOp) { - Value ub = tt::intel::getFinalValue(forOp.getUpperBound()); - Operation *defOp = ub.getDefiningOp(); - if (!defOp || !isa(defOp)) - return false; + static bool version(scf::ForOp &forOp, + MaskedOpsCollector &collector) { + // Collect the (loop invariant) mask conditions. + std::set maskConds; + for (Operation *maskedOp : collector.getMaskedOps()) { + if (auto loadOp = dyn_cast(maskedOp)) + maskConds.insert(loadOp.getMask().getDefiningOp()); + if (auto storeOp = dyn_cast(maskedOp)) + maskConds.insert(storeOp.getMask().getDefiningOp()); + } - auto divOp = cast(defOp); - Operation *divLhsOp = divOp.getLhs().getDefiningOp(); - Operation *divRhsOp = divOp.getRhs().getDefiningOp(); - if (!divLhsOp || !divRhsOp || !isa(divLhsOp) || - !isa(divRhsOp)) - return false; + // Combine the versioning conditions. + OpBuilder builder(forOp); + Location loc = forOp.getLoc(); + auto it = maskConds.begin(); + Value firstCond = (*it++)->getResult(0); + auto maskValidator = collector.getMaskValidator(); + Value verCond = maskValidator.getVersioningCond(firstCond, forOp); + for (; it != maskConds.end(); ++it) { + Value nextCond = (*it)->getResult(0); + Value cond = maskValidator.getVersioningCond(nextCond, forOp); + verCond = builder.create(loc, verCond, cond); + } - auto divNumOp = cast(divLhsOp); - auto divDenOp = cast(divRhsOp); - Operation *addLhsOp = divNumOp.getLhs().getDefiningOp(); - Operation *addRhsOp = divNumOp.getRhs().getDefiningOp(); - if (addLhsOp || !isa(addRhsOp) || - (divDenOp.value() != cast(addRhsOp).value() + 1)) - return false; + auto ifOp = builder.create(loc, forOp.getResultTypes(), verCond, + /*withThenRegion=*/true); - return true; - } + // Clone the original loop into the 2 if branches. + IRMapping map; + OpBuilder thenB = ifOp.getThenBodyBuilder(); + Operation *thenForLoop = thenB.clone(*forOp.getOperation(), map); + OpBuilder elseB = ifOp.getElseBodyBuilder(); + Operation *elseForLoop = elseB.clone(*forOp.getOperation()); -private: - // Currently we can version the loop only is it doesn't have downward - // exposed uses of return values that are a tensor of pointers. - // Note: this is due to the fact the results yielded by the 2 versioning - // branches have different types for ptr (only in one versioned loop tensor of - // ptrs are changed to block ptrs) 'then' part of the versioning branch and - // leave them as is in the 'else' branch). - static bool canVersion(scf::ForOp &forOp) { - return llvm::any_of(forOp.getResults(), [](Value res) { - return !tt::isTensorPointerType(res.getType()) || res.getUsers().empty(); - }); + // Drop the mask from candidate masked operations in the "then" region's + // cloned loop. + for (Operation *maskedOp : collector.getMaskedOps()) { + Operation *mappedOp = map.lookup(maskedOp); + if (auto loadOp = dyn_cast(mappedOp)) { + OpBuilder builder(mappedOp); + auto newLoad = builder.create( + loadOp.getLoc(), loadOp.getPtr(), loadOp.getCache(), + loadOp.getEvict(), loadOp.getIsVolatile()); + mappedOp->replaceAllUsesWith(newLoad); + mappedOp->erase(); + } + // TODO: stores + } + + // Replace the uses of the original loop results. + unsigned idx = 0; + for (Value res : forOp.getResults()) { + if (!res.getUsers().empty()) + res.replaceAllUsesWith(ifOp->getResult(idx++)); + } + + forOp.erase(); + return true; } }; @@ -286,19 +428,37 @@ struct TritonIntelRemoveMasksBase void runOnOperation() final { ModuleOp moduleOp = getOperation(); - // Attempt to version loops so that masked operations in the loop become - // superfluous. + // Version loops containing masked operation in canonical form. moduleOp->walk([&](Operation *op) { if (scf::ForOp forOp = dyn_cast(op)) { // Nested loop aren't currently handled. if (forOp->template getParentOfType()) return WalkResult::advance(); - if (!forOp.getSingleInductionVar() || - !LoopVersioner::hasValidUpperBound(forOp)) + if (!forOp.getSingleInductionVar()) + return WalkResult::advance(); + + CanonicalMaskValidator maskValidator; + MaskedOpsCollector collector(forOp, maskValidator); + if (collector.collectMaskedOps()) { + [[maybe_unused]] bool loopVersioned = + LoopVersioner::version(forOp, collector); + LLVM_DEBUG(if (loopVersioned) llvm::dbgs() << "Loop versioned\n"); + } + } + return WalkResult::advance(); + }); + + // Version loops containing masked operation with a mask defined before the + // loop. + moduleOp->walk([&](Operation *op) { + if (scf::ForOp forOp = dyn_cast(op)) { + // Nested loop aren't currently handled. + if (forOp->template getParentOfType()) return WalkResult::advance(); - MaskedOpsCollector collector(forOp); + InvariantMaskValidator maskValidator; + MaskedOpsCollector collector(forOp, maskValidator); if (collector.collectMaskedOps()) { [[maybe_unused]] bool loopVersioned = LoopVersioner::version(forOp, collector); diff --git a/third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp b/third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp index ba826582c8..0430c852da 100644 --- a/third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp +++ b/third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp @@ -695,12 +695,13 @@ struct TritonRaiseBlockPointer Operation *defOp = value.getDefiningOp(); if (!defOp) { // look init values outside the loop - BlockArgument blockArg = dyn_cast(value); + BlockArgument blockArg = cast(value); Operation *parentOp = blockArg.getOwner()->getParentOp(); scf::ForOp forOp = dyn_cast(parentOp); - return forOp ? hasExpandOpInDefiningPath( - forOp.getInitArgs()[blockArg.getArgNumber() - 1]) - : false; + return forOp && !forOp.getInitArgs().empty() + ? hasExpandOpInDefiningPath( + forOp.getInitArgs()[blockArg.getArgNumber() - 1]) + : false; } if (isa(defOp)) diff --git a/third_party/intel/lib/Utils/Utility.cpp b/third_party/intel/lib/Utils/Utility.cpp index 62f55a74ec..6b75b75f6b 100644 --- a/third_party/intel/lib/Utils/Utility.cpp +++ b/third_party/intel/lib/Utils/Utility.cpp @@ -51,6 +51,7 @@ bool isConstant(Value val, const unsigned expected) { } Value getFinalValue(Value value) { + assert(value && "Expecting a valid value"); Operation *defOp = value.getDefiningOp(); if (!defOp) { // look init values outside the loop From f8e2da5e2c7a96577c8ac71b77f2f5d91cad8b5a Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Mon, 3 Mar 2025 22:42:33 +0000 Subject: [PATCH 6/6] Address code review comments Signed-off-by: Tiotto, Ettore --- .../Dialect/Triton/Transforms/RemoveMasks.cpp | 45 +++++++++++-------- 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/third_party/intel/lib/Dialect/Triton/Transforms/RemoveMasks.cpp b/third_party/intel/lib/Dialect/Triton/Transforms/RemoveMasks.cpp index 4671e10ffa..5d53bf9229 100644 --- a/third_party/intel/lib/Dialect/Triton/Transforms/RemoveMasks.cpp +++ b/third_party/intel/lib/Dialect/Triton/Transforms/RemoveMasks.cpp @@ -21,20 +21,26 @@ namespace { // Abstract base class for mask validators. // Mask validators are used to check whether a given mask has an expected form. -// Concreate subclasses define the expected form. +// Concrete subclasses provide a member function used to select masked +// operations that have a mask in a particular (e.g. desired) form. +// Furthermore concrete mask validators classes might also provide a member +// function class MaskValidatorBase { public: virtual ~MaskValidatorBase() = default; // Check whether the given mask is valid. - virtual bool isValidMask(Value mask, scf::ForOp &forOp) const = 0; + virtual bool isValidMask(scf::ForOp &forOp, Value mask) const = 0; + + // Create the loop versioning condition based on the mask. + virtual Value getVersioningCond(scf::ForOp &forOp, Value mask) const = 0; }; // A mask validator which ensures that the mask can be reduced to the form: // `END < N-i*END`. class CanonicalMaskValidator final : public MaskValidatorBase { public: - virtual bool isValidMask(Value mask, scf::ForOp &forOp) const { + virtual bool isValidMask(scf::ForOp &forOp, Value mask) const { assert(mask && "Expecting a valid mask"); if (!mask.getDefiningOp() || !isa(mask.getDefiningOp())) @@ -82,7 +88,8 @@ class CanonicalMaskValidator final : public MaskValidatorBase { // Create the loop versioning condition, assumes the loop upper bound is the // form `(N+END-1)/END`. - Value getVersioningCond(scf::ForOp forOp) const { + virtual Value getVersioningCond(scf::ForOp &forOp, + Value mask = nullptr) const { assert(hasValidUpperBound(forOp) && "Invalid upper bound"); Value ub = tt::intel::getFinalValue(forOp.getUpperBound()); @@ -138,7 +145,7 @@ class InvariantMaskValidator final : public MaskValidatorBase { // - N < M (with i1 data type) // - [0..END] < splat(N) // - splat(N) < [0..END] - virtual bool isValidMask(Value mask, scf::ForOp &forOp) const { + virtual bool isValidMask(scf::ForOp &forOp, Value mask) const { assert(mask && "Expecting a valid mask"); if (!mask.getDefiningOp() || !isa(mask.getDefiningOp())) @@ -181,8 +188,8 @@ class InvariantMaskValidator final : public MaskValidatorBase { return false; } - Value getVersioningCond(Value mask, scf::ForOp &forOp) const { - assert(isValidMask(mask, forOp) && "Invalid mask"); + virtual Value getVersioningCond(scf::ForOp &forOp, Value mask) const { + assert(isValidMask(forOp, mask) && "Invalid mask"); OpBuilder builder(forOp); Location loc = forOp.getLoc(); @@ -240,7 +247,7 @@ template class MaskedOpsCollector { : isa(op) ? cast(op).getMask() : nullptr; if (mask && - maskValidator.isValidMask(tt::intel::getFinalValue(mask), forOp)) + maskValidator.isValidMask(forOp, tt::intel::getFinalValue(mask))) maskedOps.insert(op); } }; @@ -271,9 +278,9 @@ class LoopVersioner { // Currently we can version the loop only is it doesn't have downward // exposed uses of return values that are a tensor of pointers. // Note: this is due to the fact the results yielded by the 2 versioning - // branches have different types for ptr (only in one versioned loop tensor - // of ptrs are changed to block ptrs) 'then' part of the versioning branch - // and leave them as is in the 'else' branch). + // branches have different types for ptr (only in one versioned loop + // tensor of ptrs are changed to block ptrs) 'then' part of the versioning + // branch and leave them as is in the 'else' branch). auto canVersion = [](scf::ForOp &forOp) { return llvm::any_of(forOp.getResults(), [](Value res) { return !tt::isTensorPointerType(res.getType()) || @@ -283,8 +290,8 @@ class LoopVersioner { if (!canVersion(forOp)) return false; - // Retrieve the versioning condition, bail out if it doesn't exist (in which - // case the loop upper bound is not in canonical form). + // Retrieve the versioning condition, bail out if it doesn't exist (in + // which case the loop upper bound is not in canonical form). Value verCond = collector.getMaskValidator().getVersioningCond(forOp); if (!verCond) return false; @@ -313,8 +320,8 @@ class LoopVersioner { OpBuilder elseB = ifOp.getElseBodyBuilder(); Operation *elseForLoop = elseB.clone(*forOp.getOperation()); - // Collect results in 'clonedLoop' corresponding to downward exposed results - // of the given loop. + // Collect results in 'clonedLoop' corresponding to downward exposed + // results of the given loop. auto pruneUnusedResults = [&](const scf::ForOp &forOp, Operation *clonedLoop) { SmallVector prunedResults; @@ -375,10 +382,10 @@ class LoopVersioner { auto it = maskConds.begin(); Value firstCond = (*it++)->getResult(0); auto maskValidator = collector.getMaskValidator(); - Value verCond = maskValidator.getVersioningCond(firstCond, forOp); + Value verCond = maskValidator.getVersioningCond(forOp, firstCond); for (; it != maskConds.end(); ++it) { Value nextCond = (*it)->getResult(0); - Value cond = maskValidator.getVersioningCond(nextCond, forOp); + Value cond = maskValidator.getVersioningCond(forOp, nextCond); verCond = builder.create(loc, verCond, cond); } @@ -449,8 +456,8 @@ struct TritonIntelRemoveMasksBase return WalkResult::advance(); }); - // Version loops containing masked operation with a mask defined before the - // loop. + // Version loops containing masked operation with a mask defined before + // the loop. moduleOp->walk([&](Operation *op) { if (scf::ForOp forOp = dyn_cast(op)) { // Nested loop aren't currently handled.