Skip to content

Commit

Permalink
#0: address more comments
Browse files Browse the repository at this point in the history
  • Loading branch information
cfjchu committed Jan 14, 2025
1 parent f1c9b4a commit b246da3
Show file tree
Hide file tree
Showing 10 changed files with 192 additions and 196 deletions.
2 changes: 1 addition & 1 deletion tt_metal/distributed/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
set(DISTRIBUTED_SRC
${CMAKE_CURRENT_SOURCE_DIR}/coordinate_translation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
${CMAKE_CURRENT_SOURCE_DIR}/mesh_handle.cpp
${CMAKE_CURRENT_SOURCE_DIR}/system_mesh.cpp
${CMAKE_CURRENT_SOURCE_DIR}/mesh_device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/mesh_device_view.cpp
Expand Down
92 changes: 92 additions & 0 deletions tt_metal/distributed/coordinate_translation.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@

#include "tt_metal/distributed/coordinate_translation.hpp"

#include <nlohmann/json.hpp>

namespace tt::tt_metal::distributed {

namespace {

std::string get_config_path(const std::string& filename) {
std::string root_path = getenv("TT_METAL_HOME") ? getenv("TT_METAL_HOME") : "./";
return root_path + "/tt_metal/distributed/mesh_configurations/" + filename;
}

CoordinateTranslationMap load_translation_map(const std::string& filename, const std::string& key) {
std::ifstream file(filename);
TT_FATAL(file.is_open(), "Unable to open file: {}", filename);

nlohmann::json j;
try {
file >> j;
} catch (const nlohmann::json::parse_error& e) {
TT_THROW("JSON parsing error in file {}: {}", filename, e.what());
}

TT_FATAL(j.contains(key), "Key '{}' not found in JSON file: {}", key, filename);

CoordinateTranslationMap result;
for (const auto& mapping : j[key]) {
if (mapping.size() != 2 || mapping[0].size() != 2 || mapping[1].size() != 5) {
TT_THROW("Invalid coordinate format in JSON file: {}", filename);
}
result.emplace(
Coordinate{mapping[0][0], mapping[0][1]},
PhysicalCoordinate{
mapping[1][0], // cluster_id
mapping[1][2], // x
mapping[1][1], // y
mapping[1][3], // rack
mapping[1][4] // shelf
});
}

return result;
}

} // namespace

// Implementation of private static methods
MeshShape get_system_mesh_shape(size_t system_num_devices) {
static const std::unordered_map<size_t, MeshShape> system_mesh_to_shape = {
{1, MeshShape{1, 1}}, // single-device
{2, MeshShape{1, 2}}, // N300
{8, MeshShape{2, 4}}, // T3000; as ring to match existing tests
{32, MeshShape{8, 4}}, // TG, QG
{64, MeshShape{8, 8}}, // TGG
};
TT_FATAL(
system_mesh_to_shape.contains(system_num_devices), "Unsupported number of devices: {}", system_num_devices);
auto shape = system_mesh_to_shape.at(system_num_devices);
log_debug(LogMetal, "Logical SystemMesh Shape: {}x{}", shape.num_rows, shape.num_cols);
return shape;
}

std::pair<CoordinateTranslationMap, MeshShape> get_system_mesh_coordinate_translation_map() {
auto system_num_devices = tt::Cluster::instance().number_of_devices();

// TG has 32 non-mmio user devices and 4 mmio devices not exposed to the user
// QG has 32 mmio user devices
// Once TG is fully deprecated, can remove TG code path
std::string galaxy_mesh_descriptor = "TG.json";
if (tt::Cluster::instance().number_of_pci_devices() == system_num_devices) {
galaxy_mesh_descriptor = "QG.json";
}
const std::unordered_map<size_t, std::string> system_mesh_translation_map = {
{1, "device.json"},
{2, "N300.json"},
{8, "T3000.json"},
{32, galaxy_mesh_descriptor},
{64, "TGG.json"},
};
TT_FATAL(
system_mesh_translation_map.contains(system_num_devices),
"Unsupported number of devices: {}",
system_num_devices);
auto translation_config_file = get_config_path(system_mesh_translation_map.at(system_num_devices));
return {
load_translation_map(translation_config_file, "logical_to_physical_coordinates"),
get_system_mesh_shape(system_num_devices)};
}

} // namespace tt::tt_metal::distributed
23 changes: 23 additions & 0 deletions tt_metal/distributed/coordinate_translation.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <unordered_map>

#include "umd/device/types/cluster_descriptor_types.h"
#include "tt_metal/distributed/mesh_device_view.hpp"

namespace tt::tt_metal::distributed {

// TODO: Consider conversion to StrongType instead of alias
using LogicalCoordinate = Coordinate;
using PhysicalCoordinate = eth_coord_t;
using CoordinateTranslationMap = std::unordered_map<LogicalCoordinate, PhysicalCoordinate>;

// Returns a translation map between logical coordinates in logical 2D space
// to the physical coordinates as defined by the UMD layer.
std::pair<CoordinateTranslationMap, MeshShape> get_system_mesh_coordinate_translation_map();

} // namespace tt::tt_metal::distributed
13 changes: 13 additions & 0 deletions tt_metal/distributed/mesh_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,19 @@ struct MeshShape {
size_t num_cols = 0;
};

/**
* @brief Defines the organization of physical devices in a user-defined MeshDevice.
*
* The mesh type imposes properties on the physical connectivity of devices:
*
* - RowMajor: Devices are arranged in a 2D grid and accessed in row-major order.
* This is the default configuration for most multi-device setups.
*
* - Ring: Devices are arranged in a circular topology where each device is connected
* to its neighbors, forming a ring structure.
*
* - Line: Devices are arranged linearly in a single dimension.
*/
enum class MeshType { RowMajor, Ring, Line };

struct MeshDeviceConfig {
Expand Down
45 changes: 36 additions & 9 deletions tt_metal/distributed/mesh_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

#include "tt_metal/common/logger.hpp"
#include "tt_metal/host_api.hpp"
#include "tt_metal/distributed/mesh_handle.hpp"
#include "tt_metal/detail/tt_metal.hpp"
#include "tt_metal/distributed/system_mesh.hpp"
#include "tt_metal/distributed/mesh_device_view.hpp"
#include "tt_metal/distributed/mesh_command_queue.hpp"
Expand All @@ -27,6 +27,33 @@ MeshDeviceID generate_unique_mesh_id() {
}
}

MeshDevice::ScopedDevices::ScopedDevices(
size_t l1_small_size,
size_t trace_region_size,
size_t num_command_queues,
const DispatchCoreConfig& dispatch_core_config,
const MeshDeviceConfig& config) {
auto& system_mesh = SystemMesh::instance();
auto physical_device_ids = system_mesh.request_available_devices(config);

opened_devices_ = tt::tt_metal::detail::CreateDevices(
physical_device_ids, num_command_queues, l1_small_size, trace_region_size, dispatch_core_config);

for (auto physical_device_id : physical_device_ids) {
devices_.push_back(opened_devices_.at(physical_device_id));
}
}

MeshDevice::ScopedDevices::~ScopedDevices() {
if (not opened_devices_.empty()) {
tt::tt_metal::detail::CloseDevices(opened_devices_);
opened_devices_.clear();
devices_.clear();
}
}

const std::vector<IDevice*>& MeshDevice::ScopedDevices::get_devices() const { return devices_; }

uint32_t MeshDevice::build_key() const {
TT_FATAL(tt::tt_metal::hal.is_coordinate_virtualization_enabled(), "MeshDevice::build_key() expects coordinate virtualization to be enabled");
return reference_device()->build_key();
Expand All @@ -44,8 +71,8 @@ uint32_t MeshDevice::dram_size_per_channel() const { return reference_device()->

IDevice* MeshDevice::reference_device() const { return this->get_devices().at(0); }

MeshDevice::MeshDevice(std::shared_ptr<IMeshHandle> mesh_handle, const MeshShape& mesh_shape, MeshType type, std::weak_ptr<MeshDevice> parent_mesh) :
mesh_handle_(std::move(mesh_handle)),
MeshDevice::MeshDevice(std::shared_ptr<ScopedDevices> mesh_handle, const MeshShape& mesh_shape, MeshType type, std::weak_ptr<MeshDevice> parent_mesh) :
scoped_devices_(std::move(mesh_handle)),
mesh_shape_(mesh_shape),
type_(type),
mesh_id_(generate_unique_mesh_id()),
Expand All @@ -58,7 +85,7 @@ std::shared_ptr<MeshDevice> MeshDevice::create(
size_t num_command_queues,
const DispatchCoreConfig& dispatch_core_config) {
auto mesh_device = std::make_shared<MeshDevice>(
std::make_shared<MeshHandle>(l1_small_size, trace_region_size, num_command_queues, dispatch_core_config, config),
std::make_shared<ScopedDevices>(l1_small_size, trace_region_size, num_command_queues, dispatch_core_config, config),
config.mesh_shape,
config.mesh_type);
mesh_device->initialize();
Expand Down Expand Up @@ -90,7 +117,7 @@ std::shared_ptr<MeshDevice> MeshDevice::create_submesh(
mesh_shape_.num_cols);
}

auto submesh = std::make_shared<MeshDevice>(mesh_handle_, submesh_shape, type, shared_from_this());
auto submesh = std::make_shared<MeshDevice>(scoped_devices_, submesh_shape, type, shared_from_this());
auto start_coordinate = Coordinate{offset.row, offset.col};
auto end_coordinate = Coordinate{offset.row + submesh_shape.num_rows - 1, offset.col + submesh_shape.num_cols - 1};

Expand Down Expand Up @@ -123,7 +150,7 @@ std::vector<std::shared_ptr<MeshDevice>> MeshDevice::create_submeshes(const Mesh
}

void MeshDevice::initialize() {
view_ = std::make_unique<MeshDeviceView>(mesh_handle_->get_devices(), mesh_shape_);
view_ = std::make_unique<MeshDeviceView>(scoped_devices_->get_devices(), mesh_shape_);
SystemMesh::instance().register_mesh_device(shared_from_this(), this->get_devices());
if (this->using_fast_dispatch()) {
mesh_command_queue_ = std::make_unique<MeshCommandQueue>(this, 0);
Expand Down Expand Up @@ -219,16 +246,16 @@ void MeshDevice::reshape(const MeshShape& new_shape) {
}

mesh_shape_ = new_shape;
view_ = std::make_unique<MeshDeviceView>(mesh_handle_->get_devices(), mesh_shape_);
view_ = std::make_unique<MeshDeviceView>(scoped_devices_->get_devices(), mesh_shape_);
}

bool MeshDevice::close() {
for (const auto& submesh : submeshes_) {
submesh->close();
}
submeshes_.clear();
if (mesh_handle_) {
mesh_handle_.reset();
if (scoped_devices_) {
scoped_devices_.reset();
}
parent_mesh_.reset();
view_.reset();
Expand Down
29 changes: 25 additions & 4 deletions tt_metal/distributed/mesh_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#include <vector>

#include "tt_metal/include/tt_metal/device.hpp"
#include "tt_metal/distributed/mesh_handle.hpp"

#include "tt_metal/distributed/mesh_config.hpp"
#include "tt_metal/distributed/mesh_device_view.hpp"
Expand All @@ -25,8 +24,30 @@ class MeshSubDeviceManagerId;

class MeshDevice : public IDevice, public std::enable_shared_from_this<MeshDevice> {
private:
std::shared_ptr<IMeshHandle> mesh_handle_;

// Resource management class / RAII wrapper for *physical devices* of the mesh
class ScopedDevices {
private:
std::map<chip_id_t, IDevice*> opened_devices_;
std::vector<IDevice*> devices_;

public:
// Constructor acquires physical resources
ScopedDevices(
size_t l1_small_size,
size_t trace_region_size,
size_t num_command_queues,
const DispatchCoreConfig& dispatch_core_config,
const MeshDeviceConfig& config);

// Destructor releases physical resources
~ScopedDevices();
ScopedDevices(const ScopedDevices&) = delete;
ScopedDevices& operator=(const ScopedDevices&) = delete;

const std::vector<IDevice*>& get_devices() const;
};

std::shared_ptr<ScopedDevices> scoped_devices_;
MeshDeviceID mesh_id_;
MeshShape mesh_shape_;
MeshType type_;
Expand All @@ -43,7 +64,7 @@ class MeshDevice : public IDevice, public std::enable_shared_from_this<MeshDevic

public:
MeshDevice(
std::shared_ptr<IMeshHandle> mesh_handle,
std::shared_ptr<ScopedDevices> mesh_handle,
const MeshShape& mesh_shape,
MeshType type,
std::weak_ptr<MeshDevice> parent_mesh = {});
Expand Down
37 changes: 0 additions & 37 deletions tt_metal/distributed/mesh_handle.cpp

This file was deleted.

54 changes: 0 additions & 54 deletions tt_metal/distributed/mesh_handle.hpp

This file was deleted.

Loading

0 comments on commit b246da3

Please sign in to comment.