Skip to content

Commit

Permalink
#16988: fix program cache bug
Browse files Browse the repository at this point in the history
  • Loading branch information
sjameelTT committed Jan 22, 2025
1 parent 77ee3cf commit fd7b17a
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1120,3 +1120,32 @@ def test_transpose_16411(device):
assert_with_pcc(p_c2, ttnn.to_torch(c2), 0.9999)
assert_with_pcc(p_c3, ttnn.to_torch(c3), 0.9999)
assert_with_pcc(p_c4, ttnn.to_torch(c4), 0.9999)


@pytest.mark.parametrize("rank", [5])
@pytest.mark.parametrize("indices", [[0, 1], [0, 2], [0, 3], [0, 4], [1, 2], [1, 3], [1, 4], [2, 3], [2, 4], [3, 4]])
@pytest.mark.parametrize("layout", [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT])
def test_transpose_high_rank(*, device: ttnn.Device, rank: int, indices, layout):
torch.manual_seed(2005)
ttnn.disable_and_clear_program_cache(device)
ttnn.enable_program_cache(device)

shape = [2] * rank

a = torch.randn(shape, dtype=torch.bfloat16)
b = torch.randn(shape, dtype=torch.bfloat16)

tt_a = ttnn.from_torch(a, device=device, layout=layout)
tt_b = ttnn.from_torch(b, device=device, layout=layout)

a = a.transpose(*indices)
b = b.transpose(*indices)

tt_a = ttnn.transpose(tt_a, *indices)
tt_b = ttnn.transpose(tt_b, *indices)

output_a = ttnn.to_torch(tt_a)
output_b = ttnn.to_torch(tt_b)

assert torch.allclose(a, output_a)
assert torch.allclose(b, output_b)
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ struct PermuteDeviceOperation {
struct shared_variables_t {
tt::tt_metal::KernelHandle unary_reader_kernel_id;
tt::tt_metal::KernelHandle unary_writer_kernel_id;
tt::tt_metal::KernelHandle compute_kernel_id;
CoreRangeSet core_range;
};
using cached_program_t = ttnn::device_operation::CachedProgram<shared_variables_t>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,10 @@ PermuteDeviceOperation::MultiCoreRowInvariant::cached_program_t PermuteDeviceOpe

return {
std::move(program),
{.unary_reader_kernel_id = unary_reader_kernel_id, .unary_writer_kernel_id = unary_writer_kernel_id}};
{.unary_reader_kernel_id = unary_reader_kernel_id,
.unary_writer_kernel_id = unary_writer_kernel_id,
.core_range = all_cores},
};
}

void PermuteDeviceOperation::MultiCoreRowInvariant::override_runtime_arguments(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,13 @@ PermuteDeviceOperation::MultiCoreTileInvariant::cached_program_t PermuteDeviceOp
auto inv_perm = detail::get_inverse_permutation(operation_attributes.dims);

std::vector<uint32_t> reader_runtime_args = {src_buffer->address(), 0, 0};

reader_runtime_args.insert(reader_runtime_args.end(), output_shape_view.begin(), output_shape_view.end());
reader_runtime_args.insert(reader_runtime_args.end(), inv_perm.begin(), inv_perm.end());
reader_runtime_args.insert(reader_runtime_args.end(), input_tile_strides.begin(), input_tile_strides.end());

std::vector<uint32_t> writer_runtime_args = {dst_buffer->address(), 0, 0};

std::vector<uint32_t> compute_runtime_args = {0};

auto cores = corerange_to_cores(all_cores, std::nullopt);
Expand Down Expand Up @@ -216,7 +218,10 @@ PermuteDeviceOperation::MultiCoreTileInvariant::cached_program_t PermuteDeviceOp

return {
std::move(program),
{.unary_reader_kernel_id = unary_reader_kernel_id, .unary_writer_kernel_id = unary_writer_kernel_id}};
{.unary_reader_kernel_id = unary_reader_kernel_id,
.unary_writer_kernel_id = unary_writer_kernel_id,
.compute_kernel_id = compute_kernel_id,
.core_range = all_cores}};
}

