Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoRA] Implement hot-swapping of LoRA #9453

Open
wants to merge 21 commits into
base: main
Choose a base branch
from

Conversation

BenjaminBossan
Copy link
Member

This PR adds the possibility to hot-swap LoRA adapters. It is WIP.

Description

As of now, users can already load multiple LoRA adapters. They can offload existing adapters or they can unload them (i.e. delete them). However, they cannot "hotswap" adapters yet, i.e. substitute the weights from one LoRA adapter with the weights of another, without the need to create a separate LoRA adapter.

Generally, hot-swapping may not appear not super useful but when the model is compiled, it is necessary to prevent recompilation. See #9279 for more context.

Caveats

To hot-swap a LoRA adapter for another, these two adapters should target exactly the same layers and the "hyper-parameters" of the two adapters should be identical. For instance, the LoRA alpha has to be the same: Given that we keep the alpha from the first adapter, the LoRA scaling would be incorrect for the second adapter otherwise.

Theoretically, we could override the scaling dict with the alpha values derived from the second adapter's config, but changing the dict will trigger a guard for recompilation, defeating the main purpose of the feature.

I also found that compilation flags can have an impact on whether this works or not. E.g. when passing "reduce-overhead", there will be errors of the type:

input name: arg861_1. data pointer changed from 139647332027392 to
139647331054592

I don't know enough about compilation to determine whether this is problematic or not.

Current state

This is obviously WIP right now to collect feedback and discuss which direction to take this. If this PR turns out to be useful, the hot-swapping functions will be added to PEFT itself and can be imported here (or there is a separate copy in diffusers to avoid the need for a min PEFT version to use this feature).

Moreover, more tests need to be added to better cover this feature, although we don't necessarily need tests for the hot-swapping functionality itself, since those tests will be added to PEFT.

Furthermore, as of now, this is only implemented for the unet. Other pipeline components have yet to implement this feature.

Finally, it should be properly documented.

I would like to collect feedback on the current state of the PR before putting more time into finalizing it.

What does this PR do?

Fixes # (issue)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

This PR adds the possibility to hot-swap LoRA adapters. It is WIP.

Description

As of now, users can already load multiple LoRA adapters. They can
offload existing adapters or they can unload them (i.e. delete them).
However, they cannot "hotswap" adapters yet, i.e. substitute the weights
from one LoRA adapter with the weights of another, without the need to
create a separate LoRA adapter.

Generally, hot-swapping may not appear not super useful but when the
model is compiled, it is necessary to prevent recompilation. See huggingface#9279
for more context.

Caveats

To hot-swap a LoRA adapter for another, these two adapters should target
exactly the same layers and the "hyper-parameters" of the two adapters
should be identical. For instance, the LoRA alpha has to be the same:
Given that we keep the alpha from the first adapter, the LoRA scaling
would be incorrect for the second adapter otherwise.

Theoretically, we could override the scaling dict with the alpha values
derived from the second adapter's config, but changing the dict will
trigger a guard for recompilation, defeating the main purpose of the
feature.

I also found that compilation flags can have an impact on whether this
works or not. E.g. when passing "reduce-overhead", there will be errors
of the type:

> input name: arg861_1. data pointer changed from 139647332027392 to
139647331054592

I don't know enough about compilation to determine whether this is
problematic or not.

Current state

This is obviously WIP right now to collect feedback and discuss which
direction to take this. If this PR turns out to be useful, the
hot-swapping functions will be added to PEFT itself and can be imported
here (or there is a separate copy in diffusers to avoid the need for a
min PEFT version to use this feature).

Moreover, more tests need to be added to better cover this feature,
although we don't necessarily need tests for the hot-swapping
functionality itself, since those tests will be added to PEFT.

Furthermore, as of now, this is only implemented for the unet. Other
pipeline components have yet to implement this feature.

Finally, it should be properly documented.

I would like to collect feedback on the current state of the PR before
putting more time into finalizing it.
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for working on this. I left some comments.

@yiyixuxu
Copy link
Collaborator

cc @apolinario
can you take a look at this initial draft?

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Sep 20, 2024

does most lora have same scaling?
just wonder how important (or not important) it is to be able to support hot swap with different scale (without recompile) - maybe more of a question for @apolinario

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Sep 22, 2024

So I played around a little bit, I have two main question:

Do we support hotswap with different lora ranks? the rank config is not checked in the _check_hotswap_configs_compatible step, so it is a bit unclear. However, I would imagine a different rank Lora would most likely trigger recompilation because the weights shapes are different now. If we want to support Lora with different rank, maybe we need to pad the weights to a fixed size.

I think we should also look into supporting hot-swap with different scaling, I checked some popular loras on our hub, I think most of them have different ranks/alphas so this feature will be a lot more impactful if we are able to support different rank & scaling - based on this thread #9279, I understand that the change in the "scaling" dict would trigger a recompilation. But maybe there are ways to avoid it? for example, if scaling value is a tensor, torch.compile will put different guards in it. I played around with this dummy example a little bit

this trigger recompile

import torch
scaling = {}

def fn(x, key):
    return x * scaling[key]


opt_fn = torch.compile(fn, backend="eager")

