diff --git a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py index 7f3e6fbf6f..eb4eb01d97 100644 --- a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py +++ b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py @@ -257,9 +257,9 @@ def sin_triton(x): # Prefer this to using ``torch.autograd.Function`` (which has various composability footguns # with ``torch.compile``). -def backward(ctx, grad_output): +def backward(ctx, grad): x, = ctx.saved_tensors - return grad_input * x.cos() + return grad * x.cos() def setup_context(ctx, inputs, output): x, = inputs @@ -293,9 +293,9 @@ def mycos(x: torch.Tensor) -> torch.Tensor: wrap_triton(cos_kernel)[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4) return out -def backward(ctx, grad_output): +def backward(ctx, grad): x, = ctx.saved_tensors - return grad_input * mycos(x) + return grad * mycos(x) def setup_context(ctx, inputs, output): x, = inputs