Skip to content

Commit

Permalink
#16557: Implement JointAttention (#17079)
Browse files Browse the repository at this point in the history
  • Loading branch information
cglagovichTT authored Jan 29, 2025
1 parent d5d979a commit 9cc0216
Show file tree
Hide file tree
Showing 13 changed files with 1,766 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -564,3 +564,155 @@ def test_sdpa_chunked_iterate_batch(
assert device.num_program_cache_entries() == 1, "Program cache should only have 1 entry but has {}".format(
device.num_program_cache_entries()
)


def run_test_joint_sdpa(
device,
b,
nh,
seq_len,
joint_seq_len,
d,
q_chunk_size,
k_chunk_size,
dtype,
use_high_precision_compute=False,
grid_size=None,
):
torch.manual_seed(1234)

program_config = ttnn.SDPAProgramConfig(
compute_with_storage_grid_size=grid_size or device.compute_with_storage_grid_size(),
q_chunk_size=q_chunk_size,
k_chunk_size=k_chunk_size,
exp_approx_mode=False,
)

if use_high_precision_compute:
compute_kernel_config = ttnn.WormholeComputeKernelConfig(
math_fidelity=ttnn.MathFidelity.HiFi4,
math_approx_mode=False,
fp32_dest_acc_en=True,
packer_l1_acc=False,
)
else:
compute_kernel_config = ttnn.WormholeComputeKernelConfig(
math_fidelity=ttnn.MathFidelity.HiFi2,
math_approx_mode=True,
fp32_dest_acc_en=False,
packer_l1_acc=False,
)

Q = fa_rand(b, nh, seq_len, d)
K = fa_rand(b, nh, seq_len, d)
V = fa_rand(b, nh, seq_len, d)

joint_Q = fa_rand(b, nh, joint_seq_len, d)
joint_K = fa_rand(b, nh, joint_seq_len, d)
joint_V = fa_rand(b, nh, joint_seq_len, d)

# Print shapes of all inputs along with input names
logger.debug(f"Q: {Q.shape}")
logger.debug(f"K: {K.shape}")
logger.debug(f"V: {V.shape}")

tt_Q = ttnn.from_torch(Q, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=device)
tt_K = ttnn.from_torch(K, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=device)
tt_V = ttnn.from_torch(V, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=device)
tt_joint_Q = ttnn.from_torch(joint_Q, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=device)
tt_joint_K = ttnn.from_torch(joint_K, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=device)
tt_joint_V = ttnn.from_torch(joint_V, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=device)
tt_out, tt_joint_out = ttnn.transformer.joint_scaled_dot_product_attention(
tt_Q,
tt_K,
tt_V,
tt_joint_Q,
tt_joint_K,
tt_joint_V,
joint_strategy="rear",
program_config=program_config,
compute_kernel_config=compute_kernel_config,
)
tt_out = ttnn.to_torch(tt_out)
tt_joint_out = ttnn.to_torch(tt_joint_out)
# Slice out any tile-padding
tt_out = tt_out[:, :, :seq_len, :]
tt_joint_out = tt_joint_out[:, :, :joint_seq_len, :]
logger.debug(f"tt_out: {tt_out.shape}")
logger.debug(f"tt_joint_out: {tt_joint_out.shape}")

pt_Q = torch.cat([Q, joint_Q], dim=2)
pt_K = torch.cat([K, joint_K], dim=2)
pt_V = torch.cat([V, joint_V], dim=2)
gt = torch.nn.functional.scaled_dot_product_attention(pt_Q, pt_K, pt_V, is_causal=False)
gt_out = gt[:, :, :seq_len, :]
gt_joint_out = gt[:, :, seq_len:, :]

for out, gt in [(tt_out, gt_out), (tt_joint_out, gt_joint_out)]:
out_pass, out_pcc = comp_pcc(gt, out, 0.994)
logger.debug(f"python vs pytorch: {out_pcc}")
logger.debug(f"mse: {((gt - out) ** 2).mean()}")
assert out_pass


@skip_for_blackhole("Mismatching on BH, see #12349")
@pytest.mark.skipif(is_watcher_enabled(), reason="Kernel OOM with watcher enabled")
@skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs")
@pytest.mark.parametrize("dtype", [ttnn.bfloat8_b, ttnn.bfloat16], ids=["bfp8", "bf16"])
@pytest.mark.parametrize("q_chunk_size", [128, 512], ids=["q128", "q512"])
@pytest.mark.parametrize("k_chunk_size", [128, 512], ids=["k128", "k512"])
@pytest.mark.parametrize("b", [1, 2], ids=["b1", "b2"])
@pytest.mark.parametrize("nh", [1, 3], ids=["nh1", "nh3"])
@pytest.mark.parametrize(
"seq_len, joint_seq_len",
[
(15, 19),
(2048, 256),
(3000, 100),
(20 * 1024 + 1, 118),
],
)
@pytest.mark.parametrize(
"d",
[128],
ids=[
"d128",
],
)
def test_joint_sdpa(device, b, nh, seq_len, joint_seq_len, d, q_chunk_size, k_chunk_size, dtype):
if q_chunk_size == 512 and k_chunk_size == 512:
pytest.skip("OOM config.")
ttnn.device.DisablePersistentKernelCache()
run_test_joint_sdpa(device, b, nh, seq_len, joint_seq_len, d, q_chunk_size, k_chunk_size, dtype)


@skip_for_blackhole("Mismatching on BH, see #12349")
@pytest.mark.skipif(is_watcher_enabled(), reason="Kernel OOM with watcher enabled")
@skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs")
@pytest.mark.parametrize("dtype", [ttnn.bfloat8_b, ttnn.bfloat16], ids=["bfp8", "bf16"])
@pytest.mark.parametrize("q_chunk_size", [128], ids=["q128"])
@pytest.mark.parametrize("k_chunk_size", [128], ids=["k128"])
@pytest.mark.parametrize("b", [1], ids=["b1"])
@pytest.mark.parametrize("nh", [1], ids=["nh1"])
@pytest.mark.parametrize(
"seq_len, joint_seq_len",
[
(3000, 100),
],
)
@pytest.mark.parametrize(
"d",
[128],
ids=[
"d128",
],
)
def test_joint_sdpa_program_cache(
device, b, nh, seq_len, joint_seq_len, d, q_chunk_size, k_chunk_size, dtype, use_program_cache
):
dummy_tensors = []
for _ in range(3):
dummy_tensors.append(
ttnn.from_torch(fa_rand(b, nh, seq_len, d), dtype=dtype, layout=ttnn.TILE_LAYOUT, device=device)
)
run_test_joint_sdpa(device, b, nh, seq_len, joint_seq_len, d, q_chunk_size, k_chunk_size, dtype, dummy_tensors)
2 changes: 2 additions & 0 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,8 @@ set(TTNN_OP_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/transformer/sdpa/sdpa_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/transformer/sdpa/device/sdpa_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/transformer/sdpa/device/sdpa_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/transformer/sdpa/device/joint_sdpa_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/transformer/sdpa/device/joint_sdpa_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.cpp
Expand Down
204 changes: 204 additions & 0 deletions ttnn/cpp/ttnn/operations/transformer/sdpa/device/joint_sdpa_op.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "joint_sdpa_op.hpp"

#include "joint_sdpa_program_factory.hpp"
#include "ttnn/run_operation.hpp"
#include <tt-metalium/constants.hpp>

using namespace tt::tt_metal;

namespace ttnn::operations::transformer {

void JointScaledDotProductAttention::validate(const std::vector<Tensor>& input_tensors) const {
tt::log_info("Validating Joint SDPA inputs");
TT_FATAL(input_tensors.size() == 6, "Must have 6 input tensors (Q, K, V, joint_Q, joint_K, joint_V)");

const auto& input_tensor_q = input_tensors.at(0);
const auto& input_tensor_k = input_tensors.at(1);
const auto& input_tensor_v = input_tensors.at(2);
const auto& joint_tensor_q = input_tensors.at(3);
const auto& joint_tensor_k = input_tensors.at(4);
const auto& joint_tensor_v = input_tensors.at(5);

// Validate joint strategy is 'rear'
TT_FATAL(this->joint_strategy == "rear", "Joint strategy must be 'rear'. Got: {}", this->joint_strategy);

// Validate all tensors have the same dtype
const auto dtype = input_tensor_q.get_dtype();
for (const auto& tensor : input_tensors) {
TT_FATAL(
tensor.get_dtype() == dtype,
"All tensors must have the same dtype. Expected {}, got {}",
dtype,
tensor.get_dtype());
}

// Get shapes
const auto q_shape = input_tensor_q.get_logical_shape();
const auto k_shape = input_tensor_k.get_logical_shape();
const auto v_shape = input_tensor_v.get_logical_shape();
const auto joint_q_shape = joint_tensor_q.get_logical_shape();
const auto joint_k_shape = joint_tensor_k.get_logical_shape();
const auto joint_v_shape = joint_tensor_v.get_logical_shape();

// Validate storage types and buffers
for (auto& tensor : input_tensors) {
TT_FATAL(tensor.storage_type() == StorageType::DEVICE, "Operands to Joint SDPA need to be on device");
TT_FATAL(tensor.buffer() != nullptr, "Operands to Joint SDPA need to be allocated in buffers on device");
TT_FATAL(tensor.get_layout() == Layout::TILE, "Inputs to Joint SDPA must be tilized");
TT_FATAL(
tensor.get_dtype() == DataType::BFLOAT16 || tensor.get_dtype() == DataType::BFLOAT8_B,
"Inputs to Joint SDPA must be BF16 or BF8");
TT_FATAL(
tensor.buffer()->buffer_type() == tt::tt_metal::BufferType::DRAM,
"Operands to Joint SDPA need to be in DRAM");
}

// Validate input shapes match
TT_FATAL(
k_shape[0] == q_shape[0] && v_shape[0] == q_shape[0],
"Batch sizes must match. Got Q: {}, K: {}, V: {}",
q_shape[0],
k_shape[0],
v_shape[0]);

// Validate joint input shapes match
TT_FATAL(
joint_k_shape[0] == joint_q_shape[0] && joint_v_shape[0] == joint_q_shape[0],
"Joint batch sizes must match. Got Q: {}, K: {}, V: {}",
joint_q_shape[0],
joint_k_shape[0],
joint_v_shape[0]);

// Validate Q and joint Q have same batch size and num heads
TT_FATAL(
q_shape[0] == joint_q_shape[0],
"Q and joint Q must have same batch size. Got Q: {}, joint Q: {}",
q_shape[0],
joint_q_shape[0]);

// Validate head dimensions match
TT_FATAL(
k_shape[3] == q_shape[3] && v_shape[3] == q_shape[3],
"Head dimensions must match. Got Q: {}, K: {}, V: {}",
q_shape[3],
k_shape[3],
v_shape[3]);

TT_FATAL(
joint_k_shape[3] == joint_q_shape[3] && joint_v_shape[3] == joint_q_shape[3],
"Joint head dimensions must match. Got Q: {}, K: {}, V: {}",
joint_q_shape[3],
joint_k_shape[3],
joint_v_shape[3]);

TT_FATAL(
q_shape[3] == joint_q_shape[3],
"Q and joint Q must have same head dimension. Got Q: {}, joint Q: {}",
q_shape[3],
joint_q_shape[3]);

// Validate num_heads relationship
const auto nqh = q_shape[1];
const auto nkv = k_shape[1];
const auto joint_nqh = joint_q_shape[1];
const auto joint_nkv = joint_k_shape[1];

TT_FATAL(nqh == nkv, "Q num_heads must be equal to K num_heads. Got Q: {}, K: {}", nqh, nkv);

TT_FATAL(
joint_nqh == joint_nkv,
"Joint Q num_heads must be equal to Joint K num_heads. Got Q: {}, K: {}",
joint_nqh,
joint_nkv);
TT_FATAL(
joint_nkv == nkv, "Joint K num_heads must be equal to K num_heads. Got Joint K: {}, K: {}", joint_nkv, nkv);

// Validate chunk sizes if program config is provided
auto q_chunk_size = this->get_q_chunk_size();
auto k_chunk_size = this->get_k_chunk_size();

TT_FATAL(
q_chunk_size % tt::constants::TILE_WIDTH == 0,
"q_chunk_size must be divisible by TILE_SIZE. Got q_chunk_size: {}, TILE_SIZE: {}",
q_chunk_size,
tt::constants::TILE_WIDTH);
TT_FATAL(
k_chunk_size % tt::constants::TILE_WIDTH == 0,
"k_chunk_size must be divisible by TILE_SIZE. Got k_chunk_size: {}, TILE_SIZE: {}",
k_chunk_size,
tt::constants::TILE_WIDTH);

// Validate padding: Only the sequence dimension may be padded
auto validate_padding = [](const Tensor& tensor) {
auto logical_shape = tensor.get_logical_shape();
auto padded_shape = tensor.get_padded_shape();
TT_FATAL(logical_shape[0] == padded_shape[0], "Padding is not supported on the batch dimension");
TT_FATAL(logical_shape[1] == padded_shape[1], "Padding is not supported on the num_heads dimension");
TT_FATAL(logical_shape[3] == padded_shape[3], "Padding is not supported on the head_dim dimension");
};

for (const auto& tensor : input_tensors) {
validate_padding(tensor);
}
}

std::uint32_t JointScaledDotProductAttention::get_q_chunk_size() const {
return this->program_config ? this->program_config->q_chunk_size : 32;
}

std::uint32_t JointScaledDotProductAttention::get_k_chunk_size() const {
return this->program_config ? this->program_config->k_chunk_size : 32;
}

std::vector<TensorSpec> JointScaledDotProductAttention::compute_output_specs(
const std::vector<Tensor>& input_tensors) const {
auto& input = input_tensors.at(0);
auto& joint_input = input_tensors.at(3);
return {
TensorSpec(
input.get_logical_shape(), TensorLayout(input.get_dtype(), PageConfig(Layout::TILE), output_mem_config)),
TensorSpec(
joint_input.get_logical_shape(),
TensorLayout(joint_input.get_dtype(), PageConfig(Layout::TILE), output_mem_config))};
}

operation::ProgramWithCallbacks JointScaledDotProductAttention::create_program(
const std::vector<Tensor>& input_tensors, std::vector<Tensor>& output_tensors) const {
auto& input_tensor_q = input_tensors.at(0);
auto& input_tensor_k = input_tensors.at(1);
auto& input_tensor_v = input_tensors.at(2);
auto& joint_tensor_q = input_tensors.at(3);
auto& joint_tensor_k = input_tensors.at(4);
auto& joint_tensor_v = input_tensors.at(5);
auto& output_tensor = output_tensors.at(0);
auto& joint_output_tensor = output_tensors.at(1);

auto scale = this->scale;
if (not scale.has_value()) {
scale = 1.0f / std::sqrt(static_cast<float>(input_tensor_q.get_logical_shape()[-1]));
}

std::size_t q_chunk_size = this->get_q_chunk_size();
std::size_t k_chunk_size = this->get_k_chunk_size();

return detail::joint_sdpa(
input_tensor_q,
input_tensor_k,
input_tensor_v,
joint_tensor_q,
joint_tensor_k,
joint_tensor_v,
output_tensor,
joint_output_tensor,
scale,
q_chunk_size,
k_chunk_size,
this->compute_kernel_config,
this->program_config);
}

} // namespace ttnn::operations::transformer
Loading

0 comments on commit 9cc0216

Please sign in to comment.