Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug: Numerically unstable loss at reward model #423

Closed
s-isaev opened this issue Apr 25, 2023 · 8 comments
Closed

Bug: Numerically unstable loss at reward model #423

s-isaev opened this issue Apr 25, 2023 · 8 comments
Assignees
Labels
deespeed chat DeepSpeed Chat

Comments

@s-isaev
Copy link
Contributor

s-isaev commented Apr 25, 2023

Hi! I have got an infinite loss when trained critic model at step 2:
Epoch 1/1 with loss inf
I've found a source of this problem: reward model loss is calculated with unstable formula:


I propose to replace it with this expression:

loss += nn.functional.softplus( r_truncated_reward - c_truncated_reward).mean()

Mathematically -log(sigmoid(x)) is equal to softplus(-x) but the second one is stable. Here are outputs of these functions respectivelly with fp32:
-100.0: (inf,100.0) -90.0: (inf,90.0) -80.0: (80.0,80.0) -70.0: (70.0,70.0) -60.0: (60.0,60.0) -50.0: (50.0,50.0) -40.0: (40.0,40.0) -30.0: (30.0,30.0) -20.0: (20.0,20.0) -10.0: (10.000045776367188,10.000045776367188) 0.0: (0.6931471824645996,0.6931471824645996) 10.0: (4.541977250482887e-05,4.5398901420412585e-05) 20.0: (-0.0,2.06115369216775e-09) 30.0: (-0.0,9.357622912219837e-14) 40.0: (-0.0,4.24835413113866e-18) 50.0: (-0.0,1.9287498933537385e-22) 60.0: (-0.0,8.75651089272076e-27) 70.0: (-0.0,3.975449954226706e-31) 80.0: (-0.0,1.8048513285848406e-35) 90.0: (-0.0,8.194008692231508e-40)

@s-isaev s-isaev changed the title Bug: Numerically unstable at reward model Bug: Numerically unstable loss at reward model Apr 25, 2023
@s-isaev
Copy link
Contributor Author

s-isaev commented Apr 25, 2023

If you agree with this, here is PR: #424

@minjiaz
Copy link
Contributor

minjiaz commented May 4, 2023

Hi @s-isaev,

Thank you for bringing this issue to our attention! I can help evaluate the impact of the loss calculation and the PR.

From a quick glance, this seems to be quite reasonable.If I understand correctly, softplus(x) is essentially max(0,x)+log(1+e−|x|), right? softplus(-x) avoids Inf values while having a very small delta (should be identical in theory) in comparison to -log(sigmoid(x)).

In the meantime, could you please provide more information about how you encountered this issue? Did it occur while following the step 2 training script or when changing certain parameters such as datasets or learning rate?

Best,
Minjia

@DanqingZ
Copy link

DanqingZ commented May 7, 2023

@minjiaz @s-isaev
I had infinity loss for step 2 and chatted with GPT4 to get a solution
When calculating the difference between c_truncated_reward and r_truncated_reward, the values might be too large, causing the sigmoid function to return either 0 or 1. In this case, the logarithm becomes undefined. To mitigate this issue, you can use the logsigmoid function instead:

loss += -torch.nn.functional.logsigmoid(c_truncated_reward - r_truncated_reward).mean()

This is consistent with trl implementation: https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py#L186

I believe this repository primarily re-implements the reward model based on trlx while also inheriting the unstable loss implementation: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/reward_model.py#L89

@HermitSun
Copy link

I found that if I use all four datasets given in the training script(i.e. Dahoas/rm-static, Dahoas/full-hh-rlhf, Dahoas/synthetic-instruct-gptj-pairwise, yitingxie/rlhf-reward-datasets), then I will get an infinite loss; but if I only use Dahoas/rm-static, the loss is a normal value.
Does anyone know why this happens?

@minjiaz
Copy link
Contributor

minjiaz commented May 9, 2023

@s-isaev and @DanqingZ, thank you both for your suggestion!

I have tested both softplus and logsigmoid. Both are numerically stable. -torch.nn.functional.logsigmoid(x) provides a slightly more accurate value while avoiding infinite loss than torch.nn.functiona.softplus(-x), so we will replace log(sigmoid) with torch.nn.functional.logsigmoid(x). I have created a PR (#501) to fix the issue.

i, -torch.log(torch.sigmoid(x)), torch.nn.functiona.softplus(-x), -torch.nn.functional.logsigmoid(x)
-100 , tensor(inf) tensor(100.) tensor(100.)
-90.0 , tensor(inf) tensor(90.) tensor(90.)
-80.0 , tensor(80.) tensor(80.) tensor(80.)
-70.0 , tensor(70.) tensor(70.) tensor(70.)
-60.0 , tensor(60.) tensor(60.) tensor(60.)
-50.0 , tensor(50.) tensor(50.) tensor(50.)
-40.0 , tensor(40.) tensor(40.) tensor(40.)
-30.0 , tensor(30.) tensor(30.) tensor(30.)
-20.0 , tensor(20.) tensor(20.) tensor(20.)
-10.0 , tensor(10.0000) tensor(10.0000) tensor(10.0000)
0.0 , tensor(0.6931) tensor(0.6931) tensor(0.6931)
10.0 , tensor(4.5420e-05) tensor(4.5399e-05) tensor(4.5418e-05)
20.0 , tensor(-0.) tensor(2.0612e-09) tensor(0.)
30.0 , tensor(-0.) tensor(9.3576e-14) tensor(0.)
40.0 , tensor(-0.) tensor(4.2484e-18) tensor(0.)
50.0 , tensor(-0.) tensor(1.9287e-22) tensor(0.)
60.0 , tensor(-0.) tensor(8.7565e-27) tensor(0.)
70.0 , tensor(-0.) tensor(3.9754e-31) tensor(0.)
80.0 , tensor(-0.) tensor(1.8049e-35) tensor(0.)
90.0 , tensor(-0.) tensor(8.1940e-40) tensor(0.)
100.0 , tensor(-0.) tensor(3.7835e-44) tensor(0.)

Best,
Minjia

@minjiaz
Copy link
Contributor

minjiaz commented May 9, 2023

Hi @HermitSun,

I found that if I use all four datasets given in the training script(i.e. Dahoas/rm-static, Dahoas/full-hh-rlhf, Dahoas/synthetic-instruct-gptj-pairwise, yitingxie/rlhf-reward-datasets), then I will get an infinite loss; but if I only use Dahoas/rm-static, the loss is a normal value. Does anyone know why this happens?

Yes, we have observed that adding certain datasets lead to performance degradation. We have updated step 3 training scripts to use just one dataset for now, as shown below.

Best,
Minjia

@HermitSun
Copy link

Hi @HermitSun,

I found that if I use all four datasets given in the training script(i.e. Dahoas/rm-static, Dahoas/full-hh-rlhf, Dahoas/synthetic-instruct-gptj-pairwise, yitingxie/rlhf-reward-datasets), then I will get an infinite loss; but if I only use Dahoas/rm-static, the loss is a normal value. Does anyone know why this happens?

Yes, we have observed that adding certain datasets lead to performance degradation. We have updated step 3 training scripts to use just one dataset for now, as shown below.

Best, Minjia

Thank you for your reply🥰.

Is there any conclusion about what kind of datasets will cause an unexpected performance degradation?

Maybe these datasets contain specific texts or something else?

@samadejacobs samadejacobs added the deespeed chat DeepSpeed Chat label May 9, 2023
@minjiaz
Copy link
Contributor

minjiaz commented May 12, 2023

Multiple log(sigmoid) implementations have been tested with end-to-end eval, changes merged at #501.

@minjiaz minjiaz closed this as completed May 12, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
deespeed chat DeepSpeed Chat
Projects
None yet
Development

No branches or pull requests

5 participants