Skip to content

Commit

Permalink
Switch to converting only filtered foralls into parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
erwei-xilinx committed Jan 20, 2025
1 parent fb3211e commit 408294a
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 57 deletions.
101 changes: 44 additions & 57 deletions mlir/lib/Conversion/ConvertToAIRPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -861,18 +861,6 @@ class ScfParToSegmentConversion : public OpRewritePattern<scf::ParallelOp> {
llvm::SmallSet<air::SegmentOp, 2> &replacementOps;
};

/// Pattern to rewriter scf.forall -> scf.parallel after bufferization.
class SCFForAllToParallelOp : public OpRewritePattern<scf::ForallOp> {
using OpRewritePattern<scf::ForallOp>::OpRewritePattern;

LogicalResult matchAndRewrite(scf::ForallOp forallOp,
PatternRewriter &rewriter) const override {
if (forallOp.getNumResults() != 0)
return failure();
return forallToParallelLoop(rewriter, forallOp);
}
};

struct CopyToDmaPass : public air::impl::CopyToDmaBase<CopyToDmaPass> {

CopyToDmaPass() = default;
Expand Down Expand Up @@ -1072,6 +1060,28 @@ static void getSegmentNames(ModuleOp module) {
}
}

// Convert forall to parallel in filtered ops
LogicalResult
ConvertForallToParallelInFilteredOps(SmallPtrSet<Operation *, 8> &filteredOps,
mlir::MLIRContext *context) {
IRRewriter rewriter(context);
SmallVector<Operation *> fErased, fAdded;
for (auto op : filteredOps) {
auto forall = dyn_cast<scf::ForallOp>(op);
if (!forall)
continue;
scf::ParallelOp newPar;
fErased.push_back(op);
if (failed(forallToParallelLoop(rewriter, forall, &newPar)))
return failure();
fAdded.push_back(newPar);
}
for (auto e : fErased)
assert(filteredOps.erase(e));
filteredOps.insert(fAdded.begin(), fAdded.end());
return success();
}

