Skip to content

Commit

Permalink
Convert forall to parallel before par-to-herd/launch (#861)
Browse files Browse the repository at this point in the history
  • Loading branch information
erwei-xilinx authored Jan 17, 2025
1 parent 9f75b46 commit 1646603
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 218 deletions.
268 changes: 50 additions & 218 deletions mlir/lib/Conversion/ConvertToAIRPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -610,122 +610,6 @@ class ScfParToHerdConversion : public OpRewritePattern<scf::ParallelOp> {
int firstDim;
};

class ScfForallToHerdConversion : public OpRewritePattern<scf::ForallOp> {
public:
using OpRewritePattern<scf::ForallOp>::OpRewritePattern;

ScfForallToHerdConversion(MLIRContext *ctx,
SmallPtrSet<Operation *, 8> &filteredOps,
llvm::SmallSet<air::HerdOp, 2> &replacementOps,
int firstDim)
: OpRewritePattern(ctx), filteredOps(filteredOps),
replacementOps(replacementOps), firstDim(firstDim){};

LogicalResult matchAndRewrite(scf::ForallOp parOp,
PatternRewriter &rewriter) const override {

scf::ForallOp op = parOp;

if (!filteredOps.contains(op))
return failure();

auto loc = op.getLoc();

if (op.getRank() > 2) {
unsigned split_idx = op.getRank() - 2;
SmallVector<OpFoldResult> outerLowerBounds, outerUpperBounds, outerSteps;
SmallVector<OpFoldResult> innerLowerBounds, innerUpperBounds, innerSteps;

for (unsigned i = 0, e = split_idx; i < e; ++i) {
outerLowerBounds.push_back(op.getMixedLowerBound()[i]);
outerUpperBounds.push_back(op.getMixedUpperBound()[i]);
outerSteps.push_back(op.getMixedStep()[i]);
}
auto outerLoop = rewriter.create<scf::ParallelOp>(
loc, getValueOrCreateConstantIndexOp(rewriter, loc, outerLowerBounds),
getValueOrCreateConstantIndexOp(rewriter, loc, outerUpperBounds),
getValueOrCreateConstantIndexOp(rewriter, loc, outerSteps));
for (unsigned i = 0, e = split_idx; i < e; ++i)
op.getInductionVars()[i].replaceAllUsesWith(
outerLoop.getInductionVars()[i]);

rewriter.setInsertionPointToStart(outerLoop.getBody());

for (unsigned i = split_idx, e = op.getRank(); i < e; ++i) {
innerLowerBounds.push_back(op.getMixedLowerBound()[i]);
innerUpperBounds.push_back(op.getMixedUpperBound()[i]);
innerSteps.push_back(op.getMixedStep()[i]);
}
auto innerLoop = rewriter.create<scf::ForallOp>(
loc, innerLowerBounds, innerUpperBounds, innerSteps, ValueRange(),
std::nullopt);
for (unsigned i = split_idx, e = op.getRank(); i < e; ++i)
op.getInductionVars()[i].replaceAllUsesWith(
innerLoop.getInductionVars()[i - split_idx]);

auto &body = op.getBody()->getOperations();
innerLoop.getBody()->getOperations().splice(
innerLoop.getBody()->begin(), body, body.begin(), --body.end());
op = innerLoop;
}

SmallVector<int, 2> bounds{1, 1};
for (unsigned int i = 0; i < op.getRank(); i++) {
int64_t ub_int = op.getStaticUpperBound()[i];
int64_t step_int = op.getStaticStep()[i];
bounds[i] = ub_int / step_int;
}
SmallVector<Value, 4> args;
SmallVector<Value, 4> constants;
llvm::SetVector<Value> region_args;
getUsedValuesDefinedAbove(op.getRegion(), region_args);
for (Value v : region_args) {
if (isa_and_present<arith::ConstantOp>(v.getDefiningOp()))
constants.push_back(v);
else
args.push_back(v);
}

int idx0 = firstDim;
int idx1 = (firstDim + 1) % 2;
SmallVector<Value, 2> dims{
rewriter.create<arith::ConstantIndexOp>(loc, bounds[idx0]),
rewriter.create<arith::ConstantIndexOp>(loc, bounds[idx1])};
auto herdOp = rewriter.create<air::HerdOp>(op.getLoc(), dims, args);
auto &bb = herdOp.getBody().front();
auto ivs = op.getInductionVars();

propagateLinkWith(op, herdOp);

ivs[0].replaceAllUsesWith(herdOp.getIds()[idx0]);
if (op.getRank() == 2)
ivs[1].replaceAllUsesWith(herdOp.getIds()[idx1]);

auto &body = op.getBody()->getOperations();
bb.getOperations().splice(bb.begin(), body, body.begin(), --body.end());
rewriter.setInsertionPointToStart(&herdOp.getRegion().front());
replaceAllUsesOfConstsInRegionWithNew(constants, rewriter,
herdOp.getRegion());

int i = 0;
auto kernel_args = herdOp.getKernelArguments();
for (Value v : args)
replaceAllUsesInRegionWith(v, kernel_args[i++], herdOp.getRegion());

if (op != parOp)
rewriter.eraseOp(op);
rewriter.eraseOp(parOp);
replacementOps.insert(herdOp);

return success();
}

private:
llvm::SmallPtrSet<Operation *, 8> &filteredOps;
llvm::SmallSet<air::HerdOp, 2> &replacementOps;
int firstDim;
};

LogicalResult
getMemrefBackwardSlices(Value &memref, Operation *&memrefAlloc,
SmallVector<Operation *> &backwardSlices) {
Expand Down Expand Up @@ -997,107 +881,37 @@ class ScfParToLaunchConversion : public OpRewritePattern<scf::ParallelOp> {
bool generateSegment;
};

class ScfForallToLaunchConversion : public OpRewritePattern<scf::ForallOp> {
public:
/// Pattern to rewriter scf.forall -> scf.parallel after bufferization.
class SCFForAllToParallelOp : public OpRewritePattern<scf::ForallOp> {
using OpRewritePattern<scf::ForallOp>::OpRewritePattern;

ScfForallToLaunchConversion(MLIRContext *ctx,
llvm::SmallSet<Operation *, 8> &filteredOps,
llvm::SmallSet<air::LaunchOp, 2> &replacementOps,
bool generateSegment)
: OpRewritePattern(ctx), filteredOps(filteredOps),
replacementOps(replacementOps), generateSegment(generateSegment){};

LogicalResult matchAndRewrite(scf::ForallOp forOp,
LogicalResult matchAndRewrite(scf::ForallOp forallOp,
PatternRewriter &rewriter) const override {

scf::ForallOp op = forOp;

if (!filteredOps.contains(op))
if (forallOp.getNumResults() != 0) {
return failure();

// if (failed(normalizeScfParallel(op, rewriter)))
// return failure();

auto loc = op.getLoc();

SmallVector<int, 4> bounds(op.getRank(), 1);
for (unsigned int i = 0; i < op.getRank(); i++) {
int64_t lb_int = op.getStaticLowerBound()[i];
int64_t ub_int = op.getStaticUpperBound()[i];
int64_t step_int = op.getStaticStep()[i];

// must start at 0
if (lb_int)
return failure();

// step must divide upper bound evenly
if (ub_int % step_int)
return failure();

ub_int = ub_int / step_int;
bounds[i] = ub_int;
}

SmallVector<Value, 4> args;
SmallVector<Value, 4> constants;
llvm::SetVector<Value> region_args;
getUsedValuesDefinedAbove(op.getRegion(), region_args);
for (Value v : region_args) {
if (isa_and_present<arith::ConstantOp>(v.getDefiningOp()))
constants.push_back(v);
else
args.push_back(v);
}

SmallVector<Value, 4> sizes;
for (auto b : bounds)
sizes.push_back(rewriter.create<arith::ConstantIndexOp>(loc, b));
auto launch = rewriter.create<air::LaunchOp>(op.getLoc(), sizes, args);

rewriter.setInsertionPointToStart(&launch.getRegion().front());

if (generateSegment) {
auto segment = generateEmptySegmentOp(rewriter, op, launch);
replaceAllUsesOfConstsInRegionWithNew(constants, rewriter,
segment.getRegion());
int i = 0;
auto kernel_args = segment.getKernelArguments();
kernel_args = kernel_args.drop_front(
launch.getIds().size() +
launch.getSize().size()); // Launch's induction vars
for (Value v : args)
replaceAllUsesInRegionWith(v, kernel_args[i++], segment.getRegion());
} else {
auto &bb = launch.getBody().front();
auto ivs = op.getInductionVars();

for (int i = 0, e = ivs.size(); i < e; i++) {
ivs[i].replaceAllUsesWith(launch.getIds()[i]);
}

auto &body = op.getBody()->getOperations();
bb.getOperations().splice(bb.begin(), body, body.begin(), --body.end());
replaceAllUsesOfConstsInRegionWithNew(constants, rewriter,
launch.getRegion());
int i = 0;
auto kernel_args = launch.getKernelArguments();
for (Value v : args)
replaceAllUsesInRegionWith(v, kernel_args[i++], launch.getRegion());
}

if (op != forOp)
op.erase();
rewriter.eraseOp(forOp);
replacementOps.insert(launch);

Location loc = forallOp.getLoc();
SmallVector<Value> lowerBounds = getValueOrCreateConstantIndexOp(
rewriter, loc, forallOp.getMixedLowerBound());
SmallVector<Value> upperBounds = getValueOrCreateConstantIndexOp(
rewriter, loc, forallOp.getMixedUpperBound());
SmallVector<Value> step =
getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep());
auto parallelOp = rewriter.create<scf::ParallelOp>(
loc, lowerBounds, upperBounds, step, ValueRange{},
[&](OpBuilder &b, Location bodyLoc, ValueRange ivs,
ValueRange regionArgs) {});
rewriter.inlineRegionBefore(forallOp.getRegion(), parallelOp.getRegion(),
parallelOp.getRegion().begin());
rewriter.eraseBlock(&parallelOp.getRegion().back());
// Fixup the terminator
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointToEnd(&parallelOp.getRegion().front());
rewriter.replaceOpWithNewOp<scf::ReduceOp>(
parallelOp.getRegion().front().getTerminator());
rewriter.replaceOp(forallOp, parallelOp);
return success();
}

private:
llvm::SmallSet<Operation *, 8> &filteredOps;
llvm::SmallSet<air::LaunchOp, 2> &replacementOps;
bool generateSegment;
};

struct CopyToDmaPass : public air::impl::CopyToDmaBase<CopyToDmaPass> {
Expand Down Expand Up @@ -1319,6 +1133,19 @@ struct ParallelToHerdPass
LLVM_DEBUG(llvm::outs() << "input\n");
LLVM_DEBUG(module.print(llvm::outs()));

// Preprocessing: convert forall to parallel.
RewritePatternSet preprocPatterns(context);
preprocPatterns.add<SCFForAllToParallelOp>(context);
ConversionTarget preprocTarget(*context);
preprocTarget.addLegalDialect<scf::SCFDialect, arith::ArithDialect>();
preprocTarget.addIllegalOp<scf::ForallOp>();
if (failed(applyPartialConversion(module, preprocTarget,
std::move(preprocPatterns)))) {
signalPassFailure();
}
LLVM_DEBUG(llvm::outs() << "ir after preprocessing\n");
LLVM_DEBUG(module.print(llvm::outs()));

// Ensure that air.dma_memcpy_nd ops between L1 and L2 are within at least
// two parent scf.parallel loops.
module.walk([&](air::DmaMemcpyNdOp op) {
Expand Down Expand Up @@ -1401,8 +1228,6 @@ struct ParallelToHerdPass
patterns.add<AffineParToHerdConversion>(context);
patterns.add<ScfParToHerdConversion>(context, filteredOps, replacementOps,
clFirstDim);
patterns.add<ScfForallToHerdConversion>(context, filteredOps,
replacementOps, clFirstDim);

ConversionTarget target(*context);

Expand Down Expand Up @@ -1448,6 +1273,19 @@ struct ParallelToLaunchPass
LLVM_DEBUG(llvm::outs() << "input\n");
LLVM_DEBUG(module.print(llvm::outs()));

// Preprocessing: convert forall to parallel.
RewritePatternSet preprocPatterns(context);
preprocPatterns.add<SCFForAllToParallelOp>(context);
ConversionTarget preprocTarget(*context);
preprocTarget.addLegalDialect<scf::SCFDialect, arith::ArithDialect>();
preprocTarget.addIllegalOp<scf::ForallOp>();
if (failed(applyPartialConversion(module, preprocTarget,
std::move(preprocPatterns)))) {
signalPassFailure();
}
LLVM_DEBUG(llvm::outs() << "ir after preprocessing\n");
LLVM_DEBUG(module.print(llvm::outs()));

llvm::SmallVector<air::LaunchOp> launchOps;
module.walk([&](air::LaunchOp op) { launchOps.push_back(op); });

Expand Down Expand Up @@ -1538,8 +1376,6 @@ struct ParallelToLaunchPass
RewritePatternSet patterns(context);
patterns.add<ScfParToLaunchConversion>(context, filteredOps, replacementOps,
clHasSegment);
patterns.add<ScfForallToLaunchConversion>(context, filteredOps,
replacementOps, clHasSegment);

ConversionTarget target(*context);

Expand Down Expand Up @@ -1611,8 +1447,6 @@ transform::ParToHerdOp::applyToOne(transform::TransformRewriter &rewriter,
filteredOps.insert(target);
patterns.add<ScfParToHerdConversion>(ctx, filteredOps, herdOps,
getFirstDim());
patterns.add<ScfForallToHerdConversion>(ctx, filteredOps, herdOps,
getFirstDim());
(void)applyPatternsGreedily(
target->getParentWithTrait<OpTrait::IsIsolatedFromAbove>(),
std::move(patterns));
Expand All @@ -1639,8 +1473,6 @@ transform::ParToLaunchOp::applyToOne(transform::TransformRewriter &rewriter,
filteredOps.insert(target);
patterns.add<ScfParToLaunchConversion>(ctx, filteredOps, launchOps,
getHasAirSegment());
patterns.add<ScfForallToLaunchConversion>(ctx, filteredOps, launchOps,
getHasAirSegment());
(void)applyPatternsGreedily(
target->getParentWithTrait<OpTrait::IsIsolatedFromAbove>(),
std::move(patterns));
Expand Down
35 changes: 35 additions & 0 deletions mlir/test/Conversion/ConvertToAIR/scf_forall_to_herd.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,41 @@ func.func @scf2() {

// -----

// CHECK: [[$MAP0:#map[0-9]*]] = affine_map<(d0) -> (d0 * 2)>
// CHECK: [[$MAP1:#map[0-9]*]] = affine_map<(d0) -> (d0 + 1)>
// CHECK-LABEL: func.func @scf3() {
// CHECK: air.herd @herd_0 tile (%[[VAL_0:.*]], %[[VAL_1:.*]]) in (%{{.*}}=%c3{{.*}}, %{{.*}}=%c2{{.*}})
// CHECK: affine.apply [[$MAP0]](%[[VAL_1]])
// CHECK: affine.apply [[$MAP1]](%[[VAL_0]])
func.func @scf3() {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c4 = arith.constant 4 : index
scf.forall (%i, %j) = (%c1, %c0) to (%c4, %c4)
step (%c1, %c2) {
%2 = arith.muli %i, %j : index
}
return
}

// -----

// CHECK: [[$MAP0:#map[0-9]*]] = affine_map<(d0) -> (d0 * 2)>
// CHECK: [[$MAP1:#map[0-9]*]] = affine_map<(d0) -> (d0 + 1)>
// CHECK-LABEL: func.func @scf4() {
// CHECK: air.herd @herd_0 tile (%[[VAL_0:.*]], %[[VAL_1:.*]]) in (%{{.*}}=%c3{{.*}}, %{{.*}}=%c2{{.*}})
// CHECK: affine.apply [[$MAP0]](%[[VAL_1]])
// CHECK: affine.apply [[$MAP1]](%[[VAL_0]])
func.func @scf4() {
scf.forall (%i, %j) = (1, 0) to (4, 4) step (1, 2) {
%2 = arith.muli %i, %j : index
}
return
}

// -----

// This test demonstrates that while forming air.herd we look through func.call ops, fetch
// the corresponding function declaration's 'link_with' attribute and attach it to the newly
// formed air.herd op.
Expand Down

0 comments on commit 1646603

Please sign in to comment.