void PermuteDeviceOperation::MultiCoreTileInvariant::override_runtime_arguments(
Expand All @@ -239,6 +244,7 @@ void PermuteDeviceOperation::MultiCoreTileInvariant::override_runtime_arguments(
for (const auto& core : cores) {
auto& runtime_args = tt::tt_metal::GetRuntimeArgs(program, unary_reader_kernel_id, core);
runtime_args[0] = src_buffer->address();

auto& runtime_args_writer = tt::tt_metal::GetRuntimeArgs(program, unary_writer_kernel_id, core);
runtime_args_writer[0] = dst_buffer->address();
}
Expand Down Expand Up @@ -492,7 +498,10 @@ PermuteDeviceOperation::MultiCoreTileRowInvariant::create(

return {
std::move(program),
{.unary_reader_kernel_id = unary_reader_kernel_id, .unary_writer_kernel_id = unary_writer_kernel_id}};
{.unary_reader_kernel_id = unary_reader_kernel_id,
.unary_writer_kernel_id = unary_writer_kernel_id,
.compute_kernel_id = compute_kernel_id,
.core_range = all_cores}};
}

void PermuteDeviceOperation::MultiCoreTileRowInvariant::override_runtime_arguments(
Expand Down Expand Up @@ -855,7 +864,10 @@ PermuteDeviceOperation::MultiCoreTiledGeneric::cached_program_t PermuteDeviceOpe

return {
std::move(program),
{.unary_reader_kernel_id = unary_reader_kernel_id, .unary_writer_kernel_id = unary_writer_kernel_id}};
{.unary_reader_kernel_id = unary_reader_kernel_id,
.unary_writer_kernel_id = unary_writer_kernel_id,
.compute_kernel_id = compute_kernel_id,
.core_range = all_cores}};
}

void PermuteDeviceOperation::MultiCoreTiledGeneric::override_runtime_arguments(
Expand Down
47 changes: 13 additions & 34 deletions ttnn/cpp/ttnn/operations/data_movement/transpose/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,17 @@ inline Tensor transpose_(
TransposeOpDim transpose_dim,
const MemoryConfig& output_mem_config,
const std::optional<float>& pad_value) {
bool tiled_only = false;
constexpr uint32_t FACE_WIDTH =
tt::constants::FACE_WIDTH; // this is a highly restrictive constraint on the RM transpose_wh kernel, and with
// all the other bugs/limitations we should rewrite it
// use device->get_allocator_alignment when the it reflects the alignment of the buffer and doesn't just default to
// DRAM
auto BUFFER_ALIGNMENT = a.buffer()->buffer_type() == tt::tt_metal::BufferType::DRAM ? hal::get_dram_alignment()
: hal::get_l1_alignment();
uint32_t W = a.get_padded_shape()[-1];
uint32_t H = a.get_padded_shape()[-2];
auto prim_permute = [&](const ttnn::Tensor& input, ttnn::SmallVector<uint32_t> dims) -> ttnn::Tensor {
return ttnn::prim::permute(input, dims, output_mem_config, std::nullopt, pad_value);
};

bool interleaved_rm = !a.is_sharded() && a.layout() == Layout::ROW_MAJOR;
switch (transpose_dim) {
case TransposeOpDim::HC:
tiled_only = a.get_layout() == Layout::TILE;
if ((!tiled_only) && ((W * a.element_size()) % BUFFER_ALIGNMENT != 0)) { //
tiled_only = true;
if (interleaved_rm) {
return prim_permute(a, ttnn::SmallVector<uint32_t>{0, 2, 1, 3});
}
break;
// bubble dim around to make it possible as these implementations don't have a kernel
case TransposeOpDim::NH:
return ttnn::permute(
(const ttnn::Tensor)a, ttnn::SmallVector<int64_t>({2, 1, 0, 3}), output_mem_config, pad_value);
Expand All @@ -54,32 +47,18 @@ inline Tensor transpose_(
return ttnn::permute(
(const ttnn::Tensor)a, ttnn::SmallVector<int64_t>({0, 3, 2, 1}), output_mem_config, pad_value);
case TransposeOpDim::CN:
tiled_only = true; // CN only has a tiled implementation at the moment
if (interleaved_rm) {
return prim_permute(a, ttnn::SmallVector<uint32_t>({1, 0, 2, 3}));
}
break;
case TransposeOpDim::WH:
if (!a.is_sharded() && a.layout() == Layout::ROW_MAJOR) {
return ttnn::prim::permute(
a, ttnn::SmallVector<uint32_t>({0, 1, 3, 2}), output_mem_config, std::nullopt);
if (interleaved_rm) {
return prim_permute(a, ttnn::SmallVector<uint32_t>({0, 1, 3, 2}));
}
break;
default: break;
}
if (a.get_layout() == Layout::ROW_MAJOR) {
// the assorted cases where only tiled works right now (HC with stick width constraint, WH with stick width
// constraint, CN).
if (tiled_only) {
// convert to tiled
Tensor b = ttnn::to_layout(a, Layout::TILE, std::nullopt, std::nullopt, (IDevice*)nullptr);
// run the transpose.
b = operation::run(Transpose{transpose_dim, output_mem_config, pad_value}, {b}).at(0);
// back to original layout
b = ttnn::to_layout(b, a.get_layout(), std::nullopt, std::nullopt, (IDevice*)nullptr);
return b;
}
return operation::run(Transpose{transpose_dim, output_mem_config, pad_value}, {a}).at(0);
} else {
return operation::run(Transpose{transpose_dim, output_mem_config, pad_value}, {a}).at(0);
}
return operation::run(Transpose{transpose_dim, output_mem_config, pad_value}, {a}).at(0);
}

ttnn::Tensor transpose_nd(
Expand Down

0 comments on commit fd7b17a

Please sign in to comment.