struct ParallelToHerdPass
: public air::impl::ParallelToHerdBase<ParallelToHerdPass> {

Expand All @@ -1092,19 +1102,6 @@ struct ParallelToHerdPass
LLVM_DEBUG(llvm::outs() << "input\n");
LLVM_DEBUG(module.print(llvm::outs()));

// Preprocessing: convert forall to parallel.
RewritePatternSet preprocPatterns(context);
preprocPatterns.add<SCFForAllToParallelOp>(context);
ConversionTarget preprocTarget(*context);
preprocTarget.addLegalDialect<scf::SCFDialect, arith::ArithDialect>();
preprocTarget.addIllegalOp<scf::ForallOp>();
if (failed(applyPartialConversion(module, preprocTarget,
std::move(preprocPatterns)))) {
signalPassFailure();
}
LLVM_DEBUG(llvm::outs() << "ir after preprocessing\n");
LLVM_DEBUG(module.print(llvm::outs()));

// Ensure that air.dma_memcpy_nd ops between L1 and L2 are within at least
// two parent scf.parallel loops.
module.walk([&](air::DmaMemcpyNdOp op) {
Expand Down Expand Up @@ -1159,7 +1156,7 @@ struct ParallelToHerdPass
SmallPtrSet<Operation *, 8> filteredOps;
llvm::SmallSet<air::HerdOp, 2> replacementOps;
module.walk([&](Operation *op) {
if (!isa<scf::ParallelOp, affine::AffineParallelOp>(op))
if (!isa<scf::ForallOp, scf::ParallelOp, affine::AffineParallelOp>(op))
return;
// skip parallel op already inside herd
if (op->getParentOfType<air::HerdOp>())
Expand All @@ -1176,13 +1173,17 @@ struct ParallelToHerdPass
int parallel_depth = 0;
Operation *par = op;
while ((par = par->getParentOp()))
if (isa<scf::ParallelOp, affine::AffineParallelOp>(par))
if (isa<scf::ForallOp, scf::ParallelOp, affine::AffineParallelOp>(par))
parallel_depth++;
if (parallel_depth != clAssignDepth)
return;
filteredOps.insert(op);
});

// Convert forall to parallel in filtered ops
if (failed(ConvertForallToParallelInFilteredOps(filteredOps, context)))
signalPassFailure();

RewritePatternSet patterns(context);
patterns.add<AffineParToHerdConversion>(context);
patterns.add<ScfParToHerdConversion>(context, filteredOps, replacementOps,
Expand Down Expand Up @@ -1236,25 +1237,14 @@ struct ParallelToLaunchPass
LLVM_DEBUG(llvm::outs() << "input\n");
LLVM_DEBUG(module.print(llvm::outs()));

// Preprocessing: convert forall to parallel.
RewritePatternSet preprocPatterns(context);
preprocPatterns.add<SCFForAllToParallelOp>(context);
ConversionTarget preprocTarget(*context);
preprocTarget.addLegalDialect<scf::SCFDialect, arith::ArithDialect>();
preprocTarget.addIllegalOp<scf::ForallOp>();
if (failed(applyPartialConversion(module, preprocTarget,
std::move(preprocPatterns)))) {
signalPassFailure();
}
LLVM_DEBUG(llvm::outs() << "ir after preprocessing\n");
LLVM_DEBUG(module.print(llvm::outs()));

llvm::SmallVector<air::LaunchOp> launchOps;
module.walk([&](air::LaunchOp op) { launchOps.push_back(op); });

llvm::SmallSet<Operation *, 8> filteredOps;
llvm::SmallSet<air::LaunchOp, 2> replacementOps;
module.walk([&](scf::ParallelOp op) {
module.walk([&](Operation *op) {
if (!isa<scf::ForallOp, scf::ParallelOp>(op))
return;
if (op->getParentOfType<air::HerdOp>())
return;
if (op->getParentOfType<air::LaunchOp>())
Expand All @@ -1271,13 +1261,17 @@ struct ParallelToLaunchPass
int parallel_depth = 0;
Operation *par = op;
while ((par = par->getParentOp()))
if (isa<scf::ParallelOp>(par))
if (isa<scf::ForallOp, scf::ParallelOp>(par))
parallel_depth++;
if (parallel_depth != clAssignDepth)
return;
filteredOps.insert(op);
});

// Convert forall to parallel in filtered ops
if (failed(ConvertForallToParallelInFilteredOps(filteredOps, context)))
signalPassFailure();

RewritePatternSet patterns(context);
patterns.add<ScfParToLaunchConversion>(context, filteredOps, replacementOps,
clHasSegment);
Expand Down Expand Up @@ -1332,25 +1326,14 @@ struct ParallelToSegmentPass
LLVM_DEBUG(llvm::outs() << "input\n");
LLVM_DEBUG(module.print(llvm::outs()));

// Preprocessing: convert forall to parallel.
RewritePatternSet preprocPatterns(context);
preprocPatterns.add<SCFForAllToParallelOp>(context);
ConversionTarget preprocTarget(*context);
preprocTarget.addLegalDialect<scf::SCFDialect, arith::ArithDialect>();
preprocTarget.addIllegalOp<scf::ForallOp>();
if (failed(applyPartialConversion(module, preprocTarget,
std::move(preprocPatterns)))) {
signalPassFailure();
}
LLVM_DEBUG(llvm::outs() << "ir after preprocessing\n");
LLVM_DEBUG(module.print(llvm::outs()));

llvm::SmallVector<air::SegmentOp> segmentOps;
module.walk([&](air::SegmentOp op) { segmentOps.push_back(op); });

llvm::SmallSet<Operation *, 8> filteredOps;
llvm::SmallSet<air::SegmentOp, 2> replacementOps;
module.walk([&](scf::ParallelOp op) {
module.walk([&](Operation *op) {
if (!isa<scf::ForallOp, scf::ParallelOp>(op))
return;
if (op->getParentOfType<air::HerdOp>())
return;
if (op->getParentOfType<air::SegmentOp>())
Expand All @@ -1367,13 +1350,17 @@ struct ParallelToSegmentPass
int parallel_depth = 0;
Operation *par = op;
while ((par = par->getParentOp()))
if (isa<scf::ParallelOp>(par))
if (isa<scf::ForallOp, scf::ParallelOp>(par))
parallel_depth++;
if (parallel_depth != clAssignDepth)
return;
filteredOps.insert(op);
});

// Convert forall to parallel in filtered ops
if (failed(ConvertForallToParallelInFilteredOps(filteredOps, context)))
signalPassFailure();

RewritePatternSet patterns(context);
patterns.add<ScfParToSegmentConversion>(context, filteredOps,
replacementOps);
Expand Down
23 changes: 23 additions & 0 deletions mlir/test/Conversion/ConvertToAIR/scf_forall_to_herd.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,29 @@ func.func @scf4() {

// -----

// CHECK-LABEL: func.func @scf5() {
// CHECK: air.herd @herd_{{.*}} {
// CHECK: air.herd @herd_{{.*}} {
// CHECK: air.herd @herd_{{.*}} {
// CHECK: }
// CHECK: }
// CHECK: }
func.func @scf5() {
%src = memref.alloc() : memref<4x4x4xi32, 2 : i32>
%dst = memref.alloc() : memref<4x4x4xi32, 2 : i32>
scf.forall (%i) = (0) to (4) step (1) {
scf.forall (%j) = (0) to (4) step (1) {
scf.forall (%k) = (0) to (4) step (1) {
%0 = memref.load %src[%i, %j, %k] : memref<4x4x4xi32, 2 : i32>
memref.store %0, %dst[%i, %j, %k] : memref<4x4x4xi32, 2 : i32>
}
}
}
return
}

// -----

// This test demonstrates that while forming air.herd we look through func.call ops, fetch
// the corresponding function declaration's 'link_with' attribute and attach it to the newly
// formed air.herd op.
Expand Down

0 comments on commit 408294a

Please sign in to comment.