Skip to content

Commit

Permalink
LightMetal - Add initial Trace/Replay support for many popular host APIs
Browse files Browse the repository at this point in the history
 - Not comprehensive, just initial changes for some of these paths:

   CreateBuffer()
   EnqueueWriteBuffer()
   EnqueueReadBuffer()
   Finish()
   DeallocateBuffer
   ReleaseTrace()
   CreateProgram()
   EnqueueProgram()
   CreateKernel()
   SetRuntimeArgs(uint32)
   SetRuntimeArgs(Kernel,RuntimeArgs)
   CreateCircularBuffer()

 - When Metal Trace is enabled, don't capture EnqueueProgram(), instead
   inject ReplayTrace(), would be used alongside LoadTrace()

 - Serialization / Deserialization of structs/enums defined in
   flatbuffer schema types.fbs handled via ToFlatbuffer() and
   FromFlatbuffer() functions.
  • Loading branch information
kmabeeTT committed Jan 13, 2025
1 parent 8c90cf8 commit bad6a24
Show file tree
Hide file tree
Showing 14 changed files with 1,939 additions and 9 deletions.
1 change: 1 addition & 0 deletions tt_metal/impl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ include(${PROJECT_SOURCE_DIR}/cmake/flatbuffers.cmake)
set(FLATBUFFER_SCHEMAS
${CMAKE_CURRENT_SOURCE_DIR}/lightmetal/binary.fbs
${CMAKE_CURRENT_SOURCE_DIR}/tracehost/command.fbs
${CMAKE_CURRENT_SOURCE_DIR}/tracehost/types.fbs
)
foreach(FBS_FILE ${FLATBUFFER_SCHEMAS})
GENERATE_FBS_HEADER(${FBS_FILE})
Expand Down
19 changes: 19 additions & 0 deletions tt_metal/impl/buffers/circular_buffer_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#include <map>
#include <optional>
#include <unordered_set>
#include <flatbuffers/flatbuffers.h>
#include "flatbuffers/flatbuffer_builder.h"

#include "tt_metal/common/logger.hpp"
#include "tt_metal/common/tt_backend_api_types.hpp"
Expand All @@ -18,12 +20,29 @@

#include "tt_metal/hw/inc/circular_buffer_constants.h"

// Forward declarations for external flatbuffer serialization function
namespace tt::target {
class CircularBufferConfig;
}
namespace tt::tt_metal {
inline namespace v0 {
class CircularBufferConfig;
}
inline flatbuffers::Offset<tt::target::CircularBufferConfig> ToFlatbuffer(
const tt::tt_metal::CircularBufferConfig& config, flatbuffers::FlatBufferBuilder& builder);
inline CircularBufferConfig fromFlatBuffer(const tt::target::CircularBufferConfig* config_fb);
} // namespace tt::tt_metal

namespace tt::tt_metal {
inline namespace v0 {

using CBHandle = uintptr_t;

class CircularBufferConfig {
friend flatbuffers::Offset<tt::target::CircularBufferConfig> tt::tt_metal::ToFlatbuffer(
const tt::tt_metal::CircularBufferConfig& config, flatbuffers::FlatBufferBuilder& builder);
friend CircularBufferConfig FromFlatbuffer(const tt::target::CircularBufferConfig* config_fb);

public:
// Static circular buffer spec
CircularBufferConfig(uint32_t total_size, const std::map<uint8_t, tt::DataFormat>& data_format_spec);
Expand Down
4 changes: 4 additions & 0 deletions tt_metal/impl/dispatch/command_queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1911,6 +1911,7 @@ void EnqueueReadBuffer(
void* dst,
bool blocking,
tt::stl::Span<const SubDeviceId> sub_device_ids) {
TRACE_FUNCTION_CALL(CaptureEnqueueReadBuffer, cq, buffer, dst, blocking); // FIXME (kmabee) consider sub_device_ids added recently.
detail::DispatchStateCheck(true);
cq.run_command(CommandInterface{
.type = EnqueueCommandType::ENQUEUE_READ_BUFFER, .blocking = blocking, .buffer = buffer, .dst = dst, .sub_device_ids = sub_device_ids});
Expand All @@ -1922,13 +1923,15 @@ void EnqueueWriteBuffer(
HostDataType src,
bool blocking,
tt::stl::Span<const SubDeviceId> sub_device_ids) {
TRACE_FUNCTION_CALL(CaptureEnqueueWriteBuffer, cq, buffer, src, blocking); // FIXME (kmabee) consider sub_device_ids added recently.
detail::DispatchStateCheck(true);
cq.run_command(CommandInterface{
.type = EnqueueCommandType::ENQUEUE_WRITE_BUFFER, .blocking = blocking, .buffer = buffer, .src = std::move(src), .sub_device_ids = sub_device_ids});
}

void EnqueueProgram(
CommandQueue& cq, Program& program, bool blocking) {
TRACE_FUNCTION_CALL(CaptureEnqueueProgram, cq, program, blocking);
detail::DispatchStateCheck(true);
cq.run_command(
CommandInterface{.type = EnqueueCommandType::ENQUEUE_PROGRAM, .blocking = blocking, .program = &program});
Expand Down Expand Up @@ -1990,6 +1993,7 @@ bool EventQuery(const std::shared_ptr<Event>& event) {
}

void Finish(CommandQueue& cq, tt::stl::Span<const SubDeviceId> sub_device_ids) {
TRACE_FUNCTION_CALL(CaptureFinish, cq); // FIXME (kmabee) consider sub_device_ids added recently.
detail::DispatchStateCheck(true);
cq.run_command(CommandInterface{.type = EnqueueCommandType::FINISH, .blocking = true, .sub_device_ids = sub_device_ids});
TT_ASSERT(
Expand Down
1 change: 1 addition & 0 deletions tt_metal/impl/dispatch/command_queue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,7 @@ class HWCommandQueue {
friend void FinishImpl(CommandQueue& cq, tt::stl::Span<const SubDeviceId> sub_device_ids);
friend CommandQueue;
friend detail::Program_;
friend void CaptureEnqueueProgram(CommandQueue& cq, Program& program, bool blocking);
};

// Common interface for all command queue types
Expand Down
Loading

0 comments on commit bad6a24

Please sign in to comment.