-
Notifications
You must be signed in to change notification settings - Fork 45
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Operator] fix: fix pytorch 2.4 tanh backward bug #263
base: master
Are you sure you want to change the base?
Conversation
谢谢贡献! |
@StrongSpoon 理解,所以我现在的版本没有用custom op。custom op的代码已经注释了,可能在代码diff里有点看不清。。。 这里的做法是在gems namespace上再把triton算子注册一遍,来保证经过torch compile时没有问题。其实这是make sense的。因为在torch原生的代码里,像gelu这个算子抓去到最底层也会得到类似 @torch.library.impl("gems::tanh_forward", "cuda")
def tanh_forward(x: torch.Tensor) -> torch.Tensor:
return tanh_forward_kernel(x)
@torch.library.impl("gems::tanh_backward", "cuda")
def tanh_backward(y: torch.Tensor, dy: torch.Tensor) -> torch.Tensor:
return tanh_backward_kernel(y, dy)
@torch.library.impl_abstract("gems::tanh_forward")
def fake_tanh_forward(x: torch.Tensor) -> torch.Tensor:
return x
@torch.library.impl_abstract("gems::tanh_backward")
def fake_tanh_backward(y: torch.Tensor, dy: torch.Tensor) -> torch.Tensor:
return dy 现在的做法使用了是可以兼容 From torch2.2 To torch2.4的。但是看pytorch这几个版本的更新自定义算子还是占大头,所以在注释里提供了一个更新的解决方案~ |
Thanks so much. @StrongSpoon we need to dig into how torch.compile resolves the mapping from these decorated functions to its own lowering targets. |
We need to figure out why the previous solution went wrong and whether the new solution will cause performance loss. |
Get it |
you could reformat the code according to CONTRIBUTING.md first :) |
不过这个做法是定义了一个新的 op 了。 |
@iclementine 我的理解是在__init__.py里定义了torch.library.Library 替换的aten,并且在impl里让最外层的tanh算子替换了。内部注册的ops其实只在内部用到了,并不会影响其他的。所以会分发到我们定义的tanh上。同时附上一份测试: def tanh(A: torch.Tensor):
print("using gems tanh")
return Tanh.apply(A) Terminal: using gems tanh
using gems tanh
using gems tanh
using gems tanh
tensor([ 0.1863, -1.5428, -0.3257, -0.1186, -1.1268, -0.1928, -0.0543, 0.0186,
-0.2310, -0.0073], device='cuda:0') |
了解了。实际上是注册了一个新的 op gems::tanh_forward. 不过 impl 这样的用法并不像 custom_op 那样返回一个 tanh_forward。所以tanh_forward 这个 name 的值成了 None. 使用的时候有一层间接。 torch.ops.aten.tanh -> flag_gems.ops.tanh 函数 |
是的,是这样。 |
Hi, I fixed according to the code style in CONTRIBUTING.md. |
Hi, when I was looking at the flash-attn repository, I noticed that they handle the interfaces like this. if torch.__version__ >= "2.4.0":
_torch_custom_op_wrapper = torch.library.custom_op
_torch_register_fake_wrapper = torch.library.register_fake
else:
def noop_custom_op_wrapper(name, fn=None, /, *, mutates_args, device_types=None, schema=None):
def wrap(func):
return func
if fn is None:
return wrap
return fn
def noop_register_fake_wrapper(op, fn=None, /, *, lib=None, _stacklevel=1):
def wrap(func):
return func
if fn is None:
return wrap
return fn
_torch_custom_op_wrapper = noop_custom_op_wrapper
_torch_register_fake_wrapper = noop_register_fake_wrapper and both register fake impl and real impl like.
@_torch_register_fake_wrapper("flash_attn::_flash_attn_forward")
def _flash_attn_forward_fake(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
dropout_p: float,
softmax_scale: float,
causal: bool,
window_size_left: int,
window_size_right: int,
softcap: float,
alibi_slopes: Optional[torch.Tensor],
return_softmax: bool
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# some fake impl
...
@_torch_custom_op_wrapper("flash_attn::_flash_attn_forward", mutates_args=(), device_types="cuda")
def _flash_attn_forward(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
dropout_p: float,
softmax_scale: float,
causal: bool,
window_size_left: int,
window_size_right: int,
softcap: float,
alibi_slopes: Optional[torch.Tensor],
return_softmax: bool
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# some real impl
... And then use the custom op to torch.autograd.Function. Maybe we could reference this implementation. |
I think we were obviously focused on custom ops and overlooked the real cause that gives rise to the error. It's not autograd.Function that impedes symbolic tracing. I was able to write a simple function using autograd.Function, registered to aten and passed torch compile. It proves that using the low level API and be compatible with torch compile is plausible. Declaring functions under The real issue is |
Thank you for your response. I tried several methods to wrap And I viewed some PyTorch 2.5.0 Release, this PR maybe related to out problem: pytorch/pytorch#133125 |
Thanks for the additional information. I am also digging into the details of fake tensor and meta device and their relations. We would get some results soon. |
We have found the reason. When registering a backend-specific autograd kernel: Fake tensors on CUDA device has been passed into an AutogradCUDA kernel. That is not an expected behavior. In an ideal case, fake tensors on Meta device passed into Meta kernels. And everything works. I don't know exactly why there are I have opend an issue on pytorch repository: pytorch/pytorch#139707. Adding another layer of indirection does not really solve this problem. There are still fake tensors on cuda passed to the |
cool! |
PR Category
Operator
Type of Change
Bug Fix
Description
根据issue 249修复 bug
问题代码:
同时也怀疑其他使用torch.autograd.Function 和 AutogradCUDA 的wrapper里也需要修复。
Issue
#249
Progress
Performance