Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a kernel that performs permute on tiled inputs where the tile hei…
…ght and width can both be swapped around (#17009) ### Ticket #16464 (overarching issue tracking permute generality for tiled) #16467 (tiled permute across all dimensions) #16988 (program cache bug fix for transpose) ### Problem description - Currently permute is implemented with several recursive transposes. - Permute does not reach 100% on sweeps as some of the transposes have limited support - Permute can be non-performant in many cases due to the recursive calls, which add dispatch overhead and have an unnecessary amount of reads/writes ### What's changed - Create permute kernels that work generically for all interleaved inputs, both RM and tiled - This PR in particular adds the final missing case: permute on tiled inputs where both tiled dimensions can be swapped around - Also add support for swapping around the W dimension in the row invariant kernel as it's more performant to do there if H is not broken - Remove recursive transpose calls in permute. Only keep calls to transpose when there's a dedicated kernel for those swaps. - Achieve 100% on PyTorch2 sweeps for permute - Fix a program cache bug where the runtime args are not overwritten I didn't make much of an attempt to optimize but even the initial attempt has a lot of improvements for some tiled inputs. | Permutation | Shape | Kernel Duration Old [ns] | Kernel Duration New [ns] | Difference [ns] | % Difference (Improvement) | |-------------|------------------|--------------------------|--------------------------|-----------------|----------------------------| | [0,2,3,1] | [1, 1, 32, 32] | 10484 | 6031 | 4453 | 42 | | [0,2,3,1] | [1, 1, 128, 128]| 49752 | 42526 | 7226 | 15 | | [0,2,3,1] | [32, 32, 32, 32]| 103483 | 118654 | -15171 | -15 | | [0,2,3,1] | [96, 96, 96, 96]| 8624100 | 8995629 | -371529 | -4 | | [1,2,3,0] | [1, 1, 32, 32] | 11351 | 6136 | 5215 | 46 | | [1,2,3,0] | [1, 1, 128, 128]| 49438 | 41728 | 7710 | 16 | | [1,2,3,0] | [32, 32, 32, 32]| 127924 | 112612 | 15312 | 12 | | [1,2,3,0] | [96, 96, 96, 96]| 10482327 | 10596224 | -113897 | -1 | | [2,1,3,0] | [1, 1, 32, 32] | 13236 | 6122 | 7114 | 54 | | [2,1,3,0] | [1, 1, 128, 128]| 61827 | 44486 | 17341 | 28 | | [2,1,3,0] | [32, 32, 32, 32]| 154753 | 141337 | 13416 | 9 | | [2,1,3,0] | [96, 96, 96, 96]| 12078552 | 10501675 | 1576877 | 13 | | [2,0,3,1] | [1, 1, 32, 32] | 13122 | 6018 | 7104 | 54 | | [2,0,3,1] | [1, 1, 128, 128]| 61895 | 39760 | 22135 | 36 | | [2,0,3,1] | [32, 32, 32, 32]| 130613 | 107358 | 23255 | 18 | | [2,0,3,1] | [96, 96, 96, 96]| 10475954 | 9088701 | 1387253 | 13 | | [0, 3, 2, 1]| [1, 1, 32, 32] | 13531 | 6398 | 7133 | 53 | | [0, 3, 2, 1]| [1, 1, 128, 128]| 54249 | 32988 | 21261 | 39 | | [0, 3, 2, 1]| [32, 32, 32, 32]| 127644 | 103850 | 23794 | 19 | | [0, 3, 2, 1]| [96, 96, 96, 96]| 10363108 | 9040940 | 1322168 | 13 | | [3, 1, 2, 0]| [1, 1, 32, 32] | 15433 | 6318 | 9115 | 59 | | [3, 1, 2, 0]| [1, 1, 128, 128]| 66011 | 34051 | 31960 | 48 | | [3, 1, 2, 0]| [32, 32, 32, 32]| 182886 | 140970 | 41916 | 23 | | [3, 1, 2, 0]| [96, 96, 96, 96]| 13773854 | 14002262 | -228408 | -2 | | [1, 3, 2, 0]| [1, 1, 32, 32] | 13506 | 6532 | 6974 | 52 | | [1, 3, 2, 0]| [1, 1, 128, 128]| 54333 | 33442 | 20891 | 38 | | [1, 3, 2, 0]| [32, 32, 32, 32]| 159025 | 121565 | 37460 | 24 | | [1, 3, 2, 0]| [96, 96, 96, 96]| 12104105 | 13157161 | -1053056 | -9 | | [3, 0, 2, 1]| [1, 1, 32, 32] | 15407 | 6314 | 9093 | 59 | | [3, 0, 2, 1]| [1, 1, 128, 128]| 66209 | 32760 | 33449 | 51 | | [3, 0, 2, 1]| [32, 32, 32, 32]| 157755 | 104751 | 53004 | 34 | | [3, 0, 2, 1]| [96, 96, 96, 96]| 12029488 | 8846762 | 3182726 | 26 | | [2, 3, 0, 1]| [1, 1, 32, 32] | 86795 | 77277 | 9518 | 11 | | [2, 3, 0, 1]| [1, 1, 128, 128]| 1007534 | 981804 | 25730 | 3 | | [2, 3, 0, 1]| [32, 32, 32, 32]| 209988 | 102957 | 107031 | 51 | | [2, 3, 0, 1]| [96, 96, 96, 96]| 17523170 | 8847372 | 8675798 | 50 | | [3, 2, 1, 0]| [1, 1, 32, 32] | 114489 | 80261 | 34228 | 30 | | [3, 2, 1, 0]| [1, 1, 128, 128]| 1290711 | 970495 | 320216 | 25 | | [3, 2, 1, 0]| [32, 32, 32, 32]| 264289 | 113664 | 150625 | 57 | | [3, 2, 1, 0]| [96, 96, 96, 96]| 20708815 | 12742424 | 7966391 | 38 | | [2, 3, 1, 0]| [1, 1, 32, 32] | 117225 | 79516 | 37709 | 32 | | [2, 3, 1, 0]| [1, 1, 128, 128]| 1301694 | 980410 | 321284 | 25 | | [2, 3, 1, 0]| [32, 32, 32, 32]| 234711 | 116980 | 117731 | 50 | | [2, 3, 1, 0]| [96, 96, 96, 96]| 19136366 | 12757227 | 6379139 | 33 | | [3, 2, 0, 1]| [1, 1, 32, 32] | 86944 | 77929 | 9015 | 10 | | [3, 2, 0, 1]| [1, 1, 128, 128]| 998782 | 975924 | 22858 | 2 | | [3, 2, 0, 1]| [32, 32, 32, 32]| 233709 | 101485 | 132224 | 57 | | [3, 2, 0, 1]| [96, 96, 96, 96]| 19106782 | 9550606 | 9556176 | 50 | ### Checklist - [ ] Post commit CI passes https://github.com/tenstorrent/tt-metal/actions/runs/12921962567 - [ ] Blackhole Post commit (if applicable) https://github.com/tenstorrent/tt-metal/actions/runs/12921962567 - [ ] Model regression CI testing passes (if applicable) https://github.com/tenstorrent/tt-metal/actions/runs/12921219125 - [ ] Device performance regression CI testing passes (if applicable) https://github.com/tenstorrent/tt-metal/actions/runs/12918280077 - [ ] **(For models and ops writers)** Full [new models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml) tests passes - [ ] New/Existing tests provide coverage for changes
- Loading branch information