Skip to content

Commit

Permalink
Hotswap allow different alpha scalings and ranks (#2177)
Browse files Browse the repository at this point in the history
Hotswapping of LoRA adapters is already implemented, but when alpha
scalings or ranks differ, this triggers recompilation of the model is
compiled, which is inefficient. Users can now call
prepare_model_for_compiled_hotswap to prevent recompilation in many
cases (see the doc update for caveats).
  • Loading branch information
BenjaminBossan authored Feb 5, 2025
1 parent db9dd3f commit eaab05e
Show file tree
Hide file tree
Showing 6 changed files with 854 additions and 320 deletions.
39 changes: 35 additions & 4 deletions docs/source/package_reference/hotswap.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ The idea of hotswapping an adapter is the following: We can already load multipl

In general, this should be faster than deleting one adapter and loading the adapter in its place, which would be the how to achieve the same final outcome without hotswapping. Another advantage of hotswapping is that it prevents re-compilation in case the PEFT model is already compiled using `torch.compile`. This can save quite a lot of time.

## Example without `torch.compile`

```python
import torch
from transformers import AutoModelForCausalLM
Expand All @@ -21,7 +23,6 @@ model = AutoModelForCausalLM.from_pretrained(model_id).to(device)

# load lora 0
model = PeftModel.from_pretrained(model, <path-adapter-0>)
model = torch.compile(model) # optionally compile the model
with torch.inference_mode():
output_adapter_0 = model(inputs)

Expand All @@ -31,12 +32,42 @@ with torch.inference_mode():
output_adapter_1 = model(inputs).logits
```

## Example with `torch.compile`

```python
import torch
from transformers import AutoModelForCausalLM
from peft import PeftModel
from peft.utils.hotswap import hotswap_adapter, prepare_model_for_compiled_hotswap

model_id = ...
inputs = ...
device = ...
max_rank = ... # maximum rank among all LoRA adapters that will be used
model = AutoModelForCausalLM.from_pretrained(model_id).to(device)

# load lora 0
model = PeftModel.from_pretrained(model, <path-adapter-0>)
# Prepare the model to allow hotswapping even if ranks/scalings of 2nd adapter differ.
# You can skip this step if all ranks and scalings are identical.
prepare_model_for_compiled_hotswap(model, target_rank=max_rank)
model = torch.compile(model)
with torch.inference_mode():
output_adapter_0 = model(inputs)

# replace the "default" lora adapter with the new one
hotswap_adapter(model, <path-adapter-1>, adapter_name="default", torch_device=device)
with torch.inference_mode():
output_adapter_1 = model(inputs).logits
```

## Caveats

Hotswapping works with transformers models and diffusers models. However, there are some caveats:

- It only works for the same PEFT method, so no swapping LoRA and LoHa, for example.
- Right now, only LoRA is properly supported.
- The adapters must be compatible (e.g. same LoRA alpha, same target modules).
- If you use `torch.compile` and want to avoid recompilation, the LoRA rank must be the same.
- It only works for the same PEFT method, so no swapping LoRA and LoHa, for example.
- The adapter that is being swapped in must target the same layers as the previous adapter or a subset of those layers. It cannot target new layers. Therefore, if possible, start with the adapter that targets most layers.

[[autodoc]] utils.hotswap.hotswap_adapter
- all
Expand Down
Loading

0 comments on commit eaab05e

Please sign in to comment.