Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operator] fix: fix pytorch 2.4 tanh backward bug #263

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

yinfan98
Copy link

@yinfan98 yinfan98 commented Oct 28, 2024

PR Category

Operator

Type of Change

Bug Fix

Description

根据issue 249修复 bug

问题代码:

import flag_gems
import torch

flag_gems.enable()

def f(x, y):
    a = torch.tanh(y)
    b = x - y
    return flag_gems.fused.gelu_and_mul(a, b)

x = torch.randn(10,device="cuda")
y = torch.randn(10,device="cuda")

F = torch.compile(f)

print(F(x, y))
  • issue 249写明是在backward时的bug,我观察发现感觉不是backward的问题,在torch 2.4的环境下进行测试时。发现是带autograd.Funcion 和 torch.compile一起使用时会有Fake Tensor的报错,但同样测试一些不带autograd的算子,例如cos就不会导致这个错误。
  • 了解到pointwise_dynamic会自动生成triton算子,在autograd.Function里。应当在生成的triton算子外再包一层custom ops较为妥当。并对额外包的一层算子注册它自己的FakeTensor。考虑到本仓库需要的torch版本是2.2以上,这里给出两种实现方式。
  • 注释中这个方法强烈需求custom ops,它是在torch2.4才加入的新特性

同时也怀疑其他使用torch.autograd.Function 和 AutogradCUDA 的wrapper里也需要修复。

Issue

#249

Progress

  • Change is properly reviewed (1 reviewer required, 2 recommended).
  • Change is responded to an issue.
  • Change is fully covered by a UT.

Performance

@StrongSpoon
Copy link
Collaborator

谢谢贡献!
FakeTensor不支持的情况与我们的结论一致,然而考虑到对多版本PyTorch的兼容需求,custom_op的方法可能暂时不便合入。

@yinfan98
Copy link
Author

yinfan98 commented Oct 29, 2024

谢谢贡献! FakeTensor不支持的情况与我们的结论一致,然而考虑到对多版本PyTorch的兼容需求,custom_op的方法可能暂时不便合入。

@StrongSpoon 理解,所以我现在的版本没有用custom op。custom op的代码已经注释了,可能在代码diff里有点看不清。。。

这里的做法是在gems namespace上再把triton算子注册一遍,来保证经过torch compile时没有问题。其实这是make sense的。因为在torch原生的代码里,像gelu这个算子抓去到最底层也会得到类似torch.ops.aten.gelu_backward 这样的东西。

@torch.library.impl("gems::tanh_forward", "cuda")
def tanh_forward(x: torch.Tensor) -> torch.Tensor:
    return tanh_forward_kernel(x)

@torch.library.impl("gems::tanh_backward", "cuda")
def tanh_backward(y: torch.Tensor, dy: torch.Tensor) -> torch.Tensor:
    return tanh_backward_kernel(y, dy)

@torch.library.impl_abstract("gems::tanh_forward")
def fake_tanh_forward(x: torch.Tensor) -> torch.Tensor:
    return x

@torch.library.impl_abstract("gems::tanh_backward")
def fake_tanh_backward(y: torch.Tensor, dy: torch.Tensor) -> torch.Tensor:
    return dy

现在的做法使用了是可以兼容 From torch2.2 To torch2.4的。但是看pytorch这几个版本的更新自定义算子还是占大头,所以在注释里提供了一个更新的解决方案~

@tongxin
Copy link
Contributor

tongxin commented Oct 29, 2024

Thanks so much. @StrongSpoon we need to dig into how torch.compile resolves the mapping from these decorated functions to its own lowering targets.

@Bowen12992
Copy link
Collaborator

We need to figure out why the previous solution went wrong and whether the new solution will cause performance loss.

@yinfan98
Copy link
Author

We need to figure out why the previous solution went wrong and whether the new solution will cause performance loss.

Get it

@StrongSpoon
Copy link
Collaborator

you could reformat the code according to CONTRIBUTING.md first :)

@iclementine
Copy link
Collaborator

不过这个做法是定义了一个新的 op 了。gems::tanh_forwardaten::tanh 不同。那么这么做直接使用 torch.tanh 就不会分发到我们定义的函数了吧?

@yinfan98
Copy link
Author

yinfan98 commented Oct 29, 2024

不过这个做法是定义了一个新的 op 了。gems::tanh_forwardaten::tanh 不同。那么这么做直接使用 torch.tanh 就不会分发到我们定义的函数了吧?

@iclementine 我的理解是在__init__.py里定义了torch.library.Library 替换的aten,并且在impl里让最外层的tanh算子替换了。内部注册的ops其实只在内部用到了,并不会影响其他的。所以会分发到我们定义的tanh上。同时附上一份测试:

def tanh(A: torch.Tensor):
    print("using gems tanh")
    return Tanh.apply(A)

Terminal:

using gems tanh
using gems tanh
using gems tanh
using gems tanh
tensor([ 0.1863, -1.5428, -0.3257, -0.1186, -1.1268, -0.1928, -0.0543,  0.0186,
        -0.2310, -0.0073], device='cuda:0')

@iclementine
Copy link
Collaborator

iclementine commented Oct 29, 2024

不过这个做法是定义了一个新的 op 了。gems::tanh_forwardaten::tanh 不同。那么这么做直接使用 torch.tanh 就不会分发到我们定义的函数了吧?

@iclementine 我的理解是在__init__.py里定义了torch.library.Library 替换的aten,并且在impl里让最外层的tanh算子替换了。内部注册的ops其实只在内部用到了,并不会影响其他的。所以会分发到我们定义的tanh上。同时附上一份测试:

def tanh(A: torch.Tensor):
    print("using gems tanh")
    return Tanh.apply(A)

Terminal:

using gems tanh
using gems tanh
using gems tanh
using gems tanh
tensor([ 0.1863, -1.5428, -0.3257, -0.1186, -1.1268, -0.1928, -0.0543,  0.0186,
        -0.2310, -0.0073], device='cuda:0')

了解了。实际上是注册了一个新的 op gems::tanh_forward.

不过 impl 这样的用法并不像 custom_op 那样返回一个 tanh_forward。所以tanh_forward 这个 name 的值成了 None.

使用的时候有一层间接。

torch.ops.aten.tanh -> flag_gems.ops.tanh 函数
torch.ops.gems.tanh_forward -> 上面定义的函数

@yinfan98
Copy link
Author

不过这个做法是定义了一个新的 op 了。gems::tanh_forwardaten::tanh 不同。那么这么做直接使用 torch.tanh 就不会分发到我们定义的函数了吧?

@iclementine 我的理解是在__init__.py里定义了torch.library.Library 替换的aten,并且在impl里让最外层的tanh算子替换了。内部注册的ops其实只在内部用到了,并不会影响其他的。所以会分发到我们定义的tanh上。同时附上一份测试:

def tanh(A: torch.Tensor):
    print("using gems tanh")
    return Tanh.apply(A)

Terminal:

using gems tanh
using gems tanh
using gems tanh
using gems tanh
tensor([ 0.1863, -1.5428, -0.3257, -0.1186, -1.1268, -0.1928, -0.0543,  0.0186,
        -0.2310, -0.0073], device='cuda:0')

了解了。实际上是注册了一个新的 op gems::tanh_forward.

不过 impl 这样的用法并不像 custom_op 那样返回一个 tanh_forward。所以tanh_forward 这个 name 的值成了 None.

使用的时候有一层间接。

torch.ops.aten.tanh -> flag_gems.ops.tanh 函数 torch.ops.gems.tanh_forward -> 上面定义的函数

是的,是这样。

@yinfan98
Copy link
Author

you could reformat the code according to CONTRIBUTING.md first :)

Hi, I fixed according to the code style in CONTRIBUTING.md.

@yinfan98
Copy link
Author

yinfan98 commented Nov 1, 2024

Hi, when I was looking at the flash-attn repository, I noticed that they handle the interfaces like this.

if torch.__version__ >= "2.4.0":
    _torch_custom_op_wrapper = torch.library.custom_op
    _torch_register_fake_wrapper = torch.library.register_fake
else:
    def noop_custom_op_wrapper(name, fn=None, /, *, mutates_args, device_types=None, schema=None):
        def wrap(func):
            return func
        if fn is None:
            return wrap
        return fn
    def noop_register_fake_wrapper(op, fn=None, /, *, lib=None, _stacklevel=1):
        def wrap(func):
            return func
        if fn is None:
            return wrap
        return fn
    _torch_custom_op_wrapper = noop_custom_op_wrapper
    _torch_register_fake_wrapper = noop_register_fake_wrapper

and both register fake impl and real impl like.

  • fake:
@_torch_register_fake_wrapper("flash_attn::_flash_attn_forward")
def _flash_attn_forward_fake(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    dropout_p: float,
    softmax_scale: float,
    causal: bool,
    window_size_left: int,
    window_size_right: int,
    softcap: float,
    alibi_slopes: Optional[torch.Tensor],
    return_softmax: bool
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    # some fake impl
    ...
  • real:
@_torch_custom_op_wrapper("flash_attn::_flash_attn_forward", mutates_args=(), device_types="cuda")
def _flash_attn_forward(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    dropout_p: float,
    softmax_scale: float,
    causal: bool,
    window_size_left: int,
    window_size_right: int,
    softcap: float,
    alibi_slopes: Optional[torch.Tensor],
    return_softmax: bool
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    # some real impl
    ...

And then use the custom op to torch.autograd.Function. Maybe we could reference this implementation.
cc: @StrongSpoon @tongxin @Bowen12992 @iclementine

@tongxin
Copy link
Contributor

tongxin commented Nov 2, 2024

I think we were obviously focused on custom ops and overlooked the real cause that gives rise to the error. It's not autograd.Function that impedes symbolic tracing. I was able to write a simple function using autograd.Function, registered to aten and passed torch compile. It proves that using the low level API and be compatible with torch compile is plausible. Declaring functions under gems:: namespace is not necessary although it might be a good idea otherwise.

The real issue is _base.data_ptr() in StridedBuffer is not traceable with FakeTensor. I guess we need a symbolic wrapper for StridedBuffer instead.

@yinfan98
Copy link
Author

yinfan98 commented Nov 3, 2024

I think we were obviously focused on custom ops and overlooked the real cause that gives rise to the error. It's not autograd.Function that impedes symbolic tracing. I was able to write a simple function using autograd.Function, registered to aten and passed torch compile. It proves that using the low level API and be compatible with torch compile is plausible. Declaring functions under gems:: namespace is not necessary although it might be a good idea otherwise.

The real issue is _base.data_ptr() in StridedBuffer is not traceable with FakeTensor. I guess we need a symbolic wrapper for StridedBuffer instead.

Thank you for your response. I tried several methods to wrap StridedBuffer using torch.fx, but they were unsuccessful. Despite implementing various wrappers, FakeTensor was still being caught in the Triton code. Furthermore, I also believe this is not an autograd.Function issue, but rather that the choice between AutogradCUDA and CUDA during registration determines whether this problem occurs. I think at the capture stage, different registration methods affect whether the generation can be successfully called (for example, the cos operator doesn't encounter these issues). Using torch.compile will ultimately rely on dynamo, and I don't fully understand the deeper details of dynamo. Moreover, dynamo has been continuously updating in recent versions. Since I'm not very familiar with FlagGems underlying Triton code generation, for now, I can only proceed with the registration method.

And I viewed some PyTorch 2.5.0 Release, this PR maybe related to out problem: pytorch/pytorch#133125

@iclementine
Copy link
Collaborator

Thanks for the additional information. I am also digging into the details of fake tensor and meta device and their relations. We would get some results soon.

@iclementine
Copy link
Collaborator

iclementine commented Nov 5, 2024

We have found the reason. When registering a backend-specific autograd kernel:

Fake tensors on CUDA device has been passed into an AutogradCUDA kernel.

That is not an expected behavior.

In an ideal case, fake tensors on Meta device passed into Meta kernels. And everything works.

I don't know exactly why there are fake tensors on cuda involved in the case when compiling a function with aten operators with backend-specific autograd kernels.

I have opend an issue on pytorch repository: pytorch/pytorch#139707.

Adding another layer of indirection does not really solve this problem. There are still fake tensors on cuda passed to the torch.autograd.Function. It just happens that the function can now handles them. If you add a x.data_ptr() into the torch.autograd.Function's forward method, it fails again. This is a usual usage in custom kernels since accessing data pointer is commonplace to do.

@yinfan98
Copy link
Author

yinfan98 commented Nov 5, 2024

We have found the reason. When registering a backend-specific autograd kernel:

Fake tensors on CUDA device has been passed into AutogradCUDA kernel.

That is not an expected behavior.

In an ideal case, fake tensors on Meta device passed into Meta kernels. And everything works.

I don't know exactly why there are fake tensors on cuda involved in the case when compiling a function with aten operators with backend-specific autograd kernels.

I have opend an issue on pytorch repository: pytorch/pytorch#139707.

Adding another layer of indirection does not really solve this problem. There are still fake tensors on cuda passed to the torch.autograd.Function. It just happens that the function can now handles them. If you add a x.data_ptr() into the torch.autograd.Function's forward method, it fails again. This is a usual usage in custom kernels since accessing data pointer is commonplace to do.

cool!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants