-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implemented tiling and fusion path for GPU #383
base: main
Are you sure you want to change the base?
Conversation
04326d8
to
dbaea96
Compare
7656d44
to
db5fbc1
Compare
7466a14
to
acc2dec
Compare
8faf654
to
def5e26
Compare
3e29f3b
to
3597a30
Compare
There are currently two tests failing:
The first one can be fixed by reducing the number of work group down to 16 (a hacky one, but will work for now until we figure out a proper fix): diff --git a/test/mlir/test/gc/gpu-runner/XeGPU/f16_mlp_32x4096x4096x4096.mlir b/test/mlir/test/gc/gpu-runner/XeGPU/f16_mlp_32x4096x4096x4096.mlir
index cb3f5972..4e0f1265 100644
--- a/test/mlir/test/gc/gpu-runner/XeGPU/f16_mlp_32x4096x4096x4096.mlir
+++ b/test/mlir/test/gc/gpu-runner/XeGPU/f16_mlp_32x4096x4096x4096.mlir
@@ -1,6 +1,6 @@
// RUN: gc-gpu-runner --shared-libs=%mlir_runner_utils %s | FileCheck %s
-module {
+module attributes {"#dlti.sys_spec" = #dlti.target_system_spec<"GPU" : #dlti.target_device_spec<#dlti.dl_entry<"max_work_group_size", 16 : i64>>>} {
func.func @linalg_mlp(%arg0: tensor<32x4096xf16>, %arg1: tensor<4096x4096xf16>, %arg2 : tensor<32x4096xf16>,
%arg3: tensor<4096x4096xf16>, %arg4 : tensor<32x4096xf16>) {
%cst = arith.constant 0.000000e+00 : f16 The second one fails because one of the inputs of the second f16_mlp_32x4096x4096x4096_transpose.mlir with NEW tilingBoth matmuls are placed into a single gpu kernel: func.func @linalg_mlp(%arg0: memref<32x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<32x4096xf16>, %arg3: memref<4096x4096xf16>, %arg4: memref<32x4096xf16>) {
%c128 = arith.constant 128 : index
%c4096 = arith.constant 4096 : index
%c8 = arith.constant 8 : index
%c32 = arith.constant 32 : index
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%cst = arith.constant 0.000000e+00 : f16
%0 = memref.get_global @__constant_8x4096xf16 : memref<8x4096xf16>
%1 = memref.get_global @__constant_8x128xf16 : memref<8x128xf16>
%alloc = memref.alloc() {alignment = 64 : i64} : memref<32x4096xf16>
gpu.launch blocks(%arg5, %arg6, %arg7) in (%arg11 = %c2, %arg12 = %c1, %arg13 = %c1) threads(%arg8, %arg9, %arg10) in (%arg14 = %c2, %arg15 = %c32, %arg16 = %c1) {
%2 = affine.apply #map(%arg5)
%subview_0 = memref.subview %alloc[%2, 0] [16, 4096] [1, 1] : memref<32x4096xf16> to memref<16x4096xf16, strided<[4096, 1], offset: ?>>
%3 = affine.apply #map1(%arg8)
%4 = affine.apply #map2(%arg9)
%5 = arith.addi %3, %2 : index
%subview_1 = memref.subview %arg4[%5, %4] [8, 128] [1, 1] : memref<32x4096xf16> to memref<8x128xf16, strided<[4096, 1], offset: ?>>
%subview_2 = memref.subview %arg2[%5, 0] [8, 4096] [1, 1] : memref<32x4096xf16> to memref<8x4096xf16, strided<[4096, 1], offset: ?>>
%subview_3 = memref.subview %arg0[%5, 0] [8, 4096] [1, 1] : memref<32x4096xf16> to memref<8x4096xf16, strided<[4096, 1], offset: ?>>
%alloc_4 = memref.alloc() : memref<16x131072xf16, 3>
%6 = arith.muli %arg8, %c8 : index
%7 = arith.muli %arg9, %c4096 : index
%subview_5 = memref.subview %alloc_4[%6, %7] [8, 4096] [1, 1] : memref<16x131072xf16, 3> to memref<8x4096xf16, strided<[131072, 1], offset: ?>, 3>
linalg.fill ins(%cst : f16) outs(%subview_5 : memref<8x4096xf16, strided<[131072, 1], offset: ?>, 3>)
linalg.matmul_transpose_b ins(%subview_3, %arg1 : memref<8x4096xf16, strided<[4096, 1], offset: ?>>, memref<4096x4096xf16>) outs(%subview_5 : memref<8x4096xf16, strided<[131072, 1], offset: ?>, 3>)
%alloc_6 = memref.alloc() : memref<16x131072xf16, 3>
%8 = arith.muli %arg8, %c8 : index
%9 = arith.muli %arg9, %c4096 : index
%subview_7 = memref.subview %alloc_6[%8, %9] [8, 4096] [1, 1] : memref<16x131072xf16, 3> to memref<8x4096xf16, strided<[131072, 1], offset: ?>, 3>
linalg.add ins(%subview_2, %subview_5 : memref<8x4096xf16, strided<[4096, 1], offset: ?>>, memref<8x4096xf16, strided<[131072, 1], offset: ?>, 3>) outs(%subview_7 : memref<8x4096xf16, strided<[131072, 1], offset: ?>, 3>)
%alloc_8 = memref.alloc() : memref<16x131072xf16, 3>
%10 = arith.muli %arg8, %c8 : index
%11 = arith.muli %arg9, %c4096 : index
%subview_9 = memref.subview %alloc_8[%10, %11] [8, 4096] [1, 1] : memref<16x131072xf16, 3> to memref<8x4096xf16, strided<[131072, 1], offset: ?>, 3>
linalg.max ins(%0, %subview_7 : memref<8x4096xf16>, memref<8x4096xf16, strided<[131072, 1], offset: ?>, 3>) outs(%subview_9 : memref<8x4096xf16, strided<[131072, 1], offset: ?>, 3>)
%subview_10 = memref.subview %arg3[%4, 0] [128, 4096] [1, 1] : memref<4096x4096xf16> to memref<128x4096xf16, strided<[4096, 1], offset: ?>>
%alloc_11 = memref.alloc() : memref<16x4096xf16, 3>
%12 = arith.muli %arg8, %c8 : index
%13 = arith.muli %arg9, %c128 : index
%subview_12 = memref.subview %alloc_11[%12, %13] [8, 128] [1, 1] : memref<16x4096xf16, 3> to memref<8x128xf16, strided<[4096, 1], offset: ?>, 3>
linalg.fill ins(%cst : f16) outs(%subview_12 : memref<8x128xf16, strided<[4096, 1], offset: ?>, 3>)
linalg.matmul_transpose_b ins(%subview_9, %subview_10 : memref<8x4096xf16, strided<[131072, 1], offset: ?>, 3>, memref<128x4096xf16, strided<[4096, 1], offset: ?>>) outs(%subview_12 : memref<8x128xf16, strided<[4096, 1], offset: ?>, 3>)
%alloc_13 = memref.alloc() : memref<16x4096xf16, 3>
%14 = arith.muli %arg8, %c8 : index
%15 = arith.muli %arg9, %c128 : index
%subview_14 = memref.subview %alloc_13[%14, %15] [8, 128] [1, 1] : memref<16x4096xf16, 3> to memref<8x128xf16, strided<[4096, 1], offset: ?>, 3>
linalg.add ins(%subview_1, %subview_12 : memref<8x128xf16, strided<[4096, 1], offset: ?>>, memref<8x128xf16, strided<[4096, 1], offset: ?>, 3>) outs(%subview_14 : memref<8x128xf16, strided<[4096, 1], offset: ?>, 3>)
%subview_15 = memref.subview %subview_0[%3, %4] [8, 128] [1, 1] : memref<16x4096xf16, strided<[4096, 1], offset: ?>> to memref<8x128xf16, strided<[4096, 1], offset: ?>>
linalg.max ins(%1, %subview_14 : memref<8x128xf16>, memref<8x128xf16, strided<[4096, 1], offset: ?>, 3>) outs(%subview_15 : memref<8x128xf16, strided<[4096, 1], offset: ?>>)
gpu.terminator
} {SCFToGPU_visited}
%subview = memref.subview %alloc[0, 0] [32, 2] [1, 1] : memref<32x4096xf16> to memref<32x2xf16, strided<[4096, 1]>>
%cast = memref.cast %subview : memref<32x2xf16, strided<[4096, 1]>> to memref<*xf16>
call @printMemrefF16(%cast) : (memref<*xf16>) -> ()
memref.dealloc %alloc : memref<32x4096xf16>
return
} f16_mlp_32x4096x4096x4096_transpose.mlir with OLD tilingTwo dependent matmuls are placed into separate kernels func.func @linalg_mlp(%arg0: memref<32x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<32x4096xf16>, %arg3: memref<4096x4096xf16>, %arg4: memref<32x4096xf16>, %arg5: memref<i8>) {
%c32 = arith.constant 32 : index
%c4096 = arith.constant 4096 : index
%c128 = arith.constant 128 : index
%c1 = arith.constant 1 : index
%cst = arith.constant 0.000000e+00 : f16
%0 = memref.get_global @__constant_32x32xf16 : memref<32x32xf16>
%alloc = memref.alloc() {alignment = 64 : i64} : memref<32x4096xf16>
%alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<32x4096xf16>
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<32x4096xf16>
gpu.launch blocks(%arg6, %arg7, %arg8) in (%arg12 = %c128, %arg13 = %c1, %arg14 = %c1) threads(%arg9, %arg10, %arg11) in (%arg15 = %c1, %arg16 = %c1, %arg17 = %c1) {
%1 = affine.apply affine_map<(d0) -> (d0 * 32)>(%arg6)
%subview_5 = memref.subview %arg1[%1, 0] [32, 4096] [1, 1] : memref<4096x4096xf16> to memref<32x4096xf16, strided<[4096, 1], offset: ?>>
%subview_6 = memref.subview %alloc[0, %1] [32, 32] [1, 1] : memref<32x4096xf16> to memref<32x32xf16, strided<[4096, 1], offset: ?>>
linalg.fill ins(%cst : f16) outs(%subview_6 : memref<32x32xf16, strided<[4096, 1], offset: ?>>)
linalg.matmul_transpose_b ins(%arg0, %subview_5 : memref<32x4096xf16>, memref<32x4096xf16, strided<[4096, 1], offset: ?>>) outs(%subview_6 : memref<32x32xf16, strided<[4096, 1], offset: ?>>)
%subview_7 = memref.subview %arg2[0, %1] [32, 32] [1, 1] : memref<32x4096xf16> to memref<32x32xf16, strided<[4096, 1], offset: ?>>
%subview_8 = memref.subview %alloc_0[0, %1] [32, 32] [1, 1] : memref<32x4096xf16> to memref<32x32xf16, strided<[4096, 1], offset: ?>>
linalg.add ins(%subview_7, %subview_6 : memref<32x32xf16, strided<[4096, 1], offset: ?>>, memref<32x32xf16, strided<[4096, 1], offset: ?>>) outs(%subview_8 : memref<32x32xf16, strided<[4096, 1], offset: ?>>)
%subview_9 = memref.subview %alloc_1[0, %1] [32, 32] [1, 1] : memref<32x4096xf16> to memref<32x32xf16, strided<[4096, 1], offset: ?>>
linalg.max ins(%0, %subview_8 : memref<32x32xf16>, memref<32x32xf16, strided<[4096, 1], offset: ?>>) outs(%subview_9 : memref<32x32xf16, strided<[4096, 1], offset: ?>>)
gpu.terminator
} {SCFToGPU_visited}
%alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<32x4096xf16>
%alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x4096xf16>
%alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<32x4096xf16>
gpu.launch blocks(%arg6, %arg7, %arg8) in (%arg12 = %c128, %arg13 = %c1, %arg14 = %c1) threads(%arg9, %arg10, %arg11) in (%arg15 = %c1, %arg16 = %c1, %arg17 = %c1) {
%1 = affine.apply affine_map<(d0) -> (d0 * 32)>(%arg6)
%subview_5 = memref.subview %arg3[%1, 0] [32, 4096] [1, 1] : memref<4096x4096xf16> to memref<32x4096xf16, strided<[4096, 1], offset: ?>>
%alloc_6 = memref.alloc() : memref<4096x32xf16, 3>
%2 = arith.muli %arg9, %c4096 : index
%3 = arith.muli %arg10, %c32 : index
%subview_7 = memref.subview %alloc_6[%2, %3] [4096, 32] [1, 1] : memref<4096x32xf16, 3> to memref<4096x32xf16, strided<[32, 1], offset: ?>, 3>
linalg.transpose ins(%subview_5 : memref<32x4096xf16, strided<[4096, 1], offset: ?>>) outs(%subview_7 : memref<4096x32xf16, strided<[32, 1], offset: ?>, 3>) permutation = [1, 0]
%subview_8 = memref.subview %alloc_2[0, %1] [32, 32] [1, 1] : memref<32x4096xf16> to memref<32x32xf16, strided<[4096, 1], offset: ?>>
linalg.fill ins(%cst : f16) outs(%subview_8 : memref<32x32xf16, strided<[4096, 1], offset: ?>>)
linalg.matmul ins(%alloc_1, %subview_7 : memref<32x4096xf16>, memref<4096x32xf16, strided<[32, 1], offset: ?>, 3>) outs(%subview_8 : memref<32x32xf16, strided<[4096, 1], offset: ?>>)
%subview_9 = memref.subview %arg4[0, %1] [32, 32] [1, 1] : memref<32x4096xf16> to memref<32x32xf16, strided<[4096, 1], offset: ?>>
%subview_10 = memref.subview %alloc_3[0, %1] [32, 32] [1, 1] : memref<32x4096xf16> to memref<32x32xf16, strided<[4096, 1], offset: ?>>
linalg.add ins(%subview_9, %subview_8 : memref<32x32xf16, strided<[4096, 1], offset: ?>>, memref<32x32xf16, strided<[4096, 1], offset: ?>>) outs(%subview_10 : memref<32x32xf16, strided<[4096, 1], offset: ?>>)
%subview_11 = memref.subview %alloc_4[0, %1] [32, 32] [1, 1] : memref<32x4096xf16> to memref<32x32xf16, strided<[4096, 1], offset: ?>>
linalg.max ins(%0, %subview_10 : memref<32x32xf16>, memref<32x32xf16, strided<[4096, 1], offset: ?>>) outs(%subview_11 : memref<32x32xf16, strided<[4096, 1], offset: ?>>)
gpu.terminator
} {SCFToGPU_visited}
%subview = memref.subview %alloc_4[0, 0] [32, 2] [1, 1] : memref<32x4096xf16> to memref<32x2xf16, strided<[4096, 1]>>
%cast = memref.cast %subview : memref<32x2xf16, strided<[4096, 1]>> to memref<*xf16>
call @printMemrefF16(%cast) : (memref<*xf16>) -> ()
memref.dealloc %alloc : memref<32x4096xf16>
memref.dealloc %alloc_0 : memref<32x4096xf16>
memref.dealloc %alloc_1 : memref<32x4096xf16>
memref.dealloc %alloc_2 : memref<32x4096xf16>
memref.dealloc %alloc_3 : memref<32x4096xf16>
memref.dealloc %alloc_4 : memref<32x4096xf16>
return
} The non-transposed case (f16_mlp_32x4096x4096x4096.mlir) works absolutely fine since two dependent matmuls are split into two separate kernels as expected, it's only the transposed one that causes the problem |
3597a30
to
88a7b26
Compare
Now the matmuls are split into 2 kernels and boths tests fail due to |
88a7b26
to
522f3e7
Compare
522f3e7
to
bddcd97
Compare
The error |
This path creates 2 nested loops for linalg operations, that are later converted to gpu.launch.
The outer loop is mapped to to the grid sizes and the inner loop is mapped to the block sizes.
The tiles calculation is based on the device information retrieved either from the module DLTI attributes or from the path options.
Depends on #406