diff --git a/tests/tt_metal/distributed/CMakeLists.txt b/tests/tt_metal/distributed/CMakeLists.txt index 97aa4feff0b..27bb9ee7b53 100644 --- a/tests/tt_metal/distributed/CMakeLists.txt +++ b/tests/tt_metal/distributed/CMakeLists.txt @@ -4,6 +4,8 @@ set(UNIT_TESTS_DISTRIBUTED_SRC ${CMAKE_CURRENT_SOURCE_DIR}/test_mesh_workload.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test_mesh_sub_device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test_mesh_allocator.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test_mesh_events.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ) # Define the function to create test executables for each architecture diff --git a/tests/tt_metal/distributed/test_mesh_events.cpp b/tests/tt_metal/distributed/test_mesh_events.cpp new file mode 100644 index 00000000000..c19d3632800 --- /dev/null +++ b/tests/tt_metal/distributed/test_mesh_events.cpp @@ -0,0 +1,253 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include +#include +#include + +#include "tests/tt_metal/tt_metal/dispatch/dispatch_test_utils.hpp" +#include "tests/tt_metal/tt_metal/common/multi_device_fixture.hpp" +#include "tests/tt_metal/distributed/utils.hpp" + +namespace tt::tt_metal::distributed::test { +namespace { + +using MeshEventsTest = T3000MultiCQMultiDeviceFixture; + +TEST_F(MeshEventsTest, ReplicatedAsyncIO) { + uint32_t NUM_TILES = 1000; + uint32_t num_iterations = 20; + int32_t single_tile_size = ::tt::tt_metal::detail::TileSize(DataFormat::UInt32); + + DeviceLocalBufferConfig per_device_buffer_config{ + .page_size = single_tile_size, + .buffer_type = BufferType::L1, + .buffer_layout = TensorMemoryLayout::INTERLEAVED, + .bottom_up = false}; + ReplicatedBufferConfig global_buffer_config = { + .size = NUM_TILES * single_tile_size, + }; + + std::shared_ptr buf = + MeshBuffer::create(global_buffer_config, per_device_buffer_config, mesh_device_.get()); + + for (std::size_t i = 0; i < num_iterations; i++) { + std::vector src_vec(NUM_TILES * single_tile_size / sizeof(uint32_t), 0); + std::iota(src_vec.begin(), src_vec.end(), i); + + std::vector> readback_vecs = {}; + std::shared_ptr event = std::make_shared(); + // Writes on CQ 0 + EnqueueWriteMeshBuffer(mesh_device_->mesh_command_queue(0), buf, src_vec); + // Device to Device Synchronization + EnqueueRecordEvent(mesh_device_->mesh_command_queue(0), event); + EnqueueWaitForEvent(mesh_device_->mesh_command_queue(1), event); + + // Reads on CQ 1 + for (std::size_t logical_x = 0; logical_x < buf->device()->num_cols(); logical_x++) { + for (std::size_t logical_y = 0; logical_y < buf->device()->num_rows(); logical_y++) { + readback_vecs.push_back({}); + auto shard = buf->get_device_buffer(Coordinate(logical_y, logical_x)); + ReadShard( + mesh_device_->mesh_command_queue(1), readback_vecs.back(), buf, Coordinate(logical_y, logical_x)); + } + } + + for (auto& vec : readback_vecs) { + EXPECT_EQ(vec, src_vec); + } + } +} + +TEST_F(MeshEventsTest, ShardedAsyncIO) { + uint32_t num_iterations = 20; + uint32_t single_tile_size = ::tt::tt_metal::detail::TileSize(DataFormat::UInt32); + + DeviceLocalBufferConfig per_device_buffer_config{ + .page_size = single_tile_size, + .buffer_type = BufferType::DRAM, + .buffer_layout = TensorMemoryLayout::INTERLEAVED, + .bottom_up = true}; + + Shape2D global_buffer_shape = {2048, 2048}; + Shape2D shard_shape = {512, 1024}; + + uint32_t global_buffer_size = global_buffer_shape.height() * global_buffer_shape.width() * sizeof(uint32_t); + + ShardedBufferConfig sharded_config{ + .global_size = global_buffer_size, + .global_buffer_shape = global_buffer_shape, + .shard_shape = shard_shape, + .shard_orientation = ShardOrientation::ROW_MAJOR, + }; + + auto mesh_buffer = MeshBuffer::create(sharded_config, per_device_buffer_config, mesh_device_.get()); + for (std::size_t i = 0; i < num_iterations; i++) { + std::vector src_vec = + std::vector(global_buffer_shape.height() * global_buffer_shape.width(), 0); + std::iota(src_vec.begin(), src_vec.end(), i); + std::shared_ptr event = std::make_shared(); + // Writes on CQ 0 + EnqueueWriteMeshBuffer(mesh_device_->mesh_command_queue(0), mesh_buffer, src_vec); + if (i % 2) { + // Test Host <-> Device synchronization + EnqueueRecordEventToHost(mesh_device_->mesh_command_queue(0), event); + EventSynchronize(event); + } else { + // Test Device <-> Device synchronization + EnqueueRecordEvent(mesh_device_->mesh_command_queue(0), event); + EnqueueWaitForEvent(mesh_device_->mesh_command_queue(1), event); + } + // Reads on CQ 1 + std::vector dst_vec = {}; + EnqueueReadMeshBuffer(mesh_device_->mesh_command_queue(1), dst_vec, mesh_buffer); + + EXPECT_EQ(dst_vec, src_vec); + } +} + +TEST_F(MeshEventsTest, AsyncWorkloadAndIO) { + uint32_t num_iters = 5; + std::vector> src0_bufs = {}; + std::vector> src1_bufs = {}; + std::vector> output_bufs = {}; + + CoreCoord worker_grid_size = mesh_device_->compute_with_storage_grid_size(); + + auto programs = tt::tt_metal::distributed::test::utils::create_eltwise_bin_programs( + mesh_device_, src0_bufs, src1_bufs, output_bufs); + auto mesh_workload = CreateMeshWorkload(); + LogicalDeviceRange devices_0 = LogicalDeviceRange({0, 0}, {3, 0}); + LogicalDeviceRange devices_1 = LogicalDeviceRange({0, 1}, {3, 1}); + + AddProgramToMeshWorkload(mesh_workload, *programs[0], devices_0); + AddProgramToMeshWorkload(mesh_workload, *programs[1], devices_1); + + for (int iter = 0; iter < num_iters; iter++) { + std::vector src0_vec = create_constant_vector_of_bfloat16(src0_bufs[0]->size(), iter + 2); + std::vector src1_vec = create_constant_vector_of_bfloat16(src1_bufs[0]->size(), iter + 3); + + std::shared_ptr write_event = std::make_shared(); + std::shared_ptr op_event = std::make_shared(); + + // Issue writes on MeshCQ 1 + for (std::size_t col_idx = 0; col_idx < worker_grid_size.x; col_idx++) { + for (std::size_t row_idx = 0; row_idx < worker_grid_size.y; row_idx++) { + EnqueueWriteMeshBuffer( + mesh_device_->mesh_command_queue(1), src0_bufs[col_idx * worker_grid_size.y + row_idx], src0_vec); + EnqueueWriteMeshBuffer( + mesh_device_->mesh_command_queue(1), src1_bufs[col_idx * worker_grid_size.y + row_idx], src1_vec); + } + } + if (iter % 2) { + // Test Host <-> Device Synchronization + EnqueueRecordEventToHost(mesh_device_->mesh_command_queue(1), write_event); + EventSynchronize(write_event); + } else { + // Test Device <-> Device Synchronization + EnqueueRecordEvent(mesh_device_->mesh_command_queue(1), write_event); + EnqueueWaitForEvent(mesh_device_->mesh_command_queue(0), write_event); + } + // Issue workloads on MeshCQ 0 + EnqueueMeshWorkload(mesh_device_->mesh_command_queue(0), mesh_workload, false); + if (iter % 2) { + // Test Device <-> Device Synchronization + EnqueueRecordEvent(mesh_device_->mesh_command_queue(0), op_event); + EnqueueWaitForEvent(mesh_device_->mesh_command_queue(1), op_event); + } else { + // Test Host <-> Device Synchronization + EnqueueRecordEventToHost(mesh_device_->mesh_command_queue(0), op_event); + EventSynchronize(op_event); + } + + // Issue reads on MeshCQ 1 + for (std::size_t logical_y = 0; logical_y < mesh_device_->num_rows(); logical_y++) { + for (std::size_t logical_x = 0; logical_x < mesh_device_->num_cols(); logical_x++) { + for (std::size_t col_idx = 0; col_idx < worker_grid_size.x; col_idx++) { + for (std::size_t row_idx = 0; row_idx < worker_grid_size.y; row_idx++) { + std::vector dst_vec = {}; + ReadShard( + mesh_device_->mesh_command_queue(1), + dst_vec, + output_bufs[col_idx * worker_grid_size.y + row_idx], + Coordinate(logical_y, logical_x)); + if (logical_y == 0) { + for (int i = 0; i < dst_vec.size(); i++) { + EXPECT_EQ(dst_vec[i].to_float(), (2 * iter + 5)); + } + } else { + for (int i = 0; i < dst_vec.size(); i++) { + EXPECT_EQ(dst_vec[i].to_float(), (iter + 2) * (iter + 3)); + } + } + } + } + } + } + } +} + +TEST_F(MeshEventsTest, CustomDeviceRanges) { + uint32_t NUM_TILES = 1000; + uint32_t num_iterations = 20; + int32_t single_tile_size = ::tt::tt_metal::detail::TileSize(DataFormat::UInt32); + + DeviceLocalBufferConfig per_device_buffer_config{ + .page_size = single_tile_size, + .buffer_type = BufferType::L1, + .buffer_layout = TensorMemoryLayout::INTERLEAVED, + .bottom_up = false}; + ReplicatedBufferConfig global_buffer_config = { + .size = NUM_TILES * single_tile_size, + }; + + std::shared_ptr buf = + MeshBuffer::create(global_buffer_config, per_device_buffer_config, mesh_device_.get()); + + for (std::size_t i = 0; i < num_iterations; i++) { + std::vector src_vec(NUM_TILES * single_tile_size / sizeof(uint32_t), i); + std::iota(src_vec.begin(), src_vec.end(), i); + LogicalDeviceRange devices_0 = LogicalDeviceRange({0, 0}, {3, 0}); + LogicalDeviceRange devices_1 = LogicalDeviceRange({0, 1}, {3, 1}); + + std::vector> readback_vecs = {}; + std::shared_ptr event_0 = std::make_shared(); + std::shared_ptr event_1 = std::make_shared(); + + mesh_device_->mesh_command_queue(1).enqueue_write_shard_to_sub_grid(*buf, src_vec.data(), devices_0, false); + EnqueueRecordEvent(mesh_device_->mesh_command_queue(1), event_0, {}, devices_0); + EnqueueWaitForEvent(mesh_device_->mesh_command_queue(0), event_0); + + for (std::size_t logical_x = devices_0.start_coord.x; logical_x < devices_0.end_coord.x; logical_x++) { + for (std::size_t logical_y = devices_0.start_coord.y; logical_y < devices_0.end_coord.y; logical_y++) { + readback_vecs.push_back({}); + auto shard = buf->get_device_buffer(Coordinate(logical_y, logical_x)); + ReadShard( + mesh_device_->mesh_command_queue(0), readback_vecs.back(), buf, Coordinate(logical_y, logical_x)); + } + } + + mesh_device_->mesh_command_queue(1).enqueue_write_shard_to_sub_grid(*buf, src_vec.data(), devices_1, false); + EnqueueRecordEventToHost(mesh_device_->mesh_command_queue(1), event_1, {}, devices_1); + EventSynchronize(event_1); + + for (std::size_t logical_x = devices_1.start_coord.x; logical_x < devices_1.end_coord.x; logical_x++) { + for (std::size_t logical_y = devices_1.start_coord.y; logical_y < devices_1.end_coord.y; logical_y++) { + readback_vecs.push_back({}); + auto shard = buf->get_device_buffer(Coordinate(logical_y, logical_x)); + ReadShard( + mesh_device_->mesh_command_queue(0), readback_vecs.back(), buf, Coordinate(logical_y, logical_x)); + } + } + for (auto& vec : readback_vecs) { + EXPECT_EQ(vec, src_vec); + } + } + Finish(mesh_device_->mesh_command_queue(0)); + Finish(mesh_device_->mesh_command_queue(1)); +} + +} // namespace +} // namespace tt::tt_metal::distributed::test diff --git a/tests/tt_metal/distributed/test_mesh_sub_device.cpp b/tests/tt_metal/distributed/test_mesh_sub_device.cpp index 90c0983f4c1..7a21597dd59 100644 --- a/tests/tt_metal/distributed/test_mesh_sub_device.cpp +++ b/tests/tt_metal/distributed/test_mesh_sub_device.cpp @@ -116,34 +116,10 @@ TEST_F(MeshSubDeviceTest, DataCopyOnSubDevices) { std::vector src_vec(input_buf->size() / sizeof(uint32_t)); std::iota(src_vec.begin(), src_vec.end(), i); - EnqueueWriteMeshBuffer(mesh_device_->mesh_command_queue(), input_buf, src_vec, false); - // Read Back global semaphore value across all cores to verify that it has been reset to 0 - // before updating it through host - auto shard_parameters = - ShardSpecBuffer(all_cores, {1, 1}, ShardOrientation::ROW_MAJOR, {1, 1}, {all_cores.size(), 1}); - DeviceLocalBufferConfig global_sem_buf_local_config{ - .page_size = sizeof(uint32_t), - .buffer_type = BufferType::L1, - .buffer_layout = TensorMemoryLayout::HEIGHT_SHARDED, - .shard_parameters = shard_parameters, - .bottom_up = false}; - ReplicatedBufferConfig global_sem_buf_global_config{ - .size = all_cores.size() * sizeof(uint32_t), - }; - - auto global_sem_buf = MeshBuffer::create( - global_sem_buf_global_config, global_sem_buf_local_config, mesh_device_.get(), global_sem.address()); - - for (std::size_t logical_x = 0; logical_x < input_buf->device()->num_cols(); logical_x++) { - for (std::size_t logical_y = 0; logical_y < input_buf->device()->num_rows(); logical_y++) { - std::vector dst_vec; - ReadShard( - mesh_device_->mesh_command_queue(), dst_vec, global_sem_buf, Coordinate(logical_y, logical_x)); - for (const auto& val : dst_vec) { - EXPECT_EQ(val, 0); - } - } - } + // Block after this write on host, since the global semaphore update starting the + // program goes through an independent path (UMD) and can go out of order wrt the + // buffer data + EnqueueWriteMeshBuffer(mesh_device_->mesh_command_queue(), input_buf, src_vec, true); for (auto device : mesh_device_->get_devices()) { tt::llrt::write_hex_vec_to_core( diff --git a/tests/tt_metal/distributed/test_mesh_workload.cpp b/tests/tt_metal/distributed/test_mesh_workload.cpp index ec25670047e..dcf3f9a4158 100644 --- a/tests/tt_metal/distributed/test_mesh_workload.cpp +++ b/tests/tt_metal/distributed/test_mesh_workload.cpp @@ -11,6 +11,7 @@ #include "tests/tt_metal/tt_metal/dispatch/dispatch_test_utils.hpp" #include "tests/tt_metal/tt_metal/common/multi_device_fixture.hpp" +#include "tests/tt_metal/distributed/utils.hpp" namespace tt::tt_metal::distributed::test { namespace { @@ -323,123 +324,6 @@ std::shared_ptr initialize_dummy_program(CoreCoord worker_grid_size) { return program; } -std::vector> create_eltwise_bin_programs( - std::shared_ptr& mesh_device, - std::vector>& src0_bufs, - std::vector>& src1_bufs, - std::vector>& output_bufs) { - const std::vector op_id_to_op_define = {"add_tiles", "mul_tiles"}; - const std::vector op_id_to_op_type_define = {"EltwiseBinaryType::ELWADD", "EltwiseBinaryType::ELWMUL"}; - - CoreCoord worker_grid_size = mesh_device->compute_with_storage_grid_size(); - - std::vector> programs = {std::make_shared(), std::make_shared()}; - auto full_grid = CoreRange({0, 0}, {worker_grid_size.x - 1, worker_grid_size.y - 1}); - - for (std::size_t eltwise_op = 0; eltwise_op < op_id_to_op_define.size(); eltwise_op++) { - auto& program = *programs[eltwise_op]; - uint32_t single_tile_size = 2 * 1024; - uint32_t num_tiles = 2048; - uint32_t dram_buffer_size = - single_tile_size * num_tiles; // num_tiles of FP16_B, hard-coded in the reader/writer kernels - uint32_t page_size = single_tile_size; - - ReplicatedBufferConfig global_buffer_config{.size = dram_buffer_size}; - DeviceLocalBufferConfig per_device_buffer_config{ - .page_size = page_size, - .buffer_type = tt_metal::BufferType::DRAM, - .buffer_layout = TensorMemoryLayout::INTERLEAVED, - .bottom_up = true}; - - for (std::size_t col_idx = 0; col_idx < worker_grid_size.x; col_idx++) { - for (std::size_t row_idx = 0; row_idx < worker_grid_size.y; row_idx++) { - auto src0_dram_buffer = - MeshBuffer::create(global_buffer_config, per_device_buffer_config, mesh_device.get()); - src0_bufs.push_back(src0_dram_buffer); - - auto src1_dram_buffer = - MeshBuffer::create(global_buffer_config, per_device_buffer_config, mesh_device.get()); - src1_bufs.push_back(src1_dram_buffer); - auto dst_dram_buffer = - MeshBuffer::create(global_buffer_config, per_device_buffer_config, mesh_device.get()); - output_bufs.push_back(dst_dram_buffer); - } - } - - uint32_t src0_cb_index = tt::CBIndex::c_0; - uint32_t num_input_tiles = 2; - tt_metal::CircularBufferConfig cb_src0_config = - tt_metal::CircularBufferConfig( - num_input_tiles * single_tile_size, {{src0_cb_index, tt::DataFormat::Float16_b}}) - .set_page_size(src0_cb_index, single_tile_size); - auto cb_src0 = tt_metal::CreateCircularBuffer(program, full_grid, cb_src0_config); - - uint32_t src1_cb_index = tt::CBIndex::c_1; - tt_metal::CircularBufferConfig cb_src1_config = - tt_metal::CircularBufferConfig( - num_input_tiles * single_tile_size, {{src1_cb_index, tt::DataFormat::Float16_b}}) - .set_page_size(src1_cb_index, single_tile_size); - auto cb_src1 = tt_metal::CreateCircularBuffer(program, full_grid, cb_src1_config); - - uint32_t ouput_cb_index = tt::CBIndex::c_16; - uint32_t num_output_tiles = 2; - tt_metal::CircularBufferConfig cb_output_config = - tt_metal::CircularBufferConfig( - num_output_tiles * single_tile_size, {{ouput_cb_index, tt::DataFormat::Float16_b}}) - .set_page_size(ouput_cb_index, single_tile_size); - auto cb_output = tt_metal::CreateCircularBuffer(program, full_grid, cb_output_config); - - auto binary_reader_kernel = tt_metal::CreateKernel( - program, - "tests/tt_metal/tt_metal/test_kernels/dataflow/reader_dual_8bank.cpp", - full_grid, - tt_metal::DataMovementConfig{ - .processor = tt_metal::DataMovementProcessor::RISCV_1, .noc = tt_metal::NOC::RISCV_1_default}); - - auto unary_writer_kernel = tt_metal::CreateKernel( - program, - "tests/tt_metal/tt_metal/test_kernels/dataflow/writer_unary_8bank.cpp", - full_grid, - tt_metal::DataMovementConfig{ - .processor = tt_metal::DataMovementProcessor::RISCV_0, .noc = tt_metal::NOC::RISCV_0_default}); - - std::vector compute_kernel_args = {}; - - bool fp32_dest_acc_en = false; - bool math_approx_mode = false; - std::map binary_defines = { - {"ELTWISE_OP", op_id_to_op_define[eltwise_op]}, {"ELTWISE_OP_TYPE", op_id_to_op_type_define[eltwise_op]}}; - auto eltwise_binary_kernel = tt_metal::CreateKernel( - program, - "tt_metal/kernels/compute/eltwise_binary.cpp", - full_grid, - tt_metal::ComputeConfig{.compile_args = compute_kernel_args, .defines = binary_defines}); - - SetRuntimeArgs(program, eltwise_binary_kernel, full_grid, {2048, 1}); - - for (std::size_t col_idx = 0; col_idx < worker_grid_size.x; col_idx++) { - for (std::size_t row_idx = 0; row_idx < worker_grid_size.y; row_idx++) { - CoreCoord curr_core = {col_idx, row_idx}; - const std::array reader_args = { - src0_bufs.at(col_idx * worker_grid_size.y + row_idx)->address(), - 0, - num_tiles, - src1_bufs.at(col_idx * worker_grid_size.y + row_idx)->address(), - 0, - num_tiles, - 0}; - - const std::array writer_args = { - output_bufs.at(col_idx * worker_grid_size.y + row_idx)->address(), 0, num_tiles}; - - SetRuntimeArgs(program, unary_writer_kernel, curr_core, writer_args); - SetRuntimeArgs(program, binary_reader_kernel, curr_core, reader_args); - } - } - } - return programs; -} - void verify_cb_config( std::shared_ptr& mesh_device, MeshWorkload& workload, @@ -650,7 +534,8 @@ TEST_F(MeshWorkloadTest, EltwiseBinaryMeshWorkload) { CoreCoord worker_grid_size = mesh_device_->compute_with_storage_grid_size(); - auto programs = create_eltwise_bin_programs(mesh_device_, src0_bufs, src1_bufs, output_bufs); + auto programs = tt::tt_metal::distributed::test::utils::create_eltwise_bin_programs( + mesh_device_, src0_bufs, src1_bufs, output_bufs); auto mesh_workload = CreateMeshWorkload(); LogicalDeviceRange devices_0 = LogicalDeviceRange({0, 0}, {3, 0}); LogicalDeviceRange devices_1 = LogicalDeviceRange({0, 1}, {3, 1}); diff --git a/tests/tt_metal/distributed/utils.cpp b/tests/tt_metal/distributed/utils.cpp new file mode 100644 index 00000000000..c53f1c9d96a --- /dev/null +++ b/tests/tt_metal/distributed/utils.cpp @@ -0,0 +1,126 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "tests/tt_metal/distributed/utils.hpp" + +namespace tt::tt_metal::distributed::test::utils { + +std::vector> create_eltwise_bin_programs( + std::shared_ptr& mesh_device, + std::vector>& src0_bufs, + std::vector>& src1_bufs, + std::vector>& output_bufs) { + const std::vector op_id_to_op_define = {"add_tiles", "mul_tiles"}; + const std::vector op_id_to_op_type_define = {"EltwiseBinaryType::ELWADD", "EltwiseBinaryType::ELWMUL"}; + + CoreCoord worker_grid_size = mesh_device->compute_with_storage_grid_size(); + + std::vector> programs = {std::make_shared(), std::make_shared()}; + auto full_grid = CoreRange({0, 0}, {worker_grid_size.x - 1, worker_grid_size.y - 1}); + + for (std::size_t eltwise_op = 0; eltwise_op < op_id_to_op_define.size(); eltwise_op++) { + auto& program = *programs[eltwise_op]; + uint32_t single_tile_size = 2 * 1024; + uint32_t num_tiles = 2048; + uint32_t dram_buffer_size = + single_tile_size * num_tiles; // num_tiles of FP16_B, hard-coded in the reader/writer kernels + uint32_t page_size = single_tile_size; + + ReplicatedBufferConfig global_buffer_config{.size = dram_buffer_size}; + DeviceLocalBufferConfig per_device_buffer_config{ + .page_size = page_size, + .buffer_type = tt_metal::BufferType::DRAM, + .buffer_layout = TensorMemoryLayout::INTERLEAVED, + .bottom_up = true}; + + for (std::size_t col_idx = 0; col_idx < worker_grid_size.x; col_idx++) { + for (std::size_t row_idx = 0; row_idx < worker_grid_size.y; row_idx++) { + auto src0_dram_buffer = + MeshBuffer::create(global_buffer_config, per_device_buffer_config, mesh_device.get()); + src0_bufs.push_back(src0_dram_buffer); + + auto src1_dram_buffer = + MeshBuffer::create(global_buffer_config, per_device_buffer_config, mesh_device.get()); + src1_bufs.push_back(src1_dram_buffer); + auto dst_dram_buffer = + MeshBuffer::create(global_buffer_config, per_device_buffer_config, mesh_device.get()); + output_bufs.push_back(dst_dram_buffer); + } + } + + uint32_t src0_cb_index = tt::CBIndex::c_0; + uint32_t num_input_tiles = 2; + tt_metal::CircularBufferConfig cb_src0_config = + tt_metal::CircularBufferConfig( + num_input_tiles * single_tile_size, {{src0_cb_index, tt::DataFormat::Float16_b}}) + .set_page_size(src0_cb_index, single_tile_size); + auto cb_src0 = tt_metal::CreateCircularBuffer(program, full_grid, cb_src0_config); + + uint32_t src1_cb_index = tt::CBIndex::c_1; + tt_metal::CircularBufferConfig cb_src1_config = + tt_metal::CircularBufferConfig( + num_input_tiles * single_tile_size, {{src1_cb_index, tt::DataFormat::Float16_b}}) + .set_page_size(src1_cb_index, single_tile_size); + auto cb_src1 = tt_metal::CreateCircularBuffer(program, full_grid, cb_src1_config); + + uint32_t ouput_cb_index = tt::CBIndex::c_16; + uint32_t num_output_tiles = 2; + tt_metal::CircularBufferConfig cb_output_config = + tt_metal::CircularBufferConfig( + num_output_tiles * single_tile_size, {{ouput_cb_index, tt::DataFormat::Float16_b}}) + .set_page_size(ouput_cb_index, single_tile_size); + auto cb_output = tt_metal::CreateCircularBuffer(program, full_grid, cb_output_config); + + auto binary_reader_kernel = tt_metal::CreateKernel( + program, + "tests/tt_metal/tt_metal/test_kernels/dataflow/reader_dual_8bank.cpp", + full_grid, + tt_metal::DataMovementConfig{ + .processor = tt_metal::DataMovementProcessor::RISCV_1, .noc = tt_metal::NOC::RISCV_1_default}); + + auto unary_writer_kernel = tt_metal::CreateKernel( + program, + "tests/tt_metal/tt_metal/test_kernels/dataflow/writer_unary_8bank.cpp", + full_grid, + tt_metal::DataMovementConfig{ + .processor = tt_metal::DataMovementProcessor::RISCV_0, .noc = tt_metal::NOC::RISCV_0_default}); + + std::vector compute_kernel_args = {}; + + bool fp32_dest_acc_en = false; + bool math_approx_mode = false; + std::map binary_defines = { + {"ELTWISE_OP", op_id_to_op_define[eltwise_op]}, {"ELTWISE_OP_TYPE", op_id_to_op_type_define[eltwise_op]}}; + auto eltwise_binary_kernel = tt_metal::CreateKernel( + program, + "tt_metal/kernels/compute/eltwise_binary.cpp", + full_grid, + tt_metal::ComputeConfig{.compile_args = compute_kernel_args, .defines = binary_defines}); + + SetRuntimeArgs(program, eltwise_binary_kernel, full_grid, {2048, 1}); + + for (std::size_t col_idx = 0; col_idx < worker_grid_size.x; col_idx++) { + for (std::size_t row_idx = 0; row_idx < worker_grid_size.y; row_idx++) { + CoreCoord curr_core = {col_idx, row_idx}; + const std::array reader_args = { + src0_bufs.at(col_idx * worker_grid_size.y + row_idx)->address(), + 0, + num_tiles, + src1_bufs.at(col_idx * worker_grid_size.y + row_idx)->address(), + 0, + num_tiles, + 0}; + + const std::array writer_args = { + output_bufs.at(col_idx * worker_grid_size.y + row_idx)->address(), 0, num_tiles}; + + SetRuntimeArgs(program, unary_writer_kernel, curr_core, writer_args); + SetRuntimeArgs(program, binary_reader_kernel, curr_core, reader_args); + } + } + } + return programs; +} + +} // namespace tt::tt_metal::distributed::test::utils diff --git a/tests/tt_metal/distributed/utils.hpp b/tests/tt_metal/distributed/utils.hpp new file mode 100644 index 00000000000..36b1bbb2fdd --- /dev/null +++ b/tests/tt_metal/distributed/utils.hpp @@ -0,0 +1,18 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +namespace tt::tt_metal::distributed::test::utils { + +std::vector> create_eltwise_bin_programs( + std::shared_ptr& mesh_device, + std::vector>& src0_bufs, + std::vector>& src1_bufs, + std::vector>& output_bufs); + +} // namespace tt::tt_metal::distributed::test::utils diff --git a/tests/tt_metal/tt_metal/common/multi_device_fixture.hpp b/tests/tt_metal/tt_metal/common/multi_device_fixture.hpp index 04a8ce84a78..1fa6f2443c9 100644 --- a/tests/tt_metal/tt_metal/common/multi_device_fixture.hpp +++ b/tests/tt_metal/tt_metal/common/multi_device_fixture.hpp @@ -52,7 +52,7 @@ class N300DeviceFixture : public MultiDeviceFixture { class T3000MultiDeviceFixture : public ::testing::Test { protected: - void SetUp() override { + virtual void SetUp() override { using tt::tt_metal::distributed::MeshDevice; using tt::tt_metal::distributed::MeshDeviceConfig; using tt::tt_metal::distributed::MeshShape; @@ -66,7 +66,7 @@ class T3000MultiDeviceFixture : public ::testing::Test { if (num_devices < 8 or arch != tt::ARCH::WORMHOLE_B0) { GTEST_SKIP() << "Skipping T3K Multi-Device test suite on non T3K machine."; } - mesh_device_ = MeshDevice::create(MeshDeviceConfig{.mesh_shape = MeshShape{2, 4}}); + create_mesh_device(); } void TearDown() override { @@ -77,5 +77,28 @@ class T3000MultiDeviceFixture : public ::testing::Test { mesh_device_->close(); mesh_device_.reset(); } + +protected: + virtual void create_mesh_device() { + using tt::tt_metal::distributed::MeshDevice; + using tt::tt_metal::distributed::MeshDeviceConfig; + using tt::tt_metal::distributed::MeshShape; + + mesh_device_ = MeshDevice::create(MeshDeviceConfig{.mesh_shape = MeshShape{2, 4}}); + } + std::shared_ptr mesh_device_; }; + +class T3000MultiCQMultiDeviceFixture : public T3000MultiDeviceFixture { +protected: + // Override only the mesh device creation logic + void create_mesh_device() override { + using tt::tt_metal::distributed::MeshDevice; + using tt::tt_metal::distributed::MeshDeviceConfig; + using tt::tt_metal::distributed::MeshShape; + + mesh_device_ = + MeshDevice::create(MeshDeviceConfig{.mesh_shape = MeshShape{2, 4}}, 0, 0, 2, DispatchCoreType::ETH); + } +}; diff --git a/tt_metal/api/tt-metalium/command_queue.hpp b/tt_metal/api/tt-metalium/command_queue.hpp index 9c9bb3b29de..3c1a57fe7e7 100644 --- a/tt_metal/api/tt-metalium/command_queue.hpp +++ b/tt_metal/api/tt-metalium/command_queue.hpp @@ -75,10 +75,9 @@ class CommandQueue { tt::stl::Span sub_device_ids = {}) = 0; virtual void enqueue_record_event( - const std::shared_ptr& event, - bool clear_count = false, - tt::stl::Span sub_device_ids = {}) = 0; - virtual void enqueue_wait_for_event(const std::shared_ptr& sync_event, bool clear_count = false) = 0; + const std::shared_ptr& event, tt::stl::Span sub_device_ids = {}) = 0; + + virtual void enqueue_wait_for_event(const std::shared_ptr& sync_event) = 0; virtual void enqueue_write_buffer( const std::variant, std::shared_ptr>& buffer, diff --git a/tt_metal/api/tt-metalium/dispatch_core_manager.hpp b/tt_metal/api/tt-metalium/dispatch_core_manager.hpp index 62433e832b5..2edda1f01ae 100644 --- a/tt_metal/api/tt-metalium/dispatch_core_manager.hpp +++ b/tt_metal/api/tt-metalium/dispatch_core_manager.hpp @@ -143,6 +143,8 @@ class dispatch_core_manager { bool is_dispatcher_s_core_allocated(chip_id_t device_id, uint16_t channel, uint8_t cq_id); + bool is_dispatcher_d_core_allocated(chip_id_t device_id, uint16_t channel, uint8_t cq_id); + /// @brief Gets the location of the kernel designated to relay fast dispatch commands to worker cores from a particular command queue /// @param device_id ID of the device that should be running the command /// @param channel assigned to the command queue where commands are enqueued diff --git a/tt_metal/api/tt-metalium/distributed.hpp b/tt_metal/api/tt-metalium/distributed.hpp index 96b3a23ed10..017214b437a 100644 --- a/tt_metal/api/tt-metalium/distributed.hpp +++ b/tt_metal/api/tt-metalium/distributed.hpp @@ -6,7 +6,7 @@ #include "mesh_buffer.hpp" #include "mesh_command_queue.hpp" -#include "mesh_workload.hpp" +#include "mesh_event.hpp" namespace tt::tt_metal { @@ -78,7 +78,23 @@ void EnqueueReadMeshBuffer( mesh_cq.enqueue_read_mesh_buffer(dst.data(), mesh_buffer, blocking); } -void Finish(MeshCommandQueue& mesh_cq); +void EnqueueRecordEvent( + MeshCommandQueue& mesh_cq, + const std::shared_ptr& event, + tt::stl::Span sub_device_ids = {}, + const std::optional& device_range = std::nullopt); + +void EnqueueRecordEventToHost( + MeshCommandQueue& mesh_cq, + const std::shared_ptr& event, + tt::stl::Span sub_device_ids = {}, + const std::optional& device_range = std::nullopt); + +void EnqueueWaitForEvent(MeshCommandQueue& mesh_cq, const std::shared_ptr& event); + +void EventSynchronize(const std::shared_ptr& event); + +void Finish(MeshCommandQueue& mesh_cq, tt::stl::Span sub_device_ids = {}); } // namespace distributed } // namespace tt::tt_metal diff --git a/tt_metal/api/tt-metalium/mesh_command_queue.hpp b/tt_metal/api/tt-metalium/mesh_command_queue.hpp index 61263207b9c..11ca2ab65e8 100644 --- a/tt_metal/api/tt-metalium/mesh_command_queue.hpp +++ b/tt_metal/api/tt-metalium/mesh_command_queue.hpp @@ -5,6 +5,8 @@ #pragma once #include +#include + #include "buffer.hpp" #include "command_queue_interface.hpp" #include "mesh_buffer.hpp" @@ -13,6 +15,9 @@ namespace tt::tt_metal::distributed { +class MeshEvent; +struct MeshReadEventDescriptor; + class MeshCommandQueue { // Main interface to dispatch data and workloads to a MeshDevice // Currently only supports dispatching workloads and relies on the @@ -39,12 +44,18 @@ class MeshCommandQueue { // Helper functions for read and write entire Sharded-MeshBuffers void write_sharded_buffer(const MeshBuffer& buffer, const void* src); void read_sharded_buffer(MeshBuffer& buffer, void* dst); + void enqueue_record_event_helper( + const std::shared_ptr& event, + tt::stl::Span sub_device_ids, + bool notify_host, + const std::optional& device_range = std::nullopt); std::array config_buffer_mgr_; std::array expected_num_workers_completed_; MeshDevice* mesh_device_ = nullptr; uint32_t id_ = 0; CoreCoord dispatch_core_; CoreType dispatch_core_type_ = CoreType::WORKER; + std::queue> event_descriptors_; public: MeshCommandQueue(MeshDevice* mesh_device, uint32_t id); @@ -76,7 +87,18 @@ class MeshCommandQueue { const std::shared_ptr& mesh_buffer, bool blocking); - void finish(); + void enqueue_record_event( + const std::shared_ptr& event, + tt::stl::Span sub_device_ids = {}, + const std::optional& device_range = std::nullopt); + void enqueue_record_event_to_host( + const std::shared_ptr& event, + tt::stl::Span sub_device_ids = {}, + const std::optional& device_range = std::nullopt); + void enqueue_wait_for_event(const std::shared_ptr& sync_event); + void drain_events_from_completion_queue(); + void verify_reported_events_after_draining(const std::shared_ptr& event); + void finish(tt::stl::Span sub_device_ids = {}); void reset_worker_state( bool reset_launch_msg_state, uint32_t num_sub_devices, diff --git a/tt_metal/api/tt-metalium/mesh_device.hpp b/tt_metal/api/tt-metalium/mesh_device.hpp index ec04ada058f..c4f1469ee46 100644 --- a/tt_metal/api/tt-metalium/mesh_device.hpp +++ b/tt_metal/api/tt-metalium/mesh_device.hpp @@ -58,7 +58,7 @@ class MeshDevice : public IDevice, public std::enable_shared_from_this> submeshes_; // Parent owns submeshes and is responsible for their destruction std::weak_ptr parent_mesh_; // Submesh created with reference to parent mesh - std::unique_ptr mesh_command_queue_; + std::vector> mesh_command_queues_; std::unique_ptr sub_device_manager_tracker_; // This is a reference device used to query properties that are the same for all devices in the mesh. @@ -238,7 +238,7 @@ class MeshDevice : public IDevice, public std::enable_shared_from_this sub_devices, DeviceAddr local_l1_size); // TODO #16526: Temporary api until migration to actual fabric is complete diff --git a/tt_metal/api/tt-metalium/mesh_device_view.hpp b/tt_metal/api/tt-metalium/mesh_device_view.hpp index 98a7cad5740..fbadc8f32c2 100644 --- a/tt_metal/api/tt-metalium/mesh_device_view.hpp +++ b/tt_metal/api/tt-metalium/mesh_device_view.hpp @@ -39,6 +39,15 @@ struct Coordinate { return os << "Coord(" << coord.row << ", " << coord.col << ")"; } }; +// TODO (Issue #17477): MeshWorkload and MeshEvent currently rely on the coordinate systems +// exposed below. These must be uplifted to an ND coordinate system (DeviceCoord and DeviceRange), +// keeping things more consistent across the stack. +// For now, since the LogicalDeviceRange concept is fundamentally identical to the CoreRange concept +// on a 2D Mesh use this definition. CoreRange contains several utility functions required +// in the MeshWorkload context. + +using DeviceCoord = CoreCoord; +using LogicalDeviceRange = CoreRange; /** * @brief The MeshDeviceView class provides a view of a specific sub-region within the MeshDevice. diff --git a/tt_metal/api/tt-metalium/mesh_event.hpp b/tt_metal/api/tt-metalium/mesh_event.hpp new file mode 100644 index 00000000000..f115a118d15 --- /dev/null +++ b/tt_metal/api/tt-metalium/mesh_event.hpp @@ -0,0 +1,19 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "mesh_device.hpp" + +namespace tt::tt_metal::distributed { + +class MeshEvent { +public: + MeshDevice* device = nullptr; + LogicalDeviceRange device_range = LogicalDeviceRange({0, 0}); + uint32_t cq_id = 0; + uint32_t event_id = 0; +}; + +} // namespace tt::tt_metal::distributed diff --git a/tt_metal/api/tt-metalium/mesh_workload.hpp b/tt_metal/api/tt-metalium/mesh_workload.hpp index 577c1f0e7d6..f57bccb3edf 100644 --- a/tt_metal/api/tt-metalium/mesh_workload.hpp +++ b/tt_metal/api/tt-metalium/mesh_workload.hpp @@ -9,11 +9,6 @@ #include "mesh_buffer.hpp" namespace tt::tt_metal::distributed { -// The LogicalDeviceRange concept is fundamentally identical to the CoreRange concept -// Use this definition for now, since CoreRange contains several utility functions required -// in the MeshWorkload context. CoreRange can eventually be renamed to Range2D. -using LogicalDeviceRange = CoreRange; -using DeviceCoord = CoreCoord; using RuntimeArgsPerCore = std::vector>; class MeshCommandQueue; diff --git a/tt_metal/distributed/distributed.cpp b/tt_metal/distributed/distributed.cpp index d7410816baa..b92546832a1 100644 --- a/tt_metal/distributed/distributed.cpp +++ b/tt_metal/distributed/distributed.cpp @@ -20,6 +20,34 @@ void EnqueueMeshWorkload(MeshCommandQueue& mesh_cq, MeshWorkload& mesh_workload, mesh_cq.enqueue_mesh_workload(mesh_workload, blocking); } -void Finish(MeshCommandQueue& mesh_cq) { mesh_cq.finish(); } +void EnqueueRecordEvent( + MeshCommandQueue& mesh_cq, + const std::shared_ptr& event, + tt::stl::Span sub_device_ids, + const std::optional& device_range) { + mesh_cq.enqueue_record_event(event, sub_device_ids, device_range); +} + +void EnqueueRecordEventToHost( + MeshCommandQueue& mesh_cq, + const std::shared_ptr& event, + tt::stl::Span sub_device_ids, + const std::optional& device_range) { + mesh_cq.enqueue_record_event_to_host(event, sub_device_ids, device_range); +} + +void EnqueueWaitForEvent(MeshCommandQueue& mesh_cq, const std::shared_ptr& event) { + mesh_cq.enqueue_wait_for_event(event); +} + +void EventSynchronize(const std::shared_ptr& event) { + auto& mesh_cq = event->device->mesh_command_queue(event->cq_id); + mesh_cq.drain_events_from_completion_queue(); + mesh_cq.verify_reported_events_after_draining(event); +} + +void Finish(MeshCommandQueue& mesh_cq, tt::stl::Span sub_device_ids) { + mesh_cq.finish(sub_device_ids); +} } // namespace tt::tt_metal::distributed diff --git a/tt_metal/distributed/mesh_command_queue.cpp b/tt_metal/distributed/mesh_command_queue.cpp index 89eaaff1b03..d19911a3112 100644 --- a/tt_metal/distributed/mesh_command_queue.cpp +++ b/tt_metal/distributed/mesh_command_queue.cpp @@ -4,6 +4,7 @@ #include #include +#include #include #include @@ -11,9 +12,15 @@ #include "tt_metal/distributed/mesh_workload_utils.hpp" #include "tt_metal/impl/buffers/dispatch.hpp" #include "tt_metal/impl/program/dispatch.hpp" +#include "tt_metal/impl/dispatch/dispatch_query_manager.hpp" namespace tt::tt_metal::distributed { +struct MeshReadEventDescriptor { + ReadEventDescriptor single_device_descriptor; + LogicalDeviceRange device_range; +}; + MeshCommandQueue::MeshCommandQueue(MeshDevice* mesh_device, uint32_t id) { this->mesh_device_ = mesh_device; this->id_ = id; @@ -62,6 +69,8 @@ void MeshCommandQueue::enqueue_mesh_workload(MeshWorkload& mesh_workload, bool b auto sub_device_index = sub_device_id.to_index(); auto mesh_device_id = this->mesh_device_->id(); auto& sysmem_manager = mesh_device_->get_device(0, 0)->sysmem_manager(); + auto dispatch_core_config = DispatchQueryManager::instance().get_dispatch_core_config(); + CoreType dispatch_core_type = dispatch_core_config.get_core_type(); TT_FATAL( mesh_workload.get_program_binary_status(mesh_device_id) != ProgramBinaryStatus::NotSent, @@ -105,7 +114,7 @@ void MeshCommandQueue::enqueue_mesh_workload(MeshWorkload& mesh_workload, bool b sysmem_manager.get_worker_launch_message_buffer_state()[sub_device_index].get_unicast_wptr(), expected_num_workers_completed_[sub_device_index], this->virtual_program_dispatch_core(), - this->dispatch_core_type(), + dispatch_core_type, sub_device_id, dispatch_metadata, mesh_workload.get_program_binary_status(mesh_device_id), @@ -117,14 +126,13 @@ void MeshCommandQueue::enqueue_mesh_workload(MeshWorkload& mesh_workload, bool b logical_x++) { for (std::size_t logical_y = device_range.start_coord.y; logical_y < device_range.end_coord.y + 1; logical_y++) { - experimental::write_program_commands( - this->mesh_device_->get_device(logical_y, logical_x)->command_queue(this->id_), + program_dispatch::write_program_command_sequence( program_cmd_seq, - num_workers, - sub_device_id, + this->mesh_device_->get_device(logical_y, logical_x)->sysmem_manager(), + id_, + dispatch_core_type, dispatch_metadata.stall_first, - dispatch_metadata.stall_before_program, - false); + dispatch_metadata.stall_before_program); chip_ids_in_workload.insert(this->mesh_device_->get_device(logical_y, logical_x)->id()); } } @@ -132,8 +140,11 @@ void MeshCommandQueue::enqueue_mesh_workload(MeshWorkload& mesh_workload, bool b // Send go signals to devices not running a program to ensure consistent global state for (auto& device : this->mesh_device_->get_devices()) { if (chip_ids_in_workload.find(device->id()) == chip_ids_in_workload.end()) { - experimental::write_go_signal( - device->command_queue(this->id_), + write_go_signal( + id_, + device, + sub_device_id, + device->sysmem_manager(), expected_num_workers_completed_[sub_device_index], this->virtual_program_dispatch_core(), mcast_go_signals, @@ -159,10 +170,11 @@ void MeshCommandQueue::enqueue_mesh_workload(MeshWorkload& mesh_workload, bool b } } -void MeshCommandQueue::finish() { - for (auto device : this->mesh_device_->get_devices()) { - Finish(device->command_queue(this->id_)); - } +void MeshCommandQueue::finish(tt::stl::Span sub_device_ids) { + std::shared_ptr event = std::make_shared(); + this->enqueue_record_event_to_host(event, sub_device_ids); + this->drain_events_from_completion_queue(); + this->verify_reported_events_after_draining(event); } void MeshCommandQueue::write_shard_to_device( @@ -181,6 +193,7 @@ void MeshCommandQueue::read_shard_from_device( void* dst, const BufferRegion& region, tt::stl::Span sub_device_ids) { + this->drain_events_from_completion_queue(); auto device = shard_view->device(); chip_id_t mmio_device_id = tt::Cluster::instance().get_associated_mmio_device(device->id()); uint16_t channel = tt::Cluster::instance().get_assigned_channel_for_device(device->id()); @@ -417,6 +430,110 @@ void MeshCommandQueue::enqueue_read_shards( } } +void MeshCommandQueue::enqueue_record_event_helper( + const std::shared_ptr& event, + tt::stl::Span sub_device_ids, + bool notify_host, + const std::optional& device_range) { + auto& sysmem_manager = mesh_device_->get_device(0, 0)->sysmem_manager(); + event->cq_id = id_; + event->event_id = sysmem_manager.get_next_event(id_); + event->device = mesh_device_; + event->device_range = + device_range.value_or(LogicalDeviceRange({0, 0}, {mesh_device_->num_cols() - 1, mesh_device_->num_rows() - 1})); + + sub_device_ids = buffer_dispatch::select_sub_device_ids(mesh_device_, sub_device_ids); + for (std::size_t logical_x = event->device_range.start_coord.x; logical_x < event->device_range.end_coord.x + 1; + logical_x++) { + for (std::size_t logical_y = event->device_range.start_coord.y; logical_y < event->device_range.end_coord.y + 1; + logical_y++) { + event_dispatch::issue_record_event_commands( + mesh_device_, + event->event_id, + id_, + mesh_device_->num_hw_cqs(), + mesh_device_->get_device(logical_y, logical_x)->sysmem_manager(), + sub_device_ids, + expected_num_workers_completed_, + notify_host); + } + } +} + +void MeshCommandQueue::enqueue_record_event( + const std::shared_ptr& event, + tt::stl::Span sub_device_ids, + const std::optional& device_range) { + this->enqueue_record_event_helper(event, sub_device_ids, false, device_range); +} + +void MeshCommandQueue::enqueue_record_event_to_host( + const std::shared_ptr& event, + tt::stl::Span sub_device_ids, + const std::optional& device_range) { + this->enqueue_record_event_helper(event, sub_device_ids, true, device_range); + event_descriptors_.push(std::make_shared(MeshReadEventDescriptor{ + .single_device_descriptor = ReadEventDescriptor(event->event_id), .device_range = event->device_range})); +} + +void MeshCommandQueue::enqueue_wait_for_event(const std::shared_ptr& sync_event) { + for (std::size_t logical_x = sync_event->device_range.start_coord.x; + logical_x < sync_event->device_range.end_coord.x + 1; + logical_x++) { + for (std::size_t logical_y = sync_event->device_range.start_coord.y; + logical_y < sync_event->device_range.end_coord.y + 1; + logical_y++) { + event_dispatch::issue_wait_for_event_commands( + id_, + sync_event->cq_id, + mesh_device_->get_device(logical_y, logical_x)->sysmem_manager(), + sync_event->event_id); + } + } +} + +void MeshCommandQueue::drain_events_from_completion_queue() { + constexpr bool exit_condition = false; + auto num_events = event_descriptors_.size(); + for (std::size_t event_idx = 0; event_idx < num_events; event_idx++) { + auto& mesh_read_descriptor = event_descriptors_.front(); + auto& device_range = mesh_read_descriptor->device_range; + for (std::size_t logical_x = device_range.start_coord.x; logical_x < device_range.end_coord.x + 1; + logical_x++) { + for (std::size_t logical_y = device_range.start_coord.y; logical_y < device_range.end_coord.y + 1; + logical_y++) { + auto device = mesh_device_->get_device(logical_y, logical_x); + chip_id_t mmio_device_id = tt::Cluster::instance().get_associated_mmio_device(device->id()); + uint16_t channel = tt::Cluster::instance().get_assigned_channel_for_device(device->id()); + bool exit_condition = false; + device->sysmem_manager().completion_queue_wait_front(id_, exit_condition); + event_dispatch::read_events_from_completion_queue( + mesh_read_descriptor->single_device_descriptor, + mmio_device_id, + channel, + id_, + device->sysmem_manager()); + } + } + event_descriptors_.pop(); + } +} + +void MeshCommandQueue::verify_reported_events_after_draining(const std::shared_ptr& event) { + auto& device_range = event->device_range; + for (std::size_t logical_x = device_range.start_coord.x; logical_x < device_range.end_coord.x + 1; logical_x++) { + for (std::size_t logical_y = device_range.start_coord.y; logical_y < device_range.end_coord.y + 1; + logical_y++) { + TT_FATAL( + mesh_device_->get_device(logical_y, logical_x) + ->sysmem_manager() + .get_last_completed_event(event->cq_id) >= event->event_id, + "Expected to see event id {} in completion queue", + event->event_id); + } + } +} + void MeshCommandQueue::reset_worker_state( bool reset_launch_msg_state, uint32_t num_sub_devices, const vector_memcpy_aligned& go_signal_noc_data) { for (auto device : mesh_device_->get_devices()) { @@ -433,11 +550,6 @@ void MeshCommandQueue::reset_worker_state( } program_dispatch::reset_config_buf_mgrs_and_expected_workers( config_buffer_mgr_, expected_num_workers_completed_, mesh_device_->num_sub_devices()); - for (auto device : mesh_device_->get_devices()) { - for (int i = 0; i < mesh_device_->num_sub_devices(); i++) { - device->command_queue(id_).set_expected_num_workers_completed_for_sub_device(i, 0); - } - } if (reset_launch_msg_state) { auto& sysmem_manager = mesh_device_->get_device(0, 0)->sysmem_manager(); sysmem_manager.reset_worker_launch_message_buffer_state(num_sub_devices); diff --git a/tt_metal/distributed/mesh_device.cpp b/tt_metal/distributed/mesh_device.cpp index e02498c3c28..312d164934b 100644 --- a/tt_metal/distributed/mesh_device.cpp +++ b/tt_metal/distributed/mesh_device.cpp @@ -221,9 +221,10 @@ IDevice* MeshDevice::get_device(size_t row_idx, size_t col_idx) const { return this->get_device_index(row_idx * num_cols() + col_idx); } -MeshCommandQueue& MeshDevice::mesh_command_queue() { - TT_FATAL(this->using_fast_dispatch(), "Can only acess the MeshCommandQueue when using Fast Dispatch."); - return *(mesh_command_queue_); +MeshCommandQueue& MeshDevice::mesh_command_queue(std::size_t cq_id) const { + TT_FATAL(this->using_fast_dispatch(), "Can only access the MeshCommandQueue when using Fast Dispatch."); + TT_FATAL(cq_id < mesh_command_queues_.size(), "cq_id {} is out of range", cq_id); + return *(mesh_command_queues_[cq_id]); } const DeviceIds MeshDevice::get_device_ids() const { @@ -626,9 +627,11 @@ bool MeshDevice::initialize( const auto& allocator = reference_device()->allocator(); sub_device_manager_tracker_ = std::make_unique( this, std::make_unique(allocator->get_config()), sub_devices); - + mesh_command_queues_.reserve(this->num_hw_cqs()); if (this->using_fast_dispatch()) { - mesh_command_queue_ = std::make_unique(this, 0); + for (std::size_t cq_id = 0; cq_id < this->num_hw_cqs(); cq_id++) { + mesh_command_queues_.push_back(std::make_unique(this, cq_id)); + } } return true; } diff --git a/tt_metal/distributed/mesh_workload_utils.cpp b/tt_metal/distributed/mesh_workload_utils.cpp index 634249da09c..c51a99c957a 100644 --- a/tt_metal/distributed/mesh_workload_utils.cpp +++ b/tt_metal/distributed/mesh_workload_utils.cpp @@ -6,54 +6,28 @@ #include #include "tt_metal/impl/program/dispatch.hpp" +#include "tt_metal/impl/dispatch/dispatch_query_manager.hpp" namespace tt::tt_metal::distributed { -namespace experimental { - -void write_program_commands( - CommandQueue& cq, - ProgramCommandSequence& program_cmd_seq, - uint32_t num_active_cores_in_program, - SubDeviceId sub_device_id, - bool stall_first, - bool stall_before_program, - bool blocking) { - auto sub_device_index = sub_device_id.to_index(); - // Increment expected num workers inside single device CQs to ensure other paths dont break. - // This is temporary, since data movement and events rely on single device CQs. Once MeshCommandQueue - // supports all runtime features, this will be removed, and program dispatch commands will be written - // directly through dedicated interfaces. - - uint32_t num_workers_in_cq = cq.get_expected_num_workers_completed_for_sub_device(sub_device_index); - cq.set_expected_num_workers_completed_for_sub_device( - sub_device_index, num_workers_in_cq + num_active_cores_in_program); - // Write program command stream to device - program_dispatch::write_program_command_sequence( - program_cmd_seq, - cq.device()->sysmem_manager(), - cq.id(), - dispatch_core_manager::instance().get_dispatch_core_type(cq.device()->id()), - stall_first, - stall_before_program); -} - // Use this function to send go signals to a device not running a program. // In the MeshWorkload context, a go signal must be sent to each device when // a workload is dispatched, in order to maintain consistent global state. void write_go_signal( - CommandQueue& cq, + uint8_t cq_id, + IDevice* device, + SubDeviceId sub_device_id, + SystemMemoryManager& sysmem_manager, uint32_t expected_num_workers_completed, CoreCoord dispatch_core, bool send_mcast, bool send_unicasts, - int num_unicast_txns = -1) { + int num_unicast_txns) { uint32_t pcie_alignment = hal.get_alignment(HalMemType::HOST); uint32_t cmd_sequence_sizeB = align(sizeof(CQPrefetchCmd) + sizeof(CQDispatchCmd), pcie_alignment) + hal.get_alignment(HalMemType::HOST); - auto& manager = cq.device()->sysmem_manager(); - void* cmd_region = manager.issue_queue_reserve(cmd_sequence_sizeB, cq.id()); + void* cmd_region = sysmem_manager.issue_queue_reserve(cmd_sequence_sizeB, cq_id); HugepageDeviceCommand go_signal_cmd_sequence(cmd_region, cmd_sequence_sizeB); go_msg_t run_program_go_signal; @@ -63,30 +37,37 @@ void write_go_signal( run_program_go_signal.master_y = dispatch_core.y; run_program_go_signal.dispatch_message_offset = 0; - CoreType dispatch_core_type = dispatch_core_manager::instance().get_dispatch_core_type(cq.device()->id()); + CoreType dispatch_core_type = dispatch_core_manager::instance().get_dispatch_core_type(device->id()); uint32_t dispatch_message_addr = DispatchMemMap::get(dispatch_core_type) .get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_MESSAGE); - go_signal_cmd_sequence.add_notify_dispatch_s_go_signal_cmd( - 0, /* wait */ - 1 /* index_bitmask */); - + auto sub_device_index = sub_device_id.to_index(); + // When running with dispatch_s enabled: + // - dispatch_d must notify dispatch_s that a go signal can be sent + // - dispatch_s then mcasts the go signal to all workers. + // When running without dispatch_s: + // - dispatch_d handles sending the go signal to all workers + // There is no need for dispatch_d to barrier before sending the dispatch_s notification or go signal, + // since this go signal is not preceeded by NOC txns for program config data + if (DispatchQueryManager::instance().dispatch_s_enabled()) { + uint16_t index_bitmask = 1 << sub_device_index; + go_signal_cmd_sequence.add_notify_dispatch_s_go_signal_cmd( + 0, /* wait */ + index_bitmask /* index_bitmask */); // When running on sub devices, we must account for this + } go_signal_cmd_sequence.add_dispatch_go_signal_mcast( expected_num_workers_completed, *reinterpret_cast(&run_program_go_signal), dispatch_message_addr, - send_mcast ? cq.device()->num_noc_mcast_txns(SubDeviceId{0}) : 0, - send_unicasts ? ((num_unicast_txns > 0) ? num_unicast_txns : cq.device()->num_noc_unicast_txns(SubDeviceId{0})) - : 0, - 0, /* noc_data_start_idx */ + send_mcast ? device->num_noc_mcast_txns(sub_device_id) : 0, + send_unicasts ? ((num_unicast_txns > 0) ? num_unicast_txns : device->num_noc_unicast_txns(sub_device_id)) : 0, + device->noc_data_start_index(sub_device_id, send_mcast, send_unicasts), /* noc_data_start_idx */ DispatcherSelect::DISPATCH_SLAVE); - manager.issue_queue_push_back(cmd_sequence_sizeB, cq.id()); + sysmem_manager.issue_queue_push_back(cmd_sequence_sizeB, cq_id); - manager.fetch_queue_reserve_back(cq.id()); - manager.fetch_queue_write(cmd_sequence_sizeB, cq.id()); + sysmem_manager.fetch_queue_reserve_back(cq_id); + sysmem_manager.fetch_queue_write(cmd_sequence_sizeB, cq_id); } -} // namespace experimental - } // namespace tt::tt_metal::distributed diff --git a/tt_metal/distributed/mesh_workload_utils.hpp b/tt_metal/distributed/mesh_workload_utils.hpp index e6b0429dd54..1461aad13f8 100644 --- a/tt_metal/distributed/mesh_workload_utils.hpp +++ b/tt_metal/distributed/mesh_workload_utils.hpp @@ -4,30 +4,19 @@ #include +// Utility functions for dispatch MeshWorkloads +// Used by MeshCommandQueue namespace tt::tt_metal::distributed { -namespace experimental { -// Utility functions for writing program dispatch commands -// and go signals through the per device CQ. -// Usage of these functions is temporary, until the MeshCQ -// can function independently and support MeshBuffer reads and -// writes. -void write_program_commands( - CommandQueue& cq, - ProgramCommandSequence& program_cmd_seq, - uint32_t num_active_cores_in_program, - SubDeviceId sub_device_id, - bool stall_first, - bool stall_before_program, - bool blocking); - void write_go_signal( - CommandQueue& cq, + uint8_t cq_id, + IDevice* device, + SubDeviceId sub_device_id, + SystemMemoryManager& sysmem_manager, uint32_t expected_num_workers_completed, CoreCoord dispatch_core, bool send_mcast, bool send_unicasts, int num_unicast_txns = -1); -} // namespace experimental } // namespace tt::tt_metal::distributed diff --git a/tt_metal/impl/CMakeLists.txt b/tt_metal/impl/CMakeLists.txt index c72409857bf..46a2578a2af 100644 --- a/tt_metal/impl/CMakeLists.txt +++ b/tt_metal/impl/CMakeLists.txt @@ -47,6 +47,7 @@ set(IMPL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/trace/trace.cpp ${CMAKE_CURRENT_SOURCE_DIR}/trace/trace_buffer.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event/event.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/event/dispatch.cpp ${CMAKE_CURRENT_SOURCE_DIR}/flatbuffer/base_types_from_flatbuffer.cpp ${CMAKE_CURRENT_SOURCE_DIR}/flatbuffer/base_types_to_flatbuffer.cpp ${CMAKE_CURRENT_SOURCE_DIR}/flatbuffer/buffer_types_from_flatbuffer.cpp diff --git a/tt_metal/impl/buffers/dispatch.hpp b/tt_metal/impl/buffers/dispatch.hpp index 15c3fa6e440..c2064fce6a4 100644 --- a/tt_metal/impl/buffers/dispatch.hpp +++ b/tt_metal/impl/buffers/dispatch.hpp @@ -8,6 +8,7 @@ #include #include #include "buffer.hpp" +#include "tt_metal/impl/event/dispatch.hpp" namespace tt::tt_metal { @@ -44,17 +45,6 @@ struct ReadBufferDescriptor { starting_host_page_id(starting_host_page_id) {} }; -// Used so host knows data in completion queue is just an event ID -struct ReadEventDescriptor { - uint32_t event_id; - uint32_t global_offset; - - explicit ReadEventDescriptor(uint32_t event) : event_id(event), global_offset(0) {} - - void set_global_offset(uint32_t offset) { global_offset = offset; } - uint32_t get_global_event_id() { return global_offset + event_id; } -}; - using CompletionReaderVariant = std::variant; // Contains helper functions to interface with buffers on device diff --git a/tt_metal/impl/dispatch/dispatch_core_manager.cpp b/tt_metal/impl/dispatch/dispatch_core_manager.cpp index 09b8f7e4b4a..401172737e9 100644 --- a/tt_metal/impl/dispatch/dispatch_core_manager.cpp +++ b/tt_metal/impl/dispatch/dispatch_core_manager.cpp @@ -225,6 +225,11 @@ bool dispatch_core_manager::is_dispatcher_s_core_allocated(chip_id_t device_id, return assignment.dispatcher_s.has_value(); } +bool dispatch_core_manager::is_dispatcher_d_core_allocated(chip_id_t device_id, uint16_t channel, uint8_t cq_id) { + dispatch_core_placement_t& assignment = this->dispatch_core_assignments[device_id][channel][cq_id]; + return assignment.dispatcher_d.has_value(); +} + const tt_cxy_pair& dispatch_core_manager::dispatcher_d_core(chip_id_t device_id, uint16_t channel, uint8_t cq_id) { dispatch_core_placement_t& assignment = this->dispatch_core_assignments[device_id][channel][cq_id]; if (assignment.dispatcher_d.has_value()) { diff --git a/tt_metal/impl/dispatch/dispatch_query_manager.cpp b/tt_metal/impl/dispatch/dispatch_query_manager.cpp index e49af46ef7e..9eef6cbc72a 100644 --- a/tt_metal/impl/dispatch/dispatch_query_manager.cpp +++ b/tt_metal/impl/dispatch/dispatch_query_manager.cpp @@ -6,6 +6,8 @@ #include "tt_cluster.hpp" +using dispatch_core_mgr = tt::tt_metal::dispatch_core_manager; + namespace { tt::tt_metal::DispatchCoreConfig dispatch_core_config() { @@ -13,7 +15,7 @@ tt::tt_metal::DispatchCoreConfig dispatch_core_config() { tt::tt_metal::DispatchCoreConfig first_dispatch_core_config; for (chip_id_t device_id = 0; device_id < tt::Cluster::instance().number_of_devices(); device_id++) { - dispatch_core_config = tt::tt_metal::dispatch_core_manager::instance().get_dispatch_core_config(device_id); + dispatch_core_config = dispatch_core_mgr::instance().get_dispatch_core_config(device_id); if (device_id == 0) { first_dispatch_core_config = dispatch_core_config; } else { @@ -26,6 +28,36 @@ tt::tt_metal::DispatchCoreConfig dispatch_core_config() { return dispatch_core_config; } +tt_cxy_pair dispatch_core(uint8_t cq_id) { + tt_cxy_pair dispatch_core = tt_cxy_pair(0, 0, 0); + std::optional first_dispatch_core = std::nullopt; + for (chip_id_t device_id = 0; device_id < tt::Cluster::instance().number_of_devices(); device_id++) { + uint16_t channel = tt::Cluster::instance().get_assigned_channel_for_device(device_id); + if (tt::Cluster::instance().get_associated_mmio_device(device_id) == device_id) { + // Dispatch core is not allocated on this MMIO device, skip it + if (not dispatch_core_mgr::instance().is_dispatcher_core_allocated(device_id, channel, cq_id)) { + continue; + } + dispatch_core = dispatch_core_mgr::instance().dispatcher_core(device_id, channel, cq_id); + } else { + // Dispatch core is not allocated on this Non-MMIO device, skip it + if (not dispatch_core_mgr::instance().is_dispatcher_d_core_allocated(device_id, channel, cq_id)) { + continue; + } + dispatch_core = dispatch_core_mgr::instance().dispatcher_d_core(device_id, channel, cq_id); + } + if (not first_dispatch_core.has_value()) { + first_dispatch_core = dispatch_core; + } else { + TT_FATAL( + dispatch_core.x == first_dispatch_core.value().x and dispatch_core.y == first_dispatch_core.value().y, + "Expected the Dispatch Cores to be consistent across physical devices"); + } + } + TT_FATAL(first_dispatch_core.has_value(), "Could not find the dispatch core for {}", cq_id); + return dispatch_core; +} + tt::tt_metal::DispatchQueryManager* inst = nullptr; } // namespace @@ -60,6 +92,8 @@ void DispatchQueryManager::reset(uint8_t num_hw_cqs) { distributed_dispatcher_ = (num_hw_cqs == 1 and dispatch_core_config_.get_dispatch_core_type() == DispatchCoreType::ETH); go_signal_noc_ = dispatch_s_enabled_ ? NOC::NOC_1 : NOC::NOC_0; + // Reset the dispatch cores reported by the manager. Will be re-populated when the associated query is made + dispatch_cores_ = {}; } const DispatchCoreConfig& DispatchQueryManager::get_dispatch_core_config() const { return dispatch_core_config_; } @@ -72,6 +106,19 @@ const std::vector& DispatchQueryManager::get_logical_dispatch_cores(u return tt::get_logical_dispatch_cores(device_id, num_hw_cqs_, dispatch_core_config_); } +tt_cxy_pair DispatchQueryManager::get_dispatch_core(uint8_t cq_id) const { + if (dispatch_cores_.empty()) { + for (auto cq = 0; cq < num_hw_cqs_; cq++) { + // Populate when queried. Statically allocating at + // the start of the process causes the dispatch core + // order to change, which leads to lower performance + // with ethernet dispatch. + dispatch_cores_.push_back(dispatch_core(cq)); + } + } + return dispatch_cores_[cq_id]; +} + DispatchQueryManager::DispatchQueryManager(uint8_t num_hw_cqs) { this->reset(num_hw_cqs); } } // namespace tt::tt_metal diff --git a/tt_metal/impl/dispatch/dispatch_query_manager.hpp b/tt_metal/impl/dispatch/dispatch_query_manager.hpp index e01cae1d068..9435871461f 100644 --- a/tt_metal/impl/dispatch/dispatch_query_manager.hpp +++ b/tt_metal/impl/dispatch/dispatch_query_manager.hpp @@ -31,6 +31,7 @@ class DispatchQueryManager { const DispatchCoreConfig& get_dispatch_core_config() const; const std::vector& get_logical_storage_cores(uint32_t device_id) const; const std::vector& get_logical_dispatch_cores(uint32_t device_id) const; + tt_cxy_pair get_dispatch_core(uint8_t cq_id) const; private: void reset(uint8_t num_hw_cqs); @@ -41,6 +42,9 @@ class DispatchQueryManager { NOC go_signal_noc_ = NOC::NOC_0; uint8_t num_hw_cqs_ = 0; DispatchCoreConfig dispatch_core_config_; + // Make this mutable, since this is JIT populated + // through a const instance when queried + mutable std::vector dispatch_cores_; }; } // namespace tt::tt_metal diff --git a/tt_metal/impl/dispatch/hardware_command_queue.cpp b/tt_metal/impl/dispatch/hardware_command_queue.cpp index ed24132819c..8a72db6e742 100644 --- a/tt_metal/impl/dispatch/hardware_command_queue.cpp +++ b/tt_metal/impl/dispatch/hardware_command_queue.cpp @@ -399,7 +399,7 @@ void HWCommandQueue::enqueue_program(Program& program, bool blocking) { } void HWCommandQueue::enqueue_record_event( - const std::shared_ptr& event, bool clear_count, tt::stl::Span sub_device_ids) { + const std::shared_ptr& event, tt::stl::Span sub_device_ids) { ZoneScopedN("HWCommandQueue_enqueue_record_event"); TT_FATAL(!this->manager.get_bypass_mode(), "Enqueue Record Event cannot be used with tracing"); @@ -413,38 +413,22 @@ void HWCommandQueue::enqueue_record_event( event->ready = true; // what does this mean??? sub_device_ids = buffer_dispatch::select_sub_device_ids(this->device_, sub_device_ids); - - auto command = EnqueueRecordEventCommand( - this->id_, - this->device_, - this->noc_index_, - this->manager, + event_dispatch::issue_record_event_commands( + device_, event->event_id, - this->expected_num_workers_completed, + id_, + device_->num_hw_cqs(), + this->manager, sub_device_ids, - clear_count, - true); - this->enqueue_command(command, false, sub_device_ids); - - if (clear_count) { - for (const auto& id : sub_device_ids) { - this->expected_num_workers_completed[id.to_index()] = 0; - } - } + this->expected_num_workers_completed); this->issued_completion_q_reads.push( std::make_shared(std::in_place_type, event->event_id)); this->increment_num_entries_in_completion_q(); } -void HWCommandQueue::enqueue_wait_for_event(const std::shared_ptr& sync_event, bool clear_count) { +void HWCommandQueue::enqueue_wait_for_event(const std::shared_ptr& sync_event) { ZoneScopedN("HWCommandQueue_enqueue_wait_for_event"); - - auto command = EnqueueWaitForEventCommand(this->id_, this->device_, this->manager, *sync_event, clear_count); - this->enqueue_command(command, false, {}); - - if (clear_count) { - this->manager.reset_event_id(this->id_); - } + event_dispatch::issue_wait_for_event_commands(id_, sync_event->cq_id, this->manager, sync_event->event_id); } void HWCommandQueue::enqueue_trace(const uint32_t trace_id, bool blocking) { @@ -528,29 +512,8 @@ void HWCommandQueue::read_completion_queue() { this->exit_condition); } else if constexpr (std::is_same_v) { ZoneScopedN("CompletionQueueReadEvent"); - uint32_t read_ptr = this->manager.get_completion_queue_read_ptr(this->id_); - thread_local static std::vector dispatch_cmd_and_event( - (sizeof(CQDispatchCmd) + DispatchSettings::EVENT_PADDED_SIZE) / sizeof(uint32_t)); - tt::Cluster::instance().read_sysmem( - dispatch_cmd_and_event.data(), - sizeof(CQDispatchCmd) + DispatchSettings::EVENT_PADDED_SIZE, - read_ptr, - mmio_device_id, - channel); - uint32_t event_completed = dispatch_cmd_and_event[sizeof(CQDispatchCmd) / sizeof(uint32_t)]; - - TT_ASSERT( - event_completed == read_descriptor.event_id, - "Event Order Issue: expected to read back completion signal for event {} but got {}!", - read_descriptor.event_id, - event_completed); - this->manager.completion_queue_pop_front(1, this->id_); - this->manager.set_last_completed_event(this->id_, read_descriptor.get_global_event_id()); - log_trace( - LogAlways, - "Completion queue popped event {} (global: {})", - event_completed, - read_descriptor.get_global_event_id()); + event_dispatch::read_events_from_completion_queue( + read_descriptor, mmio_device_id, channel, this->id_, this->manager); } }, read_descriptor); @@ -570,7 +533,7 @@ void HWCommandQueue::finish(tt::stl::Span sub_device_ids) { ZoneScopedN("HWCommandQueue_finish"); tt::log_debug(tt::LogDispatch, "Finish for command queue {}", this->id_); std::shared_ptr event = std::make_shared(); - this->enqueue_record_event(event, false, sub_device_ids); + this->enqueue_record_event(event, sub_device_ids); if (tt::llrt::RunTimeOptions::get_instance().get_test_mode_enabled()) { while (this->num_entries_in_completion_q > this->num_completed_completion_q_reads) { if (DPrintServerHangDetected()) { diff --git a/tt_metal/impl/dispatch/hardware_command_queue.hpp b/tt_metal/impl/dispatch/hardware_command_queue.hpp index b281934db54..eeb8c1b9fe8 100644 --- a/tt_metal/impl/dispatch/hardware_command_queue.hpp +++ b/tt_metal/impl/dispatch/hardware_command_queue.hpp @@ -72,10 +72,8 @@ class HWCommandQueue : public CommandQueue { tt::stl::Span sub_device_ids = {}) override; void enqueue_record_event( - const std::shared_ptr& event, - bool clear_count = false, - tt::stl::Span sub_device_ids = {}) override; - void enqueue_wait_for_event(const std::shared_ptr& sync_event, bool clear_count = false) override; + const std::shared_ptr& event, tt::stl::Span sub_device_ids = {}) override; + void enqueue_wait_for_event(const std::shared_ptr& sync_event) override; void enqueue_write_buffer( const std::variant, std::shared_ptr>& buffer, diff --git a/tt_metal/impl/dispatch/host_runtime_commands.cpp b/tt_metal/impl/dispatch/host_runtime_commands.cpp index e1e0dfa8b5b..368bc663199 100644 --- a/tt_metal/impl/dispatch/host_runtime_commands.cpp +++ b/tt_metal/impl/dispatch/host_runtime_commands.cpp @@ -173,166 +173,6 @@ void EnqueueProgramCommand::process() { program.set_program_binary_status(device->id(), ProgramBinaryStatus::Committed); } -EnqueueRecordEventCommand::EnqueueRecordEventCommand( - uint32_t command_queue_id, - IDevice* device, - NOC noc_index, - SystemMemoryManager& manager, - uint32_t event_id, - tt::stl::Span expected_num_workers_completed, - tt::stl::Span sub_device_ids, - bool clear_count, - bool write_barrier) : - command_queue_id(command_queue_id), - device(device), - noc_index(noc_index), - manager(manager), - event_id(event_id), - expected_num_workers_completed(expected_num_workers_completed), - sub_device_ids(sub_device_ids), - clear_count(clear_count), - write_barrier(write_barrier) {} - -void EnqueueRecordEventCommand::process() { - std::vector event_payload(DispatchSettings::EVENT_PADDED_SIZE / sizeof(uint32_t), 0); - event_payload[0] = this->event_id; - - uint32_t pcie_alignment = hal.get_alignment(HalMemType::HOST); - uint32_t l1_alignment = hal.get_alignment(HalMemType::L1); - uint8_t num_hw_cqs = - this->device->num_hw_cqs(); // Device initialize asserts that there can only be a maximum of 2 HW CQs - uint32_t packed_event_payload_sizeB = - align(sizeof(CQDispatchCmd) + num_hw_cqs * sizeof(CQDispatchWritePackedUnicastSubCmd), l1_alignment) + - (align(DispatchSettings::EVENT_PADDED_SIZE, l1_alignment) * num_hw_cqs); - uint32_t packed_write_sizeB = align(sizeof(CQPrefetchCmd) + packed_event_payload_sizeB, pcie_alignment); - uint32_t num_worker_counters = this->sub_device_ids.size(); - - uint32_t cmd_sequence_sizeB = - hal.get_alignment(HalMemType::HOST) * - num_worker_counters + // CQ_PREFETCH_CMD_RELAY_INLINE + CQ_DISPATCH_CMD_WAIT - packed_write_sizeB + // CQ_PREFETCH_CMD_RELAY_INLINE + CQ_DISPATCH_CMD_WRITE_PACKED + unicast subcmds + event - // payload - align( - sizeof(CQPrefetchCmd) + sizeof(CQDispatchCmd) + DispatchSettings::EVENT_PADDED_SIZE, - pcie_alignment); // CQ_PREFETCH_CMD_RELAY_INLINE + CQ_DISPATCH_CMD_WRITE_LINEAR_HOST + event ID - - void* cmd_region = this->manager.issue_queue_reserve(cmd_sequence_sizeB, this->command_queue_id); - - HugepageDeviceCommand command_sequence(cmd_region, cmd_sequence_sizeB); - - CoreType dispatch_core_type = dispatch_core_manager::instance().get_dispatch_core_type(this->device->id()); - uint32_t dispatch_message_base_addr = - DispatchMemMap::get(dispatch_core_type) - .get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_MESSAGE); - - uint32_t last_index = num_worker_counters - 1; - // We only need the write barrier for the last wait cmd - for (uint32_t i = 0; i < last_index; ++i) { - auto offset_index = this->sub_device_ids[i].to_index(); - uint32_t dispatch_message_addr = - dispatch_message_base_addr + - DispatchMemMap::get(dispatch_core_type).get_dispatch_message_offset(offset_index); - command_sequence.add_dispatch_wait( - false, dispatch_message_addr, this->expected_num_workers_completed[offset_index], this->clear_count); - } - auto offset_index = this->sub_device_ids[last_index].to_index(); - uint32_t dispatch_message_addr = - dispatch_message_base_addr + - DispatchMemMap::get(dispatch_core_type).get_dispatch_message_offset(offset_index); - command_sequence.add_dispatch_wait( - this->write_barrier, - dispatch_message_addr, - this->expected_num_workers_completed[offset_index], - this->clear_count); - - CoreType core_type = dispatch_core_manager::instance().get_dispatch_core_type(this->device->id()); - uint16_t channel = tt::Cluster::instance().get_assigned_channel_for_device(this->device->id()); - std::vector unicast_sub_cmds(num_hw_cqs); - std::vector> event_payloads(num_hw_cqs); - - for (uint8_t cq_id = 0; cq_id < num_hw_cqs; cq_id++) { - tt_cxy_pair dispatch_location; - if (device->is_mmio_capable()) { - dispatch_location = dispatch_core_manager::instance().dispatcher_core(this->device->id(), channel, cq_id); - } else { - dispatch_location = dispatch_core_manager::instance().dispatcher_d_core(this->device->id(), channel, cq_id); - } - - CoreCoord dispatch_virtual_core = this->device->virtual_core_from_logical_core(dispatch_location, core_type); - unicast_sub_cmds[cq_id] = CQDispatchWritePackedUnicastSubCmd{ - .noc_xy_addr = this->device->get_noc_unicast_encoding(this->noc_index, dispatch_virtual_core)}; - event_payloads[cq_id] = {event_payload.data(), event_payload.size() * sizeof(uint32_t)}; - } - - uint32_t completion_q0_last_event_addr = DispatchMemMap::get(core_type).get_device_command_queue_addr( - CommandQueueDeviceAddrType::COMPLETION_Q0_LAST_EVENT); - uint32_t completion_q1_last_event_addr = DispatchMemMap::get(core_type).get_device_command_queue_addr( - CommandQueueDeviceAddrType::COMPLETION_Q1_LAST_EVENT); - uint32_t address = this->command_queue_id == 0 ? completion_q0_last_event_addr : completion_q1_last_event_addr; - const uint32_t packed_write_max_unicast_sub_cmds = get_packed_write_max_unicast_sub_cmds(this->device); - command_sequence.add_dispatch_write_packed( - num_hw_cqs, - address, - DispatchSettings::EVENT_PADDED_SIZE, - packed_event_payload_sizeB, - unicast_sub_cmds, - event_payloads, - packed_write_max_unicast_sub_cmds); - - bool flush_prefetch = true; - command_sequence.add_dispatch_write_host( - flush_prefetch, DispatchSettings::EVENT_PADDED_SIZE, true, event_payload.data()); - - this->manager.issue_queue_push_back(cmd_sequence_sizeB, this->command_queue_id); - - this->manager.fetch_queue_reserve_back(this->command_queue_id); - this->manager.fetch_queue_write(cmd_sequence_sizeB, this->command_queue_id); -} - -EnqueueWaitForEventCommand::EnqueueWaitForEventCommand( - uint32_t command_queue_id, - IDevice* device, - SystemMemoryManager& manager, - const Event& sync_event, - bool clear_count) : - command_queue_id(command_queue_id), - device(device), - manager(manager), - sync_event(sync_event), - clear_count(clear_count) { - this->dispatch_core_type = dispatch_core_manager::instance().get_dispatch_core_type(device->id()); - // Should not be encountered under normal circumstances (record, wait) unless user is modifying sync event ID. - // TT_ASSERT(command_queue_id != sync_event.cq_id || event != sync_event.event_id, - // "EnqueueWaitForEventCommand cannot wait on it's own event id on the same CQ. Event ID: {} CQ ID: {}", - // event, command_queue_id); -} - -void EnqueueWaitForEventCommand::process() { - uint32_t cmd_sequence_sizeB = - hal.get_alignment(HalMemType::HOST); // CQ_PREFETCH_CMD_RELAY_INLINE + CQ_DISPATCH_CMD_WAIT - - void* cmd_region = this->manager.issue_queue_reserve(cmd_sequence_sizeB, this->command_queue_id); - - HugepageDeviceCommand command_sequence(cmd_region, cmd_sequence_sizeB); - uint32_t completion_q0_last_event_addr = - DispatchMemMap::get(this->dispatch_core_type) - .get_device_command_queue_addr(CommandQueueDeviceAddrType::COMPLETION_Q0_LAST_EVENT); - uint32_t completion_q1_last_event_addr = - DispatchMemMap::get(this->dispatch_core_type) - .get_device_command_queue_addr(CommandQueueDeviceAddrType::COMPLETION_Q1_LAST_EVENT); - - uint32_t last_completed_event_address = - sync_event.cq_id == 0 ? completion_q0_last_event_addr : completion_q1_last_event_addr; - - command_sequence.add_dispatch_wait(false, last_completed_event_address, sync_event.event_id, this->clear_count); - - this->manager.issue_queue_push_back(cmd_sequence_sizeB, this->command_queue_id); - - this->manager.fetch_queue_reserve_back(this->command_queue_id); - - this->manager.fetch_queue_write(cmd_sequence_sizeB, this->command_queue_id); -} - EnqueueTraceCommand::EnqueueTraceCommand( uint32_t command_queue_id, IDevice* device, @@ -584,7 +424,7 @@ void EnqueueProgram(CommandQueue& cq, Program& program, bool blocking) { void EnqueueRecordEvent( CommandQueue& cq, const std::shared_ptr& event, tt::stl::Span sub_device_ids) { detail::DispatchStateCheck(true); - cq.enqueue_record_event(event, false, sub_device_ids); + cq.enqueue_record_event(event, sub_device_ids); } void EnqueueWaitForEvent(CommandQueue& cq, const std::shared_ptr& event) { diff --git a/tt_metal/impl/dispatch/host_runtime_commands.hpp b/tt_metal/impl/dispatch/host_runtime_commands.hpp index 655a379deb1..6a62c3a2053 100644 --- a/tt_metal/impl/dispatch/host_runtime_commands.hpp +++ b/tt_metal/impl/dispatch/host_runtime_commands.hpp @@ -96,61 +96,6 @@ class EnqueueProgramCommand : public Command { constexpr bool has_side_effects() { return true; } }; -class EnqueueRecordEventCommand : public Command { -private: - uint32_t command_queue_id; - IDevice* device; - NOC noc_index; - SystemMemoryManager& manager; - uint32_t event_id; - tt::stl::Span expected_num_workers_completed; - tt::stl::Span sub_device_ids; - bool clear_count; - bool write_barrier; - -public: - EnqueueRecordEventCommand( - uint32_t command_queue_id, - IDevice* device, - NOC noc_index, - SystemMemoryManager& manager, - uint32_t event_id, - tt::stl::Span expected_num_workers_completed, - tt::stl::Span sub_device_ids, - bool clear_count = false, - bool write_barrier = true); - - void process(); - - EnqueueCommandType type() { return EnqueueCommandType::ENQUEUE_RECORD_EVENT; } - - constexpr bool has_side_effects() { return false; } -}; - -class EnqueueWaitForEventCommand : public Command { -private: - uint32_t command_queue_id; - IDevice* device; - SystemMemoryManager& manager; - const Event& sync_event; - CoreType dispatch_core_type; - bool clear_count; - -public: - EnqueueWaitForEventCommand( - uint32_t command_queue_id, - IDevice* device, - SystemMemoryManager& manager, - const Event& sync_event, - bool clear_count = false); - - void process(); - - EnqueueCommandType type() { return EnqueueCommandType::ENQUEUE_WAIT_FOR_EVENT; } - - constexpr bool has_side_effects() { return false; } -}; - class EnqueueTraceCommand : public Command { private: uint32_t command_queue_id; diff --git a/tt_metal/impl/event/dispatch.cpp b/tt_metal/impl/event/dispatch.cpp new file mode 100644 index 00000000000..36a62181c60 --- /dev/null +++ b/tt_metal/impl/event/dispatch.cpp @@ -0,0 +1,183 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "tt_metal/impl/event/dispatch.hpp" +#include +#include "tt_metal/impl/dispatch/dispatch_query_manager.hpp" +#include + +namespace tt::tt_metal { + +namespace event_dispatch { + +namespace { +uint32_t get_packed_write_max_unicast_sub_cmds(IDevice* device) { + return device->compute_with_storage_grid_size().x * device->compute_with_storage_grid_size().y; +} +} // namespace + +void issue_record_event_commands( + IDevice* device, + uint32_t event_id, + uint8_t cq_id, + uint32_t num_command_queues, + SystemMemoryManager& manager, + tt::stl::Span sub_device_ids, + tt::stl::Span expected_num_workers_completed, + bool notify_host) { + std::vector event_payload(DispatchSettings::EVENT_PADDED_SIZE / sizeof(uint32_t), 0); + event_payload[0] = event_id; + + uint32_t pcie_alignment = hal.get_alignment(HalMemType::HOST); + uint32_t l1_alignment = hal.get_alignment(HalMemType::L1); + uint32_t packed_event_payload_sizeB = + align(sizeof(CQDispatchCmd) + num_command_queues * sizeof(CQDispatchWritePackedUnicastSubCmd), l1_alignment) + + (align(DispatchSettings::EVENT_PADDED_SIZE, l1_alignment) * num_command_queues); + uint32_t packed_write_sizeB = align(sizeof(CQPrefetchCmd) + packed_event_payload_sizeB, pcie_alignment); + uint32_t num_worker_counters = sub_device_ids.size(); + + uint32_t cmd_sequence_sizeB = + hal.get_alignment(HalMemType::HOST) * + num_worker_counters + // CQ_PREFETCH_CMD_RELAY_INLINE + CQ_DISPATCH_CMD_WAIT + packed_write_sizeB + // CQ_PREFETCH_CMD_RELAY_INLINE + CQ_DISPATCH_CMD_WRITE_PACKED + + // unicast subcmds + event payload + align( + sizeof(CQPrefetchCmd) + sizeof(CQDispatchCmd) + DispatchSettings::EVENT_PADDED_SIZE, + pcie_alignment) * + notify_host; // CQ_PREFETCH_CMD_RELAY_INLINE + CQ_DISPATCH_CMD_WRITE_LINEAR_HOST + event ID ===> Write + // event notification back to host, if requested by user + + void* cmd_region = manager.issue_queue_reserve(cmd_sequence_sizeB, cq_id); + + HugepageDeviceCommand command_sequence(cmd_region, cmd_sequence_sizeB); + + auto dispatch_core_config = DispatchQueryManager::instance().get_dispatch_core_config(); + CoreType dispatch_core_type = dispatch_core_config.get_core_type(); + + uint32_t dispatch_message_base_addr = + DispatchMemMap::get(dispatch_core_type) + .get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_MESSAGE); + + uint32_t last_index = num_worker_counters - 1; + for (uint32_t i = 0; i < num_worker_counters; ++i) { + auto offset_index = sub_device_ids[i].to_index(); + uint32_t dispatch_message_addr = + dispatch_message_base_addr + + DispatchMemMap::get(dispatch_core_type).get_dispatch_message_offset(offset_index); + // recording an event does not have any side-effects on the dispatch completion count + // hence clear_count is set to false, i.e. the number of workers on the dispatcher is + // not reset + // We only need the write barrier for the last wait cmd. + command_sequence.add_dispatch_wait( + (i == num_worker_counters - 1), /* write_barrier ensures that all writes initiated by the dispatcher are + flushed before the event is recorded */ + dispatch_message_addr, + expected_num_workers_completed[offset_index], + false /* recording an event does not have any side-effects on the dispatch completion count */); + } + + std::vector unicast_sub_cmds(num_command_queues); + std::vector> event_payloads(num_command_queues); + + for (auto cq_id = 0; cq_id < num_command_queues; cq_id++) { + tt_cxy_pair dispatch_location = DispatchQueryManager::instance().get_dispatch_core(cq_id); + CoreCoord dispatch_virtual_core = device->virtual_core_from_logical_core(dispatch_location, dispatch_core_type); + unicast_sub_cmds[cq_id] = CQDispatchWritePackedUnicastSubCmd{ + .noc_xy_addr = device->get_noc_unicast_encoding(dispatch_downstream_noc, dispatch_virtual_core)}; + event_payloads[cq_id] = {event_payload.data(), event_payload.size() * sizeof(uint32_t)}; + } + + uint32_t completion_q0_last_event_addr = + DispatchMemMap::get(dispatch_core_type) + .get_device_command_queue_addr(CommandQueueDeviceAddrType::COMPLETION_Q0_LAST_EVENT); + uint32_t completion_q1_last_event_addr = + DispatchMemMap::get(dispatch_core_type) + .get_device_command_queue_addr(CommandQueueDeviceAddrType::COMPLETION_Q1_LAST_EVENT); + uint32_t address = cq_id == 0 ? completion_q0_last_event_addr : completion_q1_last_event_addr; + const uint32_t packed_write_max_unicast_sub_cmds = get_packed_write_max_unicast_sub_cmds(device); + command_sequence.add_dispatch_write_packed( + num_command_queues, + address, + DispatchSettings::EVENT_PADDED_SIZE, + packed_event_payload_sizeB, + unicast_sub_cmds, + event_payloads, + packed_write_max_unicast_sub_cmds); + + if (notify_host) { + bool flush_prefetch = true; + command_sequence.add_dispatch_write_host( + flush_prefetch, DispatchSettings::EVENT_PADDED_SIZE, true, event_payload.data()); + } + + manager.issue_queue_push_back(cmd_sequence_sizeB, cq_id); + + manager.fetch_queue_reserve_back(cq_id); + manager.fetch_queue_write(cmd_sequence_sizeB, cq_id); +} + +void issue_wait_for_event_commands( + uint8_t cq_id, uint8_t event_cq_id, SystemMemoryManager& sysmem_manager, uint32_t event_id) { + uint32_t cmd_sequence_sizeB = + hal.get_alignment(HalMemType::HOST); // CQ_PREFETCH_CMD_RELAY_INLINE + CQ_DISPATCH_CMD_WAIT + + auto dispatch_core_config = DispatchQueryManager::instance().get_dispatch_core_config(); + CoreType dispatch_core_type = dispatch_core_config.get_core_type(); + + void* cmd_region = sysmem_manager.issue_queue_reserve(cmd_sequence_sizeB, cq_id); + + HugepageDeviceCommand command_sequence(cmd_region, cmd_sequence_sizeB); + uint32_t completion_q0_last_event_addr = + DispatchMemMap::get(dispatch_core_type) + .get_device_command_queue_addr(CommandQueueDeviceAddrType::COMPLETION_Q0_LAST_EVENT); + uint32_t completion_q1_last_event_addr = + DispatchMemMap::get(dispatch_core_type) + .get_device_command_queue_addr(CommandQueueDeviceAddrType::COMPLETION_Q1_LAST_EVENT); + + uint32_t last_completed_event_address = + event_cq_id == 0 ? completion_q0_last_event_addr : completion_q1_last_event_addr; + + command_sequence.add_dispatch_wait(false, last_completed_event_address, event_id, false); + + sysmem_manager.issue_queue_push_back(cmd_sequence_sizeB, cq_id); + + sysmem_manager.fetch_queue_reserve_back(cq_id); + + sysmem_manager.fetch_queue_write(cmd_sequence_sizeB, cq_id); +} + +void read_events_from_completion_queue( + ReadEventDescriptor& event_descriptor, + chip_id_t mmio_device_id, + uint16_t channel, + uint8_t cq_id, + SystemMemoryManager& sysmem_manager) { + uint32_t read_ptr = sysmem_manager.get_completion_queue_read_ptr(cq_id); + thread_local static std::vector dispatch_cmd_and_event( + (sizeof(CQDispatchCmd) + DispatchSettings::EVENT_PADDED_SIZE) / sizeof(uint32_t)); + tt::Cluster::instance().read_sysmem( + dispatch_cmd_and_event.data(), + sizeof(CQDispatchCmd) + DispatchSettings::EVENT_PADDED_SIZE, + read_ptr, + mmio_device_id, + channel); + uint32_t event_completed = dispatch_cmd_and_event[sizeof(CQDispatchCmd) / sizeof(uint32_t)]; + + TT_ASSERT( + event_completed == event_descriptor.event_id, + "Event Order Issue: expected to read back completion signal for event {} but got {}!", + event_descriptor.event_id, + event_completed); + sysmem_manager.completion_queue_pop_front(1, cq_id); + sysmem_manager.set_last_completed_event(cq_id, event_descriptor.get_global_event_id()); + log_trace( + LogAlways, + "Completion queue popped event {} (global: {})", + event_completed, + event_descriptor.get_global_event_id()); +} + +} // namespace event_dispatch + +} // namespace tt::tt_metal diff --git a/tt_metal/impl/event/dispatch.hpp b/tt_metal/impl/event/dispatch.hpp new file mode 100644 index 00000000000..461fd47018f --- /dev/null +++ b/tt_metal/impl/event/dispatch.hpp @@ -0,0 +1,48 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include + +namespace tt::tt_metal { + +// Used so host knows data in completion queue is just an event ID +struct ReadEventDescriptor { + uint32_t event_id; + uint32_t global_offset; + + explicit ReadEventDescriptor(uint32_t event) : event_id(event), global_offset(0) {} + + void set_global_offset(uint32_t offset) { global_offset = offset; } + uint32_t get_global_event_id() { return global_offset + event_id; } +}; + +namespace event_dispatch { + +void issue_record_event_commands( + IDevice* device, + uint32_t event_id, + uint8_t cq_id, + uint32_t num_command_queues, + SystemMemoryManager& manager, + tt::stl::Span sub_device_ids, + tt::stl::Span expected_num_workers_completed, + bool notify_host = true); + +void issue_wait_for_event_commands( + uint8_t cq_id, uint8_t event_cq_id, SystemMemoryManager& sysmem_manager, uint32_t event_id); + +void read_events_from_completion_queue( + ReadEventDescriptor& event_descriptor, + chip_id_t mmio_device_id, + uint16_t channel, + uint8_t cq_id, + SystemMemoryManager& sysmem_manager); + +} // namespace event_dispatch + +} // namespace tt::tt_metal diff --git a/tt_metal/impl/program/dispatch.cpp b/tt_metal/impl/program/dispatch.cpp index fcd9b76494d..67e9a1a2740 100644 --- a/tt_metal/impl/program/dispatch.cpp +++ b/tt_metal/impl/program/dispatch.cpp @@ -406,7 +406,8 @@ void insert_empty_program_dispatch_preamble_cmd(ProgramCommandSequence& program_ void insert_stall_cmds(ProgramCommandSequence& program_command_sequence, SubDeviceId sub_device_id, IDevice* device) { // Initialize stall command sequences for this program. - auto dispatch_core_type = dispatch_core_manager::instance().get_dispatch_core_type(device->id()); + auto dispatch_core_config = DispatchQueryManager::instance().get_dispatch_core_config(); + auto dispatch_core_type = dispatch_core_config.get_core_type(); uint32_t dispatch_message_addr = DispatchMemMap::get(dispatch_core_type) .get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_MESSAGE) + @@ -549,7 +550,8 @@ void assemble_runtime_args_commands( ProgramCommandSequence& program_command_sequence, Program& program, IDevice* device) { static const uint32_t packed_write_max_unicast_sub_cmds = get_packed_write_max_unicast_sub_cmds(device); NOC noc_index = dispatch_downstream_noc; - CoreType dispatch_core_type = dispatch_core_manager::instance().get_dispatch_core_type(device->id()); + auto dispatch_core_config = DispatchQueryManager::instance().get_dispatch_core_config(); + auto dispatch_core_type = dispatch_core_config.get_core_type(); const uint32_t max_prefetch_command_size = DispatchMemMap::get(dispatch_core_type).max_prefetch_command_size(); // Dispatch Commands to Unicast Unique Runtime Args to Workers @@ -812,7 +814,8 @@ void insert_write_packed_payloads( void assemble_device_commands( ProgramCommandSequence& program_command_sequence, Program& program, IDevice* device, SubDeviceId sub_device_id) { DeviceCommandCalculator calculator; - CoreType dispatch_core_type = dispatch_core_manager::instance().get_dispatch_core_type(device->id()); + auto dispatch_core_config = DispatchQueryManager::instance().get_dispatch_core_config(); + auto dispatch_core_type = dispatch_core_config.get_core_type(); NOC noc_index = dispatch_downstream_noc; const uint32_t max_prefetch_command_size = DispatchMemMap::get(dispatch_core_type).max_prefetch_command_size(); static const uint32_t packed_write_max_unicast_sub_cmds = get_packed_write_max_unicast_sub_cmds(device);