Skip to content

Commit

Permalink
#0: make row-invariant permute kernel multicore
Browse files Browse the repository at this point in the history
  • Loading branch information
sjameelTT committed Dec 12, 2024
1 parent 3210f6a commit e6d86ea
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@ void kernel_main() {
constexpr uint32_t num_rows = get_compile_time_arg_val(3);

const uint32_t src_addr = get_arg_val<uint32_t>(0);
const uint32_t start_row = get_arg_val<uint32_t>(1);
const uint32_t end_row = get_arg_val<uint32_t>(2);

const InterleavedAddrGen<src0_is_dram> s0 = {.bank_base_address = src_addr, .page_size = page_size};

uint32_t curr_addr = src_addr;
for (uint32_t i = 0; i < num_rows; ++i) {
for (uint32_t row = start_row; row < end_row; ++row) {
cb_reserve_back(tt::CBIndex::c_0, 1);
uint32_t src_buffer_l1_addr = get_write_ptr(tt::CBIndex::c_0);
noc_async_read_page(i, s0, src_buffer_l1_addr);
noc_async_read_page(row, s0, src_buffer_l1_addr);
noc_async_read_barrier();
cb_push_back(tt::CBIndex::c_0, 1);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,21 @@ void kernel_main() {
constexpr uint32_t num_rows = get_compile_time_arg_val(3);

const uint32_t dst_addr = get_arg_val<uint32_t>(0);
const uint32_t start_row = get_arg_val<uint32_t>(1);
const uint32_t end_row = get_arg_val<uint32_t>(2);

const InterleavedAddrGen<dst_is_dram> s0 = {.bank_base_address = dst_addr, .page_size = page_size};

uint32_t input_shape[N], perm[N], dest_strides[N];
for (uint32_t i = 1; i <= N; i++) {
input_shape[i - 1] = get_arg_val<uint32_t>(i);
perm[i - 1] = get_arg_val<uint32_t>(i + N);
dest_strides[i - 1] = get_arg_val<uint32_t>(i + 2 * N);
for (uint32_t i = 3; i < N + 3; i++) {
input_shape[i - 3] = get_arg_val<uint32_t>(i);
perm[i - 3] = get_arg_val<uint32_t>(i + N);
dest_strides[i - 3] = get_arg_val<uint32_t>(i + 2 * N);
}

uint32_t src_buffer_l1_addr = get_write_ptr(tt::CBIndex::c_0);
uint32_t curr_addr = dst_addr;
for (uint32_t row = 0; row < num_rows; ++row) {
for (uint32_t row = start_row; row < end_row; ++row) {
// Compute multi-dimensional index for the source row
uint32_t src_multi_idx[N];
size_t remaining = row;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ namespace ttnn::operations::data_movement {

PermuteDeviceOperation::program_factory_t PermuteDeviceOperation::select_program_factory(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
// If the last dimension is not permuted, we can use the row-invariant kernel
if (operation_attributes.dims.back() == tensor_args.input_tensor.get_logical_shape().rank() - 1) {
return SingleCore{};
return MultiCoreRowInvariant{};
}
// Otherwise, we need to use the blocked generic, row moving kernel
return MultiCoreBlockedGeneric{};
}

Expand All @@ -33,7 +35,7 @@ void PermuteDeviceOperation::validate_on_program_cache_hit(

PermuteDeviceOperation::shape_return_value_t PermuteDeviceOperation::compute_output_shapes(
const operation_attributes_t& attributes, const tensor_args_t& tensor_args) {
SmallVector<uint32_t> shape, padded_shape;
SmallVector<uint32_t> shape;
auto input_shape = tensor_args.input_tensor.get_logical_shape();
shape.reserve(input_shape.rank());
for (auto dim : attributes.dims) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,12 @@ struct PermuteDeviceOperation {

using tensor_return_value_t = Tensor;

struct SingleCore {
struct MultiCoreRowInvariant {
// Shared variables are the variables that are shared between the create and override_runtime_arguments methods
struct shared_variables_t {
KernelHandle unary_reader_kernel_id;
KernelHandle unary_writer_kernel_id;
CoreRangeSet all_cores;
};
using cached_program_t = ttnn::device_operation::CachedProgram<shared_variables_t>;

Expand Down Expand Up @@ -72,7 +73,7 @@ struct PermuteDeviceOperation {
tensor_return_value_t& tensor_return_value);
};

using program_factory_t = std::variant<SingleCore, MultiCoreBlockedGeneric>;
using program_factory_t = std::variant<MultiCoreRowInvariant, MultiCoreBlockedGeneric>;

// Mandatory methods

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ std::vector<uint32_t> get_row_strides(const ttnn::SimpleShape& shape) {

} // namespace detail

PermuteDeviceOperation::SingleCore::cached_program_t PermuteDeviceOperation::SingleCore::create(
PermuteDeviceOperation::MultiCoreRowInvariant::cached_program_t PermuteDeviceOperation::MultiCoreRowInvariant::create(
const operation_attributes_t& operation_attributes,
const tensor_args_t& tensor_args,
tensor_return_value_t& tensor_return_value) {
Expand Down Expand Up @@ -60,54 +60,78 @@ PermuteDeviceOperation::SingleCore::cached_program_t PermuteDeviceOperation::Sin
uint32_t src0_cb_index = tt::CBIndex::c_0;
uint32_t num_input_pages_to_read = 2;

CoreRange core({0, 0}, {0, 0});
uint32_t num_rows = input_tensor.volume() / input_tensor.get_logical_shape()[-1];

auto compute_with_storage_grid_size = input_tensor.device()->compute_with_storage_grid_size();
auto [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] =
tt::tt_metal::split_work_to_cores(compute_with_storage_grid_size, num_rows);

tt::tt_metal::CircularBufferConfig cb_src0_config =
tt::tt_metal::CircularBufferConfig(
num_input_pages_to_read * input_rm_page_size, {{src0_cb_index, cb_data_format}})
.set_page_size(src0_cb_index, input_rm_page_size);
auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, core, cb_src0_config);
auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, all_cores, cb_src0_config);

uint32_t N = operation_attributes.dims.size();
uint32_t num_rows = input_tensor.volume() / input_tensor.get_logical_shape()[-1];

bool src_is_dram = src_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0;
std::vector<uint32_t> reader_compile_time_args = {(uint32_t)src_is_dram, N, input_rm_page_size, num_rows};

tt::tt_metal::KernelHandle unary_reader_kernel_id = tt::tt_metal::CreateKernel(
program,
"ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm.cpp",
core,
"ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/"
"reader_permute_interleaved_rm_row_invariant.cpp",
all_cores,
tt::tt_metal::ReaderDataMovementConfig(reader_compile_time_args));

bool dst_is_dram = dst_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0;
std::vector<uint32_t> writer_compile_time_args = {(std::uint32_t)dst_is_dram, N, output_rm_page_size, num_rows};
tt::tt_metal::KernelHandle unary_writer_kernel_id = tt::tt_metal::CreateKernel(
program,
"ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm.cpp",
core,
"ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/"
"writer_permute_interleaved_rm_row_invariant.cpp",
all_cores,
tt::tt_metal::WriterDataMovementConfig(writer_compile_time_args));

std::vector<uint32_t> reader_runtime_args = {src_buffer->address()};

tt::tt_metal::SetRuntimeArgs(program, unary_reader_kernel_id, core, reader_runtime_args);
std::vector<uint32_t> reader_runtime_args = {src_buffer->address(), 0, 0};

auto input_shape_view = input_tensor.get_logical_shape().view();
auto output_strides = detail::get_row_strides(output_tensor.get_logical_shape()); // in anticipation of RM padding

std::vector<uint32_t> writer_runtime_args = {dst_buffer->address()};
std::vector<uint32_t> writer_runtime_args = {dst_buffer->address(), 0, 0};
writer_runtime_args.insert(writer_runtime_args.end(), input_shape_view.begin(), input_shape_view.end());
writer_runtime_args.insert(
writer_runtime_args.end(), operation_attributes.dims.begin(), operation_attributes.dims.end());
writer_runtime_args.insert(writer_runtime_args.end(), output_strides.begin(), output_strides.end());

tt::tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, writer_runtime_args);
auto cores = corerange_to_cores(all_cores, std::nullopt);
uint32_t start_row = 0;
uint32_t num_rows_per_core = 0;
for (const auto& core : cores) {
if (core_group_1.contains(core)) {
num_rows_per_core = num_tiles_per_core_group_1;
} else if (core_group_2.contains(core)) {
num_rows_per_core = num_tiles_per_core_group_2;
} else {
// no-op
num_rows_per_core = 0;
}
uint32_t end_row = start_row + num_rows_per_core;
reader_runtime_args[1] = start_row;
reader_runtime_args[2] = end_row;
writer_runtime_args[1] = start_row;
writer_runtime_args[2] = end_row;
tt::tt_metal::SetRuntimeArgs(program, unary_reader_kernel_id, core, reader_runtime_args);
tt::tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, writer_runtime_args);
start_row = end_row;
}

return {
std::move(program),
{.unary_reader_kernel_id = unary_reader_kernel_id, .unary_writer_kernel_id = unary_writer_kernel_id}};
}

void PermuteDeviceOperation::SingleCore::override_runtime_arguments(
void PermuteDeviceOperation::MultiCoreRowInvariant::override_runtime_arguments(
cached_program_t& cached_program,
const operation_attributes_t& operation_attributes,
const tensor_args_t& tensor_args,
Expand All @@ -121,15 +145,14 @@ void PermuteDeviceOperation::SingleCore::override_runtime_arguments(

auto src_buffer = input_tensor.buffer();
auto dst_buffer = output_tensor.buffer();
auto& all_cores = cached_program.shared_variables.all_cores;

{
auto& runtime_args = tt::tt_metal::GetRuntimeArgs(program, unary_reader_kernel_id, CoreCoord{0, 0});
auto cores = corerange_to_cores(all_cores, std::nullopt);
for (const auto& core : cores) {
auto& runtime_args = tt::tt_metal::GetRuntimeArgs(program, unary_reader_kernel_id, core);
runtime_args[0] = src_buffer->address();
}

{
auto& runtime_args = tt::tt_metal::GetRuntimeArgs(program, unary_writer_kernel_id, CoreCoord{0, 0});
runtime_args[0] = dst_buffer->address();
auto& runtime_args_writer = tt::tt_metal::GetRuntimeArgs(program, unary_writer_kernel_id, core);
runtime_args_writer[0] = dst_buffer->address();
}
}

Expand Down

0 comments on commit e6d86ea

Please sign in to comment.