diff --git a/tests/ttnn/unit_tests/operations/test_permute.py b/tests/ttnn/unit_tests/operations/test_permute.py index 560a8f1ef81b..d3cbd33a14ca 100644 --- a/tests/ttnn/unit_tests/operations/test_permute.py +++ b/tests/ttnn/unit_tests/operations/test_permute.py @@ -499,7 +499,7 @@ def generate_fixed_no_dim0_dim1_transpose_permutations(N, dim0, dim1): @pytest.mark.parametrize("shape", [[7, 7, 7, 17, 17]]) @pytest.mark.parametrize("perm", [[0, 1, 4, 3, 2]]) -@pytest.mark.parametrize("dtype", [ttnn.float32]) +@pytest.mark.parametrize("dtype", [ttnn.bfloat16]) def test_permute_5d_yw(shape, perm, dtype, device): torch.set_printoptions(threshold=300000000) if is_grayskull() and dtype == ttnn.float32: @@ -507,7 +507,7 @@ def test_permute_5d_yw(shape, perm, dtype, device): torch.manual_seed(2005) torch_tensor = torch.rand(shape, dtype=torch.bfloat16) input_tensor = ttnn.from_torch(torch_tensor, layout=ttnn.TILE_LAYOUT, dtype=dtype, device=device) - output_tensor = ttnn.permute(input_tensor, perm, pad_value=None) + output_tensor = ttnn.permute(input_tensor, perm, pad_value=0.0) print(ttnn.from_device(output_tensor).to_torch()) output_tensor = ttnn.to_torch(output_tensor) torch_output = torch.permute(torch_tensor, perm) diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/compute/transpose_xw_tiled.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/compute/transpose_xw_tiled.cpp index 9b00fbaf575d..fdbaac4109ab 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/compute/transpose_xw_tiled.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/compute/transpose_xw_tiled.cpp @@ -9,14 +9,21 @@ #include "compute_kernel_api/tilize.h" #include "compute_kernel_api/untilize.h" #include "compute_kernel_api/pack_untilize.h" +#include "tt_metal/hw/inc/circular_buffer.h" namespace NAMESPACE { + void MAIN { constexpr uint32_t N = get_compile_time_arg_val(0); constexpr uint32_t x_blocks = get_compile_time_arg_val(1); constexpr uint32_t w_blocks = get_compile_time_arg_val(2); constexpr uint32_t H = get_compile_time_arg_val(3); + constexpr uint32_t read_alignment = get_compile_time_arg_val(4); + constexpr uint32_t SUBTILE_LINE_BYTES = get_compile_time_arg_val(5); + constexpr uint32_t misalignment = read_alignment - SUBTILE_LINE_BYTES; + constexpr uint32_t misalignment_div_16 = misalignment >> cb_addr_shift; + uint32_t offset_div_16 = 0; uint32_t start_block = get_arg_val(0); uint32_t end_block = get_arg_val(1); @@ -28,6 +35,7 @@ void MAIN { for (uint32_t block = start_block; block < end_block; block++) { // Decompose block into w_block, x_block, and xw_block indices +#ifdef TRISC_UNPACK uint32_t rem = block; uint32_t w_block = rem % w_blocks; // Which W block are we in? rem /= w_blocks; @@ -36,6 +44,7 @@ void MAIN { rem /= x_blocks; uint32_t h = rem % H; +#endif // tilize input via unpack and then pack tilize_init_short(cb_in, 1, cb_tilize); @@ -43,7 +52,31 @@ void MAIN { cb_wait_front(cb_in, 1); cb_reserve_back(cb_tilize, 1); - tilize_block(cb_in, 1, cb_tilize); // tilize and pack into cb_tilize + // For BH, DRAM read alignment is 64B, but each subtile/face line is 32B, so every odd numbered row in BFLOAT16 + // is misaligned +#ifdef TRISC_UNPACK + if constexpr (misalignment > 0) { + // if h is an odd number, offset_div_16 by misalignment + if ((h & 1) == 1) { + // offset_div_16 = misalignment_div_16; + std::uint32_t operand_id = get_operand_id(cb_in); + get_local_cb_interface(operand_id).fifo_rd_ptr += misalignment_div_16; + } + } +#endif + // custom_tilize_block(cb_in, offset_div_16, 1, cb_tilize); // tilize and pack into cb_tilize + tilize_block(cb_in, 1, cb_tilize); + +#ifdef TRISC_UNPACK + if constexpr (misalignment > 0) { + // if h is an odd number, offset_div_16 by misalignment + if ((h & 1) == 1) { + // offset_div_16 = misalignment_div_16; + std::uint32_t operand_id = get_operand_id(cb_in); + get_local_cb_interface(operand_id).fifo_rd_ptr -= misalignment_div_16; + } + } +#endif cb_push_back(cb_tilize, 1); cb_pop_front(cb_in, 1); diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_tiled_generic.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_tiled_generic.cpp index bcded12d9971..014e7482b13d 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_tiled_generic.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_tiled_generic.cpp @@ -37,6 +37,7 @@ void kernel_main() { constexpr bool needs_x_padding = static_cast(get_compile_time_arg_val(24)); constexpr bool needs_y_padding = static_cast(get_compile_time_arg_val(25)); constexpr uint32_t non_x_rows = get_compile_time_arg_val(26); + constexpr uint32_t read_alignment = get_compile_time_arg_val(27); // ------------------------------------------------------------------------ // 2) Derived Constants (kept as constexpr) @@ -53,6 +54,8 @@ void kernel_main() { constexpr uint32_t FACE_H_STRIDE_BYTES = NUM_FACES_W * FACE_HW_BYTES; constexpr uint32_t tile_bytes = TILE_HW * element_size; + constexpr uint32_t misalignment = read_alignment - SUBTILE_LINE_BYTES; + // For x-padding logic: constexpr uint32_t final_face_real_w = (W % FACE_WIDTH); constexpr uint32_t ratio = sizeof(uint32_t) / element_size; @@ -165,6 +168,11 @@ void kernel_main() { // Reserve a slot in the circular buffer, get L1 pointer cb_reserve_back(tt::CBIndex::c_0, 1); uint32_t src_buffer_l1_addr = get_write_ptr(tt::CBIndex::c_0); + if constexpr (misalignment > 0) { + if ((h & 1) == 1) { + src_buffer_l1_addr += misalignment; + } + } // -------------------------------------------------------------------- // 5.1) Async read for [x_start..x_end) diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_tiled_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_tiled_program_factory.cpp index bc6fe9e83368..f04450b08579 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_tiled_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_tiled_program_factory.cpp @@ -5,6 +5,7 @@ #include "cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.hpp" #include #include +#include namespace ttnn::operations::data_movement { @@ -72,6 +73,13 @@ ttnn::SmallVector get_inverse_permutation(const ttnn::SmallVectorbuffer_type() == tt::tt_metal::BufferType::DRAM + ? tt::tt_metal::experimental::hal::get_dram_alignment() + : tt::tt_metal::experimental::hal::get_l1_alignment()); +} + } // namespace detail PermuteDeviceOperation::MultiCoreTileInvariant::cached_program_t PermuteDeviceOperation::MultiCoreTileInvariant::create( @@ -497,6 +505,10 @@ PermuteDeviceOperation::MultiCoreTiledGeneric::cached_program_t PermuteDeviceOpe uint32_t H_t = H_p / tile_shape[0]; uint32_t W_t = W_p / tile_shape[1]; + uint32_t subtile_line_bytes = face_shape[1] * element_size; + uint32_t read_alignment = detail::get_buffer_alignment(input_tensor); + uint32_t misalignment = read_alignment - subtile_line_bytes; + uint32_t permuted_w_dim = 0; // Will hold the position of w_dim in the permuted array for (uint32_t i = 0; i < N; ++i) { if (dims[i] == N - 1) { @@ -591,10 +603,7 @@ PermuteDeviceOperation::MultiCoreTiledGeneric::cached_program_t PermuteDeviceOpe all_cores = num_cores > padded_num_cores ? all_cores : padded_all_cores; tt::DataFormat cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(input_tensor.get_dtype()); - uint32_t input_page_size = detail::tile_size(tensor_return_value); - - tt::DataFormat cb_data_format_output = tt::tt_metal::datatype_to_dataformat_converter(output_tensor.get_dtype()); - uint32_t output_page_size = detail::tile_size(tensor_return_value); + uint32_t input_page_size = detail::tile_size(tensor_return_value) + misalignment; tt::tt_metal::CircularBufferConfig cb_src0_config = tt::tt_metal::CircularBufferConfig(num_input_pages_to_read * input_page_size, {{src0_cb_index, cb_data_format}}) @@ -650,6 +659,7 @@ PermuteDeviceOperation::MultiCoreTiledGeneric::cached_program_t PermuteDeviceOpe (uint32_t)needs_x_padding, (uint32_t)needs_y_padding, non_x_rows, + read_alignment, }; tt::tt_metal::KernelHandle unary_reader_kernel_id = tt::tt_metal::CreateKernel( @@ -664,9 +674,11 @@ PermuteDeviceOperation::MultiCoreTiledGeneric::cached_program_t PermuteDeviceOpe x_blocks, w_blocks, input_shape[N - 2], + read_alignment, + subtile_line_bytes, }; - bool fp32_dest_acc_en = cb_data_format_output == tt::DataFormat::Float32; + bool fp32_dest_acc_en = cb_data_format == tt::DataFormat::Float32; auto compute_kernel_id = tt::tt_metal::CreateKernel( program, "ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/compute/transpose_xw_tiled.cpp",