Skip to content

Commit

Permalink
[compiler/StableHloExt] Incorporate some simplification patterns from…
Browse files Browse the repository at this point in the history
… upstream StableHLO (#384)

Our next LLVM & StableHLO upgrade will incorporate some additional
simplification
patterns. However, the upgrade is large and will not fully land until
mid
    next week.
    
Until then, incorporate some critical concat and slice simplification
patterns
that simplify shape calculation IR. Solves
#381.
    
GitOrigin-RevId: 2a875a35547b89c96d530878907312a0f898e508
  • Loading branch information
christopherbate authored Nov 16, 2024
1 parent d4a3058 commit 49fede3
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "mlir/Dialect/CommonFolders.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "stablehlo/dialect/TypeInference.h"
Expand Down Expand Up @@ -1064,6 +1065,10 @@ struct AbsorbTensorCastProducer : public RewritePattern {
};
} // namespace

/// Populates patterns that are temporarily reproduced here from upstream
/// commits we have not yet integrated.
static void populateFutureUpstreamPatterns(RewritePatternSet &patterns);

void stablehlo_ext::populateStableHloAbsorbTensorCastPatterns(
RewritePatternSet &patterns) {
patterns.add<AbsorbTensorCastProducer>(patterns.getContext());
Expand Down Expand Up @@ -1108,6 +1113,7 @@ class ConstantFoldingPass
SqrtOpFolder
>(ctx);
// clang-format on
populateFutureUpstreamPatterns(patterns);
populateStableHloAbsorbTensorCastPatterns(patterns);
stablehlo::populateStablehloCanonicalizationPatterns(ctx, &patterns);
tensor::EmptyOp::getCanonicalizationPatterns(patterns, ctx);
Expand All @@ -1124,3 +1130,150 @@ class ConstantFoldingPass
}
};
} // namespace

//===----------------------------------------------------------------------===//
/// The patterns below this point are reproduced from
/// https://github.com/openxla/stablehlo/commit/5d15ab064f165cc6773ef4ba949ac083ae8e1fea,
/// which is in upstream, but our current pinned StableHlo commit is not there
/// yet. The patterns can be removed in the next StableHLO upgrade.
///
//===----------------------------------------------------------------------===//

///
/// In cases where a concat is fed into a slice, it
/// is possible the concat can be simplified or bypassed. This checks which
/// inputs to the concat are used by the slice, either reducing the number of
/// concatenated values or entirely removes the concat. Pattern:
/// slice(concat(X,Y,Z,...),...) -> concat(slice(X),slice(Y),slice(Z))
struct SimplifySliceOfConcat : public OpRewritePattern<SliceOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(SliceOp slice,
PatternRewriter &rewriter) const override {
RankedTensorType resultTy = slice.getType();
if (!resultTy.hasStaticShape())
return rewriter.notifyMatchFailure(slice, "result shape not static");

auto concat = slice.getOperand().getDefiningOp<ConcatenateOp>();
if (!concat)
return rewriter.notifyMatchFailure(slice, "slice input not concat");

RankedTensorType concatType = concat.getType();
uint64_t dimension = concat.getDimension();

ArrayRef<int64_t> start = slice.getStartIndices();
ArrayRef<int64_t> limit = slice.getLimitIndices();

int64_t sliceStart = start[dimension];
int64_t sliceLimit = limit[dimension];

// We need to determine what inputs from the concat affect the slice, and
// how the bounds of the slice need to be updated for the minimally required
// inputs.
int64_t runningSize = 0;
int64_t frontOffset = concatType.getShape()[dimension];

auto subsetStart = concat.operand_end();
auto subsetEnd = concat.operand_end();
for (auto it = concat.operand_begin(); it < concat.operand_end(); ++it) {
Value input = *it;
auto inputTy = cast<RankedTensorType>(input.getType());
if (inputTy.isDynamicDim(dimension))
return rewriter.notifyMatchFailure(
slice, "concat input has dynamic dimension");

int64_t dimSize = inputTy.getShape()[dimension];

// If this position is in the slice its the start of the subset and we
// need to update the start and limit values.
if (runningSize + dimSize > sliceStart &&
subsetStart == concat.operand_end()) {
subsetStart = it;
frontOffset = runningSize;
}

// Determine the last required offset.
if (runningSize < sliceLimit) {
subsetEnd = it + 1;
}

runningSize += dimSize;
}

auto subsetSize = subsetEnd - subsetStart;
// We need all inputs so no optimization.
if (subsetSize == concat.getNumOperands())
return rewriter.notifyMatchFailure(slice,
"slice needs all concat inputs");

// If there's nothing to slice that means the output is an empty tensor and
// there is dead code. We do nothing here and rely on other passes to clean
// this up.
if (subsetSize == 0)
return rewriter.notifyMatchFailure(slice, "slice is empty");

if (subsetSize > 1 && !concat.getResult().hasOneUse())
return rewriter.notifyMatchFailure(slice,
"slice is not the only concat user");

auto concatRange = OperandRange(subsetStart, subsetEnd);
auto newConcat = rewriter.create<ConcatenateOp>(
concat.getLoc(), concatRange, concat.getDimension());

SmallVector<int64_t> newStart(start);
SmallVector<int64_t> newLimit(limit);
newStart[dimension] -= frontOffset;
newLimit[dimension] -= frontOffset;

rewriter.replaceOpWithNewOp<SliceOp>(
slice, newConcat, rewriter.getDenseI64ArrayAttr(newStart),
rewriter.getDenseI64ArrayAttr(newLimit), slice.getStrides());
return success();
}
};

/// Flatten sequential concatenations as long as the parent concatenation either
/// has a single use or is <= 32 elements.
class SimplifyConcatOfConcatPattern
: public OpRewritePattern<stablehlo::ConcatenateOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ConcatenateOp op,
PatternRewriter &rewriter) const override {
auto getFlattenedOperands = [&](const Value &val) -> ValueRange {
auto definingOp = dyn_cast_or_null<ConcatenateOp>(val.getDefiningOp());
if (!definingOp || definingOp.getDimension() != op.getDimension())
return val;
if (definingOp->hasOneUse())
return definingOp.getInputs();
if (!definingOp.getType().hasStaticShape())
return val;
if (definingOp.getType().getNumElements() <= 32)
return definingOp.getInputs();
return val;
};

bool needToFlatten = false;
int operandCount = 0;
for (Value val : op.getInputs()) {
ValueRange result = getFlattenedOperands(val);
if (result.size() != 1 || result[0] != val)
needToFlatten = true;
operandCount += result.size();
}
if (!needToFlatten)
return rewriter.notifyMatchFailure(op, "no need to flatten");

llvm::SmallVector<Value, 6> newOperands;
newOperands.reserve(operandCount);
for (Value operand : op.getInputs())
llvm::append_range(newOperands, getFlattenedOperands(operand));

rewriter.modifyOpInPlace(op, [&] { op->setOperands(newOperands); });
return success();
}
};

void populateFutureUpstreamPatterns(RewritePatternSet &patterns) {
patterns.add<SimplifySliceOfConcat, SimplifyConcatOfConcatPattern>(
patterns.getContext());
}
16 changes: 16 additions & 0 deletions mlir-tensorrt/test/Dialect/StableHloExt/constant-folding.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,22 @@ func.func @concat_simplify_single_operand_requires_cast(%arg0: tensor<4xi32>) ->

// -----

func.func @concat_slice_concat(%arg0: tensor<1xi32>, %arg1: tensor<3xi32>, %arg2: tensor<1xi32>) -> tensor<5xi32> {
%0 = stablehlo.concatenate %arg0, %arg1, %arg2, dim = 0 : (tensor<1xi32>, tensor<3xi32>, tensor<1xi32>) -> tensor<5xi32>
%1 = stablehlo.slice %0 [1:5] : (tensor<5xi32>) -> tensor<4xi32>
%2 = stablehlo.constant dense<1> : tensor<1xi32>
%3 = stablehlo.concatenate %2, %1, dim = 0 : (tensor<1xi32>, tensor<4xi32>) -> tensor<5xi32>
return %3 : tensor<5xi32>
}

// CHECK-LABEL: func.func @concat_slice_concat
// CHECK-SAME: (%[[arg0:.+]]: tensor<1xi32>, %[[arg1:.+]]: tensor<3xi32>, %[[arg2:.+]]: tensor<1xi32>) -> tensor<5xi32>
// CHECK: %[[c:.+]] = stablehlo.constant dense<1> : tensor<1xi32>
// CHECK: %[[v0:.+]] = stablehlo.concatenate %[[c]], %[[arg1]], %[[arg2]], dim = 0
// CHECK: return %[[v0]] : tensor<5xi32>

// -----

func.func @bitwise_or_fold_lhs(%arg0: tensor<5xi8>, %arg1: tensor<5xi1>, %arg2: tensor<5xi32>) -> (tensor<5xi8>, tensor<5xi1>, tensor<5xi32>, tensor<5xi32>){
%0 = stablehlo.constant dense<[255, 255, 255, 255, 255]> : tensor<5xi8>
%1 = stablehlo.or %0, %arg0 : tensor<5xi8>
Expand Down

0 comments on commit 49fede3

Please sign in to comment.