-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support dynamic LoRA loading with torch.compile
model
#9279
Comments
torch.compiled
modeltorch.compile
model
Some references:
|
@apolinario thanks for the detailed thread. Would be also nice to include a small snippet that we could quickly verify. And yes, ccing @BenjaminBossan for his comments here. I assume we could just add checks in diffusers/src/diffusers/loaders/lora_pipeline.py Line 1218 in 4cfb216
|
I have not dealt with hot-swapping LoRA weights on compiled models in PEFT, so I'm not surprised that it doesn't work out of the box. The PEFT experiments with compiled models all apply the compilation step after loading the PEFT weights. Maybe it's as easy as remapping the state dict and then calling I'll make a note to look into this on the PEFT side when I have a bit of extra time. But do let me know if you make any progress. |
Added! |
Applying this change seems to work: However, I get the following when doing skipping cudagraphs due to skipping cudagraphs due to cpu device (arg25_1). Found from :
File "/fsx/sayak/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 494, in forward
encoder_hidden_states, hidden_states = block(
File "/fsx/sayak/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 165, in forward
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
File "/fsx/sayak/diffusers/src/diffusers/models/normalization.py", line 137, in forward
emb = self.linear(self.silu(emb))
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/peft/tuners/lora/layer.py", line 509, in forward
result = result + lora_B(lora_A(dropout(x))) * scaling
|
Cool, 0e7204a works here too for However it triggers a recompilation when you load/swap a LoRA: import torch
from diffusers import DiffusionPipeline
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe = pipe.to(device)
pipe.transformer = pipe.transformer.to(memory_format=torch.channels_last)
pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead")
prompt = "a photo of an astronaut riding a horse on mars"
#This will compile for the first time
image = pipe(prompt).images[0]
pipe.load_lora_weights("multimodalart/flux-tarot-v1")
prompt = "a photo of an astronaut riding a horse on mars, tarot card"
#This will re-compile
image = pipe(prompt).images[0]
pipe.unload_lora_weights()
pipe.load_lora_weights("davisbro/half_illustration")
prompt = "In the style of TOK, a photo of an astronaut riding a horse on mars"
#This will re-compile again
image = pipe(prompt).images[0] |
Oh okay. Let's perhaps discuss this with the PyTorch team. |
Recompiles are coming from this
On the first compile, TorchDynamo guarded on the |
Yeah it does. When we're hitting diffusers/src/diffusers/loaders/lora_pipeline.py Line 1217 in 4cfb216
The codepath in To get around the problem, we can likely just fuse the adapter modules into the respective base model modules to prevent recompilation but this is somewhat a bit restrictive in terms of a smoother UX. Perhaps @apolinario can explain this a bit better. So, we wanted to see if there's any alternative we could try here. Cc: @BenjaminBossan |
No easy alternative I can think of as of now. Is it possible to somehow ensure that type of the class is same as before after all the weights have been loaded? If we can do that, there wont be any recompiles. A bad way would be to monkeypatch the |
If you're loading a Lora, it's very reasonable to have a recompilation, no? The actual operations are different. The question here is whether it needs a recompilation upon swapping a new LoRA. The other thing is that if you call |
Agreed. I think loading the first LoRA triggering recompilation is probably fine. However, hot swapping different LoRAs and having it not recompile would be very good to allow for dynamic apps that change LoRAs to benefit from compilation performances (as waiting for it to compile every time LoRAs are swapped would not allow for a live swapping application) |
I see. In that case, can someone run TORCH_LOGS="guards,recompiles" and share the log? For some reason, my run fails at
|
I am getting you the log but recompilation upon swapping or loading a new LoRA also seems reasonable to me:
Both of these combined should lead to different operations, I'd imagine. |
@anijain2305 here you go: https://huggingface.co/datasets/sayakpaul/torchao-diffusers/blob/main/traces/regular_lora_compile.txt I ran @apolinario's code here from the +---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.12 Driver Version: 535.104.12 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA H100 80GB HBM3 On | 00000000:97:00.0 Off | 0 |
| N/A 36C P0 68W / 700W | 2MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| No running processes found |
+---------------------------------------------------------------------------------------+ LMK if you need more information. |
Note that when PEFT loads the 2nd LoRA adapter, no layers should be swapped. Instead, PEFT will detect that LoRA layers are already in place and will instead update the If this is indeed the cause, the whole loading procedure probably needs to be rewritten to directly overwrite the |
Thanks @sayakpaul These are the recompilation reasons
More informationIt seems First recompilation
Second recompilation - There is this
Is it possible to keep the dictionary same on re-loading? |
@anijain2305, thanks!
Yes, scaling can change depending on varying
Well, the ...
pipe.unload_lora_weights()
pipe.load_lora_weights("davisbro/half_illustration")
print(pipe.unet.state_dict())
So, I am afraid this won't be possible because on a reload we're essentially adding things to the existing state dict of the modules affected by a given LoRA checkpoint. Ccing @BenjaminBossan if I missed something. |
Out of curiosity, would this be different if
I think there could be a way to replace the data from the first LoRA adapter directly with the second, instead of updating the dicts to add a separate adapter. To try this, I wanted to pass the same diffusers/src/diffusers/loaders/lora_pipeline.py Lines 1721 to 1724 in 4f495b0
After removing the guard, it appears like I could load the second adpater without increasing the size of the dicts. However, I'm not sure if this prevents recompilation. |
Looking more into this, I think that unloading for Flux models does not work correctly. Specifically, the (edit: compiled) diffusers/src/diffusers/loaders/lora_base.py Lines 371 to 377 in e417d02
Therefore, it is kept as is. Maybe it would be better to check However, I don't think that fixing this would solve the initial issue. If the LoRA layers are completely unloaded, it means they're removed and the second adapter will create completely new LoRA layers, which I guess would always trigger a recompilation. Maybe it's better to just offload the first adapter? |
Could you explain more?
And diffusers/src/diffusers/loaders/peft.py Line 305 in e417d02
|
It appears that the type check fails when the model is compiled. When I run this code: import torch
from diffusers import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe.transformer = pipe.transformer.to(memory_format=torch.channels_last)
pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead")
pipe.load_lora_weights("multimodalart/flux-tarot-v1", adapter_name="foobar")
lora_ids_0 = {name: id(module) for name, module in pipe.transformer.named_modules()}
pipe.unload_lora_weights() The LoRA weights of the When I remove the compilation step, the unloading works. Using |
Sorry missed it.
Thanks for this info. But I don’t understand why do we still need the edit: you meant for compiled models.
Agreed. I can follow code trails in diffusers. Would you be able to check it for peft? |
Not sure if something needs to be done on the PEFT side, or are you aware of something? If this is addressed on the diffusers side, we can try if anything has to change for PEFT too and fix it then. |
I was mainly referring to the isinstances used in both the codebases and the cases where we may have to include a check for compiled models too (or rejig the condition like the one you mentioned with hasattr). |
I think on the level of PEFT layers, which diffusers is using, we should be good. Maybe there are other parts of PEFT where |
Fair. I will keep this thread posted! |
I have been testing a few options. My testbed is the I am using this codebase: Codeimport torch
from diffusers import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe = pipe.to("cuda")
pipe.transformer = pipe.transformer.to(memory_format=torch.channels_last)
pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead")
prompt = "a photo of an astronaut riding a horse on mars"
# This will compile for the first time
image = pipe(prompt, num_inference_steps=5).images[0]
pipe.load_lora_weights("multimodalart/flux-tarot-v1", adapter_name="first")
prompt = "a photo of an astronaut riding a horse on mars, tarot card"
# This will re-compile
image = pipe(prompt, num_inference_steps=5).images[0]
# pipe.unload_lora_weights()
pipe.set_lora_device(adapter_names=["first"], device="cpu")
pipe.load_lora_weights("davisbro/half_illustration", adapter_name="second")
prompt = "In the style of TOK, a photo of an astronaut riding a horse on mars"
# This will re-compile again
image = pipe(prompt, num_inference_steps=5).images[0]
# pipe.unload_lora_weights()
pipe.set_lora_device(adapter_names=["second"], device="cpu")
pipe.load_lora_weights("davisbro/half_illustration")
prompt = "In the style of TOK, a photo of an astronaut riding a horse on mars"
# This will re-compile again
image = pipe(prompt, num_inference_steps=5).images[0] I have tried two options:
There are still recompiles. However, for the second option I see more recompiles than the former one. @anijain2305 any comments? Cc: @BenjaminBossan since we were discussing this. |
These are the guards and recompilation reasons - https://gist.github.com/anijain2305/9f3654e3a25b38446d572cfe2f9b7566 So, I think the recompiles can't be avoided easily because the codepath truly changes. From A small example that repros the above scenario is this
|
@anijain2305 I'm working on a hot swapping method. It's just a quick and dirty implementation for now but I would like to get some early feedback if it's worth putting more time into or not. For this, could you please check the trace that I received using |
@BenjaminBossan It looks good to me. As long as the scaling dictionary remains same, we should be good from recompilation issue. |
This PR adds the possibility to hot-swap LoRA adapters. It is WIP. Description As of now, users can already load multiple LoRA adapters. They can offload existing adapters or they can unload them (i.e. delete them). However, they cannot "hotswap" adapters yet, i.e. substitute the weights from one LoRA adapter with the weights of another, without the need to create a separate LoRA adapter. Generally, hot-swapping may not appear not super useful but when the model is compiled, it is necessary to prevent recompilation. See huggingface#9279 for more context. Caveats To hot-swap a LoRA adapter for another, these two adapters should target exactly the same layers and the "hyper-parameters" of the two adapters should be identical. For instance, the LoRA alpha has to be the same: Given that we keep the alpha from the first adapter, the LoRA scaling would be incorrect for the second adapter otherwise. Theoretically, we could override the scaling dict with the alpha values derived from the second adapter's config, but changing the dict will trigger a guard for recompilation, defeating the main purpose of the feature. I also found that compilation flags can have an impact on whether this works or not. E.g. when passing "reduce-overhead", there will be errors of the type: > input name: arg861_1. data pointer changed from 139647332027392 to 139647331054592 I don't know enough about compilation to determine whether this is problematic or not. Current state This is obviously WIP right now to collect feedback and discuss which direction to take this. If this PR turns out to be useful, the hot-swapping functions will be added to PEFT itself and can be imported here (or there is a separate copy in diffusers to avoid the need for a min PEFT version to use this feature). Moreover, more tests need to be added to better cover this feature, although we don't necessarily need tests for the hot-swapping functionality itself, since those tests will be added to PEFT. Furthermore, as of now, this is only implemented for the unet. Other pipeline components have yet to implement this feature. Finally, it should be properly documented. I would like to collect feedback on the current state of the PR before putting more time into finalizing it.
I created a draft PR based on what I tested: #9453. |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Not stale, work still in progress. |
@BenjaminBossan This is awesome, hope it can be checked in soon! |
@BenjaminBossan Any updates on this? |
Is your feature request related to a problem? Please describe.
Would be great to be able to load a LoRA to a model compiled with
torch.compile
Describe the solution you'd like.
Do
load_lora_weights
with a compiledpipe
(ideally without triggering recompilation)Currently, running this code:
It errors:
When compiled, the state dict of a model seem to add a _orig_mod prefix to all keys
Describe alternatives you've considered.
An alternative is to fuse the LoRA into the model and then compile, however this does not allow for hot swapping LoRAs (as a new pipeline and a new compilation is needed for every LoRA)
Additional context.
This seems to have been achieved by @chengzeyi , author of the now paused https://github.com/chengzeyi/stable-fast , however it seems to be part of the non-open source FAL optimized inference (however if you'd like to contribute this upstream, feel free!)
The text was updated successfully, but these errors were encountered: