From bad6a24b0b49dd199576fb9594ea82f5c5ffc66d Mon Sep 17 00:00:00 2001 From: Kyle Mabee Date: Fri, 13 Dec 2024 00:18:07 +0000 Subject: [PATCH] LightMetal - Add initial Trace/Replay support for many popular host APIs - Not comprehensive, just initial changes for some of these paths: CreateBuffer() EnqueueWriteBuffer() EnqueueReadBuffer() Finish() DeallocateBuffer ReleaseTrace() CreateProgram() EnqueueProgram() CreateKernel() SetRuntimeArgs(uint32) SetRuntimeArgs(Kernel,RuntimeArgs) CreateCircularBuffer() - When Metal Trace is enabled, don't capture EnqueueProgram(), instead inject ReplayTrace(), would be used alongside LoadTrace() - Serialization / Deserialization of structs/enums defined in flatbuffer schema types.fbs handled via ToFlatbuffer() and FromFlatbuffer() functions. --- tt_metal/impl/CMakeLists.txt | 1 + .../impl/buffers/circular_buffer_types.hpp | 19 + tt_metal/impl/dispatch/command_queue.cpp | 4 + tt_metal/impl/dispatch/command_queue.hpp | 1 + .../lightmetal/host_api_capture_helpers.hpp | 298 ++++++++++++++ .../impl/lightmetal/lightmetal_capture.cpp | 129 ++++++ .../impl/lightmetal/lightmetal_capture.hpp | 32 +- .../impl/lightmetal/lightmetal_replay.cpp | 381 ++++++++++++++++- .../impl/lightmetal/lightmetal_replay.hpp | 60 +++ tt_metal/impl/tracehost/command.fbs | 86 +++- tt_metal/impl/tracehost/types.fbs | 211 ++++++++++ .../impl/tracehost/types_from_flatbuffer.hpp | 310 ++++++++++++++ .../impl/tracehost/types_to_flatbuffer.hpp | 387 ++++++++++++++++++ tt_metal/tt_metal.cpp | 29 +- 14 files changed, 1939 insertions(+), 9 deletions(-) create mode 100644 tt_metal/impl/tracehost/types.fbs create mode 100644 tt_metal/impl/tracehost/types_from_flatbuffer.hpp create mode 100644 tt_metal/impl/tracehost/types_to_flatbuffer.hpp diff --git a/tt_metal/impl/CMakeLists.txt b/tt_metal/impl/CMakeLists.txt index 0563b052386a..4cb0cb86698f 100644 --- a/tt_metal/impl/CMakeLists.txt +++ b/tt_metal/impl/CMakeLists.txt @@ -47,6 +47,7 @@ include(${PROJECT_SOURCE_DIR}/cmake/flatbuffers.cmake) set(FLATBUFFER_SCHEMAS ${CMAKE_CURRENT_SOURCE_DIR}/lightmetal/binary.fbs ${CMAKE_CURRENT_SOURCE_DIR}/tracehost/command.fbs + ${CMAKE_CURRENT_SOURCE_DIR}/tracehost/types.fbs ) foreach(FBS_FILE ${FLATBUFFER_SCHEMAS}) GENERATE_FBS_HEADER(${FBS_FILE}) diff --git a/tt_metal/impl/buffers/circular_buffer_types.hpp b/tt_metal/impl/buffers/circular_buffer_types.hpp index 2a174c92813c..1d5befc73087 100644 --- a/tt_metal/impl/buffers/circular_buffer_types.hpp +++ b/tt_metal/impl/buffers/circular_buffer_types.hpp @@ -10,6 +10,8 @@ #include #include #include +#include +#include "flatbuffers/flatbuffer_builder.h" #include "tt_metal/common/logger.hpp" #include "tt_metal/common/tt_backend_api_types.hpp" @@ -18,12 +20,29 @@ #include "tt_metal/hw/inc/circular_buffer_constants.h" +// Forward declarations for external flatbuffer serialization function +namespace tt::target { +class CircularBufferConfig; +} +namespace tt::tt_metal { +inline namespace v0 { +class CircularBufferConfig; +} +inline flatbuffers::Offset ToFlatbuffer( + const tt::tt_metal::CircularBufferConfig& config, flatbuffers::FlatBufferBuilder& builder); +inline CircularBufferConfig fromFlatBuffer(const tt::target::CircularBufferConfig* config_fb); +} // namespace tt::tt_metal + namespace tt::tt_metal { inline namespace v0 { using CBHandle = uintptr_t; class CircularBufferConfig { + friend flatbuffers::Offset tt::tt_metal::ToFlatbuffer( + const tt::tt_metal::CircularBufferConfig& config, flatbuffers::FlatBufferBuilder& builder); + friend CircularBufferConfig FromFlatbuffer(const tt::target::CircularBufferConfig* config_fb); + public: // Static circular buffer spec CircularBufferConfig(uint32_t total_size, const std::map& data_format_spec); diff --git a/tt_metal/impl/dispatch/command_queue.cpp b/tt_metal/impl/dispatch/command_queue.cpp index e8ddb2e8a86d..8b2057b98338 100644 --- a/tt_metal/impl/dispatch/command_queue.cpp +++ b/tt_metal/impl/dispatch/command_queue.cpp @@ -1911,6 +1911,7 @@ void EnqueueReadBuffer( void* dst, bool blocking, tt::stl::Span sub_device_ids) { + TRACE_FUNCTION_CALL(CaptureEnqueueReadBuffer, cq, buffer, dst, blocking); // FIXME (kmabee) consider sub_device_ids added recently. detail::DispatchStateCheck(true); cq.run_command(CommandInterface{ .type = EnqueueCommandType::ENQUEUE_READ_BUFFER, .blocking = blocking, .buffer = buffer, .dst = dst, .sub_device_ids = sub_device_ids}); @@ -1922,6 +1923,7 @@ void EnqueueWriteBuffer( HostDataType src, bool blocking, tt::stl::Span sub_device_ids) { + TRACE_FUNCTION_CALL(CaptureEnqueueWriteBuffer, cq, buffer, src, blocking); // FIXME (kmabee) consider sub_device_ids added recently. detail::DispatchStateCheck(true); cq.run_command(CommandInterface{ .type = EnqueueCommandType::ENQUEUE_WRITE_BUFFER, .blocking = blocking, .buffer = buffer, .src = std::move(src), .sub_device_ids = sub_device_ids}); @@ -1929,6 +1931,7 @@ void EnqueueWriteBuffer( void EnqueueProgram( CommandQueue& cq, Program& program, bool blocking) { + TRACE_FUNCTION_CALL(CaptureEnqueueProgram, cq, program, blocking); detail::DispatchStateCheck(true); cq.run_command( CommandInterface{.type = EnqueueCommandType::ENQUEUE_PROGRAM, .blocking = blocking, .program = &program}); @@ -1990,6 +1993,7 @@ bool EventQuery(const std::shared_ptr& event) { } void Finish(CommandQueue& cq, tt::stl::Span sub_device_ids) { + TRACE_FUNCTION_CALL(CaptureFinish, cq); // FIXME (kmabee) consider sub_device_ids added recently. detail::DispatchStateCheck(true); cq.run_command(CommandInterface{.type = EnqueueCommandType::FINISH, .blocking = true, .sub_device_ids = sub_device_ids}); TT_ASSERT( diff --git a/tt_metal/impl/dispatch/command_queue.hpp b/tt_metal/impl/dispatch/command_queue.hpp index 52956541e6a0..617aef1e366a 100644 --- a/tt_metal/impl/dispatch/command_queue.hpp +++ b/tt_metal/impl/dispatch/command_queue.hpp @@ -608,6 +608,7 @@ class HWCommandQueue { friend void FinishImpl(CommandQueue& cq, tt::stl::Span sub_device_ids); friend CommandQueue; friend detail::Program_; + friend void CaptureEnqueueProgram(CommandQueue& cq, Program& program, bool blocking); }; // Common interface for all command queue types diff --git a/tt_metal/impl/lightmetal/host_api_capture_helpers.hpp b/tt_metal/impl/lightmetal/host_api_capture_helpers.hpp index d8d2c8145f98..cb6075bfb299 100644 --- a/tt_metal/impl/lightmetal/host_api_capture_helpers.hpp +++ b/tt_metal/impl/lightmetal/host_api_capture_helpers.hpp @@ -8,6 +8,8 @@ #include "lightmetal_capture.hpp" #include "command_generated.h" #include "tt_metal/common/logger.hpp" +#include "tt_metal/tt_stl/span.hpp" +#include "tracehost/types_to_flatbuffer.hpp" // FIXME (kmabee) - Temp hack, remove before merge and integrate as cmake define. #define ENABLE_TRACING 1 @@ -25,6 +27,41 @@ } while (0) #endif +namespace tt::tt_metal { + +////////////////////////////////////////////////////////////// +// Debug Code // +////////////////////////////////////////////////////////////// + +inline void PrintHostDataType(const HostDataType& data) { + std::visit( + [](const auto& value) { + using T = std::decay_t; + if constexpr (std::is_same_v>>) { + log_info(tt::LogMetalTrace, "HostDataType contains: std::shared_ptr>"); + } else if constexpr (std::is_same_v>>) { + log_info(tt::LogMetalTrace, "HostDataType contains: std::shared_ptr>"); + } else if constexpr (std::is_same_v>>) { + log_info(tt::LogMetalTrace, "HostDataType contains: std::shared_ptr>"); + } else if constexpr (std::is_same_v>>) { + log_info(tt::LogMetalTrace, "HostDataType contains: std::shared_ptr>"); + } else if constexpr (std::is_same_v>>) { + log_info(tt::LogMetalTrace, "HostDataType contains: std::shared_ptr>"); + } else if constexpr (std::is_same_v>>) { + log_info(tt::LogMetalTrace, "HostDataType contains: std::shared_ptr>"); + } else if constexpr (std::is_same_v) { + log_info(tt::LogMetalTrace, "HostDataType contains: const void*"); + } else { + log_info(tt::LogMetalTrace, "HostDataType contains: Unknown type"); + } + }, + data); +} + +////////////////////////////////////////////////////////////// +// Host API tracing helper functions // +////////////////////////////////////////////////////////////// + // Generic helper to build command and add to vector of cmds (CQ) inline void CaptureCommand(tt::target::CommandType cmd_type, ::flatbuffers::Offset fb_offset) { auto& ctx = LightMetalCaptureContext::Get(); @@ -51,3 +88,264 @@ inline void CaptureLoadTrace(IDevice* device, const uint8_t cq_id, const uint32_ auto cmd = tt::target::CreateLoadTraceCommand(ctx.GetBuilder(), tid, cq_id); CaptureCommand(tt::target::CommandType::LoadTraceCommand, cmd.Union()); } + +inline void CaptureReleaseTrace(IDevice* device, uint32_t tid) { + auto& ctx = LightMetalCaptureContext::Get(); + log_debug(tt::LogMetalTrace, "{}: tid: {}", __FUNCTION__, tid); + auto cmd = tt::target::CreateReleaseTraceCommand(ctx.GetBuilder(), tid); + CaptureCommand(tt::target::CommandType::ReleaseTraceCommand, cmd.Union()); +} + +// TODO (kmabee) - Consider passing Buffer* to capture funcs intead so it's clear we don't extend lifetime of buffer. +inline void CaptureCreateBuffer(std::shared_ptr buffer, const InterleavedBufferConfig& config) { + auto& ctx = LightMetalCaptureContext::Get(); + + uint32_t buffer_global_id = ctx.AddToMap(buffer.get()); + log_debug( + tt::LogMetalTrace, + "{}: size: {} page_size: {} buffer_type: {} buffer_layout: {} buffer_global_id: {}", + __FUNCTION__, + config.size, + config.page_size, + config.buffer_type, + config.buffer_layout, + buffer_global_id); + + assert(config.device->id() == 0 && "multichip not supported yet"); + auto buffer_config_offset = tt::target::CreateInterleavedBufferConfig( + ctx.GetBuilder(), + config.device->id(), + config.size, + config.page_size, + ToFlatbuffer(config.buffer_type), + ToFlatbuffer(config.buffer_layout)); + auto cmd = tt::target::CreateCreateBufferCommand(ctx.GetBuilder(), buffer_global_id, buffer_config_offset); + CaptureCommand(tt::target::CommandType::CreateBufferCommand, cmd.Union()); +} + +inline void CaptureDeallocateBuffer(Buffer* buffer) { + auto& ctx = LightMetalCaptureContext::Get(); + + auto buffer_global_id = ctx.GetGlobalId(buffer); + + log_debug( + tt::LogMetalTrace, + "{}: buffer_global_id: {} size: {} address: {}", + __FUNCTION__, + buffer_global_id, + buffer->size(), + buffer->address()); + + auto cmd = tt::target::CreateDeallocateBufferCommand(ctx.GetBuilder(), buffer_global_id); + CaptureCommand(tt::target::CommandType::DeallocateBufferCommand, cmd.Union()); +} + +inline void CaptureEnqueueWriteBuffer( + CommandQueue& cq, + std::variant, std::shared_ptr> buffer, + HostDataType src, + bool blocking) { + auto& ctx = LightMetalCaptureContext::Get(); + + // We don't want to use shared_ptr to extend lifetime of buffer when adding to global_id map. + Buffer* buffer_ptr = std::holds_alternative>(buffer) + ? std::get>(buffer).get() + : &std::get>(buffer).get(); + + uint32_t cq_global_id = cq.id(); // TODO (kmabee) - consider storing/getting CQ from global map instead. + uint32_t buffer_global_id = ctx.GetGlobalId(buffer_ptr); + + log_debug( + tt::LogMetalTrace, "{}: cq_global_id: {} buffer_global_id: {}", __FUNCTION__, cq_global_id, buffer_global_id); + // PrintHostDataType(src); // Debug + + // TODO (kmabee) - Currently support limited data formats. Long term we might not store data in flatbuffer, + // but have it provided at runtime so just do what's easiest here and support few types for now. + ::flatbuffers::Offset<::flatbuffers::Vector> src_vector; + if (auto* uint32_vec = std::get_if>>(&src)) { + src_vector = ctx.GetBuilder().CreateVector(**uint32_vec); + } else if (auto* uint16_vec = std::get_if>>(&src)) { + // Convert uint16_t to uint32_t before creating the FlatBuffers vector + std::vector converted(uint16_vec->get()->begin(), uint16_vec->get()->end()); + src_vector = ctx.GetBuilder().CreateVector(converted); + } else if (auto* void_ptr = std::get_if(&src)) { + // Assuming the void* points to a buffer of uint32_t values. Infer size, cast to uint32_t. + size_t num_elements = buffer_ptr->size() / sizeof(uint32_t); + auto uint32_data = static_cast(*void_ptr); + src_vector = ctx.GetBuilder().CreateVector(uint32_data, num_elements); + } else { + throw std::runtime_error("Unsupported HostDataType for captureEnqueueWriteBuffer()"); + } + + auto cmd = tt::target::CreateEnqueueWriteBufferCommand( + ctx.GetBuilder(), cq_global_id, buffer_global_id, src_vector, blocking); + CaptureCommand(tt::target::CommandType::EnqueueWriteBufferCommand, cmd.Union()); +} + +inline void CaptureEnqueueReadBuffer( + CommandQueue& cq, + std::variant, std::shared_ptr> buffer, + void* dst, + bool blocking) { + auto& ctx = LightMetalCaptureContext::Get(); + + // We don't want to use shared_ptr to extend lifetime of buffer when adding to global_id map. + Buffer* buffer_ptr = std::holds_alternative>(buffer) + ? std::get>(buffer).get() + : &std::get>(buffer).get(); + + uint32_t cq_global_id = cq.id(); // TODO (kmabee) - consider storing/getting CQ from global map instead. + uint32_t buffer_global_id = ctx.GetGlobalId(buffer_ptr); + + log_debug( + tt::LogMetalTrace, "{}: cq_global_id: {} buffer_global_id: {}", __FUNCTION__, cq_global_id, buffer_global_id); + + // Idea store a read_global_id to keep track of read results. + auto cmd = tt::target::CreateEnqueueReadBufferCommand(ctx.GetBuilder(), cq_global_id, buffer_global_id, blocking); + CaptureCommand(tt::target::CommandType::EnqueueReadBufferCommand, cmd.Union()); +} + +inline void CaptureFinish(CommandQueue& cq) { + auto& ctx = LightMetalCaptureContext::Get(); + uint32_t cq_global_id = cq.id(); // TODO (kmabee) - consider storing/getting CQ from global map instead. + log_debug(tt::LogMetalTrace, "{}: cq_global_id: {}", __FUNCTION__, cq_global_id); + auto cmd = tt::target::CreateFinishCommand(ctx.GetBuilder(), cq_global_id); + CaptureCommand(tt::target::CommandType::FinishCommand, cmd.Union()); +} + +inline void CaptureCreateProgram(Program& program) { + auto& ctx = LightMetalCaptureContext::Get(); + uint32_t program_global_id = ctx.AddToMap(&program); + log_debug(tt::LogMetalTrace, "{}: program_global_id: {}", __FUNCTION__, program_global_id); + + auto cmd = tt::target::CreateCreateProgramCommand(ctx.GetBuilder(), program_global_id); + CaptureCommand(tt::target::CommandType::CreateProgramCommand, cmd.Union()); +} + +inline void CaptureEnqueueProgram(CommandQueue& cq, Program& program, bool blocking) { + auto& ctx = LightMetalCaptureContext::Get(); + + // When Metal Trace is enabled, skip EnqueueProgram capture (replaced with LoadTrace + ReplayTrace) + if (cq.hw_command_queue().manager.get_bypass_mode()) { + return; + } + + uint32_t cq_global_id = cq.id(); // TODO (kmabee) - consider storing/getting CQ from global map instead. + uint32_t program_global_id = ctx.GetGlobalId(&program); + log_debug( + tt::LogMetalTrace, "{}: cq_global_id: {} program_global_id: {}", __FUNCTION__, cq_global_id, program_global_id); + + auto cmd = tt::target::CreateEnqueueProgramCommand(ctx.GetBuilder(), cq_global_id, program_global_id, blocking); + CaptureCommand(tt::target::CommandType::EnqueueProgramCommand, cmd.Union()); +} + +inline void CaptureCreateKernel( + KernelHandle kernel_id, + Program& program, + const std::string& file_name, + const std::variant& core_spec, + const std::variant& config) { + auto& ctx = LightMetalCaptureContext::Get(); + + std::shared_ptr kernel = program.get_kernel(kernel_id); + uint32_t kernel_global_id = ctx.AddToMap(kernel.get()); + uint32_t program_global_id = ctx.GetGlobalId(&program); + log_debug( + tt::LogMetalTrace, + "{}: file_name: {} kernel_global_id: {} (kernel_id: {}) program_global_id: {}", + __FUNCTION__, + file_name, + kernel_global_id, + kernel_id, + program_global_id); + + auto& fbb = ctx.GetBuilder(); + auto filename_offset = fbb.CreateString(file_name); + auto [core_spec_type, core_spec_offset] = ToFlatbuffer(fbb, core_spec); + auto [config_type, config_offset] = ToFlatbuffer(fbb, config); + + auto cmd = tt::target::CreateCreateKernelCommand( + fbb, + kernel_global_id, + program_global_id, + filename_offset, + core_spec_type, + core_spec_offset, + config_type, + config_offset); + CaptureCommand(tt::target::CommandType::CreateKernelCommand, cmd.Union()); +} + +inline void CaptureSetRuntimeArgsUint32( + const Program& program, + KernelHandle kernel_id, + const std::variant& core_spec, + tt::stl::Span runtime_args) { + auto& ctx = LightMetalCaptureContext::Get(); + + std::shared_ptr kernel = program.get_kernel(kernel_id); + uint32_t program_global_id = ctx.GetGlobalId(&program); + uint32_t kernel_global_id = ctx.GetGlobalId(kernel.get()); + log_debug( + tt::LogMetalTrace, + "{}(uint32): kernel_global_id: {} program_global_id: {} rt_args: {}", + __FUNCTION__, + kernel_global_id, + program_global_id, + runtime_args.size()); + + auto& fbb = ctx.GetBuilder(); + auto [core_spec_type, core_spec_offset] = ToFlatbuffer(fbb, core_spec); + auto rt_args_offset = fbb.CreateVector(runtime_args.data(), runtime_args.size()); + + auto cmd = tt::target::CreateSetRuntimeArgsUint32Command( + fbb, program_global_id, kernel_global_id, core_spec_type, core_spec_offset, rt_args_offset); + CaptureCommand(tt::target::CommandType::SetRuntimeArgsUint32Command, cmd.Union()); +} + +inline void CaptureSetRuntimeArgs( + IDevice* device, + const std::shared_ptr kernel, + const std::variant& core_spec, + std::shared_ptr runtime_args) { + auto& ctx = LightMetalCaptureContext::Get(); + auto& fbb = ctx.GetBuilder(); + uint32_t kernel_global_id = ctx.GetGlobalId(kernel.get()); + auto [core_spec_type, core_spec_offset] = ToFlatbuffer(fbb, core_spec); + auto rt_args_offset = ToFlatbuffer(fbb, runtime_args); + log_debug( + tt::LogMetalTrace, + "{}(RuntimeArgs): kernel_global_id: {} rt_args_size: {}", + __FUNCTION__, + kernel_global_id, + runtime_args->size()); + + auto cmd = tt::target::CreateSetRuntimeArgsCommand( + fbb, kernel_global_id, core_spec_type, core_spec_offset, rt_args_offset); + CaptureCommand(tt::target::CommandType::SetRuntimeArgsCommand, cmd.Union()); +} + +inline void CaptureCreateCircularBuffer( + CBHandle& cb_handle, + Program& program, + const std::variant& core_spec, + const CircularBufferConfig& config) { + auto& ctx = LightMetalCaptureContext::Get(); + auto& fbb = ctx.GetBuilder(); + uint32_t cb_global_id = ctx.AddToMap(cb_handle); + uint32_t program_global_id = ctx.GetGlobalId(&program); + auto [core_spec_type, core_spec_offset] = ToFlatbuffer(fbb, core_spec); + auto cb_config_offset = ToFlatbuffer(config, fbb); + log_debug( + tt::LogMetalTrace, + "{}: cb_global_id: {} program_global_id: {} ", + __FUNCTION__, + cb_global_id, + program_global_id); + + auto cmd = tt::target::CreateCreateCircularBufferCommand( + fbb, cb_global_id, program_global_id, core_spec_type, core_spec_offset, cb_config_offset); + CaptureCommand(tt::target::CommandType::CreateCircularBufferCommand, cmd.Union()); +} + +} // namespace tt::tt_metal diff --git a/tt_metal/impl/lightmetal/lightmetal_capture.cpp b/tt_metal/impl/lightmetal/lightmetal_capture.cpp index 23b70b75444d..11f2b2ebe6b2 100644 --- a/tt_metal/impl/lightmetal/lightmetal_capture.cpp +++ b/tt_metal/impl/lightmetal/lightmetal_capture.cpp @@ -8,6 +8,9 @@ #include "command_generated.h" #include "binary_generated.h" #include "tt_metal/impl/trace/trace_buffer.hpp" +#include "tt_metal/impl/buffers/buffer.hpp" +#include "tt_metal/impl/program/program.hpp" +#include "tt_metal/impl/kernels/kernel.hpp" #include #include @@ -50,7 +53,133 @@ std::vector LightMetalCaptureContext::CreateLightMetalBinary() { void LightMetalCaptureContext::Reset() { builder_.Clear(); + next_global_id_ = 0; cmds_vec_.clear(); + trace_descs_vec_.clear(); + buffer_to_global_id_map_.clear(); + program_to_global_id_map_.clear(); + kernel_to_global_id_map_.clear(); + cb_handle_to_global_id_map_.clear(); +} + +//////////////////////////////////////////// +// Object Map Public Accessors // +//////////////////////////////////////////// + +bool LightMetalCaptureContext::IsInMap(Buffer* obj) { + return buffer_to_global_id_map_.find(obj) != buffer_to_global_id_map_.end(); +} + +uint32_t LightMetalCaptureContext::AddToMap(Buffer* obj) { + if (IsInMap(obj)) { + log_warning(tt::LogMetalTrace, "Buffer already exists in global_id map."); + } + uint32_t global_id = next_global_id_++; + buffer_to_global_id_map_[obj] = global_id; + return global_id; +} + +void LightMetalCaptureContext::RemoveFromMap(Buffer* obj) { + if (!IsInMap(obj)) { + log_warning(tt::LogMetalTrace, "Buffer not found in global_id map."); + } + buffer_to_global_id_map_.erase(obj); +} + +uint32_t LightMetalCaptureContext::GetGlobalId(Buffer* obj) { + auto it = buffer_to_global_id_map_.find(obj); + if (it != buffer_to_global_id_map_.end()) { + return it->second; + } else { + throw std::runtime_error("Buffer not found in global_id global_id map"); + } +} + +bool LightMetalCaptureContext::IsInMap(const Program* obj) { + return program_to_global_id_map_.find(obj) != program_to_global_id_map_.end(); +} + +uint32_t LightMetalCaptureContext::AddToMap(const Program* obj) { + if (IsInMap(obj)) { + log_warning(tt::LogMetalTrace, "Program already exists in global_id map."); + } + uint32_t global_id = next_global_id_++; + program_to_global_id_map_[obj] = global_id; + return global_id; +} + +void LightMetalCaptureContext::RemoveFromMap(const Program* obj) { + if (!IsInMap(obj)) { + log_warning(tt::LogMetalTrace, "Program not found in global_id map."); + } + program_to_global_id_map_.erase(obj); +} + +uint32_t LightMetalCaptureContext::GetGlobalId(const Program* obj) { + auto it = program_to_global_id_map_.find(obj); + if (it != program_to_global_id_map_.end()) { + return it->second; + } else { + throw std::runtime_error("Program not found in global_id map."); + } +} + +bool LightMetalCaptureContext::IsInMap(const Kernel* obj) { + return kernel_to_global_id_map_.find(obj) != kernel_to_global_id_map_.end(); +} + +uint32_t LightMetalCaptureContext::AddToMap(const Kernel* obj) { + if (IsInMap(obj)) { + log_warning(tt::LogMetalTrace, "Kernel already exists in global_id map."); + } + uint32_t global_id = next_global_id_++; + kernel_to_global_id_map_[obj] = global_id; + return global_id; +} + +void LightMetalCaptureContext::RemoveFromMap(const Kernel* obj) { + if (!IsInMap(obj)) { + log_warning(tt::LogMetalTrace, "Kernel not found in global_id map."); + } + kernel_to_global_id_map_.erase(obj); +} + +uint32_t LightMetalCaptureContext::GetGlobalId(const Kernel* obj) { + auto it = kernel_to_global_id_map_.find(obj); + if (it != kernel_to_global_id_map_.end()) { + return it->second; + } else { + throw std::runtime_error("Kernel not found in global_id map."); + } +} + +bool LightMetalCaptureContext::IsInMap(const CBHandle handle) { + return cb_handle_to_global_id_map_.find(handle) != cb_handle_to_global_id_map_.end(); +} + +uint32_t LightMetalCaptureContext::AddToMap(const CBHandle handle) { + if (IsInMap(handle)) { + log_warning(tt::LogMetalTrace, "CBHandle already exists in global_id map."); + } + uint32_t global_id = next_global_id_++; + cb_handle_to_global_id_map_[handle] = global_id; + return global_id; +} + +void LightMetalCaptureContext::RemoveFromMap(const CBHandle handle) { + if (!IsInMap(handle)) { + log_warning(tt::LogMetalTrace, "CBHandle not found in global_id map."); + } + cb_handle_to_global_id_map_.erase(handle); +} + +uint32_t LightMetalCaptureContext::GetGlobalId(const CBHandle handle) { + auto it = cb_handle_to_global_id_map_.find(handle); + if (it != cb_handle_to_global_id_map_.end()) { + return it->second; + } else { + throw std::runtime_error("CBHandle not found in global_id map."); + } } //////////////////////////////////////////// diff --git a/tt_metal/impl/lightmetal/lightmetal_capture.hpp b/tt_metal/impl/lightmetal/lightmetal_capture.hpp index f6d04f07102d..988e1807f90c 100644 --- a/tt_metal/impl/lightmetal/lightmetal_capture.hpp +++ b/tt_metal/impl/lightmetal/lightmetal_capture.hpp @@ -27,7 +27,12 @@ class TraceDescriptor; namespace tt::tt_metal { inline namespace v0 { +class Buffer; +class Program; +class Kernel; +using CBHandle = uintptr_t; using TraceDescriptorByTraceIdOffset = flatbuffers::Offset; + class LightMetalCaptureContext { public: static LightMetalCaptureContext& Get(); @@ -39,9 +44,26 @@ class LightMetalCaptureContext { std::vector>& GetCmdsVector(); void CaptureTraceDescriptor(const detail::TraceDescriptor& trace_desc, const uint32_t tid); std::vector CreateLightMetalBinary(); - void Reset(); + // Object Map Public Accessors + bool IsInMap(Buffer* obj); + uint32_t AddToMap(Buffer* obj); + void RemoveFromMap(Buffer* obj); + uint32_t GetGlobalId(Buffer* obj); + bool IsInMap(const Program* obj); + uint32_t AddToMap(const Program* obj); + void RemoveFromMap(const Program* obj); + uint32_t GetGlobalId(const Program* obj); + bool IsInMap(const Kernel* obj); + uint32_t AddToMap(const Kernel* obj); + void RemoveFromMap(const Kernel* obj); + uint32_t GetGlobalId(const Kernel* obj); + bool IsInMap(const CBHandle handle); + uint32_t AddToMap(const CBHandle handle); + void RemoveFromMap(const CBHandle handle); + uint32_t GetGlobalId(const CBHandle handle); + private: LightMetalCaptureContext(); // Private constructor @@ -50,6 +72,14 @@ class LightMetalCaptureContext { std::vector> cmds_vec_; std::vector trace_descs_vec_; + // Object maps for associating each object with a global_id + uint32_t next_global_id_ = 0; // Shared across all object types. + std::unordered_map buffer_to_global_id_map_; + std::unordered_map program_to_global_id_map_; + std::unordered_map kernel_to_global_id_map_; + std::unordered_map cb_handle_to_global_id_map_; + // TODO (kmabee) - consider adding map for CommandQueue object. + // Delete copy constructor and assignment operator LightMetalCaptureContext(const LightMetalCaptureContext&) = delete; LightMetalCaptureContext& operator=(const LightMetalCaptureContext&) = delete; diff --git a/tt_metal/impl/lightmetal/lightmetal_replay.cpp b/tt_metal/impl/lightmetal/lightmetal_replay.cpp index 210dfb4c8daa..4ab488494762 100644 --- a/tt_metal/impl/lightmetal/lightmetal_replay.cpp +++ b/tt_metal/impl/lightmetal/lightmetal_replay.cpp @@ -13,6 +13,7 @@ #include "tt_metal/detail/tt_metal.hpp" #include "tt_metal/impl/dispatch/command_queue.hpp" #include "tt_metal/impl/device/device.hpp" +#include "tracehost/types_from_flatbuffer.hpp" namespace tt::tt_metal { inline namespace v0 { @@ -81,11 +82,55 @@ detail::TraceDescriptor FromFlatbuffer(const tt::target::lightmetal::TraceDescri return trace_desc; } +// Needs access to BufferMap, so part of LightMetalReplay class +std::shared_ptr LightMetalReplay::FromFlatbufferRtArgs(const FlatbufferRuntimeArgVector flatbuffer_args) { + auto runtime_args = std::make_shared(); + + for (const auto& flatbuffer_arg : *flatbuffer_args) { + const auto* runtime_arg = flatbuffer_arg; + if (!runtime_arg) { + throw std::runtime_error("Null RuntimeArg in FlatBuffer vector"); + } + + // Determine the type of the RuntimeArg + switch (runtime_arg->value_type()) { + case tt::target::RuntimeArgValue::UInt32Value: { + // Extract UInt32Value + const auto* uint32_value = runtime_arg->value_as_UInt32Value(); + if (!uint32_value) { + throw std::runtime_error("Failed to read UInt32Value"); + } + runtime_args->emplace_back(uint32_value->value()); + break; + } + case tt::target::RuntimeArgValue::BufferGlobalId: { + // Extract BufferGlobalId + const auto* buffer_global_id = runtime_arg->value_as_BufferGlobalId(); + if (!buffer_global_id) { + throw std::runtime_error("Failed to read BufferGlobalId"); + } + uint32_t global_id = buffer_global_id->id(); + auto buffer = GetBufferFromMap(global_id); + if (!buffer) { + throw std::runtime_error( + "Buffer w/ global_id: " + std::to_string(global_id) + " not previously created"); + } + runtime_args->emplace_back(buffer.get()); + break; + } + default: throw std::runtime_error("Unknown RuntimeArgValue type in FlatBuffer"); + } + } + + return runtime_args; +} + ////////////////////////////////////// // LightMetalReplay Class // ////////////////////////////////////// LightMetalReplay::LightMetalReplay(std::vector&& blob) : blob_(std::move(blob)), lm_binary_(nullptr) { + show_reads_ = parse_env("TT_LIGHT_METAL_SHOW_READS", false); lm_binary_ = ParseFlatBufferBinary(); // Parse and store the FlatBuffer binary if (!lm_binary_) { throw std::runtime_error("Failed to parse FlatBuffer binary during initialization."); @@ -126,6 +171,92 @@ std::optional LightMetalReplay::GetTraceByTraceId(uint3 return std::nullopt; } +////////////////////////////////////// +// Object Map Public Accessors // +////////////////////////////////////// + +void LightMetalReplay::AddBufferToMap(uint32_t global_id, std::shared_ptr<::tt::tt_metal::Buffer> buffer) { + if (buffer_map_.find(global_id) != buffer_map_.end()) { + log_warning(tt::LogMetalTrace, "Buffer with global_id: {} already exists in map.", global_id); + } + buffer_map_[global_id] = buffer; // Shared ownership +} + +std::shared_ptr<::tt::tt_metal::Buffer> LightMetalReplay::GetBufferFromMap(uint32_t global_id) const { + auto it = buffer_map_.find(global_id); + if (it != buffer_map_.end()) { + return it->second; // Return shared_ptr + } + return nullptr; // If not found +} + +void LightMetalReplay::RemoveBufferFromMap(uint32_t global_id) { buffer_map_.erase(global_id); } + +void LightMetalReplay::AddProgramToMap(uint32_t global_id, std::shared_ptr<::tt::tt_metal::Program> program) { + if (program_map_.find(global_id) != program_map_.end()) { + log_warning(tt::LogMetalTrace, "Program with global_id: {} already exists in map.", global_id); + } + program_map_[global_id] = program; // Shared ownership +} + +std::shared_ptr<::tt::tt_metal::Program> LightMetalReplay::GetProgramFromMap(uint32_t global_id) const { + auto it = program_map_.find(global_id); + if (it != program_map_.end()) { + return it->second; // Return shared_ptr + } + return nullptr; // If not found +} + +void LightMetalReplay::RemoveProgramFromMap(uint32_t global_id) { program_map_.erase(global_id); } + +void LightMetalReplay::AddKernelHandleToMap(uint32_t global_id, ::tt::tt_metal::KernelHandle kernel_id) { + if (kernel_handle_map_.find(global_id) != kernel_handle_map_.end()) { + log_warning(tt::LogMetalTrace, "KernelHandle with global_id: {} already exists in map.", global_id); + } + kernel_handle_map_[global_id] = kernel_id; // Shared ownership +} + +::tt::tt_metal::KernelHandle LightMetalReplay::GetKernelHandleFromMap(uint32_t global_id) const { + if (auto it = kernel_handle_map_.find(global_id); it != kernel_handle_map_.end()) { + return it->second; // Return KernelHandle. + } + throw std::runtime_error(fmt::format("KernelHandle with global_id: {} used but doesn't exist.", global_id)); +} + +void LightMetalReplay::RemoveKernelHandleFromMap(uint32_t global_id) { kernel_handle_map_.erase(global_id); } + +void LightMetalReplay::AddKernelToMap(uint32_t global_id, std::shared_ptr<::tt::tt_metal::Kernel> kernel) { + if (kernel_map_.find(global_id) != kernel_map_.end()) { + log_warning(tt::LogMetalTrace, "Kernel with global_id: {} already exists in map.", global_id); + } + kernel_map_[global_id] = kernel; // Shared ownership +} + +std::shared_ptr<::tt::tt_metal::Kernel> LightMetalReplay::GetKernelFromMap(uint32_t global_id) const { + if (auto it = kernel_map_.find(global_id); it != kernel_map_.end()) { + return it->second; // Return Kernel. + } + throw std::runtime_error(fmt::format("Kernel with global_id: {} used but doesn't exist.", global_id)); +} + +void LightMetalReplay::RemoveKernelFromMap(uint32_t global_id) { kernel_map_.erase(global_id); } + +void LightMetalReplay::AddCBHandleToMap(uint32_t global_id, ::tt::tt_metal::CBHandle cb_handle) { + if (cb_handle_map_.find(global_id) != cb_handle_map_.end()) { + log_warning(tt::LogMetalTrace, "CBHandle with global_id: {} already exists in map.", global_id); + } + cb_handle_map_[global_id] = cb_handle; // Shared ownership +} + +::tt::tt_metal::CBHandle LightMetalReplay::GetCBHandleFromMap(uint32_t global_id) const { + if (auto it = cb_handle_map_.find(global_id); it != cb_handle_map_.end()) { + return it->second; // Return CBHandle. + } + throw std::runtime_error(fmt::format("CBHandle with global_id: {} used but doesn't exist.", global_id)); +} + +void LightMetalReplay::RemoveCBHandleFromMap(uint32_t global_id) { cb_handle_map_.erase(global_id); } + ////////////////////////////////////// // Device Setup/Teardown // ////////////////////////////////////// @@ -165,6 +296,54 @@ void LightMetalReplay::Execute(const tt::target::Command* command) { Execute(command->cmd_as_LoadTraceCommand()); break; } + case ::tt::target::CommandType::ReleaseTraceCommand: { + Execute(command->cmd_as_ReleaseTraceCommand()); + break; + } + case ::tt::target::CommandType::CreateBufferCommand: { + Execute(command->cmd_as_CreateBufferCommand()); + break; + } + case ::tt::target::CommandType::DeallocateBufferCommand: { + Execute(command->cmd_as_DeallocateBufferCommand()); + break; + } + case ::tt::target::CommandType::EnqueueWriteBufferCommand: { + Execute(command->cmd_as_EnqueueWriteBufferCommand()); + break; + } + case ::tt::target::CommandType::EnqueueReadBufferCommand: { + Execute(command->cmd_as_EnqueueReadBufferCommand()); + break; + } + case ::tt::target::CommandType::FinishCommand: { + Execute(command->cmd_as_FinishCommand()); + break; + } + case ::tt::target::CommandType::CreateProgramCommand: { + Execute(command->cmd_as_CreateProgramCommand()); + break; + } + case ::tt::target::CommandType::EnqueueProgramCommand: { + Execute(command->cmd_as_EnqueueProgramCommand()); + break; + } + case ::tt::target::CommandType::CreateKernelCommand: { + Execute(command->cmd_as_CreateKernelCommand()); + break; + } + case ::tt::target::CommandType::SetRuntimeArgsUint32Command: { + Execute(command->cmd_as_SetRuntimeArgsUint32Command()); + break; + } + case ::tt::target::CommandType::SetRuntimeArgsCommand: { + Execute(command->cmd_as_SetRuntimeArgsCommand()); + break; + } + case ::tt::target::CommandType::CreateCircularBufferCommand: { + Execute(command->cmd_as_CreateCircularBufferCommand()); + break; + } default: throw std::runtime_error("Unsupported type: " + std::string(EnumNameCommandType(command->cmd_type()))); break; @@ -173,7 +352,7 @@ void LightMetalReplay::Execute(const tt::target::Command* command) { // Per API command handlers. void LightMetalReplay::Execute(const tt::target::EnqueueTraceCommand* cmd) { - log_info( + log_debug( tt::LogMetalTrace, "LightMetalReplay(EnqueueTrace) cq_id: {} tid: {} blocking: {}", cmd->cq_id(), @@ -194,12 +373,210 @@ void LightMetalReplay::Execute(const tt::target::ReplayTraceCommand* cmd) { } void LightMetalReplay::Execute(const tt::target::LoadTraceCommand* cmd) { - log_info(tt::LogMetalTrace, "LightMetalReplay(LoadTrace) cq_id: {} tid: {}", cmd->cq_id(), cmd->tid()); + log_debug(tt::LogMetalTrace, "LightMetalReplay(LoadTrace) cq_id: {} tid: {}", cmd->cq_id(), cmd->tid()); // Get the trace descriptor from flatbuffer and load it to device. auto trace_desc = GetTraceByTraceId(cmd->tid()); LoadTrace(this->device_, cmd->cq_id(), cmd->tid(), trace_desc.value()); } +void LightMetalReplay::Execute(const tt::target::ReleaseTraceCommand* cmd) { + log_debug(tt::LogMetalTrace, "LightMetalReplay(ReleaseTrace) tid: {}", cmd->tid()); + ReleaseTrace(this->device_, cmd->tid()); +} + +void LightMetalReplay::Execute(const tt::target::CreateBufferCommand* cmd) { + log_debug( + tt::LogMetalTrace, + "LightMetalReplay(CreateBuffer) global_id: {} size: {} page_size: {} layout: {} buffer_type: {}", + cmd->global_id(), + cmd->config()->size(), + cmd->config()->page_size(), + EnumNameTensorMemoryLayout(cmd->config()->buffer_layout()), + EnumNameBufferType(cmd->config()->buffer_type())); + + switch (cmd->config()->buffer_layout()) { + case tt::target::TensorMemoryLayout::Interleaved: { + tt::tt_metal::InterleavedBufferConfig config{ + .device = this->device_, + .size = cmd->config()->size(), + .page_size = cmd->config()->page_size(), + .buffer_type = FromFlatbuffer(cmd->config()->buffer_type())}; + + auto buffer = CreateBuffer(config); + AddBufferToMap(cmd->global_id(), buffer); + break; + } + default: + throw std::runtime_error( + "Unsupported buffer_layout: " + + std::string(EnumNameTensorMemoryLayout(cmd->config()->buffer_layout()))); + } +} + +void LightMetalReplay::Execute(const tt::target::DeallocateBufferCommand* cmd) { + auto buffer = GetBufferFromMap(cmd->global_id()); + if (!buffer) { + throw std::runtime_error( + "Buffer w/ global_id: " + std::to_string(cmd->global_id()) + " not previously created"); + } + + log_debug(tt::LogMetalTrace, "LightMetalReplay(DeallocateBuffer) global_id: {}", cmd->global_id()); + DeallocateBuffer(*buffer); // Buffer& expected. + RemoveBufferFromMap(cmd->global_id()); +} + +void LightMetalReplay::Execute(const tt::target::EnqueueWriteBufferCommand* cmd) { + auto buffer = GetBufferFromMap(cmd->buffer_global_id()); + if (!buffer) { + throw std::runtime_error( + "Buffer w/ global_id: " + std::to_string(cmd->buffer_global_id()) + " not previously created"); + } + + log_debug( + tt::LogMetalTrace, + "LightMetalReplay(EnqueueWriteBuffer) cq_global_id: {} buffer_global_id: {} addr: 0x{:x}", + cmd->cq_global_id(), + cmd->buffer_global_id(), + buffer->address()); + + // TODO (kmabee) - consider storing/getting CQ from global map instead. + CommandQueue& cq = this->device_->command_queue(cmd->cq_global_id()); + EnqueueWriteBuffer(cq, buffer, cmd->src()->data(), cmd->blocking()); +} + +void LightMetalReplay::Execute(const tt::target::EnqueueReadBufferCommand* cmd) { + auto buffer = GetBufferFromMap(cmd->buffer_global_id()); + if (!buffer) { + throw std::runtime_error( + "Buffer w/ global_id: " + std::to_string(cmd->buffer_global_id()) + " not previously created"); + } + + log_debug( + tt::LogMetalTrace, + "LightMetalReplay(EnqueueReadBuffer) cq_global_id: {} buffer_global_id: {} addr: 0x{:x} buf_size: {}", + cmd->cq_global_id(), + cmd->buffer_global_id(), + buffer->address(), + buffer->size()); + + // TODO (kmabee) - consider storing/getting CQ from global map instead. + CommandQueue& cq = this->device_->command_queue(cmd->cq_global_id()); + std::vector readback_data(buffer->size() / sizeof(uint32_t), 0); + EnqueueReadBuffer(cq, buffer, readback_data.data(), cmd->blocking()); + + // TODO (kmabee) - TBD what to do with readback data. For now, optionally print. + // One idea is to store in map by global_read_id that caller can access. + if (show_reads_) { + for (size_t i = 0; i < readback_data.size(); i++) { + log_info(tt::LogMetalTrace, " rd_data i: {:3d} => data: {} ({:x})", i, readback_data[i], readback_data[i]); + } + } +} + +void LightMetalReplay::Execute(const tt::target::FinishCommand* cmd) { + log_debug(tt::LogMetalTrace, "LightMetalReplay(Finish) cq_global_id: {}", cmd->cq_global_id()); + CommandQueue& cq = this->device_->command_queue(cmd->cq_global_id()); + Finish(cq); +} + +void LightMetalReplay::Execute(const tt::target::CreateProgramCommand* cmd) { + log_debug(tt::LogMetalTrace, "LightMetalReplay(CreateProgram) global_id: {} ", cmd->global_id()); + auto program = CreateProgram(); + AddProgramToMap(cmd->global_id(), std::make_shared(std::move(program))); +} + +void LightMetalReplay::Execute(const tt::target::EnqueueProgramCommand* cmd) { + auto program = GetProgramFromMap(cmd->program_global_id()); + if (!program) { + throw std::runtime_error( + "Program with global_id: " + std::to_string(cmd->program_global_id()) + " not previously created"); + } + log_debug( + tt::LogMetalTrace, + "LightMetalReplay(EnqueueProgram) program_global_id: {} cq_global_id: {}", + cmd->program_global_id(), + cmd->cq_global_id()); + + // TODO (kmabee) - consider storing/getting CQ from global map instead. + CommandQueue& cq = this->device_->command_queue(cmd->cq_global_id()); + EnqueueProgram(cq, *program, cmd->blocking()); +} + +void LightMetalReplay::Execute(const tt::target::CreateKernelCommand* cmd) { + log_debug( + tt::LogMetalTrace, + "LightMetalReplay(CreateKernel) global_id: {} program_global_id: {}", + cmd->global_id(), + cmd->program_global_id()); + auto program = GetProgramFromMap(cmd->program_global_id()); + if (!program) { + throw std::runtime_error( + "Program with global_id: " + std::to_string(cmd->program_global_id()) + " not previously created"); + } + + auto core_spec = FromFlatbuffer(cmd->core_spec_type(), cmd->core_spec()); + auto kernel_config = FromFlatbuffer(cmd->config_type(), cmd->config()); + auto kernel_id = CreateKernel(*program, cmd->file_name()->c_str(), core_spec, kernel_config); + AddKernelHandleToMap(cmd->global_id(), kernel_id); + // Some APIs use Kernel, so convert to and store Kernel. + std::shared_ptr kernel = program->get_kernel(kernel_id); + AddKernelToMap(cmd->global_id(), kernel); +} + +void LightMetalReplay::Execute(const tt::target::SetRuntimeArgsUint32Command* cmd) { + log_debug( + tt::LogMetalTrace, + "LightMetalReplay(SetRuntimeArgs). program_global_id: {} kernel_global_id: {}", + cmd->program_global_id(), + cmd->kernel_global_id()); + auto program = GetProgramFromMap(cmd->program_global_id()); + auto kernel_id = GetKernelHandleFromMap(cmd->kernel_global_id()); + + if (!program) { + throw std::runtime_error( + "Program with global_id: " + std::to_string(cmd->program_global_id()) + " not previously created"); + } + + // API expects a span so create from flatbuffer vector. + stl::Span args_span(cmd->args()->data(), cmd->args()->size()); + auto core_spec = FromFlatbuffer(cmd->core_spec_type(), cmd->core_spec()); + SetRuntimeArgs(*program, kernel_id, core_spec, args_span); +} + +void LightMetalReplay::Execute(const tt::target::SetRuntimeArgsCommand* cmd) { + log_debug( + tt::LogMetalTrace, + "LightMetalReplay(SetRuntimeArgs). kernel_global_id: {} rt_args_size: {}", + cmd->kernel_global_id(), + cmd->args()->size()); + auto core_spec = FromFlatbuffer(cmd->core_spec_type(), cmd->core_spec()); + auto runtime_args = FromFlatbufferRtArgs(cmd->args()); + auto kernel = GetKernelFromMap(cmd->kernel_global_id()); + if (!kernel) { + throw std::runtime_error( + "Kernel with global_id: " + std::to_string(cmd->kernel_global_id()) + " not previously created"); + } + SetRuntimeArgs(this->device_, kernel, core_spec, runtime_args); +} + +void LightMetalReplay::Execute(const tt::target::CreateCircularBufferCommand* cmd) { + log_debug( + tt::LogMetalTrace, + "LightMetalReplay(CreateCircularBuffer) global_id: {} program_global_id: {}", + cmd->global_id(), + cmd->program_global_id()); + auto program = GetProgramFromMap(cmd->program_global_id()); + if (!program) { + throw std::runtime_error( + "Program with global_id: " + std::to_string(cmd->program_global_id()) + " not previously created"); + } + + auto core_spec = FromFlatbuffer(cmd->core_spec_type(), cmd->core_spec()); + auto config = FromFlatbuffer(cmd->config()); + auto cb_handle = CreateCircularBuffer(*program, core_spec, config); + AddCBHandleToMap(cmd->global_id(), cb_handle); +} + // Main entry point to execute a light metal binary blob, return true if pass. bool LightMetalReplay::ExecuteLightMetalBinary() { if (!lm_binary_) { diff --git a/tt_metal/impl/lightmetal/lightmetal_replay.hpp b/tt_metal/impl/lightmetal/lightmetal_replay.hpp index 852a7a128752..ee0a6462c07c 100644 --- a/tt_metal/impl/lightmetal/lightmetal_replay.hpp +++ b/tt_metal/impl/lightmetal/lightmetal_replay.hpp @@ -22,6 +22,19 @@ struct Command; struct ReplayTraceCommand; struct EnqueueTraceCommand; struct LoadTraceCommand; +struct ReleaseTraceCommand; +struct CreateBufferCommand; +struct DeallocateBufferCommand; +struct EnqueueWriteBufferCommand; +struct EnqueueReadBufferCommand; +struct FinishCommand; +struct CreateProgramCommand; +struct EnqueueProgramCommand; +struct CreateKernelCommand; +struct SetRuntimeArgsUint32Command; +struct SetRuntimeArgsCommand; +struct CreateCircularBufferCommand; +struct RuntimeArg; // Forward decl for binary_generated.h namespace lightmetal { @@ -31,6 +44,9 @@ struct LightMetalBinary; } // namespace lightmetal } // namespace tt::target +using FlatbufferRuntimeArgVector = const flatbuffers::Vector>*; +using RuntimeArgs = std::vector>; + namespace tt::tt_metal { inline namespace v0 { @@ -48,6 +64,9 @@ class LightMetalReplay { // Return the TraceDescriptor for a given trace_id from flatbuffer. std::optional GetTraceByTraceId(uint32_t target_trace_id); + // fromFlatBuffer that need class state + std::shared_ptr FromFlatbufferRtArgs(const FlatbufferRuntimeArgVector flatbuffer_args); + // Execute the stored LightMetal binary bool ExecuteLightMetalBinary(); @@ -56,6 +75,39 @@ class LightMetalReplay { void Execute(const tt::target::EnqueueTraceCommand* command); void Execute(const tt::target::ReplayTraceCommand* command); void Execute(const tt::target::LoadTraceCommand* command); + void Execute(const tt::target::ReleaseTraceCommand* command); + void Execute(const tt::target::CreateBufferCommand* command); + void Execute(const tt::target::DeallocateBufferCommand* command); + void Execute(const tt::target::EnqueueWriteBufferCommand* command); + void Execute(const tt::target::EnqueueReadBufferCommand* command); + void Execute(const tt::target::FinishCommand* command); + void Execute(const tt::target::CreateProgramCommand* command); + void Execute(const tt::target::EnqueueProgramCommand* command); + void Execute(const tt::target::CreateKernelCommand* command); + void Execute(const tt::target::SetRuntimeArgsUint32Command* command); + void Execute(const tt::target::SetRuntimeArgsCommand* command); + void Execute(const tt::target::CreateCircularBufferCommand* command); + + // Object maps public accessors + void AddBufferToMap(uint32_t global_id, std::shared_ptr<::tt::tt_metal::Buffer> buffer); + std::shared_ptr<::tt::tt_metal::Buffer> GetBufferFromMap(uint32_t global_id) const; + void RemoveBufferFromMap(uint32_t global_id); + + void AddProgramToMap(uint32_t global_id, std::shared_ptr<::tt::tt_metal::Program> program); + std::shared_ptr<::tt::tt_metal::Program> GetProgramFromMap(uint32_t global_id) const; + void RemoveProgramFromMap(uint32_t global_id); + + void AddKernelHandleToMap(uint32_t global_id, ::tt::tt_metal::KernelHandle kernel_id); + ::tt::tt_metal::KernelHandle GetKernelHandleFromMap(uint32_t global_id) const; + void RemoveKernelHandleFromMap(uint32_t global_id); + + void AddKernelToMap(uint32_t global_id, std::shared_ptr<::tt::tt_metal::Kernel> kernel); + std::shared_ptr<::tt::tt_metal::Kernel> GetKernelFromMap(uint32_t global_id) const; + void RemoveKernelFromMap(uint32_t global_id); + + void AddCBHandleToMap(uint32_t global_id, ::tt::tt_metal::CBHandle cb_handle); + ::tt::tt_metal::CBHandle GetCBHandleFromMap(uint32_t global_id) const; + void RemoveCBHandleFromMap(uint32_t global_id); private: // Workload related members -------------------- @@ -63,6 +115,7 @@ class LightMetalReplay { std::vector blob_; // Stored binary blob const target::lightmetal::LightMetalBinary* lm_binary_; // Parsed FlatBuffer binary + bool show_reads_ = false; // Flag to show read buffer contents // System related members ---------------------- void SetupDevices(); @@ -70,6 +123,13 @@ class LightMetalReplay { tt::tt_metal::IDevice* device_; tt::ARCH arch_; + + // Object maps for storing objects by global_id + std::unordered_map> buffer_map_; + std::unordered_map> program_map_; + std::unordered_map kernel_handle_map_; + std::unordered_map> kernel_map_; + std::unordered_map cb_handle_map_; }; } // namespace v0 diff --git a/tt_metal/impl/tracehost/command.fbs b/tt_metal/impl/tracehost/command.fbs index f4b104034f9c..b9fe5aaac59e 100644 --- a/tt_metal/impl/tracehost/command.fbs +++ b/tt_metal/impl/tracehost/command.fbs @@ -1,4 +1,5 @@ // Define schema for tracing host API calls, called Commands in this context. +include "tracehost/types.fbs"; namespace tt.target; @@ -21,10 +22,93 @@ table LoadTraceCommand { cq_id: int; } +table ReleaseTraceCommand { + // TODO (kmabee) - add device. + tid: int; // Pointer to trace data. +} + +table CreateBufferCommand { + global_id: uint32; + config: InterleavedBufferConfig; // Later grow to union for Sharded. + address: uint32; // Optional for pre-allocated buffers. +} + +table DeallocateBufferCommand { + global_id: uint32; // Reference to Buffer to be deallocated +} + +table EnqueueWriteBufferCommand { + cq_global_id: uint32; // reference to CommandQueue + buffer_global_id: uint32; // Reference to Buffer used as destination + src: [uint32]; // Data to be written. Support only some types for now. + blocking: bool; +} + +table EnqueueReadBufferCommand { + cq_global_id: uint32; // reference to CommandQueue + buffer_global_id: uint32; // Reference to Buffer used as source + // dst unsure what to do here. + blocking: bool; +} + +table FinishCommand { + cq_global_id: uint32; // reference to CommandQueue +} + +table CreateProgramCommand { + global_id: uint32; +} + +table EnqueueProgramCommand { + cq_global_id: uint32; // reference to CommandQueue + program_global_id: uint32; // Reference to Program + blocking: bool; +} + +table CreateKernelCommand { + global_id: uint32; // Reference to Kernel + program_global_id: uint32; // Reference to Program + file_name: string; // Later replace with src, then binary + core_spec: CoreSpec; + config: KernelConfig; +} + +table SetRuntimeArgsUint32Command { + program_global_id: uint32; // Reference to Program + kernel_global_id: uint32; // Reference to Kernel + core_spec: CoreSpec; + args: [uint32]; // Arguments to be passed to kernel +} + +table SetRuntimeArgsCommand { + kernel_global_id: uint32; // Reference to Kernel + core_spec: CoreSpec; + args: [RuntimeArg]; // Arguments to be passed to kernel +} + +table CreateCircularBufferCommand { + global_id: uint32; // Reference to CBHandle + program_global_id: uint32; // Reference to Program + core_spec: CoreSpec; + config: CircularBufferConfig; +} + union CommandType { ReplayTraceCommand, EnqueueTraceCommand, - LoadTraceCommand + LoadTraceCommand, + ReleaseTraceCommand, + CreateBufferCommand, + DeallocateBufferCommand, + EnqueueWriteBufferCommand, + EnqueueReadBufferCommand, + FinishCommand, + CreateProgramCommand, + EnqueueProgramCommand, + CreateKernelCommand, + SetRuntimeArgsUint32Command, + SetRuntimeArgsCommand, + CreateCircularBufferCommand, } table Command { diff --git a/tt_metal/impl/tracehost/types.fbs b/tt_metal/impl/tracehost/types.fbs new file mode 100644 index 000000000000..fc9c3c9fea10 --- /dev/null +++ b/tt_metal/impl/tracehost/types.fbs @@ -0,0 +1,211 @@ +namespace tt.target; + +enum Arch: uint { + Grayskull = 0, + Wormhole_b0 = 1, + Blackhole = 2, +} + +enum BufferType: ushort { + DRAM = 0, + L1 = 1, + SystemMemory = 2, + L1Small = 3, + Trace = 4, +} + +enum TensorMemoryLayout: ushort { + None = 0, + Interleaved = 1, + SingleBank = 2, + HeightSharded = 3, + WidthSharded = 4, + BlockSharded = 5, +} + +table InterleavedBufferConfig { + device_id: int; // Device *device; + size: int; // Size in bytes + page_size: int; // Size of unit being interleaved. For non-interleaved buffers: size == page_size + buffer_type: BufferType; + buffer_layout: TensorMemoryLayout; +} + + +// Core Types ////////////////// + +table CoreCoord { + x: int; + y: int; +} + +table CoreRange { + start: CoreCoord; + end: CoreCoord; +} + +table CoreRangeSet { + ranges: [CoreRange]; +} + +union CoreSpec { + CoreCoord, + CoreRange, + CoreRangeSet +} + +enum DataMovementProcessor : byte { + RISCV_0, + RISCV_1 +} + +enum NOC : byte { + NOC_0, + NOC_1 +} + +enum NOC_MODE : byte { + DM_DEDICATED_NOC, + DM_DYNAMIC_NOC +} + +enum Eth : ubyte { + SENDER = 0, + RECEIVER = 1, + IDLE = 2 +} + +enum MathFidelity : ubyte { + LoFi = 0, + HiFi2 = 2, + HiFi3 = 3, + HiFi4 = 4, + Invalid = 255 +} + +enum UnpackToDestMode : byte { + Default, + UnpackToDestFp32 +} + +table DefineEntry { + key: string; + value: string; +} + +// Kernel Configurations ////////////////// + +table DataMovementConfig { + processor: DataMovementProcessor; + noc: NOC; + noc_mode: NOC_MODE; + compile_args: [uint32]; // Array of compile arguments + defines: [DefineEntry]; // Key-value pair map for defines +} + +table ComputeConfig { + math_fidelity: MathFidelity; + fp32_dest_acc_en: bool; + dst_full_sync_en: bool; + unpack_to_dest_mode: [UnpackToDestMode]; // Array of unpack modes + bfp8_pack_precise: bool; + math_approx_mode: bool; + compile_args: [uint32]; // Array of compile arguments + defines: [DefineEntry]; // Key-value pair map for defines +} + +table EthernetConfig { + eth_mode: Eth; + noc: NOC; + processor: DataMovementProcessor; + compile_args: [uint32]; // Array of compile arguments + defines: [DefineEntry]; // Key-value pair map for defines +} + +// Union to include multiple configurations +union KernelConfig { + DataMovementConfig, + ComputeConfig, + EthernetConfig +} + + +table Tile { + tile_shape: [uint32]; // Shape of the tile (e.g., height, width) + face_shape: [uint32]; // Shape of the face + tile_hw: uint32; // Tile hardware size + face_hw: uint32; // Face hardware size + num_faces: uint32; // Number of faces + partial_face: uint32; // Indicates if this is a partial face + narrow_tile: uint32; // Indicates if this is a narrow tile + transpose_within_face: bool; // Transpose within each face + transpose_of_faces: bool; // Transpose face order +} + +struct CBConfigPageSize { + index: uint32; // The index in the array + size: uint32; // The page-size value for this index +} + +enum DataFormat : uint8 { + Float32 = 0, + Float16 = 1, + Bfp8 = 2, + Bfp4 = 3, + Bfp2 = 11, + Float16_b = 5, + Bfp8_b = 6, + Bfp4_b = 7, + Bfp2_b = 15, + Lf8 = 10, + Fp8_e4m3 = 26, // 0x1A in decimal + Int8 = 14, + Tf32 = 4, + UInt8 = 30, + UInt16 = 9, + Int32 = 8, + UInt32 = 24, + RawUInt8 = 240, // 0xf0 in decimal + RawUInt16 = 241, // 0xf1 in decimal + RawUInt32 = 242, // 0xf2 in decimal + Invalid = 255 +} + +struct CBConfigDataFormat { + index: uint32; // The index in the array + format: DataFormat; // The data format for this index +} + +table CircularBufferConfig { + total_size: uint32; + globally_allocated_address: uint32; // Optional behavior can be handled with a default value (or union) + data_formats: [CBConfigDataFormat]; // Optional arrays are naturally nullable in FlatBuffers + page_sizes: [CBConfigPageSize]; // Mimic optional array in C++ by using KV map. + tiles: [Tile]; + // FIXME (kmabee) - buffer pointer missing + buffer_indices: [uint8]; + local_buffer_indices: [uint8]; + remote_buffer_indices: [uint8]; + dynamic_cb: bool; + max_size: uint32; + buffer_size: uint32; +} + +// Runtime Args + +table UInt32Value { + value: uint32; +} + +table BufferGlobalId { + id: uint32; +} + +union RuntimeArgValue { + UInt32Value, + BufferGlobalId, +} + +table RuntimeArg { + value: RuntimeArgValue; +} diff --git a/tt_metal/impl/tracehost/types_from_flatbuffer.hpp b/tt_metal/impl/tracehost/types_from_flatbuffer.hpp new file mode 100644 index 000000000000..ac921e028561 --- /dev/null +++ b/tt_metal/impl/tracehost/types_from_flatbuffer.hpp @@ -0,0 +1,310 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +////////////////////////////////////////////////////////////// +// From-flatbuffer helper functions // +////////////////////////////////////////////////////////////// + +namespace tt::tt_metal { +inline namespace v0 { + +inline BufferType FromFlatbuffer(tt::target::BufferType type) { + switch (type) { + case tt::target::BufferType::DRAM: return BufferType::DRAM; + case tt::target::BufferType::L1: return BufferType::L1; + case tt::target::BufferType::SystemMemory: return BufferType::SYSTEM_MEMORY; + case tt::target::BufferType::L1Small: return BufferType::L1_SMALL; + case tt::target::BufferType::Trace: return BufferType::TRACE; + default: throw std::invalid_argument("Unknown BufferType value in FromFlatbuffer()"); + } +} + +inline tt::tt_metal::DataMovementProcessor FromFlatbuffer(tt::target::DataMovementProcessor in) { + switch (in) { + case tt::target::DataMovementProcessor::RISCV_0: return tt::tt_metal::DataMovementProcessor::RISCV_0; + case tt::target::DataMovementProcessor::RISCV_1: return tt::tt_metal::DataMovementProcessor::RISCV_1; + default: throw std::invalid_argument("Unknown DataMovementProcessor value in FromFlatbuffer()"); + } +} + +inline tt::tt_metal::NOC FromFlatbuffer(tt::target::NOC in) { + switch (in) { + case tt::target::NOC::NOC_0: return tt::tt_metal::NOC::NOC_0; + case tt::target::NOC::NOC_1: return tt::tt_metal::NOC::NOC_1; + default: throw std::invalid_argument("Invalid NOC value passed to FromFlatbuffer"); + } +} + +inline tt::tt_metal::NOC_MODE FromFlatbuffer(tt::target::NOC_MODE in) { + switch (in) { + case tt::target::NOC_MODE::DM_DEDICATED_NOC: return tt::tt_metal::NOC_MODE::DM_DEDICATED_NOC; + case tt::target::NOC_MODE::DM_DYNAMIC_NOC: return tt::tt_metal::NOC_MODE::DM_DYNAMIC_NOC; + default: throw std::invalid_argument("Unknown NOC_MODE value in FromFlatbuffer()"); + } +} + +inline tt::tt_metal::Eth FromFlatbuffer(tt::target::Eth in) { + switch (in) { + case tt::target::Eth::SENDER: return tt::tt_metal::Eth::SENDER; + case tt::target::Eth::RECEIVER: return tt::tt_metal::Eth::RECEIVER; + case tt::target::Eth::IDLE: return tt::tt_metal::Eth::IDLE; + default: throw std::invalid_argument("Unknown Eth value in FromFlatbuffer()"); + } +} + +inline MathFidelity FromFlatbuffer(tt::target::MathFidelity input) { + switch (input) { + case tt::target::MathFidelity::LoFi: return MathFidelity::LoFi; + case tt::target::MathFidelity::HiFi2: return MathFidelity::HiFi2; + case tt::target::MathFidelity::HiFi3: return MathFidelity::HiFi3; + case tt::target::MathFidelity::HiFi4: return MathFidelity::HiFi4; + case tt::target::MathFidelity::Invalid: return MathFidelity::Invalid; + default: throw std::invalid_argument("Unknown MathFidelity value in FromFlatbuffer()"); + } +} + +inline UnpackToDestMode FromFlatbuffer(tt::target::UnpackToDestMode input) { + switch (input) { + case tt::target::UnpackToDestMode::UnpackToDestFp32: return UnpackToDestMode::UnpackToDestFp32; + case tt::target::UnpackToDestMode::Default: return UnpackToDestMode::Default; + default: throw std::invalid_argument("Invalid UnpackToDestMode value passed to FromFlatbuffer"); + } +} + +inline tt::DataFormat FromFlatbuffer(tt::target::DataFormat input) { + switch (input) { + case tt::target::DataFormat::Float32: return tt::DataFormat::Float32; + case tt::target::DataFormat::Float16: return tt::DataFormat::Float16; + case tt::target::DataFormat::Bfp8: return tt::DataFormat::Bfp8; + case tt::target::DataFormat::Bfp4: return tt::DataFormat::Bfp4; + case tt::target::DataFormat::Bfp2: return tt::DataFormat::Bfp2; + case tt::target::DataFormat::Float16_b: return tt::DataFormat::Float16_b; + case tt::target::DataFormat::Bfp8_b: return tt::DataFormat::Bfp8_b; + case tt::target::DataFormat::Bfp4_b: return tt::DataFormat::Bfp4_b; + case tt::target::DataFormat::Bfp2_b: return tt::DataFormat::Bfp2_b; + case tt::target::DataFormat::Lf8: return tt::DataFormat::Lf8; + case tt::target::DataFormat::Fp8_e4m3: return tt::DataFormat::Fp8_e4m3; + case tt::target::DataFormat::Int8: return tt::DataFormat::Int8; + case tt::target::DataFormat::Tf32: return tt::DataFormat::Tf32; + case tt::target::DataFormat::UInt8: return tt::DataFormat::UInt8; + case tt::target::DataFormat::UInt16: return tt::DataFormat::UInt16; + case tt::target::DataFormat::Int32: return tt::DataFormat::Int32; + case tt::target::DataFormat::UInt32: return tt::DataFormat::UInt32; + case tt::target::DataFormat::RawUInt8: return tt::DataFormat::RawUInt8; + case tt::target::DataFormat::RawUInt16: return tt::DataFormat::RawUInt16; + case tt::target::DataFormat::RawUInt32: return tt::DataFormat::RawUInt32; + case tt::target::DataFormat::Invalid: return tt::DataFormat::Invalid; + default: throw std::invalid_argument("Unknown DataFormat value in FromFlatbuffer()"); + } +} + +inline std::variant FromFlatbuffer( + const tt::target::CoreSpec core_spec, const void* flatbuffer_union) { + switch (core_spec) { + case tt::target::CoreSpec::CoreCoord: { + auto core_coord = static_cast(flatbuffer_union); + if (!core_coord) { + throw std::runtime_error("Invalid CoreCoord data"); + } + return CoreCoord{core_coord->x(), core_coord->y()}; + } + case tt::target::CoreSpec::CoreRange: { + auto core_range = static_cast(flatbuffer_union); + if (!core_range) { + throw std::runtime_error("Invalid CoreRange data"); + } + return CoreRange{ + {core_range->start()->x(), core_range->start()->y()}, {core_range->end()->x(), core_range->end()->y()}}; + } + case tt::target::CoreSpec::CoreRangeSet: { + auto core_range_set = static_cast(flatbuffer_union); + if (!core_range_set) { + throw std::runtime_error("Invalid CoreRangeSet data"); + } + std::vector ranges; + for (const auto range : *core_range_set->ranges()) { + ranges.emplace_back( + CoreCoord{range->start()->x(), range->start()->y()}, + CoreCoord{range->end()->x(), range->end()->y()}); + } + return CoreRangeSet{ranges}; + } + default: throw std::runtime_error("Unhandled CoreSpec type in FromFlatbuffer"); + } +} + +inline DataMovementConfig FromFlatbuffer(const tt::target::DataMovementConfig* fb_config) { + DataMovementConfig config; + + // Extract processor, noc, and noc_mode + config.processor = FromFlatbuffer(fb_config->processor()); + config.noc = FromFlatbuffer(fb_config->noc()); + config.noc_mode = FromFlatbuffer(fb_config->noc_mode()); + + // Extract compile_args + auto fb_compile_args = fb_config->compile_args(); + config.compile_args.assign(fb_compile_args->begin(), fb_compile_args->end()); + + // Extract defines + auto fb_defines = fb_config->defines(); + for (auto fb_define : *fb_defines) { + config.defines.emplace(fb_define->key()->str(), fb_define->value()->str()); + } + + return config; +} + +inline ComputeConfig FromFlatbuffer(const tt::target::ComputeConfig* fb_config) { + ComputeConfig config; + + // Extract math_fidelity and boolean flags + config.math_fidelity = FromFlatbuffer(fb_config->math_fidelity()); + config.fp32_dest_acc_en = fb_config->fp32_dest_acc_en(); + config.dst_full_sync_en = fb_config->dst_full_sync_en(); + config.bfp8_pack_precise = fb_config->bfp8_pack_precise(); + config.math_approx_mode = fb_config->math_approx_mode(); + + // Extract unpack_to_dest_mode + auto fb_unpack_modes = fb_config->unpack_to_dest_mode(); + config.unpack_to_dest_mode.reserve(fb_unpack_modes->size()); + for (auto fb_mode : *fb_unpack_modes) { + config.unpack_to_dest_mode.push_back(FromFlatbuffer(fb_mode)); + } + + // Extract compile_args + auto fb_compile_args = fb_config->compile_args(); + config.compile_args.assign(fb_compile_args->begin(), fb_compile_args->end()); + + // Extract defines + auto fb_defines = fb_config->defines(); + for (auto fb_define : *fb_defines) { + config.defines.emplace(fb_define->key()->str(), fb_define->value()->str()); + } + + return config; +} + +inline EthernetConfig FromFlatbuffer(const tt::target::EthernetConfig* fb_config) { + EthernetConfig config; + + // Extract eth_mode, noc, and processor + config.eth_mode = FromFlatbuffer(fb_config->eth_mode()); + config.noc = FromFlatbuffer(fb_config->noc()); + config.processor = FromFlatbuffer(fb_config->processor()); + + // Extract compile_args + auto fb_compile_args = fb_config->compile_args(); + config.compile_args.assign(fb_compile_args->begin(), fb_compile_args->end()); + + // Extract defines + auto fb_defines = fb_config->defines(); + for (auto fb_define : *fb_defines) { + config.defines.emplace(fb_define->key()->str(), fb_define->value()->str()); + } + + return config; +} + +inline std::variant FromFlatbuffer( + const tt::target::KernelConfig config_type, const void* flatbuffer_union) { + switch (config_type) { + case tt::target::KernelConfig::DataMovementConfig: + return FromFlatbuffer(static_cast(flatbuffer_union)); + case tt::target::KernelConfig::ComputeConfig: + return FromFlatbuffer(static_cast(flatbuffer_union)); + case tt::target::KernelConfig::EthernetConfig: + return FromFlatbuffer(static_cast(flatbuffer_union)); + default: throw std::runtime_error("Unhandled KernelConfig type in FromFlatbuffer."); + } +} + +inline Tile FromFlatbuffer(const tt::target::Tile* tile_fb) { + if (!tile_fb) { + throw std::runtime_error("Invalid Tile FlatBuffer object"); + } + + // Convert FlatBuffer vectors to std::array + std::array tile_shape = {tile_fb->tile_shape()->Get(0), tile_fb->tile_shape()->Get(1)}; + std::array face_shape = {tile_fb->face_shape()->Get(0), tile_fb->face_shape()->Get(1)}; + + // Create and return the Tile object, explicitly initializing the members + Tile tile; + tile.tile_shape = tile_shape; + tile.face_shape = face_shape; + tile.tile_hw = tile_fb->tile_hw(); + tile.face_hw = tile_fb->face_hw(); + tile.num_faces = tile_fb->num_faces(); + tile.partial_face = tile_fb->partial_face(); + tile.narrow_tile = tile_fb->narrow_tile(); + tile.transpose_within_face = tile_fb->transpose_within_face(); + tile.transpose_of_faces = tile_fb->transpose_of_faces(); + + return tile; +} + +inline std::array, NUM_CIRCULAR_BUFFERS> FromFlatbuffer( + const flatbuffers::Vector>* tiles_fb) { + std::array, NUM_CIRCULAR_BUFFERS> tiles = {}; + if (tiles_fb) { + for (size_t i = 0; i < tiles_fb->size() && i < NUM_CIRCULAR_BUFFERS; ++i) { + tiles[i] = FromFlatbuffer(tiles_fb->Get(i)); + } + } + return tiles; +} + +inline CircularBufferConfig FromFlatbuffer(const tt::target::CircularBufferConfig* config_fb) { + if (!config_fb) { + throw std::runtime_error("Invalid CircularBufferConfig FlatBuffer object"); + } + + // Create a CircularBufferConfig. Constructor doesn't matter much, since we serialized all + // members, will deserialize them here to get fully formed object. + CircularBufferConfig config(0, {}); + config.total_size_ = config_fb->total_size(); + + // Note: std::optional is not supported by FlatBuffers, so nullopt was serialized as value 0 in FlatBuffer. + config.globally_allocated_address_ = config_fb->globally_allocated_address() == 0 + ? std::nullopt + : std::optional(config_fb->globally_allocated_address()); + + if (config_fb->data_formats()) { + for (auto entry : *config_fb->data_formats()) { + config.data_formats_[entry->index()] = FromFlatbuffer(entry->format()); + } + } + + if (config_fb->page_sizes()) { + for (auto entry : *config_fb->page_sizes()) { + config.page_sizes_[entry->index()] = entry->size(); + } + } + + config.tiles_ = FromFlatbuffer(config_fb->tiles()); + + if (config_fb->buffer_indices()) { + config.buffer_indices_.insert(config_fb->buffer_indices()->begin(), config_fb->buffer_indices()->end()); + } + + if (config_fb->local_buffer_indices()) { + config.local_buffer_indices_.insert( + config_fb->local_buffer_indices()->begin(), config_fb->local_buffer_indices()->end()); + } + + if (config_fb->remote_buffer_indices()) { + config.remote_buffer_indices_.insert( + config_fb->remote_buffer_indices()->begin(), config_fb->remote_buffer_indices()->end()); + } + + config.dynamic_cb_ = config_fb->dynamic_cb(); + config.max_size_ = config_fb->max_size(); + config.buffer_size_ = config_fb->buffer_size(); + + return config; +} + +} // namespace v0 +} // namespace tt::tt_metal diff --git a/tt_metal/impl/tracehost/types_to_flatbuffer.hpp b/tt_metal/impl/tracehost/types_to_flatbuffer.hpp new file mode 100644 index 000000000000..931fe4f7876c --- /dev/null +++ b/tt_metal/impl/tracehost/types_to_flatbuffer.hpp @@ -0,0 +1,387 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +////////////////////////////////////////////////////////////// +// To-flatbuffer helper functions // +////////////////////////////////////////////////////////////// + +namespace tt::tt_metal { + +// Original types defined in buffer_constants.hpp +inline tt::target::BufferType ToFlatbuffer(BufferType type) { + switch (type) { + case BufferType::DRAM: return tt::target::BufferType::DRAM; + case BufferType::L1: return tt::target::BufferType::L1; + case BufferType::SYSTEM_MEMORY: return tt::target::BufferType::SystemMemory; + case BufferType::L1_SMALL: return tt::target::BufferType::L1Small; + case BufferType::TRACE: return tt::target::BufferType::Trace; + default: throw std::invalid_argument("Unknown BufferType value in ToFlatbuffer()"); + } +} + +// Original types defined in buffer_constants.hpp +inline tt::target::TensorMemoryLayout ToFlatbuffer(TensorMemoryLayout layout) { + switch (layout) { + case TensorMemoryLayout::INTERLEAVED: return tt::target::TensorMemoryLayout::Interleaved; + case TensorMemoryLayout::SINGLE_BANK: return tt::target::TensorMemoryLayout::SingleBank; + case TensorMemoryLayout::HEIGHT_SHARDED: return tt::target::TensorMemoryLayout::HeightSharded; + case TensorMemoryLayout::WIDTH_SHARDED: return tt::target::TensorMemoryLayout::WidthSharded; + case TensorMemoryLayout::BLOCK_SHARDED: return tt::target::TensorMemoryLayout::BlockSharded; + default: throw std::invalid_argument("Unknown TensorMemoryLayout value in ToFlatbuffer()"); + } +} + +// Original types defined in data_types.hpp +inline tt::target::DataMovementProcessor ToFlatbuffer(tt::tt_metal::DataMovementProcessor in) { + switch (in) { + case tt::tt_metal::DataMovementProcessor::RISCV_0: return tt::target::DataMovementProcessor::RISCV_0; + case tt::tt_metal::DataMovementProcessor::RISCV_1: return tt::target::DataMovementProcessor::RISCV_1; + default: throw std::invalid_argument("Unknown DataMovementProcessor value in ToFlatbuffer()"); + } +} + +inline tt::target::NOC ToFlatbuffer(tt::tt_metal::NOC in) { + switch (in) { + case tt::tt_metal::NOC::NOC_0: return tt::target::NOC::NOC_0; + case tt::tt_metal::NOC::NOC_1: return tt::target::NOC::NOC_1; + default: throw std::invalid_argument("Invalid NOC value passed to ToFlatbuffer"); + } +} + +inline tt::target::NOC_MODE ToFlatbuffer(tt::tt_metal::NOC_MODE in) { + switch (in) { + case tt::tt_metal::NOC_MODE::DM_DEDICATED_NOC: return tt::target::NOC_MODE::DM_DEDICATED_NOC; + case tt::tt_metal::NOC_MODE::DM_DYNAMIC_NOC: return tt::target::NOC_MODE::DM_DYNAMIC_NOC; + default: throw std::invalid_argument("Unknown NOC_MODE value in ToFlatbuffer()"); + } +} + +inline tt::target::Eth ToFlatbuffer(tt::tt_metal::Eth in) { + switch (in) { + case tt::tt_metal::Eth::SENDER: return tt::target::Eth::SENDER; + case tt::tt_metal::Eth::RECEIVER: return tt::target::Eth::RECEIVER; + case tt::tt_metal::Eth::IDLE: return tt::target::Eth::IDLE; + default: throw std::invalid_argument("Unknown Eth value in ToFlatbuffer()"); + } +} + +// Original types defined in base_types.hpp +inline tt::target::MathFidelity ToFlatbuffer(MathFidelity input) { + switch (input) { + case MathFidelity::LoFi: return tt::target::MathFidelity::LoFi; + case MathFidelity::HiFi2: return tt::target::MathFidelity::HiFi2; + case MathFidelity::HiFi3: return tt::target::MathFidelity::HiFi3; + case MathFidelity::HiFi4: return tt::target::MathFidelity::HiFi4; + case MathFidelity::Invalid: return tt::target::MathFidelity::Invalid; + default: throw std::invalid_argument("Unknown MathFidelity value in ToFlatbuffer()"); + } +} + +inline tt::target::UnpackToDestMode ToFlatbuffer(UnpackToDestMode input) { + switch (input) { + case UnpackToDestMode::UnpackToDestFp32: return tt::target::UnpackToDestMode::UnpackToDestFp32; + case UnpackToDestMode::Default: return tt::target::UnpackToDestMode::Default; + default: throw std::invalid_argument("Invalid UnpackToDestMode value passed to ToFlatbuffer"); + } +} + +// Original types defined in tt_backend_api_types.hpp +inline tt::target::DataFormat ToFlatbuffer(tt::DataFormat input) { + switch (input) { + case tt::DataFormat::Float32: return tt::target::DataFormat::Float32; + case tt::DataFormat::Float16: return tt::target::DataFormat::Float16; + case tt::DataFormat::Bfp8: return tt::target::DataFormat::Bfp8; + case tt::DataFormat::Bfp4: return tt::target::DataFormat::Bfp4; + case tt::DataFormat::Bfp2: return tt::target::DataFormat::Bfp2; + case tt::DataFormat::Float16_b: return tt::target::DataFormat::Float16_b; + case tt::DataFormat::Bfp8_b: return tt::target::DataFormat::Bfp8_b; + case tt::DataFormat::Bfp4_b: return tt::target::DataFormat::Bfp4_b; + case tt::DataFormat::Bfp2_b: return tt::target::DataFormat::Bfp2_b; + case tt::DataFormat::Lf8: return tt::target::DataFormat::Lf8; + case tt::DataFormat::Fp8_e4m3: return tt::target::DataFormat::Fp8_e4m3; + case tt::DataFormat::Int8: return tt::target::DataFormat::Int8; + case tt::DataFormat::Tf32: return tt::target::DataFormat::Tf32; + case tt::DataFormat::UInt8: return tt::target::DataFormat::UInt8; + case tt::DataFormat::UInt16: return tt::target::DataFormat::UInt16; + case tt::DataFormat::Int32: return tt::target::DataFormat::Int32; + case tt::DataFormat::UInt32: return tt::target::DataFormat::UInt32; + case tt::DataFormat::RawUInt8: return tt::target::DataFormat::RawUInt8; + case tt::DataFormat::RawUInt16: return tt::target::DataFormat::RawUInt16; + case tt::DataFormat::RawUInt32: return tt::target::DataFormat::RawUInt32; + case tt::DataFormat::Invalid: return tt::target::DataFormat::Invalid; + default: throw std::invalid_argument("Unknown DataFormat value in ToFlatbuffer()"); + } +} + +// Original types defined in core_coord.hpp +inline std::pair> ToFlatbuffer( + flatbuffers::FlatBufferBuilder& builder, const std::variant& core_spec) { + return std::visit( + [&](auto&& spec) -> std::pair> { + using T = std::decay_t; + if constexpr (std::is_same_v) { + auto core_coord = tt::target::CreateCoreCoord(builder, spec.x, spec.y); + return {tt::target::CoreSpec::CoreCoord, core_coord.Union()}; + } else if constexpr (std::is_same_v) { + auto start = tt::target::CreateCoreCoord(builder, spec.start_coord.x, spec.start_coord.y); + auto end = tt::target::CreateCoreCoord(builder, spec.end_coord.x, spec.end_coord.y); + auto core_range = tt::target::CreateCoreRange(builder, start, end); + return {tt::target::CoreSpec::CoreRange, core_range.Union()}; + } else if constexpr (std::is_same_v) { + std::vector> range_offsets; + for (const auto& range : spec.ranges()) { + auto start = tt::target::CreateCoreCoord(builder, range.start_coord.x, range.start_coord.y); + auto end = tt::target::CreateCoreCoord(builder, range.end_coord.x, range.end_coord.y); + range_offsets.push_back(tt::target::CreateCoreRange(builder, start, end)); + } + auto ranges_vector = builder.CreateVector(range_offsets); + auto core_range_set = tt::target::CreateCoreRangeSet(builder, ranges_vector); + return {tt::target::CoreSpec::CoreRangeSet, core_range_set.Union()}; + } else { + throw std::runtime_error("Unhandled variant type in ToFlatbuffer"); + } + }, + core_spec); +} + +// Original types defined in kernel_types.hpp +inline std::pair> ToFlatbuffer( + flatbuffers::FlatBufferBuilder& builder, const DataMovementConfig& config) { + // Convert defines (map) to FlatBuffer format + std::vector> defines_vector; + for (const auto& [key, value] : config.defines) { + auto key_offset = builder.CreateString(key); + auto value_offset = builder.CreateString(value); + defines_vector.push_back(tt::target::CreateDefineEntry(builder, key_offset, value_offset)); + } + auto defines_offset = builder.CreateVector(defines_vector); + + // Convert compile_args to FlatBuffer format + auto compile_args_offset = builder.CreateVector(config.compile_args); + + // Create the FlatBuffer DataMovementConfig object + auto config_offset = tt::target::CreateDataMovementConfig( + builder, + ToFlatbuffer(config.processor), + ToFlatbuffer(config.noc), + ToFlatbuffer(config.noc_mode), + compile_args_offset, + defines_offset); + + return {tt::target::KernelConfig::DataMovementConfig, config_offset.Union()}; +} + +inline std::pair> ToFlatbuffer( + flatbuffers::FlatBufferBuilder& builder, const ComputeConfig& config) { + // Convert defines (map) to FlatBuffer format + std::vector> defines_vector; + for (const auto& [key, value] : config.defines) { + auto key_offset = builder.CreateString(key); + auto value_offset = builder.CreateString(value); + defines_vector.push_back(tt::target::CreateDefineEntry(builder, key_offset, value_offset)); + } + auto defines_offset = builder.CreateVector(defines_vector); + + // Convert unpack_to_dest_mode to FlatBuffer format + std::vector unpack_modes; + for (const auto& mode : config.unpack_to_dest_mode) { + unpack_modes.push_back(ToFlatbuffer(mode)); + } + auto unpack_modes_offset = builder.CreateVector(unpack_modes); + + // Convert compile_args to FlatBuffer format + auto compile_args_offset = builder.CreateVector(config.compile_args); + + // Create the FlatBuffer ComputeConfig object + auto config_offset = tt::target::CreateComputeConfig( + builder, + ToFlatbuffer(config.math_fidelity), + config.fp32_dest_acc_en, + config.dst_full_sync_en, + unpack_modes_offset, + config.bfp8_pack_precise, + config.math_approx_mode, + compile_args_offset, + defines_offset); + + return {tt::target::KernelConfig::ComputeConfig, config_offset.Union()}; +} + +inline std::pair> ToFlatbuffer( + flatbuffers::FlatBufferBuilder& builder, const EthernetConfig& config) { + // Convert defines (map) to FlatBuffer format + std::vector> defines_vector; + for (const auto& [key, value] : config.defines) { + auto key_offset = builder.CreateString(key); + auto value_offset = builder.CreateString(value); + defines_vector.push_back(tt::target::CreateDefineEntry(builder, key_offset, value_offset)); + } + auto defines_offset = builder.CreateVector(defines_vector); + + // Convert compile_args to FlatBuffer format + auto compile_args_offset = builder.CreateVector(config.compile_args); + + // Create the FlatBuffer EthernetConfig object + auto config_offset = tt::target::CreateEthernetConfig( + builder, + ToFlatbuffer(config.eth_mode), + ToFlatbuffer(config.noc), + ToFlatbuffer(config.processor), + compile_args_offset, + defines_offset); + + return {tt::target::KernelConfig::EthernetConfig, config_offset.Union()}; +} + +// Generic function for variant, specialized for each type above. +inline std::pair> ToFlatbuffer( + flatbuffers::FlatBufferBuilder& builder, + const std::variant& config) { + return std::visit( + [&](auto&& cfg) -> std::pair> { + using T = std::decay_t; + if constexpr ( + std::is_same_v || std::is_same_v || + std::is_same_v) { + return ToFlatbuffer(builder, cfg); + } else { + throw std::runtime_error("Unhandled config type in ToFlatbuffer."); + } + }, + config); +} + +inline std::pair> ToFlatbuffer( + flatbuffers::FlatBufferBuilder& builder, const ReaderDataMovementConfig& config) { + const DataMovementConfig& base_config = config; // Cast to base + return ToFlatbuffer(builder, base_config); +} + +inline std::pair> ToFlatbuffer( + flatbuffers::FlatBufferBuilder& builder, const WriterDataMovementConfig& config) { + const DataMovementConfig& base_config = config; // Cast to base + return ToFlatbuffer(builder, base_config); +} + +inline flatbuffers::Offset createRuntimeArg( + flatbuffers::FlatBufferBuilder& builder, const std::variant& arg) { + flatbuffers::Offset value_offset; + tt::target::RuntimeArgValue value_type; + + if (std::holds_alternative(arg)) { + // Create UInt32Value table + uint32_t value = std::get(arg); + auto uint32_offset = tt::target::CreateUInt32Value(builder, value); + value_offset = uint32_offset.Union(); + value_type = tt::target::RuntimeArgValue::UInt32Value; + } else if (std::holds_alternative(arg)) { + // Create BufferGlobalId table + Buffer* buffer = std::get(arg); + auto& ctx = LightMetalCaptureContext::Get(); + uint32_t buffer_global_id = ctx.GetGlobalId(buffer); + auto buffer_offset = tt::target::CreateBufferGlobalId(builder, buffer_global_id); + value_offset = buffer_offset.Union(); + value_type = tt::target::RuntimeArgValue::BufferGlobalId; + } else { + throw std::runtime_error("Unexpected variant type in createRuntimeArg"); + } + + // Create RuntimeArg + return tt::target::CreateRuntimeArg(builder, value_type, value_offset); +} + +inline flatbuffers::Offset>> ToFlatbuffer( + flatbuffers::FlatBufferBuilder& builder, const std::shared_ptr& runtime_args) { + std::vector> arg_offsets; + + for (const auto& arg : *runtime_args) { + auto runtime_arg_offset = createRuntimeArg(builder, arg); + arg_offsets.push_back(runtime_arg_offset); + } + + return builder.CreateVector(arg_offsets); +} + +inline flatbuffers::Offset ToFlatbuffer(const Tile& tile, flatbuffers::FlatBufferBuilder& builder) { + auto tile_shape_fb = builder.CreateVector(tile.get_tile_shape().data(), tile.get_tile_shape().size()); + auto face_shape_fb = builder.CreateVector(tile.get_face_shape().data(), tile.get_face_shape().size()); + + return tt::target::CreateTile( + builder, + tile_shape_fb, + face_shape_fb, + tile.get_tile_hw(), + tile.get_face_hw(), + tile.get_num_faces(), + tile.get_partial_face(), + tile.get_narrow_tile(), + tile.get_transpose_within_face(), + tile.get_transpose_of_faces()); +} + +inline flatbuffers::Offset>> ToFlatbuffer( + const std::array, NUM_CIRCULAR_BUFFERS>& tiles, flatbuffers::FlatBufferBuilder& builder) { + std::vector> tiles_fb; + for (const auto& tile_opt : tiles) { + if (tile_opt) { + tiles_fb.push_back(ToFlatbuffer(*tile_opt, builder)); + } + } + + return builder.CreateVector(tiles_fb); +} + +inline flatbuffers::Offset ToFlatbuffer( + const tt::tt_metal::CircularBufferConfig& config, flatbuffers::FlatBufferBuilder& builder) { + // Note: std::optional not supported by FlatBuffers, so serialize it as a uint32_t with a default val 0 + auto global_address = config.globally_allocated_address_ ? *config.globally_allocated_address_ : 0; + + // Note: std::optional data_formats array represented as vec of k (idx) v (format) pairs + std::vector data_formats_vec; + for (size_t i = 0; i < config.data_formats_.size(); i++) { + if (config.data_formats_[i]) { + data_formats_vec.push_back({i, ToFlatbuffer(*config.data_formats_[i])}); + } + } + auto data_formats_fb = builder.CreateVectorOfStructs(data_formats_vec); + + // Note: std::optional page_sizes array represented as vec of k (idx) v (size) pairs + std::vector page_sizes_vec; + for (size_t i = 0; i < config.page_sizes_.size(); i++) { + if (config.page_sizes_[i]) { + page_sizes_vec.push_back({i, *config.page_sizes_[i]}); + } + } + auto page_sizes_fb = builder.CreateVectorOfStructs(page_sizes_vec); + auto tiles_fb = ToFlatbuffer(config.tiles_, builder); + + // FIXME (kmabee) - Handle const Buffer* shadow_global_buffer too, and other things missed. + + // Serialize buffer_indices_ and variants as a FlatBuffer vector + std::vector buf_ind_vec(config.buffer_indices_.begin(), config.buffer_indices_.end()); + auto buffer_indices_fb = builder.CreateVector(buf_ind_vec); + std::vector local_buf_ind_vec(config.local_buffer_indices_.begin(), config.local_buffer_indices_.end()); + auto local_buffer_indices_fb = builder.CreateVector(local_buf_ind_vec); + std::vector remote_buf_ind_vec(config.remote_buffer_indices_.begin(), config.remote_buffer_indices_.end()); + auto remote_buffer_indices_fb = builder.CreateVector(remote_buf_ind_vec); + + // Create the FlatBuffer object + return tt::target::CreateCircularBufferConfig( + builder, + config.total_size_, + global_address, + data_formats_fb, + page_sizes_fb, + tiles_fb, + buffer_indices_fb, + local_buffer_indices_fb, + remote_buffer_indices_fb, + config.dynamic_cb_, + config.max_size_, + config.buffer_size_); +} + +} // namespace tt::tt_metal diff --git a/tt_metal/tt_metal.cpp b/tt_metal/tt_metal.cpp index 6881815f5a7d..ad2f1e04ad14 100644 --- a/tt_metal/tt_metal.cpp +++ b/tt_metal/tt_metal.cpp @@ -898,6 +898,7 @@ DeviceAddr AllocateBuffer(Buffer* buffer) { } void DeallocateBuffer(Buffer* buffer) { + TRACE_FUNCTION_CALL(CaptureDeallocateBuffer, buffer); GraphTracker::instance().track_deallocate(buffer); if (GraphTracker::instance().hook_deallocate(buffer)) { return; @@ -980,7 +981,11 @@ bool CloseDevice(IDevice* device) { return tt::DevicePool::instance().close_device(device_id); } -Program CreateProgram() { return Program(); } +Program CreateProgram() { + auto program = Program(); + TRACE_FUNCTION_CALL(CaptureCreateProgram, program); + return program; +} KernelHandle CreateDataMovementKernel( Program& program, @@ -1066,7 +1071,7 @@ KernelHandle CreateKernel( const std::string& file_name, const std::variant& core_spec, const std::variant& config) { - return std::visit( + KernelHandle kernel = std::visit( [&](auto&& cfg) -> KernelHandle { CoreRangeSet core_ranges = GetCoreRangeSet(core_spec); KernelSource kernel_src(file_name, KernelSource::FILE_PATH); @@ -1080,6 +1085,9 @@ KernelHandle CreateKernel( } }, config); + + TRACE_FUNCTION_CALL(CaptureCreateKernel, kernel, program, file_name, core_spec, config); + return kernel; } KernelHandle CreateKernelFromString( @@ -1108,7 +1116,9 @@ CBHandle CreateCircularBuffer( const std::variant& core_spec, const CircularBufferConfig& config) { CoreRangeSet core_ranges = GetCoreRangeSet(core_spec); - return program.add_circular_buffer(core_ranges, config); + auto cb_handle = program.add_circular_buffer(core_ranges, config); + TRACE_FUNCTION_CALL(CaptureCreateCircularBuffer, cb_handle, program, core_spec, config); + return cb_handle; } const CircularBufferConfig& GetCircularBufferConfig(Program& program, CBHandle cb_handle) { @@ -1195,7 +1205,7 @@ GlobalSemaphore CreateGlobalSemaphore( } std::shared_ptr CreateBuffer(const InterleavedBufferConfig& config) { - return Buffer::create( + auto buffer = Buffer::create( config.device, config.size, config.page_size, @@ -1204,6 +1214,9 @@ std::shared_ptr CreateBuffer(const InterleavedBufferConfig& config) { std::nullopt, std::nullopt, std::nullopt); + + TRACE_FUNCTION_CALL(CaptureCreateBuffer, buffer, config); + return buffer; } std::shared_ptr CreateBuffer(const InterleavedBufferConfig& config, DeviceAddr address) { return Buffer::create( @@ -1274,6 +1287,7 @@ void SetRuntimeArgs( KernelHandle kernel_id, const std::variant& core_spec, stl::Span runtime_args) { + TRACE_FUNCTION_CALL(CaptureSetRuntimeArgsUint32, program, kernel_id, core_spec, runtime_args); ZoneScoped; TT_FATAL( not CommandQueue::async_mode_set(), @@ -1309,6 +1323,7 @@ void SetRuntimeArgs( const std::variant& core_spec, std::shared_ptr runtime_args) { detail::DispatchStateCheck(not device->using_slow_dispatch()); + TRACE_FUNCTION_CALL(CaptureSetRuntimeArgs, device, kernel, core_spec, runtime_args); SetRuntimeArgsImpl(device->command_queue(), kernel, core_spec, std::move(runtime_args), false); } @@ -1369,6 +1384,7 @@ void EndTraceCapture(IDevice* device, const uint8_t cq_id, const uint32_t tid) { // When light metal tracing is enabled, TraceDescriptor will be serialized via end_trace() and this // will serialize the LightMetalLoadTraceId call to be used during replay to load trace back to device. TRACE_FUNCTION_CALL(CaptureLoadTrace, device, cq_id, tid); + TRACE_FUNCTION_CALL(CaptureReplayTrace, device, cq_id, tid, true); // blocking=true } void ReplayTrace(IDevice* device, const uint8_t cq_id, const uint32_t tid, const bool blocking) { @@ -1376,7 +1392,10 @@ void ReplayTrace(IDevice* device, const uint8_t cq_id, const uint32_t tid, const device->replay_trace(cq_id, tid, blocking); } -void ReleaseTrace(IDevice* device, const uint32_t tid) { device->release_trace(tid); } +void ReleaseTrace(IDevice* device, const uint32_t tid) { + TRACE_FUNCTION_CALL(CaptureReleaseTrace, device, tid); + device->release_trace(tid); +} void LightMetalBeginCapture(IDevice* device) { device->light_metal_begin_capture(); }