From d52a065862c32bb6ad9de98ff0491eed668c0486 Mon Sep 17 00:00:00 2001 From: erwei-xilinx Date: Fri, 24 Jan 2025 14:58:42 -0800 Subject: [PATCH 1/2] Fixup dependency canonicalizer where wait_all should never be the source for memref RW analysis --- mlir/lib/Dialect/AIR/IR/AIRDialect.cpp | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/AIR/IR/AIRDialect.cpp b/mlir/lib/Dialect/AIR/IR/AIRDialect.cpp index 27b248cae..a7d229b81 100644 --- a/mlir/lib/Dialect/AIR/IR/AIRDialect.cpp +++ b/mlir/lib/Dialect/AIR/IR/AIRDialect.cpp @@ -295,12 +295,23 @@ static LogicalResult CanonicalizeAsyncOpDeps(OpT op, auto memrefsReadBySinkOp = getAllMemrefsReadByOp(op.getOperation()); auto memrefsWrittenBySinkOp = getAllMemrefsWrittenByOp(op.getOperation()); // make a list of new async token operands + std::function, SmallVector &)> + getDirectDependenciesGreedily; + getDirectDependenciesGreedily = [&getDirectDependenciesGreedily]( + SmallVector depList, + SmallVector &directDeps) { + for (auto v : depList) { + if (auto wa = dyn_cast_if_present(v.getDefiningOp())) + getDirectDependenciesGreedily(wa.getAsyncDependencies(), directDeps); + else + directDeps.push_back(v); + } + return; + }; llvm::SetVector newAsyncDeps; // don't include duplicates - for (auto v : op.getAsyncDependencies()) { - // don't include wait_all ops with no operands - if (auto wa = dyn_cast_if_present(v.getDefiningOp())) - if (wa.getAsyncDependencies().size() == 0) - continue; + SmallVector directDeps; + getDirectDependenciesGreedily(op.getAsyncDependencies(), directDeps); + for (auto v : directDeps) { // don't include any false dependencies, i.e. sink does not depend on source // in RAW, WAR or WAW; RAR is a false dependency if (v.getDefiningOp()) { From 2f3230aefecbdcd10093f64661ded52c5c2d3cd1 Mon Sep 17 00:00:00 2001 From: erwei-xilinx Date: Fri, 24 Jan 2025 15:00:14 -0800 Subject: [PATCH 2/2] Add more mlir ir tests showing the memref RW-based dependency canonicalizer --- mlir/test/Dialect/AIR/air_canonicalize.mlir | 42 +++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/mlir/test/Dialect/AIR/air_canonicalize.mlir b/mlir/test/Dialect/AIR/air_canonicalize.mlir index e5b40dda7..ededb904e 100644 --- a/mlir/test/Dialect/AIR/air_canonicalize.mlir +++ b/mlir/test/Dialect/AIR/air_canonicalize.mlir @@ -363,6 +363,48 @@ func.func @chan_0(%arg0 : memref<4x1x64x64xbf16>, %arg1 : memref<1x4x64x64xbf16> return } +// CHECK: func.func @chan_1 +// CHECK: %[[TOKEN0:.*]] = air.channel.put async @channel_0 +// CHECK: %[[TOKEN1:.*]] = air.channel.put async @channel_1 +func.func @chan_1(%arg0 : memref<4x1x64x64xbf16>) { + %1 = air.channel.put async @channel_0[] (%arg0[] [] []) : (memref<4x1x64x64xbf16>) + %2 = air.channel.put async [%1] @channel_1[] (%arg0[] [] []) : (memref<4x1x64x64xbf16>) + return +} + +// CHECK: func.func @chan_2 +// CHECK: %[[TOKEN0:.*]] = air.channel.get async @channel_0 +// CHECK: %[[TOKEN1:.*]] = air.channel.get async [%[[TOKEN0]]] @channel_1 +func.func @chan_2(%arg0 : memref<4x1x64x64xbf16>) { + %0 = air.channel.get async @channel_0[] (%arg0[] [] []) : (memref<4x1x64x64xbf16>) + %1 = air.channel.get async [%0] @channel_1[] (%arg0[] [] []) : (memref<4x1x64x64xbf16>) + return +} + +// CHECK: func.func @chan_3 +// CHECK: %[[TOKEN0:.*]] = air.channel.get async @channel_0 +// CHECK: %[[TOKEN1:.*]] = air.channel.put async [%[[TOKEN0]]] @channel_1 +func.func @chan_3(%arg0 : memref<4x1x64x64xbf16>) { + %0 = air.channel.get async @channel_0[] (%arg0[] [] []) : (memref<4x1x64x64xbf16>) + %1 = air.wait_all async [%0] + %2 = air.channel.put async [%1] @channel_1[] (%arg0[] [] []) : (memref<4x1x64x64xbf16>) + return +} + +// CHECK: func.func @chan_4 +// CHECK: %[[TOKEN0:.*]] = air.channel.get async @channel_0 +// CHECK: %[[TOKEN1:.*]] = air.channel.get async @channel_1 +// CHECK: %[[TOKEN2:.*]] = air.channel.put async [%[[TOKEN0]]] @channel_2 +// CHECK: %[[TOKEN3:.*]] = air.channel.put async [%[[TOKEN1]]] @channel_3 +func.func @chan_4(%arg0 : memref<4x1x64x64xbf16>, %arg1 : memref<4x1x64x64xbf16>) { + %0 = air.channel.get async @channel_0[] (%arg0[] [] []) : (memref<4x1x64x64xbf16>) + %1 = air.channel.get async @channel_1[] (%arg1[] [] []) : (memref<4x1x64x64xbf16>) + %2 = air.wait_all async [%0, %1] + %3 = air.channel.put async [%2] @channel_2[] (%arg0[] [] []) : (memref<4x1x64x64xbf16>) + %4 = air.channel.put async [%2] @channel_3[] (%arg1[] [] []) : (memref<4x1x64x64xbf16>) + return +} + // CHECK: func.func @dma_compose_subview // CHECK: air.dma_memcpy_nd (%{{.*}}[%c0{{.*}}, %c0{{.*}}] [%c32{{.*}}, %c32{{.*}}] [%c64{{.*}}, %c1{{.*}}], %{{.*}}[%c0{{.*}}, %c0{{.*}}] [%c32{{.*}}, %c32{{.*}}] [%c64{{.*}}, %c1{{.*}}] func.func @dma_compose_subview() {