Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question on mathematical equivalence #2

Open
nlpfollower opened this issue Feb 2, 2025 · 1 comment
Open

Question on mathematical equivalence #2

nlpfollower opened this issue Feb 2, 2025 · 1 comment

Comments

@nlpfollower
Copy link

nlpfollower commented Feb 2, 2025

Hey, really like this idea!

Was experimenting with the mask a bit, and was curious to know whether you ever tested for mathematical equivalence of the masked and unmasked forward steps for Llama?

I tried running unmasked = model(prompt + rejected, mask=None) and then masked = model(prompt + chosen + rejected, mask=block_mask), and then comparing the logits of unmasked and masked[:prompt] + masked[prompt + chosen:]. However, the logits for the rejected part seem to differ in my experiments, at least when using Llama. The results were the same for a 2D bool mask and sdpa.

It's possible this is an issue with my implementation. I'll see if I can reproduce it in your repo.

@nlpfollower
Copy link
Author

nlpfollower commented Feb 2, 2025

Okay, it seems like I was able to reproduce this experiment on a fork (nlpfollower#1), running single-rank Llama3.1-8B.

In my original implementation, I tested using the document mask (e.g. https://github.com/pytorch-labs/attention-gym/blob/main/attn_gym/masks/document_mask.py) with packing -- i.e. the boring approach -- and there was still a difference between the logits, though it was significantly smaller compared to the prefix sharing mask.

Do you see any potential issues with my code or this experiment?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant