Skip to content

Commit

Permalink
Add a kernel that performs permute on tiled inputs where the tile hei…
Browse files Browse the repository at this point in the history
…ght and width can both be swapped around (#17009)

### Ticket
#16464 (overarching issue tracking permute generality for tiled)
#16467 (tiled permute across all dimensions)
#16988 (program cache bug fix for transpose)

### Problem description

- Currently permute is implemented with several recursive transposes.
- Permute does not reach 100% on sweeps as some of the transposes have
limited support
- Permute can be non-performant in many cases due to the recursive
calls, which add dispatch overhead and have an unnecessary amount of
reads/writes

### What's changed
- Create permute kernels that work generically for all interleaved
inputs, both RM and tiled
- This PR in particular adds the final missing case: permute on tiled
inputs where both tiled dimensions can be swapped around
- Also add support for swapping around the W dimension in the row
invariant kernel as it's more performant to do there if H is not broken
- Remove recursive transpose calls in permute. Only keep calls to
transpose when there's a dedicated kernel for those swaps.
- Achieve 100% on PyTorch2 sweeps for permute
- Fix a program cache bug where the runtime args are not overwritten


I didn't make much of an attempt to optimize but even the initial
attempt has a lot of improvements for some tiled inputs.

| Permutation | Shape | Kernel Duration Old [ns] | Kernel Duration New
[ns] | Difference [ns] | % Difference (Improvement) |

|-------------|------------------|--------------------------|--------------------------|-----------------|----------------------------|
| [0,2,3,1] | [1, 1, 32, 32] | 10484 | 6031 | 4453 | 42 |
| [0,2,3,1] | [1, 1, 128, 128]| 49752 | 42526 | 7226 | 15 |
| [0,2,3,1] | [32, 32, 32, 32]| 103483 | 118654 | -15171 | -15 |
| [0,2,3,1] | [96, 96, 96, 96]| 8624100 | 8995629 | -371529 | -4 |
| [1,2,3,0] | [1, 1, 32, 32] | 11351 | 6136 | 5215 | 46 |
| [1,2,3,0] | [1, 1, 128, 128]| 49438 | 41728 | 7710 | 16 |
| [1,2,3,0] | [32, 32, 32, 32]| 127924 | 112612 | 15312 | 12 |
| [1,2,3,0] | [96, 96, 96, 96]| 10482327 | 10596224 | -113897 | -1 |
| [2,1,3,0] | [1, 1, 32, 32] | 13236 | 6122 | 7114 | 54 |
| [2,1,3,0] | [1, 1, 128, 128]| 61827 | 44486 | 17341 | 28 |
| [2,1,3,0] | [32, 32, 32, 32]| 154753 | 141337 | 13416 | 9 |
| [2,1,3,0] | [96, 96, 96, 96]| 12078552 | 10501675 | 1576877 | 13 |
| [2,0,3,1] | [1, 1, 32, 32] | 13122 | 6018 | 7104 | 54 |
| [2,0,3,1] | [1, 1, 128, 128]| 61895 | 39760 | 22135 | 36 |
| [2,0,3,1] | [32, 32, 32, 32]| 130613 | 107358 | 23255 | 18 |
| [2,0,3,1] | [96, 96, 96, 96]| 10475954 | 9088701 | 1387253 | 13 |
| [0, 3, 2, 1]| [1, 1, 32, 32] | 13531 | 6398 | 7133 | 53 |
| [0, 3, 2, 1]| [1, 1, 128, 128]| 54249 | 32988 | 21261 | 39 |
| [0, 3, 2, 1]| [32, 32, 32, 32]| 127644 | 103850 | 23794 | 19 |
| [0, 3, 2, 1]| [96, 96, 96, 96]| 10363108 | 9040940 | 1322168 | 13 |
| [3, 1, 2, 0]| [1, 1, 32, 32] | 15433 | 6318 | 9115 | 59 |
| [3, 1, 2, 0]| [1, 1, 128, 128]| 66011 | 34051 | 31960 | 48 |
| [3, 1, 2, 0]| [32, 32, 32, 32]| 182886 | 140970 | 41916 | 23 |
| [3, 1, 2, 0]| [96, 96, 96, 96]| 13773854 | 14002262 | -228408 | -2 |
| [1, 3, 2, 0]| [1, 1, 32, 32] | 13506 | 6532 | 6974 | 52 |
| [1, 3, 2, 0]| [1, 1, 128, 128]| 54333 | 33442 | 20891 | 38 |
| [1, 3, 2, 0]| [32, 32, 32, 32]| 159025 | 121565 | 37460 | 24 |
| [1, 3, 2, 0]| [96, 96, 96, 96]| 12104105 | 13157161 | -1053056 | -9 |
| [3, 0, 2, 1]| [1, 1, 32, 32] | 15407 | 6314 | 9093 | 59 |
| [3, 0, 2, 1]| [1, 1, 128, 128]| 66209 | 32760 | 33449 | 51 |
| [3, 0, 2, 1]| [32, 32, 32, 32]| 157755 | 104751 | 53004 | 34 |
| [3, 0, 2, 1]| [96, 96, 96, 96]| 12029488 | 8846762 | 3182726 | 26 |
| [2, 3, 0, 1]| [1, 1, 32, 32] | 86795 | 77277 | 9518 | 11 |
| [2, 3, 0, 1]| [1, 1, 128, 128]| 1007534 | 981804 | 25730 | 3 |
| [2, 3, 0, 1]| [32, 32, 32, 32]| 209988 | 102957 | 107031 | 51 |
| [2, 3, 0, 1]| [96, 96, 96, 96]| 17523170 | 8847372 | 8675798 | 50 |
| [3, 2, 1, 0]| [1, 1, 32, 32] | 114489 | 80261 | 34228 | 30 |
| [3, 2, 1, 0]| [1, 1, 128, 128]| 1290711 | 970495 | 320216 | 25 |
| [3, 2, 1, 0]| [32, 32, 32, 32]| 264289 | 113664 | 150625 | 57 |
| [3, 2, 1, 0]| [96, 96, 96, 96]| 20708815 | 12742424 | 7966391 | 38 |
| [2, 3, 1, 0]| [1, 1, 32, 32] | 117225 | 79516 | 37709 | 32 |
| [2, 3, 1, 0]| [1, 1, 128, 128]| 1301694 | 980410 | 321284 | 25 |
| [2, 3, 1, 0]| [32, 32, 32, 32]| 234711 | 116980 | 117731 | 50 |
| [2, 3, 1, 0]| [96, 96, 96, 96]| 19136366 | 12757227 | 6379139 | 33 |
| [3, 2, 0, 1]| [1, 1, 32, 32] | 86944 | 77929 | 9015 | 10 |
| [3, 2, 0, 1]| [1, 1, 128, 128]| 998782 | 975924 | 22858 | 2 |
| [3, 2, 0, 1]| [32, 32, 32, 32]| 233709 | 101485 | 132224 | 57 |
| [3, 2, 0, 1]| [96, 96, 96, 96]| 19106782 | 9550606 | 9556176 | 50 |


### Checklist
- [ ] Post commit CI passes
https://github.com/tenstorrent/tt-metal/actions/runs/12921962567
- [ ] Blackhole Post commit (if applicable)
https://github.com/tenstorrent/tt-metal/actions/runs/12921962567
- [ ] Model regression CI testing passes (if applicable)
https://github.com/tenstorrent/tt-metal/actions/runs/12921219125
- [ ] Device performance regression CI testing passes (if applicable)
https://github.com/tenstorrent/tt-metal/actions/runs/12918280077
- [ ] **(For models and ops writers)** Full [new
models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml)
tests passes
- [ ] New/Existing tests provide coverage for changes
  • Loading branch information
sjameelTT authored Jan 23, 2025
1 parent 72d5f79 commit 14f2057
Show file tree
Hide file tree
Showing 15 changed files with 1,354 additions and 126 deletions.
2 changes: 1 addition & 1 deletion models/demos/bert_tiny/tests/test_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def test_perf_device_bare_metal(batch_size, expected_perf):
margin = 0.03

if is_wormhole_b0():
expected_perf = 3990.0
expected_perf = 4114.8
else:
expected_perf = 3460.0

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1116,3 +1116,32 @@ def test_transpose_16411(device):
assert_with_pcc(p_c2, ttnn.to_torch(c2), 0.9999)
assert_with_pcc(p_c3, ttnn.to_torch(c3), 0.9999)
assert_with_pcc(p_c4, ttnn.to_torch(c4), 0.9999)


@pytest.mark.parametrize("rank", [5])
@pytest.mark.parametrize("indices", [[0, 1], [0, 2], [0, 3], [0, 4], [1, 2], [1, 3], [1, 4], [2, 3], [2, 4], [3, 4]])
@pytest.mark.parametrize("layout", [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT])
def test_transpose_high_rank(*, device: ttnn.Device, rank: int, indices, layout):
torch.manual_seed(2005)
ttnn.disable_and_clear_program_cache(device)
ttnn.enable_program_cache(device)

shape = [2] * rank

a = torch.randn(shape, dtype=torch.bfloat16)
b = torch.randn(shape, dtype=torch.bfloat16)

tt_a = ttnn.from_torch(a, device=device, layout=layout)
tt_b = ttnn.from_torch(b, device=device, layout=layout)

a = a.transpose(*indices)
b = b.transpose(*indices)

tt_a = ttnn.transpose(tt_a, *indices)
tt_b = ttnn.transpose(tt_b, *indices)

output_a = ttnn.to_torch(tt_a)
output_b = ttnn.to_torch(tt_b)

assert torch.allclose(a, output_a)
assert torch.allclose(b, output_b)
124 changes: 117 additions & 7 deletions tests/ttnn/unit_tests/operations/test_permute.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

import torch
import math

import ttnn
import itertools
Expand Down Expand Up @@ -182,13 +183,13 @@ def generate_permutations(N):
yield perm


@skip_for_blackhole("tilize_block gives bad pcc after second iteration")
@skip_for_grayskull("tilize_block gives bad pcc after second iteration")
@pytest.mark.parametrize("shape", [(7, 7, 7, 7, 7)])
@pytest.mark.parametrize("perm", generate_permutations(5))
@pytest.mark.parametrize("memory_config", [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG])
@pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.float32])
def test_permute_5d_width(shape, perm, memory_config, dtype, device):
if is_grayskull() and dtype == ttnn.float32:
pytest.skip("Grayskull doesn't support float32")
torch.manual_seed(2005)
input_a = torch.randn(shape)
torch_output = torch.permute(input_a, perm)
Expand All @@ -202,13 +203,13 @@ def test_permute_5d_width(shape, perm, memory_config, dtype, device):
assert_with_pcc(torch_output, tt_output, 0.9999)


@skip_for_blackhole("tilize_block gives bad pcc after second iteration")
@skip_for_grayskull("tilize_block gives bad pcc after second iteration")
@pytest.mark.parametrize("shape", [(3, 65, 3, 3, 65), (1, 6, 256, 20, 50), (6, 20, 50, 1, 256)])
@pytest.mark.parametrize("perm", [(4, 0, 3, 2, 1), (1, 3, 4, 0, 2), (3, 0, 4, 1, 2)])
@pytest.mark.parametrize("memory_config", [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG])
@pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.float32])
def test_permute_5d_blocked(shape, perm, memory_config, dtype, device):
if is_grayskull() and dtype == ttnn.float32:
pytest.skip("Grayskull doesn't support float32")
torch.manual_seed(520)
input_a = torch.randn(shape)

Expand All @@ -224,8 +225,6 @@ def test_permute_5d_blocked(shape, perm, memory_config, dtype, device):
assert_with_pcc(torch_output, tt_output, 0.9999)


@skip_for_blackhole("tilize_block gives bad pcc after second iteration")
@skip_for_grayskull("tilize_block gives bad pcc after second iteration")
def test_permute_nd(device):
torch.manual_seed(2005)
torch_tensor = torch.rand((1, 3, 16, 16, 16, 16), dtype=torch.bfloat16)
Expand Down Expand Up @@ -400,7 +399,6 @@ def test_permutations_5d_fixed_w(shape, perm, dtype, device):
assert_with_pcc(torch_output, output_tensor, 0.9999)


