diff --git a/tests/ttnn/unit_tests/gtests/ccl/kernels/edm_fabric_writer.cpp b/tests/ttnn/unit_tests/gtests/ccl/kernels/edm_fabric_writer.cpp new file mode 100644 index 00000000000..8fdcfc5e302 --- /dev/null +++ b/tests/ttnn/unit_tests/gtests/ccl/kernels/edm_fabric_writer.cpp @@ -0,0 +1,213 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/fabric_connection_manager.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/noc_addr.hpp" +#include "dataflow_api.h" + +#include +#include + +static constexpr bool enable_start_synchronization = get_compile_time_arg_val(0) != 0; +static constexpr bool enable_finish_synchronization = get_compile_time_arg_val(1) != 0; +static constexpr bool enable_any_synchronization = enable_start_synchronization || enable_finish_synchronization; + +FORCE_INLINE void line_sync( + FabricConnectionManager& fabric_connection, + volatile tt::fabric::PacketHeader* mcast_fwd_packet_header, + volatile tt::fabric::PacketHeader* mcast_bwd_packet_header, + size_t sync_bank_addr, + size_t sync_noc_x, + size_t sync_noc_y, + size_t sync_val) { + using namespace tt::fabric; + mcast_fwd_packet_header->to_atomic_inc(); + mcast_bwd_packet_header->to_atomic_inc(); + + if (fabric_connection.has_forward_connection()) { + mcast_fwd_packet_header->to_noc_unicast_atomic_inc(NocUnicastAtomicIncCommandHeader{ + sync_bank_addr, 1, 128, static_cast(sync_noc_x), static_cast(sync_noc_y)}); + fabric_connection.get_forward_connection().wait_for_empty_write_slot(); + fabric_connection.get_forward_connection().send_payload_flush_non_blocking_from_address( + (uint32_t)mcast_fwd_packet_header, sizeof(tt::fabric::PacketHeader)); + } + + if (fabric_connection.has_backward_connection()) { + mcast_bwd_packet_header->to_noc_unicast_atomic_inc(NocUnicastAtomicIncCommandHeader{ + sync_bank_addr, 1, 128, static_cast(sync_noc_x), static_cast(sync_noc_y)}); + fabric_connection.get_backward_connection().wait_for_empty_write_slot(); + fabric_connection.get_backward_connection().send_payload_flush_non_blocking_from_address( + (uint32_t)mcast_bwd_packet_header, sizeof(tt::fabric::PacketHeader)); + } + noc_semaphore_inc(get_noc_addr(sync_noc_x, sync_noc_y, sync_bank_addr), 1); + if (sync_noc_x == my_x[0] && sync_noc_y == my_y[0]) { + noc_semaphore_wait_min(reinterpret_cast(sync_bank_addr), sync_val); + } +} + +void kernel_main() { + using namespace tt::fabric; + size_t arg_idx = 0; + + const size_t dest_bank_addr = get_arg_val(arg_idx++); + const size_t packet_payload_size_bytes = get_arg_val(arg_idx++); + const size_t dest_noc_x = get_arg_val(arg_idx++); + const size_t dest_noc_y = get_arg_val(arg_idx++); + + const size_t num_mcasts = get_arg_val(arg_idx++); + const size_t mcast_fwd_hops = get_arg_val(arg_idx++); + const size_t mcast_bwd_hops = get_arg_val(arg_idx++); + + const size_t num_unicasts = get_arg_val(arg_idx++); + const size_t unicast_hops = get_arg_val(arg_idx++); + const bool unicast_is_fwd = get_arg_val(arg_idx++) != 0; + + const size_t source_l1_cb_index = get_arg_val(arg_idx++); + const size_t packet_header_cb = get_arg_val(arg_idx++); + const size_t packet_header_size_in_headers = get_arg_val(arg_idx++); + + auto fabric_connection = FabricConnectionManager::build_from_args(arg_idx); + + ASSERT(fabric_connection.is_logically_connected()); + + if (!fabric_connection.is_logically_connected()) { + return; + } + size_t sync_noc_x = 0; + size_t sync_noc_y = 0; + size_t sync_bank_addr = 0; + size_t total_workers_per_sync = 0; + if (enable_any_synchronization) { + sync_noc_x = get_arg_val(arg_idx++); + sync_noc_y = get_arg_val(arg_idx++); + sync_bank_addr = get_arg_val(arg_idx++); + total_workers_per_sync = get_arg_val(arg_idx++); + } + + const size_t start_sync_val = total_workers_per_sync; + const size_t finish_sync_val = 3 * total_workers_per_sync; + + fabric_connection.open(); + + cb_reserve_back(source_l1_cb_index, 1); + cb_reserve_back(packet_header_cb, packet_header_size_in_headers); + const auto source_l1_buffer_address = get_write_ptr(source_l1_cb_index); + const auto packet_header_buffer_address = get_write_ptr(packet_header_cb); + + auto* mcast_fwd_packet_header = reinterpret_cast(packet_header_buffer_address); + auto* mcast_bwd_packet_header = + reinterpret_cast(packet_header_buffer_address + sizeof(tt::fabric::PacketHeader)); + auto* unicast_packet_header = + reinterpret_cast(packet_header_buffer_address + sizeof(tt::fabric::PacketHeader) * 2); + mcast_fwd_packet_header->to_write().to_chip_multicast( + MulticastRoutingCommandHeader{1, static_cast(mcast_fwd_hops)}); + mcast_bwd_packet_header->to_write().to_chip_multicast( + MulticastRoutingCommandHeader{1, static_cast(mcast_bwd_hops)}); + + if (enable_start_synchronization) { + line_sync( + fabric_connection, + mcast_fwd_packet_header, + mcast_bwd_packet_header, + sync_bank_addr, + sync_noc_x, + sync_noc_y, + start_sync_val); + line_sync( + fabric_connection, + mcast_fwd_packet_header, + mcast_bwd_packet_header, + sync_bank_addr, + sync_noc_x, + sync_noc_y, + 2 * start_sync_val); + } + + mcast_fwd_packet_header->to_write().to_chip_multicast( + MulticastRoutingCommandHeader{1, static_cast(mcast_fwd_hops)}); + mcast_bwd_packet_header->to_write().to_chip_multicast( + MulticastRoutingCommandHeader{1, static_cast(mcast_bwd_hops)}); + unicast_packet_header->to_atomic_inc().to_chip_unicast( + UnicastRoutingCommandHeader{static_cast(unicast_hops)}); + + { + DeviceZoneScopedN("MAIN-WRITE-ZONE"); + for (size_t i = 0; i < num_mcasts; i++) { + noc_async_write( + source_l1_buffer_address, + safe_get_noc_addr(static_cast(dest_noc_x), static_cast(dest_noc_y), dest_bank_addr), + packet_payload_size_bytes); + if (fabric_connection.has_forward_connection()) { + DeviceZoneScopedN("WR-FWD"); + mcast_fwd_packet_header->to_noc_unicast(NocUnicastCommandHeader{ + dest_bank_addr, + packet_payload_size_bytes + sizeof(tt::fabric::PacketHeader), + static_cast(dest_noc_x), + static_cast(dest_noc_y)}); + { + DeviceZoneScopedN("WR-FWD-WAIT"); + fabric_connection.get_forward_connection().wait_for_empty_write_slot(); + } + fabric_connection.get_forward_connection().send_payload_without_header_non_blocking_from_address( + source_l1_buffer_address, packet_payload_size_bytes); + fabric_connection.get_forward_connection().send_payload_flush_non_blocking_from_address( + (uint32_t)mcast_fwd_packet_header, sizeof(tt::fabric::PacketHeader)); + } + + if (fabric_connection.has_backward_connection()) { + DeviceZoneScopedN("WR-BWD"); + mcast_bwd_packet_header->to_noc_unicast(NocUnicastCommandHeader{ + dest_bank_addr, + packet_payload_size_bytes + sizeof(tt::fabric::PacketHeader), + static_cast(dest_noc_x), + static_cast(dest_noc_y)}); + { + DeviceZoneScopedN("WR-BWD-WAIT"); + fabric_connection.get_backward_connection().wait_for_empty_write_slot(); + } + fabric_connection.get_backward_connection().send_payload_without_header_non_blocking_from_address( + source_l1_buffer_address, packet_payload_size_bytes); + fabric_connection.get_backward_connection().send_payload_flush_non_blocking_from_address( + (uint32_t)mcast_bwd_packet_header, sizeof(tt::fabric::PacketHeader)); + } + { + noc_async_writes_flushed(); + } + } + } + + for (size_t i = 0; i < num_unicasts; i++) { + DeviceZoneScopedN("UNICAST-WRITE"); + auto& fabric_conn = + unicast_is_fwd ? fabric_connection.get_forward_connection() : fabric_connection.get_backward_connection(); + unicast_packet_header->to_noc_unicast(NocUnicastCommandHeader{ + dest_bank_addr, + packet_payload_size_bytes, + static_cast(dest_noc_x), + static_cast(dest_noc_y)}); + fabric_conn.wait_for_empty_write_slot(); + fabric_conn.send_payload_without_header_non_blocking_from_address( + source_l1_buffer_address, packet_payload_size_bytes); + fabric_conn.send_payload_flush_blocking_from_address( + (uint32_t)unicast_packet_header, sizeof(tt::fabric::PacketHeader)); + } + + if (enable_finish_synchronization) { + line_sync( + fabric_connection, + mcast_fwd_packet_header, + mcast_bwd_packet_header, + sync_bank_addr, + sync_noc_x, + sync_noc_y, + finish_sync_val); + } + + { + DeviceZoneScopedN("WR-CLOSE"); + fabric_connection.close(); + } + noc_async_write_barrier(); +} diff --git a/tests/ttnn/unit_tests/gtests/ccl/kernels/fabric_erisc_datamover_sender_worker_sender.cpp b/tests/ttnn/unit_tests/gtests/ccl/kernels/fabric_erisc_datamover_sender_worker_sender.cpp index 86f71c8b31d..bd9b986c2f3 100644 --- a/tests/ttnn/unit_tests/gtests/ccl/kernels/fabric_erisc_datamover_sender_worker_sender.cpp +++ b/tests/ttnn/unit_tests/gtests/ccl/kernels/fabric_erisc_datamover_sender_worker_sender.cpp @@ -150,8 +150,6 @@ void kernel_main() { packet_header.reserved2 = 0x1111; // debug only } - uint64_t buffer_address = sender.edm_buffer_addr + - (*sender.buffer_index_ptr * (sender.buffer_size_bytes + sizeof(eth_channel_sync_t))); sender.send_payload_blocking_from_address(packet_addr, packet_size); noc_async_writes_flushed(); cb_pop_front(cb_id_in0, pages_to_send); diff --git a/tests/ttnn/unit_tests/gtests/ccl/test_fabric_erisc_data_mover_loopback_with_workers.cpp b/tests/ttnn/unit_tests/gtests/ccl/test_fabric_erisc_data_mover_loopback_with_workers.cpp index aa628a931c4..410c3206ee5 100644 --- a/tests/ttnn/unit_tests/gtests/ccl/test_fabric_erisc_data_mover_loopback_with_workers.cpp +++ b/tests/ttnn/unit_tests/gtests/ccl/test_fabric_erisc_data_mover_loopback_with_workers.cpp @@ -9,6 +9,7 @@ #include #include #include +#include "tt-metalium/kernel_types.hpp" #include "tt_metal/test_utils/df/df.hpp" #include "tt_metal/test_utils/env_vars.hpp" #include "ttnn/common/constants.hpp" @@ -223,6 +224,18 @@ std::tuple, std::vector> build_input_buffer( return {local_input_buffer, inputs}; } +static void build_and_enqueue(const std::vector& devices, std::vector& programs) { + TT_FATAL( + devices.size() == programs.size(), + "Number of devices must match number of programs when calling build_and_enqueue in test"); + for (size_t i = 0; i < devices.size(); i++) { + tt::tt_metal::detail::CompileProgram(devices[i], programs[i]); + } + for (size_t i = 0; i < devices.size(); i++) { + tt_metal::EnqueueProgram(devices[i]->command_queue(), programs[i], false); + } +} + struct EthLinkHop { CoreCoord hop_src; CoreCoord hop_dest; @@ -1074,12 +1087,7 @@ void setup_test_with_persistent_fabric( log_info(tt::LogTest, "Building EDM kernels"); line_fabric->build_kernels(); - for (size_t i = 0; i < devices.size(); i++) { - tt::tt_metal::detail::CompileProgram(devices[i], fabric_programs->at(i)); - } - for (size_t i = 0; i < devices.size(); i++) { - tt_metal::EnqueueProgram(devices[i]->command_queue(), fabric_programs->at(i), false); - } + build_and_enqueue(devices, *fabric_programs); } } @@ -1470,7 +1478,8 @@ bool TestMultiInputReaderKernel( /// MESSAGE COUNT TERMINATION MODE //////////////////////////////////////////////////////////////////// -TEST(WorkerFabricEdmDatapath, FabricEDMLoopback_With_Workers_SingleMessage) { +// Disabled non persistent fabric tests - non-persistent fabric mode not supported +TEST(WorkerFabricEdmDatapath, DISABLED_FabricEDMLoopback_With_Workers_SingleMessage) { const uint32_t page_size = 2048; const uint32_t num_pages_total = 1; const bool src_is_dram = true; @@ -1481,7 +1490,8 @@ TEST(WorkerFabricEdmDatapath, FabricEDMLoopback_With_Workers_SingleMessage) { } // Will wrapp sender but not receiver buffers -TEST(WorkerFabricEdmDatapath, FabricEDMLoopback_With_Workers_2_messages) { +// Disabled non persistent fabric tests - non-persistent fabric mode not supported +TEST(WorkerFabricEdmDatapath, DISABLED_FabricEDMLoopback_With_Workers_2_messages) { const uint32_t page_size = 2048; const uint32_t num_pages_total = 2; const bool src_is_dram = true; @@ -1491,7 +1501,8 @@ TEST(WorkerFabricEdmDatapath, FabricEDMLoopback_With_Workers_2_messages) { ASSERT_EQ(result, 0); } // Will wrapp sender but not receiver buffers -TEST(WorkerFabricEdmDatapath, FabricEDMLoopback_With_Workers_10_messages) { +// Disabled non persistent fabric tests - non-persistent fabric mode not supported +TEST(WorkerFabricEdmDatapath, DISABLED_FabricEDMLoopback_With_Workers_10_messages) { const uint32_t page_size = 2048; const uint32_t num_pages_total = 10; const bool src_is_dram = true; @@ -1502,7 +1513,8 @@ TEST(WorkerFabricEdmDatapath, FabricEDMLoopback_With_Workers_10_messages) { } // Will wrapp sender and receiver buffers -TEST(WorkerFabricEdmDatapath, FabricEDMLoopback_With_Workers_20_messages) { +// Disabled non persistent fabric tests - non-persistent fabric mode not supported +TEST(WorkerFabricEdmDatapath, DISABLED_FabricEDMLoopback_With_Workers_20_messages) { const uint32_t page_size = 2048; const uint32_t num_pages_total = 20; const bool src_is_dram = true; @@ -1512,7 +1524,8 @@ TEST(WorkerFabricEdmDatapath, FabricEDMLoopback_With_Workers_20_messages) { ASSERT_EQ(result, 0); } -TEST(WorkerFabricEdmDatapath, FabricEDMLoopback_With_Workers) { +// Disabled non persistent fabric tests - non-persistent fabric mode not supported +TEST(WorkerFabricEdmDatapath, DISABLED_FabricEDMLoopback_With_Workers) { const uint32_t page_size = 2048; const uint32_t num_pages_total = 10000; const bool src_is_dram = true; @@ -1580,7 +1593,8 @@ TEST(WorkerFabricEdmDatapath, FabricEDMLoopback_With_Workers_PersistentFabric) { //////////////////////////////// -TEST(WorkerFabricEdmDatapath, LineFabricMcast_SingleMessage_SingleSource) { +// Disabled non persistent fabric tests - non-persistent fabric mode not supported +TEST(WorkerFabricEdmDatapath, DISABLED_LineFabricMcast_SingleMessage_SingleSource) { const uint32_t page_size = 2048; const uint32_t num_pages_total = 1; const bool src_is_dram = true; @@ -1595,7 +1609,8 @@ TEST(WorkerFabricEdmDatapath, LineFabricMcast_SingleMessage_SingleSource) { } // Non-functional on harvested parts. Needs testing on unharvested parts. -TEST(WorkerFabricEdmDatapath, LineFabricMcast_ManyMessages_SingleSource) { +// Disabled non persistent fabric tests - non-persistent fabric mode not supported +TEST(WorkerFabricEdmDatapath, DISABLED_LineFabricMcast_ManyMessages_SingleSource) { const uint32_t page_size = 2048; const uint32_t num_pages_total = 10000; const bool src_is_dram = true; @@ -2809,6 +2824,13 @@ TEST(CclAsyncOp, ReduceScatterSmall_PersistentFabric) { log_info(tt::LogTest, "Finished"); } +static void wait_for_worker_subdevice_program_completion( + const std::vector& devices, const std::optional& subdevice_managers) { + std::ranges::for_each(devices, [&](IDevice* d) { + tt_metal::Finish(d->command_queue(), {subdevice_managers->worker_subdevice_id.at(d->id())}); + }); +} + #include "ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.hpp" void run_all_gather_with_persistent_fabric(const size_t dim, const size_t num_links, ttnn::Shape const& input_shape) { log_info(tt::LogTest, "entering test"); @@ -2892,10 +2914,7 @@ void run_all_gather_with_persistent_fabric(const size_t dim, const size_t num_li true); // wait for op completion - log_info(tt::LogTest, "Waiting for Op finish"); - std::ranges::for_each(devices, [&](IDevice* d) { - tt_metal::Finish(d->command_queue(), {subdevice_managers->worker_subdevice_id.at(d->id())}); - }); + wait_for_worker_subdevice_program_completion(devices, subdevice_managers); log_info(tt::LogTest, "Main op done"); log_info(tt::LogTest, "Fabric teardown"); @@ -2923,3 +2942,707 @@ TEST(CclAsyncOp, DISABLED_AllGather_PersistentFabric_Dim3_Links2_Shape1_1_32_128 TEST(CclAsyncOp, DISABLED_AllGather_PersistentFabric_Dim3_Links2_Shape1_1_32_8192) { run_all_gather_with_persistent_fabric(3, 2, ttnn::Shape({1, 1, 32, 8192})); } + +struct WriteThroughputStabilityTestWithPersistentFabricParams { + size_t line_size = 4; + size_t num_devices_with_workers = 0; + bool line_sync = false; +}; + +void RunWriteThroughputStabilityTestWithPersistentFabric( + size_t num_mcasts, + size_t num_unicasts, + size_t num_links, + size_t num_op_invocations, + const WriteThroughputStabilityTestWithPersistentFabricParams& params = {}) { + auto arch = tt::get_arch_from_string(tt::test_utils::get_umd_arch_name()); + auto num_devices = tt::tt_metal::GetNumAvailableDevices(); + if (num_devices < 4) { + log_info("This test can only be run on T3000 devices"); + return; + } + if (arch == tt::ARCH::GRAYSKULL) { + log_info("Test must be run on WH"); + return; + } + + size_t line_size = params.line_size; + size_t num_devices_with_workers = params.num_devices_with_workers; + if (num_devices_with_workers == 0) { + num_devices_with_workers = line_size; + } + using namespace ttnn::ccl; + TT_FATAL(num_devices_with_workers <= line_size, "num_devices_with_workers must be less than or equal to num_links"); + + if (params.line_sync) { + TT_FATAL(num_op_invocations == 1, "Performance reporting only supported for 1 invocation per test"); + } + + auto worker_core_logical = [](size_t link) { return CoreCoord(link, 0); }; + + // static constexpr size_t source_l1_buffer_address = 1000000; + static constexpr uint32_t packet_header_cb_index = tt::CB::c_in0; + static constexpr uint32_t source_payload_cb_index = tt::CB::c_in1; + static constexpr size_t packet_header_cb_size_in_headers = 4; + static constexpr bool enable_persistent_fabric_mode = true; + static constexpr size_t packet_payload_size_bytes = 4096; + static constexpr size_t dest_buffer_size = packet_payload_size_bytes * 4; + static constexpr tt::DataFormat cb_df = tt::DataFormat::Bfp8; + + T3000TestDevice test_fixture; + auto view = test_fixture.mesh_device_->get_view(); + + // Get the inner 4 device ring on a WH T3K device so that we can use both links for all devices + std::vector devices_ = { + view.get_device(0, 1), view.get_device(0, 2), view.get_device(1, 2), view.get_device(1, 1)}; + std::vector devices; + devices.reserve(line_size); + for (size_t i = 0; i < line_size; i++) { + devices.push_back(devices_[i]); + } + // build the mesh device + + // Persistent Fabric Setup + std::vector dummy_worker_programs; + std::optional subdevice_managers = std::nullopt; + std::optional> fabric_programs; + std::vector fabric_program_ptrs; + std::optional fabric_handle; + setup_test_with_persistent_fabric( + devices, + dummy_worker_programs, + subdevice_managers, + fabric_programs, + fabric_program_ptrs, + fabric_handle, + enable_persistent_fabric_mode, + num_links); + + // Other boiler plate setup + CoreRangeSet worker_cores = CoreRangeSet(CoreRange(CoreCoord(0, 0), CoreCoord(num_links - 1, 0))); + auto worker_cores_vec = corerange_to_cores(worker_cores, std::nullopt, false); + auto dest_core_coord = CoreCoord(2, 2); + auto sync_core_coord = CoreCoord(0, 0); + + ttnn::SmallVector> device_dest_buffers; + device_dest_buffers.reserve(line_size); + for (auto* d : devices) { + auto local_input_buffer = + CreateBuffer(InterleavedBufferConfig{d, dest_buffer_size, dest_buffer_size, BufferType::L1}); + device_dest_buffers.push_back(local_input_buffer); + } + + size_t dest_bank_addr = device_dest_buffers[0]->address(); + TT_FATAL( + std::all_of( + device_dest_buffers.begin(), + device_dest_buffers.end(), + [dest_bank_addr](const auto& buffer) { return buffer->address() == dest_bank_addr; }), + "Test setup error: all destination buffers must have the same bank address across devices"); + + auto global_semaphores = ttnn::global_semaphore::create_global_semaphore_with_same_address( + test_fixture.mesh_device_.get(), + devices[0]->worker_cores(HalProgrammableCoreType::TENSIX, SubDeviceId{0}), + 0, // initial value + tt::tt_metal::BufferType::L1, // buffer type + 1000 // attempts + ); + auto global_semaphore_addr = + ttnn::global_semaphore::get_global_semaphore_address(global_semaphores.global_semaphores.at(0)); + + std::vector worker_devices; + for (size_t i = 0; i < num_devices_with_workers; i++) { + worker_devices.push_back(devices[i]); + } + // Worker program setup + std::vector programs(num_devices_with_workers); + TT_FATAL( + programs.size() == worker_devices.size(), + "Test misconfiguration. Mismatch in line size and devices. Expected line size of {} but got {} devices " + "instead.", + line_size, + worker_devices.size()); + for (size_t i = 0; i < num_devices_with_workers; i++) { + const size_t line_index = i; + auto& program = programs[i]; + auto* device = devices[i]; + const size_t dest_noc_x = device->worker_core_from_logical_core(dest_core_coord).x; + const size_t dest_noc_y = device->worker_core_from_logical_core(dest_core_coord).y; + const size_t sync_core_noc_x = device->worker_core_from_logical_core(sync_core_coord).x; + const size_t sync_core_noc_y = device->worker_core_from_logical_core(sync_core_coord).y; + + IDevice* backward_device = i == 0 ? nullptr : devices[i - 1]; + IDevice* forward_device = i == line_size - 1 ? nullptr : devices[i + 1]; + + // Initialize the fabric handle for worker connection + bool start_of_line = line_index == 0; + bool end_of_line = line_index == line_size - 1; + bool has_forward_connection = !end_of_line; + bool has_backward_connection = !start_of_line; + bool unicast_forward = !end_of_line; + size_t mcast_fwd_hops = line_size - line_index - 1; + size_t mcast_bwd_hops = line_index; + size_t unicast_hops = unicast_forward ? mcast_fwd_hops : mcast_bwd_hops; + + auto local_device_fabric_handle = + ttnn::ccl::EdmLineFabricOpInterface::build_program_builder_worker_connection_fabric( + device, forward_device, backward_device, &program, enable_persistent_fabric_mode, num_links); + + // reserve CB + tt_metal::CircularBufferConfig cb_src0_config = + tt_metal::CircularBufferConfig( + packet_header_cb_size_in_headers * sizeof(tt::fabric::PacketHeader), {{packet_header_cb_index, cb_df}}) + .set_page_size(packet_header_cb_index, sizeof(tt::fabric::PacketHeader)); + CBHandle sender_workers_cb = CreateCircularBuffer(program, worker_cores, cb_src0_config); + + tt_metal::CircularBufferConfig cb_src1_config = + tt_metal::CircularBufferConfig(packet_payload_size_bytes, {{source_payload_cb_index, cb_df}}) + .set_page_size(source_payload_cb_index, packet_payload_size_bytes); + CBHandle sender_workers_payload_cb = CreateCircularBuffer(program, worker_cores, cb_src1_config); + + TT_FATAL( + local_device_fabric_handle.get_num_links() == num_links, + "Error in test setup. Expected two links between devices but got {} links for device {}", + local_device_fabric_handle.get_num_links(), + device->id()); + + std::vector worker_ct_args = {params.line_sync, params.line_sync}; + + auto worker_kernel_id = tt_metal::CreateKernel( + program, + "tests/ttnn/unit_tests/gtests/ccl/kernels/edm_fabric_writer.cpp", + worker_cores, + tt_metal::WriterDataMovementConfig(worker_ct_args)); + for (size_t l = 0; l < num_links; l++) { + auto worker_core = worker_cores_vec[l]; + auto build_connection_args = [&local_device_fabric_handle, device, &program, &worker_core]( + bool is_connected_in_direction, + ttnn::ccl::EdmLineFabricOpInterface::Direction direction, + std::vector& rt_args_out) { + rt_args_out.push_back(is_connected_in_direction); + if (is_connected_in_direction) { + const auto connection = local_device_fabric_handle.uniquely_connect_worker(device, direction); + const auto new_rt_args = + ttnn::ccl::worker_detail::generate_edm_connection_rt_args(connection, program, {worker_core}); + log_info( + tt::LogTest, + "On device: {}, connecting to EDM fabric in {} direction. EDM noc_x: {}, noc_y: {}", + device->id(), + direction, + connection.edm_noc_x, + connection.edm_noc_y); + std::copy(new_rt_args.begin(), new_rt_args.end(), std::back_inserter(rt_args_out)); + } + }; + // RT ARGS + std::vector rt_args = { + dest_bank_addr, + packet_payload_size_bytes, + dest_noc_x, + dest_noc_y, + + num_mcasts, + mcast_fwd_hops, + mcast_bwd_hops, + + num_unicasts, + unicast_hops, + unicast_forward, + + source_payload_cb_index, // source_l1_buffer_address, + packet_header_cb_index, + packet_header_cb_size_in_headers, + }; + + build_connection_args(has_forward_connection, ttnn::ccl::EdmLineFabricOpInterface::FORWARD, rt_args); + build_connection_args(has_backward_connection, ttnn::ccl::EdmLineFabricOpInterface::BACKWARD, rt_args); + + if (params.line_sync) { + rt_args.push_back(sync_core_noc_x); + rt_args.push_back(sync_core_noc_y); + rt_args.push_back(global_semaphore_addr); + rt_args.push_back(num_links * num_devices_with_workers /*line_size*/); + } + + tt_metal::SetRuntimeArgs(program, worker_kernel_id, worker_core, rt_args); + } + } + + for (size_t i = 0; i < num_op_invocations; i++) { + log_info(tt::LogTest, "Iteration: {}", i); + build_and_enqueue(worker_devices, programs); + + log_info(tt::LogTest, "Waiting for Op finish on all devices"); + wait_for_worker_subdevice_program_completion(worker_devices, subdevice_managers); + log_info(tt::LogTest, "Main op done"); + } + + TT_FATAL(fabric_programs->size() == devices.size(), "Expected fabric programs size to be same as devices size"); + log_info(tt::LogTest, "Fabric teardown"); + persistent_fabric_teardown_sequence( + devices, subdevice_managers, fabric_handle.value(), tt::fabric::TerminationSignal::IMMEDIATELY_TERMINATE); + + log_info(tt::LogTest, "Waiting for teardown completion"); + for (IDevice* d : devices) { + tt_metal::Synchronize(d, ttnn::DefaultQueueId); + } + for (size_t i = 0; i < programs.size(); i++) { + auto d = worker_devices[i]; + auto& program = programs[i]; + tt_metal::DumpDeviceProfileResults(d, program); + } + for (size_t i = 0; i < fabric_programs->size(); i++) { + auto d = devices[i]; + auto& program = fabric_programs.value()[i]; + tt_metal::DumpDeviceProfileResults(d, program); + } + log_info(tt::LogTest, "Finished"); +} + +TEST(EdmFabric, BasicMcastThroughputTest_SingleLink_LineSize2_SingleMcast) { + const size_t num_mcasts = 1; + const size_t num_unicasts = 2; + const size_t num_links = 1; + const size_t num_op_invocations = 1; + const bool line_sync = false; + WriteThroughputStabilityTestWithPersistentFabricParams params; + params.line_sync = line_sync; + params.line_size = 2; + RunWriteThroughputStabilityTestWithPersistentFabric( + num_mcasts, num_unicasts, num_links, num_op_invocations, params); +} + +TEST(EdmFabric, BasicMcastThroughputTest_SingleMcast) { + const size_t num_mcasts = 1; + const size_t num_unicasts = 2; + const size_t num_links = 2; + const size_t num_op_invocations = 1; + const bool line_sync = false; + WriteThroughputStabilityTestWithPersistentFabricParams params; + params.line_sync = line_sync; + RunWriteThroughputStabilityTestWithPersistentFabric( + num_mcasts, num_unicasts, num_links, num_op_invocations, params); +} +TEST(EdmFabric, BasicMcastThroughputTest_SenderFullNoWrap_ReceiverNoWrap) { + const size_t num_mcasts = 9; + const size_t num_unicasts = 0; + const size_t num_links = 1; + const size_t num_op_invocations = 1; + const bool line_sync = false; + WriteThroughputStabilityTestWithPersistentFabricParams params; + params.line_sync = line_sync; + RunWriteThroughputStabilityTestWithPersistentFabric( + num_mcasts, num_unicasts, num_links, num_op_invocations, params); +} +TEST(EdmFabric, BasicMcastThroughputTest_SenderOneElemWrap_ReceiverNoWrap_2Device) { + const size_t num_mcasts = 10; + const size_t num_unicasts = 0; + const size_t num_links = 1; + const size_t num_op_invocations = 1; + const size_t line_size = 2; + const bool line_sync = false; + WriteThroughputStabilityTestWithPersistentFabricParams params; + params.line_size = line_size; + params.line_sync = line_sync; + RunWriteThroughputStabilityTestWithPersistentFabric( + num_mcasts, num_unicasts, num_links, num_op_invocations, params); +} +TEST(EdmFabric, BasicMcastThroughputTest_SenderOneElemWrap_ReceiverNoWrap) { + const size_t num_mcasts = 10; + const size_t num_unicasts = 0; + const size_t num_links = 1; + const size_t num_op_invocations = 1; + const bool line_sync = false; + WriteThroughputStabilityTestWithPersistentFabricParams params; + params.line_sync = line_sync; + RunWriteThroughputStabilityTestWithPersistentFabric( + num_mcasts, num_unicasts, num_links, num_op_invocations, params); +} +TEST(EdmFabric, BasicMcastThroughputTest_SenderTwiceFilled_ReceiverOnceFilled_2Device) { + const size_t num_mcasts = 18; + const size_t num_unicasts = 0; + const size_t num_links = 1; + const size_t num_op_invocations = 1; + const size_t line_size = 2; + const bool line_sync = false; + WriteThroughputStabilityTestWithPersistentFabricParams params; + params.line_size = line_size; + params.line_sync = line_sync; + RunWriteThroughputStabilityTestWithPersistentFabric( + num_mcasts, num_unicasts, num_links, num_op_invocations, params); +} +TEST(EdmFabric, BasicMcastThroughputTest_SenderTwiceFilled_ReceiverOnceFilled) { + const size_t num_mcasts = 18; + const size_t num_unicasts = 0; + const size_t num_links = 1; + const size_t num_op_invocations = 1; + const bool line_sync = false; + WriteThroughputStabilityTestWithPersistentFabricParams params; + params.line_sync = line_sync; + RunWriteThroughputStabilityTestWithPersistentFabric( + num_mcasts, num_unicasts, num_links, num_op_invocations, params); +} +TEST(EdmFabric, BasicMcastThroughputTest_SenderTwoWrap_ReceiverOneWrap) { + const size_t num_mcasts = 19; + const size_t num_unicasts = 0; + const size_t num_links = 1; + const size_t num_op_invocations = 1; + const bool line_sync = false; + WriteThroughputStabilityTestWithPersistentFabricParams params; + params.line_sync = line_sync; + RunWriteThroughputStabilityTestWithPersistentFabric( + num_mcasts, num_unicasts, num_links, num_op_invocations, params); +} + +TEST(EdmFabric, BasicMcastThroughputTest_SingleLink_LineSize2_SingleMcast_LineSync) { + const size_t num_mcasts = 1; + const size_t num_unicasts = 2; + const size_t num_links = 1; + const size_t num_op_invocations = 1; + const bool line_sync = true; + WriteThroughputStabilityTestWithPersistentFabricParams params; + params.line_sync = line_sync; + RunWriteThroughputStabilityTestWithPersistentFabric( + num_mcasts, num_unicasts, num_links, num_op_invocations, params); +} + +TEST(EdmFabric, BasicMcastThroughputTest_SingleMcast_LineSync) { + const size_t num_mcasts = 1; + const size_t num_unicasts = 2; + const size_t num_links = 2; + const size_t num_op_invocations = 1; + const bool line_sync = true; + WriteThroughputStabilityTestWithPersistentFabricParams params; + params.line_sync = line_sync; + RunWriteThroughputStabilityTestWithPersistentFabric( + num_mcasts, num_unicasts, num_links, num_op_invocations, params); +} +TEST(EdmFabric, BasicMcastThroughputTest_SenderFullNoWrap_ReceiverNoWrap_LineSync) { + const size_t num_mcasts = 9; + const size_t num_unicasts = 0; + const size_t num_links = 1; + const size_t num_op_invocations = 1; + const bool line_sync = true; + WriteThroughputStabilityTestWithPersistentFabricParams params; + params.line_sync = line_sync; + RunWriteThroughputStabilityTestWithPersistentFabric( + num_mcasts, num_unicasts, num_links, num_op_invocations, params); +} +TEST(EdmFabric, BasicMcastThroughputTest_SenderOneElemWrap_ReceiverNoWrap_2Device_LineSync) { + const size_t num_mcasts = 10; + const size_t num_unicasts = 0; + const size_t num_links = 1; + const size_t num_op_invocations = 1; + const size_t line_size = 2; + const bool line_sync = true; + WriteThroughputStabilityTestWithPersistentFabricParams params; + params.line_size = line_size; + params.line_sync = line_sync; + RunWriteThroughputStabilityTestWithPersistentFabric( + num_mcasts, num_unicasts, num_links, num_op_invocations, params); +} +TEST(EdmFabric, BasicMcastThroughputTest_SenderOneElemWrap_ReceiverNoWrap_LineSync) { + const size_t num_mcasts = 10; + const size_t num_unicasts = 0; + const size_t num_links = 1; + const size_t num_op_invocations = 1; + const bool line_sync = true; + WriteThroughputStabilityTestWithPersistentFabricParams params; + params.line_sync = line_sync; + RunWriteThroughputStabilityTestWithPersistentFabric( + num_mcasts, num_unicasts, num_links, num_op_invocations, params); +} +TEST(EdmFabric, BasicMcastThroughputTest_SenderTwiceFilled_ReceiverOnceFilled_2Device_LineSync) { + const size_t num_mcasts = 18; + const size_t num_unicasts = 0; + const size_t num_links = 1; + const size_t num_op_invocations = 1; + const size_t line_size = 2; + const bool line_sync = true; + WriteThroughputStabilityTestWithPersistentFabricParams params; + params.line_size = line_size; + params.line_sync = line_sync; + RunWriteThroughputStabilityTestWithPersistentFabric( + num_mcasts, num_unicasts, num_links, num_op_invocations, params); +} +TEST(EdmFabric, BasicMcastThroughputTest_SenderTwiceFilled_ReceiverOnceFilled_LineSync) { + const size_t num_mcasts = 18; + const size_t num_unicasts = 0; + const size_t num_links = 1; + const size_t num_op_invocations = 1; + const bool line_sync = true; + WriteThroughputStabilityTestWithPersistentFabricParams params; + params.line_sync = line_sync; + RunWriteThroughputStabilityTestWithPersistentFabric( + num_mcasts, num_unicasts, num_links, num_op_invocations, params); +} +TEST(EdmFabric, BasicMcastThroughputTest_SenderFourTImesFilled_ReceiverTwiceFilled_2Device_1Worker) { + const size_t num_mcasts = 36; + const size_t num_unicasts = 0; + const size_t num_links = 1; + const size_t num_op_invocations = 1; + const size_t line_size = 2; + const bool line_sync = false; + WriteThroughputStabilityTestWithPersistentFabricParams params; + params.line_size = line_size; + params.line_sync = line_sync; + params.num_devices_with_workers = 1; + RunWriteThroughputStabilityTestWithPersistentFabric( + num_mcasts, num_unicasts, num_links, num_op_invocations, params); +} +TEST(EdmFabric, BasicMcastThroughputTest_SenderFourTImesFilled_ReceiverTwiceFilled_2Device_LineSync) { + const size_t num_mcasts = 36; + const size_t num_unicasts = 0; + const size_t num_links = 1; + const size_t num_op_invocations = 1; + const size_t line_size = 2; + const bool line_sync = true; + WriteThroughputStabilityTestWithPersistentFabricParams params; + params.line_size = line_size; + params.line_sync = line_sync; + RunWriteThroughputStabilityTestWithPersistentFabric( + num_mcasts, num_unicasts, num_links, num_op_invocations, params); +} +TEST(EdmFabric, BasicMcastThroughputTest_SenderFourTImesFilled_ReceiverTwiceFilled_LineSync) { + const size_t num_mcasts = 36; + const size_t num_unicasts = 0; + const size_t num_links = 1; + const size_t num_op_invocations = 1; + const bool line_sync = true; + WriteThroughputStabilityTestWithPersistentFabricParams params; + params.line_sync = line_sync; + RunWriteThroughputStabilityTestWithPersistentFabric( + num_mcasts, num_unicasts, num_links, num_op_invocations, params); +} +TEST(EdmFabric, BasicMcastThroughputTest_SenderTwoWrap_ReceiverOneWrap_LineSync) { + const size_t num_mcasts = 19; + const size_t num_unicasts = 0; + const size_t num_links = 1; + const size_t num_op_invocations = 1; + const bool line_sync = true; + WriteThroughputStabilityTestWithPersistentFabricParams params; + params.line_sync = line_sync; + RunWriteThroughputStabilityTestWithPersistentFabric( + num_mcasts, num_unicasts, num_links, num_op_invocations, params); +} + +TEST(EdmFabric, BasicMcastThroughputTest_SmallPerf_2Device) { + const size_t num_mcasts = 70; + const size_t num_unicasts = 0; + const size_t num_links = 2; + const size_t num_op_invocations = 1; + const size_t line_size = 2; + const bool report_performance = true; + WriteThroughputStabilityTestWithPersistentFabricParams params; + params.line_sync = report_performance; + params.line_size = line_size; + RunWriteThroughputStabilityTestWithPersistentFabric( + num_mcasts, num_unicasts, num_links, num_op_invocations, params); +} + +TEST(EdmFabric, BasicMcastThroughputTest_SmallPerf0) { + const size_t num_mcasts = 70; + const size_t num_unicasts = 0; + const size_t num_links = 2; + const size_t num_op_invocations = 1; + WriteThroughputStabilityTestWithPersistentFabricParams params; + params.line_sync = true; + RunWriteThroughputStabilityTestWithPersistentFabric( + num_mcasts, num_unicasts, num_links, num_op_invocations, params); +} +TEST(EdmFabric, BasicMcastThroughputTest_SmallPerf1) { + const size_t num_mcasts = 70; + const size_t num_unicasts = 0; + const size_t num_links = 2; + const size_t num_op_invocations = 1; + WriteThroughputStabilityTestWithPersistentFabricParams params; + params.line_sync = true; + RunWriteThroughputStabilityTestWithPersistentFabric( + num_mcasts, num_unicasts, num_links, num_op_invocations, params); +} + +TEST(EdmFabric, BasicMcastThroughputTest_0) { + const size_t num_mcasts = 100; + const size_t num_unicasts = 2; + const size_t num_links = 2; + const size_t num_op_invocations = 1; + const bool line_sync = false; + WriteThroughputStabilityTestWithPersistentFabricParams params; + params.line_size = 2; + RunWriteThroughputStabilityTestWithPersistentFabric( + num_mcasts, num_unicasts, num_links, num_op_invocations, params); +} +TEST(EdmFabric, BasicMcastThroughputTest_1) { + const size_t num_mcasts = 1000; + const size_t num_unicasts = 2; + const size_t num_links = 2; + const size_t num_op_invocations = 1; + const bool line_sync = false; + RunWriteThroughputStabilityTestWithPersistentFabric(num_mcasts, num_unicasts, num_links, num_op_invocations); +} +TEST(EdmFabric, BasicMcastThroughputTest_2) { + const size_t num_mcasts = 50000; + const size_t num_unicasts = 2; + const size_t num_links = 2; + const size_t num_op_invocations = 1; + + RunWriteThroughputStabilityTestWithPersistentFabric(num_mcasts, num_unicasts, num_links, num_op_invocations); +} +TEST(EdmFabric, BasicMcastThroughputTest_3) { + const size_t num_mcasts = 200000; + const size_t num_unicasts = 2; + const size_t num_links = 2; + const size_t num_op_invocations = 1; + RunWriteThroughputStabilityTestWithPersistentFabric(num_mcasts, num_unicasts, num_links, num_op_invocations); +} +TEST(EdmFabric, BasicMcastThroughputTest_4) { + const size_t num_mcasts = 800000; + const size_t num_unicasts = 2; + const size_t num_links = 2; + const size_t num_op_invocations = 1; + RunWriteThroughputStabilityTestWithPersistentFabric(num_mcasts, num_unicasts, num_links, num_op_invocations); +} + +TEST(EdmFabric, BasicMcastThroughputTest_5) { + const size_t num_mcasts = 1; + const size_t num_unicasts = 2; + const size_t num_links = 2; + const size_t num_op_invocations = 20000; + RunWriteThroughputStabilityTestWithPersistentFabric(num_mcasts, num_unicasts, num_links, num_op_invocations); +} +// DISABLED due to long runtime +TEST(EdmFabric, DISABLED_BasicMcastThroughputTest_6) { + const size_t num_mcasts = 100; + const size_t num_unicasts = 2; + const size_t num_links = 2; + const size_t num_op_invocations = 8000; + RunWriteThroughputStabilityTestWithPersistentFabric(num_mcasts, num_unicasts, num_links, num_op_invocations); +} +// DISABLED due to long runtime +TEST(EdmFabric, DISABLED_BasicMcastThroughputTest_7) { + const size_t num_mcasts = 1000; + const size_t num_unicasts = 2; + const size_t num_links = 2; + const size_t num_op_invocations = 1000; + RunWriteThroughputStabilityTestWithPersistentFabric(num_mcasts, num_unicasts, num_links, num_op_invocations); +} +// DISABLED due to long runtime +TEST(EdmFabric, DISABLED_BasicMcastThroughputTest_8) { + const size_t num_mcasts = 50000; + const size_t num_unicasts = 2; + const size_t num_links = 2; + const size_t num_op_invocations = 200; + RunWriteThroughputStabilityTestWithPersistentFabric(num_mcasts, num_unicasts, num_links, num_op_invocations); +} +// DISABLED due to long runtime +TEST(EdmFabric, DISABLED_BasicMcastThroughputTest_9) { + const size_t num_mcasts = 200000; + const size_t num_unicasts = 2; + const size_t num_links = 2; + const size_t num_op_invocations = 150; + RunWriteThroughputStabilityTestWithPersistentFabric(num_mcasts, num_unicasts, num_links, num_op_invocations); +} +// DISABLED due to long runtime +TEST(EdmFabric, DISABLED_BasicMcastThroughputTest_10) { + const size_t num_mcasts = 800000; + const size_t num_unicasts = 2; + const size_t num_links = 2; + const size_t num_op_invocations = 50; + RunWriteThroughputStabilityTestWithPersistentFabric(num_mcasts, num_unicasts, num_links, num_op_invocations); +} +// DISABLED due to long runtime +TEST(EdmFabric, BasicMcastThroughputTest_6_Short) { + const size_t num_mcasts = 100; + const size_t num_unicasts = 2; + const size_t num_links = 2; + const size_t num_op_invocations = 100; + RunWriteThroughputStabilityTestWithPersistentFabric(num_mcasts, num_unicasts, num_links, num_op_invocations); +} +// DISABLED due to long runtime +TEST(EdmFabric, BasicMcastThroughputTest_7_Short) { + const size_t num_mcasts = 1000; + const size_t num_unicasts = 2; + const size_t num_links = 2; + const size_t num_op_invocations = 50; + RunWriteThroughputStabilityTestWithPersistentFabric(num_mcasts, num_unicasts, num_links, num_op_invocations); +} +// DISABLED due to long runtime +TEST(EdmFabric, BasicMcastThroughputTest_8_Short) { + const size_t num_mcasts = 50000; + const size_t num_unicasts = 2; + const size_t num_links = 2; + const size_t num_op_invocations = 20; + RunWriteThroughputStabilityTestWithPersistentFabric(num_mcasts, num_unicasts, num_links, num_op_invocations); +} +// DISABLED due to long runtime +TEST(EdmFabric, BasicMcastThroughputTest_9_Short) { + const size_t num_mcasts = 200000; + const size_t num_unicasts = 2; + const size_t num_links = 2; + const size_t num_op_invocations = 10; + RunWriteThroughputStabilityTestWithPersistentFabric(num_mcasts, num_unicasts, num_links, num_op_invocations); +} +// DISABLED due to long runtime +TEST(EdmFabric, BasicMcastThroughputTest_10_Short) { + const size_t num_mcasts = 800000; + const size_t num_unicasts = 2; + const size_t num_links = 2; + const size_t num_op_invocations = 5; + RunWriteThroughputStabilityTestWithPersistentFabric(num_mcasts, num_unicasts, num_links, num_op_invocations); +} + +TEST(EdmFabric, BasicMcastThroughputTest_0_WithLineSync) { + const size_t num_mcasts = 100; + const size_t num_unicasts = 2; + const size_t num_links = 2; + const size_t num_op_invocations = 1; + const bool line_sync = true; + WriteThroughputStabilityTestWithPersistentFabricParams params; + params.line_sync = line_sync; + RunWriteThroughputStabilityTestWithPersistentFabric( + num_mcasts, num_unicasts, num_links, num_op_invocations, params); +} +TEST(EdmFabric, BasicMcastThroughputTest_1_WithLineSync) { + const size_t num_mcasts = 1000; + const size_t num_unicasts = 2; + const size_t num_links = 2; + const size_t num_op_invocations = 1; + const bool line_sync = true; + WriteThroughputStabilityTestWithPersistentFabricParams params; + params.line_sync = line_sync; + RunWriteThroughputStabilityTestWithPersistentFabric( + num_mcasts, num_unicasts, num_links, num_op_invocations, params); +} +TEST(EdmFabric, BasicMcastThroughputTest_2_WithLineSync) { + const size_t num_mcasts = 50000; + const size_t num_unicasts = 2; + const size_t num_links = 2; + const size_t num_op_invocations = 1; + const bool line_sync = true; + WriteThroughputStabilityTestWithPersistentFabricParams params; + params.line_sync = line_sync; + RunWriteThroughputStabilityTestWithPersistentFabric( + num_mcasts, num_unicasts, num_links, num_op_invocations, params); +} +TEST(EdmFabric, BasicMcastThroughputTest_3_WithLineSync) { + const size_t num_mcasts = 200000; + const size_t num_unicasts = 2; + const size_t num_links = 2; + const size_t num_op_invocations = 1; + const bool line_sync = true; + WriteThroughputStabilityTestWithPersistentFabricParams params; + params.line_sync = line_sync; + RunWriteThroughputStabilityTestWithPersistentFabric( + num_mcasts, num_unicasts, num_links, num_op_invocations, params); +} +TEST(EdmFabric, BasicMcastThroughputTest_4_WithLineSync) { + const size_t num_mcasts = 800000; + const size_t num_unicasts = 2; + const size_t num_links = 2; + const size_t num_op_invocations = 1; + const bool line_sync = true; + WriteThroughputStabilityTestWithPersistentFabricParams params; + params.line_sync = line_sync; + RunWriteThroughputStabilityTestWithPersistentFabric( + num_mcasts, num_unicasts, num_links, num_op_invocations, params); +} diff --git a/tests/ttnn/unit_tests/operations/ccl/test_ccl_common.py b/tests/ttnn/unit_tests/operations/ccl/test_ccl_common.py index b3e79cb95be..65fa2a49b73 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_ccl_common.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_ccl_common.py @@ -7,7 +7,12 @@ def create_and_load_sub_device_manager_with_fabric_interface( - mesh_device, worker_sub_devices, ccl_worker_sub_device_id, local_allocator_size, enable_persistent_fabric=True + mesh_device, + worker_sub_devices, + ccl_worker_sub_device_id, + local_allocator_size, + enable_persistent_fabric=True, + wrap_fabric_around_mesh=False, ): assert ccl_worker_sub_device_id < len(worker_sub_devices) mesh_sub_device_manager_id, fabric_subdevice_id = mesh_device.create_sub_device_manager_with_fabric( @@ -16,7 +21,7 @@ def create_and_load_sub_device_manager_with_fabric_interface( # fabric sub-device id can also be queried from device, no need to explicitly pass it in mesh_device.load_sub_device_manager(mesh_sub_device_manager_id) if enable_persistent_fabric: - ttnn.initialize_edm_fabric(mesh_device) + ttnn.initialize_edm_fabric(mesh_device, wrap_fabric_around_mesh=wrap_fabric_around_mesh) return mesh_sub_device_manager_id diff --git a/tests/ttnn/unit_tests/operations/ccl/test_new_all_gather.py b/tests/ttnn/unit_tests/operations/ccl/test_new_all_gather.py index 7b0f1d04629..3b11f56b80a 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_new_all_gather.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_new_all_gather.py @@ -139,6 +139,7 @@ def run_all_gather_impl( cluster_axis=None, create_persistent_fabric=True, teardown_persistent_fabric=True, + wrap_fabric_around_mesh=False, ): enable_persistent_fabric = True if num_iters < 1: @@ -162,7 +163,12 @@ def run_all_gather_impl( sub_device_stall_group = [worker_sub_device_id] if create_persistent_fabric: mesh_sub_device_manager_id = create_and_load_sub_device_manager_with_fabric_interface( - mesh_device, [worker_sub_device], 0, 0, enable_persistent_fabric + mesh_device, + [worker_sub_device], + 0, + 0, + enable_persistent_fabric, + wrap_fabric_around_mesh=wrap_fabric_around_mesh, ) mesh_device.set_sub_device_stall_group(sub_device_stall_group) @@ -284,6 +290,7 @@ def run_all_gather_impl( ttnn.synchronize_devices(mesh_device, sub_device_ids=sub_device_stall_group) logger.info(f"Done iteration {i}") + passed = True for tensor_index in range(len(tt_out_tensor_list)): tt_out_tensor = tt_out_tensor_list[tensor_index] output_tensor = output_tensor_goldens_list[tensor_index] @@ -297,12 +304,15 @@ def run_all_gather_impl( eq, output = comp_pcc(tt_output_tensor, output_tensor) if not eq: logger.error(f"output mismatch for tensor {i}") - assert eq, f"{i} FAILED: {output}" + passed = False if enable_persistent_fabric and teardown_persistent_fabric: mesh_device.reset_sub_device_stall_group() teardown_fabric_interface(mesh_device) + if not passed: + assert eq, f"{i} FAILED: {output}" + # Enumerate the post-commit cases explicitly @skip_for_grayskull("Requires eth connected devices to run") @@ -312,12 +322,14 @@ def run_all_gather_impl( (4, 1, [1, 1, 64, 512], 3, ttnn.TILE_LAYOUT), # (4, 1, [1, 1, 32, 32768], 3, ttnn.TILE_LAYOUT), # (4, 1, [1, 1, 2048, 16384], 3, ttnn.TILE_LAYOUT), + (4, 1, [1, 1, 32, 1280], 3, ttnn.TILE_LAYOUT), ], ) @pytest.mark.parametrize( "input_dtype", [ ttnn.bfloat16, + ttnn.bfloat8_b, ], ) @pytest.mark.parametrize( @@ -353,10 +365,12 @@ def test_all_gather( layout, use_program_cache, function_level_defaults, - all_gather_topology=ttnn.Topology.Ring, + all_gather_topology=ttnn.Topology.Linear, num_iters=num_iters, enable_async=enable_async, rand_tensor=True, + create_persistent_fabric=True, + teardown_persistent_fabric=True, mem_config=mem_config, ) @@ -432,8 +446,7 @@ def test_all_gather( @pytest.mark.parametrize("num_iters", [8]) @pytest.mark.parametrize("enable_async", [True]) def test_all_gather_sharded( - t3k_mesh_device, - # pcie_mesh_device, + pcie_mesh_device, num_devices, output_shape, dim, @@ -449,7 +462,7 @@ def test_all_gather_sharded( tensor_mem_layout, ): run_all_gather_impl( - t3k_mesh_device, + pcie_mesh_device, num_devices, output_shape, dim, @@ -458,7 +471,7 @@ def test_all_gather_sharded( layout, use_program_cache, function_level_defaults, - all_gather_topology=ttnn.Topology.Ring, + all_gather_topology=ttnn.Topology.Linear, num_iters=num_iters, enable_async=enable_async, rand_tensor=True, @@ -467,6 +480,7 @@ def test_all_gather_sharded( tensor_mem_layout=tensor_mem_layout, create_persistent_fabric=True, teardown_persistent_fabric=True, + wrap_fabric_around_mesh=True, ) diff --git a/ttnn/cpp/ttnn/operations/ccl/ccl_pybind.cpp b/ttnn/cpp/ttnn/operations/ccl/ccl_pybind.cpp index a0470fd0185..adbd4c341ad 100644 --- a/ttnn/cpp/ttnn/operations/ccl/ccl_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/ccl_pybind.cpp @@ -18,7 +18,12 @@ void py_bind_common(pybind11::module& module) { .value("Ring", ttnn::ccl::Topology::Ring) .value("Linear", ttnn::ccl::Topology::Linear); - module.def("initialize_edm_fabric", &ttnn::ccl::initialize_edm_fabric, py::arg("mesh_device"), py::kw_only()); + module.def( + "initialize_edm_fabric", + &ttnn::ccl::initialize_edm_fabric, + py::arg("mesh_device"), + py::kw_only(), + py::arg("wrap_fabric_around_mesh") = false); module.def("teardown_edm_fabric", &ttnn::ccl::teardown_edm_fabric, py::arg("mesh_device"), py::kw_only()); } diff --git a/ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.cpp b/ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.cpp index a7270f3ee67..72a48f32827 100644 --- a/ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.cpp @@ -1038,6 +1038,24 @@ static void log_command_stream(ttnn::ccl::cmd::CclHostLowLevelCommandSequence co } } +std::vector generate_edm_connection_rt_args( + ttnn::ccl::SenderWorkerAdapterSpec const& connection_info, + Program &program, + CoreRangeSet worker_cores) { + std::vector new_rt_args; + auto worker_flow_control_semaphore_id = CreateSemaphore(program, worker_cores, 0); + auto worker_teardown_semaphore_id = CreateSemaphore(program, worker_cores, 0); + auto worker_buffer_index_semaphore_id = CreateSemaphore(program, worker_cores, 0); + append_worker_to_fabric_edm_sender_rt_args( + connection_info, + worker_flow_control_semaphore_id, + worker_teardown_semaphore_id, + worker_buffer_index_semaphore_id, + new_rt_args); + + return new_rt_args; +} + void generate_multi_input_command_stream_kernel_rt_args( Program& program, KernelHandle kernel_id, @@ -1147,27 +1165,15 @@ void generate_multi_input_command_stream_kernel_rt_args( rt_args.push_back(forward_fabric_connections.has_value()); if (forward_fabric_connections.has_value()) { - auto sender_worker_flow_control_semaphore_id = CreateSemaphore(program, worker_core_range, 0); - auto sender_worker_teardown_semaphore_id = CreateSemaphore(program, worker_core_range, 0); - auto sender_worker_buffer_index_semaphore_id = CreateSemaphore(program, worker_core_range, 0); - append_worker_to_fabric_edm_sender_rt_args( - forward_fabric_connections.value(), - sender_worker_flow_control_semaphore_id, - sender_worker_teardown_semaphore_id, - sender_worker_buffer_index_semaphore_id, - rt_args); + const auto new_rt_args = + generate_edm_connection_rt_args(*forward_fabric_connections, program, worker_core_range); + std::copy(new_rt_args.begin(), new_rt_args.end(), std::back_inserter(rt_args)); } rt_args.push_back(backward_fabric_connections.has_value()); if (backward_fabric_connections.has_value()) { - auto sender_worker_flow_control_semaphore_id = CreateSemaphore(program, worker_core_range, 0); - auto sender_worker_teardown_semaphore_id = CreateSemaphore(program, worker_core_range, 0); - auto sender_worker_buffer_index_semaphore_id = CreateSemaphore(program, worker_core_range, 0); - append_worker_to_fabric_edm_sender_rt_args( - backward_fabric_connections.value(), - sender_worker_flow_control_semaphore_id, - sender_worker_teardown_semaphore_id, - sender_worker_buffer_index_semaphore_id, - rt_args); + const auto new_rt_args = + generate_edm_connection_rt_args(*backward_fabric_connections, program, worker_core_range); + std::copy(new_rt_args.begin(), new_rt_args.end(), std::back_inserter(rt_args)); } for (size_t i = 0; i < num_command_streams; i++) { diff --git a/ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.hpp b/ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.hpp index b03d8d398c4..23271b809b8 100644 --- a/ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.hpp @@ -62,6 +62,12 @@ void generate_ccl_command_stream_to_kernel_args( ttnn::ccl::tensor_address_runtime_args_overrider *rt_args_overrider_out, std::vector& rt_args_out); +/* + * @return the runtime args + */ +std::vector generate_edm_connection_rt_args( + const ttnn::ccl::SenderWorkerAdapterSpec& connection_info, Program& program, CoreRangeSet worker_cores); + // TODO: eventually take a fabric handle void generate_multi_input_command_stream_kernel_rt_args( Program& program, diff --git a/ttnn/cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/fabric_connection_manager.hpp b/ttnn/cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/fabric_connection_manager.hpp index 3a4c40961e2..61d790efcae 100644 --- a/ttnn/cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/fabric_connection_manager.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/fabric_connection_manager.hpp @@ -4,6 +4,8 @@ #pragma once +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp" + class FabricConnectionManager final { public: // return if there is/should be a connection - doesn't return whether or not the connection diff --git a/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.cpp b/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.cpp index 27cf8adfd33..8be28978f47 100644 --- a/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.cpp @@ -44,6 +44,34 @@ namespace ttnn::ccl { FabricEriscDatamoverConfig::FabricEriscDatamoverConfig( std::size_t channel_buffer_size_bytes, std::size_t sender_ratio_size, std::size_t receiver_ratio_size) { + TT_FATAL( + (receiver_completed_packet_header_cb_address % eth_word_l1_alignment == 0), + "receiver_completed_packet_header_cb_address must be aligned to 16 bytes"); + TT_FATAL( + (sender_0_completed_packet_header_cb_address % eth_word_l1_alignment == 0), + "receiver_completed_packet_header_cb_address must be aligned to 16 bytes"); + TT_FATAL( + (sender_1_completed_packet_header_cb_address % eth_word_l1_alignment == 0), + "receiver_completed_packet_header_cb_address must be aligned to 16 bytes"); + TT_FATAL( + (sender_channel_0_buffer_index_address % eth_word_l1_alignment == 0), + "receiver_completed_packet_header_cb_address must be aligned to 16 bytes"); + TT_FATAL( + (sender_channel_0_worker_conn_info_base_address % eth_word_l1_alignment == 0), + "receiver_completed_packet_header_cb_address must be aligned to 16 bytes"); + TT_FATAL( + (sender_channel_0_local_flow_control_semaphore_address % eth_word_l1_alignment == 0), + "receiver_completed_packet_header_cb_address must be aligned to 16 bytes"); + TT_FATAL( + (sender_channel_0_producer_terminate_connection_address % eth_word_l1_alignment == 0), + "receiver_completed_packet_header_cb_address must be aligned to 16 bytes"); + TT_FATAL( + (sender_channel_1_local_flow_control_semaphore_address % eth_word_l1_alignment == 0), + "receiver_completed_packet_header_cb_address must be aligned to 16 bytes"); + TT_FATAL( + (sender_channel_1_producer_terminate_connection_address % eth_word_l1_alignment == 0), + "receiver_completed_packet_header_cb_address must be aligned to 16 bytes"); + TT_FATAL(sender_channel_1_buffer_index_address != sender_channel_0_buffer_index_address, "FabricEriscDatamoverConfig was constructed with illegal buffer index address"); const size_t min_buffer_size = sizeof(tt::fabric::PacketHeader) + 2 * FabricEriscDatamoverConfig::eth_channel_sync_size; TT_FATAL(channel_buffer_size_bytes >= min_buffer_size, "FabricEriscDatamoverConfig was constructed with `channel_buffer_size_bytes` argument set smaller than minimum size of {}", min_buffer_size); @@ -190,9 +218,9 @@ FabricEriscDatamoverBuilder::FabricEriscDatamoverBuilder( receiver_channel_local_buffer_index_address(config.receiver_channel_local_buffer_index_address), local_sender_channel_0_buffer_address(config.sender_0_channel_base_address), - local_sender_channel_0_connection_info_addr(config.sender_channel_0_worker_connection_info_address), + local_sender_channel_0_connection_info_addr(config.sender_channel_0_worker_conn_info_base_address), local_sender_channel_1_buffer_address(config.sender_1_channel_base_address), - local_sender_channel_1_connection_info_addr(config.sender_channel_1_worker_connection_info_address), + local_sender_channel_1_connection_info_addr(config.sender_channel_1_worker_conn_info_base_address), local_receiver_channel_buffer_address(config.receiver_channel_base_address), termination_signal_ptr(config.termination_signal_address), @@ -211,6 +239,7 @@ std::vector FabricEriscDatamoverBuilder::get_compile_time_args() const log_trace(tt::LogTest, "Sender 1 channel address: {}", this->local_sender_channel_1_buffer_address); log_trace(tt::LogTest, "Receiver num buffers: {}", this->receiver_num_buffers); log_trace(tt::LogTest, "Receiver channel address: {}", this->local_receiver_channel_buffer_address); + return std::vector{ this->firmware_context_switch_interval, is_handshake_master, @@ -221,9 +250,9 @@ std::vector FabricEriscDatamoverBuilder::get_compile_time_args() const this->receiver_num_buffers, config.sender_0_channel_base_address, - config.sender_channel_0_worker_connection_info_address, + config.sender_channel_0_worker_conn_info_base_address, config.sender_1_channel_base_address, - config.sender_channel_1_worker_connection_info_address, + config.sender_channel_1_worker_conn_info_base_address, config.receiver_channel_base_address, config.receiver_channel_base_address, @@ -231,7 +260,23 @@ std::vector FabricEriscDatamoverBuilder::get_compile_time_args() const config.sender_1_channel_base_address, this->termination_signal_ptr, - this->enable_persistent_mode}; + this->enable_persistent_mode, + + // fabric counters + FabricEriscDatamoverConfig::enable_fabric_counters, + config.receiver_channel_counters_address, + config.sender_channel_0_counters_address, + config.sender_channel_1_counters_address, + + // fabric pkt header recording + FabricEriscDatamoverConfig::enable_fabric_pkt_header_recording, + + config.receiver_completed_packet_header_cb_address, + FabricEriscDatamoverConfig::receiver_completed_packet_header_cb_size_headers, + config.sender_0_completed_packet_header_cb_address, + FabricEriscDatamoverConfig::sender_completed_packet_header_cb_size_headers, + config.sender_1_completed_packet_header_cb_address, + FabricEriscDatamoverConfig::sender_completed_packet_header_cb_size_headers}; } std::vector FabricEriscDatamoverBuilder::get_runtime_args() const { @@ -349,34 +394,32 @@ SenderWorkerAdapterSpec FabricEriscDatamoverBuilder::build_connection_to_worker_ log_trace(tt::LogOp, "Building connection to non-persistent fabric"); } TT_FATAL(sender_channel_0_buffer_index_semaphore_id != sender_channel_0_flow_control_semaphore_id, "Internal error - sender_channel_0_buffer_index_semaphore_id and sender_channel_0_flow_control_semaphore_id aliased eachother"); - return SenderWorkerAdapterSpec { + return SenderWorkerAdapterSpec{ this->my_noc_x, this->my_noc_y, this->local_sender_channel_0_buffer_address, this->sender_0_num_buffers, this->sender_channel_0_flow_control_semaphore_id, this->sender_channel_0_connection_semaphore_id, - this->config.sender_channel_0_worker_connection_info_address, + this->config.sender_channel_0_worker_conn_info_base_address, this->config.channel_buffer_size_bytes, this->sender_channel_0_buffer_index_semaphore_id, - this->enable_persistent_mode - }; + this->enable_persistent_mode}; } SenderWorkerAdapterSpec FabricEriscDatamoverBuilder::build_connection_to_fabric_channel() const { - return SenderWorkerAdapterSpec { + return SenderWorkerAdapterSpec{ this->my_noc_x, this->my_noc_y, this->local_sender_channel_1_buffer_address, this->sender_1_num_buffers, this->sender_channel_1_flow_control_semaphore_id, this->sender_channel_1_connection_semaphore_id, - this->config.sender_channel_1_worker_connection_info_address, + this->config.sender_channel_1_worker_conn_info_base_address, this->config.channel_buffer_size_bytes, this->sender_channel_1_buffer_index_semaphore_id, - false - }; + false}; } void FabricEriscDatamoverBuilder::connect_to_downstream_edm(FabricEriscDatamoverBuilder const& downstream_edm) { @@ -401,7 +444,8 @@ EdmLineFabricOpInterface::EdmLineFabricOpInterface( std::optional desired_num_links, bool build_in_worker_connection_mode) : device_sequence(device_sequence), programs(program_sequence) { - static constexpr std::size_t edm_buffer_size = 4096 + sizeof(tt::fabric::PacketHeader); + static constexpr std::size_t edm_buffer_size = + FabricEriscDatamoverBuilder::default_packet_payload_size_bytes + sizeof(tt::fabric::PacketHeader); auto const config = FabricEriscDatamoverConfig(edm_buffer_size, 1, 2); TT_ASSERT(device_sequence.size() == program_sequence.size()); @@ -496,7 +540,8 @@ EdmLineFabricOpInterface::EdmLineFabricOpInterface( std::optional desired_num_links, bool build_in_worker_connection_mode) : device_sequence({local_device}), programs({program}) { - static constexpr std::size_t edm_buffer_size = 4096 + sizeof(tt::fabric::PacketHeader); + static constexpr std::size_t edm_buffer_size = + FabricEriscDatamoverBuilder::default_packet_payload_size_bytes + sizeof(tt::fabric::PacketHeader); auto const config = FabricEriscDatamoverConfig(edm_buffer_size, 1, 2); log_trace(tt::LogOp, "device id={}", local_device->id()); @@ -603,12 +648,19 @@ EdmLineFabricOpInterface EdmLineFabricOpInterface::build_program_builder_worker_ EdmLineFabricOpInterface EdmLineFabricOpInterface::build_program_builder_worker_connection_fabric( IDevice* local_device, - std::optional forward_device, - std::optional backward_device, + IDevice* forward_device, + IDevice* backward_device, Program* program, bool enable_persistent_mode, std::optional desired_num_links) { - return EdmLineFabricOpInterface(local_device, forward_device, backward_device, program, enable_persistent_mode, desired_num_links, true); + return EdmLineFabricOpInterface( + local_device, + forward_device == nullptr ? std::nullopt : std::optional(forward_device), + backward_device == nullptr ? std::nullopt : std::optional(backward_device), + program, + enable_persistent_mode, + desired_num_links, + true); } void EdmLineFabricOpInterface::build_kernels() const { @@ -669,7 +721,8 @@ std::vector EdmLineFabricOpInterface::generate_local_chi } std::vector EdmLineFabricOpInterface::generate_ordered_termination_info_farthest_to_nearest() const { - static constexpr std::size_t edm_buffer_size = 4096 + sizeof(tt::fabric::PacketHeader); + static constexpr std::size_t edm_buffer_size = + FabricEriscDatamoverBuilder::default_packet_payload_size_bytes + sizeof(tt::fabric::PacketHeader); static const auto config = FabricEriscDatamoverConfig(edm_buffer_size, 1, 2); TT_ASSERT(device_sequence.size() > 0); const size_t num_hops = device_sequence.size() - 1; @@ -753,46 +806,70 @@ void EdmLineFabricOpInterface::set_firmware_context_switch_interval(size_t inter } } -void initialize_edm_fabric(distributed::MeshDevice* mesh_device) { - - std::vector row_fabric_lines; - row_fabric_lines.reserve(mesh_device->get_view().get_row_views().size()); - std::vector col_fabric_lines; - col_fabric_lines.reserve(mesh_device->get_view().get_column_views().size()); - - size_t num_rows = mesh_device->get_view().get_row_views().size(); - size_t num_cols = mesh_device->get_view().get_column_views().size(); - std::vector> programs(num_rows); - for (size_t r = 0; r < num_rows; r++) { - programs[r].resize(num_cols); - } - - for (size_t i = 0; i < num_rows; i++) { +void initialize_edm_fabric(distributed::MeshDevice* mesh_device, bool wrap_fabric_around_mesh) { + if (wrap_fabric_around_mesh) { + auto devices = mesh_device->get_view().get_ring_devices(); std::vector program_ptrs; - program_ptrs.reserve(num_cols); - std::transform(programs[i].begin(), programs[i].end(), std::back_inserter(program_ptrs), [](Program& p) { return &p; }); - row_fabric_lines.push_back(EdmLineFabricOpInterface(mesh_device->get_view().get_row_views()[i], program_ptrs, true)); - } - - for (size_t i = 0; i < num_cols; i++) { - std::vector program_ptrs; - program_ptrs.reserve(num_rows); + std::vector programs(devices.size()); + program_ptrs.reserve(devices.size()); + + std::transform( + programs.begin(), programs.end(), std::back_inserter(program_ptrs), [](Program& p) { return &p; }); + EdmLineFabricOpInterface fabric_device_builders = EdmLineFabricOpInterface(devices, program_ptrs, true); + fabric_device_builders.build_kernels(); + + for (size_t i = 0; i < devices.size(); i++) { + auto* device = devices[i]; + auto* program_ptr = program_ptrs[i]; + device->push_work([&]() { tt::tt_metal::detail::CompileProgram(device, *program_ptr); }, false); + device->push_work( + [&]() { tt::tt_metal::EnqueueProgram(device->command_queue(), *program_ptr, false); }, true); + } + } else { + std::vector row_fabric_lines; + row_fabric_lines.reserve(mesh_device->get_view().get_row_views().size()); + std::vector col_fabric_lines; + col_fabric_lines.reserve(mesh_device->get_view().get_column_views().size()); + + size_t num_rows = mesh_device->get_view().get_row_views().size(); + size_t num_cols = mesh_device->get_view().get_column_views().size(); + std::vector> programs(num_rows); for (size_t r = 0; r < num_rows; r++) { - program_ptrs.push_back(&programs[r][i]); + programs[r].resize(num_cols); + } + + for (size_t i = 0; i < num_rows; i++) { + std::vector program_ptrs; + program_ptrs.reserve(num_cols); + std::transform(programs[i].begin(), programs[i].end(), std::back_inserter(program_ptrs), [](Program& p) { + return &p; + }); + row_fabric_lines.push_back( + EdmLineFabricOpInterface(mesh_device->get_view().get_row_views()[i], program_ptrs, true)); + } + + for (size_t i = 0; i < num_cols; i++) { + std::vector program_ptrs; + program_ptrs.reserve(num_rows); + for (size_t r = 0; r < num_rows; r++) { + program_ptrs.push_back(&programs[r][i]); + } + col_fabric_lines.push_back( + EdmLineFabricOpInterface(mesh_device->get_view().get_column_views()[i], program_ptrs, true)); } - col_fabric_lines.push_back(EdmLineFabricOpInterface(mesh_device->get_view().get_column_views()[i], program_ptrs, true)); - } - std::for_each(row_fabric_lines.begin(), row_fabric_lines.end(), [](auto& line) { line.build_kernels(); }); - std::for_each(col_fabric_lines.begin(), col_fabric_lines.end(), [](auto& line) { line.build_kernels(); }); + std::for_each(row_fabric_lines.begin(), row_fabric_lines.end(), [](auto& line) { line.build_kernels(); }); + std::for_each(col_fabric_lines.begin(), col_fabric_lines.end(), [](auto& line) { line.build_kernels(); }); - for (size_t r = 0; r < num_rows; r++) { - for (size_t c = 0; c < num_cols; c++) { - log_info(tt::LogAlways, "Compile EDM program"); - IDevice*device = mesh_device->get_device(r, c); - auto& program = programs.at(r).at(c); - device->push_work([&](){tt::tt_metal::detail::CompileProgram(device, program);}, false); - device->push_work([&](){tt::tt_metal::EnqueueProgram(device->command_queue(), program, false);}, true); + for (size_t r = 0; r < num_rows; r++) { + for (size_t c = 0; c < num_cols; c++) { + log_info(tt::LogAlways, "Compile EDM program"); + IDevice* device = mesh_device->get_device(r, c); + auto& program = programs.at(r).at(c); + device->push_work([&]() { tt::tt_metal::detail::CompileProgram(device, program); }, false); + device->push_work( + [&]() { tt::tt_metal::EnqueueProgram(device->command_queue(), program, false); }, true); + } } } } diff --git a/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.hpp b/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.hpp index eaeeea20501..1d32db7f8c3 100644 --- a/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.hpp @@ -14,6 +14,9 @@ #include "cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" #include "cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp" +#include "cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_counters.hpp" + +#include #include #include #include @@ -29,21 +32,57 @@ namespace ccl { struct FabricEriscDatamoverConfig { static constexpr std::size_t field_size = 16; static constexpr std::size_t buffer_alignment = 32; + static constexpr std::size_t eth_word_l1_alignment = 16; static_assert(((buffer_alignment - 1) & buffer_alignment) == 0); + static constexpr bool enable_fabric_counters = false; + static constexpr bool enable_fabric_pkt_header_recording = false; // Global static constexpr std::size_t eth_channel_sync_size = 16; std::size_t handshake_addr = tt::tt_metal::experimental::hal::get_erisc_l1_unreserved_base()/* + 1024*/; std::size_t edm_channel_ack_addr = handshake_addr + eth_channel_sync_size; std::size_t termination_signal_address = - edm_channel_ack_addr + (2 * eth_channel_sync_size); // pad extra bytes to match old EDM so handshake logic will still work + edm_channel_ack_addr + + (4 * eth_channel_sync_size); // pad extra bytes to match old EDM so handshake logic will still work + + // Debug and Counters + static constexpr std::size_t receiver_channel_counters_size_bytes = + (((tt::fabric::receiver_channel_counters_l1_size - 1) / field_size) + 1) * field_size; + static constexpr std::size_t sender_channel_counters_size_bytes = + (((tt::fabric::sender_channel_counters_l1_size - 1) / field_size) + 1) * field_size; + + std::size_t receiver_channel_counters_address = termination_signal_address + field_size; + std::size_t sender_channel_0_counters_address = + receiver_channel_counters_address + receiver_channel_counters_size_bytes; + std::size_t sender_channel_1_counters_address = + sender_channel_0_counters_address + sender_channel_counters_size_bytes; + + // Packet header history buffer(s) + static constexpr std::size_t receiver_completed_packet_header_cb_size_headers = 32; + static constexpr std::size_t receiver_completed_packet_header_cb_size_bytes = + sizeof(tt::fabric::PacketHeader) * receiver_completed_packet_header_cb_size_headers; + static constexpr std::size_t sender_completed_packet_header_cb_size_headers = 32; + static constexpr std::size_t sender_completed_packet_header_cb_size_bytes = + sizeof(tt::fabric::PacketHeader) * sender_completed_packet_header_cb_size_headers; + std::size_t receiver_completed_packet_header_cb_address = + sender_channel_1_counters_address + sender_channel_counters_size_bytes; + std::size_t sender_0_completed_packet_header_cb_address = + receiver_completed_packet_header_cb_address + receiver_completed_packet_header_cb_size_bytes; + std::size_t sender_1_completed_packet_header_cb_address = + sender_0_completed_packet_header_cb_address + sender_completed_packet_header_cb_size_bytes; // ----------- Sender Channel 0 - std::size_t sender_channel_0_buffer_index_address = termination_signal_address + field_size; - std::size_t sender_channel_0_worker_connection_info_address = - sender_channel_0_buffer_index_address + field_size; + std::size_t sender_channel_0_buffer_index_address = + sender_1_completed_packet_header_cb_address + sender_completed_packet_header_cb_size_bytes; + // Connection info layout: + // 0: buffer_index_rdptr -> Tells EDM the address in worker L1 to update EDM's copy of channel rdptr + // 1: worker_teardown_semaphore_address -> Tells EDM where to signal connection teardown completion in worker's L1 + // 2: WorkerXY (as uint32_t) + // 3: Hold's EDM's rdptr for the buffer index in the channel + std::size_t sender_channel_0_worker_conn_info_base_address = sender_channel_0_buffer_index_address + field_size; std::size_t sender_channel_0_local_flow_control_semaphore_address = - sender_channel_0_worker_connection_info_address + field_size; + sender_channel_0_worker_conn_info_base_address + sizeof(tt::fabric::EDMChannelWorkerLocationInfo); + // sender_channel_0_conn_info_edm_rdptr_address_address + field_size; std::size_t sender_channel_0_producer_terminate_connection_address = sender_channel_0_local_flow_control_semaphore_address + field_size; // persistent mode field @@ -53,17 +92,23 @@ struct FabricEriscDatamoverConfig { std::size_t sender_channel_0_buffer_index_semaphore_address = sender_channel_0_connection_semaphore_address + field_size; - static_assert(field_size >= sizeof(tt::fabric::EDMChannelWorkerLocationInfo)); + static_assert(sizeof(tt::fabric::EDMChannelWorkerLocationInfo) % field_size == 0); // ----------- Sender Channel 1 std::size_t sender_channel_1_buffer_index_address = sender_channel_0_buffer_index_semaphore_address + field_size; - std::size_t sender_channel_1_worker_connection_info_address = - sender_channel_1_buffer_index_address + field_size; + // Connection info layout: + // 0: buffer_index_rdptr -> Tells EDM the address in worker L1 to update EDM's copy of channel rdptr + // 1: worker_teardown_semaphore_address -> Tells EDM where to signal connection teardown completion in worker's L1 + // 2: WorkerXY (as uint32_t) + // 3: Hold's EDM's rdptr for the buffer index in the channel + std::size_t sender_channel_1_worker_conn_info_base_address = sender_channel_1_buffer_index_address + field_size; std::size_t sender_channel_1_local_flow_control_semaphore_address = - sender_channel_1_worker_connection_info_address + field_size; + sender_channel_1_worker_conn_info_base_address + sizeof(tt::fabric::EDMChannelWorkerLocationInfo); + // sender_channel_1_conn_info_edm_rdptr_address_address + field_size; std::size_t sender_channel_1_producer_terminate_connection_address = sender_channel_1_local_flow_control_semaphore_address + field_size; + // persistent mode field std::size_t sender_channel_1_connection_semaphore_address = sender_channel_1_producer_terminate_connection_address + field_size; @@ -135,6 +180,8 @@ size_t log_worker_to_fabric_edm_sender_rt_args(std::vector const& args class FabricEriscDatamoverBuilder { public: static constexpr size_t default_firmware_context_switch_interval = 200000; + // payload only, no header + static constexpr size_t default_packet_payload_size_bytes = 4096; FabricEriscDatamoverBuilder( const CoreCoord& my_eth_core_logical, @@ -253,7 +300,13 @@ class EdmLineFabricOpInterface { EdmLineFabricOpInterface (IDevice* local_device, std::optional forward_device, std::optional backward_device, Program* program, bool enable_persistent_mode, std::optional desired_num_links, bool build_in_worker_connection_mode = false); static EdmLineFabricOpInterface build_program_builder_worker_connection_fabric(std::vector const& device_sequence, std::vector const& program_sequence, bool enable_persistent_mode, std::optional desired_num_links = std::nullopt); - static EdmLineFabricOpInterface build_program_builder_worker_connection_fabric(IDevice* local_device, std::optional forward_device, std::optional backward_device, Program* program, bool enable_persistent_mode, std::optional desired_num_links = std::nullopt); + static EdmLineFabricOpInterface build_program_builder_worker_connection_fabric( + IDevice* local_device, + IDevice* forward_device, + IDevice* backward_device, + Program* program, + bool enable_persistent_mode, + std::optional desired_num_links = std::nullopt); // Will create a connection adapter for a worker which can be used to pass args to the worker kernel talking to the // corresponding fabric endpoint. This interface will guarantee unique connections only so requesting more unique connections @@ -316,7 +369,7 @@ class EdmLineFabricOpInterface { size_t firmware_context_switch_interval = FabricEriscDatamoverBuilder::default_firmware_context_switch_interval; }; -void initialize_edm_fabric(distributed::MeshDevice* mesh_device); +void initialize_edm_fabric(distributed::MeshDevice* mesh_device, bool wrap_fabric_around_mesh = false); void teardown_edm_fabric(distributed::MeshDevice* mesh_device); }; // namespace ccl diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_counters.hpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_counters.hpp new file mode 100644 index 00000000000..b0db94b98df --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_counters.hpp @@ -0,0 +1,79 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace tt::fabric { + +struct EdmFabricReceiverChannelCounters { + uint32_t n_pkts_processed = 0; + uint32_t n_pkts_fwded = 0; + uint32_t n_pkts_written_locally = 0; + uint32_t n_pkts_rx_acked = 0; + uint32_t n_pkts_completion_acked = 0; + + uint32_t n_fabric_mcast_noc_atomic_processed = 0; + uint32_t n_fabric_mcast_noc_write_processed = 0; + uint32_t n_fabric_unicast_noc_atomic_processed = 0; + uint32_t n_fabric_unicast_noc_write_processed = 0; + + EdmFabricReceiverChannelCounters() = default; +}; +static constexpr uint32_t receiver_channel_counters_l1_size = sizeof(EdmFabricReceiverChannelCounters); + +struct EdmFabricSenderChannelCounters { + uint32_t n_lifetime_connections = 0; + + uint32_t n_lifetime_pkts_received = 0; + uint32_t n_lifetime_pkts_fwded = 0; + uint32_t n_lifetime_pkts_acked = 0; + uint32_t n_lifetime_pkts_complete = 0; + + uint32_t n_lifetime_fabric_mcast_noc_atomic_pkts = 0; + uint32_t n_lifetime_fabric_mcast_noc_write_pkts = 0; + uint32_t n_lifetime_fabric_unicast_noc_atomic_pkts = 0; + uint32_t n_lifetime_fabric_unicast_noc_write_pkts = 0; + + uint32_t n_connection_pkts_received = 0; + uint32_t n_connection_pkts_fwded = 0; + uint32_t n_connection_pkts_acked = 0; + uint32_t n_connection_pkts_complete = 0; + + uint32_t n_connection_fabric_mcast_noc_atomic_pkts = 0; + uint32_t n_connection_fabric_mcast_noc_write_pkts = 0; + uint32_t n_connection_fabric_unicast_noc_atomic_pkts = 0; + uint32_t n_connection_fabric_unicast_noc_write_pkts = 0; + + void add_connection() volatile { + this->n_lifetime_connections++; + this->reset_connection_counters(); + } + + void add_pkt_received() volatile { + this->n_lifetime_pkts_received++; + this->n_connection_pkts_received++; + } + + void add_pkt_sent() volatile { + this->n_lifetime_pkts_fwded++; + this->n_connection_pkts_fwded++; + } + + void reset_connection_counters() volatile { + this->n_connection_pkts_received = 0; + this->n_connection_pkts_fwded = 0; + this->n_connection_pkts_acked = 0; + this->n_connection_pkts_complete = 0; + + this->n_connection_fabric_mcast_noc_atomic_pkts = 0; + this->n_connection_fabric_mcast_noc_write_pkts = 0; + this->n_connection_fabric_unicast_noc_atomic_pkts = 0; + this->n_connection_fabric_unicast_noc_write_pkts = 0; + } +}; +static constexpr uint32_t sender_channel_counters_l1_size = sizeof(EdmFabricSenderChannelCounters); + +} // namespace tt::fabric diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp index 9f1ca2a1a78..30aba536630 100644 --- a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp @@ -9,18 +9,37 @@ #include "tt_metal/hw/inc/ethernet/dataflow_api.h" #include "cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp" #include "cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header_validate.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_types.hpp" #include "debug/assert.h" #include "debug/dprint.h" - #include namespace tt::fabric { +/* + * The WorkerToFabricEdmSender acts as an adapter between the worker and the EDM, it hides details + * of the communication between worker and EDM to provide flexibility for the implementation to change + * over time without kernel updates. Additionally, details for adapter setup w.r.t runtime args is also hidden. + * The main functionality provided is: + * - Opening a connection with the EDM + * - Closing a connection with the EDM + * - Flow control protocol between worker and EDM + * + * ### Flow Control Protocol: + * The flow control protocol is rd/wr ptr based and is implemented as follows (from the worker's perspective): + * The adapter has a local write pointer (wrptr) which is used to track the next buffer slot to write to. The adapter + * also has a local memory slot that holds the remote read pointer (rdptr) of the EDM. The adapter uses the difference + * between these two pointers (where rdptr trails wrptr) to determine if the EDM has space to accept a new packet. + * + * As the adapter writes into the EDM, it updates the local wrptr. As the EDM reads from its local L1 channel buffer, + * it will notify the worker/adapter (here) by updating the worker remote_rdptr to carry the value of the EDM rdptr. + */ struct WorkerToFabricEdmSender { + static constexpr uint32_t unused_connection_value = 0; static constexpr uint32_t open_connection_value = 1; - static constexpr uint32_t close_connection_value = 0; + static constexpr uint32_t close_connection_request_value = 2; - WorkerToFabricEdmSender() : worker_sem_addr(nullptr) {} + WorkerToFabricEdmSender() : from_remote_buffer_slot_rdptr_ptr(nullptr) {} template static WorkerToFabricEdmSender build_from_args(std::size_t& arg_idx) { @@ -72,11 +91,11 @@ struct WorkerToFabricEdmSender { std::size_t edm_worker_location_info_addr, // The EDM's location for `EDMChannelWorkerLocationInfo` uint16_t buffer_size_bytes, size_t edm_buffer_index_id, - volatile uint32_t* const worker_sem_addr, + volatile uint32_t* const from_remote_buffer_slot_rdptr_ptr, volatile uint32_t* const worker_teardown_addr, uint32_t local_buffer_index_addr) : edm_buffer_addr(edm_buffer_base_addr), - edm_semaphore_addr( + edm_buffer_slot_wrptr_addr( connected_to_persistent_fabric ? edm_l1_sem_id : get_semaphore(edm_l1_sem_id)), edm_connection_handshake_l1_addr( @@ -87,10 +106,10 @@ struct WorkerToFabricEdmSender { edm_buffer_index_addr( connected_to_persistent_fabric ? edm_buffer_index_id : get_semaphore(edm_buffer_index_id)), - worker_sem_addr(worker_sem_addr), + from_remote_buffer_slot_rdptr_ptr(from_remote_buffer_slot_rdptr_ptr), worker_teardown_addr(worker_teardown_addr), edm_buffer_base_addr(edm_buffer_base_addr), - buffer_index_ptr(reinterpret_cast(local_buffer_index_addr)), + buffer_slot_wrptr_ptr(reinterpret_cast(local_buffer_index_addr)), buffer_size_bytes(buffer_size_bytes), num_buffers_per_channel(num_buffers_per_channel), last_buffer_index(num_buffers_per_channel - 1), @@ -99,11 +118,16 @@ struct WorkerToFabricEdmSender { ASSERT(buffer_size_bytes > 0); } - [[nodiscard]] FORCE_INLINE bool consumer_has_space() const { return *this->worker_sem_addr == 1; } - FORCE_INLINE void clear_flow_control_semaphore() const { noc_semaphore_set(this->worker_sem_addr, 0); } + FORCE_INLINE bool edm_has_space_for_packet() const { + auto const wrptr = *this->buffer_slot_wrptr_ptr; + auto const rdptr = *this->from_remote_buffer_slot_rdptr_ptr; + bool wrptr_ge_rptr = wrptr >= rdptr; + uint8_t slots_used = wrptr_ge_rptr ? (wrptr - rdptr) : ((2 * this->num_buffers_per_channel) - rdptr) + wrptr; + return slots_used < this->num_buffers_per_channel; + } + FORCE_INLINE void wait_for_empty_write_slot() const { - DPRINT << "Wait for write slot @ " << (uint32_t)this->worker_sem_addr << "\n"; - noc_semaphore_wait(this->worker_sem_addr, 1); + while (!this->edm_has_space_for_packet()); } FORCE_INLINE void send_payload_blocking(uint32_t cb_id, uint32_t num_pages, uint32_t page_size) { @@ -127,6 +151,9 @@ struct WorkerToFabricEdmSender { FORCE_INLINE void send_payload_flush_blocking_from_address(uint32_t source_address, size_t size_bytes) { send_payload_from_address_impl(source_address, size_bytes); } + FORCE_INLINE void send_payload_flush_non_blocking_from_address(uint32_t source_address, size_t size_bytes) { + send_payload_from_address_impl(source_address, size_bytes); + } FORCE_INLINE void send_payload_blocking_from_address(uint32_t source_address, size_t size_bytes) { send_payload_from_address_impl(source_address, size_bytes); } @@ -141,37 +168,40 @@ struct WorkerToFabricEdmSender { static constexpr size_t edm_sender_channel_field_stride_bytes = 16; - FORCE_INLINE void open() { - const auto dest_noc_addr_coord_only = - get_noc_addr(this->edm_noc_x, this->edm_noc_y, this->edm_semaphore_addr) & ~(uint64_t)NOC_COORDINATE_MASK; + void open() { + const auto dest_noc_addr_coord_only = get_noc_addr(this->edm_noc_x, this->edm_noc_y, 0); const uint64_t remote_buffer_index_addr = dest_noc_addr_coord_only | edm_buffer_index_addr; ASSERT(remote_buffer_index_addr > 0); - noc_async_read(remote_buffer_index_addr, reinterpret_cast(this->buffer_index_ptr), sizeof(uint32_t)); + noc_async_read(remote_buffer_index_addr, reinterpret_cast(this->buffer_slot_wrptr_ptr), sizeof(uint32_t)); - const uint64_t dest_edm_location_info_addr = dest_noc_addr_coord_only | edm_worker_location_info_addr; + tt::fabric::EDMChannelWorkerLocationInfo* worker_location_info_ptr = reinterpret_cast(edm_worker_location_info_addr); + const uint64_t edm_rdptr_addr = dest_noc_addr_coord_only | reinterpret_cast(edm_worker_location_info_addr + offsetof(tt::fabric::EDMChannelWorkerLocationInfo, edm_rdptr)); + noc_async_read(edm_rdptr_addr, reinterpret_cast(this->from_remote_buffer_slot_rdptr_ptr), sizeof(uint32_t)); // TODO: Need to change byte enable to be word enable - noc_inline_dw_write(dest_edm_location_info_addr, reinterpret_cast(worker_sem_addr)); - noc_inline_dw_write(dest_edm_location_info_addr + sizeof(uint32_t), reinterpret_cast(worker_teardown_addr)); - noc_inline_dw_write( - dest_edm_location_info_addr + 2 * sizeof(uint32_t), ttnn::ccl::WorkerXY(my_x[0], my_y[0]).to_uint32()); + const uint64_t dest_edm_location_info_addr = dest_noc_addr_coord_only | edm_worker_location_info_addr; + const uint64_t edm_teardown_semaphore_address_address = dest_noc_addr_coord_only | reinterpret_cast(&(worker_location_info_ptr->worker_teardown_semaphore_address)); + const uint64_t connection_worker_xy_address = dest_noc_addr_coord_only | reinterpret_cast(&(worker_location_info_ptr->worker_xy)); + noc_inline_dw_write(dest_edm_location_info_addr, reinterpret_cast(from_remote_buffer_slot_rdptr_ptr)); + noc_inline_dw_write(edm_teardown_semaphore_address_address, reinterpret_cast(worker_teardown_addr)); + noc_inline_dw_write(connection_worker_xy_address, ttnn::ccl::WorkerXY(my_x[0], my_y[0]).to_uint32()); const uint64_t edm_connection_handshake_noc_addr = dest_noc_addr_coord_only | edm_connection_handshake_l1_addr; noc_inline_dw_write(edm_connection_handshake_noc_addr, open_connection_value); noc_async_read_barrier(); - ASSERT(*this->buffer_index_ptr < 20); + ASSERT(*this->buffer_slot_wrptr_ptr < 20); } - FORCE_INLINE void close() { + void close() { const auto dest_noc_addr_coord_only = - get_noc_addr(this->edm_noc_x, this->edm_noc_y, this->edm_semaphore_addr) & ~(uint64_t)NOC_COORDINATE_MASK; + get_noc_addr(this->edm_noc_x, this->edm_noc_y, this->edm_buffer_slot_wrptr_addr) & ~(uint64_t)NOC_COORDINATE_MASK; const uint64_t dest_edm_connection_state_addr = dest_noc_addr_coord_only | edm_connection_handshake_l1_addr; - noc_inline_dw_write(dest_edm_connection_state_addr, close_connection_value); + noc_inline_dw_write(dest_edm_connection_state_addr, close_connection_request_value); // buffer index stored at location after handshake addr const uint64_t remote_buffer_index_addr = dest_noc_addr_coord_only | edm_buffer_index_addr; - noc_inline_dw_write(remote_buffer_index_addr, *this->buffer_index_ptr); + noc_inline_dw_write(remote_buffer_index_addr, *this->buffer_slot_wrptr_ptr); // Need to wait for the ack to teardown notice, from edm noc_semaphore_wait(this->worker_teardown_addr, 1); @@ -180,38 +210,70 @@ struct WorkerToFabricEdmSender { } uint32_t edm_buffer_addr; - uint32_t edm_semaphore_addr; + + // the L1 address of buffer_slot wrptr on the EDM we are writing to + // Writing to this address will tell the EDM that the wrptr is changed and + // that new data is available + uint32_t edm_buffer_slot_wrptr_addr; size_t edm_connection_handshake_l1_addr; size_t edm_worker_location_info_addr; size_t edm_buffer_index_addr; - volatile uint32_t* worker_sem_addr; + + // Local copy of the the buffer slot rdptr on the EDM + // EDM will update this to indicate that packets have been read (and hence + // space is available) + volatile uint32_t* from_remote_buffer_slot_rdptr_ptr; volatile uint32_t* worker_teardown_addr; size_t edm_buffer_base_addr; - size_t* buffer_index_ptr; + + // TODO: keep a local copy that we use during the lifetime of the channel to avoid repeated L1 reads + size_t* buffer_slot_wrptr_ptr; + uint16_t buffer_size_bytes; uint8_t num_buffers_per_channel; + + // Specifies how many buffer slots are available in the EDM channel uint8_t last_buffer_index; + + // noc location of the edm we are connected to (where packets are sent to) uint8_t edm_noc_x; uint8_t edm_noc_y; private: + FORCE_INLINE void update_edm_buffer_slot_wrptr() { + uint64_t const noc_sem_addr = get_noc_addr(this->edm_noc_x, this->edm_noc_y, this->edm_buffer_slot_wrptr_addr); + + noc_inline_dw_write(noc_sem_addr, *this->buffer_slot_wrptr_ptr); + } + + FORCE_INLINE void advance_buffer_slot_wrptr() { + // TODO: smarter addition if we are working with pow2 + uint8_t wrptr = *this->buffer_slot_wrptr_ptr; + *this->buffer_slot_wrptr_ptr = + !(wrptr == ((this->num_buffers_per_channel * 2) - 1)) ? wrptr + 1 : 0; + } + + FORCE_INLINE uint8_t get_buffer_slot_index() const { + auto const wrptr = *this->buffer_slot_wrptr_ptr; + bool normalize = wrptr >= this->num_buffers_per_channel; + return wrptr - (normalize * this->num_buffers_per_channel); + } + template FORCE_INLINE void send_packet_header_and_notify_fabric(uint32_t source_address) { - this->clear_flow_control_semaphore(); + uint64_t buffer_address = get_noc_addr(this->edm_noc_x, this->edm_noc_y, this->edm_buffer_addr) + - (*this->buffer_index_ptr * (this->buffer_size_bytes + sizeof(eth_channel_sync_t))); + (this->get_buffer_slot_index() * (this->buffer_size_bytes + sizeof(eth_channel_sync_t))); send_chunk_from_address(source_address, 1, sizeof(tt::fabric::PacketHeader), buffer_address); - auto const noc_sem_addr = get_noc_addr(this->edm_noc_x, this->edm_noc_y, this->edm_semaphore_addr); - noc_semaphore_inc(noc_sem_addr, 1); - *this->buffer_index_ptr = - (*this->buffer_index_ptr == this->last_buffer_index) ? 0 : *this->buffer_index_ptr + 1; + this->advance_buffer_slot_wrptr(); + this->update_edm_buffer_slot_wrptr(); } template FORCE_INLINE void send_payload_without_header_from_address_impl(uint32_t source_address, size_t size_bytes) { uint64_t buffer_address = get_noc_addr(this->edm_noc_x, this->edm_noc_y, this->edm_buffer_addr) + - (*this->buffer_index_ptr * (this->buffer_size_bytes + sizeof(eth_channel_sync_t))); + (this->get_buffer_slot_index() * (this->buffer_size_bytes + sizeof(eth_channel_sync_t))); // skip past the first part of the buffer which will be occupied by the packet header send_chunk_from_address(source_address, 1, size_bytes, buffer_address + sizeof(tt::fabric::PacketHeader)); @@ -219,34 +281,26 @@ struct WorkerToFabricEdmSender { template FORCE_INLINE void send_payload_from_address_impl(uint32_t source_address, size_t size_bytes) { - this->clear_flow_control_semaphore(); uint64_t buffer_address = get_noc_addr(this->edm_noc_x, this->edm_noc_y, this->edm_buffer_addr) + - (*this->buffer_index_ptr * (this->buffer_size_bytes + sizeof(eth_channel_sync_t))); + (this->get_buffer_slot_index() * (this->buffer_size_bytes + sizeof(eth_channel_sync_t))); ASSERT(size_bytes <= this->buffer_size_bytes); - - DPRINT << "SND PKT TO @ " << (uint64_t)buffer_address << "\n"; ASSERT(tt::fabric::is_valid(*const_cast( reinterpret_cast(source_address)))); send_chunk_from_address(source_address, 1, size_bytes, buffer_address); - auto const noc_sem_addr = get_noc_addr(this->edm_noc_x, this->edm_noc_y, this->edm_semaphore_addr); - DPRINT << "\tSEMINC TO @ " << (uint64_t)noc_sem_addr << "\n"; - noc_semaphore_inc(noc_sem_addr, 1); - *this->buffer_index_ptr = - (*this->buffer_index_ptr == this->last_buffer_index) ? 0 : *this->buffer_index_ptr + 1; + this->advance_buffer_slot_wrptr(); + this->update_edm_buffer_slot_wrptr(); } template FORCE_INLINE void send_payload_impl(uint32_t cb_id, uint32_t num_pages, uint32_t page_size) { - this->clear_flow_control_semaphore(); uint64_t buffer_address = get_noc_addr(this->edm_noc_x, this->edm_noc_y, this->edm_buffer_addr) + - (*this->buffer_index_ptr * (this->buffer_size_bytes + sizeof(eth_channel_sync_t))); + (this->get_buffer_slot_index() * (this->buffer_size_bytes + sizeof(eth_channel_sync_t))); ASSERT(num_pages * page_size <= this->buffer_size_bytes); send_chunk(cb_id, num_pages, page_size, buffer_address); - noc_semaphore_inc(get_noc_addr(this->edm_noc_x, this->edm_noc_y, this->edm_semaphore_addr), 1); - *this->buffer_index_ptr = - (*this->buffer_index_ptr == this->last_buffer_index) ? 0 : *this->buffer_index_ptr + 1; + this->advance_buffer_slot_wrptr(); + this->update_edm_buffer_slot_wrptr(); } }; diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_transmission.hpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_transmission.hpp index cfc77c9a945..272a1ca4d7d 100644 --- a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_transmission.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_transmission.hpp @@ -59,6 +59,7 @@ void print_pkt_header_noc_fields(volatile tt::fabric::PacketHeader *const packet break; } case tt::fabric::NocSendType::NOC_MULTICAST: { + ASSERT(false); // unimplemented break; } } @@ -177,28 +178,23 @@ void update_packet_header_for_next_hop(volatile tt::fabric::PacketHeader * packe // Modifies the packet header (decrements hop counts) so ... // // !!!WARNING!!! -// !!!WARNING!!! do NOT call before determining if the packet should be consumed locally or forwarded +// !!!WARNING!!! * do NOT call before determining if the packet should be consumed locally or forwarded +// !!!WARNING!!! * ENSURE DOWNSTREAM EDM HAS SPACE FOR PACKET BEFORE CALLING // !!!WARNING!!! -tt::fabric::SendStatus forward_payload_to_downstream_edm( +void forward_payload_to_downstream_edm( volatile tt::fabric::PacketHeader *packet_header, tt::fabric::WorkerToFabricEdmSender &downstream_edm_interface ) { DPRINT << "Fwding pkt to downstream\n"; // TODO: PERF - this should already be getting checked by the caller so this should be redundant make it an ASSERT - bool safe_to_send = downstream_edm_interface.consumer_has_space(); - if (!safe_to_send) { - return tt::fabric::SendStatus::NOT_SENT; - } + ASSERT(downstream_edm_interface.edm_has_space_for_packet()); // best effort check // This is a good place to print the packet header for debug if you are trying to inspect packets // because it is before we start manipulating the header for forwarding update_packet_header_for_next_hop(packet_header); - downstream_edm_interface.send_payload_blocking_from_address( reinterpret_cast(packet_header), packet_header->get_payload_size_including_header()); - - return tt::fabric::SendStatus::SENT_PAYLOAD_AND_SYNC; } diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_types.hpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_types.hpp index 30fcebf9a60..3a7198e9b01 100644 --- a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_types.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_types.hpp @@ -48,10 +48,26 @@ enum SendStatus : uint8_t { struct EDMChannelWorkerLocationInfo { uint32_t worker_semaphore_address; + uint32_t align_pad_0; // Padding added for safe reading over noc + uint32_t align_pad_1; + uint32_t align_pad_2; + uint32_t worker_teardown_semaphore_address; + uint32_t align_pad_3; // Padding added for safe reading over noc + uint32_t align_pad_4; + uint32_t align_pad_5; + ttnn::ccl::WorkerXY worker_xy; + uint32_t align_pad_6; // Padding added for safe reading over noc + uint32_t align_pad_7; + uint32_t align_pad_8; + + uint32_t edm_rdptr = 0; + uint32_t align_pad_9; // Padding added for safe reading over noc + uint32_t align_pad_10; + uint32_t align_pad_11; }; -static_assert(sizeof(EDMChannelWorkerLocationInfo) <= 16); +static_assert(sizeof(EDMChannelWorkerLocationInfo) <= 64); } // namespace tt::fabric diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover.cpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover.cpp index 73955c9b469..5e46f93e0e5 100644 --- a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover.cpp @@ -16,6 +16,9 @@ #include "cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover_channels.hpp" #include "cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_counters.hpp" + + using ttnn::ccl::WorkerXY; /* @@ -207,12 +210,112 @@ Sending a packet is done as follows: *NOTE*: !!!ALL PACKETS MUST CONTAIN DESTINATION NOC X/Y AS NOC 0 COORDINATES, REGARDLESS OF THE `noc_index` OF THE SENDER!!! + +## EDM <-> EDM Channel Flow Control +The flow control protocol between EDM channels is built on a rd/wr ptr based protocol where pointers are +to buffer slots within the channel (as opposed so something else like byte or word offset). Ptrs are +free to advance independently from each other as long as there is no overflow or underflow. + +### Sender Channel Flow Control +Both sender channels share the same flow control view into the receiver channel. This is because both channels +write to the same receiver channel. +* wrptr: + * points to next buffer slot to write to into the remote (over Ethernet) receiver channel. + * leads other pointers + * writer updates for every new packet + * `has_data_to_send(): local_wrptr != remote_sender_wrptr` +* ackptr + * trails `wrptr` + * advances as the channel receives acknowledgements from the receiver + * as this advances, the sender channel can notify the upstream worker of additional space in sender channel buffer +* completion_ptr: + * trails `local_wrptr` + * "rdptr" from remote sender's perspective + * advances as packets completed by receiver + * as this advances, the sender channel can write additional packets to the receiver at this slot + +### Receiver Channel Flow Control +* ackptr/rdptr: + * leads all pointers + * indicates the next buffer slot we expect data to arrive (from remote sender) at + * advances as packets are received (and acked) + * make sure not to overlap completion pointer +* wr_sent_ptr: + * trails `ackptr` + * indicates the buffer slot currently being processed, written out + * advances after all forwding writes (to noc or downstream EDM) are initiated +* wr_flush_ptr: + * trails `wr_sent_ptr` + * advances as writes are flushed +* completion_ptr: + * trails `wr_flush_ptr` + * indicates the next receiver buffer slot in the receiver channel to send completion acks for */ + //////////////////////////////////////////////// // Data structures, types, enums, and constants //////////////////////////////////////////////// +/* + * Tracks receiver channel pointers (from sender side) + */ +template +struct OutboundReceiverChannelPointers { + tt::fabric::ChannelBufferPointer wrptr; + tt::fabric::ChannelBufferPointer ack_ptr; + tt::fabric::ChannelBufferPointer completion_ptr; + + bool has_space_for_packet() const { + return completion_ptr.distance_behind(wrptr) < RECEIVER_NUM_BUFFERS; + } + + bool has_unacknowledged_eth_packets() const { + return ack_ptr.get_ptr() != wrptr.get_ptr(); + } + + bool has_incomplete_eth_packets() const { + return completion_ptr.get_ptr() != wrptr.get_ptr(); + } + + bool has_unacknowledged_or_incomplete_eth_packets() const { + return has_incomplete_eth_packets() || has_unacknowledged_eth_packets(); + } +}; + +/* + * Tracks receiver channel pointers (from receiver side) + */ +template +struct ReceiverChannelPointers { + tt::fabric::ChannelBufferPointer wr_sent_ptr; + tt::fabric::ChannelBufferPointer wr_flush_ptr; + tt::fabric::ChannelBufferPointer ack_ptr; + tt::fabric::ChannelBufferPointer completion_ptr; +}; + +struct PacketHeaderRecorder { + volatile tt::fabric::PacketHeader *buffer_ptr; + size_t buffer_n_headers; + size_t buffer_index; + + PacketHeaderRecorder(volatile tt::fabric::PacketHeader *buffer_ptr, size_t buffer_n_headers) : buffer_ptr(buffer_ptr), buffer_n_headers(buffer_n_headers), buffer_index(0) {} + + void record_packet_header(volatile tt::fabric::PacketHeader *packet_header_ptr) { + uint32_t dest_l1_addr = (uint32_t)buffer_ptr + buffer_index * sizeof(tt::fabric::PacketHeader); + noc_async_write( + (uint32_t)packet_header_ptr, + get_noc_addr(my_x[0], my_y[0], dest_l1_addr), + sizeof(tt::fabric::PacketHeader), + 1 - noc_index // avoid the contention on main noc + ); + buffer_index++; + if (buffer_index == buffer_n_headers) { + buffer_index = 0; + } + } +}; + enum SenderState : uint8_t { SENDER_DONE = 0, @@ -258,7 +361,13 @@ enum PacketLocalForwardType : uint8_t { PACKET_FORWARD_LOCAL_AND_REMOTE = 0x3 }; -static constexpr uint32_t SWITCH_INTERVAL = get_compile_time_arg_val(0); +static constexpr uint32_t SWITCH_INTERVAL = +#ifndef DEBUG_PRINT_ENABLED +get_compile_time_arg_val(0); +#else +0; +#endif + static constexpr size_t ETH_BYTES_TO_WORDS_SHIFT = 4; static constexpr size_t NUM_SENDER_CHANNELS = 2; static constexpr size_t num_workers_ctor = 1; @@ -271,44 +380,40 @@ static constexpr size_t worker_info_offset_past_connection_semaphore = 32; // SENDER SIDE HELPERS ///////////////////////////////////////////// -FORCE_INLINE void sender_notify_workers_if_buffer_available_sequence( - tt::fabric::EdmChannelWorkerInterface &local_sender_worker_interface) { - local_sender_worker_interface.clear_local_semaphore(); - local_sender_worker_interface.increment_worker_semaphore(); -} - template void send_channel_sync( tt::fabric::EthChannelBuffer &sender_buffer_channel, - tt::fabric::EthChannelBuffer &receiver_buffer_channel) { - + tt::fabric::ChannelBufferPointer &sender_wrptr, + tt::fabric::EthChannelBuffer &receiver_buffer_channel, + tt::fabric::ChannelBufferPointer &remote_receiver_wrptr + ) { + auto src_addr = sender_buffer_channel.get_bytes_sent_address(sender_wrptr.get_buffer_index()); + auto dest_addr = receiver_buffer_channel.get_bytes_sent_address(remote_receiver_wrptr.get_buffer_index()); eth_send_bytes_over_channel_payload_only_unsafe( - reinterpret_cast(sender_buffer_channel.get_current_bytes_sent_address()), - reinterpret_cast(receiver_buffer_channel.get_current_bytes_sent_address()), + reinterpret_cast(src_addr), + reinterpret_cast(dest_addr), sizeof(eth_channel_sync_t), sizeof(eth_channel_sync_t), sizeof(eth_channel_sync_t) >> ETH_BYTES_TO_WORDS_SHIFT); } template -tt::fabric::SendStatus send_next_data( +void send_next_data( tt::fabric::EthChannelBuffer &sender_buffer_channel, + tt::fabric::EdmChannelWorkerInterface &sender_worker_interface, + OutboundReceiverChannelPointers &outbound_to_receiver_channel_pointers, tt::fabric::EthChannelBuffer &receiver_buffer_channel) { - auto status = tt::fabric::SendStatus::NOT_SENT; + auto &remote_receiver_wrptr = outbound_to_receiver_channel_pointers.wrptr; + auto &local_sender_wrptr = sender_worker_interface.local_wrptr; + auto local_sender_wrptr_buffer_index = local_sender_wrptr.get_buffer_index(); ASSERT(!eth_txq_is_busy()); - - status = tt::fabric::SendStatus::SENT_PAYLOAD_AND_SYNC; ASSERT( - reinterpret_cast(sender_buffer_channel.get_current_bytes_sent_address()) == - (reinterpret_cast(sender_buffer_channel.get_current_buffer_address()) + - reinterpret_cast(sender_buffer_channel.get_current_max_eth_payload_size()) - + reinterpret_cast(sender_buffer_channel.get_bytes_sent_address(local_sender_wrptr_buffer_index)) == + (reinterpret_cast(sender_buffer_channel.get_buffer_address(local_sender_wrptr_buffer_index)) + + reinterpret_cast(sender_buffer_channel.get_max_eth_payload_size()) - (uint32_t)sizeof(eth_channel_sync_t))); - *sender_buffer_channel.get_current_bytes_sent_address() = sender_buffer_channel.get_current_max_eth_payload_size(); - *sender_buffer_channel.get_current_bytes_acked_address() = 0; - *sender_buffer_channel.get_current_src_id_address() = sender_buffer_channel.get_id(); - ASSERT(*sender_buffer_channel.get_current_src_id_address() < 2); // TODO: TUNING - experiment with only conditionally breaking the transfer up into multiple packets if we are // a certain threshold less than full packet @@ -316,134 +421,124 @@ tt::fabric::SendStatus send_next_data( // compare // NOTE: if we always send full packet, then we don't need the second branch below dedicated for // channel sync - ASSERT(tt::fabric::is_valid(*const_cast(reinterpret_cast(sender_buffer_channel.get_current_buffer_address())))); - const size_t payload_size = sender_buffer_channel.get_current_payload_plus_channel_sync_size(); + ASSERT(tt::fabric::is_valid(*const_cast(reinterpret_cast(sender_buffer_channel.get_buffer_address(local_sender_wrptr_buffer_index))))); + const size_t payload_size = sender_buffer_channel.get_payload_plus_channel_sync_size(local_sender_wrptr_buffer_index); + *sender_buffer_channel.get_bytes_sent_address(local_sender_wrptr_buffer_index) = payload_size; + *sender_buffer_channel.get_bytes_acked_address(local_sender_wrptr_buffer_index) = 0; + *sender_buffer_channel.get_src_id_address(local_sender_wrptr_buffer_index) = sender_buffer_channel.get_id(); + ASSERT(*sender_buffer_channel.get_src_id_address(local_sender_wrptr_buffer_index) < 2); + + auto src_addr = sender_buffer_channel.get_buffer_address(local_sender_wrptr_buffer_index); + auto dest_addr = receiver_buffer_channel.get_buffer_address(remote_receiver_wrptr.get_buffer_index()); eth_send_bytes_over_channel_payload_only_unsafe( - sender_buffer_channel.get_current_buffer_address(), - receiver_buffer_channel.get_current_buffer_address(), // get_remote_eth_buffer_address(), + src_addr, + dest_addr, payload_size, payload_size, payload_size >> ETH_BYTES_TO_WORDS_SHIFT); bool sent_payload_and_channel_sync_in_one_shot = - payload_size == sender_buffer_channel.get_current_max_eth_payload_size(); + payload_size == sender_buffer_channel.get_max_eth_payload_size(); if (!sent_payload_and_channel_sync_in_one_shot) { // We weren't able to send the channel_sync_t in one shot with the payload so we need to send a second // packet // TODO: TUNING - consider busy waiting for a maximum amount of time - if (!eth_txq_is_busy()) { - send_channel_sync(sender_buffer_channel, receiver_buffer_channel); - } else { - status = tt::fabric::SendStatus::SENT_PAYLOAD_ONLY; - } + while (eth_txq_is_busy()) {} + send_channel_sync( + sender_buffer_channel, local_sender_wrptr, receiver_buffer_channel, remote_receiver_wrptr); } // Note: We can only advance to the next buffer index if we have fully completed the send (both the payload and sync // messages) - if (status == tt::fabric::SendStatus::SENT_PAYLOAD_AND_SYNC) { - sender_buffer_channel.advance_buffer_index(); - receiver_buffer_channel.advance_buffer_index(); - } - - return status; + local_sender_wrptr.increment(); + remote_receiver_wrptr.increment(); } -template -FORCE_INLINE bool sender_noc_receive_payload_ack_check_sequence( - tt::fabric::EthChannelBuffer &sender_buffer_channel, - tt::fabric::EthChannelBuffer &receiver_buffer_channel) { - return sender_buffer_channel.is_local_semaphore_full(); -} -template -FORCE_INLINE void sender_eth_check_receiver_ack_sequence( - tt::fabric::EthChannelBuffer &sender_buffer_channel, - tt::fabric::EdmChannelWorkerInterface &sender_worker_interface) { - sender_buffer_channel.eth_clear_sender_channel_ack(); - - sender_notify_workers_if_buffer_available_sequence(sender_worker_interface); -} ///////////////////////////////////////////// // RECEIVER SIDE HELPERS ///////////////////////////////////////////// -template -FORCE_INLINE bool new_unacknowledged_packet_avilable_on_reciever_channel( - tt::fabric::EthChannelBuffer &local_receiver_channel) { - return local_receiver_channel.eth_bytes_are_available_on_channel(); -} - /* * Acting the receiver, we are looking at our receiver channel and acking the sender who sent us the latest packet. * Doesn't check to see if indeed a new message is available. It's assumed the caller has handled that separately. + * MUST CHECK !is_eth_txq_busy() before calling */ -// MUST CHECK !is_eth_txq_busy() before calling template void receiver_send_received_ack( + std::array, NUM_SENDER_CHANNELS> &remote_eth_sender_ackptrs, std::array, NUM_SENDER_CHANNELS> &remote_sender_channels, + // currently the pointer is working multiple jobs (ack, completion, read) because we haven't implemented the + // decoupling of those jobs yet to separate pointrers + tt::fabric::ChannelBufferPointer &receiver_channel_ptr, tt::fabric::EthChannelBuffer &local_receiver_buffer_channel) { // Set the acknowledgement bits. We have a different location than the - const auto src_id = *local_receiver_buffer_channel.get_current_src_id_address(); - ASSERT(src_id < NUM_SENDER_CHANNELS); - auto &sender_buffer_channel = remote_sender_channels[src_id]; - ASSERT( - reinterpret_cast(sender_buffer_channel.get_current_bytes_sent_address()) == - reinterpret_cast(sender_buffer_channel.get_current_buffer_address()) + - reinterpret_cast(sender_buffer_channel.get_current_max_eth_payload_size()) - - sizeof(eth_channel_sync_t)); + auto receiver_buffer_index = receiver_channel_ptr.get_buffer_index(); + const auto src_id = *local_receiver_buffer_channel.get_src_id_address(receiver_buffer_index); + auto &sender_ackptr = remote_eth_sender_ackptrs[src_id]; + ASSERT(src_id < NUM_SENDER_CHANNELS); const size_t local_ack_channel_sync_src_addr = local_receiver_buffer_channel.get_eth_transaction_ack_word_addr() + (src_id * sizeof(eth_channel_sync_t)); - reinterpret_cast(local_ack_channel_sync_src_addr)->bytes_sent = - *local_receiver_buffer_channel.get_current_bytes_sent_address(); + reinterpret_cast(local_ack_channel_sync_src_addr)->bytes_sent = 1; // *local_receiver_buffer_channel.get_bytes_sent_address(); reinterpret_cast(local_ack_channel_sync_src_addr)->receiver_ack = 1; - reinterpret_cast(local_ack_channel_sync_src_addr)->src_id = - *local_receiver_buffer_channel.get_current_src_id_address(); + reinterpret_cast(local_ack_channel_sync_src_addr)->src_id = src_id; + reinterpret_cast(local_ack_channel_sync_src_addr)->reserved_2 = 0xc0ffee2; + auto &sender_buffer_channel = remote_sender_channels[src_id]; + auto sender_ackptr_buffer_index = sender_ackptr.get_buffer_index(); + ASSERT(src_id < NUM_SENDER_CHANNELS); + ASSERT( + reinterpret_cast(sender_buffer_channel.get_bytes_sent_address(sender_ackptr_buffer_index)) == + reinterpret_cast(sender_buffer_channel.get_buffer_address(sender_ackptr_buffer_index)) + + reinterpret_cast(sender_buffer_channel.get_max_eth_payload_size()) - + sizeof(eth_channel_sync_t)); // Make sure we don't alias the erisc_info eth_channel_sync_t ASSERT( - reinterpret_cast(local_receiver_buffer_channel.get_current_bytes_sent_address()) + reinterpret_cast(local_receiver_buffer_channel.get_bytes_sent_address(receiver_buffer_index)) ->bytes_sent != 0); ASSERT( - reinterpret_cast(local_receiver_buffer_channel.get_current_bytes_sent_address()) + reinterpret_cast(local_receiver_buffer_channel.get_bytes_sent_address(receiver_buffer_index)) ->receiver_ack == 0); - - DPRINT << "EDMR rsa to " << (uint32_t)sender_buffer_channel.get_current_bytes_sent_address() << "\n"; - ASSERT(!eth_txq_is_busy()); internal_::eth_send_packet_unsafe( 0, local_ack_channel_sync_src_addr >> 4, - ((uint32_t)(sender_buffer_channel.get_current_bytes_sent_address())) >> 4, + ((uint32_t)(sender_buffer_channel.get_bytes_sent_address(sender_ackptr_buffer_index))) >> 4, 1); } // MUST CHECK !is_eth_txq_busy() before calling template FORCE_INLINE void receiver_send_completion_ack( + std::array, NUM_SENDER_CHANNELS> &remote_eth_sender_completion_ptrs, std::array, NUM_SENDER_CHANNELS> &remote_sender_channels, + tt::fabric::ChannelBufferPointer &receiver_channel_ptr, tt::fabric::EthChannelBuffer &local_receiver_buffer_channel) { - volatile auto local_bytes_sent_addr = local_receiver_buffer_channel.get_current_bytes_sent_address(); - volatile auto local_src_id_ptr = local_receiver_buffer_channel.get_current_src_id_address(); - auto src_sender_channel = *local_src_id_ptr; + auto receiver_buffer_index = receiver_channel_ptr.get_buffer_index(); + volatile auto local_bytes_sent_addr = local_receiver_buffer_channel.get_bytes_sent_address(receiver_buffer_index); + volatile auto local_src_id_ptr = local_receiver_buffer_channel.get_src_id_address(receiver_buffer_index); *(local_bytes_sent_addr) = 0; - *(local_receiver_buffer_channel.get_current_bytes_acked_address()) = 0; - ASSERT(src_sender_channel < NUM_SENDER_CHANNELS); + *(local_receiver_buffer_channel.get_bytes_acked_address(receiver_buffer_index)) = 0; - DPRINT << "EDMR rsc to " << (uint32_t)remote_sender_channels[src_sender_channel].get_current_bytes_sent_address() << "\n"; + auto src_sender_channel = *local_src_id_ptr; + auto &remote_sender_channel = remote_sender_channels[src_sender_channel]; + auto &remote_sender_completion_ptr = remote_eth_sender_completion_ptrs[src_sender_channel]; + ASSERT(src_sender_channel < NUM_SENDER_CHANNELS); ASSERT(!eth_txq_is_busy()); + internal_::eth_send_packet_unsafe( 0, (uint32_t)(local_bytes_sent_addr) >> 4, - (uint32_t)(remote_sender_channels[src_sender_channel].get_current_bytes_sent_address()) >> 4, + (uint32_t)(remote_sender_channel.get_bytes_sent_address(remote_sender_completion_ptr.get_buffer_index())) >> 4, 1); - local_receiver_buffer_channel.advance_buffer_index(); - remote_sender_channels[src_sender_channel].advance_buffer_index(); + receiver_channel_ptr.increment(); + remote_sender_completion_ptr.increment(); } @@ -456,46 +551,44 @@ PacketLocalForwardType get_packet_local_forward_type(const volatile tt::fabric:: } FORCE_INLINE bool can_forward_packet_completely( - const volatile tt::fabric::PacketHeader &packet_header, tt::fabric::WorkerToFabricEdmSender &downstream_edm_interface) { - auto forward_status = get_packet_local_forward_type(packet_header); + const volatile tt::fabric::PacketHeader *packet_header, tt::fabric::WorkerToFabricEdmSender &downstream_edm_interface) { + auto forward_status = get_packet_local_forward_type(*packet_header); switch (forward_status) { case PACKET_FORWARD_INVALID: return false; case PACKET_FORWARD_LOCAL_ONLY: return true; case PACKET_FORWARD_REMOTE_ONLY: - case PACKET_FORWARD_LOCAL_AND_REMOTE: return downstream_edm_interface.consumer_has_space(); + case PACKET_FORWARD_LOCAL_AND_REMOTE: return downstream_edm_interface.edm_has_space_for_packet(); default: ASSERT(false); return false; }; } -// template -tt::fabric::SendStatus receiver_forward_packet( +// !!!WARNING!!! - MAKE SURE CONSUMER HAS SPACE BEFORE CALLING +void receiver_forward_packet( volatile tt::fabric::PacketHeader *packet_start, tt::fabric::WorkerToFabricEdmSender &downstream_edm_interface) { // Just cache the packet_header - we don't really expect (or care) if contents change during this function. - tt::fabric::PacketHeader const &packet_header = *const_cast(packet_start); - ASSERT(tt::fabric::is_valid(packet_header)); + volatile tt::fabric::PacketHeader const &packet_header = *packet_start; + ASSERT(tt::fabric::is_valid(const_cast(packet_header))); auto forward_status = get_packet_local_forward_type(packet_header); - switch (forward_status) { case PACKET_FORWARD_LOCAL_ONLY: { execute_chip_unicast_to_local_chip(packet_start); - return tt::fabric::SendStatus::SENT_PAYLOAD_AND_SYNC; } break; case PACKET_FORWARD_REMOTE_ONLY: { - return forward_payload_to_downstream_edm(packet_start, downstream_edm_interface); + forward_payload_to_downstream_edm(packet_start, downstream_edm_interface); } break; case PACKET_FORWARD_LOCAL_AND_REMOTE: { ASSERT(packet_header.chip_send_type == tt::fabric::ChipSendType::CHIP_MULTICAST); // TODO: make local chip write non-blocking execute_chip_unicast_to_local_chip(packet_start); - return forward_payload_to_downstream_edm(packet_start, downstream_edm_interface); + forward_payload_to_downstream_edm(packet_start, downstream_edm_interface); } break; case PACKET_FORWARD_INVALID: - default: ASSERT(false); return tt::fabric::SendStatus::ERROR; + default: ASSERT(false); }; } @@ -504,144 +597,191 @@ tt::fabric::SendStatus receiver_forward_packet( // Main Control Loop //////////////////////////////////// //////////////////////////////////// -template -bool run_sender_channel_state_machine_step( +template +bool run_sender_channel_step( tt::fabric::EthChannelBuffer &local_sender_channel, - tt::fabric::EdmChannelWorkerInterface &local_sender_channel_worker_interface, + tt::fabric::EdmChannelWorkerInterface &local_sender_channel_worker_interface, + OutboundReceiverChannelPointers &outbound_to_receiver_channel_pointers, tt::fabric::EthChannelBuffer &remote_receiver_channel, + volatile tt::fabric::EdmFabricSenderChannelCounters* sender_channel_counters, + PacketHeaderRecorder &packet_header_recorder, bool graceful_termination_mode, - SenderState *const sender_state_out, + bool &channel_connection_established, uint8_t sender_channel_index) { - bool incr_sender_channel_index = true; - switch (*sender_state_out) { - case SenderState::SENDER_WAITING_FOR_WORKER: { - bool able_to_send = local_sender_channel_worker_interface.has_payload() && !eth_txq_is_busy() && - local_sender_channel.eth_is_receiver_channel_send_done(); - if (able_to_send) { - DPRINT << "EDMS " << (uint32_t)sender_channel_index << "\n"; - DPRINT << "\taddress: " << (uint32_t)local_sender_channel.get_current_buffer_address() << "\n"; - DPRINT << "\t1st 8B: " << (uint64_t)*reinterpret_cast(local_sender_channel.get_current_buffer_address()) << "\n"; - DPRINT << "\tsend to " << (uint32_t)remote_receiver_channel.get_current_buffer_address() << "\n"; - auto send_status = send_next_data(local_sender_channel, remote_receiver_channel); - // TODO: align the enums and state values so I can just do - // sender_states[sender_channel_index] += send_status :) - ASSERT(send_status != tt::fabric::SendStatus::ERROR); - *sender_state_out = - send_status == tt::fabric::SendStatus::NOT_SENT ? SenderState::SENDER_WAITING_FOR_WORKER - : send_status == tt::fabric::SendStatus::SENT_PAYLOAD_ONLY ? SenderState::SENDER_SEND_CHANNEL_SYNC - : SenderState::SENDER_WAITING_FOR_ETH; - // Avoid any sort of starvation/bubbles so we only advance if we've sent the packet and channel sync - // otherwise what can happen is we could start sending another large payload from the other channel - // and not be able to send the channel sync for the packet we just sent, which overall negatively - // impact latency - incr_sender_channel_index = send_status != tt::fabric::SendStatus::SENT_PAYLOAD_ONLY; - } else if (!graceful_termination_mode) { - if (!local_sender_channel_worker_interface.has_payload() && local_sender_channel_worker_interface.has_worker_teardown_request()) { - local_sender_channel_worker_interface.teardown_connection(); - *sender_state_out = SenderState::SENDER_WAIT_WORKER_HANDSHAKE; + bool did_something = false; + + // If the receiver has space, and we have one or more packets unsent from producer, then send one + // TODO: convert to loop to send multiple packets back to back (or support sending multiple packets in one shot) + // when moving to stream regs to manage rd/wr ptrs + bool receiver_has_space_for_packet = outbound_to_receiver_channel_pointers.has_space_for_packet(); + if (receiver_has_space_for_packet && !eth_txq_is_busy()) { + bool has_unsent_packet = local_sender_channel_worker_interface.has_unsent_payload(); + if (has_unsent_packet) { + bool sender_backpressured_from_sender_side = !(local_sender_channel_worker_interface.local_rdptr.distance_behind(local_sender_channel_worker_interface.local_wrptr) < SENDER_NUM_BUFFERS); + if (!sender_backpressured_from_sender_side) { + ASSERT(local_sender_channel.eth_is_receiver_channel_send_done(local_sender_channel_worker_interface.local_wrptr.get_buffer_index())); + did_something = true; + auto packet_header = reinterpret_cast(local_sender_channel.get_buffer_address(local_sender_channel_worker_interface.local_wrptr.get_buffer_index())); + tt::fabric::validate(*packet_header); + if constexpr (enable_packet_header_recording) { + packet_header_recorder.record_packet_header(packet_header); } + send_next_data( + local_sender_channel, + local_sender_channel_worker_interface, + outbound_to_receiver_channel_pointers, + remote_receiver_channel); } - } break; + } + } - case SenderState::SENDER_WAIT_WORKER_HANDSHAKE: - if (local_sender_channel_worker_interface.connection_is_live()) { - bool is_safe_to_receive_next_message = local_sender_channel.eth_is_receiver_channel_send_acked() || - local_sender_channel.eth_is_receiver_channel_send_done(); - if (is_safe_to_receive_next_message) { - DPRINT << "EDM ch " << (uint32_t)sender_channel_index << " wkr con ntfy wrkr\n"; - DPRINT << "\tl1 worker info ptr: " << (uint32_t)local_sender_channel_worker_interface.worker_location_info_ptr << "\n"; - DPRINT << "\tworker.x=" << (uint32_t)local_sender_channel_worker_interface.worker_location_info_ptr->worker_xy.x << ", .y=" << (uint32_t)local_sender_channel_worker_interface.worker_location_info_ptr->worker_xy.y << ", sem_addr=" << (uint32_t)local_sender_channel_worker_interface.worker_location_info_ptr->worker_semaphore_address << "\n"; - sender_notify_workers_if_buffer_available_sequence(local_sender_channel_worker_interface); - *sender_state_out = SenderState::SENDER_WAITING_FOR_WORKER; - } else { - *sender_state_out = SenderState::SENDER_WAITING_FOR_ETH; + bool has_unacknowledged_eth_packets = outbound_to_receiver_channel_pointers.has_unacknowledged_or_incomplete_eth_packets(); + if (has_unacknowledged_eth_packets) { + { + auto& sender_ackptr = local_sender_channel_worker_interface.local_ackptr; + auto old_ackptr = sender_ackptr; + // Only check for acks first + bool check_next = !local_sender_channel_worker_interface.all_eth_packets_acked(); + while (check_next) { + // TODO: change how ack is represented so we can check both at once without + // having to worry about races (i.e. right now we don't have monotonicity + // but if we did we could safely (check ack || completed)) + tt::fabric::BufferIndex rd_buffer_index = sender_ackptr.get_buffer_index(); + + bool acked_or_completed = local_sender_channel.eth_is_acked_or_completed(rd_buffer_index); + if (acked_or_completed) { + local_sender_channel.eth_clear_sender_channel_ack(rd_buffer_index); + sender_ackptr.increment(); + local_sender_channel_worker_interface.propagate_ackptr_to_connection_info(); + did_something = true; + outbound_to_receiver_channel_pointers.ack_ptr.increment(); } + check_next = acked_or_completed && !local_sender_channel_worker_interface.all_eth_packets_acked(); } - break; - - case SenderState::SENDER_SEND_CHANNEL_SYNC: { - bool can_send_channel_sync_without_blocking = !eth_txq_is_busy(); - if (can_send_channel_sync_without_blocking) { - DPRINT << "EDMS send channel sync\n"; - send_channel_sync(local_sender_channel, remote_receiver_channel); - local_sender_channel.advance_buffer_index(); - remote_receiver_channel.advance_buffer_index(); - *sender_state_out = SenderState::SENDER_WAITING_FOR_ETH; + + bool advanced = old_ackptr.get_ptr() != sender_ackptr.get_ptr(); + if (advanced && channel_connection_established) { + local_sender_channel_worker_interface.update_worker_copy_of_read_ptr(); } - } break; + } - case SenderState::SENDER_WAITING_FOR_ETH: { - bool is_safe_to_receive_next_message = local_sender_channel.eth_is_receiver_channel_send_acked() || - local_sender_channel.eth_is_receiver_channel_send_done(); - if (is_safe_to_receive_next_message) { - // This also notifies workers in the same call - DPRINT << "EDMS:\n"; - sender_eth_check_receiver_ack_sequence(local_sender_channel, local_sender_channel_worker_interface); - *sender_state_out = SenderState::SENDER_WAITING_FOR_WORKER; + { + // stupid implementation but keeps things simple to bootstrap + auto& sender_rdptr = local_sender_channel_worker_interface.local_rdptr; + bool check_next = !local_sender_channel_worker_interface.all_eth_packets_completed(); + while (check_next) { + bool completed = local_sender_channel.eth_is_receiver_channel_send_done(sender_rdptr.get_buffer_index()); + if (completed) { + did_something = true; + if (local_sender_channel_worker_interface.local_ackptr.get_ptr() == sender_rdptr.get_ptr()) { + // If ackptr is also here, then we need to increment it too + outbound_to_receiver_channel_pointers.ack_ptr.increment(); + local_sender_channel_worker_interface.propagate_ackptr_to_connection_info(); + if (channel_connection_established) { + local_sender_channel_worker_interface.update_worker_copy_of_read_ptr(); + } + } + outbound_to_receiver_channel_pointers.completion_ptr.increment(); + sender_rdptr.increment(); + } + check_next = completed && !local_sender_channel_worker_interface.all_eth_packets_completed(); } - } break; + } + } - default: break; - }; + if (!channel_connection_established) { + // Can get rid of one of these two checks if we duplicate the logic above here in the function + // and depending on which of the two versions we are in (the connected version or disconnected version) + // We also check if the interface has a teardown request in case worker + // 1. opened connection + // 2. sent of all packets (EDM sender channel was sufficiently empty) + // 3. closed the connection + // + // In such a case like that, we still want to formally teardown the connection to keep things clean + bool connect_requested = local_sender_channel_worker_interface.connection_is_live() || + local_sender_channel_worker_interface.has_worker_teardown_request(); + if (connect_requested) { + if constexpr (enable_fabric_counters) { + sender_channel_counters->add_connection(); + } + did_something = true; + channel_connection_established = true; + local_sender_channel_worker_interface.update_worker_copy_of_read_ptr(); + } + } else if (local_sender_channel_worker_interface.has_worker_teardown_request()) { + did_something = true; + channel_connection_established = false; + local_sender_channel_worker_interface.teardown_connection( + local_sender_channel_worker_interface.local_rdptr.get_ptr()); + } - return incr_sender_channel_index; + return did_something; }; -template -void run_receiver_channel_state_machine_step( +template +void run_receiver_channel_step( tt::fabric::EthChannelBuffer &local_receiver_channel, std::array, NUM_SENDER_CHANNELS> &remote_sender_channnels, tt::fabric::WorkerToFabricEdmSender &downstream_edm_interface, + volatile tt::fabric::EdmFabricReceiverChannelCounters *receiver_channel_counters_ptr, + std::array, NUM_SENDER_CHANNELS> &remote_eth_sender_wrptrs, + ReceiverChannelPointers &receiver_channel_pointers, + PacketHeaderRecorder &packet_header_recorder, ReceiverState *const receiver_state_out) { - switch (*receiver_state_out) { - case ReceiverState::RECEIVER_WAITING_FOR_ETH: { - bool got_payload = local_receiver_channel.eth_bytes_are_available_on_channel(); - if (got_payload) { - bool can_ack = !eth_txq_is_busy(); - if (can_ack) { - DPRINT << "EDMR got pkt @: " << (uint32_t)reinterpret_cast(local_receiver_channel.get_current_packet_header()) << "\n"; - DPRINT << "EDMR got pkt 0 : " << (uint64_t) reinterpret_cast(local_receiver_channel.get_current_packet_header())[0] << "\n"; - DPRINT << "EDMR got pkt 1: " << (uint64_t) reinterpret_cast(local_receiver_channel.get_current_packet_header())[1] << "\n"; - ASSERT(tt::fabric::is_valid( - *const_cast(local_receiver_channel.get_current_packet_header()))); - receiver_send_received_ack(remote_sender_channnels, local_receiver_channel); - // TODO: PERF Need to add feature to let use perform local noc write and defer the forward to EDM - // if we are mcasting to the local chip and neighbours, but the downstream EDM isn't currently able - // to accept the packet - // ... - // but as a starting point we can do the dumb thing and just wait for space downstream - // before we do either. - *receiver_state_out = ReceiverState::RECEIVER_SENDING_PAYLOAD; - // TODO: PERF - SHORT CIRCUIT IF WE CAN TO NESXT STATE TO MINIMIZE LATENCY BUT CURRENTLY - // A LITTLE CODE SIZE BOUND - } - } - } break; - case ReceiverState::RECEIVER_SENDING_PAYLOAD: { - auto& packet_header = *local_receiver_channel.get_current_packet_header(); - bool can_send_to_all_local_chip_receivers = - can_forward_packet_completely(packet_header, downstream_edm_interface); - if (can_send_to_all_local_chip_receivers) { - DPRINT << "EDMR writing pkt\n"; - receiver_forward_packet(local_receiver_channel.get_current_packet_header(), downstream_edm_interface); - *receiver_state_out = ReceiverState::RECEIVER_WAITING_FOR_WRITE_FLUSH; - } - } break; + // Optimization: + // 1. Let wrptr advance ahead of ackptr + auto &ack_ptr = receiver_channel_pointers.ack_ptr; + auto ack_ptr_buffer_index = ack_ptr.get_buffer_index(); + bool packet_received = local_receiver_channel.eth_bytes_are_available_on_channel(ack_ptr_buffer_index) && + receiver_channel_pointers.completion_ptr.distance_behind(ack_ptr) < RECEIVER_NUM_BUFFERS; + bool can_send_over_eth = !eth_txq_is_busy(); + if (packet_received && can_send_over_eth) { + receiver_send_received_ack( + remote_eth_sender_wrptrs, + remote_sender_channnels, + ack_ptr, + local_receiver_channel); + ack_ptr.increment(); + } - case ReceiverState::RECEIVER_WAITING_FOR_WRITE_FLUSH: { - bool writes_flushed = ncrisc_noc_nonposted_writes_sent(noc_index); - if (writes_flushed) { - bool can_send_ack_without_blocking = !eth_txq_is_busy(); - if (can_send_ack_without_blocking) { - receiver_send_completion_ack(remote_sender_channnels, local_receiver_channel); - *receiver_state_out = ReceiverState::RECEIVER_WAITING_FOR_ETH; - } - } - } break; + auto &wr_sent_ptr = receiver_channel_pointers.wr_sent_ptr; + bool unwritten_packets = !wr_sent_ptr.is_caught_up_to(ack_ptr); + if (unwritten_packets) { + auto receiver_buffer_index = wr_sent_ptr.get_buffer_index(); + volatile auto packet_header = local_receiver_channel.get_packet_header(receiver_buffer_index); + bool can_send_to_all_local_chip_receivers = + can_forward_packet_completely(packet_header, downstream_edm_interface); + if (can_send_to_all_local_chip_receivers) { + receiver_forward_packet(packet_header, downstream_edm_interface); + wr_sent_ptr.increment(); + } + } - default: break; - }; + auto &wr_flush_ptr = receiver_channel_pointers.wr_flush_ptr; + bool unflushed_writes = !wr_flush_ptr.is_caught_up_to(wr_sent_ptr); + if (unflushed_writes) { + bool writes_flushed = ncrisc_noc_nonposted_writes_sent(noc_index); + if (writes_flushed) { + auto receiver_buffer_index = wr_flush_ptr.get_buffer_index(); + local_receiver_channel.eth_clear_sender_channel_ack(receiver_buffer_index); + wr_flush_ptr.increment(); + } + } + + auto &completion_ptr = receiver_channel_pointers.completion_ptr; + bool unsent_completions = !completion_ptr.is_caught_up_to(wr_flush_ptr); + if (unsent_completions) { + bool can_send_without_blocking = !eth_txq_is_busy(); + if (can_send_without_blocking) { + // completion ptr incremented in callee + receiver_send_completion_ack( + remote_eth_sender_wrptrs, + remote_sender_channnels, + completion_ptr, + local_receiver_channel); + } + } }; @@ -660,12 +800,15 @@ FORCE_INLINE bool got_termination_signal(volatile tt::fabric::TerminationSignal template bool all_channels_drained(tt::fabric::EthChannelBuffer &local_receiver_channel, std::array, NUM_SENDER_CHANNELS> &local_sender_channels, - std::array &local_sender_channel_worker_interfaces) { + std::array, NUM_SENDER_CHANNELS> &local_sender_channel_worker_interfaces) { - bool eth_buffers_drained = local_sender_channels[0].all_buffers_drained() && local_sender_channels[1].all_buffers_drained() && local_receiver_channel.all_buffers_drained(); + bool eth_buffers_drained = + !local_sender_channel_worker_interfaces[0].has_unacked_sends() && + !local_sender_channel_worker_interfaces[1].has_unacked_sends() && + local_receiver_channel.all_buffers_drained(); - bool sender0_has_unsent_packets = (local_sender_channel_worker_interfaces[0].has_payload()); - bool sender1_has_unsent_packets = (local_sender_channel_worker_interfaces[1].has_payload()); + bool sender0_has_unsent_packets = local_sender_channel_worker_interfaces[0].has_unsent_payload(); + bool sender1_has_unsent_packets = local_sender_channel_worker_interfaces[1].has_unsent_payload(); return eth_buffers_drained && !sender0_has_unsent_packets && !sender1_has_unsent_packets; } @@ -676,15 +819,19 @@ bool all_channels_drained(tt::fabric::EthChannelBuffer &lo * Every loop iteration visit a sender channel and the receiver channel. Switch between sender * channels every iteration unless it is unsafe/undesirable to do so (e.g. for performance reasons). */ -template +template void run_fabric_edm_main_loop( tt::fabric::EthChannelBuffer &local_receiver_channel, std::array, NUM_SENDER_CHANNELS> &local_sender_channels, - std::array &local_sender_channel_worker_interfaces, + std::array, NUM_SENDER_CHANNELS> &local_sender_channel_worker_interfaces, tt::fabric::WorkerToFabricEdmSender &downstream_edm_noc_interface, std::array, NUM_SENDER_CHANNELS> &remote_sender_channels, tt::fabric::EthChannelBuffer &remote_receiver_channel, - volatile tt::fabric::TerminationSignal *termination_signal_ptr) { + volatile tt::fabric::TerminationSignal *termination_signal_ptr, + volatile tt::fabric::EdmFabricReceiverChannelCounters *receiver_channel_counters_ptr, + std::array sender_channel_counters_ptrs, + PacketHeaderRecorder &receiver_channel_packet_recorder, + std::array &sender_channel_packet_recorders) { std::array sender_states = { SenderState::SENDER_WAIT_WORKER_HANDSHAKE, SenderState::SENDER_WAIT_WORKER_HANDSHAKE}; ReceiverState receiver_state = ReceiverState::RECEIVER_WAITING_FOR_ETH; @@ -692,11 +839,22 @@ void run_fabric_edm_main_loop( size_t did_nothing_count = 0; *termination_signal_ptr = tt::fabric::TerminationSignal::KEEP_RUNNING; + // May want to promote to part of the handshake but for now we just initialize in this standalone way + // TODO: flatten all of these arrays into a single object (one array lookup) OR + // (probably better) pack most of these into single words (e.g. we could hold a read, write, and ackptr in a single word) + // this way - especially if power of 2 wraps, we can handle both channels literally at once with math ops on single individual + // words (or half words) + std::array, NUM_SENDER_CHANNELS> remote_eth_sender_wrptrs { + tt::fabric::ChannelBufferPointer(), + tt::fabric::ChannelBufferPointer()}; + OutboundReceiverChannelPointers outbound_to_receiver_channel_pointers; + ReceiverChannelPointers receiver_channel_pointers; + std::array channel_connection_established = {false, false}; + while (!got_immediate_termination_signal(termination_signal_ptr)) { bool got_graceful_termination = got_graceful_termination_signal(termination_signal_ptr); if (got_graceful_termination) { DPRINT << "EDM Graceful termination\n"; - DPRINT << "EDMS0 ST: " << (uint32_t)sender_states[0] << "\n"; bool all_drained = all_channels_drained( local_receiver_channel, local_sender_channels, local_sender_channel_worker_interfaces); @@ -706,28 +864,28 @@ void run_fabric_edm_main_loop( } // Capture these to see if we made progress - auto old_send_state = sender_states[sender_channel_index]; auto old_recv_state = receiver_state; - auto &local_sender_channel = local_sender_channels[sender_channel_index]; - auto &local_sender_channel_worker_interface = local_sender_channel_worker_interfaces[sender_channel_index]; // There are some cases, mainly for performance, where we don't want to switch between sender channels // so we interoduce this to provide finer grain control over when we disable the automatic switching - bool incr_sender_channel_index = run_sender_channel_state_machine_step( - local_sender_channel, - local_sender_channel_worker_interface, + bool did_something_sender = run_sender_channel_step( + local_sender_channels[sender_channel_index], + local_sender_channel_worker_interfaces[sender_channel_index], + outbound_to_receiver_channel_pointers, remote_receiver_channel, + sender_channel_counters_ptrs[sender_channel_index], + sender_channel_packet_recorders[sender_channel_index], got_graceful_termination, - &(sender_states[sender_channel_index]), + channel_connection_established[sender_channel_index], sender_channel_index); - bool did_something_sender = old_send_state != sender_states[sender_channel_index]; - if (incr_sender_channel_index) { - // TODO: this can probably be optimized - sender_channel_index = 1 - sender_channel_index; - } - run_receiver_channel_state_machine_step( - local_receiver_channel, remote_sender_channels, downstream_edm_noc_interface, &receiver_state); + sender_channel_index = 1 - sender_channel_index; + + run_receiver_channel_step( + local_receiver_channel, remote_sender_channels, downstream_edm_noc_interface, receiver_channel_counters_ptr, + remote_eth_sender_wrptrs, + receiver_channel_pointers, + receiver_channel_packet_recorder, &receiver_state); bool did_something = did_something_sender || old_recv_state != receiver_state; @@ -754,10 +912,8 @@ void kernel_main() { static constexpr size_t DEFAULT_HANDSHAKE_CONTEXT_SWITCH_TIMEOUT = 0; if constexpr (is_handshake_sender) { - // DPRINT << "EDM Starting handshake as sender\n"; erisc::datamover::handshake::sender_side_start(handshake_addr, DEFAULT_HANDSHAKE_CONTEXT_SWITCH_TIMEOUT); } else { - // DPRINT << "EDM Starting handshake as receiver\n"; erisc::datamover::handshake::receiver_side_start(handshake_addr); } @@ -777,6 +933,17 @@ void kernel_main() { static constexpr size_t remote_sender_0_channel_address = get_compile_time_arg_val(12); static constexpr size_t remote_sender_1_channel_address = get_compile_time_arg_val(13); + DPRINT << "SENDER_NUM_BUFFERS: " << (uint32_t)SENDER_NUM_BUFFERS << "\n"; + DPRINT << "RECEIVER_NUM_BUFFERS: " << (uint32_t)RECEIVER_NUM_BUFFERS << "\n"; + DPRINT << "local_sender_0_channel_address: " << (uint32_t)local_sender_0_channel_address << "\n"; + DPRINT << "local_sender_channel_0_connection_info_addr: " << (uint32_t)local_sender_channel_0_connection_info_addr << "\n"; + DPRINT << "local_sender_1_channel_address: " << (uint32_t)local_sender_1_channel_address << "\n"; + DPRINT << "local_sender_channel_1_connection_info_addr: " << (uint32_t)local_sender_channel_1_connection_info_addr << "\n"; + DPRINT << "local_receiver_channel_buffer_address: " << (uint32_t)local_receiver_channel_buffer_address << "\n"; + DPRINT << "remote_receiver_channel_buffer_address: " << (uint32_t)remote_receiver_channel_buffer_address << "\n"; + DPRINT << "remote_sender_0_channel_address: " << (uint32_t)remote_sender_0_channel_address << "\n"; + DPRINT << "remote_sender_1_channel_address: " << (uint32_t)remote_sender_1_channel_address << "\n"; + // TODO: CONVERT TO SEMAPHORE volatile auto termination_signal_ptr = reinterpret_cast(get_compile_time_arg_val(14)); @@ -785,9 +952,45 @@ void kernel_main() { // resolve the semaphore addresses on the EDM core static constexpr bool persistent_mode = get_compile_time_arg_val(15) != 0; + // Per-channel counters + static constexpr bool enable_fabric_counters = get_compile_time_arg_val(16) != 0; + static constexpr size_t receiver_channel_counters_address = get_compile_time_arg_val(17); + static constexpr size_t sender_channel_0_counters_address = get_compile_time_arg_val(18); + static constexpr size_t sender_channel_1_counters_address = get_compile_time_arg_val(19); + + static constexpr bool enable_packet_header_recording = get_compile_time_arg_val(20) != 0; + static constexpr size_t receiver_completed_packet_header_cb_address = get_compile_time_arg_val(21); + static constexpr size_t receiver_completed_packet_header_cb_size_headers = get_compile_time_arg_val(22); + static constexpr size_t sender_0_completed_packet_header_cb_address = get_compile_time_arg_val(23); + static constexpr size_t sender_0_completed_packet_header_cb_size_headers = get_compile_time_arg_val(24); + static constexpr size_t sender_1_completed_packet_header_cb_address = get_compile_time_arg_val(25); + static constexpr size_t sender_1_completed_packet_header_cb_size_headers = get_compile_time_arg_val(26); + + std::array sender_channel_packet_recorders{ + PacketHeaderRecorder( + reinterpret_cast(sender_0_completed_packet_header_cb_address), + sender_0_completed_packet_header_cb_size_headers), + PacketHeaderRecorder( + reinterpret_cast(sender_1_completed_packet_header_cb_address), + sender_1_completed_packet_header_cb_size_headers) + }; + PacketHeaderRecorder receiver_channel_packet_recorder( + reinterpret_cast(receiver_completed_packet_header_cb_address), + receiver_completed_packet_header_cb_size_headers); + static_assert(SENDER_NUM_BUFFERS > 0, "compile time argument [1]: SENDER_NUM_BUFFERS must be > 0"); static_assert(RECEIVER_NUM_BUFFERS > 0, "compile time argument [2]: RECEIVER_NUM_BUFFERS must be > 0"); + volatile tt::fabric::EdmFabricReceiverChannelCounters *receiver_channel_counters_ptr = nullptr; + volatile tt::fabric::EdmFabricSenderChannelCounters *sender_channel_0_counters_ptr = nullptr; + volatile tt::fabric::EdmFabricSenderChannelCounters *sender_channel_1_counters_ptr = nullptr; + + if constexpr (enable_fabric_counters) { + new (const_cast(receiver_channel_counters_ptr)) tt::fabric::EdmFabricReceiverChannelCounters(); + new (const_cast(sender_channel_0_counters_ptr)) tt::fabric::EdmFabricSenderChannelCounters(); + new (const_cast(sender_channel_1_counters_ptr)) tt::fabric::EdmFabricSenderChannelCounters(); + } + size_t arg_idx = 0; /////////////////////// // Common runtime args: @@ -854,13 +1057,18 @@ void kernel_main() { std::array{remote_sender_0_channel_address, remote_sender_1_channel_address}; std::array, NUM_SENDER_CHANNELS> remote_sender_channels; std::array, NUM_SENDER_CHANNELS> local_sender_channels; - std::array local_sender_channel_worker_interfaces; + std::array, NUM_SENDER_CHANNELS> local_sender_channel_worker_interfaces; std::array local_sender_flow_control_semaphores = { reinterpret_cast(sender0_worker_semaphore_ptr), reinterpret_cast(sender1_worker_semaphore_ptr)}; std::array local_sender_connection_live_semaphore_addresses = { local_sender_channel_0_connection_semaphore_addr, local_sender_channel_1_connection_semaphore_addr}; std::array local_sender_connection_info_addresses = { local_sender_channel_0_connection_info_addr, local_sender_channel_1_connection_info_addr}; + for (size_t i = 0; i < NUM_SENDER_CHANNELS; i++) { + auto connection_worker_info_ptr = reinterpret_cast( + local_sender_connection_info_addresses[i]); + connection_worker_info_ptr->edm_rdptr = 0; + } auto downstream_edm_noc_interface = has_downstream_edm_buffer_connection ? tt::fabric::WorkerToFabricEdmSender( @@ -916,7 +1124,8 @@ void kernel_main() { reinterpret_cast(local_sender_connection_live_semaphore_addresses[i]); auto connection_worker_info_ptr = reinterpret_cast( local_sender_connection_info_addresses[i]); - new (&local_sender_channel_worker_interfaces[i]) tt::fabric::EdmChannelWorkerInterface( + connection_worker_info_ptr->edm_rdptr = 0; + new (&local_sender_channel_worker_interfaces[i]) tt::fabric::EdmChannelWorkerInterface( connection_worker_info_ptr, reinterpret_cast( local_sender_flow_control_semaphores[i]), @@ -926,6 +1135,8 @@ void kernel_main() { if (has_downstream_edm_buffer_connection) { downstream_edm_noc_interface.open(); + *downstream_edm_noc_interface.from_remote_buffer_slot_rdptr_ptr = 0; + ASSERT(*downstream_edm_noc_interface.from_remote_buffer_slot_rdptr_ptr == 0); } if constexpr (is_handshake_sender) { @@ -933,21 +1144,24 @@ void kernel_main() { } else { erisc::datamover::handshake::receiver_side_finish(handshake_addr, DEFAULT_HANDSHAKE_CONTEXT_SWITCH_TIMEOUT); } - DPRINT << "EDM Core y|x " << (uint32_t)((my_y[0] << 16) | my_x[0]) << "\n"; ////////////////////////////// ////////////////////////////// // MAIN LOOP ////////////////////////////// ////////////////////////////// - run_fabric_edm_main_loop( + run_fabric_edm_main_loop( local_receiver_channel, local_sender_channels, local_sender_channel_worker_interfaces, downstream_edm_noc_interface, remote_sender_channels, remote_receiver_channel, - termination_signal_ptr); + termination_signal_ptr, + receiver_channel_counters_ptr, + {sender_channel_0_counters_ptr, sender_channel_1_counters_ptr}, + receiver_channel_packet_recorder, + sender_channel_packet_recorders); if constexpr (persistent_mode) { diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover_channels.hpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover_channels.hpp index 3acf9c4cd4a..d7981f23407 100644 --- a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover_channels.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover_channels.hpp @@ -15,23 +15,128 @@ #include "cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp" #include "cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_types.hpp" #include "cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp" namespace tt::fabric { + +template +class NamedType +{ +public: + explicit NamedType(T const& value) : value_(value) {} + explicit NamedType(T&& value) : value_(std::move(value)) {} + NamedType &operator=(NamedType const& rhs) = default; + T& get() { return value_; } + T const& get() const {return value_; } + operator T() const { return value_; } + operator T&() { return value_; } +private: + T value_; +}; + +using BufferIndex = NamedType; +using BufferPtr = NamedType; + + // Increments val and wraps to 0 if it reaches limit -template +template auto wrap_increment(T val) -> T { static_assert(LIMIT != 0, "wrap_increment called with limit of 0; it must be greater than 0"); + constexpr bool is_pow2 = (LIMIT & (LIMIT - 1)) == 0; if constexpr (LIMIT == 1) { return val; } else if constexpr (LIMIT == 2) { return 1 - val; - } else if constexpr ((LIMIT > 0) && (LIMIT & (LIMIT - 1)) == 0) { + } else if constexpr (is_pow2) { return (val + 1) & (LIMIT - 1); } else { - return (val == LIMIT - 1) ? 0 : val + 1; + return (val == static_cast(LIMIT - 1)) ? static_cast(0) : static_cast(val + 1); } } +template +auto normalize_ptr(BufferPtr ptr) -> BufferIndex { + static_assert(NUM_BUFFERS != 0, "normalize_ptr called with NUM_BUFFERS of 0; it must be greater than 0"); + constexpr bool is_size_pow2 = (NUM_BUFFERS & (NUM_BUFFERS - 1)) == 0; + constexpr bool is_size_2 = NUM_BUFFERS == 2; + constexpr bool is_size_1 = NUM_BUFFERS == 1; + constexpr uint8_t wrap_mask = NUM_BUFFERS - 1; + if constexpr (is_size_pow2) { + return BufferIndex{ptr & wrap_mask}; + } else if constexpr (is_size_2) { + return BufferIndex{(uint8_t)1 - ptr}; + } else if constexpr (is_size_1) { + return BufferIndex{0}; + } else { + // note it may make sense to calculate this only when we increment + // which will save calculations overall (but may add register pressure) + // and introduce undesirable loads + bool normalize = ptr >= NUM_BUFFERS; + uint8_t normalized_ptr = ptr.get() - static_cast(normalize * NUM_BUFFERS); + ASSERT(normalized_ptr < NUM_BUFFERS); + return BufferIndex{normalized_ptr}; + } +} + + +template +class ChannelBufferPointer { + static_assert(NUM_BUFFERS <= std::numeric_limits::max() / 2, "NUM_BUFFERS must be less than or half of std::numeric_limits::max() due to the internal implementation"); + public: + static constexpr bool is_size_pow2 = (NUM_BUFFERS & (NUM_BUFFERS - 1)) == 0; + static constexpr bool is_size_2 = NUM_BUFFERS == 2; + static constexpr bool is_size_1 = NUM_BUFFERS == 1; + static constexpr uint8_t ptr_wrap_size = 2 * NUM_BUFFERS; + + // Only to use if is_size_pow2 + static constexpr uint8_t ptr_wrap_mask = (2 * NUM_BUFFERS) - 1; + static constexpr uint8_t buffer_wrap_mask = NUM_BUFFERS - 1; + ChannelBufferPointer() : ptr(0) {} + /* + * Returns the "raw" pointer - not usable to index the buffer channel + */ + BufferPtr get_ptr() const { + return this->ptr; + } + + bool is_caught_up_to(ChannelBufferPointer const& leading_ptr) const { + return this->is_caught_up_to(leading_ptr.get_ptr()); + } + uint8_t distance_behind(ChannelBufferPointer const& leading_ptr) const { + return this->distance_behind(leading_ptr.get_ptr()); + } + + /* + * Returns the buffer index pointer which is usable to index into the buffer memory + */ + BufferIndex get_buffer_index() const { + return BufferIndex{normalize_ptr(this->ptr)}; + } + + void increment() { + this->ptr = wrap_increment<2*NUM_BUFFERS>(this->ptr); + } + + private: + // Make these private to make sure caller doesn't accidentally mix two pointers pointing to + // different sized channels + bool is_caught_up_to(BufferPtr const& leading_ptr) const { + return this->get_ptr() == leading_ptr; + } + uint8_t distance_behind(BufferPtr const& leading_ptr) const { + bool leading_gte_trailing_ptr = leading_ptr >= this->ptr; + if constexpr (is_size_pow2) { + return (leading_ptr - this->ptr) & ptr_wrap_mask; + } else { + return leading_gte_trailing_ptr ? + leading_ptr - this->ptr : + ptr_wrap_size - (this->ptr - leading_ptr); + } + } + BufferPtr ptr = BufferPtr{0}; +}; + + template FORCE_INLINE auto wrap_increment(T val, size_t max) { return (val == max - 1) ? 0 : val + 1; @@ -65,11 +170,10 @@ class EthChannelBuffer final { buffer_size_in_bytes(buffer_size_bytes), eth_transaction_ack_word_addr(eth_transaction_ack_word_addr), max_eth_payload_size_in_bytes(buffer_size_in_bytes + sizeof(eth_channel_sync_t)), - buff_idx(0), channel_id(channel_id) { for (uint8_t i = 0; i < NUM_BUFFERS; i++) { this->buffer_addresses[i] = - channel_base_address + i * this->max_eth_payload_size_in_bytes; //(this->buffer_size_in_bytes); + channel_base_address + i * this->max_eth_payload_size_in_bytes; uint32_t channel_sync_addr = this->buffer_addresses[i] + buffer_size_in_bytes; auto channel_sync_ptr = reinterpret_cast(channel_sync_addr); @@ -83,73 +187,72 @@ class EthChannelBuffer final { ASSERT((uint32_t)channel_bytes_acked_addresses[i] != (uint32_t)(channel_bytes_sent_addresses[i])); *(channel_bytes_sent_addresses[i]) = 0; *(channel_bytes_acked_addresses[i]) = 0; + *(channel_src_id_addresses[i]) = 0x1c0ffee1; + (channel_src_id_addresses[i])[1] = 0x1c0ffee2; + // Note we don't need to overwrite the `channel_src_id_addresses` except for perhapse // debug purposes where we may wish to tag this with a special value } } - [[nodiscard]] FORCE_INLINE size_t get_current_buffer_address() const { - return this->buffer_addresses[this->buffer_index()]; + [[nodiscard]] FORCE_INLINE size_t get_buffer_address(BufferIndex const& buffer_index) const { + return this->buffer_addresses[buffer_index]; } - [[nodiscard]] FORCE_INLINE volatile PacketHeader *get_current_packet_header() const { - return reinterpret_cast(this->buffer_addresses[this->buffer_index()]); + [[nodiscard]] FORCE_INLINE volatile PacketHeader *get_packet_header(BufferIndex const& buffer_index) const { + return reinterpret_cast(this->buffer_addresses[buffer_index]); } - [[nodiscard]] FORCE_INLINE size_t get_current_payload_size() const { - return get_current_packet_header()->get_payload_size_including_header(); + [[nodiscard]] FORCE_INLINE size_t get_payload_size(BufferIndex const& buffer_index) const { + return get_packet_header(buffer_index)->get_payload_size_including_header(); } - [[nodiscard]] FORCE_INLINE size_t get_current_payload_plus_channel_sync_size() const { - return get_current_packet_header()->get_payload_size_including_header() + sizeof(eth_channel_sync_t); + [[nodiscard]] FORCE_INLINE size_t get_payload_plus_channel_sync_size(BufferIndex const& buffer_index) const { + return get_packet_header(buffer_index)->get_payload_size_including_header() + sizeof(eth_channel_sync_t); } - // TODO: Split off into two separate functions: - // volatile tt_l1_ptr size_t *get_current_bytes_sent_ptr() const - // size_t get_current_bytes_sent_address() const - [[nodiscard]] FORCE_INLINE volatile tt_l1_ptr size_t *get_current_bytes_sent_address() const { - return this->channel_bytes_sent_addresses[this->buffer_index()]; + [[nodiscard]] FORCE_INLINE volatile tt_l1_ptr size_t *get_bytes_sent_address(BufferIndex const& buffer_index) const { + return this->channel_bytes_sent_addresses[buffer_index]; } - [[nodiscard]] FORCE_INLINE volatile tt_l1_ptr size_t *get_current_bytes_acked_address() const { - return this->channel_bytes_acked_addresses[this->buffer_index()]; + [[nodiscard]] FORCE_INLINE volatile tt_l1_ptr size_t *get_bytes_acked_address(BufferIndex const& buffer_index) const { + return this->channel_bytes_acked_addresses[buffer_index]; } - [[nodiscard]] FORCE_INLINE volatile tt_l1_ptr size_t *get_current_src_id_address() const { - return this->channel_src_id_addresses[this->buffer_index()]; + [[nodiscard]] FORCE_INLINE volatile tt_l1_ptr size_t *get_src_id_address(BufferIndex const& buffer_index) const { + return this->channel_src_id_addresses[buffer_index]; } - [[nodiscard]] FORCE_INLINE size_t get_channel_buffer_max_size_in_bytes() const { + [[nodiscard]] FORCE_INLINE size_t get_channel_buffer_max_size_in_bytes(BufferIndex const& buffer_index) const { return this->buffer_size_in_bytes; } // Doesn't return the message size, only the maximum eth payload size - [[nodiscard]] FORCE_INLINE size_t get_current_max_eth_payload_size() const { + [[nodiscard]] FORCE_INLINE size_t get_max_eth_payload_size() const { return this->max_eth_payload_size_in_bytes; } [[nodiscard]] FORCE_INLINE size_t get_id() const { return this->channel_id; } - [[nodiscard]] FORCE_INLINE bool eth_is_receiver_channel_send_done() const { - return *(this->get_current_bytes_sent_address()) == 0; + [[nodiscard]] FORCE_INLINE bool eth_is_receiver_channel_send_done(BufferIndex const& buffer_index) const { + return *(this->get_bytes_sent_address(buffer_index)) == 0; } - [[nodiscard]] FORCE_INLINE bool eth_bytes_are_available_on_channel() const { - return *(this->get_current_bytes_sent_address()) != 0; + [[nodiscard]] FORCE_INLINE bool eth_bytes_are_available_on_channel(BufferIndex const& buffer_index) const { + return *(this->get_bytes_sent_address(buffer_index)) != 0; } - [[nodiscard]] FORCE_INLINE bool eth_is_receiver_channel_send_acked() const { - return *(this->get_current_bytes_acked_address()) != 0; + [[nodiscard]] FORCE_INLINE bool eth_is_receiver_channel_send_acked(BufferIndex const& buffer_index) const { + return *(this->get_bytes_acked_address(buffer_index)) != 0; } - FORCE_INLINE void eth_clear_sender_channel_ack() const { - *(this->channel_bytes_acked_addresses[this->buffer_index()]) = 0; + FORCE_INLINE void eth_clear_sender_channel_ack(BufferIndex const& buffer_index) const { + *(this->channel_bytes_acked_addresses[buffer_index]) = 0; + } + [[nodiscard]] FORCE_INLINE bool eth_is_acked_or_completed(BufferIndex const& buffer_index) const { + return eth_is_receiver_channel_send_acked(buffer_index) || eth_is_receiver_channel_send_done(buffer_index); } [[nodiscard]] FORCE_INLINE size_t get_eth_transaction_ack_word_addr() const { return this->eth_transaction_ack_word_addr; } - FORCE_INLINE void advance_buffer_index() { - this->buff_idx = wrap_incrementbuff_idx), NUM_BUFFERS>(this->buff_idx); - } - [[nodiscard]] FORCE_INLINE bool all_buffers_drained() const { bool drained = true; for (size_t i = 0; i < NUM_BUFFERS && drained; i++) { @@ -158,12 +261,20 @@ class EthChannelBuffer final { return drained; } - private: - FORCE_INLINE auto buffer_index() const { - ASSERT(this->buff_idx < NUM_BUFFERS); - return buff_idx; + bool needs_to_send_channel_sync() const { + return this->need_to_send_channel_sync; + } + + void set_need_to_send_channel_sync(bool need_to_send_channel_sync) { + this->need_to_send_channel_sync = need_to_send_channel_sync; } + void clear_need_to_send_channel_sync() { + this->need_to_send_channel_sync = false; + } + + private: + std::array buffer_addresses; std::array channel_bytes_sent_addresses; std::array channel_bytes_acked_addresses; @@ -174,13 +285,19 @@ class EthChannelBuffer final { // Includes header + payload + channel_sync const std::size_t eth_transaction_ack_word_addr; const std::size_t max_eth_payload_size_in_bytes; - uint8_t buff_idx; uint8_t channel_id; }; + +template struct EdmChannelWorkerInterface { EdmChannelWorkerInterface() : - worker_location_info_ptr(nullptr), local_semaphore_address(nullptr), connection_live_semaphore(nullptr) {} + worker_location_info_ptr(nullptr), + remote_producer_wrptr(nullptr), + connection_live_semaphore(nullptr), + local_wrptr(), + local_ackptr(), + local_rdptr() {} EdmChannelWorkerInterface( // TODO: PERF: See if we can make this non-volatile and then only // mark it volatile when we know we need to reload it (i.e. after we receive a @@ -190,50 +307,78 @@ struct EdmChannelWorkerInterface { // packet... Then we'll also be able to cache the uint64_t addr of the worker // semaphore directly (saving on regenerating it each time) volatile EDMChannelWorkerLocationInfo *worker_location_info_ptr, - volatile tt_l1_ptr uint32_t *const local_semaphore_address, + volatile tt_l1_ptr uint32_t *const remote_producer_wrptr, volatile tt_l1_ptr uint32_t *const connection_live_semaphore) : worker_location_info_ptr(worker_location_info_ptr), - local_semaphore_address(local_semaphore_address), - connection_live_semaphore(connection_live_semaphore) {} + remote_producer_wrptr(remote_producer_wrptr), + connection_live_semaphore(connection_live_semaphore), + local_wrptr(), + local_ackptr(), + local_rdptr() { + DPRINT << "EDM my_x: " << (uint32_t)my_x[0] << ", my_y: " << (uint32_t)my_y[0] << " rdptr set to 0 at " << (uint32_t)(void*)&(worker_location_info_ptr->edm_rdptr) << "\n"; + *reinterpret_cast(&(worker_location_info_ptr->edm_rdptr)) = 0; + } // Flow control methods // - [[nodiscard]] FORCE_INLINE auto local_semaphore_value() const { return *local_semaphore_address; } - - [[nodiscard]] FORCE_INLINE bool has_payload() { return *local_semaphore_address != 0; } - - FORCE_INLINE void clear_local_semaphore() { noc_semaphore_set(local_semaphore_address, 0); } + // local_wrptr trails from_remote_wrptr + // we have new data if they aren't equal + [[nodiscard]] FORCE_INLINE bool has_unsent_payload() { + return local_wrptr.get_ptr() != *remote_producer_wrptr; + } + [[nodiscard]] FORCE_INLINE bool has_unacked_sends() { + return local_ackptr.get_ptr() != local_wrptr.get_ptr(); + } [[nodiscard]] FORCE_INLINE uint32_t get_worker_semaphore_address() const { return worker_location_info_ptr->worker_semaphore_address; } - void increment_worker_semaphore() const { + FORCE_INLINE void update_worker_copy_of_read_ptr() { auto const &worker_info = *worker_location_info_ptr; uint64_t worker_semaphore_address = get_noc_addr( (uint32_t)worker_info.worker_xy.x, (uint32_t)worker_info.worker_xy.y, worker_info.worker_semaphore_address); - - DPRINT << "EDM ntf wrkr sem @" << (uint64_t)worker_semaphore_address << "\n"; - noc_semaphore_inc(worker_semaphore_address, 1); + noc_inline_dw_write(worker_semaphore_address, local_ackptr.get_ptr()); } // Connection management methods // - FORCE_INLINE void teardown_connection() const { + FORCE_INLINE void teardown_connection(uint32_t last_edm_rdptr_value) const { auto const &worker_info = *worker_location_info_ptr; uint64_t worker_semaphore_address = get_noc_addr( (uint32_t)worker_info.worker_xy.x, (uint32_t)worker_info.worker_xy.y, worker_info.worker_teardown_semaphore_address); + // Set connection to unused so it's available for next worker + *this->connection_live_semaphore = tt::fabric::WorkerToFabricEdmSender::unused_connection_value; + + *reinterpret_cast(&(worker_location_info_ptr->edm_rdptr)) = last_edm_rdptr_value; + noc_semaphore_inc(worker_semaphore_address, 1); } - [[nodiscard]] FORCE_INLINE bool has_worker_teardown_request() const { return *connection_live_semaphore == 0; } + bool all_eth_packets_acked() const { + return this->local_ackptr.is_caught_up_to(this->local_wrptr); + } + bool all_eth_packets_completed() const { + return this->local_rdptr.is_caught_up_to(this->local_wrptr); + } - [[nodiscard]] FORCE_INLINE bool connection_is_live() const { return *connection_live_semaphore == 1; } + // Call to keep the connection flow control info fresh with worker. + void propagate_ackptr_to_connection_info() { + worker_location_info_ptr->edm_rdptr = local_ackptr.get_ptr(); + } + + [[nodiscard]] FORCE_INLINE bool has_worker_teardown_request() const { return *connection_live_semaphore == tt::fabric::WorkerToFabricEdmSender::close_connection_request_value; } + [[nodiscard]] FORCE_INLINE bool connection_is_live() const { return *connection_live_semaphore == tt::fabric::WorkerToFabricEdmSender::open_connection_value; } volatile EDMChannelWorkerLocationInfo *worker_location_info_ptr; - volatile tt_l1_ptr uint32_t *const local_semaphore_address; + volatile tt_l1_ptr uint32_t *const remote_producer_wrptr; volatile tt_l1_ptr uint32_t *const connection_live_semaphore; + + ChannelBufferPointer local_wrptr; + ChannelBufferPointer local_ackptr; + ChannelBufferPointer local_rdptr; // also used as completion_ptr }; + } // namespace tt::fabric diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_program.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_program.cpp index 8fe983e403f..dbcc0d5848d 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_program.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_program.cpp @@ -156,9 +156,19 @@ operation::ProgramWithCallbacks all_gather_async_multi_core_with_workers( std::optional local_fabric_handle = enable_persistent_fabric_mode ? ttnn::ccl::EdmLineFabricOpInterface::build_program_builder_worker_connection_fabric( - device, forward_device, backward_device, &program, enable_persistent_fabric_mode, num_links) + device, + forward_device.value_or(nullptr), + backward_device.value_or(nullptr), + &program, + enable_persistent_fabric_mode, + num_links) : ccl::EdmLineFabricOpInterface( - device, forward_device, backward_device, &program, enable_persistent_fabric_mode, num_links); + device, + forward_device.value_or(nullptr), + backward_device.value_or(nullptr), + &program, + enable_persistent_fabric_mode, + num_links); LineTopology line_topology(ring_size, ring_index);