Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hotswap allow different alpha scalings and ranks #2177

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
355e2ab
[WIP] Allow different alphas for hotswapping
BenjaminBossan Oct 24, 2024
135dfda
Only enable hotswap tests in CI for now
BenjaminBossan Oct 24, 2024
9b7d979
More debugging
BenjaminBossan Oct 24, 2024
3a6d4b5
Merge branch 'main' into hotswap-allow-different-alpha-scalings
BenjaminBossan Oct 24, 2024
088a9d9
More debugging
BenjaminBossan Oct 24, 2024
3da33ba
Take alpha from new config
BenjaminBossan Oct 24, 2024
0bba669
Remove some debug code
BenjaminBossan Oct 24, 2024
c73dafc
Merge branch 'main' into hotswap-allow-different-alpha-scalings
BenjaminBossan Nov 28, 2024
552ecac
Compiled hotswapping works w/ different scalings
BenjaminBossan Nov 28, 2024
2dcd6a1
[WIP] Hotswap with different ranks
BenjaminBossan Nov 28, 2024
620d43e
Make hotswapping with different ranks work
BenjaminBossan Nov 29, 2024
efaf2b0
More tests, docs, fixes
BenjaminBossan Dec 2, 2024
25d7a77
Make style, fix
BenjaminBossan Dec 2, 2024
88bd814
Merge branch 'main' into hotswap-allow-different-alpha-scalings
BenjaminBossan Jan 8, 2025
2a29370
Remove obsolete test
BenjaminBossan Jan 8, 2025
9b132ec
Merge branch 'main' into hotswap-allow-different-alpha-scalings
BenjaminBossan Jan 24, 2025
ddf1bad
Disable test skips for CI (only testing)
BenjaminBossan Jan 24, 2025
3251c6c
Clean up, documentation
BenjaminBossan Jan 28, 2025
9966c10
Reviewer feedback: Refactor _pad_lora_weights
BenjaminBossan Jan 30, 2025
01924ad
Reviewer feedback: Add clarifying comments
BenjaminBossan Jan 30, 2025
66b987e
Fix issue with bias term
BenjaminBossan Jan 31, 2025
d32d23f
Clarify docstring
BenjaminBossan Jan 31, 2025
8b7e602
Reviewer feedback:
BenjaminBossan Feb 3, 2025
d0a1fc5
Reviewer feedback: add clarifying comments
BenjaminBossan Feb 4, 2025
7772e2e
Reviewer feedback: address model.compile()
BenjaminBossan Feb 4, 2025
fe19da5
Revert testing
BenjaminBossan Feb 4, 2025
5bc0dbe
Simplify tests for recompilation
BenjaminBossan Feb 5, 2025
a412f7a
Reviewer feedback: Better handling of torch cache
BenjaminBossan Feb 5, 2025
0943519
Change fixture to clean up dynamo cache after test
BenjaminBossan Feb 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 35 additions & 4 deletions docs/source/package_reference/hotswap.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ The idea of hotswapping an adapter is the following: We can already load multipl

In general, this should be faster than deleting one adapter and loading the adapter in its place, which would be the how to achieve the same final outcome without hotswapping. Another advantage of hotswapping is that it prevents re-compilation in case the PEFT model is already compiled using `torch.compile`. This can save quite a lot of time.

## Example without `torch.compile`

```python
import torch
from transformers import AutoModelForCausalLM
Expand All @@ -21,7 +23,6 @@ model = AutoModelForCausalLM.from_pretrained(model_id).to(device)

# load lora 0
model = PeftModel.from_pretrained(model, <path-adapter-0>)
model = torch.compile(model) # optionally compile the model
with torch.inference_mode():
output_adapter_0 = model(inputs)

Expand All @@ -31,12 +32,42 @@ with torch.inference_mode():
output_adapter_1 = model(inputs).logits
```

## Example with `torch.compile`

```python
import torch
from transformers import AutoModelForCausalLM
from peft import PeftModel
from peft.utils.hotswap import hotswap_adapter, prepare_model_for_compiled_hotswap

model_id = ...
inputs = ...
device = ...
max_rank = ... # maximum rank among all LoRA adapters that will be used
model = AutoModelForCausalLM.from_pretrained(model_id).to(device)

# load lora 0
model = PeftModel.from_pretrained(model, <path-adapter-0>)
# Prepare the model to allow hotswapping even if ranks/scalings of 2nd adapter differ.
# You can skip this step if all ranks and scalings are identical.
prepare_model_for_compiled_hotswap(model, target_rank=max_rank)
Comment on lines +49 to +53
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to allow a util like infer_max_rank_from_state_dict()? Because it can be difficult to know for the users to know max_rank from the get-go. WDYT?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about how an application would use this in practice for the reason you mentioned. If we break it down like this:

  1. Some business use case of serving different LoRAs: They can probably determine the max rank ahead of time and set it correctly.
  2. An image generation app like Comfy: This is hard as they cannot know in advance what LoRAs the user plans on loading. Theoretically, all available LoRAs could be scanned (which might be expensive) and the largest rank be chosen (which might be total overkill and damage runtime performance).

If we had a infer_max_rank_from_state_dict() it would be for the 2nd case, which has the mentioned issues, but correct me if I'm wrong. But regardless of that, where would the state dicts come from? Since we have multiple LoRA formats in the diffusion ecosystem and diffusers knows how to deal with them, would that function not be better added to diffusers?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough.

model = torch.compile(model)
with torch.inference_mode():
output_adapter_0 = model(inputs)

# replace the "default" lora adapter with the new one
hotswap_adapter(model, <path-adapter-1>, adapter_name="default", torch_device=device)
with torch.inference_mode():
output_adapter_1 = model(inputs).logits
```

## Caveats

Hotswapping works with transformers models and diffusers models. However, there are some caveats:

- It only works for the same PEFT method, so no swapping LoRA and LoHa, for example.
- Right now, only LoRA is properly supported.
- The adapters must be compatible (e.g. same LoRA alpha, same target modules).
- If you use `torch.compile` and want to avoid recompilation, the LoRA rank must be the same.
- It only works for the same PEFT method, so no swapping LoRA and LoHa, for example.
- The adapter that is being swapped in must target the same layers as the previous adapter or a subset of those layers. It cannot target new layers. Therefore, if possible, start with the adapter that targets most layers.

[[autodoc]] utils.hotswap.hotswap_adapter
- all
Expand Down
Loading