Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][Verl] Add entropy loss to cross_entropy_loss and fused_linear_cross_entropy_loss #551

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

hongpeng-guo
Copy link
Collaborator

Summary

In RLHF workflows, such as verl, the actor forward function usually generates both losses of cross_entropy_loss (-log_probs) and entropy_loss, the later was used to encourage the policy to be not over-deterministic.

There is a real needs for a kernel that will generates both the two losses, with materializing the huge logits tensor. Liger-kernel's fused_linear_cross_entropy_loss already works well to generate the cross_entropy_loss, but only calculating the second part of the loss, i.e., the entropy loss.

This PR adds the entropy loss option to the existing FLCE loss, and work as one important step to support verl.

  1. Adding the entropy calculation in the second pass of online softmax in cross_entropy.py::liger_cross_entropy_kernel, both the loss and its gradient subject to input are calculated and stored;
  2. Propagate the changes to relevant modules in fused_linear_cross_entropy.py,
  3. Propagate relavent changes to other functional modules in PyTorch interface.

Testing Done

Made existing unit tests working; Adding new unittest WIP.

  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
@hongpeng-guo hongpeng-guo marked this pull request as draft January 30, 2025 04:38
@hongpeng-guo hongpeng-guo changed the title [Feature] Add entropy loss to cross_entropy_loss and fused_linear_cross_entropy_loss [WIP][Feature][Verl] Add entropy loss to cross_entropy_loss and fused_linear_cross_entropy_loss Jan 30, 2025
@hongpeng-guo hongpeng-guo changed the title [WIP][Feature][Verl] Add entropy loss to cross_entropy_loss and fused_linear_cross_entropy_loss [WIP][Verl] Add entropy loss to cross_entropy_loss and fused_linear_cross_entropy_loss Jan 30, 2025
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
@hongpeng-guo hongpeng-guo requested a review from ByronHsu January 30, 2025 09:56
@Tcc0403
Copy link
Collaborator

Tcc0403 commented Jan 30, 2025

Please add a unit test with return_entropy_loss. You can write a new pytorch implementation like CrossEntropyWithZLoss, or return_entropy_loss functionality on top of it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants