diff --git a/fbgemm_gpu/experimental/gemm/test/grouped_gemm_test.py b/fbgemm_gpu/experimental/gemm/test/grouped_gemm_test.py new file mode 100644 index 000000000..2bbd6f31e --- /dev/null +++ b/fbgemm_gpu/experimental/gemm/test/grouped_gemm_test.py @@ -0,0 +1,117 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import Tuple + +import torch + +if torch.cuda.is_available(): + from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import quantize_fp8_row + from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import ( + grouped_gemm, + grouped_gemm_fp8_rowwise, + ) + from fbgemm_gpu.experimental.gemm.triton_gemm.utils import HAS_TMA_DESC + + +@unittest.skipIf( + not torch.cuda.is_available() + or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9 + or not HAS_TMA_DESC, + "Skip when H100 or TMA is not available", +) +class TestGroupedGEMM(unittest.TestCase): + def setUp(self) -> None: + torch.manual_seed(0) + + def test_grouped_gemm_fp8_rowwise(self) -> None: + def _test_grouped_gemm_fp8_rowwise( + shape: Tuple[int, int, int, int], + device: torch.device, + ) -> None: + G, M, N, K = shape + a = torch.randn(M, K, dtype=torch.bfloat16, device=device) + b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device) + m_offsets, _ = torch.sort( + torch.randint(low=0, high=M, size=[G], device=device, dtype=torch.int32) + ) + m_offsets[G - 1] = M + + a_fp8, a_scale = quantize_fp8_row(a) + b_fp8, b_scale = quantize_fp8_row(b) + + result = grouped_gemm_fp8_rowwise( + a_fp8, + b_fp8, + m_offsets, + a_scale, + b_scale, + ) + self.assertTrue(result.shape == (M, N)) + + expected_result = torch.zeros(M, N, dtype=torch.bfloat16, device=device) + # Running baseline with quantization to exclude quantization error from the test as it has nothing to do with the correctness of the kernel implementation. + for g in range(G): + m_start = 0 if g == 0 else m_offsets[g - 1] + m_end = m_offsets[g] + n_start = g * N + n_end = (g + 1) * N + + expected_result[m_start:m_end, :] = ( + a_fp8[m_start:m_end, :].to(torch.float32) + @ b_fp8[n_start:n_end, :].to(torch.float32).T + * a_scale[m_start:m_end][:, None] + * b_scale[n_start:n_end][None, :] + ).to(torch.bfloat16) + + torch.testing.assert_close(result, expected_result, atol=2e-2, rtol=1.6e-2) + + _test_grouped_gemm_fp8_rowwise((16, 512, 256, 256), torch.device("cuda")) + _test_grouped_gemm_fp8_rowwise((8, 512, 256, 256), torch.device("cuda")) + _test_grouped_gemm_fp8_rowwise((4, 512, 256, 256), torch.device("cuda")) + _test_grouped_gemm_fp8_rowwise((2, 512, 256, 256), torch.device("cuda")) + # TODO(shikaili): G=1 could produce NaNs results with on-device TMA store. Need to debug. + # _test_grouped_gemm_fp8_rowwise((1, 512, 256, 256), torch.device("cuda")) + + def test_grouped_gemm_bf16(self) -> None: + def _test_grouped_gemm_bf16( + shape: Tuple[int, int, int, int], + device: torch.device, + ) -> None: + G, M, N, K = shape + a = torch.randn(M, K, dtype=torch.bfloat16, device=device) + b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device) + m_offsets, _ = torch.sort( + torch.randint(low=0, high=M, size=[G], device=device, dtype=torch.int32) + ) + m_offsets[G - 1] = M + + result = grouped_gemm( + a, + b, + m_offsets, + ) + self.assertTrue(result.shape == (M, N)) + + expected_result = torch.zeros(M, N, dtype=torch.bfloat16, device=device) + for g in range(G): + m_start = 0 if g == 0 else m_offsets[g - 1] + m_end = m_offsets[g] + expected_result[m_start:m_end, :] = ( + a[m_start:m_end, :] @ b[g * N : (g + 1) * N, :].T + ) + + torch.testing.assert_close(result, expected_result, atol=1e-5, rtol=1.6e-2) + + _test_grouped_gemm_bf16((16, 512, 256, 256), torch.device("cuda")) + _test_grouped_gemm_bf16((8, 512, 256, 256), torch.device("cuda")) + _test_grouped_gemm_bf16((4, 512, 256, 256), torch.device("cuda")) + _test_grouped_gemm_bf16((2, 512, 256, 256), torch.device("cuda")) + # TODO(shikaili): G=1 could produce NaNs results with on-device TMA store. Need to debug. + # _test_grouped_gemm_bf16((1, 512, 256, 256), torch.device("cuda")) diff --git a/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py b/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py new file mode 100644 index 000000000..3026b46dc --- /dev/null +++ b/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py @@ -0,0 +1,372 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from typing import Optional + +import torch + +import triton +import triton.language as tl + +from fbgemm_gpu.experimental.gemm.triton_gemm import utils + + +@triton.autotune( + configs=[ + triton.Config( + { + "BLOCK_SIZE_M": block_size_m, + "BLOCK_SIZE_N": block_size_n, + "BLOCK_SIZE_K": block_size_k, + }, + num_stages=num_stages, + num_warps=num_warps, + num_ctas=num_ctas, + ) + for block_size_m in [64, 128] + for block_size_n in [128, 256] + for block_size_k in [128, 256] + for num_stages in [3, 4] + for num_warps in [4, 8] + for num_ctas in [1] + ], + key=["G", "M_BUCKET", "N", "K"], +) +@triton.jit +def _kernel_grouped_gemm( + a_desc_ptr, + b_desc_ptr, + c_ptr, + workspace, + m_offsets, + # problem sizes + G: tl.constexpr, + M_BUCKET: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + NUM_SMS: tl.constexpr, + # tile sizes + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +) -> None: + tidx = tl.program_id(0) + + dtype: tl.dtype = c_ptr.dtype.element_ty + TMA_SIZE: tl.constexpr = tl.constexpr(128) + c_desc_ptr = workspace + tidx * TMA_SIZE + + M_end_offset = 0 + iterated_tiles = 0 + for g in tl.range(G): + # Move across groups + M_start_offset = M_end_offset + M_end_offset = tl.load(m_offsets + g) + m_size = M_end_offset - M_start_offset + + if m_size > 0: + N_start_offset = g * N + n_size = N + num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) + num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N) + num_tiles = num_m_tiles * num_n_tiles + + # pyre-ignore + tl.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=c_desc_ptr, + global_address=c_ptr + M_start_offset * N, + load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], + global_size=[m_size, n_size], + element_ty=c_ptr.dtype.element_ty, + ) + # pyre-ignore + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) + + # Move across tiles + while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles: + gidx = tidx - iterated_tiles + # Split M first and N second. + tile_m_idx = gidx % num_m_tiles + tile_n_idx = gidx // num_m_tiles + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + tl.static_assert(K % BLOCK_SIZE_K == 0) + m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32) + n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32) + for k_offset in range(0, K, BLOCK_SIZE_K): + a = tl._experimental_descriptor_load( + a_desc_ptr, + [m_offset, k_offset], + [BLOCK_SIZE_M, BLOCK_SIZE_K], + dtype, + ) + b = tl._experimental_descriptor_load( + b_desc_ptr, + [n_offset, k_offset], + [BLOCK_SIZE_N, BLOCK_SIZE_K], + dtype, + ) + accumulator += tl.dot(a, b.T) + + m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32) + n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32) + tl._experimental_descriptor_store( + c_desc_ptr, + accumulator.to(c_ptr.dtype.element_ty), + [m_offset, n_offset], + ) + tidx += NUM_SMS + + iterated_tiles += num_tiles + + +@triton.autotune( + configs=[ + triton.Config( + { + "BLOCK_SIZE_M": block_size_m, + "BLOCK_SIZE_N": block_size_n, + "BLOCK_SIZE_K": block_size_k, + }, + num_stages=num_stages, + num_warps=num_warps, + num_ctas=num_ctas, + ) + for block_size_m in [64, 128] + for block_size_n in [128, 256] + for block_size_k in [128, 256] + for num_stages in [3, 4] + for num_warps in [4, 8] + for num_ctas in [1] + ], + key=["G", "M_BUCKET", "N", "K"], +) +@triton.jit +def _kernel_grouped_gemm_fp8_rowwise( + a_desc_ptr, + a_scale_ptr, + b_desc_ptr, + b_scale_ptr, + c_ptr, + workspace, + m_offsets, + # problem sizes + G: tl.constexpr, + M_BUCKET: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + NUM_SMS: tl.constexpr, + # tile sizes + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +) -> None: + tidx = tl.program_id(0) + + dtype = tl.float8e4nv + TMA_SIZE: tl.constexpr = tl.constexpr(128) + c_desc_ptr = workspace + tidx * TMA_SIZE + + M_end_offset = 0 + iterated_tiles = 0 + for g in tl.range(G): + # Move across groups + M_start_offset = M_end_offset + M_end_offset = tl.load(m_offsets + g) + m_size = M_end_offset - M_start_offset + + if m_size > 0: + N_start_offset = g * N + n_size = N + num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) + num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N) + num_tiles = num_m_tiles * num_n_tiles + + # pyre-ignore + tl.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=c_desc_ptr, + global_address=c_ptr + M_start_offset * N, + load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], + global_size=[m_size, n_size], + element_ty=c_ptr.dtype.element_ty, + ) + # pyre-ignore + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) + + # Move across tiles + while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles: + gidx = tidx - iterated_tiles + # Split M first and N second. + tile_m_idx = gidx % num_m_tiles + tile_n_idx = gidx // num_m_tiles + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + tl.static_assert(K % BLOCK_SIZE_K == 0) + m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32) + n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32) + for k_offset in range(0, K, BLOCK_SIZE_K): + a = tl._experimental_descriptor_load( + a_desc_ptr, + [m_offset, k_offset], + [BLOCK_SIZE_M, BLOCK_SIZE_K], + dtype, + ) + b = tl._experimental_descriptor_load( + b_desc_ptr, + [n_offset, k_offset], + [BLOCK_SIZE_N, BLOCK_SIZE_K], + dtype, + ) + accumulator += tl.dot(a, b.T) + + offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + a_scale = tl.load( + a_scale_ptr + M_start_offset + offs_am[:, None], + mask=offs_am[:, None] < m_size, + ) + b_scale = tl.load( + b_scale_ptr + N_start_offset + offs_bn[None, :], + mask=offs_bn[None, :] < n_size, + ) + c = accumulator.to(tl.float32) * a_scale * b_scale + + m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32) + n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32) + tl._experimental_descriptor_store( + c_desc_ptr, + c.to(c_ptr.dtype.element_ty), + [m_offset, n_offset], + ) + tidx += NUM_SMS + + iterated_tiles += num_tiles + + +_ON_DEVICE_TMA_WORKSPACE = {} + + +def _grouped_gemm( + x: torch.Tensor, + w: torch.Tensor, + m_offsets: torch.Tensor, + x_scale: Optional[torch.Tensor] = None, + w_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + if not utils.HAS_TMA_DESC: + raise NotImplementedError("Grouped GEMM without TMA is not supported yet") + + G = m_offsets.shape[0] + + # TODO(shikaili): G=1 could produce NaNs results with on-device TMA store. Need to debug. + if G == 1: + raise NotImplementedError("Grouped GEMM with NUM_GROUPS=1 is not supported yet") + + assert x.is_contiguous() + assert w.is_contiguous() + assert m_offsets.is_contiguous() + + M, K = x.shape + N = w.shape[0] // G + assert K == w.shape[1] + + y = torch.empty((M, N), device=x.device, dtype=torch.bfloat16) + + desc_helper = utils.TmaAutoTuneHelper() + desc_helper.init_tma_descriptor("x") + desc_helper.init_tma_descriptor("w") + + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + + global _ON_DEVICE_TMA_WORKSPACE + if x.device not in _ON_DEVICE_TMA_WORKSPACE: + _ON_DEVICE_TMA_WORKSPACE[x.device] = torch.empty( + NUM_SMS * utils.TmaAutoTuneHelper.TMA_SIZE, + device=x.device, + dtype=torch.uint8, + ) + workspace = _ON_DEVICE_TMA_WORKSPACE[x.device] + + def grid(META): + nonlocal desc_helper + desc_helper.fill_2d_tma_descriptor( + "x", + x.data_ptr(), + M, + K, + META["BLOCK_SIZE_M"], + META["BLOCK_SIZE_K"], + x.element_size(), + ) + + desc_helper.fill_2d_tma_descriptor( + "w", + w.data_ptr(), + N * G, + K, + META["BLOCK_SIZE_N"], + META["BLOCK_SIZE_K"], + w.element_size(), + ) + + return (NUM_SMS,) + + desc_x = desc_helper.get_tma_descriptor_kernel_param("x") + desc_w = desc_helper.get_tma_descriptor_kernel_param("w") + + M_BUCKET = triton.next_power_of_2(M) + if x_scale is not None and w_scale is not None: + assert x_scale.is_contiguous() + assert w_scale.is_contiguous() + _kernel_grouped_gemm_fp8_rowwise[grid]( + desc_x, + x_scale, + desc_w, + w_scale, + y, + workspace, + m_offsets, + G, + M_BUCKET, + N, + K, + NUM_SMS, + ) + else: + assert x_scale is None + assert w_scale is None + _kernel_grouped_gemm[grid]( + desc_x, + desc_w, + y, + workspace, + m_offsets, + G, + M_BUCKET, + N, + K, + NUM_SMS, + ) + + return y + + +def grouped_gemm( + x: torch.Tensor, w: torch.Tensor, m_offsets: torch.Tensor +) -> torch.Tensor: + return _grouped_gemm(x, w, m_offsets) + + +def grouped_gemm_fp8_rowwise( + x: torch.Tensor, + w: torch.Tensor, + m_offsets: torch.Tensor, + x_scale: torch.Tensor, + w_scale: torch.Tensor, +) -> torch.Tensor: + return _grouped_gemm(x, w, m_offsets, x_scale, w_scale) diff --git a/fbgemm_gpu/experimental/gemm/triton_gemm/utils.py b/fbgemm_gpu/experimental/gemm/triton_gemm/utils.py index 4a2893d7a..1bc83e709 100644 --- a/fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +++ b/fbgemm_gpu/experimental/gemm/triton_gemm/utils.py @@ -31,6 +31,8 @@ def map_dtype_to_triton(dtype: torch.dtype) -> tl.dtype: return tl.float32 elif dtype == torch.int32: return tl.int32 + elif dtype == torch.float8_e4m3fn and torch.version.hip is None: + return tl.float8e4nv else: raise ValueError(f"Unsupported dtype {dtype}")