Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Llama-3.2-[1B/3B] support #203

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

stillmatic
Copy link

@stillmatic stillmatic commented Oct 1, 2024

This PR adds support for the small models in the Llama 3.2 release - 1B and 3B.

There are two pieces of custom logic we needed to implement for this to work:

  1. Support safetensors files without merging. Larger models require multiple safetensors files to fit, but the 1B model fits into a single exported file. As a result, we refactor the HF conversion script logic, and skip the state dict merging if we can find a single file with all the weights.
  2. Support tied embedding weights. The smaller models share input/output embeddings, presumably to save parameters, but the larger models don't. In the conversion step, we check if the output weights are empty and if so copy the input embeddings over.

I think this is all we need - the local testing generates reasonable results for 1B, 3B, and 8B (eager and compiled). note: in local testing I needed to comment out torch._functorch.config.enable_autograd_cache = True with torch 2.4.1 - I assume this was merged and supported by default?

Testing:
3B model:

export MODEL_REPO=meta-llama/Llama-3.2-3B-Instruct; python generate.py --max_new_tokens 16 --checkpoint_path "checkpoints/${MODEL_REPO}/model.pth"
 --compile --prompt "Hello, my name is"
Compilation time: 38.25 seconds
<|begin_of_text|>Hello, my name is [Your Name] and I am excited to be participating in the [Event/
Time for inference 1: 0.14 sec total, 116.02 tokens/sec
Bandwidth achieved: 745.49 GB/s
FLOPS achieved: 1.03 TF/s

<|begin_of_text|>Hello, my name is Ryan. I'm a photographer and artist, and I'm thrilled to be a
Time for inference 2: 0.19 sec total, 82.60 tokens/sec
Bandwidth achieved: 530.73 GB/s
FLOPS achieved: 0.73 TF/s

<|begin_of_text|>Hello, my name is Ryan and I'll be guiding you through a series of challenges designed to test your
Time for inference 3: 0.19 sec total, 82.31 tokens/sec
Bandwidth achieved: 528.87 GB/s
FLOPS achieved: 0.73 TF/s

<|begin_of_text|>Hello, my name is Emily, and I'm here to visit the top 11 most amazing aquariums
Time for inference 4: 0.19 sec total, 82.12 tokens/sec
Bandwidth achieved: 527.65 GB/s
FLOPS achieved: 0.73 TF/s

<|begin_of_text|>Hello, my name is (name), and I'm a (your profession). Welcome to our practice!
Time for inference 5: 0.20 sec total, 80.32 tokens/sec
Bandwidth achieved: 516.08 GB/s
FLOPS achieved: 0.71 TF/s

==========
Batch Size: 1
Prompt Length: 6
Generated tokens: 16
Average tokens/sec: 88.67
Memory used: 7.55 GB

1B model:

export MODEL_REPO=meta-llama/Llama-3.2-1B-Instruct; python generate.py --max_new_tokens 16 --checkpoint_path "checkpoints/${MODEL_REPO}/model.pth" --compile --prompt "Hello, my name is"
Compilation time: 25.37 seconds
<|begin_of_text|>Hello, my name is [Your Name] and I am excited to be a part of the [Company
Time for inference 1: 0.06 sec total, 249.87 tokens/sec
Bandwidth achieved: 617.58 GB/s
FLOPS achieved: 0.85 TF/s

<|begin_of_text|>Hello, my name is Bertha.

I am a 16-year-old girl, living in a world
Time for inference 2: 0.06 sec total, 247.61 tokens/sec
Bandwidth achieved: 612.01 GB/s
FLOPS achieved: 0.84 TF/s

<|begin_of_text|>Hello, my name is Svetlana, and I am a 32-year-old yoga teacher from
Time for inference 3: 0.06 sec total, 257.50 tokens/sec
Bandwidth achieved: 636.45 GB/s
FLOPS achieved: 0.88 TF/s

<|begin_of_text|>Hello, my name is Emily and I'm the founder of Green Gloop & Co., a sustainable food
Time for inference 4: 0.06 sec total, 251.88 tokens/sec
Bandwidth achieved: 622.55 GB/s
FLOPS achieved: 0.86 TF/s

<|begin_of_text|>Hello, my name is Emily and I'm a huge fan of your channel. Your content is always so
Time for inference 5: 0.07 sec total, 245.82 tokens/sec
Bandwidth achieved: 607.57 GB/s
FLOPS achieved: 0.84 TF/s

==========
Batch Size: 1
Prompt Length: 6
Generated tokens: 16
Average tokens/sec: 250.54
Memory used: 3.30 GB

8B model (regression check)

export MODEL_REPO=meta-llama/Meta-Llama-3-8B-Instruct; python generate.py --max_new_tokens 16 --checkpoint_path "checkpoints/${MODEL_REPO}/model.pth" --compile --prompt "Hello, my name is"
Compilation time: 10.54 seconds
<|begin_of_text|>Hello, my name is [Your Name] and I am excited to be here. I have been writing
Time for inference 1: 0.30 sec total, 53.07 tokens/sec
Bandwidth achieved: 796.64 GB/s
FLOPS achieved: 1.10 TF/s

<|begin_of_text|>Hello, my name is Ryan and I'm a photographer based in the Pacific Northwest. I specialize in outdoor
Time for inference 2: 0.30 sec total, 53.73 tokens/sec
Bandwidth achieved: 806.46 GB/s
FLOPS achieved: 1.11 TF/s

<|begin_of_text|>Hello, my name is Svetlana, and I am a certified yoga instructor. I am passionate
Time for inference 3: 0.30 sec total, 53.24 tokens/sec
Bandwidth achieved: 799.12 GB/s
FLOPS achieved: 1.10 TF/s

<|begin_of_text|>Hello, my name is Emily and I'm the founder of Art for Recovery. Art for Recovery is a
Time for inference 4: 0.30 sec total, 52.53 tokens/sec
Bandwidth achieved: 788.53 GB/s
FLOPS achieved: 1.08 TF/s

<|begin_of_text|>Hello, my name is Emily and I'm a yoga instructor and a Reiki Master. I'm here
Time for inference 5: 0.31 sec total, 52.29 tokens/sec
Bandwidth achieved: 784.94 GB/s
FLOPS achieved: 1.08 TF/s

==========
Batch Size: 1
Prompt Length: 6
Generated tokens: 16
Average tokens/sec: 52.97
Memory used: 16.35 GB

1B does not split model into multiple files so we do not need to merge
the weights.
neccessary for the smol guys
@facebook-github-bot
Copy link

Hi @stillmatic!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 1, 2024
@facebook-github-bot
Copy link

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants