diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py index b198fc020604..4370b449fd6c 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py @@ -1120,3 +1120,32 @@ def test_transpose_16411(device): assert_with_pcc(p_c2, ttnn.to_torch(c2), 0.9999) assert_with_pcc(p_c3, ttnn.to_torch(c3), 0.9999) assert_with_pcc(p_c4, ttnn.to_torch(c4), 0.9999) + + +@pytest.mark.parametrize("rank", [5]) +@pytest.mark.parametrize("indices", [[0, 1], [0, 2], [0, 3], [0, 4], [1, 2], [1, 3], [1, 4], [2, 3], [2, 4], [3, 4]]) +@pytest.mark.parametrize("layout", [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT]) +def test_transpose_high_rank(*, device: ttnn.Device, rank: int, indices, layout): + torch.manual_seed(2005) + ttnn.disable_and_clear_program_cache(device) + ttnn.enable_program_cache(device) + + shape = [2] * rank + + a = torch.randn(shape, dtype=torch.bfloat16) + b = torch.randn(shape, dtype=torch.bfloat16) + + tt_a = ttnn.from_torch(a, device=device, layout=layout) + tt_b = ttnn.from_torch(b, device=device, layout=layout) + + a = a.transpose(*indices) + b = b.transpose(*indices) + + tt_a = ttnn.transpose(tt_a, *indices) + tt_b = ttnn.transpose(tt_b, *indices) + + output_a = ttnn.to_torch(tt_a) + output_b = ttnn.to_torch(tt_b) + + assert torch.allclose(a, output_a) + assert torch.allclose(b, output_b) diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.hpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.hpp index db9dec262677..ae3a4553527e 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.hpp @@ -108,6 +108,7 @@ struct PermuteDeviceOperation { struct shared_variables_t { tt::tt_metal::KernelHandle unary_reader_kernel_id; tt::tt_metal::KernelHandle unary_writer_kernel_id; + tt::tt_metal::KernelHandle compute_kernel_id; CoreRangeSet core_range; }; using cached_program_t = ttnn::device_operation::CachedProgram; diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_rm_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_rm_program_factory.cpp index 1ec9befaf941..9362c1d75eba 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_rm_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_rm_program_factory.cpp @@ -132,7 +132,10 @@ PermuteDeviceOperation::MultiCoreRowInvariant::cached_program_t PermuteDeviceOpe return { std::move(program), - {.unary_reader_kernel_id = unary_reader_kernel_id, .unary_writer_kernel_id = unary_writer_kernel_id}}; + {.unary_reader_kernel_id = unary_reader_kernel_id, + .unary_writer_kernel_id = unary_writer_kernel_id, + .core_range = all_cores}, + }; } void PermuteDeviceOperation::MultiCoreRowInvariant::override_runtime_arguments( diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_tiled_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_tiled_program_factory.cpp index 7f2bcadd83be..c0920fe9a1e4 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_tiled_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_tiled_program_factory.cpp @@ -179,11 +179,13 @@ PermuteDeviceOperation::MultiCoreTileInvariant::cached_program_t PermuteDeviceOp auto inv_perm = detail::get_inverse_permutation(operation_attributes.dims); std::vector reader_runtime_args = {src_buffer->address(), 0, 0}; + reader_runtime_args.insert(reader_runtime_args.end(), output_shape_view.begin(), output_shape_view.end()); reader_runtime_args.insert(reader_runtime_args.end(), inv_perm.begin(), inv_perm.end()); reader_runtime_args.insert(reader_runtime_args.end(), input_tile_strides.begin(), input_tile_strides.end()); std::vector writer_runtime_args = {dst_buffer->address(), 0, 0}; + std::vector compute_runtime_args = {0}; auto cores = corerange_to_cores(all_cores, std::nullopt); @@ -216,7 +218,10 @@ PermuteDeviceOperation::MultiCoreTileInvariant::cached_program_t PermuteDeviceOp return { std::move(program), - {.unary_reader_kernel_id = unary_reader_kernel_id, .unary_writer_kernel_id = unary_writer_kernel_id}}; + {.unary_reader_kernel_id = unary_reader_kernel_id, + .unary_writer_kernel_id = unary_writer_kernel_id, + .compute_kernel_id = compute_kernel_id, + .core_range = all_cores}}; } void PermuteDeviceOperation::MultiCoreTileInvariant::override_runtime_arguments( @@ -239,6 +244,7 @@ void PermuteDeviceOperation::MultiCoreTileInvariant::override_runtime_arguments( for (const auto& core : cores) { auto& runtime_args = tt::tt_metal::GetRuntimeArgs(program, unary_reader_kernel_id, core); runtime_args[0] = src_buffer->address(); + auto& runtime_args_writer = tt::tt_metal::GetRuntimeArgs(program, unary_writer_kernel_id, core); runtime_args_writer[0] = dst_buffer->address(); } @@ -492,7 +498,10 @@ PermuteDeviceOperation::MultiCoreTileRowInvariant::create( return { std::move(program), - {.unary_reader_kernel_id = unary_reader_kernel_id, .unary_writer_kernel_id = unary_writer_kernel_id}}; + {.unary_reader_kernel_id = unary_reader_kernel_id, + .unary_writer_kernel_id = unary_writer_kernel_id, + .compute_kernel_id = compute_kernel_id, + .core_range = all_cores}}; } void PermuteDeviceOperation::MultiCoreTileRowInvariant::override_runtime_arguments( @@ -855,7 +864,10 @@ PermuteDeviceOperation::MultiCoreTiledGeneric::cached_program_t PermuteDeviceOpe return { std::move(program), - {.unary_reader_kernel_id = unary_reader_kernel_id, .unary_writer_kernel_id = unary_writer_kernel_id}}; + {.unary_reader_kernel_id = unary_reader_kernel_id, + .unary_writer_kernel_id = unary_writer_kernel_id, + .compute_kernel_id = compute_kernel_id, + .core_range = all_cores}}; } void PermuteDeviceOperation::MultiCoreTiledGeneric::override_runtime_arguments( diff --git a/ttnn/cpp/ttnn/operations/data_movement/transpose/transpose.cpp b/ttnn/cpp/ttnn/operations/data_movement/transpose/transpose.cpp index c11535f19814..1b8bb37ea154 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/transpose/transpose.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/transpose/transpose.cpp @@ -26,24 +26,17 @@ inline Tensor transpose_( TransposeOpDim transpose_dim, const MemoryConfig& output_mem_config, const std::optional& pad_value) { - bool tiled_only = false; - constexpr uint32_t FACE_WIDTH = - tt::constants::FACE_WIDTH; // this is a highly restrictive constraint on the RM transpose_wh kernel, and with - // all the other bugs/limitations we should rewrite it - // use device->get_allocator_alignment when the it reflects the alignment of the buffer and doesn't just default to - // DRAM - auto BUFFER_ALIGNMENT = a.buffer()->buffer_type() == tt::tt_metal::BufferType::DRAM ? hal::get_dram_alignment() - : hal::get_l1_alignment(); - uint32_t W = a.get_padded_shape()[-1]; - uint32_t H = a.get_padded_shape()[-2]; + auto prim_permute = [&](const ttnn::Tensor& input, ttnn::SmallVector dims) -> ttnn::Tensor { + return ttnn::prim::permute(input, dims, output_mem_config, std::nullopt, pad_value); + }; + + bool interleaved_rm = !a.is_sharded() && a.layout() == Layout::ROW_MAJOR; switch (transpose_dim) { case TransposeOpDim::HC: - tiled_only = a.get_layout() == Layout::TILE; - if ((!tiled_only) && ((W * a.element_size()) % BUFFER_ALIGNMENT != 0)) { // - tiled_only = true; + if (interleaved_rm) { + return prim_permute(a, ttnn::SmallVector{0, 2, 1, 3}); } break; - // bubble dim around to make it possible as these implementations don't have a kernel case TransposeOpDim::NH: return ttnn::permute( (const ttnn::Tensor)a, ttnn::SmallVector({2, 1, 0, 3}), output_mem_config, pad_value); @@ -54,32 +47,18 @@ inline Tensor transpose_( return ttnn::permute( (const ttnn::Tensor)a, ttnn::SmallVector({0, 3, 2, 1}), output_mem_config, pad_value); case TransposeOpDim::CN: - tiled_only = true; // CN only has a tiled implementation at the moment + if (interleaved_rm) { + return prim_permute(a, ttnn::SmallVector({1, 0, 2, 3})); + } break; case TransposeOpDim::WH: - if (!a.is_sharded() && a.layout() == Layout::ROW_MAJOR) { - return ttnn::prim::permute( - a, ttnn::SmallVector({0, 1, 3, 2}), output_mem_config, std::nullopt); + if (interleaved_rm) { + return prim_permute(a, ttnn::SmallVector({0, 1, 3, 2})); } break; default: break; } - if (a.get_layout() == Layout::ROW_MAJOR) { - // the assorted cases where only tiled works right now (HC with stick width constraint, WH with stick width - // constraint, CN). - if (tiled_only) { - // convert to tiled - Tensor b = ttnn::to_layout(a, Layout::TILE, std::nullopt, std::nullopt, (IDevice*)nullptr); - // run the transpose. - b = operation::run(Transpose{transpose_dim, output_mem_config, pad_value}, {b}).at(0); - // back to original layout - b = ttnn::to_layout(b, a.get_layout(), std::nullopt, std::nullopt, (IDevice*)nullptr); - return b; - } - return operation::run(Transpose{transpose_dim, output_mem_config, pad_value}, {a}).at(0); - } else { - return operation::run(Transpose{transpose_dim, output_mem_config, pad_value}, {a}).at(0); - } + return operation::run(Transpose{transpose_dim, output_mem_config, pad_value}, {a}).at(0); } ttnn::Tensor transpose_nd(