-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bug: Numerically unstable loss at reward model #423
Comments
If you agree with this, here is PR: #424 |
Hi @s-isaev, Thank you for bringing this issue to our attention! I can help evaluate the impact of the loss calculation and the PR. From a quick glance, this seems to be quite reasonable.If I understand correctly, softplus(x) is essentially max(0,x)+log(1+e−|x|), right? softplus(-x) avoids Inf values while having a very small delta (should be identical in theory) in comparison to -log(sigmoid(x)). In the meantime, could you please provide more information about how you encountered this issue? Did it occur while following the step 2 training script or when changing certain parameters such as datasets or learning rate? Best, |
@minjiaz @s-isaev
This is consistent with trl implementation: https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py#L186 I believe this repository primarily re-implements the reward model based on trlx while also inheriting the unstable loss implementation: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/reward_model.py#L89 |
I found that if I use all four datasets given in the training script(i.e. Dahoas/rm-static, Dahoas/full-hh-rlhf, Dahoas/synthetic-instruct-gptj-pairwise, yitingxie/rlhf-reward-datasets), then I will get an infinite loss; but if I only use Dahoas/rm-static, the loss is a normal value. |
@s-isaev and @DanqingZ, thank you both for your suggestion! I have tested both softplus and logsigmoid. Both are numerically stable. -torch.nn.functional.logsigmoid(x) provides a slightly more accurate value while avoiding infinite loss than torch.nn.functiona.softplus(-x), so we will replace log(sigmoid) with torch.nn.functional.logsigmoid(x). I have created a PR (#501) to fix the issue. i, -torch.log(torch.sigmoid(x)), torch.nn.functiona.softplus(-x), -torch.nn.functional.logsigmoid(x) Best, |
Hi @HermitSun,
Yes, we have observed that adding certain datasets lead to performance degradation. We have updated step 3 training scripts to use just one dataset for now, as shown below. Line 28 in 2ec4be7
Best, |
Thank you for your reply🥰. Is there any conclusion about what kind of datasets will cause an unexpected performance degradation? Maybe these datasets contain specific texts or something else? |
Multiple log(sigmoid) implementations have been tested with end-to-end eval, changes merged at #501. |
Hi! I have got an infinite loss when trained critic model at step 2:
Epoch 1/1 with loss inf
I've found a source of this problem: reward model loss is calculated with unstable formula:
DeepSpeedExamples/applications/DeepSpeed-Chat/training/utils/model/reward_model.py
Line 102 in ab4e2e5
I propose to replace it with this expression:
loss += nn.functional.softplus( r_truncated_reward - c_truncated_reward).mean()
Mathematically
-log(sigmoid(x))
is equal tosoftplus(-x)
but the second one is stable. Here are outputs of these functions respectivelly with fp32:-100.0: (inf,100.0) -90.0: (inf,90.0) -80.0: (80.0,80.0) -70.0: (70.0,70.0) -60.0: (60.0,60.0) -50.0: (50.0,50.0) -40.0: (40.0,40.0) -30.0: (30.0,30.0) -20.0: (20.0,20.0) -10.0: (10.000045776367188,10.000045776367188) 0.0: (0.6931471824645996,0.6931471824645996) 10.0: (4.541977250482887e-05,4.5398901420412585e-05) 20.0: (-0.0,2.06115369216775e-09) 30.0: (-0.0,9.357622912219837e-14) 40.0: (-0.0,4.24835413113866e-18) 50.0: (-0.0,1.9287498933537385e-22) 60.0: (-0.0,8.75651089272076e-27) 70.0: (-0.0,3.975449954226706e-31) 80.0: (-0.0,1.8048513285848406e-35) 90.0: (-0.0,8.194008692231508e-40)
The text was updated successfully, but these errors were encountered: