Skip to content

Commit

Permalink
New Operation: Fill_Tile_Pad ; Op to fill tile padding with a specifi…
Browse files Browse the repository at this point in the history
…c value (#16785)

### Ticket
[#16393 ](#16393)

### Problem description
Need to create an op that fills any tile padding with a particular value

### What's changed
Only parallelizing across dims that aren't the last 2 for now. Simply
iterate through the space between the logical shape and padded shape and
fill with the respective value.

### Checklist
- [ ] Post commit CI passes:
https://github.com/tenstorrent/tt-metal/actions/runs/12835147070
  • Loading branch information
yugi957 authored Jan 20, 2025
1 parent a5cf197 commit 2ef3e06
Show file tree
Hide file tree
Showing 12 changed files with 564 additions and 0 deletions.
98 changes: 98 additions & 0 deletions tests/ttnn/unit_tests/operations/test_fill_pad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
import ttnn
from tests.ttnn.utils_for_testing import assert_with_pcc
from models.utility_functions import torch_random, run_for_wormhole_b0


def create_nd_padded_tiled_tensor(shape, tile_size, fill_value, dtype):
"""
Creates a tensor with shape `shape` with random values, and another tensor with the same values,
but with the last 2 dimensions padded to the nearest multiple of tile_size using fill_value.
Args:
shape (tuple): Shape of the original tensor.
tile_size (int): Size to which the last two dimensions will be padded.
fill_value (float or int): Value used for padding.
dtype (torch.dtype): Data type of the tensors.
Returns:
tuple: A tuple containing the original tensor and the padded tensor.
"""
# Create a tensor with random values
if dtype == torch.float32:
tensor = torch_random(shape, -15.0, 15.0, dtype=dtype)
else:
tensor = torch.randint(0, 10, shape, dtype=dtype)

# Calculate the padded sizes for the last two dimensions
padded_shape = list(shape)
padded_shape[-2] = (padded_shape[-2] + tile_size - 1) // tile_size * tile_size
padded_shape[-1] = (padded_shape[-1] + tile_size - 1) // tile_size * tile_size

# Create a padded tensor filled with fill_value
padded_tensor = torch.full(padded_shape, fill_value, dtype=dtype)

# Copy the original tensor into the padded tensor
padded_tensor[..., : shape[-2], : shape[-1]] = tensor

return tensor, padded_tensor


import pytest
import torch
import ttnn

ttnn_dtype_to_torch_dtype = {
ttnn.uint32: torch.int32,
ttnn.bfloat16: torch.float32,
}


# @pytest.mark.parametrize("shape", [(2, 32, 300, 256)])
@pytest.mark.parametrize(
"shape",
[
# 2D shapes with edge cases for fill_pad
(1, 16),
(16, 1),
(1, 17),
(17, 1),
(16, 16),
(17, 17),
(31, 31),
(33, 33),
(65, 65),
(1, 2, 3, 2, 1, 2, 97, 97),
],
)
@pytest.mark.parametrize("fill_value", [1])
@pytest.mark.parametrize("dtype", [ttnn.uint32, ttnn.bfloat16])
@pytest.mark.parametrize("input_mem_config", [ttnn.DRAM_MEMORY_CONFIG])
@pytest.mark.parametrize("output_mem_config", [ttnn.DRAM_MEMORY_CONFIG])
def test_fill_pad(
device,
shape,
fill_value,
dtype,
input_mem_config,
output_mem_config,
):
torch.manual_seed(1234)
torch_input_tensor, padded_torch_tensor = create_nd_padded_tiled_tensor(
shape, 32, fill_value, ttnn_dtype_to_torch_dtype[dtype]
)
input_tensor = ttnn.to_device(
ttnn.from_torch(torch_input_tensor, dtype=dtype, layout=ttnn.TILE_LAYOUT),
device,
memory_config=input_mem_config,
)

output_tensor = ttnn.fill_implicit_tile_padding(input_tensor, fill_value, memory_config=output_mem_config)
padded_torch_output_tensor = ttnn.from_device(output_tensor).to_torch()

assert_with_pcc(padded_torch_tensor, padded_torch_output_tensor)
4 changes: 4 additions & 0 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ set(TTNN_OP_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/copy/typecast/typecast.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/copy/typecast/typecast_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/data_transfer/data_transfer.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/fill_pad/device/fill_pad_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/fill_pad/fill_pad.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/fill_pad/fill_pad_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/fill_pad/device/fill_pad_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/fill_rm/device/fill_rm_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/fill_rm/fill_rm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/fill_rm/fill_rm_pybind.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "ttnn/operations/data_movement/concat/concat_pybind.hpp"
#include "ttnn/operations/data_movement/copy/copy_pybind.hpp"
#include "ttnn/operations/data_movement/expand/expand_pybind.hpp"
#include "ttnn/operations/data_movement/fill_pad/fill_pad_pybind.hpp"
#include "ttnn/operations/data_movement/fill_rm/fill_rm_pybind.hpp"
#include "ttnn/operations/data_movement/fold/fold_pybind.hpp"
#include "ttnn/operations/data_movement/indexed_fill/indexed_fill_pybind.hpp"
Expand Down Expand Up @@ -48,6 +49,7 @@ namespace operations {
namespace data_movement {

void py_module(py::module& module) {
bind_fill_pad(module);
bind_fill_rm(module);
bind_fold_operation(module);
bind_non_zero_indices(module);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "ttnn/operations/data_movement/fill_pad/device/fill_pad_op.hpp"
#include "ttnn/operations/core/core.hpp"
#include <tt-metalium/host_api.hpp>
#include <tt-metalium/constants.hpp>
#include "ttnn/operations/data_movement/fill_pad/device/fill_pad_program_factory.hpp"

using namespace tt;

namespace ttnn::operations::data_movement {

void FillPad::validate(const std::vector<Tensor>& input_tensors) const {
const auto& input_tensor_a = input_tensors.at(0);
TT_FATAL(input_tensor_a.get_layout() == TILE_LAYOUT, "FillPad should only be used for tile layout");
TT_FATAL(
input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED,
"FillPad does not currently support sharding");
TT_FATAL(
this->output_mem_config.memory_layout == TensorMemoryLayout::INTERLEAVED,
"FillPad does not currently support sharding");
}

std::vector<TensorSpec> FillPad::compute_output_specs(const std::vector<Tensor>& input_tensors) const {
const auto& input_tensor = input_tensors.at(0);
return {input_tensor.get_tensor_spec()};
}

std::vector<Tensor> FillPad::create_output_tensors(const std::vector<Tensor>& input_tensors) const {
const auto& input_tensor = input_tensors.at(0);
return {input_tensor};
}

operation::ProgramWithCallbacks FillPad::create_program(
const std::vector<Tensor>& input_tensors, std::vector<Tensor>& output_tensors) const {
const auto& input_tensor = input_tensors.at(0);
auto& output_tensor = output_tensors.at(0);
return detail::fill_pad_multi_core(input_tensor, this->fill_value);
}

} // namespace ttnn::operations::data_movement
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ttnn/tensor/tensor.hpp"
#include "ttnn/run_operation.hpp"

namespace ttnn::operations::data_movement {
struct FillPad {
float fill_value;
const tt::tt_metal::MemoryConfig output_mem_config;

void validate(const std::vector<Tensor>& input_tensors) const;
std::vector<TensorSpec> compute_output_specs(const std::vector<Tensor>& input_tensors) const;
std::vector<Tensor> create_output_tensors(const std::vector<Tensor>& input_tensors) const;
tt::tt_metal::operation::ProgramWithCallbacks create_program(
const std::vector<Tensor>& input_tensors, std::vector<Tensor>& output_tensors) const;
};

} // namespace ttnn::operations::data_movement
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "ttnn/operations/data_movement/fill_pad/device/fill_pad_op.hpp"
#include "ttnn/operations/core/core.hpp"
#include <tt-metalium/host_api.hpp>
#include <tt-metalium/work_split.hpp>
#include <tt-metalium/constants.hpp>
#include <tt-metalium/util.hpp>
#include <tt-metalium/host_api.hpp>
#include <tt-metalium/tt_log.h>

bool is_power_of_two_at_least_32(uint32_t value) { return value >= 32 && (value & (value - 1)) == 0; }

using namespace tt;

std::map<DataType, uint32_t> data_type_to_size = {
{DataType::BFLOAT16, 2},
{DataType::FLOAT32, 4},
{DataType::UINT32, 4},
{DataType::UINT8, 1},
};

namespace ttnn::operations::data_movement::detail {

operation::ProgramWithCallbacks fill_pad_multi_core(const Tensor& input_tensor, float fill_value) {
tt::tt_metal::IDevice* device = input_tensor.device();
tt::tt_metal::Program program = tt::tt_metal::CreateProgram();

tt::DataFormat cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(input_tensor.get_dtype());

tt::tt_metal::Buffer* tens_buffer = input_tensor.buffer();
TT_ASSERT(tens_buffer != nullptr, "Input buffer should be allocated on device!");

uint32_t input_element_size_bytes = data_type_to_size[input_tensor.get_dtype()];
uint32_t cb_page_size = input_element_size_bytes * tt::constants::FACE_HEIGHT + sizeof(uint16_t);
uint32_t height = input_tensor.get_logical_shape()[-2];
uint32_t width = input_tensor.get_logical_shape()[-1];

uint32_t problem_size = input_tensor.get_logical_shape()[-3];

auto compute_with_storage_grid_size = device->compute_with_storage_grid_size();
uint32_t num_cores_x = compute_with_storage_grid_size.x;
uint32_t num_cores_y = compute_with_storage_grid_size.y;

auto [num_cores, all_cores, core_group_1, core_group_2, num_blocks_per_core_group_1, num_blocks_per_core_group_2] =
tt::tt_metal::split_work_to_cores(compute_with_storage_grid_size, problem_size);
uint32_t g1_numcores = core_group_1.num_cores();
uint32_t g2_numcores = core_group_2.num_cores();

constexpr uint32_t src0_cb_index = tt::CBIndex::c_0;
tt::tt_metal::CircularBufferConfig cb_src0_config =
tt::tt_metal::CircularBufferConfig(cb_page_size * 2, {{src0_cb_index, cb_data_format}})
.set_page_size(src0_cb_index, cb_page_size);
auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, all_cores, cb_src0_config);

bool src_is_dram = tens_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0;

// pack bf16 vals
uint32_t packed_fill_value = (std::uint32_t)fill_value;
if (input_tensor.get_dtype() == DataType::BFLOAT16) {
packed_fill_value = pack_two_bfloat16_into_uint32({bfloat16(fill_value), bfloat16(fill_value)});
}

uint32_t padded_height = tt::div_up(height, tt::constants::TILE_HEIGHT) * tt::constants::TILE_HEIGHT;
uint32_t padded_width = tt::div_up(width, tt::constants::TILE_HEIGHT) * tt::constants::TILE_HEIGHT;
uint32_t tiles_per_2d_tensor =
padded_height / tt::constants::TILE_HEIGHT * padded_width / tt::constants::TILE_HEIGHT;
uint32_t tiles_per_tile_row = padded_width / tt::constants::TILE_HEIGHT;

// create kernel
// reader compile time args
std::vector<uint32_t> writer_compile_time_args = {
(std::uint32_t)src0_cb_index,
(std::uint32_t)src_is_dram,
(std::uint32_t)packed_fill_value,
(std::uint32_t)input_element_size_bytes,
(std::uint32_t)height,
(std::uint32_t)width,
(std::uint32_t)padded_height,
(std::uint32_t)padded_width,
(std::uint32_t)tiles_per_2d_tensor,
(std::uint32_t)tiles_per_tile_row,
(std::uint32_t)tt::constants::TILE_HEIGHT,
(std::uint32_t)tt::constants::FACE_HEIGHT};

tt::tt_metal::KernelHandle writer_kernel_id = tt::tt_metal::CreateKernel(
program,
"ttnn/cpp/ttnn/operations/data_movement/fill_pad/device/kernels/dataflow/fill_pad_writer.cpp",
all_cores,
tt_metal::WriterDataMovementConfig(writer_compile_time_args)); // writer only for in-place operation

auto cores = grid_to_cores(num_cores, num_cores_x, num_cores_y, false);
std::vector<uint32_t> writer_runtime_args = {
(std::uint32_t)tens_buffer->address(), (std::uint32_t)cb_page_size, (std::uint32_t)0, (std::uint32_t)0};

uint32_t tile_offset = 0;
for (uint32_t i = 0; i < cores.size(); ++i) {
const CoreCoord& core = cores[i];
uint32_t local_num_2d_tensors = i < g1_numcores ? num_blocks_per_core_group_1 : num_blocks_per_core_group_2;
// Writer
{
writer_runtime_args[2] = tile_offset;
writer_runtime_args[3] = local_num_2d_tensors;
tt_metal::SetRuntimeArgs(program, writer_kernel_id, core, writer_runtime_args);
}

tile_offset += local_num_2d_tensors * tiles_per_2d_tensor;
}

auto override_runtime_args_callback = [writer_kernel_id, cores](
const Program& program,
const std::vector<Buffer*>& input_buffers,
const std::vector<Buffer*>& output_buffers) {
auto tens_buffer = input_buffers.at(0);

auto& writer_runtime_args = GetRuntimeArgs(program, writer_kernel_id);

for (const auto& core : cores) {
{
auto& runtime_args = writer_runtime_args[core.x][core.y];
runtime_args[0] = tens_buffer->address();
}
}
};

return {std::move(program), override_runtime_args_callback};
}

} // namespace ttnn::operations::data_movement::detail
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

namespace ttnn::operations::data_movement::detail {

operation::ProgramWithCallbacks fill_pad_multi_core(const Tensor& input_tensor, float fill_value);

} // namespace ttnn::operations::data_movement::detail
Loading

0 comments on commit 2ef3e06

Please sign in to comment.