Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use scan and hostoffloading for llama model #123

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft

Conversation

zpcore
Copy link
Collaborator

@zpcore zpcore commented Feb 25, 2025

Port scan and hostoffloading for llama model based on @tengyifei 's prototype in 1 and 2.

The sharding schema in torchprime/torch_xla_models/configs/model/scaling/llama-fsdp.yaml also plays well with the scan code.

Currently there are NaN issue when we use scan with flash attention kernel related to pytorch/xla#8734. Need to resolve the issue before producing the correct output.

@@ -20,3 +20,4 @@ attention_dropout: false
attention_bias: false
flash_attention: true
rope_theta: 500000.0
scan_decoder_layers: true
Copy link
Collaborator Author

@zpcore zpcore Feb 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to default yaml file

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant