Skip to content

Commit

Permalink
LightMetal - New APIs LightMetalCompareToCapture() / LightMetalCompar…
Browse files Browse the repository at this point in the history
…eToGolden() for verif

 - Put them in lightmetal_capture_utils.hpp since they are purely used
   at capture time, and not worthy enough to be inside host_api.hpp
   since just for verif.

 - Update test_lightmetal_sanity.cpp tests to use these API for
   functional correctness checking between capture + replay.
  • Loading branch information
kmabeeTT committed Jan 22, 2025
1 parent cda0466 commit a7c19da
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 26 deletions.
41 changes: 17 additions & 24 deletions tests/tt_metal/tt_metal/lightmetal/test_lightmetal_sanity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <tt-metalium/logger.hpp>
#include "tt_metal/common/scoped_timer.hpp"
#include <tt-metalium/host_api.hpp>
#include "lightmetal_capture_utils.hpp"

using std::vector;
using namespace tt;
Expand Down Expand Up @@ -162,7 +163,9 @@ TEST_F(SingleDeviceLightMetalFixture, CreateBufferEnqueueWriteRead_Sanity) {

// Write data to buffer, then read outputs and verify against expected.
EnqueueWriteBuffer(command_queue, *buffer, input_data.data(), true);
EnqueueReadBuffer(command_queue, *buffer, readback_data.data(), true);
// This will verify that readback matches between capture + replay
LightMetalCompareToCapture(command_queue, *buffer, readback_data.data());

EXPECT_TRUE(input_data == readback_data);

// For dev/debug go ahead and print the results. Had a replay bug, was seeing wrong data.
Expand Down Expand Up @@ -203,7 +206,9 @@ TEST_F(SingleDeviceLightMetalFixture, SingleRISCDataMovementSanity) {
// Write data to buffer, enqueue program, then read outputs and verify against expected.
EnqueueWriteBuffer(command_queue, *input, input_data.data(), true);
EnqueueProgram(command_queue, simple_program, true);
EnqueueReadBuffer(command_queue, *output, eager_output_data.data(), true);
// This will verify that outputs matches between capture + replay
LightMetalCompareToCapture(command_queue, *output, eager_output_data.data());

EXPECT_TRUE(eager_output_data == input_data);

// For dev/debug go ahead and print the results
Expand Down Expand Up @@ -234,18 +239,11 @@ TEST_F(SingleDeviceLightMetalFixture, ThreeRISCDataMovementComputeSanity) {
input_data[i] = i;
}

vector<uint32_t> eager_output_data;
eager_output_data.resize(input_data.size());

// Write data to buffer, enqueue program, then read outputs.
EnqueueWriteBuffer(command_queue, *input, input_data.data(), true);
EnqueueProgram(command_queue, simple_program, true);
EnqueueReadBuffer(command_queue, *output, eager_output_data.data(), true);

// For dev/debug go ahead and print the results
for (size_t i = 0; i < eager_output_data.size(); i++) {
log_debug(tt::LogMetalTrace, "i: {:3d} input: {} output: {}", i, input_data[i], eager_output_data[i]);
}
// This will verify that outputs matches between capture + replay
LightMetalCompareToCapture(command_queue, *output); // No read return

Finish(command_queue);
}
Expand Down Expand Up @@ -275,18 +273,11 @@ TEST_F(SingleDeviceLightMetalFixture, ThreeRISCDataMovementComputeSanityDynamicC
input_data[i] = i;
}

vector<uint32_t> eager_output_data;
eager_output_data.resize(input_data.size());

// Write data to buffer, enqueue program, then read outputs.
EnqueueWriteBuffer(command_queue, *input, input_data.data(), true);
EnqueueProgram(command_queue, simple_program, true);
EnqueueReadBuffer(command_queue, *output, eager_output_data.data(), true);

// For dev/debug go ahead and print the results
for (size_t i = 0; i < eager_output_data.size(); i++) {
log_info(tt::LogMetalTrace, "i: {:3d} input: {} output: {}", i, input_data[i], eager_output_data[i]);
}
// This will verify that outputs matches between capture + replay
LightMetalCompareToCapture(command_queue, *output); // No read return

Finish(command_queue);
}
Expand Down Expand Up @@ -314,7 +305,8 @@ TEST_F(SingleDeviceLightMetalFixture, SingleProgramTraceCapture) {
// Initial run w/o trace. Preloads binary cache, and captures golden output.
EnqueueWriteBuffer(command_queue, *input, input_data.data(), true);
EnqueueProgram(command_queue, simple_program, true);
EnqueueReadBuffer(command_queue, *output, eager_output_data.data(), true);
// This will verify that outputs matches between capture + replay.
LightMetalCompareToCapture(command_queue, *output, eager_output_data.data());

// Write junk to output buffer to help make sure trace run from standalone binary works.
lightmetal_test_helpers::write_junk_to_buffer(command_queue, *output);
Expand All @@ -325,7 +317,7 @@ TEST_F(SingleDeviceLightMetalFixture, SingleProgramTraceCapture) {
EndTraceCapture(this->device_, command_queue.id(), tid);

// Verify trace output during replay matches expected output from original capture.
// LightMetalCompareToGolden(command_queue, *output, eager_output_data.data());
LightMetalCompareToGolden(command_queue, *output, eager_output_data.data());

// Done
Finish(command_queue);
Expand Down Expand Up @@ -359,7 +351,8 @@ TEST_F(SingleDeviceLightMetalFixture, TwoProgramTraceCapture) {
EnqueueWriteBuffer(command_queue, *input, input_data.data(), true);
EnqueueProgram(command_queue, op0, true);
EnqueueProgram(command_queue, op1, true);
EnqueueReadBuffer(command_queue, *output, eager_output_data.data(), true);
// This will verify that outputs matches between capture + replay.
LightMetalCompareToCapture(command_queue, *output, eager_output_data.data());
Finish(command_queue);

// Write junk to output buffer to help make sure trace run from standalone binary works.
Expand All @@ -372,7 +365,7 @@ TEST_F(SingleDeviceLightMetalFixture, TwoProgramTraceCapture) {
EndTraceCapture(this->device_, command_queue.id(), tid);

// Verify trace output during replay matches expected output from original capture.
// LightMetalCompareToGolden(command_queue, *output, eager_output_data.data());
LightMetalCompareToGolden(command_queue, *output, eager_output_data.data());

// Done
Finish(command_queue);
Expand Down
36 changes: 36 additions & 0 deletions tt_metal/api/tt-metalium/host_api_capture_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -392,4 +392,40 @@ inline void CaptureCreateCircularBuffer(
CaptureCommand(tt::tt_metal::flatbuffer::CommandType::CreateCircularBufferCommand, cmd.Union());
}

inline void CaptureLightMetalCompare(
CommandQueue& cq,
std::variant<std::reference_wrapper<Buffer>, std::shared_ptr<Buffer>> buffer,
void* golden_data,
bool is_user_data) {
auto& ctx = LightMetalCaptureContext::Get();

// We don't want to use shared_ptr to extend lifetime of buffer when adding to global_id map.
Buffer* buffer_ptr = std::holds_alternative<std::shared_ptr<Buffer>>(buffer)
? std::get<std::shared_ptr<Buffer>>(buffer).get()
: &std::get<std::reference_wrapper<Buffer>>(buffer).get();

uint32_t cq_global_id = cq.id(); // TODO (kmabee) - consider storing/getting CQ from global map instead.
uint32_t buffer_global_id = ctx.GetGlobalId(buffer_ptr);

// Calculate num uint32_t elements in buffer, and convert golden void* to vector
size_t golden_data_len = buffer_ptr->size() / sizeof(uint32_t);
const uint32_t* golden_data_uint32 = static_cast<const uint32_t*>(golden_data);
std::vector<uint32_t> golden_data_vector(golden_data_uint32, golden_data_uint32 + golden_data_len);

log_debug(
tt::LogMetalTrace,
"{}: buffer_global_id: {} is_user_data: {} golden_data_len: {}",
__FUNCTION__,
buffer_global_id,
is_user_data,
golden_data_len);

// Serialize golden_data into FlatBuffer
auto golden_data_fb = ctx.GetBuilder().CreateVector(golden_data_vector);

auto cmd = tt::tt_metal::flatbuffer::CreateLightMetalCompareCommand(
ctx.GetBuilder(), cq_global_id, buffer_global_id, golden_data_fb, is_user_data);
CaptureCommand(tt::tt_metal::flatbuffer::CommandType::LightMetalCompareCommand, cmd.Union());
}

} // namespace tt::tt_metal
75 changes: 75 additions & 0 deletions tt_metal/api/tt-metalium/lightmetal_capture_utils.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "command_generated.h"
#include "tracehost/types_to_flatbuffer.hpp"
#include "host_api_capture_helpers.hpp"
#include <tt-metalium/host_api.hpp>
#include <tt-metalium/buffer.hpp>

namespace tt::tt_metal {

// Note: LightMetalCompare functions could have been inside host_api.hpp / command_queue.cpp but seems better
// to not make as visible, since these are APIs used at light-metal capture time for verification purposes.

// clang-format off
/**
* Reads a buffer from the device and captures return data as golden inside Light Metal Binary, and optionally returns to user.
* When replaying Light Metal Binary, buffer is read and data is compared to the capture-time golden data.
*
* Return value: void
*
* | Argument | Description | Type | Valid Range | Required |
* |----------------|-----------------------------------------------------------------------------------|-------------------------------------|----------------------------------------|----------|
* | cq | The command queue object which dispatches the command to the hardware | CommandQueue & | | Yes |
* | buffer | The device buffer we are reading from | Buffer & or std::shared_ptr<Buffer> | | Yes |
* | dst | The memory where the result will be stored, if provided | void* | | No |
*/
// clang-format on
inline void LightMetalCompareToCapture(
CommandQueue& cq,
const std::variant<std::reference_wrapper<Buffer>, std::shared_ptr<Buffer>>& buffer,
void* dst = nullptr) {
TRACE_FUNCTION_ENTRY();

// If dst ptr is not provided, just allocate temp space for rd return capture/usage.
std::vector<uint32_t> rd_data_tmp;
if (!dst) {
size_t buffer_size = std::holds_alternative<std::reference_wrapper<Buffer>>(buffer)
? std::get<std::reference_wrapper<Buffer>>(buffer).get().size()
: std::get<std::shared_ptr<Buffer>>(buffer)->size();
rd_data_tmp.resize(buffer_size / sizeof(uint32_t));
dst = rd_data_tmp.data();
}

EnqueueReadBuffer(cq, buffer, dst, true); // Blocking read to get golden value.
TRACE_FUNCTION_CALL(CaptureLightMetalCompare, cq, buffer, dst, false);
}

// clang-format off
/**
* Accepts user-supplied golden data, stored inside Light Metal Binary.
* When replaying Light Metal Binary, buffer is read and data is compared to the user-supplied golden data.
*
* Return value: void
*
* | Argument | Description | Type | Valid Range | Required |
* |----------------|-----------------------------------------------------------------------------------|-------------------------------------|----------------------------------------|----------|
* | cq | The command queue object which dispatches the command to the hardware | CommandQueue & | | Yes |
* | buffer | The device buffer we are reading from | Buffer & or std::shared_ptr<Buffer> | | Yes |
* | golden_data | User supplied expected/golden data for buffer | void* | | Yes |
*/
// clang-format on

inline void LightMetalCompareToGolden(
CommandQueue& cq,
const std::variant<std::reference_wrapper<Buffer>, std::shared_ptr<Buffer>>& buffer,
void* golden_data) {
TRACE_FUNCTION_ENTRY();
TRACE_FUNCTION_CALL(CaptureLightMetalCompare, cq, buffer, golden_data, true);
}

} // namespace tt::tt_metal
3 changes: 3 additions & 0 deletions tt_metal/api/tt-metalium/lightmetal_replay.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ struct CreateKernelCommand;
struct SetRuntimeArgsUint32Command;
struct SetRuntimeArgsCommand;
struct CreateCircularBufferCommand;
struct LightMetalCompareCommand;
struct RuntimeArg;

struct TraceDescriptor;
Expand Down Expand Up @@ -83,6 +84,7 @@ class LightMetalReplay {
void Execute(const tt::tt_metal::flatbuffer::SetRuntimeArgsUint32Command* command);
void Execute(const tt::tt_metal::flatbuffer::SetRuntimeArgsCommand* command);
void Execute(const tt::tt_metal::flatbuffer::CreateCircularBufferCommand* command);
void Execute(const tt::tt_metal::flatbuffer::LightMetalCompareCommand* command);

// Object maps public accessors
void AddBufferToMap(uint32_t global_id, const std::shared_ptr<::tt::tt_metal::Buffer>& buffer);
Expand Down Expand Up @@ -112,6 +114,7 @@ class LightMetalReplay {
LightMetalBinary binary_blob_; // Stored binary blob
const tt::tt_metal::flatbuffer::LightMetalBinary* fb_binary_; // Parsed FlatBuffer binary
bool show_reads_ = false; // Flag to show read buffer contents
bool disable_checking_ = false; // Optionally disable equality checking in Compare command.

// System related members ----------------------
void SetupDevices();
Expand Down
60 changes: 58 additions & 2 deletions tt_metal/impl/lightmetal/lightmetal_replay.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ LightMetalReplay::LightMetalReplay(LightMetalBinary&& binary_blob) :
binary_blob_(std::move(binary_blob)), fb_binary_(nullptr) {

show_reads_ = parse_env("TT_LIGHT_METAL_SHOW_READS", false);
disable_checking_ = parse_env("TT_LIGHT_METAL_DISABLE_CHECKING", false);
fb_binary_ = ParseFlatBufferBinary(); // Parse and store the FlatBuffer binary
if (!fb_binary_) {
throw std::runtime_error("Failed to parse FlatBuffer binary during initialization.");
Expand Down Expand Up @@ -325,6 +326,10 @@ void LightMetalReplay::Execute(const tt::tt_metal::flatbuffer::Command* command)
Execute(command->cmd_as_CreateCircularBufferCommand());
break;
}
case ::tt::tt_metal::flatbuffer::CommandType::LightMetalCompareCommand: {
Execute(command->cmd_as_LightMetalCompareCommand());
break;
}
default:
throw std::runtime_error("Unsupported type: " + std::string(EnumNameCommandType(command->cmd_type())));
break;
Expand Down Expand Up @@ -446,8 +451,7 @@ void LightMetalReplay::Execute(const tt::tt_metal::flatbuffer::EnqueueReadBuffer
EnqueueReadBuffer(cq, buffer, readback_data.data(), cmd->blocking());

// TODO (kmabee) - TBD what to do with readback data. For now, optionally print.
// One idea is to store in map by global_read_id that caller can access. Plan to also
// partially replace this by mechanism to capture and treat some reads as golden.
// One idea is to store in map by global_read_id that caller can access.
if (show_reads_) {
for (size_t i = 0; i < readback_data.size(); i++) {
log_info(tt::LogMetalTrace, " rd_data i: {:3d} => data: {} ({:x})", i, readback_data[i], readback_data[i]);
Expand Down Expand Up @@ -574,6 +578,58 @@ void LightMetalReplay::Execute(const tt::tt_metal::flatbuffer::CreateCircularBuf
AddCBHandleToMap(cmd->global_id(), cb_handle);
}

// Verification command to compare readback of a buffer with golden from either capture or user expected values.
void LightMetalReplay::Execute(const ::tt::tt_metal::flatbuffer::LightMetalCompareCommand* cmd) {
log_debug(
tt::LogMetalTrace,
"LightMetalReplay(LightMetalCompare) cq_global_id: {} buffer_global_id: {} is_user_data: {}",
cmd->cq_global_id(),
cmd->buffer_global_id(),
cmd->is_user_data());

auto buffer = GetBufferFromMap(cmd->buffer_global_id());
if (!buffer) {
throw std::runtime_error(
"Buffer w/ global_id: " + std::to_string(cmd->buffer_global_id()) + " not previously created");
}

// TODO (kmabee) - consider storing/getting CQ from global map instead.
CommandQueue& cq = this->device_->command_queue(cmd->cq_global_id());
std::vector<uint32_t> rd_data(buffer->size() / sizeof(uint32_t), 0);
EnqueueReadBuffer(cq, buffer, rd_data.data(), true);

if (disable_checking_) {
log_debug(
tt::LogMetalTrace, "Skipping LightMetalCompareCommand for buffer_global_id: {}.", cmd->buffer_global_id());
} else {
if (rd_data.size() != cmd->golden_data()->size()) {
throw std::runtime_error(
"Readback data size: " + std::to_string(rd_data.size()) +
" does not match golden data size: " + std::to_string(cmd->golden_data()->size()));
}

// Optional debug to show verbose comparison
if (show_reads_) {
for (size_t i = 0; i < rd_data.size(); i++) {
bool match = rd_data[i] == cmd->golden_data()->Get(i);
log_info(
tt::LogMetalTrace,
"LightMetalCompare i: {:3d} match: {} RdData: {:x} Golden: {:x}",
i,
match,
rd_data[i],
cmd->golden_data()->Get(i));
}
}

// Do simple equality comparison between two vectors
if (!std::equal(rd_data.begin(), rd_data.end(), cmd->golden_data()->begin())) {
throw std::runtime_error(
"Golden vs rd_data mismatch for buffer_global_id: " + std::to_string(cmd->buffer_global_id()));
}
}
}

// Main entry point to execute a light metal binary blob, return true if pass.
bool LightMetalReplay::ExecuteLightMetalBinary() {
if (!fb_binary_) {
Expand Down
8 changes: 8 additions & 0 deletions tt_metal/impl/tracehost/command.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,13 @@ table CreateCircularBufferCommand {
config: CircularBufferConfig;
}

table LightMetalCompareCommand {
cq_global_id: uint32; // reference to CommandQueue
buffer_global_id: uint32; // Reference to Buffer used as destination
golden_data: [uint32]; // Golden data to compare against at replay
is_user_data: bool; // Informational, denote if golden data is from user or capture
}

union CommandType {
ReplayTraceCommand,
EnqueueTraceCommand,
Expand All @@ -110,6 +117,7 @@ union CommandType {
SetRuntimeArgsUint32Command,
SetRuntimeArgsCommand,
CreateCircularBufferCommand,
LightMetalCompareCommand,
}

table Command {
Expand Down

0 comments on commit a7c19da

Please sign in to comment.