Skip to content

Commit

Permalink
Optimize EDM-fabric flow-control protocols (#17495)
Browse files Browse the repository at this point in the history
Essentially a rewrite for the majority of the fabric EDM.

Worker -> EDM connection teardown bug fix

Updates flow control for
* Worker -> EDM
* EDM Sender Channel
* EDM Receiver Channel

The fourth piece of the data-path that is not updated is the sender ->
receiver flow control over ethernet. This is a future change and is
tracked through this issue:
#17430

# Worker -> EDM Flow Connection Teardown Fix
fix worker <-> edm fabric connection state transitions to prevent race
connection state transitions were previously an invalid design:
worker: 0 (close) -> 1 (open)
worker: 1 (open) -> 0 (close)

This design was inadequate because worker was able to open and close a connection without EDM fabric being in the loop. This could lead to the following race condition in workloads with few packets per connection:

* edm checks sender channel
* worker opens conn
* worker sends payload
* worker tears down connection
* edm checks channel  (misses teardown request)

Additionally, this change fixes a bug with worker <-> EDM connection teardown by adding a new discrete state: `close_connection_request`

New connection management is as follows:
* worker (open connection): update connection semaphore from 0 (unused connection) -> 1 (open connection)
* worker: send traffic
* worker (close connection): update connection semaphore from 1 (open) -> 2 (teardown request)
* worker wait for ack on local teardown address (`worker_teardown_addr` in `WorkerToFabricEdmSender::close()`)
* EDM acknowledge connection close by updating `worker_teardown` semaphore in worker L1


# Worker -> EDM Flow Control Protocol:
The flow control protocol is rd/wr ptr based and is implemented as follows (from the worker's perspective):
The adapter has a local write pointer (wrptr) which is used to track the next buffer slot to write to. The adapter
also has a local memory slot that holds the remote read pointer (rdptr) of the EDM. The adapter uses the difference
between these two pointers (where rdptr trails wrptr) to determine if the EDM has space to accept a new packet.

As the adapter writes into the EDM, it updates the local wrptr. As the EDM reads from its local L1 channel buffer,
it will notify the worker/adapter (here) by updating the worker remote_rdptr to carry the value of the EDM rdptr.

## EDM <-> EDM Channel Flow Control
The flow control protocol between EDM channels is built on a rd/wr ptr
based protocol where pointers are
to buffer slots within the channel (as opposed so something else like
byte or word offset). Ptrs are
free to advance independently from each other as long as there is no
overflow or underflow.

### Sender Channel Flow Control
Both sender channels share the same flow control view into the receiver
channel. This is because both channels
write to the same receiver channel.
* wrptr:
* points to next buffer slot to write to into the remote (over Ethernet)
receiver channel.
  * leads other pointers
  * writer updates for every new packet
  * `has_data_to_send(): local_wrptr != remote_sender_wrptr`
* ackptr
  * trails `wrptr`
  * advances as the channel receives acknowledgements from the receiver
* as this advances, the sender channel can notify the upstream worker of
additional space in sender channel buffer
* completion_ptr:
  * trails `local_wrptr`
  * "rdptr" from remote sender's perspective
  * advances as packets completed by receiver
* as this advances, the sender channel can write additional packets to
the receiver at this slot

### Receiver Channel Flow Control
* ackptr/rdptr:
  * leads all pointers
* indicates the next buffer slot we expect data to arrive (from remote
sender) at
    * advances as packets are received (and acked)
  * make sure not to overlap completion pointer
* wr_sent_ptr:
  * trails `ackptr`
  * indicates the buffer slot currently being processed, written out
* advances after all forwding writes (to noc or downstream EDM) are
initiated
* wr_flush_ptr:
  * trails `wr_sent_ptr`
  * advances as writes are flushed
* completion_ptr:
  * trails `wr_flush_ptr`
* indicates the next receiver buffer slot in the receiver channel to
send completion acks for
  • Loading branch information
SeanNijjar authored Feb 4, 2025
1 parent b220459 commit 63bdb89
Show file tree
Hide file tree
Showing 18 changed files with 2,089 additions and 473 deletions.
213 changes: 213 additions & 0 deletions tests/ttnn/unit_tests/gtests/ccl/kernels/edm_fabric_writer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp"
#include "ttnn/cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/fabric_connection_manager.hpp"
#include "ttnn/cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/noc_addr.hpp"
#include "dataflow_api.h"

#include <cstdint>
#include <cstddef>

static constexpr bool enable_start_synchronization = get_compile_time_arg_val(0) != 0;
static constexpr bool enable_finish_synchronization = get_compile_time_arg_val(1) != 0;
static constexpr bool enable_any_synchronization = enable_start_synchronization || enable_finish_synchronization;

FORCE_INLINE void line_sync(
FabricConnectionManager& fabric_connection,
volatile tt::fabric::PacketHeader* mcast_fwd_packet_header,
volatile tt::fabric::PacketHeader* mcast_bwd_packet_header,
size_t sync_bank_addr,
size_t sync_noc_x,
size_t sync_noc_y,
size_t sync_val) {
using namespace tt::fabric;
mcast_fwd_packet_header->to_atomic_inc();
mcast_bwd_packet_header->to_atomic_inc();

if (fabric_connection.has_forward_connection()) {
mcast_fwd_packet_header->to_noc_unicast_atomic_inc(NocUnicastAtomicIncCommandHeader{
sync_bank_addr, 1, 128, static_cast<uint8_t>(sync_noc_x), static_cast<uint8_t>(sync_noc_y)});
fabric_connection.get_forward_connection().wait_for_empty_write_slot();
fabric_connection.get_forward_connection().send_payload_flush_non_blocking_from_address(
(uint32_t)mcast_fwd_packet_header, sizeof(tt::fabric::PacketHeader));
}

if (fabric_connection.has_backward_connection()) {
mcast_bwd_packet_header->to_noc_unicast_atomic_inc(NocUnicastAtomicIncCommandHeader{
sync_bank_addr, 1, 128, static_cast<uint8_t>(sync_noc_x), static_cast<uint8_t>(sync_noc_y)});
fabric_connection.get_backward_connection().wait_for_empty_write_slot();
fabric_connection.get_backward_connection().send_payload_flush_non_blocking_from_address(
(uint32_t)mcast_bwd_packet_header, sizeof(tt::fabric::PacketHeader));
}
noc_semaphore_inc(get_noc_addr(sync_noc_x, sync_noc_y, sync_bank_addr), 1);
if (sync_noc_x == my_x[0] && sync_noc_y == my_y[0]) {
noc_semaphore_wait_min(reinterpret_cast<volatile uint32_t*>(sync_bank_addr), sync_val);
}
}

void kernel_main() {
using namespace tt::fabric;
size_t arg_idx = 0;

const size_t dest_bank_addr = get_arg_val<uint32_t>(arg_idx++);
const size_t packet_payload_size_bytes = get_arg_val<uint32_t>(arg_idx++);
const size_t dest_noc_x = get_arg_val<uint32_t>(arg_idx++);
const size_t dest_noc_y = get_arg_val<uint32_t>(arg_idx++);

const size_t num_mcasts = get_arg_val<uint32_t>(arg_idx++);
const size_t mcast_fwd_hops = get_arg_val<uint32_t>(arg_idx++);
const size_t mcast_bwd_hops = get_arg_val<uint32_t>(arg_idx++);

const size_t num_unicasts = get_arg_val<uint32_t>(arg_idx++);
const size_t unicast_hops = get_arg_val<uint32_t>(arg_idx++);
const bool unicast_is_fwd = get_arg_val<uint32_t>(arg_idx++) != 0;

const size_t source_l1_cb_index = get_arg_val<uint32_t>(arg_idx++);
const size_t packet_header_cb = get_arg_val<uint32_t>(arg_idx++);
const size_t packet_header_size_in_headers = get_arg_val<uint32_t>(arg_idx++);

auto fabric_connection = FabricConnectionManager::build_from_args(arg_idx);

ASSERT(fabric_connection.is_logically_connected());

if (!fabric_connection.is_logically_connected()) {
return;
}
size_t sync_noc_x = 0;
size_t sync_noc_y = 0;
size_t sync_bank_addr = 0;
size_t total_workers_per_sync = 0;
if (enable_any_synchronization) {
sync_noc_x = get_arg_val<uint32_t>(arg_idx++);
sync_noc_y = get_arg_val<uint32_t>(arg_idx++);
sync_bank_addr = get_arg_val<uint32_t>(arg_idx++);
total_workers_per_sync = get_arg_val<uint32_t>(arg_idx++);
}

const size_t start_sync_val = total_workers_per_sync;
const size_t finish_sync_val = 3 * total_workers_per_sync;

fabric_connection.open();

cb_reserve_back(source_l1_cb_index, 1);
cb_reserve_back(packet_header_cb, packet_header_size_in_headers);
const auto source_l1_buffer_address = get_write_ptr(source_l1_cb_index);
const auto packet_header_buffer_address = get_write_ptr(packet_header_cb);

auto* mcast_fwd_packet_header = reinterpret_cast<PacketHeader*>(packet_header_buffer_address);
auto* mcast_bwd_packet_header =
reinterpret_cast<PacketHeader*>(packet_header_buffer_address + sizeof(tt::fabric::PacketHeader));
auto* unicast_packet_header =
reinterpret_cast<PacketHeader*>(packet_header_buffer_address + sizeof(tt::fabric::PacketHeader) * 2);
mcast_fwd_packet_header->to_write().to_chip_multicast(
MulticastRoutingCommandHeader{1, static_cast<uint8_t>(mcast_fwd_hops)});
mcast_bwd_packet_header->to_write().to_chip_multicast(
MulticastRoutingCommandHeader{1, static_cast<uint8_t>(mcast_bwd_hops)});

if (enable_start_synchronization) {
line_sync(
fabric_connection,
mcast_fwd_packet_header,
mcast_bwd_packet_header,
sync_bank_addr,
sync_noc_x,
sync_noc_y,
start_sync_val);
line_sync(
fabric_connection,
mcast_fwd_packet_header,
mcast_bwd_packet_header,
sync_bank_addr,
sync_noc_x,
sync_noc_y,
2 * start_sync_val);
}

mcast_fwd_packet_header->to_write().to_chip_multicast(
MulticastRoutingCommandHeader{1, static_cast<uint8_t>(mcast_fwd_hops)});
mcast_bwd_packet_header->to_write().to_chip_multicast(
MulticastRoutingCommandHeader{1, static_cast<uint8_t>(mcast_bwd_hops)});
unicast_packet_header->to_atomic_inc().to_chip_unicast(
UnicastRoutingCommandHeader{static_cast<uint8_t>(unicast_hops)});

{
DeviceZoneScopedN("MAIN-WRITE-ZONE");
for (size_t i = 0; i < num_mcasts; i++) {
noc_async_write(
source_l1_buffer_address,
safe_get_noc_addr(static_cast<uint8_t>(dest_noc_x), static_cast<uint8_t>(dest_noc_y), dest_bank_addr),
packet_payload_size_bytes);
if (fabric_connection.has_forward_connection()) {
DeviceZoneScopedN("WR-FWD");
mcast_fwd_packet_header->to_noc_unicast(NocUnicastCommandHeader{
dest_bank_addr,
packet_payload_size_bytes + sizeof(tt::fabric::PacketHeader),
static_cast<uint8_t>(dest_noc_x),
static_cast<uint8_t>(dest_noc_y)});
{
DeviceZoneScopedN("WR-FWD-WAIT");
fabric_connection.get_forward_connection().wait_for_empty_write_slot();
}
fabric_connection.get_forward_connection().send_payload_without_header_non_blocking_from_address(
source_l1_buffer_address, packet_payload_size_bytes);
fabric_connection.get_forward_connection().send_payload_flush_non_blocking_from_address(
(uint32_t)mcast_fwd_packet_header, sizeof(tt::fabric::PacketHeader));
}

if (fabric_connection.has_backward_connection()) {
DeviceZoneScopedN("WR-BWD");
mcast_bwd_packet_header->to_noc_unicast(NocUnicastCommandHeader{
dest_bank_addr,
packet_payload_size_bytes + sizeof(tt::fabric::PacketHeader),
static_cast<uint8_t>(dest_noc_x),
static_cast<uint8_t>(dest_noc_y)});
{
DeviceZoneScopedN("WR-BWD-WAIT");
fabric_connection.get_backward_connection().wait_for_empty_write_slot();
}
fabric_connection.get_backward_connection().send_payload_without_header_non_blocking_from_address(
source_l1_buffer_address, packet_payload_size_bytes);
fabric_connection.get_backward_connection().send_payload_flush_non_blocking_from_address(
(uint32_t)mcast_bwd_packet_header, sizeof(tt::fabric::PacketHeader));
}
{
noc_async_writes_flushed();
}
}
}

for (size_t i = 0; i < num_unicasts; i++) {
DeviceZoneScopedN("UNICAST-WRITE");
auto& fabric_conn =
unicast_is_fwd ? fabric_connection.get_forward_connection() : fabric_connection.get_backward_connection();
unicast_packet_header->to_noc_unicast(NocUnicastCommandHeader{
dest_bank_addr,
packet_payload_size_bytes,
static_cast<uint8_t>(dest_noc_x),
static_cast<uint8_t>(dest_noc_y)});
fabric_conn.wait_for_empty_write_slot();
fabric_conn.send_payload_without_header_non_blocking_from_address(
source_l1_buffer_address, packet_payload_size_bytes);
fabric_conn.send_payload_flush_blocking_from_address(
(uint32_t)unicast_packet_header, sizeof(tt::fabric::PacketHeader));
}

if (enable_finish_synchronization) {
line_sync(
fabric_connection,
mcast_fwd_packet_header,
mcast_bwd_packet_header,
sync_bank_addr,
sync_noc_x,
sync_noc_y,
finish_sync_val);
}

{
DeviceZoneScopedN("WR-CLOSE");
fabric_connection.close();
}
noc_async_write_barrier();
}
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,6 @@ void kernel_main() {
packet_header.reserved2 = 0x1111; // debug only
}

uint64_t buffer_address = sender.edm_buffer_addr +
(*sender.buffer_index_ptr * (sender.buffer_size_bytes + sizeof(eth_channel_sync_t)));
sender.send_payload_blocking_from_address(packet_addr, packet_size);
noc_async_writes_flushed();
cb_pop_front(cb_id_in0, pages_to_send);
Expand Down
Loading

0 comments on commit 63bdb89

Please sign in to comment.