@pytest.mark.skip("#16575 to_layout from tiled to RM fails on reshape")
@pytest.mark.parametrize("shape", [[1, 9, 91, 7, 9]])
@pytest.mark.parametrize("perm", [[0, 3, 4, 1, 2]])
def test_permute_adversarial(shape, perm, device):
Expand All @@ -427,3 +425,115 @@ def test_permute_4d_fixed_w(shape, perm, device):
torch_output = torch.permute(torch_tensor, perm)
assert torch_output.shape == output_tensor.shape
assert_with_pcc(torch_output, output_tensor, 0.9999)


def generate_fixed_no_dim0_dim1_transpose_permutations(N, dim0, dim1):
perms_Nd = generate_permutations(N)
for perm in perms_Nd:
if perm[dim0] != dim1:
yield perm


@pytest.mark.parametrize("shape", [[7, 7, 7, 17, 17]])
@pytest.mark.parametrize("perm", [[0, 1, 4, 3, 2]])
@pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.float32])
@pytest.mark.parametrize("pad_value", [35.0, float("-inf"), None])
def test_permute_5d_yw_padded(shape, perm, dtype, pad_value, device):
if is_grayskull() and dtype == ttnn.float32:
pytest.skip("Grayskull doesn't support float32")
torch.manual_seed(2005)
torch_tensor = torch.rand(shape, dtype=torch.bfloat16)
input_tensor = ttnn.from_torch(torch_tensor, layout=ttnn.TILE_LAYOUT, dtype=dtype, device=device)
ttnn_output = ttnn.permute(input_tensor, perm, pad_value=pad_value)
output_tensor = ttnn.to_torch(ttnn_output)
torch_output = torch.permute(torch_tensor, perm)

