Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix worker <-> teardown by adding separate worker connection teardown semaphore #17033

Merged
merged 6 commits into from
Jan 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ void kernel_main() {
const uint32_t eth_sender_l1_sem_id = get_arg_val<uint32_t>(arg_idx++);
volatile uint32_t* const writer_send_sem_addr =
reinterpret_cast<volatile uint32_t* const>(get_semaphore(get_arg_val<uint32_t>(arg_idx++)));
volatile uint32_t* const worker_teardown_sem_addr =
reinterpret_cast<volatile uint32_t* const>(get_semaphore(get_arg_val<uint32_t>(arg_idx++)));
const uint32_t eth_sender_noc_x = get_arg_val<uint32_t>(arg_idx++);
const uint32_t eth_sender_noc_y = get_arg_val<uint32_t>(arg_idx++);
const uint32_t num_buffers_per_edm_channel = get_arg_val<uint32_t>(arg_idx++);
Expand All @@ -69,6 +71,7 @@ void kernel_main() {
ASSERT(edm_buffer_index_sem_id < 8);
auto edm_buffer_index_id = edm_buffer_index_sem_id;
ASSERT(worker_buffer_index_semaphore_addr != reinterpret_cast<size_t>(writer_send_sem_addr));
ASSERT(worker_buffer_index_semaphore_addr != reinterpret_cast<size_t>(worker_teardown_sem_addr));
ASSERT(worker_buffer_index_semaphore_addr != reinterpret_cast<size_t>(last_message_semaphore_address));

transmit_config config;
Expand Down Expand Up @@ -96,6 +99,7 @@ void kernel_main() {
edm_buffer_size_bytes,
edm_buffer_index_id,
writer_send_sem_addr,
worker_teardown_sem_addr,
worker_buffer_index_semaphore_addr);

sender.open();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,22 +246,23 @@ static constexpr size_t PACKET_HEADER_SIZE_BYTES = sizeof(tt::fabric::PacketHead
void generate_sender_worker_kernels(
Program& program,
IDevice* device,
CoreCoord const& worker_core,
ttnn::ccl::SenderWorkerAdapterSpec const& worker_fabric_connection,
mode_variant_t const& mode,
const CoreCoord& worker_core,
const ttnn::ccl::SenderWorkerAdapterSpec& worker_fabric_connection,
const mode_variant_t& mode,
std::size_t edm_buffer_size,
uint32_t page_plus_header_size,
uint32_t num_pages_total,
uint32_t num_pages_per_edm_buffer,
uint32_t local_worker_fabric_semaphore_id,
uint32_t local_worker_teardown_semaphore_id,
uint32_t local_worker_last_message_semaphore_id,
uint32_t dram_input_buffer_base_addr,
bool src_is_dram,
uint32_t dram_output_buffer_base_addr,
bool dest_is_dram,
uint32_t worker_buffer_index_semaphore_id,
// farthest to closest
std::vector<ttnn::ccl::edm_termination_info_t> const& edm_termination_infos) {
const std::vector<ttnn::ccl::edm_termination_info_t>& edm_termination_infos) {
auto const& edm_noc_core = CoreCoord(worker_fabric_connection.edm_noc_x, worker_fabric_connection.edm_noc_y);
std::vector<uint32_t> sender_worker_reader_compile_args{
src_is_dram, //
Expand Down Expand Up @@ -295,6 +296,7 @@ void generate_sender_worker_kernels(
worker_fabric_connection.edm_buffer_base_addr,
worker_fabric_connection.edm_l1_sem_addr,
local_worker_fabric_semaphore_id,
local_worker_teardown_semaphore_id,
(uint32_t)edm_noc_core.x,
(uint32_t)edm_noc_core.y,
worker_fabric_connection.num_buffers_per_channel,
Expand Down Expand Up @@ -377,6 +379,7 @@ bool RunLoopbackTest(
std::vector<CoreCoord> worker_cores = {CoreCoord(0, 0)};

auto local_worker_fabric_semaphore_id = tt::tt_metal::CreateSemaphore(sender_program, worker_cores.at(0), 0);
auto local_worker_teardown_semaphore_id = tt::tt_metal::CreateSemaphore(sender_program, worker_cores.at(0), 0);
auto local_worker_last_message_semaphore_id = tt::tt_metal::CreateSemaphore(sender_program, worker_cores.at(0), 0);
auto worker_buffer_index_semaphore_id = tt::tt_metal::CreateSemaphore(sender_program, worker_cores.at(0), 0);

Expand Down Expand Up @@ -447,6 +450,7 @@ bool RunLoopbackTest(
num_pages_total,
pages_per_send,
local_worker_fabric_semaphore_id,
local_worker_teardown_semaphore_id,
local_worker_last_message_semaphore_id,
local_input_buffer_address,
src_is_dram,
Expand Down Expand Up @@ -947,6 +951,7 @@ bool RunLineFabricTest(
////////////////////////////////////////////////////////////////////////////

auto local_worker_fabric_semaphore_id = tt::tt_metal::CreateSemaphore(programs[0], worker_cores.at(0), 0);
auto local_worker_teardown_semaphore_id = tt::tt_metal::CreateSemaphore(programs[0], worker_cores.at(0), 0);
auto local_worker_last_message_semaphore_id = tt::tt_metal::CreateSemaphore(programs[0], worker_cores.at(0), 0);
auto worker_buffer_index_semaphore_id = tt::tt_metal::CreateSemaphore(programs[0], worker_cores.at(0), 0);
////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -976,6 +981,7 @@ bool RunLineFabricTest(
num_pages_total,
pages_per_send,
local_worker_fabric_semaphore_id,
local_worker_teardown_semaphore_id,
local_worker_last_message_semaphore_id,
local_input_buffer_address,
src_is_dram,
Expand Down Expand Up @@ -1939,78 +1945,6 @@ TEST(WorkerCclCommandProcessingKernelLocalMode, MultiInputReader_MultiPage1) {
// ////////////////////////////////////////////////////////////////////
// ////////////////////////////////////////////////////////////////////

TEST(WorkerCclCommandProcessingKernelFabricUnicastMode, MultiInputReader_SinglePageTile_OneHop) {
ttnn::SimpleShape tensor_shape({1, 1, 32, 32});
constexpr size_t distance_dest_device = 1;
constexpr size_t num_devices = 4;
Layout const layout = Layout::TILE;
MemoryConfig const in0_memory_config = MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM);
MemoryConfig const in1_memory_config = MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM);
MemoryConfig const out0_memory_config = MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM);
MemoryConfig const out1_memory_config = MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM);

auto num_elems = std::reduce(tensor_shape.cbegin(), tensor_shape.cend(), 1, std::multiplies<uint32_t>());
Tensor input_tensor0 =
ttnn::experimental::view(ttnn::arange(0, num_elems, 1, DataType::UINT32), tensor_shape).to(layout);
Tensor input_tensor1 =
ttnn::experimental::view(ttnn::arange(num_elems, 2 * num_elems, 1, DataType::UINT32), tensor_shape).to(layout);
Tensor output_tensor0 = ttnn::experimental::view(ttnn::ones(tensor_shape, DataType::UINT32, layout), tensor_shape);
Tensor output_tensor1 = ttnn::experimental::view(ttnn::ones(tensor_shape, DataType::UINT32, layout), tensor_shape);

input_tensor0.set_tensor_spec(TensorSpec(
tensor_shape, TensorLayout(DataType::UINT32, PageConfig(layout, tt_metal::Tile()), in0_memory_config)));
input_tensor1.set_tensor_spec(TensorSpec(
tensor_shape, TensorLayout(DataType::UINT32, PageConfig(layout, tt_metal::Tile()), in1_memory_config)));
output_tensor0.set_tensor_spec(TensorSpec(
tensor_shape, TensorLayout(DataType::UINT32, PageConfig(layout, tt_metal::Tile()), out0_memory_config)));
output_tensor1.set_tensor_spec(TensorSpec(
tensor_shape, TensorLayout(DataType::UINT32, PageConfig(layout, tt_metal::Tile()), out1_memory_config)));

size_t page_size = tile_size(DataFormat::RawUInt32);

ttnn::ccl::Shape4D<uint32_t> tensor_shape_in_pages = shape_to_shape_in_tiles(tensor_shape);
ttnn::ccl::Shape4D<uint32_t> tensor_slice_shape_in_pages = tensor_shape_in_pages;
ttnn::ccl::Shape4D<uint32_t> tensor_slice_offset = {0, 0, 0, 0};
ttnn::ccl::Shape4D<uint32_t> worker_slice_shape = tensor_shape_in_pages;
ttnn::ccl::Shape4D<uint32_t> worker_slice_offset = {0, 0, 0, 0};

ttnn::ccl::v2::TensorSlice tensor_slice{
tensor_shape_in_pages,
tensor_slice_shape_in_pages,
tensor_slice_offset,
worker_slice_shape,
worker_slice_offset};

auto const in0_tensor_slice = tensor_slice;
auto const in1_tensor_slice = tensor_slice;
auto const out0_tensor_slice = tensor_slice;
auto const out1_tensor_slice = tensor_slice;

ttnn::ccl::cmd::CclCommandDestArgs dest_args = ttnn::ccl::cmd::UnicastCommandDestArgs{distance_dest_device, true};
auto pass = TestMultiInputReaderKernel(
num_devices,
input_tensor0,
in0_memory_config,
input_tensor1,
in1_memory_config,
output_tensor0,
out0_memory_config,
output_tensor1,
out1_memory_config,

in0_tensor_slice,
in1_tensor_slice,
out0_tensor_slice,
out1_tensor_slice,

page_size,
TwoInputReaderKernelWriteMode::FABRIC_UNICAST,
dest_args,
false);

ASSERT_TRUE(pass);
}

TEST(WorkerCclCommandProcessingKernelFabricUnicastMode, MultiInputReader_SinglePageTile_OneHop_PersistentFabric) {
ttnn::SimpleShape tensor_shape({1, 1, 32, 32});
constexpr size_t distance_dest_device = 1;
Expand Down Expand Up @@ -2165,50 +2099,6 @@ void RunFabricMcastFullTensorPropagateTest(
ASSERT_TRUE(pass);
}

TEST(WorkerCclCommandProcessingKernelFabricMulticastMode, MultiInputReader_SinglePageTile_SingleHop) {
ttnn::SimpleShape tensor_shape({1, 1, 32, 32});
constexpr size_t distance_dest_device = 1;
constexpr size_t num_devices = 4;
RunFabricMcastFullTensorPropagateTest(tensor_shape, distance_dest_device, num_devices, false);
}
TEST(WorkerCclCommandProcessingKernelFabricMulticastMode, MultiInputReader_SinglePageTile_TwoHop) {
ttnn::SimpleShape tensor_shape({1, 1, 32, 32});
constexpr size_t distance_dest_device = 2;
constexpr size_t num_devices = 4;
RunFabricMcastFullTensorPropagateTest(tensor_shape, distance_dest_device, num_devices, false);
}
TEST(WorkerCclCommandProcessingKernelFabricMulticastMode, MultiInputReader_SinglePageTile_ThreeHop) {
ttnn::SimpleShape tensor_shape({1, 1, 32, 32});
constexpr size_t distance_dest_device = 3;
constexpr size_t num_devices = 4;
RunFabricMcastFullTensorPropagateTest(tensor_shape, distance_dest_device, num_devices, false);
}

TEST(WorkerCclCommandProcessingKernelFabricMulticastMode, MultiInputReader_4PageTile_SingleHop) {
ttnn::SimpleShape tensor_shape({1, 1, 32, 128});
constexpr size_t distance_dest_device = 1;
constexpr size_t num_devices = 4;
RunFabricMcastFullTensorPropagateTest(tensor_shape, distance_dest_device, num_devices, false);
}
TEST(WorkerCclCommandProcessingKernelFabricMulticastMode, DMultiInputReader_4PageTile_TwoHop) {
ttnn::SimpleShape tensor_shape({1, 1, 128, 32});
constexpr size_t distance_dest_device = 2;
constexpr size_t num_devices = 4;
RunFabricMcastFullTensorPropagateTest(tensor_shape, distance_dest_device, num_devices, false);
}
TEST(WorkerCclCommandProcessingKernelFabricMulticastMode, MultiInputReader_4PageTile_ThreeHop) {
ttnn::SimpleShape tensor_shape({1, 1, 64, 64});
constexpr size_t distance_dest_device = 3;
constexpr size_t num_devices = 4;
RunFabricMcastFullTensorPropagateTest(tensor_shape, distance_dest_device, num_devices, false);
}
TEST(WorkerCclCommandProcessingKernelFabricMulticastMode, MultiInputReader_lotsPageTile_ThreeHop) {
ttnn::SimpleShape tensor_shape({1, 1, 64, 16384});
constexpr size_t distance_dest_device = 3;
constexpr size_t num_devices = 4;
RunFabricMcastFullTensorPropagateTest(tensor_shape, distance_dest_device, num_devices, false);
}

TEST(WorkerCclCommandProcessingKernelFabricMulticastMode, MultiInputReader_SinglePageTile_SingleHop_PersistentFabric) {
ttnn::SimpleShape tensor_shape({1, 1, 32, 32});
constexpr size_t distance_dest_device = 1;
Expand Down
133 changes: 0 additions & 133 deletions tests/ttnn/unit_tests/operations/ccl/test_new_all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,71 +466,6 @@ def test_all_gather_sharded(
)


@skip_for_grayskull("Requires eth connected devices to run")
@pytest.mark.parametrize(
"num_devices, num_links, per_chip_output_shape, dim, layout",
[
(2, 1, [1, 2, 32, 1280], 1, ttnn.TILE_LAYOUT),
(2, 1, [2, 1, 32, 1280], 0, ttnn.TILE_LAYOUT),
(2, 1, [1, 2, 32, 2048], 1, ttnn.TILE_LAYOUT),
(2, 1, [1, 2, 32, 2304], 1, ttnn.TILE_LAYOUT),
(2, 1, [1, 2, 32, 4096], 1, ttnn.TILE_LAYOUT),
],
)
@pytest.mark.parametrize(
"input_dtype",
[
ttnn.bfloat16,
],
)
@pytest.mark.parametrize(
"buffer_type",
[
ttnn.BufferType.DRAM,
],
)
@pytest.mark.parametrize("enable_async", [True])
@pytest.mark.parametrize("replication_factor", [4])
def test_line_all_gather_async_on_T3K_cols_transient_fabric_post_commit(
t3k_mesh_device,
num_devices,
per_chip_output_shape,
dim,
num_links,
input_dtype,
layout,
buffer_type,
use_program_cache,
function_level_defaults,
enable_async,
replication_factor,
num_iters=1,
):
if len(t3k_mesh_device.get_devices()) < 8:
pytest.skip("Not T3K!")
run_line_all_gather_on_TG_with_mesh_tensor_along_rows(
t3k_mesh_device,
num_devices,
per_chip_output_shape,
ttnn.TensorMemoryLayout.INTERLEAVED,
dim,
num_links,
input_dtype,
layout,
buffer_type,
use_program_cache,
function_level_defaults,
enable_async=enable_async,
num_iters=num_iters,
num_all_gather_instances=replication_factor,
cluster_axis=0,
use_all_gather_async=True,
enable_persistent_fabric=False,
create_persistent_fabric=False,
teardown_persistent_fabric=False,
)


@skip_for_grayskull("Requires eth connected devices to run")
@pytest.mark.parametrize(
"num_devices, num_links, per_chip_output_shape, dim, layout",
Expand Down Expand Up @@ -596,74 +531,6 @@ def test_line_all_gather_async_on_T3K_cols_persistent_fabric_post_commit(
)


# Enumerate the post-commit cases explicitly
@skip_for_grayskull("Requires eth connected devices to run")
@pytest.mark.parametrize(
"num_devices, num_links, per_chip_output_shape, dim, layout",
[
(4, 1, [4, 1, 32, 1280], 0, ttnn.TILE_LAYOUT),
(4, 1, [1, 1, 32, 16384 * 4], 3, ttnn.TILE_LAYOUT),
(4, 1, [1, 4, 32, 2304], 1, ttnn.TILE_LAYOUT),
(4, 1, [1, 4, 32, 4096], 1, ttnn.TILE_LAYOUT),
(4, 1, [1, 4, 32, 6656], 1, ttnn.TILE_LAYOUT),
],
)
@pytest.mark.parametrize(
"input_dtype",
[
ttnn.bfloat16,
ttnn.bfloat8_b,
],
)
@pytest.mark.parametrize(
"buffer_type",
[
ttnn.BufferType.DRAM,
ttnn.BufferType.L1,
],
)
@pytest.mark.parametrize("replication_factor", [2])
@pytest.mark.parametrize("enable_async", [True])
def test_line_all_gather_async_on_T3K_rows_transient_fabric_post_commit(
t3k_mesh_device,
num_devices,
per_chip_output_shape,
dim,
num_links,
input_dtype,
layout,
buffer_type,
use_program_cache,
function_level_defaults,
enable_async,
replication_factor,
num_iters=1,
):
if len(t3k_mesh_device.get_devices()) < 8:
pytest.skip("Not T3K!")
run_line_all_gather_on_TG_with_mesh_tensor_along_rows(
t3k_mesh_device,
num_devices,
per_chip_output_shape,
ttnn.TensorMemoryLayout.INTERLEAVED,
dim,
num_links,
input_dtype,
layout,
buffer_type,
use_program_cache,
function_level_defaults,
enable_async=enable_async,
num_iters=num_iters,
num_all_gather_instances=replication_factor,
cluster_axis=1,
use_all_gather_async=True,
enable_persistent_fabric=False,
create_persistent_fabric=False,
teardown_persistent_fabric=False,
)


# Enumerate the post-commit cases explicitly
@skip_for_grayskull("Requires eth connected devices to run")
@pytest.mark.parametrize(
Expand Down
Loading
Loading