From 21ed2274faf9e4e8f98cbfe7516fc1ea04bb94d5 Mon Sep 17 00:00:00 2001 From: aws-zhehongb <114249045+aws-zhehongb@users.noreply.github.com> Date: Thu, 9 Jan 2025 10:37:51 -0800 Subject: [PATCH] Disable DMA transpose on attention forward kernel While we are stabilizing the DMA transpose --- src/nki_samples/reference/attention.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/src/nki_samples/reference/attention.py b/src/nki_samples/reference/attention.py index 3c456a6..27ac89d 100644 --- a/src/nki_samples/reference/attention.py +++ b/src/nki_samples/reference/attention.py @@ -44,20 +44,13 @@ class FlashConfig: @nki.jit(mode='trace') def transpose_p_local(p_local_transposed, p_local, LARGE_TILE_SZ): for i in nl.affine_range(LARGE_TILE_SZ // 512): - if nisa.get_nc_version() == nisa.nc_version.gen3: - p_local_t_tmp = nl.ndarray((par_dim(128), 512), buffer=nl.sbuf, dtype=p_local.dtype) - else: - p_local_t_tmp = nl.ndarray((par_dim(128), 512), buffer=nl.psum, dtype=np.float32) + p_local_t_tmp = nl.ndarray((par_dim(128), 512), buffer=nl.psum, dtype=np.float32) for j in nl.affine_range(512 // 128): j_128_slice = nl.ds(j * 128, 128) i_j_128_slice = nl.ds(i * 512 + j * 128, 128) - if nisa.get_nc_version() == nisa.nc_version.gen3: - p_local_t_tmp[:, j_128_slice] = nisa.dma_transpose( - p_local[:, i_j_128_slice]) - else: - p_local_t_tmp[:, j_128_slice] = nisa.nc_transpose( + p_local_t_tmp[:, j_128_slice] = nisa.nc_transpose( p_local[:, i_j_128_slice]) p_local_transposed[:, nl.ds(i * 512, 512)] = nl.copy(