Skip to content

Commit

Permalink
#0: API Unification for Dervice and MeshDevice
Browse files Browse the repository at this point in the history
  • Loading branch information
cfjchu committed Jan 10, 2025
1 parent d46776a commit b2c2351
Show file tree
Hide file tree
Showing 17 changed files with 1,066 additions and 672 deletions.
1 change: 1 addition & 0 deletions tests/tt_metal/distributed/test_distributed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// SPDX-License-Identifier: Apache-2.0

#include "tests/tt_metal/distributed/distributed_fixture.hpp"
#include "tt_metal/distributed/system_mesh.hpp"

namespace tt::tt_metal::distributed::test {

Expand Down
54 changes: 27 additions & 27 deletions tests/tt_metal/distributed/test_mesh_workload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -518,18 +518,18 @@ TEST_F(MeshDevice_T3000, TestMeshWorkloadOnActiveEth) {
AddProgramToMeshWorkload(*workload, *programs[0], devices);
}
}
EnqueueMeshWorkload(mesh_device_->command_queue(), *workload, false);
EnqueueMeshWorkload(mesh_device_->mesh_command_queue(), *workload, false);
workloads.push_back(workload);
}
for (int i = 0; i < num_iters; i++) {
if (i % 100 == 0) {
log_info(tt::LogTest, "Run MeshWorkloads for iteration {}", i);
}
for (auto& workload : workloads) {
EnqueueMeshWorkload(mesh_device_->command_queue(), *workload, false);
EnqueueMeshWorkload(mesh_device_->mesh_command_queue(), *workload, false);
}
}
Finish(mesh_device_->command_queue());
Finish(mesh_device_->mesh_command_queue());
}

TEST_F(MeshDevice_T3000, TestMeshWorkloadMixedTensixEth) {
Expand Down Expand Up @@ -567,7 +567,7 @@ TEST_F(MeshDevice_T3000, TestMeshWorkloadMixedTensixEth) {
run_on_eth = !run_on_eth;
}
}
EnqueueMeshWorkload(mesh_device_->command_queue(), *workload, false);
EnqueueMeshWorkload(mesh_device_->mesh_command_queue(), *workload, false);
workloads.push_back(workload);
}

Expand All @@ -576,10 +576,10 @@ TEST_F(MeshDevice_T3000, TestMeshWorkloadMixedTensixEth) {
log_info(tt::LogTest, "Run MeshWorkloads for iteration {}", i);
}
for (auto& workload : workloads) {
EnqueueMeshWorkload(mesh_device_->command_queue(), *workload, false);
EnqueueMeshWorkload(mesh_device_->mesh_command_queue(), *workload, false);
}
}
Finish(mesh_device_->command_queue());
Finish(mesh_device_->mesh_command_queue());
}

TEST_F(MeshDevice_T3000, TestMeshWorkloadOnActiveEthRandomGridSize) {
Expand All @@ -605,18 +605,18 @@ TEST_F(MeshDevice_T3000, TestMeshWorkloadOnActiveEthRandomGridSize) {
AddProgramToMeshWorkload(*workload, *programs[0], devices);
}
}
EnqueueMeshWorkload(mesh_device_->command_queue(), *workload, false);
EnqueueMeshWorkload(mesh_device_->mesh_command_queue(), *workload, false);
workloads.push_back(workload);
}
for (int i = 0; i < num_iters; i++) {
if (i % 100 == 0) {
log_info(tt::LogTest, "Run MeshWorkloads for iteration {}", i);
}
for (auto& workload : workloads) {
EnqueueMeshWorkload(mesh_device_->command_queue(), *workload, false);
EnqueueMeshWorkload(mesh_device_->mesh_command_queue(), *workload, false);
}
}
Finish(mesh_device_->command_queue());
Finish(mesh_device_->mesh_command_queue());
}

