diff --git a/mlir/lib/Conversion/ConvertToAIRPass.cpp b/mlir/lib/Conversion/ConvertToAIRPass.cpp index 57a53a94b..0aaee556e 100644 --- a/mlir/lib/Conversion/ConvertToAIRPass.cpp +++ b/mlir/lib/Conversion/ConvertToAIRPass.cpp @@ -610,122 +610,6 @@ class ScfParToHerdConversion : public OpRewritePattern { int firstDim; }; -class ScfForallToHerdConversion : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - ScfForallToHerdConversion(MLIRContext *ctx, - SmallPtrSet &filteredOps, - llvm::SmallSet &replacementOps, - int firstDim) - : OpRewritePattern(ctx), filteredOps(filteredOps), - replacementOps(replacementOps), firstDim(firstDim){}; - - LogicalResult matchAndRewrite(scf::ForallOp parOp, - PatternRewriter &rewriter) const override { - - scf::ForallOp op = parOp; - - if (!filteredOps.contains(op)) - return failure(); - - auto loc = op.getLoc(); - - if (op.getRank() > 2) { - unsigned split_idx = op.getRank() - 2; - SmallVector outerLowerBounds, outerUpperBounds, outerSteps; - SmallVector innerLowerBounds, innerUpperBounds, innerSteps; - - for (unsigned i = 0, e = split_idx; i < e; ++i) { - outerLowerBounds.push_back(op.getMixedLowerBound()[i]); - outerUpperBounds.push_back(op.getMixedUpperBound()[i]); - outerSteps.push_back(op.getMixedStep()[i]); - } - auto outerLoop = rewriter.create( - loc, getValueOrCreateConstantIndexOp(rewriter, loc, outerLowerBounds), - getValueOrCreateConstantIndexOp(rewriter, loc, outerUpperBounds), - getValueOrCreateConstantIndexOp(rewriter, loc, outerSteps)); - for (unsigned i = 0, e = split_idx; i < e; ++i) - op.getInductionVars()[i].replaceAllUsesWith( - outerLoop.getInductionVars()[i]); - - rewriter.setInsertionPointToStart(outerLoop.getBody()); - - for (unsigned i = split_idx, e = op.getRank(); i < e; ++i) { - innerLowerBounds.push_back(op.getMixedLowerBound()[i]); - innerUpperBounds.push_back(op.getMixedUpperBound()[i]); - innerSteps.push_back(op.getMixedStep()[i]); - } - auto innerLoop = rewriter.create( - loc, innerLowerBounds, innerUpperBounds, innerSteps, ValueRange(), - std::nullopt); - for (unsigned i = split_idx, e = op.getRank(); i < e; ++i) - op.getInductionVars()[i].replaceAllUsesWith( - innerLoop.getInductionVars()[i - split_idx]); - - auto &body = op.getBody()->getOperations(); - innerLoop.getBody()->getOperations().splice( - innerLoop.getBody()->begin(), body, body.begin(), --body.end()); - op = innerLoop; - } - - SmallVector bounds{1, 1}; - for (unsigned int i = 0; i < op.getRank(); i++) { - int64_t ub_int = op.getStaticUpperBound()[i]; - int64_t step_int = op.getStaticStep()[i]; - bounds[i] = ub_int / step_int; - } - SmallVector args; - SmallVector constants; - llvm::SetVector region_args; - getUsedValuesDefinedAbove(op.getRegion(), region_args); - for (Value v : region_args) { - if (isa_and_present(v.getDefiningOp())) - constants.push_back(v); - else - args.push_back(v); - } - - int idx0 = firstDim; - int idx1 = (firstDim + 1) % 2; - SmallVector dims{ - rewriter.create(loc, bounds[idx0]), - rewriter.create(loc, bounds[idx1])}; - auto herdOp = rewriter.create(op.getLoc(), dims, args); - auto &bb = herdOp.getBody().front(); - auto ivs = op.getInductionVars(); - - propagateLinkWith(op, herdOp); - - ivs[0].replaceAllUsesWith(herdOp.getIds()[idx0]); - if (op.getRank() == 2) - ivs[1].replaceAllUsesWith(herdOp.getIds()[idx1]); - - auto &body = op.getBody()->getOperations(); - bb.getOperations().splice(bb.begin(), body, body.begin(), --body.end()); - rewriter.setInsertionPointToStart(&herdOp.getRegion().front()); - replaceAllUsesOfConstsInRegionWithNew(constants, rewriter, - herdOp.getRegion()); - - int i = 0; - auto kernel_args = herdOp.getKernelArguments(); - for (Value v : args) - replaceAllUsesInRegionWith(v, kernel_args[i++], herdOp.getRegion()); - - if (op != parOp) - rewriter.eraseOp(op); - rewriter.eraseOp(parOp); - replacementOps.insert(herdOp); - - return success(); - } - -private: - llvm::SmallPtrSet &filteredOps; - llvm::SmallSet &replacementOps; - int firstDim; -}; - LogicalResult getMemrefBackwardSlices(Value &memref, Operation *&memrefAlloc, SmallVector &backwardSlices) { @@ -997,107 +881,37 @@ class ScfParToLaunchConversion : public OpRewritePattern { bool generateSegment; }; -class ScfForallToLaunchConversion : public OpRewritePattern { -public: +/// Pattern to rewriter scf.forall -> scf.parallel after bufferization. +class SCFForAllToParallelOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - ScfForallToLaunchConversion(MLIRContext *ctx, - llvm::SmallSet &filteredOps, - llvm::SmallSet &replacementOps, - bool generateSegment) - : OpRewritePattern(ctx), filteredOps(filteredOps), - replacementOps(replacementOps), generateSegment(generateSegment){}; - - LogicalResult matchAndRewrite(scf::ForallOp forOp, + LogicalResult matchAndRewrite(scf::ForallOp forallOp, PatternRewriter &rewriter) const override { - - scf::ForallOp op = forOp; - - if (!filteredOps.contains(op)) + if (forallOp.getNumResults() != 0) { return failure(); - - // if (failed(normalizeScfParallel(op, rewriter))) - // return failure(); - - auto loc = op.getLoc(); - - SmallVector bounds(op.getRank(), 1); - for (unsigned int i = 0; i < op.getRank(); i++) { - int64_t lb_int = op.getStaticLowerBound()[i]; - int64_t ub_int = op.getStaticUpperBound()[i]; - int64_t step_int = op.getStaticStep()[i]; - - // must start at 0 - if (lb_int) - return failure(); - - // step must divide upper bound evenly - if (ub_int % step_int) - return failure(); - - ub_int = ub_int / step_int; - bounds[i] = ub_int; - } - - SmallVector args; - SmallVector constants; - llvm::SetVector region_args; - getUsedValuesDefinedAbove(op.getRegion(), region_args); - for (Value v : region_args) { - if (isa_and_present(v.getDefiningOp())) - constants.push_back(v); - else - args.push_back(v); - } - - SmallVector sizes; - for (auto b : bounds) - sizes.push_back(rewriter.create(loc, b)); - auto launch = rewriter.create(op.getLoc(), sizes, args); - - rewriter.setInsertionPointToStart(&launch.getRegion().front()); - - if (generateSegment) { - auto segment = generateEmptySegmentOp(rewriter, op, launch); - replaceAllUsesOfConstsInRegionWithNew(constants, rewriter, - segment.getRegion()); - int i = 0; - auto kernel_args = segment.getKernelArguments(); - kernel_args = kernel_args.drop_front( - launch.getIds().size() + - launch.getSize().size()); // Launch's induction vars - for (Value v : args) - replaceAllUsesInRegionWith(v, kernel_args[i++], segment.getRegion()); - } else { - auto &bb = launch.getBody().front(); - auto ivs = op.getInductionVars(); - - for (int i = 0, e = ivs.size(); i < e; i++) { - ivs[i].replaceAllUsesWith(launch.getIds()[i]); - } - - auto &body = op.getBody()->getOperations(); - bb.getOperations().splice(bb.begin(), body, body.begin(), --body.end()); - replaceAllUsesOfConstsInRegionWithNew(constants, rewriter, - launch.getRegion()); - int i = 0; - auto kernel_args = launch.getKernelArguments(); - for (Value v : args) - replaceAllUsesInRegionWith(v, kernel_args[i++], launch.getRegion()); } - - if (op != forOp) - op.erase(); - rewriter.eraseOp(forOp); - replacementOps.insert(launch); - + Location loc = forallOp.getLoc(); + SmallVector lowerBounds = getValueOrCreateConstantIndexOp( + rewriter, loc, forallOp.getMixedLowerBound()); + SmallVector upperBounds = getValueOrCreateConstantIndexOp( + rewriter, loc, forallOp.getMixedUpperBound()); + SmallVector step = + getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep()); + auto parallelOp = rewriter.create( + loc, lowerBounds, upperBounds, step, ValueRange{}, + [&](OpBuilder &b, Location bodyLoc, ValueRange ivs, + ValueRange regionArgs) {}); + rewriter.inlineRegionBefore(forallOp.getRegion(), parallelOp.getRegion(), + parallelOp.getRegion().begin()); + rewriter.eraseBlock(¶llelOp.getRegion().back()); + // Fixup the terminator + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPointToEnd(¶llelOp.getRegion().front()); + rewriter.replaceOpWithNewOp( + parallelOp.getRegion().front().getTerminator()); + rewriter.replaceOp(forallOp, parallelOp); return success(); } - -private: - llvm::SmallSet &filteredOps; - llvm::SmallSet &replacementOps; - bool generateSegment; }; struct CopyToDmaPass : public air::impl::CopyToDmaBase { @@ -1319,6 +1133,19 @@ struct ParallelToHerdPass LLVM_DEBUG(llvm::outs() << "input\n"); LLVM_DEBUG(module.print(llvm::outs())); + // Preprocessing: convert forall to parallel. + RewritePatternSet preprocPatterns(context); + preprocPatterns.add(context); + ConversionTarget preprocTarget(*context); + preprocTarget.addLegalDialect(); + preprocTarget.addIllegalOp(); + if (failed(applyPartialConversion(module, preprocTarget, + std::move(preprocPatterns)))) { + signalPassFailure(); + } + LLVM_DEBUG(llvm::outs() << "ir after preprocessing\n"); + LLVM_DEBUG(module.print(llvm::outs())); + // Ensure that air.dma_memcpy_nd ops between L1 and L2 are within at least // two parent scf.parallel loops. module.walk([&](air::DmaMemcpyNdOp op) { @@ -1401,8 +1228,6 @@ struct ParallelToHerdPass patterns.add(context); patterns.add(context, filteredOps, replacementOps, clFirstDim); - patterns.add(context, filteredOps, - replacementOps, clFirstDim); ConversionTarget target(*context); @@ -1448,6 +1273,19 @@ struct ParallelToLaunchPass LLVM_DEBUG(llvm::outs() << "input\n"); LLVM_DEBUG(module.print(llvm::outs())); + // Preprocessing: convert forall to parallel. + RewritePatternSet preprocPatterns(context); + preprocPatterns.add(context); + ConversionTarget preprocTarget(*context); + preprocTarget.addLegalDialect(); + preprocTarget.addIllegalOp(); + if (failed(applyPartialConversion(module, preprocTarget, + std::move(preprocPatterns)))) { + signalPassFailure(); + } + LLVM_DEBUG(llvm::outs() << "ir after preprocessing\n"); + LLVM_DEBUG(module.print(llvm::outs())); + llvm::SmallVector launchOps; module.walk([&](air::LaunchOp op) { launchOps.push_back(op); }); @@ -1538,8 +1376,6 @@ struct ParallelToLaunchPass RewritePatternSet patterns(context); patterns.add(context, filteredOps, replacementOps, clHasSegment); - patterns.add(context, filteredOps, - replacementOps, clHasSegment); ConversionTarget target(*context); @@ -1611,8 +1447,6 @@ transform::ParToHerdOp::applyToOne(transform::TransformRewriter &rewriter, filteredOps.insert(target); patterns.add(ctx, filteredOps, herdOps, getFirstDim()); - patterns.add(ctx, filteredOps, herdOps, - getFirstDim()); (void)applyPatternsGreedily( target->getParentWithTrait(), std::move(patterns)); @@ -1639,8 +1473,6 @@ transform::ParToLaunchOp::applyToOne(transform::TransformRewriter &rewriter, filteredOps.insert(target); patterns.add(ctx, filteredOps, launchOps, getHasAirSegment()); - patterns.add(ctx, filteredOps, launchOps, - getHasAirSegment()); (void)applyPatternsGreedily( target->getParentWithTrait(), std::move(patterns)); diff --git a/mlir/test/Conversion/ConvertToAIR/scf_forall_to_herd.mlir b/mlir/test/Conversion/ConvertToAIR/scf_forall_to_herd.mlir index 2447e2fea..8a8b8811e 100644 --- a/mlir/test/Conversion/ConvertToAIR/scf_forall_to_herd.mlir +++ b/mlir/test/Conversion/ConvertToAIR/scf_forall_to_herd.mlir @@ -49,6 +49,41 @@ func.func @scf2() { // ----- +// CHECK: [[$MAP0:#map[0-9]*]] = affine_map<(d0) -> (d0 * 2)> +// CHECK: [[$MAP1:#map[0-9]*]] = affine_map<(d0) -> (d0 + 1)> +// CHECK-LABEL: func.func @scf3() { +// CHECK: air.herd @herd_0 tile (%[[VAL_0:.*]], %[[VAL_1:.*]]) in (%{{.*}}=%c3{{.*}}, %{{.*}}=%c2{{.*}}) +// CHECK: affine.apply [[$MAP0]](%[[VAL_1]]) +// CHECK: affine.apply [[$MAP1]](%[[VAL_0]]) +func.func @scf3() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + scf.forall (%i, %j) = (%c1, %c0) to (%c4, %c4) + step (%c1, %c2) { + %2 = arith.muli %i, %j : index + } + return +} + +// ----- + +// CHECK: [[$MAP0:#map[0-9]*]] = affine_map<(d0) -> (d0 * 2)> +// CHECK: [[$MAP1:#map[0-9]*]] = affine_map<(d0) -> (d0 + 1)> +// CHECK-LABEL: func.func @scf4() { +// CHECK: air.herd @herd_0 tile (%[[VAL_0:.*]], %[[VAL_1:.*]]) in (%{{.*}}=%c3{{.*}}, %{{.*}}=%c2{{.*}}) +// CHECK: affine.apply [[$MAP0]](%[[VAL_1]]) +// CHECK: affine.apply [[$MAP1]](%[[VAL_0]]) +func.func @scf4() { + scf.forall (%i, %j) = (1, 0) to (4, 4) step (1, 2) { + %2 = arith.muli %i, %j : index + } + return +} + +// ----- + // This test demonstrates that while forming air.herd we look through func.call ops, fetch // the corresponding function declaration's 'link_with' attribute and attach it to the newly // formed air.herd op.