assert torch_output.shape == output_tensor.shape
assert_with_pcc(torch_output, output_tensor, 0.9999)

if pad_value != None:
logical_shape = torch_output.shape
output_padded = ttnn.from_device(ttnn_output).to_torch()
padded_shape = output_padded.shape
num_padded_values = torch.prod(torch.tensor(padded_shape)) - torch.prod(torch.tensor(logical_shape))
assert torch.sum(output_padded == pad_value) == num_padded_values


@pytest.mark.parametrize("shape", [[33, 1, 17, 33, 33]])
@pytest.mark.parametrize("perm", generate_fixed_no_dim0_dim1_transpose_permutations(5, 4, 3))
@pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.float32])
def test_permute_5d_yw_permutations(shape, perm, dtype, device):
if is_grayskull() and dtype == ttnn.float32:
pytest.skip("Grayskull doesn't support float32")
torch.manual_seed(2005)
torch_tensor = torch.rand(shape, dtype=torch.bfloat16)
input_tensor = ttnn.from_torch(torch_tensor, layout=ttnn.TILE_LAYOUT, dtype=dtype, device=device)
output_tensor = ttnn.permute(input_tensor, perm)
output_tensor = ttnn.to_torch(output_tensor)
torch_output = torch.permute(torch_tensor, perm)
assert torch_output.shape == output_tensor.shape
assert_with_pcc(torch_output, output_tensor, 0.9999)


@pytest.mark.parametrize("shape", [[1, 1, 32, 32], [1, 1, 128, 128], [32, 32, 32, 32], [96, 96, 96, 96]])
@pytest.mark.parametrize("perm", [[0, 3, 2, 1], [3, 1, 2, 0], [1, 3, 2, 0], [3, 0, 2, 1]])
@pytest.mark.parametrize("dtype", [ttnn.bfloat16])
def test_permute_4d_yw_permutations(shape, perm, dtype, device):
if is_grayskull() and dtype == ttnn.float32:
pytest.skip("Grayskull doesn't support float32")
torch.manual_seed(2005)
torch_tensor = torch.rand(shape, dtype=torch.bfloat16)
input_tensor = ttnn.from_torch(torch_tensor, layout=ttnn.TILE_LAYOUT, dtype=dtype, device=device)
output_tensor = ttnn.permute(input_tensor, perm)
output_tensor = ttnn.to_torch(output_tensor)
torch_output = torch.permute(torch_tensor, perm)
assert torch_output.shape == output_tensor.shape
assert_with_pcc(torch_output, output_tensor, 0.9999)


@pytest.mark.parametrize("shape", [[1, 1, 32, 32], [1, 1, 128, 128], [32, 32, 32, 32], [96, 96, 96, 96]])
@pytest.mark.parametrize("perm", [[2, 3, 0, 1], [3, 2, 1, 0], [2, 3, 1, 0], [3, 2, 0, 1]])
@pytest.mark.parametrize("dtype", [ttnn.bfloat16])
def test_permute_4d_whyx_permutations(shape, perm, dtype, device):
if is_grayskull() and dtype == ttnn.float32:
pytest.skip("Grayskull doesn't support float32")
torch.manual_seed(2005)
torch_tensor = torch.rand(shape, dtype=torch.bfloat16)
input_tensor = ttnn.from_torch(torch_tensor, layout=ttnn.TILE_LAYOUT, dtype=dtype, device=device)
output_tensor = ttnn.permute(input_tensor, perm)
output_tensor = ttnn.to_torch(output_tensor)
torch_output = torch.permute(torch_tensor, perm)
assert torch_output.shape == output_tensor.shape
assert_with_pcc(torch_output, output_tensor, 0.9999)


