From e6d86ea39f655dd8854e60b96a45aec55fb1f18a Mon Sep 17 00:00:00 2001 From: Saad Jameel Date: Tue, 10 Dec 2024 05:36:27 +0000 Subject: [PATCH] #0: make row-invariant permute kernel multicore --- ..._permute_interleaved_rm_row_invariant.cpp} | 6 +- ..._permute_interleaved_rm_row_invariant.cpp} | 12 ++-- .../device/permute_device_operation.cpp | 6 +- .../device/permute_device_operation.hpp | 5 +- .../device/permute_program_factory.cpp | 65 +++++++++++++------ 5 files changed, 62 insertions(+), 32 deletions(-) rename ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/{reader_permute_interleaved_rm.cpp => reader_permute_interleaved_rm_row_invariant.cpp} (78%) rename ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/{writer_permute_interleaved_rm.cpp => writer_permute_interleaved_rm_row_invariant.cpp} (83%) diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm_row_invariant.cpp similarity index 78% rename from ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm.cpp rename to ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm_row_invariant.cpp index b5ffc12cf7d8..93a42f813251 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm_row_invariant.cpp @@ -12,14 +12,16 @@ void kernel_main() { constexpr uint32_t num_rows = get_compile_time_arg_val(3); const uint32_t src_addr = get_arg_val(0); + const uint32_t start_row = get_arg_val(1); + const uint32_t end_row = get_arg_val(2); const InterleavedAddrGen s0 = {.bank_base_address = src_addr, .page_size = page_size}; uint32_t curr_addr = src_addr; - for (uint32_t i = 0; i < num_rows; ++i) { + for (uint32_t row = start_row; row < end_row; ++row) { cb_reserve_back(tt::CBIndex::c_0, 1); uint32_t src_buffer_l1_addr = get_write_ptr(tt::CBIndex::c_0); - noc_async_read_page(i, s0, src_buffer_l1_addr); + noc_async_read_page(row, s0, src_buffer_l1_addr); noc_async_read_barrier(); cb_push_back(tt::CBIndex::c_0, 1); } diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_row_invariant.cpp similarity index 83% rename from ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm.cpp rename to ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_row_invariant.cpp index 34be75dfdf45..46903375ff67 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_row_invariant.cpp @@ -12,19 +12,21 @@ void kernel_main() { constexpr uint32_t num_rows = get_compile_time_arg_val(3); const uint32_t dst_addr = get_arg_val(0); + const uint32_t start_row = get_arg_val(1); + const uint32_t end_row = get_arg_val(2); const InterleavedAddrGen s0 = {.bank_base_address = dst_addr, .page_size = page_size}; uint32_t input_shape[N], perm[N], dest_strides[N]; - for (uint32_t i = 1; i <= N; i++) { - input_shape[i - 1] = get_arg_val(i); - perm[i - 1] = get_arg_val(i + N); - dest_strides[i - 1] = get_arg_val(i + 2 * N); + for (uint32_t i = 3; i < N + 3; i++) { + input_shape[i - 3] = get_arg_val(i); + perm[i - 3] = get_arg_val(i + N); + dest_strides[i - 3] = get_arg_val(i + 2 * N); } uint32_t src_buffer_l1_addr = get_write_ptr(tt::CBIndex::c_0); uint32_t curr_addr = dst_addr; - for (uint32_t row = 0; row < num_rows; ++row) { + for (uint32_t row = start_row; row < end_row; ++row) { // Compute multi-dimensional index for the source row uint32_t src_multi_idx[N]; size_t remaining = row; diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.cpp index af6ee177fca7..8bc4bece3b01 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.cpp @@ -12,9 +12,11 @@ namespace ttnn::operations::data_movement { PermuteDeviceOperation::program_factory_t PermuteDeviceOperation::select_program_factory( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + // If the last dimension is not permuted, we can use the row-invariant kernel if (operation_attributes.dims.back() == tensor_args.input_tensor.get_logical_shape().rank() - 1) { - return SingleCore{}; + return MultiCoreRowInvariant{}; } + // Otherwise, we need to use the blocked generic, row moving kernel return MultiCoreBlockedGeneric{}; } @@ -33,7 +35,7 @@ void PermuteDeviceOperation::validate_on_program_cache_hit( PermuteDeviceOperation::shape_return_value_t PermuteDeviceOperation::compute_output_shapes( const operation_attributes_t& attributes, const tensor_args_t& tensor_args) { - SmallVector shape, padded_shape; + SmallVector shape; auto input_shape = tensor_args.input_tensor.get_logical_shape(); shape.reserve(input_shape.rank()); for (auto dim : attributes.dims) { diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.hpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.hpp index 36f9328688dd..e27b8251bdc1 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.hpp @@ -30,11 +30,12 @@ struct PermuteDeviceOperation { using tensor_return_value_t = Tensor; - struct SingleCore { + struct MultiCoreRowInvariant { // Shared variables are the variables that are shared between the create and override_runtime_arguments methods struct shared_variables_t { KernelHandle unary_reader_kernel_id; KernelHandle unary_writer_kernel_id; + CoreRangeSet all_cores; }; using cached_program_t = ttnn::device_operation::CachedProgram; @@ -72,7 +73,7 @@ struct PermuteDeviceOperation { tensor_return_value_t& tensor_return_value); }; - using program_factory_t = std::variant; + using program_factory_t = std::variant; // Mandatory methods diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp index 0d726d45be03..5ead48cb7cb8 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp @@ -32,7 +32,7 @@ std::vector get_row_strides(const ttnn::SimpleShape& shape) { } // namespace detail -PermuteDeviceOperation::SingleCore::cached_program_t PermuteDeviceOperation::SingleCore::create( +PermuteDeviceOperation::MultiCoreRowInvariant::cached_program_t PermuteDeviceOperation::MultiCoreRowInvariant::create( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args, tensor_return_value_t& tensor_return_value) { @@ -60,54 +60,78 @@ PermuteDeviceOperation::SingleCore::cached_program_t PermuteDeviceOperation::Sin uint32_t src0_cb_index = tt::CBIndex::c_0; uint32_t num_input_pages_to_read = 2; - CoreRange core({0, 0}, {0, 0}); + uint32_t num_rows = input_tensor.volume() / input_tensor.get_logical_shape()[-1]; + + auto compute_with_storage_grid_size = input_tensor.device()->compute_with_storage_grid_size(); + auto [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] = + tt::tt_metal::split_work_to_cores(compute_with_storage_grid_size, num_rows); + tt::tt_metal::CircularBufferConfig cb_src0_config = tt::tt_metal::CircularBufferConfig( num_input_pages_to_read * input_rm_page_size, {{src0_cb_index, cb_data_format}}) .set_page_size(src0_cb_index, input_rm_page_size); - auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, core, cb_src0_config); + auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, all_cores, cb_src0_config); uint32_t N = operation_attributes.dims.size(); - uint32_t num_rows = input_tensor.volume() / input_tensor.get_logical_shape()[-1]; bool src_is_dram = src_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; std::vector reader_compile_time_args = {(uint32_t)src_is_dram, N, input_rm_page_size, num_rows}; tt::tt_metal::KernelHandle unary_reader_kernel_id = tt::tt_metal::CreateKernel( program, - "ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm.cpp", - core, + "ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/" + "reader_permute_interleaved_rm_row_invariant.cpp", + all_cores, tt::tt_metal::ReaderDataMovementConfig(reader_compile_time_args)); bool dst_is_dram = dst_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; std::vector writer_compile_time_args = {(std::uint32_t)dst_is_dram, N, output_rm_page_size, num_rows}; tt::tt_metal::KernelHandle unary_writer_kernel_id = tt::tt_metal::CreateKernel( program, - "ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm.cpp", - core, + "ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/" + "writer_permute_interleaved_rm_row_invariant.cpp", + all_cores, tt::tt_metal::WriterDataMovementConfig(writer_compile_time_args)); - std::vector reader_runtime_args = {src_buffer->address()}; - - tt::tt_metal::SetRuntimeArgs(program, unary_reader_kernel_id, core, reader_runtime_args); + std::vector reader_runtime_args = {src_buffer->address(), 0, 0}; auto input_shape_view = input_tensor.get_logical_shape().view(); auto output_strides = detail::get_row_strides(output_tensor.get_logical_shape()); // in anticipation of RM padding - std::vector writer_runtime_args = {dst_buffer->address()}; + std::vector writer_runtime_args = {dst_buffer->address(), 0, 0}; writer_runtime_args.insert(writer_runtime_args.end(), input_shape_view.begin(), input_shape_view.end()); writer_runtime_args.insert( writer_runtime_args.end(), operation_attributes.dims.begin(), operation_attributes.dims.end()); writer_runtime_args.insert(writer_runtime_args.end(), output_strides.begin(), output_strides.end()); - tt::tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, writer_runtime_args); + auto cores = corerange_to_cores(all_cores, std::nullopt); + uint32_t start_row = 0; + uint32_t num_rows_per_core = 0; + for (const auto& core : cores) { + if (core_group_1.contains(core)) { + num_rows_per_core = num_tiles_per_core_group_1; + } else if (core_group_2.contains(core)) { + num_rows_per_core = num_tiles_per_core_group_2; + } else { + // no-op + num_rows_per_core = 0; + } + uint32_t end_row = start_row + num_rows_per_core; + reader_runtime_args[1] = start_row; + reader_runtime_args[2] = end_row; + writer_runtime_args[1] = start_row; + writer_runtime_args[2] = end_row; + tt::tt_metal::SetRuntimeArgs(program, unary_reader_kernel_id, core, reader_runtime_args); + tt::tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, writer_runtime_args); + start_row = end_row; + } return { std::move(program), {.unary_reader_kernel_id = unary_reader_kernel_id, .unary_writer_kernel_id = unary_writer_kernel_id}}; } -void PermuteDeviceOperation::SingleCore::override_runtime_arguments( +void PermuteDeviceOperation::MultiCoreRowInvariant::override_runtime_arguments( cached_program_t& cached_program, const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args, @@ -121,15 +145,14 @@ void PermuteDeviceOperation::SingleCore::override_runtime_arguments( auto src_buffer = input_tensor.buffer(); auto dst_buffer = output_tensor.buffer(); + auto& all_cores = cached_program.shared_variables.all_cores; - { - auto& runtime_args = tt::tt_metal::GetRuntimeArgs(program, unary_reader_kernel_id, CoreCoord{0, 0}); + auto cores = corerange_to_cores(all_cores, std::nullopt); + for (const auto& core : cores) { + auto& runtime_args = tt::tt_metal::GetRuntimeArgs(program, unary_reader_kernel_id, core); runtime_args[0] = src_buffer->address(); - } - - { - auto& runtime_args = tt::tt_metal::GetRuntimeArgs(program, unary_writer_kernel_id, CoreCoord{0, 0}); - runtime_args[0] = dst_buffer->address(); + auto& runtime_args_writer = tt::tt_metal::GetRuntimeArgs(program, unary_writer_kernel_id, core); + runtime_args_writer[0] = dst_buffer->address(); } }