Skip to content

Commit

Permalink
PRCleanup: FBS files cleanup and splitting up types.fbs, types_to/fro…
Browse files Browse the repository at this point in the history
…m_flatbuffer.hpp

 - Move all fbs files into flatbuffer folder (was split between lightmetal, tracehost)
 - Split types.fbs into somewhat reasonable grouping of
   base_types.fbs, buffer_types.fbs and program_types.fbs, and do the
   same for to/from_flatbuffer.hpp files
  • Loading branch information
kmabeeTT committed Jan 23, 2025
1 parent 478442d commit 65138f4
Show file tree
Hide file tree
Showing 17 changed files with 1,007 additions and 947 deletions.
4 changes: 3 additions & 1 deletion tt_metal/api/tt-metalium/host_api_capture_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
#include "command_generated.h"
#include <tt-metalium/logger.hpp>
#include "span.hpp"
#include "tracehost/types_to_flatbuffer.hpp"
#include "flatbuffer/base_types_to_flatbuffer.hpp"
#include "flatbuffer/program_types_to_flatbuffer.hpp"
#include "flatbuffer/buffer_types_to_flatbuffer.hpp"

//////////////////////////////////////////////////////////////
// TRACE GUARD & TRACE MACRO //
Expand Down
8 changes: 5 additions & 3 deletions tt_metal/impl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,11 @@ set(IMPL_SRC
include(${PROJECT_SOURCE_DIR}/cmake/flatbuffers.cmake)

set(FLATBUFFER_SCHEMAS
${CMAKE_CURRENT_SOURCE_DIR}/lightmetal/binary.fbs
${CMAKE_CURRENT_SOURCE_DIR}/tracehost/command.fbs
${CMAKE_CURRENT_SOURCE_DIR}/tracehost/types.fbs
${CMAKE_CURRENT_SOURCE_DIR}/flatbuffer/binary.fbs
${CMAKE_CURRENT_SOURCE_DIR}/flatbuffer/command.fbs
${CMAKE_CURRENT_SOURCE_DIR}/flatbuffer/base_types.fbs
${CMAKE_CURRENT_SOURCE_DIR}/flatbuffer/buffer_types.fbs
${CMAKE_CURRENT_SOURCE_DIR}/flatbuffer/program_types.fbs
)
foreach(FBS_FILE ${FLATBUFFER_SCHEMAS})
GENERATE_FBS_HEADER(${FBS_FILE})
Expand Down
102 changes: 102 additions & 0 deletions tt_metal/impl/flatbuffer/base_types.fbs
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
namespace tt.tt_metal.flatbuffer;


enum Arch: uint {
Grayskull = 0,
Wormhole_b0 = 1,
Blackhole = 2,
}

enum BufferType: ushort {
DRAM = 0,
L1 = 1,
SystemMemory = 2,
L1Small = 3,
Trace = 4,
}

enum TensorMemoryLayout: ushort {
None = 0,
Interleaved = 1,
SingleBank = 2,
HeightSharded = 3,
WidthSharded = 4,
BlockSharded = 5,
}


enum DataMovementProcessor : byte {
RISCV_0,
RISCV_1
}

enum NOC : byte {
NOC_0,
NOC_1
}

enum NOC_MODE : byte {
DM_DEDICATED_NOC,
DM_DYNAMIC_NOC
}

enum Eth : ubyte {
SENDER = 0,
RECEIVER = 1,
IDLE = 2
}

enum MathFidelity : ubyte {
LoFi = 0,
HiFi2 = 2,
HiFi3 = 3,
HiFi4 = 4,
Invalid = 255
}

enum DataFormat : uint8 {
Float32 = 0,
Float16 = 1,
Bfp8 = 2,
Bfp4 = 3,
Bfp2 = 11,
Float16_b = 5,
Bfp8_b = 6,
Bfp4_b = 7,
Bfp2_b = 15,
Lf8 = 10,
Fp8_e4m3 = 26, // 0x1A in decimal
Int8 = 14,
Tf32 = 4,
UInt8 = 30,
UInt16 = 9,
Int32 = 8,
UInt32 = 24,
RawUInt8 = 240, // 0xf0 in decimal
RawUInt16 = 241, // 0xf1 in decimal
RawUInt32 = 242, // 0xf2 in decimal
Invalid = 255
}

enum UnpackToDestMode : byte {
Default,
UnpackToDestFp32
}

table DefineEntry {
key: string;
value: string;
}


table Tile {
tile_shape: [uint32]; // Shape of the tile (e.g., height, width)
face_shape: [uint32]; // Shape of the face
tile_hw: uint32; // Tile hardware size
face_hw: uint32; // Face hardware size
num_faces: uint32; // Number of faces
partial_face: uint32; // Indicates if this is a partial face
narrow_tile: uint32; // Indicates if this is a narrow tile
transpose_within_face: bool; // Transpose within each face
transpose_of_faces: bool; // Transpose face order
}
140 changes: 140 additions & 0 deletions tt_metal/impl/flatbuffer/base_types_from_flatbuffer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "base_types_generated.h"

namespace tt::tt_metal {
inline namespace v0 {

inline BufferType FromFlatbuffer(tt::tt_metal::flatbuffer::BufferType type) {
switch (type) {
case tt::tt_metal::flatbuffer::BufferType::DRAM: return BufferType::DRAM;
case tt::tt_metal::flatbuffer::BufferType::L1: return BufferType::L1;
case tt::tt_metal::flatbuffer::BufferType::SystemMemory: return BufferType::SYSTEM_MEMORY;
case tt::tt_metal::flatbuffer::BufferType::L1Small: return BufferType::L1_SMALL;
case tt::tt_metal::flatbuffer::BufferType::Trace: return BufferType::TRACE;
default: throw std::invalid_argument("Unknown BufferType value in FromFlatbuffer()");
}
}

inline tt::tt_metal::DataMovementProcessor FromFlatbuffer(tt::tt_metal::flatbuffer::DataMovementProcessor in) {
switch (in) {
case tt::tt_metal::flatbuffer::DataMovementProcessor::RISCV_0:
return tt::tt_metal::DataMovementProcessor::RISCV_0;
case tt::tt_metal::flatbuffer::DataMovementProcessor::RISCV_1:
return tt::tt_metal::DataMovementProcessor::RISCV_1;
default: throw std::invalid_argument("Unknown DataMovementProcessor value in FromFlatbuffer()");
}
}

inline tt::tt_metal::NOC FromFlatbuffer(tt::tt_metal::flatbuffer::NOC in) {
switch (in) {
case tt::tt_metal::flatbuffer::NOC::NOC_0: return tt::tt_metal::NOC::NOC_0;
case tt::tt_metal::flatbuffer::NOC::NOC_1: return tt::tt_metal::NOC::NOC_1;
default: throw std::invalid_argument("Invalid NOC value passed to FromFlatbuffer");
}
}

inline tt::tt_metal::NOC_MODE FromFlatbuffer(tt::tt_metal::flatbuffer::NOC_MODE in) {
switch (in) {
case tt::tt_metal::flatbuffer::NOC_MODE::DM_DEDICATED_NOC: return tt::tt_metal::NOC_MODE::DM_DEDICATED_NOC;
case tt::tt_metal::flatbuffer::NOC_MODE::DM_DYNAMIC_NOC: return tt::tt_metal::NOC_MODE::DM_DYNAMIC_NOC;
default: throw std::invalid_argument("Unknown NOC_MODE value in FromFlatbuffer()");
}
}

inline tt::tt_metal::Eth FromFlatbuffer(tt::tt_metal::flatbuffer::Eth in) {
switch (in) {
case tt::tt_metal::flatbuffer::Eth::SENDER: return tt::tt_metal::Eth::SENDER;
case tt::tt_metal::flatbuffer::Eth::RECEIVER: return tt::tt_metal::Eth::RECEIVER;
case tt::tt_metal::flatbuffer::Eth::IDLE: return tt::tt_metal::Eth::IDLE;
default: throw std::invalid_argument("Unknown Eth value in FromFlatbuffer()");
}
}

inline MathFidelity FromFlatbuffer(tt::tt_metal::flatbuffer::MathFidelity input) {
switch (input) {
case tt::tt_metal::flatbuffer::MathFidelity::LoFi: return MathFidelity::LoFi;
case tt::tt_metal::flatbuffer::MathFidelity::HiFi2: return MathFidelity::HiFi2;
case tt::tt_metal::flatbuffer::MathFidelity::HiFi3: return MathFidelity::HiFi3;
case tt::tt_metal::flatbuffer::MathFidelity::HiFi4: return MathFidelity::HiFi4;
case tt::tt_metal::flatbuffer::MathFidelity::Invalid: return MathFidelity::Invalid;
default: throw std::invalid_argument("Unknown MathFidelity value in FromFlatbuffer()");
}
}

inline UnpackToDestMode FromFlatbuffer(tt::tt_metal::flatbuffer::UnpackToDestMode input) {
switch (input) {
case tt::tt_metal::flatbuffer::UnpackToDestMode::UnpackToDestFp32: return UnpackToDestMode::UnpackToDestFp32;
case tt::tt_metal::flatbuffer::UnpackToDestMode::Default: return UnpackToDestMode::Default;
default: throw std::invalid_argument("Invalid UnpackToDestMode value passed to FromFlatbuffer");
}
}

inline tt::DataFormat FromFlatbuffer(tt::tt_metal::flatbuffer::DataFormat input) {
switch (input) {
case tt::tt_metal::flatbuffer::DataFormat::Float32: return tt::DataFormat::Float32;
case tt::tt_metal::flatbuffer::DataFormat::Float16: return tt::DataFormat::Float16;
case tt::tt_metal::flatbuffer::DataFormat::Bfp8: return tt::DataFormat::Bfp8;
case tt::tt_metal::flatbuffer::DataFormat::Bfp4: return tt::DataFormat::Bfp4;
case tt::tt_metal::flatbuffer::DataFormat::Bfp2: return tt::DataFormat::Bfp2;
case tt::tt_metal::flatbuffer::DataFormat::Float16_b: return tt::DataFormat::Float16_b;
case tt::tt_metal::flatbuffer::DataFormat::Bfp8_b: return tt::DataFormat::Bfp8_b;
case tt::tt_metal::flatbuffer::DataFormat::Bfp4_b: return tt::DataFormat::Bfp4_b;
case tt::tt_metal::flatbuffer::DataFormat::Bfp2_b: return tt::DataFormat::Bfp2_b;
case tt::tt_metal::flatbuffer::DataFormat::Lf8: return tt::DataFormat::Lf8;
case tt::tt_metal::flatbuffer::DataFormat::Fp8_e4m3: return tt::DataFormat::Fp8_e4m3;
case tt::tt_metal::flatbuffer::DataFormat::Int8: return tt::DataFormat::Int8;
case tt::tt_metal::flatbuffer::DataFormat::Tf32: return tt::DataFormat::Tf32;
case tt::tt_metal::flatbuffer::DataFormat::UInt8: return tt::DataFormat::UInt8;
case tt::tt_metal::flatbuffer::DataFormat::UInt16: return tt::DataFormat::UInt16;
case tt::tt_metal::flatbuffer::DataFormat::Int32: return tt::DataFormat::Int32;
case tt::tt_metal::flatbuffer::DataFormat::UInt32: return tt::DataFormat::UInt32;
case tt::tt_metal::flatbuffer::DataFormat::RawUInt8: return tt::DataFormat::RawUInt8;
case tt::tt_metal::flatbuffer::DataFormat::RawUInt16: return tt::DataFormat::RawUInt16;
case tt::tt_metal::flatbuffer::DataFormat::RawUInt32: return tt::DataFormat::RawUInt32;
case tt::tt_metal::flatbuffer::DataFormat::Invalid: return tt::DataFormat::Invalid;
default: throw std::invalid_argument("Unknown DataFormat value in FromFlatbuffer()");
}
}

inline Tile FromFlatbuffer(const tt::tt_metal::flatbuffer::Tile* tile_fb) {
if (!tile_fb) {
throw std::runtime_error("Invalid Tile FlatBuffer object");
}

// Convert FlatBuffer vectors to std::array
std::array<uint32_t, 2> tile_shape = {tile_fb->tile_shape()->Get(0), tile_fb->tile_shape()->Get(1)};
std::array<uint32_t, 2> face_shape = {tile_fb->face_shape()->Get(0), tile_fb->face_shape()->Get(1)};

// Create and return the Tile object, explicitly initializing the members
Tile tile;
tile.tile_shape = tile_shape;
tile.face_shape = face_shape;
tile.tile_hw = tile_fb->tile_hw();
tile.face_hw = tile_fb->face_hw();
tile.num_faces = tile_fb->num_faces();
tile.partial_face = tile_fb->partial_face();
tile.narrow_tile = tile_fb->narrow_tile();
tile.transpose_within_face = tile_fb->transpose_within_face();
tile.transpose_of_faces = tile_fb->transpose_of_faces();

return tile;
}

inline std::array<std::optional<Tile>, NUM_CIRCULAR_BUFFERS> FromFlatbuffer(
const flatbuffers::Vector<flatbuffers::Offset<tt::tt_metal::flatbuffer::Tile>>* tiles_fb) {
std::array<std::optional<Tile>, NUM_CIRCULAR_BUFFERS> tiles = {};
if (tiles_fb) {
for (size_t i = 0; i < tiles_fb->size() && i < NUM_CIRCULAR_BUFFERS; ++i) {
tiles[i] = FromFlatbuffer(tiles_fb->Get(i));
}
}
return tiles;
}

} // namespace v0
} // namespace tt::tt_metal
Loading

0 comments on commit 65138f4

Please sign in to comment.