Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add export --output-snapshot-path snap.tc, and --snapshot-path snap.tc #1465

Merged
merged 9 commits into from
Jan 31, 2025

Conversation

mikekgfb
Copy link
Contributor

Add ability to save and restore quantized models #1032

mgschwind@mgschwind-mlt torchchat % python3 torchchat.py generate stories15M --quant torchchat/quant_config/desktop.json --prompt "once upon a time"
NumExpr defaulting to 12 threads.
PyTorch version 2.6.0.dev20241218 available.
Unabled to import torchao experimental quant_api with error:  [Errno 2] No such file or directory: '/Users/mgschwind/tc/torchchat/torchao-build/src/ao/torchao/experimental/quant_api.py'
Using device=mps 
Loading model...
Time to load model: 0.18 seconds
Quantizing the model with: {'executor': {'accelerator': 'fast'}, 'precision': {'dtype': 'fast16'}}
Time to quantize model: 0.00 seconds
-----------------------------------------------------------
once upon a time, there was a little girl named Lily. She loved to play outside in the park with her friends. One day, Lily saw a big, scary dog. She was frightened and didn't know what to do. 
Lily's friend, Timmy, saw her and asked, "What's wrong, Lily?" 
"I'm frightened of the dog," Lily said. 
Timmy said, "Don't worry, I'll call my mom and she will come to save us." 
After Timmy called, he ran to Lily's mom and told her about the scary dog. Lily's mom called the street workers to come and take the dog away. 
Lily was happy and said, "Thank you, Timmy and Timmy's mom." Once upon a time, there was a little girl named Lily. She loved to play outside in
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~                
Generated 199 tokens                 
Time for inference 1: 2.0419 sec total                 
Time to first token: 0.1196 sec with parallel prefill.                

      Total throughput: 97.9489 tokens/sec, 0.0102 s/token                 
First token throughput: 8.3582 tokens/sec, 0.1196 s/token                 
 Next token throughput: 103.5252 tokens/sec, 0.0097 s/token                     

Bandwidth achieved: 4.78 GB/s
*** This first iteration will include cold start effects for dynamic import, hardware caches. ***

========================================


Warning: Excluding compile in calculations                 
      Average tokens/sec (total): 97.95                 
Average tokens/sec (first token): 8.36                 
Average tokens/sec (next tokens): 103.53 
                
mgschwind@mgschwind-mlt torchchat % python3 torchchat.py export stories15M --quant torchchat/quant_config/desktop.json --output-snap stories15-quant.tc 
NumExpr defaulting to 12 threads.
PyTorch version 2.6.0.dev20241218 available.
Unabled to import torchao experimental quant_api with error:  [Errno 2] No such file or directory: '/Users/mgschwind/tc/torchchat/torchao-build/src/ao/torchao/experimental/quant_api.py'
Using device=mps
Loading model...
Time to load model: 0.30 seconds
Quantizing the model with: {'executor': {'accelerator': 'fast'}, 'precision': {'dtype': 'fast16'}}
Time to quantize model: 0.00 seconds
-----------------------------------------------------------
Exporting model using Snapshot to /Users/mgschwind/tc/torchchat/stories15-quant.tc
 
mgschwind@mgschwind-mlt torchchat % python3 torchchat.py generate stories15M --quant torchchat/quant_config/desktop.json --prompt "once upon a time" --snap stories15-quant.tc

NumExpr defaulting to 12 threads.
PyTorch version 2.6.0.dev20241218 available.
Unabled to import torchao experimental quant_api with error:  [Errno 2] No such file or directory: '/Users/mgschwind/tc/torchchat/torchao-build/src/ao/torchao/experimental/quant_api.py'
Using device=mps 
Loading model...
Time to load model: 0.42 seconds
-----------------------------------------------------------
once upon a time, there was a boy called Jack. He was three years old and very excited. One day, when he was playing in the park, a nosy person came up to him and started pointing at something.
Jack was very curious and he stopped to stare. He saw a big pond, and he wondered what it was. Suddenly, a frog jumped out of the pond and winked at Jack. He was very surprised and the two of them started talking.
“Hey buddy, why are you in the pond? I won't hurt you,” said the frog.
Jack smiled at the frog and asked why he was there. The frog explained that he was looking for some food to feed himself. Jack told the frog that if he helped the frog, he could become friends with him and they could play together in the park.
The frog agreed and Jack carried him out of the p
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~                
Generated 199 tokens                 
Time for inference 1: 1.9778 sec total                 
Time to first token: 0.1744 sec with parallel prefill.                

      Total throughput: 101.1238 tokens/sec, 0.0099 s/token                 
First token throughput: 5.7332 tokens/sec, 0.1744 s/token                 
 Next token throughput: 110.3501 tokens/sec, 0.0091 s/token                     

Bandwidth achieved: 17.49 GB/s
*** This first iteration will include cold start effects for dynamic import, hardware caches. ***

========================================


Warning: Excluding compile in calculations                 
      Average tokens/sec (total): 101.12                 
Average tokens/sec (first token): 5.73                 
Average tokens/sec (next tokens): 110.35 
                

Copy link

pytorch-bot bot commented Jan 18, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1465

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit da16f6a with merge base 5684175 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jan 18, 2025
@mikekgfb mikekgfb changed the title Add export --output-snapshot snap.tc, and --snapshot snap.tc Add export --output-snapshot-path snap.tc, and --snapshot-path snap.tc Jan 18, 2025
@mikekgfb
Copy link
Contributor Author

