Skip to content

Commit

Permalink
#0: address more comments
Browse files Browse the repository at this point in the history
  • Loading branch information
cfjchu committed Jan 15, 2025
1 parent e066fff commit e77d03d
Show file tree
Hide file tree
Showing 17 changed files with 205 additions and 197 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_run_sfpu_eps(device):
Arch.BLACKHOLE: 1.1920899822825959e-07,
}
value = eps_mapping[device.arch()]
assert np.isclose(value, device.sfpu_eps())
assert value == device.sfpu_eps()


def test_run_sfpu_tensor(device):
Expand Down
2 changes: 1 addition & 1 deletion tt_metal/distributed/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
set(DISTRIBUTED_SRC
${CMAKE_CURRENT_SOURCE_DIR}/coordinate_translation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
${CMAKE_CURRENT_SOURCE_DIR}/mesh_handle.cpp
${CMAKE_CURRENT_SOURCE_DIR}/system_mesh.cpp
${CMAKE_CURRENT_SOURCE_DIR}/mesh_device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/mesh_device_view.cpp
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
98 changes: 98 additions & 0 deletions tt_metal/distributed/coordinate_translation.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "tt_metal/distributed/coordinate_translation.hpp"

#include <nlohmann/json.hpp>

namespace tt::tt_metal::distributed {

namespace {

std::string get_config_path(const std::string& filename) {
std::string root_path = getenv("TT_METAL_HOME") ? getenv("TT_METAL_HOME") : "./";
return root_path + "/tt_metal/distributed/config/" + filename;
}

CoordinateTranslationMap load_translation_map(const std::string& filename, const std::string& key) {
std::ifstream file(filename);
TT_FATAL(file.is_open(), "Unable to open file: {}", filename);

nlohmann::json j;
try {
file >> j;
} catch (const nlohmann::json::parse_error& e) {
TT_THROW("JSON parsing error in file {}: {}", filename, e.what());
}

TT_FATAL(j.contains(key), "Key '{}' not found in JSON file: {}", key, filename);

CoordinateTranslationMap result;
for (const auto& mapping : j[key]) {
if (mapping.size() != 2 || mapping[0].size() != 2 || mapping[1].size() != 5) {
TT_THROW("Invalid coordinate format in JSON file: {}", filename);
}
result.emplace(
Coordinate{mapping[0][0], mapping[0][1]},
PhysicalCoordinate{
mapping[1][0], // cluster_id
mapping[1][2], // x
mapping[1][1], // y
mapping[1][3], // rack
mapping[1][4] // shelf
});
}

return result;
}

MeshShape get_system_mesh_shape(size_t system_num_devices) {
static const std::unordered_map<size_t, MeshShape> system_mesh_to_shape = {
{1, MeshShape{1, 1}}, // single-device
{2, MeshShape{1, 2}}, // N300
{8, MeshShape{2, 4}}, // T3000; as ring to match existing tests
{32, MeshShape{8, 4}}, // TG, QG
{64, MeshShape{8, 8}}, // TGG
};
TT_FATAL(
system_mesh_to_shape.contains(system_num_devices), "Unsupported number of devices: {}", system_num_devices);
auto shape = system_mesh_to_shape.at(system_num_devices);
log_debug(LogMetal, "Logical SystemMesh Shape: {}x{}", shape.num_rows, shape.num_cols);
return shape;
}

} // namespace

std::pair<CoordinateTranslationMap, MeshShape> get_system_mesh_coordinate_translation_map() {
static const auto* cached_translation_map = new std::pair<CoordinateTranslationMap, MeshShape>([] {
auto system_num_devices = tt::Cluster::instance().number_of_devices();

std::string galaxy_mesh_descriptor = "TG.json";
if (tt::Cluster::instance().number_of_pci_devices() == system_num_devices) {
galaxy_mesh_descriptor = "QG.json";
}

const std::unordered_map<size_t, std::string> system_mesh_translation_map = {
{1, "device.json"},
{2, "N300.json"},
{8, "T3000.json"},
{32, galaxy_mesh_descriptor},
{64, "TGG.json"},
};

TT_FATAL(
system_mesh_translation_map.contains(system_num_devices),
"Unsupported number of devices: {}",
system_num_devices);

auto translation_config_file = get_config_path(system_mesh_translation_map.at(system_num_devices));
return std::pair<CoordinateTranslationMap, MeshShape>{
load_translation_map(translation_config_file, "logical_to_physical_coordinates"),
get_system_mesh_shape(system_num_devices)};
}());

return *cached_translation_map;
}

} // namespace tt::tt_metal::distributed
23 changes: 23 additions & 0 deletions tt_metal/distributed/coordinate_translation.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <unordered_map>

#include "umd/device/types/cluster_descriptor_types.h"
#include "tt_metal/distributed/mesh_device_view.hpp"

namespace tt::tt_metal::distributed {

// TODO: Consider conversion to StrongType instead of alias
using LogicalCoordinate = Coordinate;
using PhysicalCoordinate = eth_coord_t;
using CoordinateTranslationMap = std::unordered_map<LogicalCoordinate, PhysicalCoordinate>;

// Returns a translation map between logical coordinates in logical 2D space
// to the physical coordinates as defined by the UMD layer.
std::pair<CoordinateTranslationMap, MeshShape> get_system_mesh_coordinate_translation_map();

} // namespace tt::tt_metal::distributed
13 changes: 13 additions & 0 deletions tt_metal/distributed/mesh_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,19 @@ struct MeshShape {
size_t num_cols = 0;
};

/**
* @brief Defines the organization of physical devices in a user-defined MeshDevice.
*
* The mesh type imposes properties on the physical connectivity of devices:
*
* - RowMajor: Devices are arranged in a 2D grid and accessed in row-major order.
* This is the default configuration for most multi-device setups.
*
* - Ring: Devices are arranged in a circular topology where each device is connected
* to its neighbors, forming a ring structure.
*
* - Line: Devices are arranged linearly in a single dimension.
*/
enum class MeshType { RowMajor, Ring, Line };

struct MeshDeviceConfig {
Expand Down
50 changes: 41 additions & 9 deletions tt_metal/distributed/mesh_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

#include "tt_metal/common/logger.hpp"
#include "tt_metal/host_api.hpp"
#include "tt_metal/distributed/mesh_handle.hpp"
#include "tt_metal/detail/tt_metal.hpp"
#include "tt_metal/distributed/system_mesh.hpp"
#include "tt_metal/distributed/mesh_device_view.hpp"
#include "tt_metal/distributed/mesh_command_queue.hpp"
Expand All @@ -27,6 +27,33 @@ MeshDeviceID generate_unique_mesh_id() {
}
}

MeshDevice::ScopedDevices::ScopedDevices(
size_t l1_small_size,
size_t trace_region_size,
size_t num_command_queues,
const DispatchCoreConfig& dispatch_core_config,
const MeshDeviceConfig& config) {
auto& system_mesh = SystemMesh::instance();
auto physical_device_ids = system_mesh.request_available_devices(config);

opened_devices_ = tt::tt_metal::detail::CreateDevices(
physical_device_ids, num_command_queues, l1_small_size, trace_region_size, dispatch_core_config);

for (auto physical_device_id : physical_device_ids) {
devices_.push_back(opened_devices_.at(physical_device_id));
}
}

MeshDevice::ScopedDevices::~ScopedDevices() {
if (not opened_devices_.empty()) {
tt::tt_metal::detail::CloseDevices(opened_devices_);
opened_devices_.clear();
devices_.clear();
}
}

const std::vector<IDevice*>& MeshDevice::ScopedDevices::get_devices() const { return devices_; }

uint32_t MeshDevice::build_key() const {
TT_FATAL(tt::tt_metal::hal.is_coordinate_virtualization_enabled(), "MeshDevice::build_key() expects coordinate virtualization to be enabled");
return reference_device()->build_key();
Expand All @@ -44,8 +71,8 @@ uint32_t MeshDevice::dram_size_per_channel() const { return reference_device()->

IDevice* MeshDevice::reference_device() const { return this->get_devices().at(0); }

MeshDevice::MeshDevice(std::shared_ptr<IMeshHandle> mesh_handle, const MeshShape& mesh_shape, MeshType type, std::weak_ptr<MeshDevice> parent_mesh) :
mesh_handle_(std::move(mesh_handle)),
MeshDevice::MeshDevice(std::shared_ptr<ScopedDevices> mesh_handle, const MeshShape& mesh_shape, MeshType type, std::weak_ptr<MeshDevice> parent_mesh) :
scoped_devices_(std::move(mesh_handle)),
mesh_shape_(mesh_shape),
type_(type),
mesh_id_(generate_unique_mesh_id()),
Expand All @@ -58,7 +85,7 @@ std::shared_ptr<MeshDevice> MeshDevice::create(
size_t num_command_queues,
const DispatchCoreConfig& dispatch_core_config) {
auto mesh_device = std::make_shared<MeshDevice>(
std::make_shared<MeshHandle>(l1_small_size, trace_region_size, num_command_queues, dispatch_core_config, config),
std::make_shared<ScopedDevices>(l1_small_size, trace_region_size, num_command_queues, dispatch_core_config, config),
config.mesh_shape,
config.mesh_type);
mesh_device->initialize();
Expand Down Expand Up @@ -90,7 +117,7 @@ std::shared_ptr<MeshDevice> MeshDevice::create_submesh(
mesh_shape_.num_cols);
}

auto submesh = std::make_shared<MeshDevice>(mesh_handle_, submesh_shape, type, shared_from_this());
auto submesh = std::make_shared<MeshDevice>(scoped_devices_, submesh_shape, type, shared_from_this());
auto start_coordinate = Coordinate{offset.row, offset.col};
auto end_coordinate = Coordinate{offset.row + submesh_shape.num_rows - 1, offset.col + submesh_shape.num_cols - 1};

Expand Down Expand Up @@ -123,7 +150,7 @@ std::vector<std::shared_ptr<MeshDevice>> MeshDevice::create_submeshes(const Mesh
}

void MeshDevice::initialize() {
view_ = std::make_unique<MeshDeviceView>(mesh_handle_->get_devices(), mesh_shape_);
view_ = std::make_unique<MeshDeviceView>(scoped_devices_->get_devices(), mesh_shape_);
SystemMesh::instance().register_mesh_device(shared_from_this(), this->get_devices());
if (this->using_fast_dispatch()) {
mesh_command_queue_ = std::make_unique<MeshCommandQueue>(this, 0);
Expand Down Expand Up @@ -219,16 +246,16 @@ void MeshDevice::reshape(const MeshShape& new_shape) {
}

mesh_shape_ = new_shape;
view_ = std::make_unique<MeshDeviceView>(mesh_handle_->get_devices(), mesh_shape_);
view_ = std::make_unique<MeshDeviceView>(scoped_devices_->get_devices(), mesh_shape_);
}

bool MeshDevice::close() {
for (const auto& submesh : submeshes_) {
submesh->close();
}
submeshes_.clear();
if (mesh_handle_) {
mesh_handle_.reset();
if (scoped_devices_) {
scoped_devices_.reset();
}
parent_mesh_.reset();
view_.reset();
Expand Down Expand Up @@ -636,6 +663,11 @@ void MeshDevice::deallocate_buffers(SubDeviceId sub_device_id) { reference_devic
void MeshDevice::dump_memory_blocks(const BufferType& buffer_type, std::ofstream& out) const { reference_device()->dump_memory_blocks(buffer_type, out); }
void MeshDevice::dump_memory_blocks(const BufferType& buffer_type, std::ofstream& out, SubDeviceId sub_device_id) const { reference_device()->dump_memory_blocks(buffer_type, out, sub_device_id); }

MemoryBlockTable MeshDevice::get_memory_block_table(const BufferType& buffer_type) const {
TT_THROW("get_memory_block_table() is not supported on MeshDevice - use individual devices instead");
return reference_device()->get_memory_block_table(buffer_type);
}

MeshSubDeviceManagerId MeshDevice::mesh_create_sub_device_manager(
tt::stl::Span<const SubDevice> sub_devices, DeviceAddr local_l1_size) {
MeshSubDeviceManagerId mesh_sub_device_manager_id(*this);
Expand Down
30 changes: 26 additions & 4 deletions tt_metal/distributed/mesh_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#include <vector>

#include "tt_metal/include/tt_metal/device.hpp"
#include "tt_metal/distributed/mesh_handle.hpp"

#include "tt_metal/distributed/mesh_config.hpp"
#include "tt_metal/distributed/mesh_device_view.hpp"
Expand All @@ -25,8 +24,30 @@ class MeshSubDeviceManagerId;

class MeshDevice : public IDevice, public std::enable_shared_from_this<MeshDevice> {
private:
std::shared_ptr<IMeshHandle> mesh_handle_;

// Resource management class / RAII wrapper for *physical devices* of the mesh
class ScopedDevices {
private:
std::map<chip_id_t, IDevice*> opened_devices_;
std::vector<IDevice*> devices_;

public:
// Constructor acquires physical resources
ScopedDevices(
size_t l1_small_size,
size_t trace_region_size,
size_t num_command_queues,
const DispatchCoreConfig& dispatch_core_config,
const MeshDeviceConfig& config);

// Destructor releases physical resources
~ScopedDevices();
ScopedDevices(const ScopedDevices&) = delete;
ScopedDevices& operator=(const ScopedDevices&) = delete;

const std::vector<IDevice*>& get_devices() const;
};

std::shared_ptr<ScopedDevices> scoped_devices_;
MeshDeviceID mesh_id_;
MeshShape mesh_shape_;
MeshType type_;
Expand All @@ -43,7 +64,7 @@ class MeshDevice : public IDevice, public std::enable_shared_from_this<MeshDevic

public:
MeshDevice(
std::shared_ptr<IMeshHandle> mesh_handle,
std::shared_ptr<ScopedDevices> mesh_handle,
const MeshShape& mesh_shape,
MeshType type,
std::weak_ptr<MeshDevice> parent_mesh = {});
Expand Down Expand Up @@ -201,6 +222,7 @@ class MeshDevice : public IDevice, public std::enable_shared_from_this<MeshDevic
uint32_t get_completion_queue_reader_core() const override;
bool is_mmio_capable() const override;
std::vector<std::vector<chip_id_t>> get_tunnels_from_mmio() const override;
MemoryBlockTable get_memory_block_table(const BufferType& buffer_type) const override;

// A MeshDevice is a collection of devices arranged in a 2D grid.
// The type parameter allows the caller to specify how to linearize the devices in the mesh.
Expand Down
37 changes: 0 additions & 37 deletions tt_metal/distributed/mesh_handle.cpp

This file was deleted.

Loading

0 comments on commit e77d03d

Please sign in to comment.