TEST_F(MeshDevice_T3000, TestSimultaneousMeshWorkloads) {
Expand Down Expand Up @@ -647,7 +647,7 @@ TEST_F(MeshDevice_T3000, TestSimultaneousMeshWorkloads) {
AddProgramToMeshWorkload(*random_workload, *programs[i], devices_0);
AddProgramToMeshWorkload(*random_workload, *programs[i + 1], devices_1);
}
EnqueueMeshWorkload(mesh_device_->command_queue(), *random_workload, false);
EnqueueMeshWorkload(mesh_device_->mesh_command_queue(), *random_workload, false);
mesh_workloads.push_back(random_workload);
}
programs = create_random_programs(num_programs, mesh_device_->compute_with_storage_grid_size(), seed);
Expand All @@ -661,7 +661,7 @@ TEST_F(MeshDevice_T3000, TestSimultaneousMeshWorkloads) {
AddProgramToMeshWorkload(*random_workload, *programs[i + 1], devices_1);
AddProgramToMeshWorkload(*random_workload, *programs[i + 2], devices_2);
AddProgramToMeshWorkload(*random_workload, *programs[i + 3], devices_3);
EnqueueMeshWorkload(mesh_device_->command_queue(), *random_workload, false);
EnqueueMeshWorkload(mesh_device_->mesh_command_queue(), *random_workload, false);
mesh_workloads.push_back(random_workload);
}
programs = create_random_programs(num_heterogeneous_programs, mesh_device_->compute_with_storage_grid_size(), seed);
Expand All @@ -684,7 +684,7 @@ TEST_F(MeshDevice_T3000, TestSimultaneousMeshWorkloads) {
AddProgramToMeshWorkload(*random_workload, *programs[i + 5], devices_5);
AddProgramToMeshWorkload(*random_workload, *programs[i + 6], devices_6);
AddProgramToMeshWorkload(*random_workload, *programs[i + 7], devices_7);
EnqueueMeshWorkload(mesh_device_->command_queue(), *random_workload, false);
EnqueueMeshWorkload(mesh_device_->mesh_command_queue(), *random_workload, false);
mesh_workloads.push_back(random_workload);
}

Expand All @@ -693,10 +693,10 @@ TEST_F(MeshDevice_T3000, TestSimultaneousMeshWorkloads) {
log_info(tt::LogTest, "Run MeshWorkloads for iteration {}", i);
}
for (auto& workload : mesh_workloads) {
EnqueueMeshWorkload(mesh_device_->command_queue(), *workload, false);
EnqueueMeshWorkload(mesh_device_->mesh_command_queue(), *workload, false);
}
}
Finish(mesh_device_->command_queue());
Finish(mesh_device_->mesh_command_queue());
}

TEST_F(MeshDevice_T3000, TestRandomizedMeshWorkload) {
Expand All @@ -721,19 +721,19 @@ TEST_F(MeshDevice_T3000, TestRandomizedMeshWorkload) {
LogicalDeviceRange device_range = LogicalDeviceRange({0, 0}, {gen_x(rng), gen_y(rng)});
auto random_workload = std::make_shared<MeshWorkload>();
AddProgramToMeshWorkload(*random_workload, *programs[i], device_range);
EnqueueMeshWorkload(mesh_device_->command_queue(), *random_workload, false);
EnqueueMeshWorkload(mesh_device_->mesh_command_queue(), *random_workload, false);
mesh_workloads.push_back(random_workload);
}
for (int i = 0; i < num_iterations; i++) {
if (i % 100 == 0) {
log_info(tt::LogTest, "Run MeshWorkloads for iteration {}", i);
}
for (auto& workload : mesh_workloads) {
EnqueueMeshWorkload(mesh_device_->command_queue(), *workload, false);
EnqueueMeshWorkload(mesh_device_->mesh_command_queue(), *workload, false);
}
}
log_info(tt::LogTest, "Calling Finish");
Finish(mesh_device_->command_queue());
Finish(mesh_device_->mesh_command_queue());
}

TEST_F(MeshDevice_T3000, TestEltwiseBinaryMeshWorkload) {
Expand Down Expand Up @@ -764,7 +764,7 @@ TEST_F(MeshDevice_T3000, TestEltwiseBinaryMeshWorkload) {
}
// Run workload multiple times
for (int i = 0; i < 1000; i++) {
EnqueueMeshWorkload(mesh_device_->command_queue(), mesh_workload, false);
EnqueueMeshWorkload(mesh_device_->mesh_command_queue(), mesh_workload, false);
}

buffer_idx = 0;
Expand Down Expand Up @@ -880,7 +880,7 @@ TEST_F(MeshDevice_T3000, TestMeshWorkloadSanity) {
rtas[core.x][core.y].at(4) = ((iter % 2) + 1) * add_factor;
}
}
EnqueueMeshWorkload(mesh_device_->command_queue(), mesh_workload, false);
EnqueueMeshWorkload(mesh_device_->mesh_command_queue(), mesh_workload, false);
buffer_idx = 0;
for (auto device : mesh_device_->get_devices()) {
for (std::size_t col_idx = 0; col_idx < worker_grid_size.x; col_idx++) {
Expand Down Expand Up @@ -922,8 +922,8 @@ TEST_F(MeshDevice_T3000, TestMeshWorkloadCBUpdate) {
LogicalDeviceRange devices = LogicalDeviceRange({0, 0}, {4, 2});

AddProgramToMeshWorkload(mesh_workload, *program, devices);
EnqueueMeshWorkload(mesh_device_->command_queue(), mesh_workload, false);
Finish(mesh_device_->command_queue());
EnqueueMeshWorkload(mesh_device_->mesh_command_queue(), mesh_workload, false);
Finish(mesh_device_->mesh_command_queue());
verify_cb_config(mesh_device_, mesh_workload, cb_config_vector, cr_set);

std::vector<CBConfig> updated_cb_config_vector = cb_config_vector;
Expand All @@ -933,8 +933,8 @@ TEST_F(MeshDevice_T3000, TestMeshWorkloadCBUpdate) {
const uint32_t cb_size = cb_config.num_pages * cb_config.page_size;
UpdateCircularBufferTotalSize(mesh_workload.get_program_on_device_range(devices), cb_handles[cb_id], cb_size);
}
EnqueueMeshWorkload(mesh_device_->command_queue(), mesh_workload, false);
Finish(mesh_device_->command_queue());
EnqueueMeshWorkload(mesh_device_->mesh_command_queue(), mesh_workload, false);
Finish(mesh_device_->mesh_command_queue());
verify_cb_config(mesh_device_, mesh_workload, updated_cb_config_vector, cr_set);
}

Expand All @@ -951,8 +951,8 @@ TEST_F(MeshDevice_T3000, TestMeshWorkloadSemaphoreSanity) {
auto mesh_workload = CreateMeshWorkload();
LogicalDeviceRange devices = LogicalDeviceRange({0, 0}, {4, 2});
AddProgramToMeshWorkload(mesh_workload, program, devices);
EnqueueMeshWorkload(mesh_device_->command_queue(), mesh_workload, false);
Finish(mesh_device_->command_queue());
EnqueueMeshWorkload(mesh_device_->mesh_command_queue(), mesh_workload, false);
Finish(mesh_device_->mesh_command_queue());

for (const auto device : mesh_device_->get_devices()) {
validate_sems(mesh_device_, device, full_grid, mesh_workload, expected_semaphore_values);
Expand Down Expand Up @@ -980,8 +980,8 @@ TEST_F(MeshDevice_T3000, TestMeshWorkloadSemaphoreDifferentPrograms) {

AddProgramToMeshWorkload(mesh_workload, program0, devices_0);
AddProgramToMeshWorkload(mesh_workload, program1, devices_1);
EnqueueMeshWorkload(mesh_device_->command_queue(), mesh_workload, false);
Finish(mesh_device_->command_queue());
EnqueueMeshWorkload(mesh_device_->mesh_command_queue(), mesh_workload, false);
Finish(mesh_device_->mesh_command_queue());

for (std::size_t logical_x = devices_0.start_coord.x; logical_x < devices_0.end_coord.x; logical_x++) {
for (std::size_t logical_y = devices_0.start_coord.y; logical_y < devices_0.end_coord.y; logical_y++) {
Expand Down
2 changes: 1 addition & 1 deletion tests/ttnn/distributed/test_distributed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class DistributedTest : public ::testing::Test {
};

TEST_F(DistributedTest, TestSystemMeshTearDownWithoutClose) {
auto& sys = tt::tt_metal::distributed::SystemMesh::instance();
auto& sys = SystemMesh::instance();
auto mesh = ttnn::distributed::open_mesh_device(
{2, 4}, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, tt::tt_metal::DispatchCoreType::WORKER);

Expand Down
2 changes: 2 additions & 0 deletions tt_metal/distributed/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
set(DISTRIBUTED_SRC
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
${CMAKE_CURRENT_SOURCE_DIR}/mesh_handle.cpp
${CMAKE_CURRENT_SOURCE_DIR}/system_mesh.cpp
${CMAKE_CURRENT_SOURCE_DIR}/mesh_device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/mesh_device_view.cpp
${CMAKE_CURRENT_SOURCE_DIR}/mesh_workload.cpp
Expand Down
47 changes: 47 additions & 0 deletions tt_metal/distributed/mesh_config.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <vector>

namespace tt::tt_metal::distributed {

using DeviceIds = std::vector<int>;
using MeshDeviceID = int;
using chip_id_t = int;

struct MeshOffset {
size_t row = 0;
size_t col = 0;
};

struct MeshShape {
size_t num_rows = 0;
size_t num_cols = 0;
};

enum class MeshType { RowMajor, Ring, Line };

struct MeshDeviceConfig {
MeshShape mesh_shape;
MeshOffset offset;
std::vector<chip_id_t> physical_device_ids;
MeshType mesh_type;

MeshDeviceConfig(const MeshShape& mesh_shape, MeshType mesh_type) :
mesh_shape(mesh_shape),
offset(MeshOffset{0, 0}),
physical_device_ids(std::vector<chip_id_t>()),
mesh_type(mesh_type) {}

MeshDeviceConfig(
const MeshShape& mesh_shape,
const MeshOffset& offset = MeshOffset{0, 0},
const std::vector<chip_id_t>& physical_device_ids = {},
MeshType mesh_type = MeshType::RowMajor) :
mesh_shape(mesh_shape), offset(offset), physical_device_ids(physical_device_ids), mesh_type(mesh_type) {}
};

} // namespace tt::tt_metal::distributed
Loading

0 comments on commit b2c2351

Please sign in to comment.