The snapshot load path may need some python imports to pull in all quantization custom ops and custom kernels that quantize.py may make available, so that they are available when a model snapshot gets reloaded. The best way may be to wholesale import all of them, rather than saving additional info in the snapshot and doing selective import, because there just aren't enough custom ops for wholesale import on reloading a snapshot prohibitive

# helper that generate just model.config
with measure_time("Time to load model: {time:.02f} seconds"):
model = _load_model(builder_args)
device_sync(device=builder_args.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: Does the saved artifact still work if the device has changed? I recall this being an issue with AO (one of the reasons why we didn't add saving earlier)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it might not. Most likely. It depends really on what quantizations are performed and whether they're implemented on the multiple platforms, and in the same way. I.e., if it's the same pytorch/python code, for doing computation with quantized numbers and the same quantization formats are supported, then yes.

If it's a C/C++/CUDA operator, it needs to be supported, with the same name, or with a suitable if/then/else (i.e., don't bake. the "device" setting in)

Quantization weight format layouts need to be consistent, or the loader needs to repack them at load time. (This is totally plausible to do, but I don't think we do that today. I think in the 4b case ""we just know". I tried to change that, but the need/priority wasn't similarly perceived by everybody.)

If it's saved, and reloaded, most (all?) decisions you made are set in stone, like quantization schemes etc (Otherwise, you'd be loading from scratch?). In some sense that's similar to how dso/aoti/pte-output-path / load-dso/aoti/pte-path work, and that's why it's modeled after that export and reload facility. You don't get to change the PTE target on reload, or the device that an aoti model has been compiled for.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think expecting the exporting conditions to be the the same as the executing conditions is a fair start


def export_snapshot(
model: nn.Module,
device: Optional[str] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused arg?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a cut and paste thing, and wanted to keep args consistent. Mind you, we could put the device in the file or some other such, and check on reload that it's the same. (I think MPS/CPU are sorta fungible, which might help on Macs with quantization when you run out of kernel memory to quantize large models. CPU could use paging to ssd for that. eg discussion on #1483)

The path to the exported model.
"""
assert output_path.endswith(".tc"), "use .tc extension for snapshots"
torch.save(model, output_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the whole model or can we get away with just the state_dict? https://github.com/pytorch/torchchat/pull/1280/files

That said if we go with the slimmer state_dict, that's dependent on migration to the AO quant that supports this saving

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You rewrite the code as part of the quantization. If you don't save the code, then you must exactly replicate what quantization options you used, create an empty model, quantize it, and then load the state dict over it. Either you transfer the whole responsibility on the user (good luck, you'll die the death of many cuts when users make mistakes and complain that this facility does not work), or you need to save an ungodly amount of information about options used for the original quantization process.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can update the serialization logic as the AO migration finishes (i.e. this PR is good), but I'm not sure that's the case anymore with AO. I was under the impression that the model itself is unaffected and that only the weights are changed

https://github.com/pytorch/ao/blob/48fdd310b3977a0db2ceba37a7725192cd2aafd4/docs/source/serialization.rst#L62

cc: @vmpuri @HDCharles

Copy link
Contributor

@HDCharles HDCharles Jan 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should be able to get by with just the state dict, there are a few apis that don't work that way, but all the subclass ones do. Thats like 75% of the reason that we went with subclasses instead of module swaps.

jack, you're link is a good resource, an alternate reference are our serialization tests to see what is explicitly tests

see e.g. https://github.com/pytorch/ao/blob/48fdd310b3977a0db2ceba37a7725192cd2aafd4/test/integration/test_integration.py#L1322-L1334

https://github.com/pytorch/ao/blob/48fdd310b3977a0db2ceba37a7725192cd2aafd4/test/dtypes/test_affine_quantized.py#L101-L111

@Jack-Khuu
Copy link
Contributor

Jack-Khuu commented Jan 29, 2025

Love the PR. This has been on the backlog for a bit (and unblocks a potential path to quantize at download time + discard the unquantized snapshot)

Left some comments

@mikekgfb
Copy link
Contributor Author

Love the PR. This has been on the backlog for a bit (and unblocks a potential path to quantize at download time + discard the unquantized snapshot)

Left some comments

you're welcome :) Glad you like it!

PS: If you quantize at download, you need to have a way to specify quantization options. Totally doable, but not sure about discarding the unquantized snapshot. As a user, if I change my mind, or forgot device, and if I want to change quantization, it has to re-download. Even on my corpnet in Silicon Valley, the larger models are painful to download -- even worse for the rest of the world. But then again, I have a fairly large disk on my laptop (but on the tradeoff of download bw vs ssd size, i think ssd size wins for most?

Then again, if the user thinks s/he might change her mind, she should just download and quantize separately...

@Jack-Khuu
Copy link
Contributor

Definitely want to spin up a RFC before we push the quantized download through

If you quantize at download, you need to have a way to specify quantization options... As a user, if I change my mind, or forgot device, and if I want to change quantization, it has to re-download.

Our reference here is Ollama which gives a "it just works" experience due to their default to quantized inference (GGUF format aside). We'd give a default quantization setting when downloading in this fashion.

Not sure about discarding the unquantized snapshot...
If the user thinks s/he might change her mind, she should just download and quantize separately

This will always be an option :D
And if we do sell things right will be something that users gravitate to (or they point to a pre-downloaded snapshot)

@Jack-Khuu Jack-Khuu merged commit 7cbf2a3 into pytorch:main Jan 31, 2025
69 checks passed
vmpuri pushed a commit that referenced this pull request Feb 4, 2025
…p.tc` (#1465)

* support model snapshots to save quantized models

* import set backend

---------

Co-authored-by: Michael Gschwind <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants