Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The calculation of w in SO3 to_quaternion() function may cause gradient explosion. #670

Open
HiOnes opened this issue Nov 28, 2024 · 0 comments

Comments

@HiOnes
Copy link

HiOnes commented Nov 28, 2024

🐛 Bug

I suffered gradient explosion in my training process. I used with autograd.detect_anomaly() and got the hint Function 'SqrtBackward0' returned nan values in its 0th output. However, I didn't use functions like torch.sqrt() in my code, so I thought the bug may lie in the internal calculations of Theseus. And I have noticed this #661 relevant fix, so I checked the implementation of to_quarternion() function in so3.py.

I found the eps in the calculation of sine_half_theta, the relevant code is:

sqrt_eps = _THESEUS_GLOBAL_PARAMS.get_eps("so3", "to_quaternion_sqrt", w.dtype)
sine_half_theta = (
    (0.5 * (1 - cosine_near_pi)).clamp(sqrt_eps, 1).sqrt().view(-1, 1)
)

However, another use of sqrt lies in the calculation of w:

w = 0.5 * (1 + self[:, 0, 0] + self[:, 1, 1] + self[:, 2, 2]).clamp(0, 4).sqrt()

here it just limits the result between 0 and 4, when it is close to 0, the backward process may fail.

Steps to Reproduce

I prepare a simple test code to reproduce this bug:

import theseus as th
import torch
import torch.nn.functional as F

rot = torch.tensor([[1.0, 0.0, 0.0],
                    [0.0, -1.0, 0.0],
                    [0.0, 0.0, -1.0]], requires_grad=True).reshape(1, 3, 3)
rot_so3 = th.SO3(tensor=rot)
identity_quat = torch.tensor([1.0, 0.0, 0.0, 0.0]).reshape(1, 4)
err = F.mse_loss(rot_so3.to_quaternion(), identity_quat)
rot.retain_grad()
err.backward()
print(rot.grad)

The output will be

tensor([[[-inf, 0., 0.],
         [0., -inf, 0.],
         [0., 0., -inf]]])

And if I add an eps(which is 1e-6 in my test) in the calculation of w:

w = 0.5 * (1 + self[:, 0, 0] + self[:, 1, 1] + self[:, 2, 2]).clamp(1e-6, 4).sqrt()

The grad will be:

tensor([[[-0.0625,  0.0000,  0.0000],
         [ 0.0000, -0.0625,  0.0000],
         [ 0.0000,  0.0000, -0.0625]]])

System Info

  • OS : Ubuntu 20.04
  • Python version: 3.8
  • CUDA version: 11.8
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant