Skip to content

Commit

Permalink
add alignment fix
Browse files Browse the repository at this point in the history
  • Loading branch information
sjameelTT committed Jan 21, 2025
1 parent 4c0e24c commit 70ac32a
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 8 deletions.
4 changes: 2 additions & 2 deletions tests/ttnn/unit_tests/operations/test_permute.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,15 +499,15 @@ def generate_fixed_no_dim0_dim1_transpose_permutations(N, dim0, dim1):

@pytest.mark.parametrize("shape", [[7, 7, 7, 17, 17]])
@pytest.mark.parametrize("perm", [[0, 1, 4, 3, 2]])
@pytest.mark.parametrize("dtype", [ttnn.float32])
@pytest.mark.parametrize("dtype", [ttnn.bfloat16])
def test_permute_5d_yw(shape, perm, dtype, device):
torch.set_printoptions(threshold=300000000)
if is_grayskull() and dtype == ttnn.float32:
pytest.skip("Grayskull doesn't support float32")
torch.manual_seed(2005)
torch_tensor = torch.rand(shape, dtype=torch.bfloat16)
input_tensor = ttnn.from_torch(torch_tensor, layout=ttnn.TILE_LAYOUT, dtype=dtype, device=device)
output_tensor = ttnn.permute(input_tensor, perm, pad_value=None)
output_tensor = ttnn.permute(input_tensor, perm, pad_value=0.0)
print(ttnn.from_device(output_tensor).to_torch())
output_tensor = ttnn.to_torch(output_tensor)
torch_output = torch.permute(torch_tensor, perm)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,21 @@
#include "compute_kernel_api/tilize.h"
#include "compute_kernel_api/untilize.h"
#include "compute_kernel_api/pack_untilize.h"
#include "tt_metal/hw/inc/circular_buffer.h"

namespace NAMESPACE {

void MAIN {
constexpr uint32_t N = get_compile_time_arg_val(0);
constexpr uint32_t x_blocks = get_compile_time_arg_val(1);
constexpr uint32_t w_blocks = get_compile_time_arg_val(2);
constexpr uint32_t H = get_compile_time_arg_val(3);
constexpr uint32_t read_alignment = get_compile_time_arg_val(4);
constexpr uint32_t SUBTILE_LINE_BYTES = get_compile_time_arg_val(5);

constexpr uint32_t misalignment = read_alignment - SUBTILE_LINE_BYTES;
constexpr uint32_t misalignment_div_16 = misalignment >> cb_addr_shift;
uint32_t offset_div_16 = 0;
uint32_t start_block = get_arg_val<uint32_t>(0);
uint32_t end_block = get_arg_val<uint32_t>(1);

Expand All @@ -28,6 +35,7 @@ void MAIN {

for (uint32_t block = start_block; block < end_block; block++) {
// Decompose block into w_block, x_block, and xw_block indices
#ifdef TRISC_UNPACK
uint32_t rem = block;
uint32_t w_block = rem % w_blocks; // Which W block are we in?
rem /= w_blocks;
Expand All @@ -36,14 +44,39 @@ void MAIN {
rem /= x_blocks;

uint32_t h = rem % H;
#endif

// tilize input via unpack and then pack
tilize_init_short(cb_in, 1, cb_tilize);

cb_wait_front(cb_in, 1);
cb_reserve_back(cb_tilize, 1);

tilize_block(cb_in, 1, cb_tilize); // tilize and pack into cb_tilize
// For BH, DRAM read alignment is 64B, but each subtile/face line is 32B, so every odd numbered row in BFLOAT16
// is misaligned
#ifdef TRISC_UNPACK
if constexpr (misalignment > 0) {
// if h is an odd number, offset_div_16 by misalignment
if ((h & 1) == 1) {
// offset_div_16 = misalignment_div_16;
std::uint32_t operand_id = get_operand_id(cb_in);
get_local_cb_interface(operand_id).fifo_rd_ptr += misalignment_div_16;
}
}
#endif
// custom_tilize_block(cb_in, offset_div_16, 1, cb_tilize); // tilize and pack into cb_tilize
tilize_block(cb_in, 1, cb_tilize);

#ifdef TRISC_UNPACK
if constexpr (misalignment > 0) {
// if h is an odd number, offset_div_16 by misalignment
if ((h & 1) == 1) {
// offset_div_16 = misalignment_div_16;
std::uint32_t operand_id = get_operand_id(cb_in);
get_local_cb_interface(operand_id).fifo_rd_ptr -= misalignment_div_16;
}
}
#endif

cb_push_back(cb_tilize, 1);
cb_pop_front(cb_in, 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ void kernel_main() {
constexpr bool needs_x_padding = static_cast<bool>(get_compile_time_arg_val(24));
constexpr bool needs_y_padding = static_cast<bool>(get_compile_time_arg_val(25));
constexpr uint32_t non_x_rows = get_compile_time_arg_val(26);
constexpr uint32_t read_alignment = get_compile_time_arg_val(27);

// ------------------------------------------------------------------------
// 2) Derived Constants (kept as constexpr)
Expand All @@ -53,6 +54,8 @@ void kernel_main() {
constexpr uint32_t FACE_H_STRIDE_BYTES = NUM_FACES_W * FACE_HW_BYTES;
constexpr uint32_t tile_bytes = TILE_HW * element_size;

constexpr uint32_t misalignment = read_alignment - SUBTILE_LINE_BYTES;

// For x-padding logic:
constexpr uint32_t final_face_real_w = (W % FACE_WIDTH);
constexpr uint32_t ratio = sizeof(uint32_t) / element_size;
Expand Down Expand Up @@ -165,6 +168,11 @@ void kernel_main() {
// Reserve a slot in the circular buffer, get L1 pointer
cb_reserve_back(tt::CBIndex::c_0, 1);
uint32_t src_buffer_l1_addr = get_write_ptr(tt::CBIndex::c_0);
if constexpr (misalignment > 0) {
if ((h & 1) == 1) {
src_buffer_l1_addr += misalignment;
}
}

// --------------------------------------------------------------------
// 5.1) Async read for [x_start..x_end)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.hpp"
#include <tt-metalium/work_split.hpp>
#include <vector>
#include <tt-metalium/hal_exp.hpp>

namespace ttnn::operations::data_movement {

Expand Down Expand Up @@ -72,6 +73,13 @@ ttnn::SmallVector<uint32_t> get_inverse_permutation(const ttnn::SmallVector<uint
return inverse_permutation;
}

uint32_t get_buffer_alignment(const ttnn::Tensor& tensor) {
return (
tensor.buffer()->buffer_type() == tt::tt_metal::BufferType::DRAM
? tt::tt_metal::experimental::hal::get_dram_alignment()
: tt::tt_metal::experimental::hal::get_l1_alignment());
}

} // namespace detail

PermuteDeviceOperation::MultiCoreTileInvariant::cached_program_t PermuteDeviceOperation::MultiCoreTileInvariant::create(
Expand Down Expand Up @@ -497,6 +505,10 @@ PermuteDeviceOperation::MultiCoreTiledGeneric::cached_program_t PermuteDeviceOpe
uint32_t H_t = H_p / tile_shape[0];
uint32_t W_t = W_p / tile_shape[1];

uint32_t subtile_line_bytes = face_shape[1] * element_size;
uint32_t read_alignment = detail::get_buffer_alignment(input_tensor);
uint32_t misalignment = read_alignment - subtile_line_bytes;

uint32_t permuted_w_dim = 0; // Will hold the position of w_dim in the permuted array
for (uint32_t i = 0; i < N; ++i) {
if (dims[i] == N - 1) {
Expand Down Expand Up @@ -591,10 +603,7 @@ PermuteDeviceOperation::MultiCoreTiledGeneric::cached_program_t PermuteDeviceOpe
all_cores = num_cores > padded_num_cores ? all_cores : padded_all_cores;

tt::DataFormat cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(input_tensor.get_dtype());
uint32_t input_page_size = detail::tile_size(tensor_return_value);

tt::DataFormat cb_data_format_output = tt::tt_metal::datatype_to_dataformat_converter(output_tensor.get_dtype());
uint32_t output_page_size = detail::tile_size(tensor_return_value);
uint32_t input_page_size = detail::tile_size(tensor_return_value) + misalignment;

tt::tt_metal::CircularBufferConfig cb_src0_config =
tt::tt_metal::CircularBufferConfig(num_input_pages_to_read * input_page_size, {{src0_cb_index, cb_data_format}})
Expand Down Expand Up @@ -650,6 +659,7 @@ PermuteDeviceOperation::MultiCoreTiledGeneric::cached_program_t PermuteDeviceOpe
(uint32_t)needs_x_padding,
(uint32_t)needs_y_padding,
non_x_rows,
read_alignment,
};

tt::tt_metal::KernelHandle unary_reader_kernel_id = tt::tt_metal::CreateKernel(
Expand All @@ -664,9 +674,11 @@ PermuteDeviceOperation::MultiCoreTiledGeneric::cached_program_t PermuteDeviceOpe
x_blocks,
w_blocks,
input_shape[N - 2],
read_alignment,
subtile_line_bytes,
};

bool fp32_dest_acc_en = cb_data_format_output == tt::DataFormat::Float32;
bool fp32_dest_acc_en = cb_data_format == tt::DataFormat::Float32;
auto compute_kernel_id = tt::tt_metal::CreateKernel(
program,
"ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/compute/transpose_xw_tiled.cpp",
Expand Down

0 comments on commit 70ac32a

Please sign in to comment.