Skip to content

Commit

Permalink
#0: Hoist SubDeviceManager/Lock-Step Allocator to MeshDevice
Browse files Browse the repository at this point in the history
  • Loading branch information
cfjchu committed Jan 23, 2025
1 parent 33bb3c0 commit a04ce94
Show file tree
Hide file tree
Showing 12 changed files with 518 additions and 225 deletions.
1 change: 1 addition & 0 deletions tests/tt_metal/distributed/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ set(UNIT_TESTS_DISTRIBUTED_SRC
${CMAKE_CURRENT_SOURCE_DIR}/test_distributed.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_mesh_buffer.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_mesh_workload.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_mesh_allocator.cpp
)

add_executable(distributed_unit_tests ${UNIT_TESTS_DISTRIBUTED_SRC})
Expand Down
33 changes: 33 additions & 0 deletions tests/tt_metal/distributed/test_mesh_allocator.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <gtest/gtest.h>
#include <memory>

#include <mesh_device.hpp>
#include "tests/tt_metal/tt_metal/common/multi_device_fixture.hpp"

namespace tt::tt_metal::distributed::test {

using MeshAllocatorTest = T3000MultiDeviceFixture;

TEST_F(MeshAllocatorTest, BasicAllocationSanityCheck) {
const size_t allocation_size = 1024 * 8; // 1KB
const tt::tt_metal::BufferType buffer_type = tt::tt_metal::BufferType::L1;

auto config = InterleavedBufferConfig{
.device = mesh_device_.get(),
.size = allocation_size,
.page_size = 1024,
.buffer_type = buffer_type,
.buffer_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED};

auto buffer = CreateBuffer(config);

EXPECT_TRUE(buffer->is_allocated());
EXPECT_EQ(buffer->size(), allocation_size);
EXPECT_EQ(buffer->buffer_type(), buffer_type);
}

} // namespace tt::tt_metal::distributed::test
4 changes: 2 additions & 2 deletions tt_metal/api/tt-metalium/device_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#pragma once

#include <memory>
#include <mutex>
#include <utility>

#include "device.hpp"
Expand Down Expand Up @@ -254,7 +253,8 @@ class Device : public IDevice {
static constexpr uint32_t DEFAULT_NUM_SUB_DEVICES = 1;

void initialize_cluster();
std::unique_ptr<Allocator> initialize_allocator(size_t l1_small_size, size_t trace_region_size, tt::stl::Span<const std::uint32_t> l1_bank_remap = {});
std::unique_ptr<Allocator> initialize_allocator(
size_t l1_small_size, size_t trace_region_size, tt::stl::Span<const std::uint32_t> l1_bank_remap = {});
void initialize_build();
void initialize_device_kernel_defines();
void initialize_device_bank_to_noc_tables(const HalProgrammableCoreType &core_type, CoreCoord virtual_core);
Expand Down
18 changes: 13 additions & 5 deletions tt_metal/api/tt-metalium/mesh_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,13 @@
#include "mesh_device_view.hpp"
#include "sub_device_types.hpp"
#include "span.hpp"
#include "work_executor.hpp"

namespace tt::tt_metal::distributed {
namespace tt::tt_metal {

class SubDeviceManagerTracker;

namespace distributed {

class MeshCommandQueue;
class MeshDeviceView;
Expand Down Expand Up @@ -56,8 +61,8 @@ class MeshDevice : public IDevice, public std::enable_shared_from_this<MeshDevic
submeshes_; // Parent owns submeshes and is responsible for their destruction
std::weak_ptr<MeshDevice> parent_mesh_; // Submesh created with reference to parent mesh
std::unique_ptr<MeshCommandQueue> mesh_command_queue_;

void initialize();
std::unique_ptr<SubDeviceManagerTracker> sub_device_manager_tracker_;
std::unique_ptr<WorkExecutor> work_executor_;

// This is a reference device used to query properties that are the same for all devices in the mesh.
IDevice* reference_device() const;
Expand Down Expand Up @@ -292,7 +297,8 @@ class MeshDevice : public IDevice, public std::enable_shared_from_this<MeshDevic
size_t l1_small_size = DEFAULT_L1_SMALL_SIZE,
size_t trace_region_size = DEFAULT_TRACE_REGION_SIZE,
size_t num_command_queues = 1,
const DispatchCoreConfig& dispatch_core_config = DispatchCoreConfig{});
const DispatchCoreConfig& dispatch_core_config = DispatchCoreConfig{},
tt::stl::Span<const std::uint32_t> l1_bank_remap = {});
};

std::ostream& operator<<(std::ostream& os, const MeshDevice& mesh_device);
Expand All @@ -305,4 +311,6 @@ struct MeshSubDeviceManagerId {
std::vector<SubDeviceManagerId> sub_device_manager_ids;
};

} // namespace tt::tt_metal::distributed
} // namespace distributed

} // namespace tt::tt_metal
3 changes: 2 additions & 1 deletion tt_metal/api/tt-metalium/sub_device_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ class SubDeviceManager {
MAX_NUM_SUB_DEVICES <= std::numeric_limits<SubDeviceId::Id>::max(),
"MAX_NUM_SUB_DEVICES must be less than or equal to the max value of SubDeviceId::Id");
// Constructor used for the default/global device
SubDeviceManager(IDevice* device, std::unique_ptr<Allocator>&& global_allocator);
SubDeviceManager(
IDevice* device, std::unique_ptr<Allocator>&& global_allocator, tt::stl::Span<const SubDevice> sub_devices);
// Constructor used for regular sub-devices
SubDeviceManager(tt::stl::Span<const SubDevice> sub_devices, DeviceAddr local_l1_size, IDevice* device);

Expand Down
6 changes: 5 additions & 1 deletion tt_metal/api/tt-metalium/sub_device_manager_tracker.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ class SubDeviceManagerTracker {
public:
// TODO: Potentially move the global allocator creation into here instead of from the device
// This creates the SubDeviceManagerTracker with a default SubDeviceManager that has the entire grid as a sub-device
SubDeviceManagerTracker(IDevice* device, std::unique_ptr<Allocator>&& global_allocator);
SubDeviceManagerTracker(
IDevice* device, std::unique_ptr<Allocator>&& global_allocator, tt::stl::Span<const SubDevice> sub_devices);

SubDeviceManagerTracker(const SubDeviceManagerTracker& other) = delete;
SubDeviceManagerTracker& operator=(const SubDeviceManagerTracker& other) = delete;
Expand Down Expand Up @@ -58,6 +59,9 @@ class SubDeviceManagerTracker {
// default case to not affect performance
SubDeviceManagerId get_default_sub_device_manager_id() const;

std::optional<DeviceAddr> lowest_occupied_compute_l1_address(
tt::stl::Span<const SubDeviceId> sub_device_ids = {}) const;

private:
void reset_sub_device_state(const std::unique_ptr<SubDeviceManager>& sub_device_manager);

Expand Down
1 change: 1 addition & 0 deletions tt_metal/api/tt-metalium/sub_device_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ struct SubDeviceId {
return *this;
}

bool operator<(size_t other) const { return id < other; }
bool operator==(const SubDeviceId& other) const { return id == other.id; }

bool operator!=(const SubDeviceId& other) const { return id != other.id; }
Expand Down
7 changes: 3 additions & 4 deletions tt_metal/distributed/mesh_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,9 @@ std::shared_ptr<MeshBuffer> MeshBuffer::create(
}},
mesh_buffer_config);

// Rely on the single device allocator to provide the address for the entire mesh buffer.
// TODO: use mesh allocator, when available.
// Rely on the MeshDevice allocator to provide the address for the entire mesh buffer.
std::shared_ptr<Buffer> backing_buffer = Buffer::create(
mesh_device->get_device(0, 0),
mesh_device,
/*address=*/address.value_or(0),
device_local_size,
device_local_config.page_size,
Expand Down Expand Up @@ -104,7 +103,7 @@ void MeshBuffer::allocate() {

auto allocate_device_buffer_at_address = [this](const Coordinate& coord) {
std::shared_ptr<Buffer> buffer = Buffer::create(
mesh_device_->get_device(coord.row, coord.col),
mesh_device_,
address_,
device_local_size_,
device_local_config_.page_size,
Expand Down
Loading

0 comments on commit a04ce94

Please sign in to comment.