Skip to content

Commit

Permalink
add herd kernel arg lowering to air-to-aie
Browse files Browse the repository at this point in the history
Add lowering of air.herd_load to npu.rtp_write
  • Loading branch information
fifield committed Nov 15, 2024
1 parent c7fc6f4 commit 9c9e58b
Show file tree
Hide file tree
Showing 12 changed files with 503 additions and 222 deletions.
7 changes: 4 additions & 3 deletions mlir/include/air/Conversion/AIRToAIESchedulingUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ AIE::TileOp getPhysTileOpOrNull(AIE::DeviceOp aie_device, int col, int row);
AIE::TileOp getPhysTileOp(AIE::DeviceOp aie_device, int col, int row);

AIE::LockOp allocateLockOp(AIE::DeviceOp aie_device, AIE::TileOp tile,
int init = 0, int id = -1);
int init = 0, int id = -1, StringAttr name = nullptr);

std::stringstream
generateBufferNameInStringStream(std::string prefix, uint64_t &BufferId,
generateBufferNameInStringStream(StringRef prefix, uint64_t &BufferId,
mlir::StringAttr attr = nullptr, int x = -1,
int y = -1);

Expand Down Expand Up @@ -195,7 +195,8 @@ void simpleDMAChannelAllocation(std::vector<MemcpyBundleAsFlow> &memcpy_flows,
ShimDMAAllocator &shim_dma_alloc,
MemTileDMAAllocator &memtile_dma_alloc,
TileDMAAllocator &tile_dma_alloc);
template <typename T> int foundInVector(T item, std::vector<T> vec);
template <typename T>
int foundInVector(T item, std::vector<T> vec);
int getSCFForLoopDepth(Operation *o);
bool groupingMemcpysByLoop(std::vector<MemcpyBundleAsFlow> &memcpy_flows);

Expand Down
19 changes: 12 additions & 7 deletions mlir/lib/Conversion/AIRLoweringPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,14 +231,20 @@ class AIRHerdConversion : public ConversionPattern {
return failure();
}

{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(op->getBlock());
rewriter.create<airrt::HerdLoadOp>(op->getLoc(), rewriter.getI64Type(),
herd_name_attr.getValue().str(),
/* operands */ SmallVector<Value>());
// Integer kernel operands are passed as arguments (runtime parameters) to
// the herd load op.
SmallVector<Value> args;
for (int i = 0, e = herd.getNumKernelOperands(); i < e; i++) {
Value o = herd.getKernelOperand(i);
if (o.use_empty())
continue;
if (llvm::isa<IntegerType, IndexType, FloatType>(o.getType()))
args.push_back(o);
}

rewriter.create<airrt::HerdLoadOp>(op->getLoc(), rewriter.getI64Type(),
herd_name_attr.getValue().str(), args);

SmallVector<Value, 4> deps;
for (auto &o : operands)
if (llvm::isa<airrt::EventType>(o.getType()))
Expand Down Expand Up @@ -853,7 +859,6 @@ class ScfParOpConversion : public OpConversionPattern<scf::ParallelOp> {
};

LogicalResult ScfParToAffineForConversion(Operation *op) {

func::FuncOp f = dyn_cast<func::FuncOp>(op);
if (!f)
return failure();
Expand Down
64 changes: 64 additions & 0 deletions mlir/lib/Conversion/AIRRtToNpuPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,66 @@ struct HerdLoadToNpuPattern : public OpConversionPattern<HerdLoadOp> {
LogicalResult
matchAndRewrite(HerdLoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto module = op->getParentOfType<ModuleOp>();

// get the size metadata associated with this herd load
int64_t size_x = -1;
int64_t size_y = -1;
int64_t loc_x = -1;
int64_t loc_y = -1;
module.walk([&](HerdMetadataOp metadata) {
// return the first match by name
if (metadata.getSymName() != op.getSymName())
return WalkResult::advance();
auto sxAttr = metadata->getAttrOfType<IntegerAttr>("size_x");
auto syAttr = metadata->getAttrOfType<IntegerAttr>("size_y");
auto lxAttr = metadata->getAttrOfType<IntegerAttr>("loc_x");
auto lyAttr = metadata->getAttrOfType<IntegerAttr>("loc_y");
if (sxAttr && syAttr && lxAttr && lyAttr) {
size_x = sxAttr.getInt();
size_y = syAttr.getInt();
loc_x = lxAttr.getInt();
loc_y = lyAttr.getInt();
} else {
metadata.emitWarning(
"airrt.herd_metadata missing size_x, size_y, loc_x, or loc_y.");
}
return WalkResult::interrupt();
});
if (size_x < 0 || size_y < 0 || loc_x < 0 || loc_y < 0) {
op.emitWarning(
"airrt.herd_metadata missing or incomplete.");
return failure();
}

// for each herd core, emit write_rtp ops for every herd operand
// followed by a write32 to the herd lock, setting it to 1.
for (int phys_x = loc_x; phys_x < size_x + loc_x; phys_x++) {
for (int phys_y = loc_y; phys_y < size_y + loc_y; phys_y++) {

for (int i = 0, e = op.getNumOperands(); i < e; i++) {
Value oper = adaptor.getOperands()[i];
if (!llvm::isa<IntegerType, IndexType, FloatType>(oper.getType()))
continue;

std::string name = "__air_herd_rtp_" + std::to_string(phys_x) + "_" +
std::to_string(phys_y);
auto constOp =
dyn_cast_if_present<arith::ConstantOp>(oper.getDefiningOp());
if (!constOp)
continue;
uint32_t v = cast<IntegerAttr>(constOp.getValue()).getInt();
rewriter.create<AIEX::NpuWriteRTPOp>(op.getLoc(), name, i, v);
}
// FIXME: this should depend on the metadata to enable and to get the id
if (op.getNumOperands())
rewriter.create<AIEX::NpuWrite32Op>(op.getLoc(), 0x0001F000, 0x1,
nullptr,
rewriter.getI32IntegerAttr(phys_x),
rewriter.getI32IntegerAttr(phys_y));
}
}
rewriter.eraseOp(op);
return success();
}
Expand Down Expand Up @@ -1350,6 +1410,10 @@ struct AIRRtToNpuPass : public impl::AIRRtToNpuBase<AIRRtToNpuPass> {
auto chan = builder.getI32IntegerAttr(infoOp->getChannelIndex());
auto col_num = builder.getI32IntegerAttr(1);
auto row_num = builder.getI32IntegerAttr(1);
// FIXME: setting the insertion point to the end is a hack for
// RTP POC, so that the sync is after the rtp
// writes and the herd lock aquire.
// builder.setInsertionPoint(dma->getBlock()->getTerminator());
builder.setInsertionPointAfter(dma);
builder.create<AIEX::NpuSyncOp>(dma->getLoc(), col, row, dir, chan,
col_num, row_num);
Expand Down
Loading

0 comments on commit 9c9e58b

Please sign in to comment.