x = torch.rand(4)

scaling["first"] = 1.0
opt_fn(x, "first")

print(f" finish first run, updating scaling")

scaling["first"] = 2.0
opt_fn(x, "first")

this won't

import torch
scaling = {}

def fn(x, key):
    return x * scaling[key]


opt_fn = torch.compile(fn, backend="eager")

x = torch.rand(4)

scaling["first"] = torch.tensor(1.0)
opt_fn(x, "first")

print(f" finish first run, updating scaling")

scaling["first"] = torch.tensor(2.0)
opt_fn(x, "first")

I'm very excited about having this in diffusers ! think would be a super nice feature, especially for production use case :)

@sayakpaul
Copy link
Member

I agree with your point on supporting LoRAs with different scaling in this context.

With backend="eager", we may not get the full benefits of torch.compile() I think because parts of the graph would run in eager mode and the benefits of a compiled graph could diminish.

A good way to verify it would be to measure the performance of a pipeline with eager torch.compile() and non-eager torch.compile() 👀

Cc: @anijain2305.

If we want to support Lora with different rank, maybe we need to pad the weights to a fixed size.

I will let @BenjaminBossan comment further but this might require a lot of changes within the tuner modules inside peft.

@BenjaminBossan
Copy link
Member Author

Thanks for all the feedback. I haven't forgotten about this PR, I was just occupied with other things. I'll come back to this as soon as I have a bit of time on my hands.

The idea of using a tensor instead of float for scaling is intriguing, thanks for testing it. It might just work OOTB, as torch broadcasts 0-dim tensors automatically. Another possibility would be to multiply the scaling directly into one of the weights, so that the original alpha can be retained, but that is probably very error prone.

Regarding different ranks, I have yet to test that.

@anijain2305
Copy link

Yes, torch.compile(backend="eager") is just for debugging purposes. In order to see benefits of torch.compile, you will have to use the inductor backend. Simply using torch.compile uses Inductor backend. If your model is overhead-bound, you should use torch.compile(mode="reduce-overhead") to use Cudagraphs.

@sayakpaul
Copy link
Member

If different ranks become a problem, then https://huggingface.co/sayakpaul/lower-rank-flux-lora could provide a meaningful direction.

@apolinario
Copy link
Collaborator

Indeed, although avoiding recompilation altogether with different ranks would be even greater for real time swap applications

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Sep 25, 2024

yep can be a nice feature indeed!
but for this PR we should aim to support different ranks without reduced rank as we are targeting for production use cases

@sayakpaul
Copy link
Member

Indeed. For different ranks, things that come to mind:

  1. Scaling value as a torch tensor --> but this was tested with "eager", so practically no torch.compile().
  2. Some form of padding in the parameter space --> I think this could be interesting but I am not sure how much code changes we're talking about here.
  3. Even if we were to pad, what is going to be the maximum length? Should this be requested from the user? I think we cannot know this value beforehand unless a user specifies it. A sensible choice for this value would be the highest rank that a user is expecting in their pool of LoRAs.

@sayakpaul
Copy link
Member

A reverse direction of what I showed in #9453 is also possible (increase the rank of a LoRA):
https://huggingface.co/sayakpaul/flux-lora-resizing#lora-rank-upsampling

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Sep 27, 2024

hi @BenjaminBossan
I tested out padding the weights + using a tensor to store scaling here in this commit c738f14

and they work for the 4 loras I tested (all with different ranks and scaling) - I'm not as familiar with peft and just made enough changes for the purpose of the experiment & provide a reference point, so the code is very hacky there. sorry for that!

to test ,

# testing hotswap PR

# TORCH_LOGS="guards,recompiles" TORCH_COMPILE_DEBUG=1 TORCH_LOGS_OUT=traces.txt python yiyi_test_3.py
from diffusers import DiffusionPipeline
import torch
import time

torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

branch = "test-hotswap"

loras = [
    "Norod78/sd15-megaphone-lora", # rank 16, scaling 0.5
    "artificialguybr/coloringbook-redmond-1-5v-coloring-book-lora-for-liberteredmond-sd-1-5", # rank 64, scaling 1.0
    "Norod78/SD15-Rubber-Duck-LoRA", # rank 16, scaling 0.5
    "wooyvern/sd-1.5-dark-fantasy-1.1", # rank 128, scaling 1.0
]

prompts =[
    "Marge Simpson holding a megaphone in her hand with her town in the background",
    "A lion, minimalist, Coloring Book, ColoringBookAF",
    "The girl with a pearl earring Rubber duck",
    "<lora:fantasyV1.1:1>, a painting of a skeleton with a long cloak and a group of skeletons in a forest with a crescent moon in the background, David Wojnarowicz, dark art, a screenprint, psychedelic art",    
]

def print_rank_scaling(pipe):
    print(f" rank: {pipe.unet.peft_config['default_0'].r}")
    print(f" scaling: {pipe.unet.down_blocks[0].attentions[0].proj_in.scaling}")


# pipe_id = "stabilityai/stable-diffusion-xl-base-1.0"
pipe_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
pipe = DiffusionPipeline.from_pretrained(pipe_id, torch_dtype=torch.float16).to("cuda")

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

pipe.unet = pipe.unet.to(memory_format=torch.channels_last)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)

for i, (lora_repo, prompt) in enumerate(zip(loras, prompts)):
    hotswap = False if i == 0 else True
    print(f"\nProcessing LoRA {i}: {lora_repo}")
    print(f" prompt: {prompt}")
    print(f" hotswap: {hotswap}")
    
    # Start timing for the entire iteration
    start_time = time.time()

    # Load LoRA weights
    pipe.load_lora_weights(lora_repo, hotswap=hotswap, adapter_name = "default_0")
    print_rank_scaling(pipe)

    # Time image generation
    generator = torch.Generator(device="cuda").manual_seed(42)
    generate_start_time = time.time()
    image = pipe(prompt, num_inference_steps=50, generator=generator).images[0]
    generate_time = time.time() - generate_start_time

    # Save the image
    image.save(f"yiyi_test_3_out_{branch}_lora{i}.png")

    # Unload LoRA weights
    pipe.unload_lora_weights()

    # Calculate and print total time for this iteration
    total_time = time.time() - start_time
    
    print(f"Image generation time: {generate_time:.2f} seconds")
    print(f"Total time for LoRA {i}: {total_time:.2f} seconds")

mem_bytes = torch.cuda.max_memory_allocated()
print(f"total Memory: {mem_bytes/(1024*1024):.3f} MB")

output

Loading pipeline components...:  57%|████████████████████████████████████████████████████████████████                                                | 4/7 [00:00<00:00,  4.45it/s]/home/yiyi/diffusers/.venv/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884
  warnings.warn(
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:01<00:00,  5.47it/s]
PyTorch version: 2.4.1+cu121
CUDA available: True
Device: cuda

Processing LoRA 0: Norod78/sd15-megaphone-lora
 prompt: Marge Simpson holding a megaphone in her hand with her town in the background
 hotswap: False
 rank: 16
 scaling: {'default_0': tensor(0.5000, device='cuda:0')}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [01:36<00:00,  1.94s/it]
Image generation time: 99.18 seconds
Total time for LoRA 0: 105.29 seconds

Processing LoRA 1: artificialguybr/coloringbook-redmond-1-5v-coloring-book-lora-for-liberteredmond-sd-1-5
 prompt: A lion, minimalist, Coloring Book, ColoringBookAF
 hotswap: True
 rank: 64
 scaling: {'default_0': tensor(1., device='cuda:0')}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:01<00:00, 44.67it/s]
Image generation time: 2.13 seconds
Total time for LoRA 1: 3.09 seconds

Processing LoRA 2: Norod78/SD15-Rubber-Duck-LoRA
 prompt: The girl with a pearl earring Rubber duck
 hotswap: True
 rank: 16
 scaling: {'default_0': tensor(0.5000, device='cuda:0')}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:01<00:00, 44.73it/s]
Image generation time: 1.23 seconds
Total time for LoRA 2: 1.85 seconds

Processing LoRA 3: wooyvern/sd-1.5-dark-fantasy-1.1
 prompt: <lora:fantasyV1.1:1>, a painting of a skeleton with a long cloak and a group of skeletons in a forest with a crescent moon in the background, David Wojnarowicz, dark art, a screenprint, psychedelic art
 hotswap: True
 rank: 128
 scaling: {'default_0': tensor(1., device='cuda:0')}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:01<00:00, 44.53it/s]
Image generation time: 2.01 seconds
Total time for LoRA 3: 3.46 seconds
total Memory: 3417.621 MB

confirm outputs are same as in main

yiyi_test_3_out_test-hotswap_lora0

yiyi_test_3_out_test-hotswap_lora1

yiyi_test_3_out_test-hotswap_lora2

yiyi_test_3_out_test-hotswap_lora3

@sayakpaul
Copy link
Member

Very cool!

Could you also try logging the traces just to confirm it does not trigger any recompilation?

TORCH_LOGS="guards,recompiles" TORCH_LOGS_OUT=traces.txt python my_code.py

@yiyixuxu
Copy link
Collaborator

I did and it doesn't

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Sep 27, 2024

also, I think, from the user experience perspective, it might be more convenient to have a "hotswap" mode that, once it's on, everything will be hot-swapped by default. I think, it is not something you use on and off, no?

maybe be a question for @apolinario

@sayakpaul
Copy link
Member

also, I think, from the user experience perspective, it might be more convenient to have a "hotswap" mode that, once it's on, everything will be hot-swapped by default. I think, it is not something you use on and off, no?

I think that is the case, yes! I also agree that the ability to hot-swap LoRAs (with torch.compile()) is a far better and more appealing UX.

But just in case it becomes a memory problem, users can explore the LoRA resizing path to have everything to a small unified rank (if it doesn't lead too much quality degradation).

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @BenjaminBossan! Just left some comments. Overall, this is looking very nice!

Comment on lines 207 to 208
There are some limitations to this technique, which are documented here:
https://huggingface.co/docs/peft/v0.14.0/en/package_reference/hotswap
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we always refer the users to main branch of the documentation?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -296,11 +325,47 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
if is_peft_version(">=", "0.13.1"):
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage

if hotswap:
try:
from peft.utils.hotswap import _check_hotswap_configs_compatible, hotswap_adapter_from_state_dict
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're relying on a private method from peft here. I guess this is okay? Just checking.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made it public in huggingface/peft#2366.

Comment on lines 345 to 347
k = k[:-7] + f".{adapter_name}.weight"
elif k.endswith("lora_B.bias"): # lora_bias=True option
k = k[:-5] + f".{adapter_name}.bias"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-7 and -5 seem to be two magic numbers. Prefer clarifying.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

@@ -64,7 +64,12 @@ class UNet2DConditionLoadersMixin:
unet_name = UNET_NAME

@validate_hf_hub_args
def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
def load_attn_procs(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this method will be deprecated, I think it's okay to not propagate these changes:

def test_load_attn_procs_raise_warning(self):

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted.

@@ -2175,3 +2176,155 @@ def test_ddpm_ddim_equality_batched(self):

# the values aren't exactly equal, but the images look the same visually
assert np.abs(ddpm_images - ddim_images).max() < 1e-1


class TestLoraHotSwapping(unittest.TestCase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the test! I think it's checking recompilation on an individual model-level, which is good.

However, to have a more realistic scenario covered, that is aligned with the users, I think we will need to have a full pipeline example, too.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a pipeline test based on your previous comment but adjusted to load 2 different adapters.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're only testing at the model-level here, would prefer this to live in tests/models/test_modeling_common.py. We could try generalizing the test similar to test_save_load_lora_adapter() so that we know this is supported across all the models that are a subclass of PeftAdapterMixin. WDYT?

Comment on lines 2300 to 2301
if do_compile:
unet = torch.compile(unet, mode="reduce-overhead")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we wanna also do unet.compile()?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly, not sure if it's worth it, as the tests are already slow. Do users use this pattern? I've only seen torch.compile in the wild and I'm not sure what the purpose of this is.

@sayakpaul
Copy link
Member

sayakpaul commented Feb 8, 2025

So the tests are running after all. I'm not sure why, as these seem to be CPU tests but I use the @require_torch_accelerator decorator. Moreover, the @require_peft_backend seems to be ignored, as the tests fail with ModuleNotFoundError: No module named 'peft'.

This is very strange. Moreover, we have slow marker already which should prevent the test to be invoked with pr_tests.yml. Will try to debug.

- Revert deprecated method
- Fix PEFT doc link to main
- Don't use private function
- Clarify magic numbers
- Add pipeline test

Moreover:
- Extend docstrings
- Extend existing test for outputs != 0
- Extend existing test for wrong adapter name
Copy link
Member Author

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the PR, your comments should be addressed. Please check again.

I would also extend the request not to push on my branch to avoid conflicts :) We can do a merge with the latest main before merging.

This is very strange. Moreover, we have slow marker already which should prevent the test to be invoked with pr_tests.yml. Will try to debug.

Thanks.

Comment on lines 207 to 208
There are some limitations to this technique, which are documented here:
https://huggingface.co/docs/peft/v0.14.0/en/package_reference/hotswap
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -296,11 +325,47 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
if is_peft_version(">=", "0.13.1"):
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage

if hotswap:
try:
from peft.utils.hotswap import _check_hotswap_configs_compatible, hotswap_adapter_from_state_dict
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made it public in huggingface/peft#2366.

Comment on lines 345 to 347
k = k[:-7] + f".{adapter_name}.weight"
elif k.endswith("lora_B.bias"): # lora_bias=True option
k = k[:-5] + f".{adapter_name}.bias"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

@@ -64,7 +64,12 @@ class UNet2DConditionLoadersMixin:
unet_name = UNET_NAME

@validate_hf_hub_args
def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
def load_attn_procs(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted.

src/diffusers/loaders/unet.py Show resolved Hide resolved
@@ -2175,3 +2176,155 @@ def test_ddpm_ddim_equality_batched(self):

# the values aren't exactly equal, but the images look the same visually
assert np.abs(ddpm_images - ddim_images).max() < 1e-1


class TestLoraHotSwapping(unittest.TestCase):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a pipeline test based on your previous comment but adjusted to load 2 different adapters.

Comment on lines 2300 to 2301
if do_compile:
unet = torch.compile(unet, mode="reduce-overhead")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly, not sure if it's worth it, as the tests are already slow. Do users use this pattern? I've only seen torch.compile in the wild and I'm not sure what the purpose of this is.

parameterized.expand seems to ignore skip decorators if added in last
place (i.e. innermost decorator).
@BenjaminBossan
Copy link
Member Author

BenjaminBossan commented Feb 10, 2025

@sayakpaul I think I figured out the reason why the tests ran despite the decorators. It appears to be a bad interaction with @parameterized.expand, which will only honor the first skip but then still run the remaining tests. When I moved that decorator to the top (i.e. making it the outermost decorator), the tests are correctly skipped when run locally.

I've never seen this issue with @pytest.mark.parametrize, so I think it is a problem with the parameterized package.

Here is a reproducer:

import unittest
import pytest
from parameterized import parameterized

class TestUnittestStyle(unittest.TestCase):
    @parameterized.expand([0, 1, 2, 3])
    @unittest.skipIf(True, reason="foo")
    def test_parameterize_decorator_first(self, x):
        assert True

    @unittest.skipIf(True, reason="foo")
    @parameterized.expand([0, 1, 2, 3])
    def test_parameterize_decorator_last(self, x):
        assert True

class TestPytestStyle(unittest.TestCase):
    @pytest.mark.parametrize("x", [0, 1, 2, 3])
    @pytest.mark.skipif(True, reason="foo")
    def test_parameterize_decorator_first(self, x):
        assert True

    @pytest.mark.skipif(True, reason="foo")
    @pytest.mark.parametrize("x", [0, 1, 2, 3])
    def test_parameterize_decorator_last(self, x):
        assert True

This gives me:

foo.py::TestUnittestStyle::test_parameterize_decorator_first_0 SKIPPED (foo)                                                                                                                                                                                                 [  9%]
foo.py::TestUnittestStyle::test_parameterize_decorator_first_1 SKIPPED (foo)                                                                                                                                                                                                 [ 18%]
foo.py::TestUnittestStyle::test_parameterize_decorator_first_2 SKIPPED (foo)                                                                                                                                                                                                 [ 27%]
foo.py::TestUnittestStyle::test_parameterize_decorator_first_3 SKIPPED (foo)                                                                                                                                                                                                 [ 36%]
foo.py::TestUnittestStyle::test_parameterize_decorator_last SKIPPED (foo)                                                                                                                                                                                                    [ 45%]
foo.py::TestUnittestStyle::test_parameterize_decorator_last_0 PASSED                                                                                                                                                                                                         [ 54%]
foo.py::TestUnittestStyle::test_parameterize_decorator_last_1 PASSED                                                                                                                                                                                                         [ 63%]
foo.py::TestUnittestStyle::test_parameterize_decorator_last_2 PASSED                                                                                                                                                                                                         [ 72%]
foo.py::TestUnittestStyle::test_parameterize_decorator_last_3 PASSED                                                                                                                                                                                                         [ 81%]
foo.py::TestPytestStyle::test_parameterize_decorator_first SKIPPED (foo)                                                                                                                                                                                                     [ 90%]
foo.py::TestPytestStyle::test_parameterize_decorator_last SKIPPED (foo)

As you can see, pytest does the right thing but parameterize only works correctly if the decorator is used first.

PS: Just checked, parameterized hasn't been updated for almost 2 years, it would probably be a good idea to transition to pytest style tests (in PEFT, we use pytest style tests for new tests but don't strictly refactor existing unittest style tests).

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing the comments and for the iterations. Some comments on the tests.

I guess after that, we will have a review from Yiyi. And then, we can propagate the changes to the other LoRA loader classes, add docs, etc.

Comment on lines 2340 to 2355
############
# PIPELINE #
############

def get_lora_state_dicts(self, modules_to_save, adapter_name):
from peft import get_peft_model_state_dict

state_dicts = {}
for module_name, module in modules_to_save.items():
if module is not None:
state_dicts[f"{module_name}_lora_layers"] = get_peft_model_state_dict(
module, adapter_name=adapter_name
)
return state_dicts

def get_dummy_input_pipeline(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYT about moving this to tests/lora/utils.py?

@@ -2175,3 +2176,155 @@ def test_ddpm_ddim_equality_batched(self):

# the values aren't exactly equal, but the images look the same visually
assert np.abs(ddpm_images - ddim_images).max() < 1e-1


class TestLoraHotSwapping(unittest.TestCase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're only testing at the model-level here, would prefer this to live in tests/models/test_modeling_common.py. We could try generalizing the test similar to test_save_load_lora_adapter() so that we know this is supported across all the models that are a subclass of PeftAdapterMixin. WDYT?

@sayakpaul sayakpaul requested a review from yiyixuxu February 11, 2025 03:43
@sayakpaul
Copy link
Member

@yiyixuxu could you give this another round of review?

@BenjaminBossan
Copy link
Member Author

@sayakpaul Thanks for the review. Could I ask for some clarification?

WDYT about moving this to tests/lora/utils.py?

Since we're only testing at the model-level here, would prefer this to live in tests/models/test_modeling_common.py. We could try generalizing the test similar to test_save_load_lora_adapter() so that we know this is supported across all the models that are a subclass of PeftAdapterMixin. WDYT?

So this would mean that I move the whole content of TestLoraHotSwapping to PeftLoraLoaderMixinTests? Or only check_hotswap and then create a new subclass of PeftLoraLoaderMixinTests that defines get_dummy_input? That subclass would go to tests/models/test_modeling_common.py? Where would test_hotswapping_compiled_diffusers_model go?

@sayakpaul
Copy link
Member

That subclass would go to tests/models/test_modeling_common.py? Where would test_hotswapping_compiled_diffusers_model go?

It would go under

class ModelTesterMixin:

Similar to

def test_save_load_lora_adapter(self, use_dora=False):

The pipeline level tests could go to PeftLoraLoaderMixinTests.

Does this make sense?

@BenjaminBossan
Copy link
Member Author

Does this make sense?

So I tried to understand the test structure. AFAICT, PeftLoraLoaderMixinTests and ModelTesterMixin are completely independent. However, I think that the check_hotswap and check_pipeline_hotswap tests are almost identical. Yes, the latter uses a pipeline and hence some method calls like inference and checkpoint saving are different. But at the end, we still do the same thing, namely adding LoRA to the unet (pipeline.unet.add_adapter), compiling the unet (torch.compile(pipeline.unet, mode="reduce-overhead")), then hotswap. Therefore, I think these tests should be on the same level, not on completely separate classes. WDYT?

Also, IIUC, if I add these tests to one of these mixin classes, I'd have to go to most the child classes and override the test function to pass, right (remember that Conv2d support is not fully there yet)? Finally, also consider that compiling these models is quite slow, not sure if that's an issue for CI or not.

@sayakpaul
Copy link
Member

Therefore, I think these tests should be on the same level, not on completely separate classes. WDYT?

Also, IIUC, if I add these tests to one of these mixin classes, I'd have to go to most the child classes and override the test function to pass, right (remember that Conv2d support is not fully there yet)? Finally, also consider that compiling these models is quite slow, not sure if that's an issue for CI or not.

Okay thanks for explaining the context.

Model-level tests should still be under test_modeling_common.py. In this case we could rename TestLoraHotSwapping to TestLoraHotSwappingForModel. The pipeline-level tests (test_hotswapping_compiled_diffusers_pipline()) should remain where they are currently (TestLoraHotSwappingForPipeline).

Would it make more sense?

@sayakpaul
Copy link
Member

Hope it's not too late. On a second thought, I think what you're suggesting makes more sense i.e., I would be fine with the current state of tests. Sorry about the bother.

Also increase test coverage by also targeting conv2d layers (support of
which was added recently on the PEFT PR).
@BenjaminBossan
Copy link
Member Author

Hope it's not too late. On a second thought, I think what you're suggesting makes more sense i.e., I would be fine with the current state of tests. Sorry about the bother.

I saw this too late, my latest commit has split the test into two tests as per the previous suggestion. However, I don't think it's bad, as splitting the test adds more clarity at the cost of some code duplication. LMK if you prefer the new way or the old one.

In addition to that, I also extended the test to cover a wider range, most notably by also targeting Conv2d layers. The support for those was just added on the PEFT PR (it was much easier than I initially thought, so I added it instead of waiting).

One more question, I saw that some compile tests have @is_torch_compile, should I add that too?

@BenjaminBossan
Copy link
Member Author

Hmm, Hub tests for models, schedulers, and pipelines failing repeatedly, but it looks unrelated.

@sayakpaul
Copy link
Member

One more question, I saw that some compile tests have @is_torch_compile, should I add that too?

This is used based on RUN_COMPILE env var. I think okay to use.

I saw this too late, my latest commit has split the test into two tests as per the previous suggestion. However, I don't think it's bad, as splitting the test adds more clarity at the cost of some code duplication. LMK if you prefer the new way or the old one.

No problem then. Let's go with the current changes. Thank you!

Hmm, Hub tests for models, schedulers, and pipelines failing repeatedly, but it looks unrelated.

Yeah safe to ignore.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking very nice. Let's wait for @yiyixuxu to review this once before propagating the changes.

@@ -1519,3 +1523,188 @@ def test_push_to_hub_library_name(self):

# Reset repo
delete_repo(self.repo_id, token=TOKEN)


class TestLoraHotSwappingForModel(unittest.TestCase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could mark the entire class here instead of marking the individual methods.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, done.

@BenjaminBossan
Copy link
Member Author

huggingface/peft#2366 is merged, so installing PEFT from main should allow the new tests to pass. To call:

RUN_SLOW=1 RUN_COMPILE=1 pytest tests/pipelines/test_pipelines.py tests/models/test_modeling_common.py -k hotswap

@yiyixuxu
Copy link
Collaborator

I'm testing this with this script here and getting an error - did I do something wrong?

ValueError: Incompatible shapes found for LoRA weights down_blocks.0.attentions.0.proj_in.lora_A.default_0.weight: torch.Size([16, 320, 1, 1]) vs torch.Size([64, 320, 1, 1]). Please ensure that all ranks are padded to the largest rank among all LoRA adapters by using peft.utils.hotswap.prepare_model_for_compiled_hotswap.
Click to see testing script I used
# testing hotswap PR
from diffusers import DiffusionPipeline
import torch
import time

from peft.utils.hotswap import prepare_model_for_compiled_hotswap

torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

branch = "test-hotswap"

loras = [
    "Norod78/sd15-megaphone-lora", # rank 16, scaling 0.5
    "artificialguybr/coloringbook-redmond-1-5v-coloring-book-lora-for-liberteredmond-sd-1-5", # rank 64, scaling 1.0
    "Norod78/SD15-Rubber-Duck-LoRA", # rank 16, scaling 0.5
    "wooyvern/sd-1.5-dark-fantasy-1.1", # rank 128, scaling 1.0
]

prompts =[
    "Marge Simpson holding a megaphone in her hand with her town in the background",
    "A lion, minimalist, Coloring Book, ColoringBookAF",
    "The girl with a pearl earring Rubber duck",
    "<lora:fantasyV1.1:1>, a painting of a skeleton with a long cloak and a group of skeletons in a forest with a crescent moon in the background, David Wojnarowicz, dark art, a screenprint, psychedelic art",    
]

def print_rank_scaling(pipe):
    print(f" rank: {pipe.unet.peft_config['default_0'].r}")
    print(f" scaling: {pipe.unet.down_blocks[0].attentions[0].proj_in.scaling}")


# pipe_id = "stabilityai/stable-diffusion-xl-base-1.0"
pipe_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
pipe = DiffusionPipeline.from_pretrained(pipe_id, torch_dtype=torch.float16).to("cuda")

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")


prepare_model_for_compiled_hotswap(pipe.unet, target_rank=128)
pipe.unet = pipe.unet.to(memory_format=torch.channels_last)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)

for i, (lora_repo, prompt) in enumerate(zip(loras, prompts)):
    hotswap = False if i == 0 else True
    print(f"\nProcessing LoRA {i}: {lora_repo}")
    print(f" prompt: {prompt}")
    print(f" hotswap: {hotswap}")
    
    # Start timing for the entire iteration
    start_time = time.time()

    # Load LoRA weights
    pipe.load_lora_weights(lora_repo, hotswap=hotswap, adapter_name = "default_0")
    print_rank_scaling(pipe)

    # Time image generation
    generator = torch.Generator(device="cuda").manual_seed(42)
    generate_start_time = time.time()
    image = pipe(prompt, num_inference_steps=50, generator=generator).images[0]
    generate_time = time.time() - generate_start_time

    # Save the image
    image.save(f"yiyi_test_14_out_{branch}_lora{i}.png")

    # Unload LoRA weights
    pipe.unload_lora_weights()

    # Calculate and print total time for this iteration
    total_time = time.time() - start_time
    
    print(f"Image generation time: {generate_time:.2f} seconds")
    print(f"Total time for LoRA {i}: {total_time:.2f} seconds")

mem_bytes = torch.cuda.max_memory_allocated()
print(f"total Memory: {mem_bytes/(1024*1024):.3f} MB")

@yiyixuxu
Copy link
Collaborator

cc @hlky here
in case if you are interested in playing around and helping test and contribute a bit more ideas on how we further optimize for this use case
this is a super impactful PR IMO!

Copy link
Collaborator

@hlky hlky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll run some tests. Expected state is:

  • Changing scale does not trigger recompilation
  • Different rank does not trigger recompilation

Correct?

Comment on lines +329 to +336
try:
from peft.utils.hotswap import check_hotswap_configs_compatible, hotswap_adapter_from_state_dict
except ImportError as exc:
msg = (
"Hotswapping requires PEFT > v0.14. Please upgrade PEFT to a higher version or install it "
"from source."
)
raise ImportError(msg) from exc
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we check the version instead of relying on exception? We have is_peft_version

def is_peft_version(operation: str, version: str):

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did this on purpose, as it allows to test the feature by installing PEFT from main. Otherwise, we'd have to wait for the next PEFT release. Normally, I'd also avoid try import ... for the side effect, but at this point, PEFT is already imported, so that's not a factor.

If you still want me to change this, LMK.

Comment on lines +381 to +388
try:
from peft.utils.hotswap import check_hotswap_configs_compatible, hotswap_adapter_from_state_dict
except ImportError as exc:
msg = (
"Hotswapping requires PEFT > v0.14. Please upgrade PEFT to a higher version or install it "
"from source."
)
raise ImportError(msg) from exc
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we check the version instead of relying on exception? We have is_peft_version

def is_peft_version(operation: str, version: str):

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above.

src/diffusers/loaders/unet.py Outdated Show resolved Hide resolved
src/diffusers/loaders/unet.py Outdated Show resolved Hide resolved
Comment on lines +431 to +432
logger.error(f"Loading {adapter_name} was unsucessful with the following error: \n{e}")
raise
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we raise the exception properly instead of logging an error?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, are we testing if this error is raised?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that I just copied the pattern from here:

try:
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
except Exception as e:
# In case `inject_adapter_in_model()` was unsuccessful even before injecting the `peft_config`.
if hasattr(self, "peft_config"):
for module in self.modules():
if isinstance(module, BaseTunerLayer):
active_adapters = module.active_adapters
for active_adapter in active_adapters:
if adapter_name in active_adapter:
module.delete_adapter(adapter_name)
self.peft_config.pop(adapter_name)
logger.error(f"Loading {adapter_name} was unsucessful with the following error: \n{e}")
raise

So this is just for consistency.

@yiyixuxu
Copy link
Collaborator

@hlky

yes, the general goal is to stay compiled as much as possible for different LoRa. these are two things likely to cause re-compile

Changing scale does not trigger recompilation
Different rank does not trigger recompilation
Correct?

Copy link
Member Author

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback, I have addressed the comments.

@yiyixuxu There is an error in your script, namely the prepare_model_for_compiled_hotswap function should be called after the first LoRA adapter has been loaded, otherwise there are no LoRA weights that can be padded. I amended the script to do that and it works for me (after some iterations, I get an error RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (CUDABFloat16Type) should be the same but I don't think it's directly related). Since this is an easy error to make, I'll check if PEFT can raise an error if we no LoRA was found.

Edit: Created a PR on PEFT to raise an error if the function is called too early. For this diffusers PR to work, it's not necessary, though.

Click to see testing script I used
from diffusers import DiffusionPipeline
import torch
import time

from peft.utils.hotswap import prepare_model_for_compiled_hotswap

torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

branch = "test-hotswap"

loras = [
    "Norod78/sd15-megaphone-lora", # rank 16, scaling 0.5
    "artificialguybr/coloringbook-redmond-1-5v-coloring-book-lora-for-liberteredmond-sd-1-5", # rank 64, scaling 1.0
    "Norod78/SD15-Rubber-Duck-LoRA", # rank 16, scaling 0.5
    "wooyvern/sd-1.5-dark-fantasy-1.1", # rank 128, scaling 1.0
]

prompts =[
    "Marge Simpson holding a megaphone in her hand with her town in the background",
    "A lion, minimalist, Coloring Book, ColoringBookAF",
    "The girl with a pearl earring Rubber duck",
    "<lora:fantasyV1.1:1>, a painting of a skeleton with a long cloak and a group of skeletons in a forest with a crescent moon in the background, David Wojnarowicz, dark art, a screenprint, psychedelic art",
]

def print_rank_scaling(pipe):
    print(f" rank: {pipe.unet.peft_config['default_0'].r}")
    print(f" scaling: {pipe.unet.down_blocks[0].attentions[0].proj_in.scaling}")


# pipe_id = "stabilityai/stable-diffusion-xl-base-1.0"
pipe_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
pipe = DiffusionPipeline.from_pretrained(pipe_id, torch_dtype=torch.float16).to("cuda")

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

first_round = True  # <============================================

for i, (lora_repo, prompt) in enumerate(zip(loras, prompts)):
    hotswap = False if i == 0 else True
    print(f"\nProcessing LoRA {i}: {lora_repo}")
    print(f" prompt: {prompt}")
    print(f" hotswap: {hotswap}")

    # Start timing for the entire iteration
    start_time = time.time()

    # Load LoRA weights
    pipe.load_lora_weights(lora_repo, hotswap=hotswap, adapter_name = "default_0")
    if first_round:   # <========================================
        prepare_model_for_compiled_hotswap(pipe.unet, target_rank=128)
        pipe.unet = pipe.unet.to(memory_format=torch.channels_last)
        pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
        first_round = False

    print_rank_scaling(pipe)

    # Time image generation
    generator = torch.Generator(device="cuda").manual_seed(42)
    generate_start_time = time.time()
    image = pipe(prompt, num_inference_steps=5, generator=generator).images[0]
    generate_time = time.time() - generate_start_time

    # Save the image
    image.save(f"yiyi_test_14_out_{branch}_lora{i}.png")

    # Unload LoRA weights
    pipe.unload_lora_weights()

    # Calculate and print total time for this iteration
    total_time = time.time() - start_time

    print(f"Image generation time: {generate_time:.2f} seconds")
    print(f"Total time for LoRA {i}: {total_time:.2f} seconds")

mem_bytes = torch.cuda.max_memory_allocated()
print(f"total Memory: {mem_bytes/(1024*1024):.3f} MB")

Comment on lines +329 to +336
try:
from peft.utils.hotswap import check_hotswap_configs_compatible, hotswap_adapter_from_state_dict
except ImportError as exc:
msg = (
"Hotswapping requires PEFT > v0.14. Please upgrade PEFT to a higher version or install it "
"from source."
)
raise ImportError(msg) from exc
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did this on purpose, as it allows to test the feature by installing PEFT from main. Otherwise, we'd have to wait for the next PEFT release. Normally, I'd also avoid try import ... for the side effect, but at this point, PEFT is already imported, so that's not a factor.

If you still want me to change this, LMK.

Comment on lines +431 to +432
logger.error(f"Loading {adapter_name} was unsucessful with the following error: \n{e}")
raise
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that I just copied the pattern from here:

try:
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
except Exception as e:
# In case `inject_adapter_in_model()` was unsuccessful even before injecting the `peft_config`.
if hasattr(self, "peft_config"):
for module in self.modules():
if isinstance(module, BaseTunerLayer):
active_adapters = module.active_adapters
for active_adapter in active_adapters:
if adapter_name in active_adapter:
module.delete_adapter(adapter_name)
self.peft_config.pop(adapter_name)
logger.error(f"Loading {adapter_name} was unsucessful with the following error: \n{e}")
raise

So this is just for consistency.

Comment on lines +381 to +388
try:
from peft.utils.hotswap import check_hotswap_configs_compatible, hotswap_adapter_from_state_dict
except ImportError as exc:
msg = (
"Hotswapping requires PEFT > v0.14. Please upgrade PEFT to a higher version or install it "
"from source."
)
raise ImportError(msg) from exc
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
Status: Future Release
Development

Successfully merging this pull request may close these issues.

7 participants