From b515f5a7ab34ce4f9b590c982b289ae64943516f Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Sat, 1 Feb 2025 01:09:48 -0800 Subject: [PATCH] Re-organize SLL ops, pt 3 Summary: - Re-organize `jagged_dense_elementwise_add` Reviewed By: sryap Differential Revision: D68923208 --- fbgemm_gpu/fbgemm_gpu/sll/__init__.py | 5 -- fbgemm_gpu/fbgemm_gpu/sll/triton/__init__.py | 12 ++++- .../triton_jagged_dense_elementwise_add.py | 52 +++++++++++++++++++ ...=> triton_jagged_dense_flash_attention.py} | 0 ...iton_multi_head_jagged_flash_attention.py} | 0 fbgemm_gpu/fbgemm_gpu/sll/triton_sll.py | 43 --------------- .../sll/jagged_dense_elementwise_add_test.py | 3 +- 7 files changed, 64 insertions(+), 51 deletions(-) create mode 100644 fbgemm_gpu/fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py rename fbgemm_gpu/fbgemm_gpu/sll/triton/{jagged_dense_flash_attention.py => triton_jagged_dense_flash_attention.py} (100%) rename fbgemm_gpu/fbgemm_gpu/sll/triton/{multi_head_jagged_flash_attention.py => triton_multi_head_jagged_flash_attention.py} (100%) diff --git a/fbgemm_gpu/fbgemm_gpu/sll/__init__.py b/fbgemm_gpu/fbgemm_gpu/sll/__init__.py index 4afaaaebc0..3d359e8d3b 100644 --- a/fbgemm_gpu/fbgemm_gpu/sll/__init__.py +++ b/fbgemm_gpu/fbgemm_gpu/sll/__init__.py @@ -39,7 +39,6 @@ jagged2_softmax, jagged2_to_padded_dense, jagged_dense_bmm, - jagged_dense_elementwise_add, jagged_dense_elementwise_mul_jagged_out, jagged_flash_attention_basic, jagged_jagged_bmm, @@ -316,10 +315,6 @@ "CUDA": jagged_flash_attention_basic, "AutogradCUDA": jagged_flash_attention_basic, }, - "sll_jagged_dense_elementwise_add": { - "CUDA": jagged_dense_elementwise_add, - "AutogradCUDA": jagged_dense_elementwise_add, - }, } for op_name, dispatches in sll_cpu_registrations.items(): diff --git a/fbgemm_gpu/fbgemm_gpu/sll/triton/__init__.py b/fbgemm_gpu/fbgemm_gpu/sll/triton/__init__.py index 54ad4fe073..447b8c2034 100644 --- a/fbgemm_gpu/fbgemm_gpu/sll/triton/__init__.py +++ b/fbgemm_gpu/fbgemm_gpu/sll/triton/__init__.py @@ -8,18 +8,26 @@ # pyre-strict -from fbgemm_gpu.sll.triton.jagged_dense_flash_attention import ( # noqa F401 +from fbgemm_gpu.sll.triton.triton_jagged_dense_elementwise_add import ( # noqa F401 + jagged_dense_elementwise_add, + JaggedDenseAdd, # noqa F401 +) +from fbgemm_gpu.sll.triton.triton_jagged_dense_flash_attention import ( # noqa F401 jagged_dense_flash_attention, JaggedDenseFlashAttention, # noqa F401 ) -from fbgemm_gpu.sll.triton.multi_head_jagged_flash_attention import ( # noqa F401 +from fbgemm_gpu.sll.triton.triton_multi_head_jagged_flash_attention import ( # noqa F401 multi_head_jagged_flash_attention, MultiHeadJaggedFlashAttention, # noqa F401 ) # pyre-ignore[5] op_registrations = { + "sll_jagged_dense_elementwise_add": { + "CUDA": jagged_dense_elementwise_add, + "AutogradCUDA": jagged_dense_elementwise_add, + }, "sll_jagged_dense_flash_attention": { "CUDA": jagged_dense_flash_attention, "AutogradCUDA": jagged_dense_flash_attention, diff --git a/fbgemm_gpu/fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py b/fbgemm_gpu/fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py new file mode 100644 index 0000000000..ed883b8343 --- /dev/null +++ b/fbgemm_gpu/fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import torch + +from fbgemm_gpu.triton.jagged.triton_jagged_tensor_ops import ( + dense_to_jagged, + jagged_to_dense, +) + + +class JaggedDenseAdd(torch.autograd.Function): + @staticmethod + # pyre-fixme + def forward( + ctx, x: torch.Tensor, x_offsets: torch.Tensor, y: torch.Tensor, max_seq_len: int + ): + ctx.save_for_backward(x_offsets) + ctx.max_seq_len = max_seq_len + # TODO: what should be the correct behavior when jagged values has length > max seq len? + # current behavior is to not truncate jagged values + # similar for backward grad_output + return dense_to_jagged( + y, [x_offsets], operation_function="add", operation_jagged_values=x + )[0] + + @staticmethod + # pyre-fixme + def backward(ctx, grad_output: torch.Tensor): + (offsets,) = ctx.saved_tensors + grad_dense = jagged_to_dense(grad_output, [offsets], [ctx.max_seq_len]) + return grad_output, None, grad_dense, None + + +def jagged_dense_elementwise_add( + x: torch.Tensor, + x_offsets: torch.Tensor, + y: torch.Tensor, + max_seq_len: int, + use_fbgemm_kernel: bool = True, +): + if use_fbgemm_kernel: + return torch.ops.fbgemm.jagged_dense_elementwise_add_jagged_output( + x, [x_offsets], y + )[0] + else: + return JaggedDenseAdd.apply(x, x_offsets, y, max_seq_len) diff --git a/fbgemm_gpu/fbgemm_gpu/sll/triton/jagged_dense_flash_attention.py b/fbgemm_gpu/fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py similarity index 100% rename from fbgemm_gpu/fbgemm_gpu/sll/triton/jagged_dense_flash_attention.py rename to fbgemm_gpu/fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py diff --git a/fbgemm_gpu/fbgemm_gpu/sll/triton/multi_head_jagged_flash_attention.py b/fbgemm_gpu/fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py similarity index 100% rename from fbgemm_gpu/fbgemm_gpu/sll/triton/multi_head_jagged_flash_attention.py rename to fbgemm_gpu/fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py diff --git a/fbgemm_gpu/fbgemm_gpu/sll/triton_sll.py b/fbgemm_gpu/fbgemm_gpu/sll/triton_sll.py index 53f77949ba..179f3b788d 100644 --- a/fbgemm_gpu/fbgemm_gpu/sll/triton_sll.py +++ b/fbgemm_gpu/fbgemm_gpu/sll/triton_sll.py @@ -12,11 +12,6 @@ import triton import triton.language as tl -from fbgemm_gpu.triton.jagged.triton_jagged_tensor_ops import ( - dense_to_jagged, - jagged_to_dense, -) - def set_block_size(N: int) -> int: if N > 64: @@ -2591,41 +2586,3 @@ def jagged_flash_attention_basic( ) return jagged_O - - -class JaggedDenseAdd(torch.autograd.Function): - @staticmethod - # pyre-fixme - def forward( - ctx, x: torch.Tensor, x_offsets: torch.Tensor, y: torch.Tensor, max_seq_len: int - ): - ctx.save_for_backward(x_offsets) - ctx.max_seq_len = max_seq_len - # TODO: what should be the correct behavior when jagged values has length > max seq len? - # current behavior is to not truncate jagged values - # similar for backward grad_output - return dense_to_jagged( - y, [x_offsets], operation_function="add", operation_jagged_values=x - )[0] - - @staticmethod - # pyre-fixme - def backward(ctx, grad_output: torch.Tensor): - (offsets,) = ctx.saved_tensors - grad_dense = jagged_to_dense(grad_output, [offsets], [ctx.max_seq_len]) - return grad_output, None, grad_dense, None - - -def jagged_dense_elementwise_add( - x: torch.Tensor, - x_offsets: torch.Tensor, - y: torch.Tensor, - max_seq_len: int, - use_fbgemm_kernel: bool = True, -): - if use_fbgemm_kernel: - return torch.ops.fbgemm.jagged_dense_elementwise_add_jagged_output( - x, [x_offsets], y - )[0] - else: - return JaggedDenseAdd.apply(x, x_offsets, y, max_seq_len) diff --git a/fbgemm_gpu/test/sll/jagged_dense_elementwise_add_test.py b/fbgemm_gpu/test/sll/jagged_dense_elementwise_add_test.py index 010e7448d0..75850e9170 100644 --- a/fbgemm_gpu/test/sll/jagged_dense_elementwise_add_test.py +++ b/fbgemm_gpu/test/sll/jagged_dense_elementwise_add_test.py @@ -9,9 +9,10 @@ import unittest +import fbgemm_gpu.sll # noqa F401 import hypothesis.strategies as st import torch -from fbgemm_gpu.sll.triton_sll import jagged_dense_elementwise_add # noqa + from hypothesis import given, settings from .common import open_source