@pytest.mark.parametrize("shape", [[1, 1, 32, 32], [1, 1, 128, 128], [32, 32, 32, 32], [96, 96, 96, 96]])
@pytest.mark.parametrize("perm", [[0, 2, 3, 1], [0, 3, 1, 2], [1, 2, 3, 0], [2, 1, 3, 0], [2, 0, 3, 1]])
@pytest.mark.parametrize("dtype", [ttnn.bfloat16])
def test_permute_4d_other_permutations(shape, perm, dtype, device):
if is_grayskull() and dtype == ttnn.float32:
pytest.skip("Grayskull doesn't support float32")
torch.manual_seed(2005)
torch_tensor = torch.rand(shape, dtype=torch.bfloat16)
input_tensor = ttnn.from_torch(torch_tensor, layout=ttnn.TILE_LAYOUT, dtype=dtype, device=device)
output_tensor = ttnn.permute(input_tensor, perm)
output_tensor = ttnn.to_torch(output_tensor)
torch_output = torch.permute(torch_tensor, perm)
assert torch_output.shape == output_tensor.shape
assert_with_pcc(torch_output, output_tensor, 0.9999)


@pytest.mark.parametrize("shape", [[33, 1, 17, 33, 33]])
@pytest.mark.parametrize("perm", [[0, 1, 4, 2, 3], [0, 4, 1, 2, 3], [2, 4, 1, 0, 3], [4, 2, 1, 0, 3]])
@pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.float32])
def test_permute_5d_wyh(shape, perm, dtype, device):
if is_grayskull() and dtype == ttnn.float32:
pytest.skip("Grayskull doesn't support float32")
torch.manual_seed(2005)
torch_tensor = torch.rand(shape, dtype=torch.bfloat16)
input_tensor = ttnn.from_torch(torch_tensor, layout=ttnn.TILE_LAYOUT, dtype=dtype, device=device)
output_tensor = ttnn.permute(input_tensor, perm, pad_value=0.0)
output_tensor = ttnn.to_torch(output_tensor)
torch_output = torch.permute(torch_tensor, perm)
assert torch_output.shape == output_tensor.shape
assert_with_pcc(torch_output, output_tensor, 0.9999)
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ void MAIN {

cb_push_back(cb_out, w_block_size);

cb_wait_front(cb_out, w_block_size);
pack_untilize_uninit(cb_out);

cb_pop_front(cb_tilize, 1);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <cstdint>

#include "compute_kernel_api/eltwise_unary/eltwise_unary.h"
#include "compute_kernel_api/transpose_wh.h"
#include "compute_kernel_api/tilize.h"
#include "compute_kernel_api/untilize.h"
#include "compute_kernel_api/pack_untilize.h"
#include "tt_metal/hw/inc/circular_buffer.h"

namespace NAMESPACE {

void MAIN {
// X = output width
// Y = output height
// input shape = (..., H, W)
// output shape = (..., Y, X)

/**
* This kernel takes in the contiguous XW block read in in the reader kernel and transposes is to a WX block, ready
* to be written out The transpose LLK does not support transposing a tile without faces/subtiles, so we need to
* rearrange it into its faces, transpose, and then pack it back such that it's de-faced (WX, where X is contiguous
* and isn't divided into subtiles)
*/
uint32_t start_block = get_arg_val<uint32_t>(0);
uint32_t end_block = get_arg_val<uint32_t>(1);

constexpr auto cb_in = tt::CBIndex::c_0;
constexpr auto cb_tilize = tt::CBIndex::c_1;
constexpr auto cb_out = tt::CBIndex::c_2;

unary_op_init_common(cb_in, cb_out);

for (uint32_t block = start_block; block < end_block; block++) {
// tilize input via unpack and then pack
tilize_init_short(cb_in, 1, cb_tilize);

cb_wait_front(cb_in, 1);
cb_reserve_back(cb_tilize, 1);

tilize_block(cb_in, 1, cb_tilize);

cb_push_back(cb_tilize, 1);
cb_pop_front(cb_in, 1);

tilize_uninit(cb_in, cb_tilize);

// transpose input
cb_wait_front(cb_tilize, 1);

transpose_wh_init_short(cb_tilize);
pack_untilize_dst_init_short<1>(cb_out);

tile_regs_acquire();
transpose_wh_tile(cb_tilize, 0, 0); // transpose call
tile_regs_commit();

// pack and untilize
cb_reserve_back(cb_out, 1);

tile_regs_wait();
pack_untilize_dst<1>(cb_out); // pack call
tile_regs_release();

cb_push_back(cb_out, 1);

pack_untilize_uninit(cb_out);

cb_pop_front(cb_tilize, 1);
}
}
} // namespace NAMESPACE
Loading

0 comments on commit 14f2057

Please sign in to comment.