Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support dynamic LoRA loading with torch.compile model #9279

Open
apolinario opened this issue Aug 26, 2024 · 36 comments
Open

Support dynamic LoRA loading with torch.compile model #9279

apolinario opened this issue Aug 26, 2024 · 36 comments
Labels

Comments

@apolinario
Copy link
Collaborator

apolinario commented Aug 26, 2024

Is your feature request related to a problem? Please describe.
Would be great to be able to load a LoRA to a model compiled with torch.compile

Describe the solution you'd like.
Do load_lora_weights with a compiled pipe (ideally without triggering recompilation)

Currently, running this code:

import torch
from diffusers import DiffusionPipeline

device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe = pipe.to(device)
pipe.transformer = pipe.transformer.to(memory_format=torch.channels_last)
pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead")

pipe.load_lora_weights("multimodalart/flux-tarot-v1")

It errors:

Loading adapter weights from state_dict led to unexpected keys not found in the model:  ['single_transformer_blocks.0.attn.to_k.lora_A.default_3.weight', 'single_transformer_blocks.0.attn.to_k.lora_B.default_3.weight', 'single_transformer_blocks.0.attn.to_q.lora_A.default_3.weight', 'single_transformer_blocks.0.attn.to_q.lora_B.default_3.weight',

When compiled, the state dict of a model seem to add a _orig_mod prefix to all keys

odict_keys(['_orig_mod.time_text_embed.timestep_embedder.linear_1.weight',...

Describe alternatives you've considered.
An alternative is to fuse the LoRA into the model and then compile, however this does not allow for hot swapping LoRAs (as a new pipeline and a new compilation is needed for every LoRA)

Additional context.
This seems to have been achieved by @chengzeyi , author of the now paused https://github.com/chengzeyi/stable-fast , however it seems to be part of the non-open source FAL optimized inference (however if you'd like to contribute this upstream, feel free!)

@apolinario apolinario changed the title Support dynamic LoRA loading with torch.compiled model Support dynamic LoRA loading with torch.compile model Aug 26, 2024
@apolinario
Copy link
Collaborator Author

Some references:

odict_keys(['_orig_mod.time_text_embed.timestep_embedder.linear_1.weight', '_orig_mod.time_text_embed.timestep_embedder.linear_1.bias'...

@sayakpaul
Copy link
Member

@apolinario thanks for the detailed thread. Would be also nice to include a small snippet that we could quickly verify. And yes, ccing @BenjaminBossan for his comments here.

I assume we could just add checks in load_lora_weights() just before when we pass the adapter weights to the underlying model:

incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)

@BenjaminBossan
Copy link
Member

I have not dealt with hot-swapping LoRA weights on compiled models in PEFT, so I'm not surprised that it doesn't work out of the box. The PEFT experiments with compiled models all apply the compilation step after loading the PEFT weights. Maybe it's as easy as remapping the state dict and then calling set_peft_model_state_dict, but I wouldn't be surprised if there are more pitfalls related to torch.compile implementation details.

I'll make a note to look into this on the PEFT side when I have a bit of extra time. But do let me know if you make any progress.

@apolinario
Copy link
Collaborator Author

Would be also nice to include a small snippet that we could quickly verify

Added!

@sayakpaul
Copy link
Member

Applying this change seems to work:
0e7204a

However, I get the following when doing torch.compile() with mode="max-autotune", fullgraph=True:

skipping cudagraphs due to skipping cudagraphs due to cpu device (arg25_1). Found from : 
   File "/fsx/sayak/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 494, in forward
    encoder_hidden_states, hidden_states = block(
  File "/fsx/sayak/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 165, in forward
    norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
  File "/fsx/sayak/diffusers/src/diffusers/models/normalization.py", line 137, in forward
    emb = self.linear(self.silu(emb))
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/peft/tuners/lora/layer.py", line 509, in forward
    result = result + lora_B(lora_A(dropout(x))) * scaling

@apolinario
Copy link
Collaborator Author

apolinario commented Aug 27, 2024

Cool, 0e7204a works here too for reduce-overhead

However it triggers a recompilation when you load/swap a LoRA:

import torch
from diffusers import DiffusionPipeline

device = "cuda" if torch.cuda.is_available() else "cpu"

pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe = pipe.to(device)
pipe.transformer = pipe.transformer.to(memory_format=torch.channels_last)
pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead")

prompt = "a photo of an astronaut riding a horse on mars"
#This will compile for the first time
image = pipe(prompt).images[0]

pipe.load_lora_weights("multimodalart/flux-tarot-v1")
prompt = "a photo of an astronaut riding a horse on mars, tarot card"
#This will re-compile
image = pipe(prompt).images[0]

pipe.unload_lora_weights()
pipe.load_lora_weights("davisbro/half_illustration")
prompt = "In the style of TOK, a photo of an astronaut riding a horse on mars"
#This will re-compile again
image = pipe(prompt).images[0]

@sayakpaul
Copy link
Member

Oh okay. Let's perhaps discuss this with the PyTorch team.

@anijain2305
Copy link

anijain2305 commented Aug 27, 2024

Recompiles are coming from this

V0827 10:44:01.036000 3046403 torch/_dynamo/guards.py:2796] [0/1] [__recompiles]     triggered by the following guard failure(s):
V0827 10:44:01.036000 3046403 torch/_dynamo/guards.py:2796] [0/1] [__recompiles]     - 0/0: ___check_type_id(L['self']._modules['transformer_blocks']._modules['0']._modules['norm1']._modules['linear'], 81088688)  # emb = self.linear(self.silu(emb))  # diffusers/src/diffusers/models/normalization.py:137 in forward

On the first compile, TorchDynamo guarded on the id(type(torch.nn.Linear)). But loading weights has somehow changed the type of the linear layer (or atleast the id of the type of class). Do you know if loading weights of the linear changes its class type (of if we are dynamically creating a new class)?

@sayakpaul
Copy link
Member

sayakpaul commented Aug 27, 2024

Yeah it does.

When we're hitting

inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)

The codepath in peft is

https://github.com/huggingface/peft/blob/850eeb5c3a5cf692f5612c7c733b13fde184e05d/src/peft/mapping.py#L223

To get around the problem, we can likely just fuse the adapter modules into the respective base model modules to prevent recompilation but this is somewhat a bit restrictive in terms of a smoother UX. Perhaps @apolinario can explain this a bit better. So, we wanted to see if there's any alternative we could try here.

Cc: @BenjaminBossan

@anijain2305
Copy link

anijain2305 commented Aug 27, 2024

No easy alternative I can think of as of now. torch.compile is adding a guard on the type. In this case, it seems that the type change is benign, but in general the type change requires a recompilation.

Is it possible to somehow ensure that type of the class is same as before after all the weights have been loaded? If we can do that, there wont be any recompiles. A bad way would be to monkeypatch the __class__ itself to the old type.

@Chillee
Copy link
Contributor

Chillee commented Aug 27, 2024

If you're loading a Lora, it's very reasonable to have a recompilation, no? The actual operations are different. The question here is whether it needs a recompilation upon swapping a new LoRA.

The other thing is that if you call model.compile() (as opposed to torch.compile(model)), the state dict won't be modified.

@apolinario
Copy link
Collaborator Author

apolinario commented Aug 27, 2024

If you're loading a Lora, it's very reasonable to have a recompilation, no? The actual operations are different. The question here is whether it needs a recompilation upon swapping a new LoRA.

Agreed. I think loading the first LoRA triggering recompilation is probably fine. However, hot swapping different LoRAs and having it not recompile would be very good to allow for dynamic apps that change LoRAs to benefit from compilation performances (as waiting for it to compile every time LoRAs are swapped would not allow for a live swapping application)

@anijain2305
Copy link

I see. In that case, can someone run TORCH_LOGS="guards,recompiles" and share the log? For some reason, my run fails at


pipe.unload_lora_weights()
pipe.load_lora_weights("davisbro/half_illustration")

@sayakpaul
Copy link
Member

I am getting you the log but recompilation upon swapping or loading a new LoRA also seems reasonable to me:

  • The new LoRA may have a different rank
  • It may have different target layers than the one initially loaded

Both of these combined should lead to different operations, I'd imagine.

@sayakpaul
Copy link
Member

@anijain2305 here you go: https://huggingface.co/datasets/sayakpaul/torchao-diffusers/blob/main/traces/regular_lora_compile.txt

I ran @apolinario's code here from the lora-compile branch of diffusers. I am on torch 2.5.0.dev20240827+cu121. nvidia-smi:

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.12             Driver Version: 535.104.12   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA H100 80GB HBM3          On  | 00000000:97:00.0 Off |                    0 |
| N/A   36C    P0              68W / 700W |      2MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+

LMK if you need more information.

@BenjaminBossan
Copy link
Member

Note that when PEFT loads the 2nd LoRA adapter, no layers should be swapped. Instead, PEFT will detect that LoRA layers are already in place and will instead update the nn.ModuleDict of those LoRA layers to contain the newly loaded weights. However, the decomposed A and B LoRA weights are implemented as nn.Linear layers, so each newly loaded LoRA adds new layers to the nn.ModuleDict, not just new weights. I assume that this is what trips the guard.

If this is indeed the cause, the whole loading procedure probably needs to be rewritten to directly overwrite the weight.data of the first LoRA weight, without going through PEFT, in order to avoid what I just mentioned.

@anijain2305
Copy link

anijain2305 commented Aug 29, 2024

Thanks @sayakpaul

These are the recompilation reasons

 [0/2] [__recompiles]     triggered by the following guard failure(s):
 [0/2] [__recompiles]     - 0/1: len(L['self']._modules['transformer_blocks']._modules['0']._modules['norm1']._modules['linear'].scaling) == 1  # scaling = self.scaling[active_adapter]  # peft/tuners/lora/layer.py:505 in forward
 [0/2] [__recompiles]     - 0/0: ___check_type_id(L['self']._modules['transformer_blocks']._modules['0']._modules['norm1']._modules['linear'], 97167728)  # emb = self.linear(self.silu(emb))  # diffusers/src/diffusers/models/normalization.py:137 in forward

0/0 is for the first recompile. As we have been discussing, this is expected.

0/1 is the second recompile. And this is happening because length of _modules['linear'].scaling has increased. Is scaling supposed to change?

More information

It seems scaling is a dictionary. On the first recompilation, its length is 1 and key is default_0. On the second recompilation, there is a key addition of default_1. From the guards, it seems default_1 is not used because there is no guard on that value on that key. But given how Dynamo handles dicts, we still guard on the length of the dictionary.

First recompilation

 | +- GuardManager: source=L['self']._modules['transformer_blocks']._modules['0']._modules['norm1']._modules['linear'].scaling, accessed_by=DictGetItemGuardAccessor(scaling)
 | | +- DICT_LENGTH: len(L['self']._modules['transformer_blocks']._modules['0']._modules['norm1']._modules['linear'].scaling) == 1  # scaling = self.scaling[active_adapter]  # peft/tuners/lora/layer.py:505 in forward
 | | +- GuardManager: source=L['self']._modules['transformer_blocks']._modules['0']._modules['norm1']._modules['linear'].scaling['default_0'], accessed_by=DictGetItemGuardAccessor(default_0)
 | | | +- EQUALS_MATCH: L['self']._modules['transformer_blocks']._modules['0']._modules['norm1']._modules['linear'].scaling['default_0'] == 1.0  # scaling = self.scaling[active_adapter]  # peft/tuners/lora/layer.py:505 in forward

Second recompilation - There is this default_1 key. It does not seem like its used in the model.

 | +- GuardManager: source=L['self']._modules['single_transformer_blocks']._modules['37']._modules['norm']._modules['linear'].scaling, accessed_by=DictGetItemGuardAccessor(scaling)
 | | +- DICT_LENGTH: len(L['self']._modules['single_transformer_blocks']._modules['37']._modules['norm']._modules['linear'].scaling) == 2  # scaling = self.scaling[active_adapter]  # peft/tuners/lora/layer.py:505 in forward
 | | +- GuardManager: source=L['self']._modules['single_transformer_blocks']._modules['37']._modules['norm']._modules['linear'].scaling['default_0'], accessed_by=DictGetItemGuardAccessor(default_0)
 | | | +- EQUALS_MATCH: L['self']._modules['single_transformer_blocks']._modules['37']._modules['norm']._modules['linear'].scaling['default_0'] == 1.0  # scaling = self.scaling[active_adapter]  # peft/tuners/lora/layer.py:505 in forward
 | | +- GuardManager: source=L['self']._modules['single_transformer_blocks']._modules['37']._modules['norm']._modules['linear'].scaling['default_1'], accessed_by=DictGetItemGuardAccessor(default_1)

Is it possible to keep the dictionary same on re-loading?
I can investigate if we can just guard on the dict keys more lazily (i.e. only guard on only those keys/values that are used in the model, eliminating the length guard), but it seems little hard to do on surface.

@sayakpaul
Copy link
Member

@anijain2305, thanks!

0/1 is the second recompile. And this is happening because length of _modules['linear'].scaling has increased. Is scaling supposed to change?

Yes, scaling can change depending on varying alpha and rank values associated with a given LoRA checkpoint.

Second recompilation - There is this default_1 key. It does not seem like its used in the model.

Well, the default_1 key in the name of the adapter. All the parameters associated to that adapter will have that key. You can verify this by doing:

...

pipe.unload_lora_weights()
pipe.load_lora_weights("davisbro/half_illustration")
print(pipe.unet.state_dict())

Is it possible to keep the dictionary same on re-loading?

So, I am afraid this won't be possible because on a reload we're essentially adding things to the existing state dict of the modules affected by a given LoRA checkpoint.

Ccing @BenjaminBossan if I missed something.

@BenjaminBossan
Copy link
Member

It seems scaling is a dictionary.

Out of curiosity, would this be different if scaling where a ModuleDict?

Is it possible to keep the dictionary same on re-loading?

I think there could be a way to replace the data from the first LoRA adapter directly with the second, instead of updating the dicts to add a separate adapter. To try this, I wanted to pass the same adapter_name when calling load_lora_weights but this runs into a guard here:

if adapter_name in getattr(transformer, "peft_config", {}):
raise ValueError(
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
)

After removing the guard, it appears like I could load the second adpater without increasing the size of the dicts. However, I'm not sure if this prevents recompilation.

@BenjaminBossan
Copy link
Member

BenjaminBossan commented Aug 30, 2024

Looking more into this, I think that unloading for Flux models does not work correctly. Specifically, the (edit: compiled) FluxTransformer2DModel does not match in these lines:

for component in self._lora_loadable_modules:
model = getattr(self, component, None)
if model is not None:
if issubclass(model.__class__, ModelMixin):
model.unload_lora()
elif issubclass(model.__class__, PreTrainedModel):
_remove_text_encoder_monkey_patch(model)

Therefore, it is kept as is. Maybe it would be better to check if hasattr(model, "unload_lora")?

However, I don't think that fixing this would solve the initial issue. If the LoRA layers are completely unloaded, it means they're removed and the second adapter will create completely new LoRA layers, which I guess would always trigger a recompilation. Maybe it's better to just offload the first adapter?

@sayakpaul
Copy link
Member

sayakpaul commented Aug 30, 2024

Looking more into this, I think that unloading for Flux models does not work correctly. Specifically, the luxTransformer2DModel does not match in these lines:

Could you explain more? FluxTansformer2DModel is a subclass of PeftAdapterMixin:

class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):

And PeftAdapterMixin has:

def unload_lora(self):

@BenjaminBossan
Copy link
Member

BenjaminBossan commented Aug 30, 2024

It appears that the type check fails when the model is compiled. When I run this code:

import torch
from diffusers import DiffusionPipeline

pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe.transformer = pipe.transformer.to(memory_format=torch.channels_last)
pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead")

pipe.load_lora_weights("multimodalart/flux-tarot-v1", adapter_name="foobar")
lora_ids_0 = {name: id(module) for name, module in pipe.transformer.named_modules()}
pipe.unload_lora_weights()

The LoRA weights of the pipe.transformer are not unloaded. Checking in the debugger, unloading is not called on the transformer part.

When I remove the compilation step, the unloading works. Using if hasattr(model, "unload_lora") should still fix it. Alternatively, there needs to be a check for compiled models and then _orig_mod should be used. I could imagine that many more isinstance checks could be faulty with compiled models :-/

@sayakpaul
Copy link
Member

Sorry missed it.

When I remove the compilation step, the unloading works. Using if hasattr(model, "unload_lora") should still fix it.

Thanks for this info. But I don’t understand why do we still need the hasattr fix for the non-compiled model or am I missing something?

edit: you meant for compiled models.

Alternatively, there needs to be a check for compiled models and then _orig_mod should be used. I could imagine that many more isinstance checks could be faulty with compiled models :-/

Agreed. I can follow code trails in diffusers. Would you be able to check it for peft?

@BenjaminBossan
Copy link
Member

Agreed. I can follow code trails in diffusers. Would you be able to check it for peft?

Not sure if something needs to be done on the PEFT side, or are you aware of something? If this is addressed on the diffusers side, we can try if anything has to change for PEFT too and fix it then.

@sayakpaul
Copy link
Member

sayakpaul commented Sep 6, 2024

I was mainly referring to the isinstances used in both the codebases and the cases where we may have to include a check for compiled models too (or rejig the condition like the one you mentioned with hasattr).

@BenjaminBossan
Copy link
Member

I think on the level of PEFT layers, which diffusers is using, we should be good. Maybe there are other parts of PEFT where isinstance checks could be invalid for compiled models, but those should not block this issue.

@sayakpaul
Copy link
Member

sayakpaul commented Sep 6, 2024

Fair. I will keep this thread posted!

@sayakpaul
Copy link
Member

sayakpaul commented Sep 9, 2024

I have been testing a few options.

My testbed is the lora-compile branch in diffusers. The major changes are around what @BenjaminBossan and I discussed in the comments above.

I am using this codebase:

Code
import torch
from diffusers import DiffusionPipeline

pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe = pipe.to("cuda")
pipe.transformer = pipe.transformer.to(memory_format=torch.channels_last)
pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead")

prompt = "a photo of an astronaut riding a horse on mars"
# This will compile for the first time
image = pipe(prompt, num_inference_steps=5).images[0]

pipe.load_lora_weights("multimodalart/flux-tarot-v1", adapter_name="first")
prompt = "a photo of an astronaut riding a horse on mars, tarot card"
# This will re-compile
image = pipe(prompt, num_inference_steps=5).images[0]

# pipe.unload_lora_weights()
pipe.set_lora_device(adapter_names=["first"], device="cpu")

pipe.load_lora_weights("davisbro/half_illustration", adapter_name="second")
prompt = "In the style of TOK, a photo of an astronaut riding a horse on mars"
# This will re-compile again
image = pipe(prompt, num_inference_steps=5).images[0]

# pipe.unload_lora_weights()
pipe.set_lora_device(adapter_names=["second"], device="cpu")

pipe.load_lora_weights("davisbro/half_illustration")
prompt = "In the style of TOK, a photo of an astronaut riding a horse on mars"

# This will re-compile again
image = pipe(prompt, num_inference_steps=5).images[0]

I have tried two options:

  1. Unloading the LoRA weights by calling pipe.unload_lora_weights(). I also checked that the LoRA weights were getting unloaded correctly. Logs.
  2. Moving the currently loaded LoRA to CPU by calling pipe.set_lora_device(...). Logs.

There are still recompiles. However, for the second option I see more recompiles than the former one.

@anijain2305 any comments?

Cc: @BenjaminBossan since we were discussing this.

@anijain2305
Copy link

anijain2305 commented Sep 11, 2024

These are the guards and recompilation reasons - https://gist.github.com/anijain2305/9f3654e3a25b38446d572cfe2f9b7566

So, I think the recompiles can't be avoided easily because the codepath truly changes.

From torch.compile standpoint, after each loading, we are accessing a different key in the scaling dict. And in general, accessing a different key can lead to a totally different graph, so Dynamo guards on that key-value pair.

A small example that repros the above scenario is this


import torch


scaling = {}

def fn(x, key):
    return x * scaling[key]


opt_fn = torch.compile(fn, backend="eager")

x = torch.rand(4)

scaling["first"] = 1
opt_fn(x, "first")

scaling["second"] = 1
opt_fn(x, "second")

scaling["third"] = 1
opt_fn(x, "third")

torch.compile will guard on the key-value pair and cause a recompile everytime. Because the accessed key-value pair is different in each invocation of fn.

@BenjaminBossan
Copy link
Member

@anijain2305 I'm working on a hot swapping method. It's just a quick and dirty implementation for now but I would like to get some early feedback if it's worth putting more time into or not. For this, could you please check the trace that I received using TORCH_LOGS="guards,recompiles" TORCH_LOGS_OUT=traces.txt. It looks like there is no re-compilation.

traces.txt

@anijain2305
Copy link

@BenjaminBossan It looks good to me. As long as the scaling dictionary remains same, we should be good from recompilation issue.

image

BenjaminBossan added a commit to BenjaminBossan/diffusers that referenced this issue Sep 17, 2024
This PR adds the possibility to hot-swap LoRA adapters. It is WIP.

Description

As of now, users can already load multiple LoRA adapters. They can
offload existing adapters or they can unload them (i.e. delete them).
However, they cannot "hotswap" adapters yet, i.e. substitute the weights
from one LoRA adapter with the weights of another, without the need to
create a separate LoRA adapter.

Generally, hot-swapping may not appear not super useful but when the
model is compiled, it is necessary to prevent recompilation. See huggingface#9279
for more context.

Caveats

To hot-swap a LoRA adapter for another, these two adapters should target
exactly the same layers and the "hyper-parameters" of the two adapters
should be identical. For instance, the LoRA alpha has to be the same:
Given that we keep the alpha from the first adapter, the LoRA scaling
would be incorrect for the second adapter otherwise.

Theoretically, we could override the scaling dict with the alpha values
derived from the second adapter's config, but changing the dict will
trigger a guard for recompilation, defeating the main purpose of the
feature.

I also found that compilation flags can have an impact on whether this
works or not. E.g. when passing "reduce-overhead", there will be errors
of the type:

> input name: arg861_1. data pointer changed from 139647332027392 to
139647331054592

I don't know enough about compilation to determine whether this is
problematic or not.

Current state

This is obviously WIP right now to collect feedback and discuss which
direction to take this. If this PR turns out to be useful, the
hot-swapping functions will be added to PEFT itself and can be imported
here (or there is a separate copy in diffusers to avoid the need for a
min PEFT version to use this feature).

Moreover, more tests need to be added to better cover this feature,
although we don't necessarily need tests for the hot-swapping
functionality itself, since those tests will be added to PEFT.

Furthermore, as of now, this is only implemented for the unet. Other
pipeline components have yet to implement this feature.

Finally, it should be properly documented.

I would like to collect feedback on the current state of the PR before
putting more time into finalizing it.
@BenjaminBossan
Copy link
Member

I created a draft PR based on what I tested: #9453.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Oct 12, 2024
@BenjaminBossan
Copy link
Member

Not stale, work still in progress.

@github-actions github-actions bot removed the stale Issues that haven't received updates label Oct 14, 2024
@nom
Copy link

nom commented Nov 3, 2024

@BenjaminBossan This is awesome, hope it can be checked in soon!

@nom
Copy link

nom commented Dec 29, 2024

@BenjaminBossan Any updates on this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

8 participants
@nom @apolinario @BenjaminBossan @Chillee @yiyixuxu @anijain2305 @sayakpaul and others