diff --git a/mlir/include/air/Conversion/Passes.td b/mlir/include/air/Conversion/Passes.td index 99d8ab14d..97c073fc7 100644 --- a/mlir/include/air/Conversion/Passes.td +++ b/mlir/include/air/Conversion/Passes.td @@ -24,8 +24,10 @@ def ParallelToHerd : Pass<"air-par-to-herd", "ModuleOp"> { }]; let options = [ Option<"clAssignDepth", "depth", "int", - /*default=*/"-1", - "Given a nest of parallel for loops, which depth to map to air.herd">, + /*default=*/"-2", + "Given a nest of parallel for loops, which depth to map to air.herd. " + "-1 means converting the innermost parallel loop; any other negative " + "value means converting all parallel loops">, Option<"clFirstDim", "first-dim", "int", /*default=*/"0", "Which herd dimension to map to first. Can be zero or one. If set to " @@ -49,8 +51,10 @@ def ParallelToLaunch : Pass<"air-par-to-launch", "ModuleOp"> { }]; let options = [ Option<"clAssignDepth", "depth", "int", - /*default=*/"-1", - "Given a nest of parallel for loops, which depth to map to air.launch">, + /*default=*/"-2", + "Given a nest of parallel for loops, which depth to map to air.launch" + "-1 means converting the innermost parallel loop; any other negative " + "value means converting all parallel loops">, Option<"clHasSegment", "has-air-segment", "bool", /*default=*/"false", "Whether to create an air.segment op in generated air.launch " "regions">, @@ -68,8 +72,10 @@ def ParallelToSegment : Pass<"air-par-to-segment", "ModuleOp"> { }]; let options = [ Option<"clAssignDepth", "depth", "int", - /*default=*/"-1", - "Given a nest of parallel for loops, which depth to map to air.segment">, + /*default=*/"-2", + "Given a nest of parallel for loops, which depth to map to air.segment" + "-1 means converting the innermost parallel loop; any other negative " + "value means converting all parallel loops">, ]; } diff --git a/mlir/lib/Conversion/ConvertToAIRPass.cpp b/mlir/lib/Conversion/ConvertToAIRPass.cpp index 763f49859..3e575fa31 100644 --- a/mlir/lib/Conversion/ConvertToAIRPass.cpp +++ b/mlir/lib/Conversion/ConvertToAIRPass.cpp @@ -1165,11 +1165,25 @@ struct ParallelToHerdPass if (llvm::any_of(hierOps, [op](Operation *h) { return op->isProperAncestor(h); })) return; + // Depth = -1 means converting the innermost parallel ops + if (clAssignDepth == -1) { + SmallVector parOpsInOp; + op->walk([&parOpsInOp](Operation *o) { + if (isa(o)) + parOpsInOp.push_back(o); + }); + if (parOpsInOp.size() > 1) + return; + filteredOps.insert(op); + return; + } + // Assigning depth to other negative values means converting all + // parallel ops if (clAssignDepth < 0) { filteredOps.insert(op); return; } - // the number of nested scf.parallel above this one + // the number of nested parallel above this one int parallel_depth = 0; Operation *par = op; while ((par = par->getParentOp())) @@ -1253,11 +1267,25 @@ struct ParallelToLaunchPass return op->isProperAncestor(l); })) return; + // Depth = -1 means converting the innermost parallel ops + if (clAssignDepth == -1) { + SmallVector parOpsInOp; + op->walk([&parOpsInOp](Operation *o) { + if (isa(o)) + parOpsInOp.push_back(o); + }); + if (parOpsInOp.size() > 1) + return; + filteredOps.insert(op); + return; + } + // Assigning depth to other negative values means converting all + // parallel ops if (clAssignDepth < 0) { filteredOps.insert(op); return; } - // the number of nested scf.parallel above this one + // the number of nested parallel above this one int parallel_depth = 0; Operation *par = op; while ((par = par->getParentOp())) @@ -1342,11 +1370,25 @@ struct ParallelToSegmentPass return op->isProperAncestor(s); })) return; + // Depth = -1 means converting the innermost parallel ops + if (clAssignDepth == -1) { + SmallVector parOpsInOp; + op->walk([&parOpsInOp](Operation *o) { + if (isa(o)) + parOpsInOp.push_back(o); + }); + if (parOpsInOp.size() > 1) + return; + filteredOps.insert(op); + return; + } + // Assigning depth to other negative values means converting all + // parallel ops if (clAssignDepth < 0) { filteredOps.insert(op); return; } - // the number of nested scf.parallel above this one + // the number of nested parallel above this one int parallel_depth = 0; Operation *par = op; while ((par = par->getParentOp())) diff --git a/mlir/test/Conversion/ConvertToAIR/scf_forall_to_herd.mlir b/mlir/test/Conversion/ConvertToAIR/scf_forall_to_herd.mlir index 72ce513ac..abe4e09f3 100644 --- a/mlir/test/Conversion/ConvertToAIR/scf_forall_to_herd.mlir +++ b/mlir/test/Conversion/ConvertToAIR/scf_forall_to_herd.mlir @@ -6,6 +6,9 @@ //===----------------------------------------------------------------------===// // RUN: air-opt -split-input-file -verify-diagnostics -air-par-to-herd %s | FileCheck %s +// RUN: air-opt -split-input-file -verify-diagnostics -air-par-to-herd="depth=-1" %s | FileCheck %s --check-prefix=DEPTHM1 +// RUN: air-opt -split-input-file -verify-diagnostics -air-par-to-herd="depth=0" %s | FileCheck %s --check-prefix=DEPTH0 +// RUN: air-opt -split-input-file -verify-diagnostics -air-par-to-herd="depth=1" %s | FileCheck %s --check-prefix=DEPTH1 // CHECK-LABEL: func.func @scf0() { // CHECK: air.herd @herd_0 tile (%{{.*}}, %{{.*}}) in (%{{.*}}=%c2{{.*}}, %{{.*}}=%c2{{.*}}) @@ -98,6 +101,27 @@ func.func @scf4() { // CHECK: } // CHECK: } // CHECK: } +// DEPTHM1-LABEL: func.func @scf5() { +// DEPTHM1: scf.forall {{.*}} { +// DEPTHM1: scf.forall {{.*}} { +// DEPTHM1: air.herd @herd_{{.*}} { +// DEPTHM1: } +// DEPTHM1: } +// DEPTHM1: } +// DEPTH0-LABEL: func.func @scf5() { +// DEPTH0: air.herd @herd_{{.*}} { +// DEPTH0: scf.forall {{.*}} { +// DEPTH0: scf.forall {{.*}} { +// DEPTH0: } +// DEPTH0: } +// DEPTH0: } +// DEPTH1-LABEL: func.func @scf5() { +// DEPTH1: scf.forall {{.*}} { +// DEPTH1: air.herd @herd_{{.*}} { +// DEPTH1: scf.forall {{.*}} { +// DEPTH1: } +// DEPTH1: } +// DEPTH1: } func.func @scf5() { %src = memref.alloc() : memref<4x4x4xi32, 2 : i32> %dst = memref.alloc() : memref<4x4x4xi32, 2 : i32> diff --git a/mlir/test/Conversion/ConvertToAIR/scf_parallel_to_herd.mlir b/mlir/test/Conversion/ConvertToAIR/scf_parallel_to_herd.mlir index ed36cf332..c46e5c1d6 100644 --- a/mlir/test/Conversion/ConvertToAIR/scf_parallel_to_herd.mlir +++ b/mlir/test/Conversion/ConvertToAIR/scf_parallel_to_herd.mlir @@ -7,6 +7,8 @@ //===----------------------------------------------------------------------===// // RUN: air-opt -split-input-file -verify-diagnostics -air-par-to-herd %s | FileCheck %s +// RUN: air-opt -split-input-file -verify-diagnostics -air-par-to-herd="depth=-1" %s | FileCheck %s --check-prefix=DEPTHM1 +// RUN: air-opt -split-input-file -verify-diagnostics -air-par-to-herd="depth=0" %s | FileCheck %s --check-prefix=DEPTH0 // CHECK-LABEL: func.func @scf0() { // CHECK: %[[C2:.*]] = arith.constant 2 : index @@ -201,6 +203,34 @@ module { // CHECK: return // CHECK: } // CHECK: } +// DEPTHM1-LABEL: @shared_herd_name +// DEPTHM1: scf.parallel {{.*}} { +// DEPTHM1: air.herd @herd_0 +// DEPTHM1: } +// DEPTHM1: air.herd @herd_0 +// DEPTHM1: } +// DEPTHM1: air.herd @herd_0 +// DEPTHM1: } +// DEPTHM1: scf.reduce +// DEPTHM1: } +// DEPTHM1: return +// DEPTHM1: } +// DEPTHM1: } +// DEPTH0-LABEL: @shared_herd_name +// DEPTH0: air.herd @herd_0 +// DEPTH0: scf.parallel {{.*}} +// DEPTH0: scf.reduce +// DEPTH0: } +// DEPTH0: scf.parallel {{.*}} +// DEPTH0: scf.reduce +// DEPTH0: } +// DEPTH0: scf.parallel {{.*}} +// DEPTH0: scf.reduce +// DEPTH0: } +// DEPTH0: } +// DEPTH0: return +// DEPTH0: } +// DEPTH0: } module { func.func @shared_herd_name(%arg0: memref<512x1024xbf16>, %arg1: memref<1024x512xbf16>, %arg2: memref<512x512xbf16>) { %c32 = arith.constant 32 : index @@ -248,6 +278,29 @@ module { // CHECK: return // CHECK: } // CHECK: } +// DEPTHM1-LABEL: @unique_herd_name +// DEPTHM1: scf.parallel {{.*}} { +// DEPTHM1: air.herd @herd_0 +// DEPTHM1: } +// DEPTHM1: air.herd @herd_1 +// DEPTHM1: } +// DEPTHM1: scf.reduce +// DEPTHM1: } +// DEPTHM1: return +// DEPTHM1: } +// DEPTHM1: } +// DEPTH0-LABEL: @unique_herd_name +// DEPTH0: air.herd @herd_0 +// DEPTH0: scf.parallel {{.*}} { +// DEPTH0: scf.reduce +// DEPTH0: } +// DEPTH0: scf.parallel {{.*}} { +// DEPTH0: scf.reduce +// DEPTH0: } +// DEPTH0: } +// DEPTH0: return +// DEPTH0: } +// DEPTH0: } module { func.func @unique_herd_name(%arg0: memref<512x1024xbf16>, %arg1: memref<1024x512xbf16>, %arg2: memref<512x512xbf16>) { %c32 = arith.constant 32 : index @@ -303,6 +356,29 @@ module { // CHECK: return // CHECK: } // CHECK: } +// DEPTHM1-LABEL: @l2_to_l1_dma_infer_herd +// DEPTHM1: scf.parallel {{.*}} { +// DEPTHM1: air.herd @herd_0 +// DEPTHM1: } +// DEPTHM1: air.herd @herd_0 +// DEPTHM1: } +// DEPTHM1: scf.reduce +// DEPTHM1: } +// DEPTHM1: return +// DEPTHM1: } +// DEPTHM1: } +// DEPTH0-LABEL: @l2_to_l1_dma_infer_herd +// DEPTH0: air.herd @herd_0 +// DEPTH0: scf.parallel {{.*}} { +// DEPTH0: scf.reduce +// DEPTH0: } +// DEPTH0: scf.parallel {{.*}} { +// DEPTH0: scf.reduce +// DEPTH0: } +// DEPTH0: } +// DEPTH0: return +// DEPTH0: } +// DEPTH0: } module { func.func @l2_to_l1_dma_infer_herd() { %c32 = arith.constant 32 : index