Skip to content

Commit

Permalink
Disable DMA transpose on attention forward kernel
Browse files Browse the repository at this point in the history
While we are stabilizing the DMA transpose
  • Loading branch information
aws-zhehongb authored Jan 9, 2025
1 parent c72a66b commit 21ed227
Showing 1 changed file with 2 additions and 9 deletions.
11 changes: 2 additions & 9 deletions src/nki_samples/reference/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,13 @@ class FlashConfig:
@nki.jit(mode='trace')
def transpose_p_local(p_local_transposed, p_local, LARGE_TILE_SZ):
for i in nl.affine_range(LARGE_TILE_SZ // 512):
if nisa.get_nc_version() == nisa.nc_version.gen3:
p_local_t_tmp = nl.ndarray((par_dim(128), 512), buffer=nl.sbuf, dtype=p_local.dtype)
else:
p_local_t_tmp = nl.ndarray((par_dim(128), 512), buffer=nl.psum, dtype=np.float32)
p_local_t_tmp = nl.ndarray((par_dim(128), 512), buffer=nl.psum, dtype=np.float32)

for j in nl.affine_range(512 // 128):
j_128_slice = nl.ds(j * 128, 128)
i_j_128_slice = nl.ds(i * 512 + j * 128, 128)

if nisa.get_nc_version() == nisa.nc_version.gen3:
p_local_t_tmp[:, j_128_slice] = nisa.dma_transpose(
p_local[:, i_j_128_slice])
else:
p_local_t_tmp[:, j_128_slice] = nisa.nc_transpose(
p_local_t_tmp[:, j_128_slice] = nisa.nc_transpose(
p_local[:, i_j_128_slice])

p_local_transposed[:, nl.ds(i * 512, 512)] = nl.copy(
Expand Down

0 comments on commit 21ed227

Please sign in to comment.