-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LoRA] Implement hot-swapping of LoRA #9453
base: main
Are you sure you want to change the base?
[LoRA] Implement hot-swapping of LoRA #9453
Conversation
This PR adds the possibility to hot-swap LoRA adapters. It is WIP. Description As of now, users can already load multiple LoRA adapters. They can offload existing adapters or they can unload them (i.e. delete them). However, they cannot "hotswap" adapters yet, i.e. substitute the weights from one LoRA adapter with the weights of another, without the need to create a separate LoRA adapter. Generally, hot-swapping may not appear not super useful but when the model is compiled, it is necessary to prevent recompilation. See huggingface#9279 for more context. Caveats To hot-swap a LoRA adapter for another, these two adapters should target exactly the same layers and the "hyper-parameters" of the two adapters should be identical. For instance, the LoRA alpha has to be the same: Given that we keep the alpha from the first adapter, the LoRA scaling would be incorrect for the second adapter otherwise. Theoretically, we could override the scaling dict with the alpha values derived from the second adapter's config, but changing the dict will trigger a guard for recompilation, defeating the main purpose of the feature. I also found that compilation flags can have an impact on whether this works or not. E.g. when passing "reduce-overhead", there will be errors of the type: > input name: arg861_1. data pointer changed from 139647332027392 to 139647331054592 I don't know enough about compilation to determine whether this is problematic or not. Current state This is obviously WIP right now to collect feedback and discuss which direction to take this. If this PR turns out to be useful, the hot-swapping functions will be added to PEFT itself and can be imported here (or there is a separate copy in diffusers to avoid the need for a min PEFT version to use this feature). Moreover, more tests need to be added to better cover this feature, although we don't necessarily need tests for the hot-swapping functionality itself, since those tests will be added to PEFT. Furthermore, as of now, this is only implemented for the unet. Other pipeline components have yet to implement this feature. Finally, it should be properly documented. I would like to collect feedback on the current state of the PR before putting more time into finalizing it.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for working on this. I left some comments.
cc @apolinario |
does most lora have same scaling? |
So I played around a little bit, I have two main question: Do we support hotswap with different lora ranks? the rank config is not checked in the I think we should also look into supporting hot-swap with different scaling, I checked some popular loras on our hub, I think most of them have different ranks/alphas so this feature will be a lot more impactful if we are able to support different rank & scaling - based on this thread #9279, I understand that the change in the "scaling" dict would trigger a recompilation. But maybe there are ways to avoid it? for example, if this trigger recompile import torch
scaling = {}
def fn(x, key):
return x * scaling[key]
opt_fn = torch.compile(fn, backend="eager")
x = torch.rand(4)
scaling["first"] = 1.0
opt_fn(x, "first")
print(f" finish first run, updating scaling")
scaling["first"] = 2.0
opt_fn(x, "first") this won't import torch
scaling = {}
def fn(x, key):
return x * scaling[key]
opt_fn = torch.compile(fn, backend="eager")
x = torch.rand(4)
scaling["first"] = torch.tensor(1.0)
opt_fn(x, "first")
print(f" finish first run, updating scaling")
scaling["first"] = torch.tensor(2.0)
opt_fn(x, "first") I'm very excited about having this in diffusers ! think would be a super nice feature, especially for production use case :) |
I agree with your point on supporting LoRAs with different scaling in this context. With backend="eager", we may not get the full benefits of A good way to verify it would be to measure the performance of a pipeline with eager Cc: @anijain2305.
I will let @BenjaminBossan comment further but this might require a lot of changes within the tuner modules inside |
Thanks for all the feedback. I haven't forgotten about this PR, I was just occupied with other things. I'll come back to this as soon as I have a bit of time on my hands. The idea of using a tensor instead of float for scaling is intriguing, thanks for testing it. It might just work OOTB, as torch broadcasts 0-dim tensors automatically. Another possibility would be to multiply the scaling directly into one of the weights, so that the original alpha can be retained, but that is probably very error prone. Regarding different ranks, I have yet to test that. |
Yes, |
If different ranks become a problem, then https://huggingface.co/sayakpaul/lower-rank-flux-lora could provide a meaningful direction. |
Indeed, although avoiding recompilation altogether with different ranks would be even greater for real time swap applications |
yep can be a nice feature indeed! |
Indeed. For different ranks, things that come to mind:
|
A reverse direction of what I showed in #9453 is also possible (increase the rank of a LoRA): |
hi @BenjaminBossan and they work for the 4 loras I tested (all with different ranks and scaling) - I'm not as familiar with peft and just made enough changes for the purpose of the experiment & provide a reference point, so the code is very hacky there. sorry for that! to test , # testing hotswap PR
# TORCH_LOGS="guards,recompiles" TORCH_COMPILE_DEBUG=1 TORCH_LOGS_OUT=traces.txt python yiyi_test_3.py
from diffusers import DiffusionPipeline
import torch
import time
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
branch = "test-hotswap"
loras = [
"Norod78/sd15-megaphone-lora", # rank 16, scaling 0.5
"artificialguybr/coloringbook-redmond-1-5v-coloring-book-lora-for-liberteredmond-sd-1-5", # rank 64, scaling 1.0
"Norod78/SD15-Rubber-Duck-LoRA", # rank 16, scaling 0.5
"wooyvern/sd-1.5-dark-fantasy-1.1", # rank 128, scaling 1.0
]
prompts =[
"Marge Simpson holding a megaphone in her hand with her town in the background",
"A lion, minimalist, Coloring Book, ColoringBookAF",
"The girl with a pearl earring Rubber duck",
"<lora:fantasyV1.1:1>, a painting of a skeleton with a long cloak and a group of skeletons in a forest with a crescent moon in the background, David Wojnarowicz, dark art, a screenprint, psychedelic art",
]
def print_rank_scaling(pipe):
print(f" rank: {pipe.unet.peft_config['default_0'].r}")
print(f" scaling: {pipe.unet.down_blocks[0].attentions[0].proj_in.scaling}")
# pipe_id = "stabilityai/stable-diffusion-xl-base-1.0"
pipe_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
pipe = DiffusionPipeline.from_pretrained(pipe_id, torch_dtype=torch.float16).to("cuda")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
pipe.unet = pipe.unet.to(memory_format=torch.channels_last)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
for i, (lora_repo, prompt) in enumerate(zip(loras, prompts)):
hotswap = False if i == 0 else True
print(f"\nProcessing LoRA {i}: {lora_repo}")
print(f" prompt: {prompt}")
print(f" hotswap: {hotswap}")
# Start timing for the entire iteration
start_time = time.time()
# Load LoRA weights
pipe.load_lora_weights(lora_repo, hotswap=hotswap, adapter_name = "default_0")
print_rank_scaling(pipe)
# Time image generation
generator = torch.Generator(device="cuda").manual_seed(42)
generate_start_time = time.time()
image = pipe(prompt, num_inference_steps=50, generator=generator).images[0]
generate_time = time.time() - generate_start_time
# Save the image
image.save(f"yiyi_test_3_out_{branch}_lora{i}.png")
# Unload LoRA weights
pipe.unload_lora_weights()
# Calculate and print total time for this iteration
total_time = time.time() - start_time
print(f"Image generation time: {generate_time:.2f} seconds")
print(f"Total time for LoRA {i}: {total_time:.2f} seconds")
mem_bytes = torch.cuda.max_memory_allocated()
print(f"total Memory: {mem_bytes/(1024*1024):.3f} MB") output
confirm outputs are same as in main |
Very cool! Could you also try logging the traces just to confirm it does not trigger any recompilation? TORCH_LOGS="guards,recompiles" TORCH_LOGS_OUT=traces.txt python my_code.py |
I did and it doesn't |
also, I think, from the user experience perspective, it might be more convenient to have a "hotswap" mode that, once it's on, everything will be hot-swapped by default. I think, it is not something you use on and off, no? maybe be a question for @apolinario |
I think that is the case, yes! I also agree that the ability to hot-swap LoRAs (with But just in case it becomes a memory problem, users can explore the LoRA resizing path to have everything to a small unified rank (if it doesn't lead too much quality degradation). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, @BenjaminBossan! Just left some comments. Overall, this is looking very nice!
src/diffusers/loaders/peft.py
Outdated
There are some limitations to this technique, which are documented here: | ||
https://huggingface.co/docs/peft/v0.14.0/en/package_reference/hotswap |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we always refer the users to main
branch of the documentation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
src/diffusers/loaders/peft.py
Outdated
@@ -296,11 +325,47 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans | |||
if is_peft_version(">=", "0.13.1"): | |||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage | |||
|
|||
if hotswap: | |||
try: | |||
from peft.utils.hotswap import _check_hotswap_configs_compatible, hotswap_adapter_from_state_dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're relying on a private method from peft
here. I guess this is okay? Just checking.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made it public in huggingface/peft#2366.
src/diffusers/loaders/peft.py
Outdated
k = k[:-7] + f".{adapter_name}.weight" | ||
elif k.endswith("lora_B.bias"): # lora_bias=True option | ||
k = k[:-5] + f".{adapter_name}.bias" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-7
and -5
seem to be two magic numbers. Prefer clarifying.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated.
src/diffusers/loaders/unet.py
Outdated
@@ -64,7 +64,12 @@ class UNet2DConditionLoadersMixin: | |||
unet_name = UNET_NAME | |||
|
|||
@validate_hf_hub_args | |||
def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): | |||
def load_attn_procs( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this method will be deprecated, I think it's okay to not propagate these changes:
def test_load_attn_procs_raise_warning(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reverted.
tests/pipelines/test_pipelines.py
Outdated
@@ -2175,3 +2176,155 @@ def test_ddpm_ddim_equality_batched(self): | |||
|
|||
# the values aren't exactly equal, but the images look the same visually | |||
assert np.abs(ddpm_images - ddim_images).max() < 1e-1 | |||
|
|||
|
|||
class TestLoraHotSwapping(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the test! I think it's checking recompilation on an individual model-level, which is good.
However, to have a more realistic scenario covered, that is aligned with the users, I think we will need to have a full pipeline example, too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a pipeline test based on your previous comment but adjusted to load 2 different adapters.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we're only testing at the model-level here, would prefer this to live in tests/models/test_modeling_common.py
. We could try generalizing the test similar to test_save_load_lora_adapter()
so that we know this is supported across all the models that are a subclass of PeftAdapterMixin
. WDYT?
tests/pipelines/test_pipelines.py
Outdated
if do_compile: | ||
unet = torch.compile(unet, mode="reduce-overhead") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we wanna also do unet.compile()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Honestly, not sure if it's worth it, as the tests are already slow. Do users use this pattern? I've only seen torch.compile
in the wild and I'm not sure what the purpose of this is.
This is very strange. Moreover, we have |
- Revert deprecated method - Fix PEFT doc link to main - Don't use private function - Clarify magic numbers - Add pipeline test Moreover: - Extend docstrings - Extend existing test for outputs != 0 - Extend existing test for wrong adapter name
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated the PR, your comments should be addressed. Please check again.
I would also extend the request not to push on my branch to avoid conflicts :) We can do a merge with the latest main before merging.
This is very strange. Moreover, we have slow marker already which should prevent the test to be invoked with pr_tests.yml. Will try to debug.
Thanks.
src/diffusers/loaders/peft.py
Outdated
There are some limitations to this technique, which are documented here: | ||
https://huggingface.co/docs/peft/v0.14.0/en/package_reference/hotswap |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
src/diffusers/loaders/peft.py
Outdated
@@ -296,11 +325,47 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans | |||
if is_peft_version(">=", "0.13.1"): | |||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage | |||
|
|||
if hotswap: | |||
try: | |||
from peft.utils.hotswap import _check_hotswap_configs_compatible, hotswap_adapter_from_state_dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made it public in huggingface/peft#2366.
src/diffusers/loaders/peft.py
Outdated
k = k[:-7] + f".{adapter_name}.weight" | ||
elif k.endswith("lora_B.bias"): # lora_bias=True option | ||
k = k[:-5] + f".{adapter_name}.bias" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated.
src/diffusers/loaders/unet.py
Outdated
@@ -64,7 +64,12 @@ class UNet2DConditionLoadersMixin: | |||
unet_name = UNET_NAME | |||
|
|||
@validate_hf_hub_args | |||
def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): | |||
def load_attn_procs( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reverted.
tests/pipelines/test_pipelines.py
Outdated
@@ -2175,3 +2176,155 @@ def test_ddpm_ddim_equality_batched(self): | |||
|
|||
# the values aren't exactly equal, but the images look the same visually | |||
assert np.abs(ddpm_images - ddim_images).max() < 1e-1 | |||
|
|||
|
|||
class TestLoraHotSwapping(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a pipeline test based on your previous comment but adjusted to load 2 different adapters.
tests/pipelines/test_pipelines.py
Outdated
if do_compile: | ||
unet = torch.compile(unet, mode="reduce-overhead") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Honestly, not sure if it's worth it, as the tests are already slow. Do users use this pattern? I've only seen torch.compile
in the wild and I'm not sure what the purpose of this is.
parameterized.expand seems to ignore skip decorators if added in last place (i.e. innermost decorator).
@sayakpaul I think I figured out the reason why the tests ran despite the decorators. It appears to be a bad interaction with I've never seen this issue with Here is a reproducer: import unittest
import pytest
from parameterized import parameterized
class TestUnittestStyle(unittest.TestCase):
@parameterized.expand([0, 1, 2, 3])
@unittest.skipIf(True, reason="foo")
def test_parameterize_decorator_first(self, x):
assert True
@unittest.skipIf(True, reason="foo")
@parameterized.expand([0, 1, 2, 3])
def test_parameterize_decorator_last(self, x):
assert True
class TestPytestStyle(unittest.TestCase):
@pytest.mark.parametrize("x", [0, 1, 2, 3])
@pytest.mark.skipif(True, reason="foo")
def test_parameterize_decorator_first(self, x):
assert True
@pytest.mark.skipif(True, reason="foo")
@pytest.mark.parametrize("x", [0, 1, 2, 3])
def test_parameterize_decorator_last(self, x):
assert True This gives me:
As you can see, pytest does the right thing but PS: Just checked, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for addressing the comments and for the iterations. Some comments on the tests.
I guess after that, we will have a review from Yiyi. And then, we can propagate the changes to the other LoRA loader classes, add docs, etc.
tests/pipelines/test_pipelines.py
Outdated
############ | ||
# PIPELINE # | ||
############ | ||
|
||
def get_lora_state_dicts(self, modules_to_save, adapter_name): | ||
from peft import get_peft_model_state_dict | ||
|
||
state_dicts = {} | ||
for module_name, module in modules_to_save.items(): | ||
if module is not None: | ||
state_dicts[f"{module_name}_lora_layers"] = get_peft_model_state_dict( | ||
module, adapter_name=adapter_name | ||
) | ||
return state_dicts | ||
|
||
def get_dummy_input_pipeline(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
WDYT about moving this to tests/lora/utils.py
?
tests/pipelines/test_pipelines.py
Outdated
@@ -2175,3 +2176,155 @@ def test_ddpm_ddim_equality_batched(self): | |||
|
|||
# the values aren't exactly equal, but the images look the same visually | |||
assert np.abs(ddpm_images - ddim_images).max() < 1e-1 | |||
|
|||
|
|||
class TestLoraHotSwapping(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we're only testing at the model-level here, would prefer this to live in tests/models/test_modeling_common.py
. We could try generalizing the test similar to test_save_load_lora_adapter()
so that we know this is supported across all the models that are a subclass of PeftAdapterMixin
. WDYT?
@yiyixuxu could you give this another round of review? |
@sayakpaul Thanks for the review. Could I ask for some clarification?
So this would mean that I move the whole content of |
It would go under
Similar to
The pipeline level tests could go to Does this make sense? |
So I tried to understand the test structure. AFAICT, Also, IIUC, if I add these tests to one of these mixin classes, I'd have to go to most the child classes and override the test function to pass, right (remember that |
Okay thanks for explaining the context. Model-level tests should still be under Would it make more sense? |
Hope it's not too late. On a second thought, I think what you're suggesting makes more sense i.e., I would be fine with the current state of tests. Sorry about the bother. |
Also increase test coverage by also targeting conv2d layers (support of which was added recently on the PEFT PR).
I saw this too late, my latest commit has split the test into two tests as per the previous suggestion. However, I don't think it's bad, as splitting the test adds more clarity at the cost of some code duplication. LMK if you prefer the new way or the old one. In addition to that, I also extended the test to cover a wider range, most notably by also targeting One more question, I saw that some compile tests have |
Hmm, |
This is used based on
No problem then. Let's go with the current changes. Thank you!
Yeah safe to ignore. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking very nice. Let's wait for @yiyixuxu to review this once before propagating the changes.
@@ -1519,3 +1523,188 @@ def test_push_to_hub_library_name(self): | |||
|
|||
# Reset repo | |||
delete_repo(self.repo_id, token=TOKEN) | |||
|
|||
|
|||
class TestLoraHotSwappingForModel(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could mark the entire class here instead of marking the individual methods.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, done.
... instead of having them on each test method.
huggingface/peft#2366 is merged, so installing PEFT from main should allow the new tests to pass. To call: RUN_SLOW=1 RUN_COMPILE=1 pytest tests/pipelines/test_pipelines.py tests/models/test_modeling_common.py -k hotswap |
I'm testing this with this script here and getting an error - did I do something wrong?
Click to see testing script I used# testing hotswap PR
from diffusers import DiffusionPipeline
import torch
import time
from peft.utils.hotswap import prepare_model_for_compiled_hotswap
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
branch = "test-hotswap"
loras = [
"Norod78/sd15-megaphone-lora", # rank 16, scaling 0.5
"artificialguybr/coloringbook-redmond-1-5v-coloring-book-lora-for-liberteredmond-sd-1-5", # rank 64, scaling 1.0
"Norod78/SD15-Rubber-Duck-LoRA", # rank 16, scaling 0.5
"wooyvern/sd-1.5-dark-fantasy-1.1", # rank 128, scaling 1.0
]
prompts =[
"Marge Simpson holding a megaphone in her hand with her town in the background",
"A lion, minimalist, Coloring Book, ColoringBookAF",
"The girl with a pearl earring Rubber duck",
"<lora:fantasyV1.1:1>, a painting of a skeleton with a long cloak and a group of skeletons in a forest with a crescent moon in the background, David Wojnarowicz, dark art, a screenprint, psychedelic art",
]
def print_rank_scaling(pipe):
print(f" rank: {pipe.unet.peft_config['default_0'].r}")
print(f" scaling: {pipe.unet.down_blocks[0].attentions[0].proj_in.scaling}")
# pipe_id = "stabilityai/stable-diffusion-xl-base-1.0"
pipe_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
pipe = DiffusionPipeline.from_pretrained(pipe_id, torch_dtype=torch.float16).to("cuda")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
prepare_model_for_compiled_hotswap(pipe.unet, target_rank=128)
pipe.unet = pipe.unet.to(memory_format=torch.channels_last)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
for i, (lora_repo, prompt) in enumerate(zip(loras, prompts)):
hotswap = False if i == 0 else True
print(f"\nProcessing LoRA {i}: {lora_repo}")
print(f" prompt: {prompt}")
print(f" hotswap: {hotswap}")
# Start timing for the entire iteration
start_time = time.time()
# Load LoRA weights
pipe.load_lora_weights(lora_repo, hotswap=hotswap, adapter_name = "default_0")
print_rank_scaling(pipe)
# Time image generation
generator = torch.Generator(device="cuda").manual_seed(42)
generate_start_time = time.time()
image = pipe(prompt, num_inference_steps=50, generator=generator).images[0]
generate_time = time.time() - generate_start_time
# Save the image
image.save(f"yiyi_test_14_out_{branch}_lora{i}.png")
# Unload LoRA weights
pipe.unload_lora_weights()
# Calculate and print total time for this iteration
total_time = time.time() - start_time
print(f"Image generation time: {generate_time:.2f} seconds")
print(f"Total time for LoRA {i}: {total_time:.2f} seconds")
mem_bytes = torch.cuda.max_memory_allocated()
print(f"total Memory: {mem_bytes/(1024*1024):.3f} MB") |
cc @hlky here |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll run some tests. Expected state is:
- Changing scale does not trigger recompilation
- Different rank does not trigger recompilation
Correct?
try: | ||
from peft.utils.hotswap import check_hotswap_configs_compatible, hotswap_adapter_from_state_dict | ||
except ImportError as exc: | ||
msg = ( | ||
"Hotswapping requires PEFT > v0.14. Please upgrade PEFT to a higher version or install it " | ||
"from source." | ||
) | ||
raise ImportError(msg) from exc |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we check the version instead of relying on exception? We have is_peft_version
diffusers/src/diffusers/utils/import_utils.py
Line 808 in 97abdd2
def is_peft_version(operation: str, version: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did this on purpose, as it allows to test the feature by installing PEFT from main. Otherwise, we'd have to wait for the next PEFT release. Normally, I'd also avoid try import ...
for the side effect, but at this point, PEFT is already imported, so that's not a factor.
If you still want me to change this, LMK.
try: | ||
from peft.utils.hotswap import check_hotswap_configs_compatible, hotswap_adapter_from_state_dict | ||
except ImportError as exc: | ||
msg = ( | ||
"Hotswapping requires PEFT > v0.14. Please upgrade PEFT to a higher version or install it " | ||
"from source." | ||
) | ||
raise ImportError(msg) from exc |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we check the version instead of relying on exception? We have is_peft_version
diffusers/src/diffusers/utils/import_utils.py
Line 808 in 97abdd2
def is_peft_version(operation: str, version: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as above.
logger.error(f"Loading {adapter_name} was unsucessful with the following error: \n{e}") | ||
raise |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we raise the exception properly instead of logging an error?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, are we testing if this error is raised?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that I just copied the pattern from here:
diffusers/src/diffusers/loaders/peft.py
Lines 301 to 316 in 97abdd2
try: | |
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs) | |
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs) | |
except Exception as e: | |
# In case `inject_adapter_in_model()` was unsuccessful even before injecting the `peft_config`. | |
if hasattr(self, "peft_config"): | |
for module in self.modules(): | |
if isinstance(module, BaseTunerLayer): | |
active_adapters = module.active_adapters | |
for active_adapter in active_adapters: | |
if adapter_name in active_adapter: | |
module.delete_adapter(adapter_name) | |
self.peft_config.pop(adapter_name) | |
logger.error(f"Loading {adapter_name} was unsucessful with the following error: \n{e}") | |
raise |
So this is just for consistency.
yes, the general goal is to stay compiled as much as possible for different LoRa. these are two things likely to cause re-compile
|
Co-authored-by: hlky <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the feedback, I have addressed the comments.
@yiyixuxu There is an error in your script, namely the prepare_model_for_compiled_hotswap
function should be called after the first LoRA adapter has been loaded, otherwise there are no LoRA weights that can be padded. I amended the script to do that and it works for me (after some iterations, I get an error RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (CUDABFloat16Type) should be the same
but I don't think it's directly related). Since this is an easy error to make, I'll check if PEFT can raise an error if we no LoRA was found.
Edit: Created a PR on PEFT to raise an error if the function is called too early. For this diffusers PR to work, it's not necessary, though.
Click to see testing script I used
from diffusers import DiffusionPipeline
import torch
import time
from peft.utils.hotswap import prepare_model_for_compiled_hotswap
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
branch = "test-hotswap"
loras = [
"Norod78/sd15-megaphone-lora", # rank 16, scaling 0.5
"artificialguybr/coloringbook-redmond-1-5v-coloring-book-lora-for-liberteredmond-sd-1-5", # rank 64, scaling 1.0
"Norod78/SD15-Rubber-Duck-LoRA", # rank 16, scaling 0.5
"wooyvern/sd-1.5-dark-fantasy-1.1", # rank 128, scaling 1.0
]
prompts =[
"Marge Simpson holding a megaphone in her hand with her town in the background",
"A lion, minimalist, Coloring Book, ColoringBookAF",
"The girl with a pearl earring Rubber duck",
"<lora:fantasyV1.1:1>, a painting of a skeleton with a long cloak and a group of skeletons in a forest with a crescent moon in the background, David Wojnarowicz, dark art, a screenprint, psychedelic art",
]
def print_rank_scaling(pipe):
print(f" rank: {pipe.unet.peft_config['default_0'].r}")
print(f" scaling: {pipe.unet.down_blocks[0].attentions[0].proj_in.scaling}")
# pipe_id = "stabilityai/stable-diffusion-xl-base-1.0"
pipe_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
pipe = DiffusionPipeline.from_pretrained(pipe_id, torch_dtype=torch.float16).to("cuda")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
first_round = True # <============================================
for i, (lora_repo, prompt) in enumerate(zip(loras, prompts)):
hotswap = False if i == 0 else True
print(f"\nProcessing LoRA {i}: {lora_repo}")
print(f" prompt: {prompt}")
print(f" hotswap: {hotswap}")
# Start timing for the entire iteration
start_time = time.time()
# Load LoRA weights
pipe.load_lora_weights(lora_repo, hotswap=hotswap, adapter_name = "default_0")
if first_round: # <========================================
prepare_model_for_compiled_hotswap(pipe.unet, target_rank=128)
pipe.unet = pipe.unet.to(memory_format=torch.channels_last)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
first_round = False
print_rank_scaling(pipe)
# Time image generation
generator = torch.Generator(device="cuda").manual_seed(42)
generate_start_time = time.time()
image = pipe(prompt, num_inference_steps=5, generator=generator).images[0]
generate_time = time.time() - generate_start_time
# Save the image
image.save(f"yiyi_test_14_out_{branch}_lora{i}.png")
# Unload LoRA weights
pipe.unload_lora_weights()
# Calculate and print total time for this iteration
total_time = time.time() - start_time
print(f"Image generation time: {generate_time:.2f} seconds")
print(f"Total time for LoRA {i}: {total_time:.2f} seconds")
mem_bytes = torch.cuda.max_memory_allocated()
print(f"total Memory: {mem_bytes/(1024*1024):.3f} MB")
try: | ||
from peft.utils.hotswap import check_hotswap_configs_compatible, hotswap_adapter_from_state_dict | ||
except ImportError as exc: | ||
msg = ( | ||
"Hotswapping requires PEFT > v0.14. Please upgrade PEFT to a higher version or install it " | ||
"from source." | ||
) | ||
raise ImportError(msg) from exc |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did this on purpose, as it allows to test the feature by installing PEFT from main. Otherwise, we'd have to wait for the next PEFT release. Normally, I'd also avoid try import ...
for the side effect, but at this point, PEFT is already imported, so that's not a factor.
If you still want me to change this, LMK.
logger.error(f"Loading {adapter_name} was unsucessful with the following error: \n{e}") | ||
raise |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that I just copied the pattern from here:
diffusers/src/diffusers/loaders/peft.py
Lines 301 to 316 in 97abdd2
try: | |
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs) | |
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs) | |
except Exception as e: | |
# In case `inject_adapter_in_model()` was unsuccessful even before injecting the `peft_config`. | |
if hasattr(self, "peft_config"): | |
for module in self.modules(): | |
if isinstance(module, BaseTunerLayer): | |
active_adapters = module.active_adapters | |
for active_adapter in active_adapters: | |
if adapter_name in active_adapter: | |
module.delete_adapter(adapter_name) | |
self.peft_config.pop(adapter_name) | |
logger.error(f"Loading {adapter_name} was unsucessful with the following error: \n{e}") | |
raise |
So this is just for consistency.
try: | ||
from peft.utils.hotswap import check_hotswap_configs_compatible, hotswap_adapter_from_state_dict | ||
except ImportError as exc: | ||
msg = ( | ||
"Hotswapping requires PEFT > v0.14. Please upgrade PEFT to a higher version or install it " | ||
"from source." | ||
) | ||
raise ImportError(msg) from exc |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as above.
This PR adds the possibility to hot-swap LoRA adapters. It is WIP.
Description
As of now, users can already load multiple LoRA adapters. They can offload existing adapters or they can unload them (i.e. delete them). However, they cannot "hotswap" adapters yet, i.e. substitute the weights from one LoRA adapter with the weights of another, without the need to create a separate LoRA adapter.
Generally, hot-swapping may not appear not super useful but when the model is compiled, it is necessary to prevent recompilation. See #9279 for more context.
Caveats
To hot-swap a LoRA adapter for another, these two adapters should target exactly the same layers and the "hyper-parameters" of the two adapters should be identical. For instance, the LoRA alpha has to be the same: Given that we keep the alpha from the first adapter, the LoRA scaling would be incorrect for the second adapter otherwise.
Theoretically, we could override the scaling dict with the alpha values derived from the second adapter's config, but changing the dict will trigger a guard for recompilation, defeating the main purpose of the feature.
I also found that compilation flags can have an impact on whether this works or not. E.g. when passing "reduce-overhead", there will be errors of the type:
I don't know enough about compilation to determine whether this is problematic or not.
Current state
This is obviously WIP right now to collect feedback and discuss which direction to take this. If this PR turns out to be useful, the hot-swapping functions will be added to PEFT itself and can be imported here (or there is a separate copy in diffusers to avoid the need for a min PEFT version to use this feature).
Moreover, more tests need to be added to better cover this feature, although we don't necessarily need tests for the hot-swapping functionality itself, since those tests will be added to PEFT.
Furthermore, as of now, this is only implemented for the unet. Other pipeline components have yet to implement this feature.
Finally, it should be properly documented.
I would like to collect feedback on the current state of the PR before putting more time into finalizing it.
What does this PR do?
Fixes # (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.