Skip to content

Commit

Permalink
LightMetal - Add+Use TRACE_FUNCTION_ENTRY() macro to prevent recursiv…
Browse files Browse the repository at this point in the history
…e host API tracing

 - Add TRACE_FUNCTION_ENTRY() at the very start of function that is
   traced with TRACE_FUNCTION_CALL() to increment scope guard counter.

 - Really liked single macro usage per trace function, but some APIs
   like EnqueueProgram() and CreateDevice() (not currently traced, maybe
   one day) call other host APIs

 - Can't bundle with existing TRACE_FUNCTION_CALL() macro because
   sometimes it's called at end of traced function (when it needs to
   capture the return object) rather than beginning
  • Loading branch information
kmabeeTT committed Jan 13, 2025
1 parent 92407da commit 494b285
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 5 deletions.
5 changes: 5 additions & 0 deletions tt_metal/impl/dispatch/command_queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1911,6 +1911,7 @@ void EnqueueReadBuffer(
void* dst,
bool blocking,
tt::stl::Span<const SubDeviceId> sub_device_ids) {
TRACE_FUNCTION_ENTRY();
TRACE_FUNCTION_CALL(CaptureEnqueueReadBuffer, cq, buffer, dst, blocking); // FIXME (kmabee) consider sub_device_ids added recently.
detail::DispatchStateCheck(true);
cq.run_command(CommandInterface{
Expand All @@ -1923,6 +1924,7 @@ void EnqueueWriteBuffer(
HostDataType src,
bool blocking,
tt::stl::Span<const SubDeviceId> sub_device_ids) {
TRACE_FUNCTION_ENTRY();
TRACE_FUNCTION_CALL(CaptureEnqueueWriteBuffer, cq, buffer, src, blocking); // FIXME (kmabee) consider sub_device_ids added recently.
detail::DispatchStateCheck(true);
cq.run_command(CommandInterface{
Expand All @@ -1931,6 +1933,7 @@ void EnqueueWriteBuffer(

void EnqueueProgram(
CommandQueue& cq, Program& program, bool blocking) {
TRACE_FUNCTION_ENTRY();
TRACE_FUNCTION_CALL(CaptureEnqueueProgram, cq, program, blocking);
detail::DispatchStateCheck(true);
cq.run_command(
Expand Down Expand Up @@ -1993,6 +1996,7 @@ bool EventQuery(const std::shared_ptr<Event>& event) {
}

void Finish(CommandQueue& cq, tt::stl::Span<const SubDeviceId> sub_device_ids) {
TRACE_FUNCTION_ENTRY();
TRACE_FUNCTION_CALL(CaptureFinish, cq); // FIXME (kmabee) consider sub_device_ids added recently.
detail::DispatchStateCheck(true);
cq.run_command(CommandInterface{.type = EnqueueCommandType::FINISH, .blocking = true, .sub_device_ids = sub_device_ids});
Expand All @@ -2006,6 +2010,7 @@ void Finish(CommandQueue& cq, tt::stl::Span<const SubDeviceId> sub_device_ids) {
}

void EnqueueTrace(CommandQueue& cq, uint32_t trace_id, bool blocking) {
TRACE_FUNCTION_ENTRY();
TRACE_FUNCTION_CALL(CaptureEnqueueTrace, cq, trace_id, blocking);
detail::DispatchStateCheck(true);
TT_FATAL(cq.device()->get_trace(trace_id) != nullptr, "Trace instance {} must exist on device", trace_id);
Expand Down
42 changes: 37 additions & 5 deletions tt_metal/impl/lightmetal/host_api_capture_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,49 @@
// FIXME (kmabee) - Temp hack, remove before merge and integrate as cmake define.
#define ENABLE_TRACING 1

//////////////////////////////////////////////////////////////
// TRACE GUARD & TRACE MACRO //
//////////////////////////////////////////////////////////////

namespace tt::tt_metal {

// This struct will disable further tracing in current scope, and re-enable
// when scope ends. Prevents recursive tracing of host APIs.
struct TraceScope {
// Provide an inline definition in the header
static inline thread_local int depth = 0;
// Increment depth on entering scope, decrement on exiting
TraceScope() { ++depth; }
~TraceScope() { --depth; }
};

} // namespace tt::tt_metal

#ifdef ENABLE_TRACING
#define TRACE_FUNCTION_CALL(capture_func, ...) \
do { \
if (LightMetalCaptureContext::Get().IsTracing()) { \
capture_func(__VA_ARGS__); \
} \

// What should we name this? Another idea is TRACE_FUNCTION_THIS_SCOPE
#define TRACE_FUNCTION_ENTRY() tt::tt_metal::TraceScope __traceScopeGuard

#define TRACE_FUNCTION_CALL(capture_func, ...) \
do { \
log_trace( \
tt::LogMetalTrace, \
"TRACE_FUNCTION_CALL: {} via {} istracing: {} depth: {}", \
#capture_func, \
__FUNCTION__, \
LightMetalCaptureContext::Get().IsTracing(), \
tt::tt_metal::TraceScope::depth); \
if (LightMetalCaptureContext::Get().IsTracing() && tt::tt_metal::TraceScope::depth == 1) { \
capture_func(__VA_ARGS__); \
} \
} while (0)
#else

#define TRACE_FUNCTION_ENTRY()
#define TRACE_FUNCTION_CALL(capture_func, ...) \
do { \
} while (0)

#endif

namespace tt::tt_metal {
Expand Down
10 changes: 10 additions & 0 deletions tt_metal/tt_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -898,6 +898,7 @@ DeviceAddr AllocateBuffer(Buffer* buffer) {
}

void DeallocateBuffer(Buffer* buffer) {
TRACE_FUNCTION_ENTRY();
TRACE_FUNCTION_CALL(CaptureDeallocateBuffer, buffer);
GraphTracker::instance().track_deallocate(buffer);
if (GraphTracker::instance().hook_deallocate(buffer)) {
Expand Down Expand Up @@ -982,6 +983,7 @@ bool CloseDevice(IDevice* device) {
}

Program CreateProgram() {
TRACE_FUNCTION_ENTRY();
auto program = Program();
TRACE_FUNCTION_CALL(CaptureCreateProgram, program);
return program;
Expand Down Expand Up @@ -1071,6 +1073,7 @@ KernelHandle CreateKernel(
const std::string& file_name,
const std::variant<CoreCoord, CoreRange, CoreRangeSet>& core_spec,
const std::variant<DataMovementConfig, ComputeConfig, EthernetConfig>& config) {
TRACE_FUNCTION_ENTRY();
KernelHandle kernel = std::visit(
[&](auto&& cfg) -> KernelHandle {
CoreRangeSet core_ranges = GetCoreRangeSet(core_spec);
Expand Down Expand Up @@ -1115,6 +1118,7 @@ CBHandle CreateCircularBuffer(
Program& program,
const std::variant<CoreCoord, CoreRange, CoreRangeSet>& core_spec,
const CircularBufferConfig& config) {
TRACE_FUNCTION_ENTRY();
CoreRangeSet core_ranges = GetCoreRangeSet(core_spec);
auto cb_handle = program.add_circular_buffer(core_ranges, config);
TRACE_FUNCTION_CALL(CaptureCreateCircularBuffer, cb_handle, program, core_spec, config);
Expand Down Expand Up @@ -1205,6 +1209,7 @@ GlobalSemaphore CreateGlobalSemaphore(
}

std::shared_ptr<Buffer> CreateBuffer(const InterleavedBufferConfig& config) {
TRACE_FUNCTION_ENTRY();
auto buffer = Buffer::create(
config.device,
config.size,
Expand Down Expand Up @@ -1288,6 +1293,7 @@ void SetRuntimeArgs(
const std::variant<CoreCoord, CoreRange, CoreRangeSet>& core_spec,
stl::Span<const uint32_t> runtime_args) {
TRACE_FUNCTION_CALL(CaptureSetRuntimeArgsUint32, program, kernel_id, core_spec, runtime_args);
TRACE_FUNCTION_ENTRY();
ZoneScoped;
TT_FATAL(
not CommandQueue::async_mode_set(),
Expand Down Expand Up @@ -1322,6 +1328,7 @@ void SetRuntimeArgs(
const std::shared_ptr<Kernel>& kernel,
const std::variant<CoreCoord, CoreRange, CoreRangeSet>& core_spec,
std::shared_ptr<RuntimeArgs> runtime_args) {
TRACE_FUNCTION_ENTRY();
detail::DispatchStateCheck(not device->using_slow_dispatch());
TRACE_FUNCTION_CALL(CaptureSetRuntimeArgs, device, kernel, core_spec, runtime_args);
SetRuntimeArgsImpl(device->command_queue(), kernel, core_spec, std::move(runtime_args), false);
Expand Down Expand Up @@ -1380,6 +1387,7 @@ uint32_t BeginTraceCapture(IDevice* device, const uint8_t cq_id) {
}

void EndTraceCapture(IDevice* device, const uint8_t cq_id, const uint32_t tid) {
TRACE_FUNCTION_ENTRY();
device->end_trace(cq_id, tid);
// When light metal tracing is enabled, TraceDescriptor will be serialized via end_trace() and this
// will serialize the LightMetalLoadTraceId call to be used during replay to load trace back to device.
Expand All @@ -1388,11 +1396,13 @@ void EndTraceCapture(IDevice* device, const uint8_t cq_id, const uint32_t tid) {
}

void ReplayTrace(IDevice* device, const uint8_t cq_id, const uint32_t tid, const bool blocking) {
TRACE_FUNCTION_ENTRY();
TRACE_FUNCTION_CALL(CaptureReplayTrace, device, cq_id, tid, blocking);
device->replay_trace(cq_id, tid, blocking);
}

void ReleaseTrace(IDevice* device, const uint32_t tid) {
TRACE_FUNCTION_ENTRY();
TRACE_FUNCTION_CALL(CaptureReleaseTrace, device, tid);
device->release_trace(tid);
}
Expand Down

0 comments on commit 494b285

Please sign in to comment.