From 13e4612903ecadcf6706aa8be33b6cb049444bf0 Mon Sep 17 00:00:00 2001 From: marty1885 Date: Mon, 11 Nov 2024 15:20:39 +0800 Subject: [PATCH] enable bcast_batch when needed --- ggml/src/ggml-metalium.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ggml/src/ggml-metalium.cpp b/ggml/src/ggml-metalium.cpp index 052d1cef48d4d..59233fcdea0f4 100644 --- a/ggml/src/ggml-metalium.cpp +++ b/ggml/src/ggml-metalium.cpp @@ -781,9 +781,11 @@ static void ggml_backend_metalium_mul_mat(ggml_backend_metalium_context * ctx, s } else { auto aT = ttnn::transpose(a, -2, -1); + bool bcast_batch = aT.shape()[0] != b.shape()[0]; // TODO: Ask TT to support multiplication of pre-transposed tensors. Calling transpose here is inefficient // https://github.com/tenstorrent/tt-metal/issues/9709 ttnn::operations::matmul::Matmul cfg = ttnn::operations::matmul::Matmul{ + .bcast_batch = bcast_batch, .compute_kernel_config = make_compute_kernel_config(a.device()), // XXX: Why output_tile doesn't have a default value? .output_tile = std::nullopt