Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIRRtToNPU: Rewrite BufferMemrefToFuncArgs as pattern and apply greedily #853

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlir/include/air/Util/Util.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,10 @@ SmallVector<Operation *> cloneDefiningOpsInRegion(OpBuilder builder,
SmallVectorImpl<Value> &opers,
IRMapping &remap);

// Buffer all allocations of memref directly within the func op's body into the
// func op's arguments.
void populateBufferMemrefToFuncArgsPattern(RewritePatternSet &patterns);

} // namespace air
} // namespace xilinx

Expand Down
58 changes: 3 additions & 55 deletions mlir/lib/Conversion/AIRRtToNpuPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1025,11 +1025,10 @@ struct AIRRtToNpuPass : public impl::AIRRtToNpuBase<AIRRtToNpuPass> {
// Unroll any affine for loops
unrollAffineFors(module);

// Buffer npu.dma_memcpy_nd memref to function's argument list.
BufferMemrefToFuncArgs(module);

// Cast buffers to i32 types
// Cast buffers to i32 types; buffer npu.dma_memcpy_nd memref to function's
// argument list.
RewritePatternSet castPattern(ctx);
air::populateBufferMemrefToFuncArgsPattern(castPattern);
castPattern.add(CastFunctionArgs);
(void)applyPatternsAndFoldGreedily(module, std::move(castPattern));

Expand Down Expand Up @@ -1631,57 +1630,6 @@ struct AIRRtToNpuPass : public impl::AIRRtToNpuBase<AIRRtToNpuPass> {
chanToIdMap[col]++));
});
}

// Buffers npu.dma_memcpy_op memref as function argument
void BufferMemrefToFuncArgs(ModuleOp module) {
module.walk([&](mlir::func::FuncOp f) { BufferMemrefToFuncArgs(f); });
}
void BufferMemrefToFuncArgs(func::FuncOp funcOp) {
if (!funcOp)
return;

// Collect illegal dma ops whose memrefs are not in function's arguments.
SmallVector<Type, 6> memrefTypes;
SmallVector<Value, 6> memrefs;
funcOp.walk([&](AIEX::NpuDmaMemcpyNdOp dma) {
Value memref = dma.getMemref();
auto args = funcOp.getArguments();
// if the memref is an arg, return
if (std::find(args.begin(), args.end(), memref) != args.end())
return;
// if the memref is the result of a cast of an arg, return
if (auto cast = dyn_cast_or_null<UnrealizedConversionCastOp>(
memref.getDefiningOp())) {
if (std::find(args.begin(), args.end(), cast.getOperand(0)) !=
args.end())
return;
else
memref = cast.getOperand(0);
}
// push back if unique
if (std::find(memrefs.begin(), memrefs.end(), memref) == memrefs.end()) {
memrefs.push_back(memref);
memrefTypes.push_back(memref.getType());
}
});

// Append memref to function's arguments.
auto functionType = funcOp.getFunctionType();
auto newArgTypes = llvm::to_vector<6>(
llvm::concat<const Type>(functionType.getInputs(), memrefTypes));
auto newFunctionType = FunctionType::get(funcOp.getContext(), newArgTypes,
functionType.getResults());
funcOp.setType(newFunctionType);

// Add the new arguments to the entry block if the function is not external.
if (!funcOp.isExternal()) {
Location loc = funcOp.getLoc();
for (Value v : memrefs) {
auto newArg = funcOp.front().addArgument(v.getType(), loc);
v.replaceAllUsesWith(newArg);
}
}
}
};

} // namespace
Expand Down
54 changes: 54 additions & 0 deletions mlir/lib/Util/Util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1693,3 +1693,57 @@ air::cloneDefiningOpsInRegion(OpBuilder builder, Region *region,
clonedOps.push_back(builder.clone(*op, remap));
return clonedOps;
}

// Buffer all allocations of L3 memref directly within the func op's body into
// the func op's arguments.
struct BufferMemrefToFuncArgsPattern : public OpRewritePattern<func::FuncOp> {
using OpRewritePattern<func::FuncOp>::OpRewritePattern;

LogicalResult matchAndRewrite(func::FuncOp funcOp,
PatternRewriter &rewriter) const override {

if (funcOp.isExternal())
return failure();

SmallVector<Type, 6> memrefTypes;
llvm::SetVector<Value> memrefs;
for (auto &op : funcOp.getFunctionBody().getOps()) {
if (isa<CastOpInterface>(op))
continue;
for (auto res : op.getResults()) {
MemRefType resType = dyn_cast<MemRefType>(res.getType());
if (!resType)
continue;
if (resType.getMemorySpaceAsInt() == (int)air::MemorySpace::L3)
memrefs.insert(res);
}
}
for (auto memref : memrefs)
memrefTypes.push_back(memref.getType());
if (memrefs.empty())
return failure();

// Append memref to function's arguments.
auto functionType = funcOp.getFunctionType();
auto newArgTypes = llvm::to_vector<6>(
llvm::concat<const Type>(functionType.getInputs(), memrefTypes));
auto newFunctionType = FunctionType::get(funcOp.getContext(), newArgTypes,
functionType.getResults());
funcOp.setType(newFunctionType);

// Add the new arguments to the entry block if the function is not external.
Location loc = funcOp.getLoc();
for (Value v : memrefs) {
auto newArg = funcOp.front().addArgument(v.getType(), loc);
v.replaceAllUsesWith(newArg);
}
return success();
}

private:
};

void air::populateBufferMemrefToFuncArgsPattern(RewritePatternSet &patterns) {
MLIRContext *ctx = patterns.getContext();
patterns.insert<BufferMemrefToFuncArgsPattern>(ctx);
}
Loading