Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: refine loRA diffusers to flux conversion logic #7708

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,19 @@ def add_lora_layer_if_present(src_key: str, dst_key: str) -> None:
values = get_lora_layer_values(src_layer_dict)
layers[dst_key] = any_lora_layer_from_state_dict(values)

def add_adaLN_lora_layer_if_present(src_key: str, dst_key: str) -> None:
if src_key in grouped_state_dict:
src_layer_dict = grouped_state_dict.pop(src_key)
values = get_lora_layer_values(src_layer_dict)

for _key in values.keys():
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
scale, shift = values[_key].chunk(2, dim=0)
values[_key] = torch.cat([shift, scale], dim=0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't look right to me. If I'm understanding correctly, in the case of a vanilla LoRA layer, we should only be flipping one of the LoRA components.

The required transformation would be a bit more involved for other LoRA variants (LoHA, LoKR, etc.), so I'm fine with only supporting vanilla LoRAs. But, we should assert that the result of any_lora_layer_from_state_dict() is a LoRALayer.

Copy link
Author

@simpletrontdip simpletrontdip Mar 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @RyanJDick thanks for spending time 👯
I have to confess, it is more complex than I expected, sorry for not asking the team before hand.

As my understanding

# for normal LoRA layer
delta_W = up @ down
W = W + delta_W

# for AdaLN in diffusers
W_prime = swap_shift_scale(W)
delta_W_prime = swap_shift_scale(delta_W)

# => We may need to add a custom LoRA layer to swap them in `get_weight`

class AdaLN_LoRALayer(LoRALayer):
   def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
       '''swap shift and scale before returning real weight'''
        weight = super().get_weight(orig_weight)
        scale, shift = weight.chunk(2, dim=0)
        return torch.cat([shift, scale], dim=0)

# we need to build and return this layer in our function

What do you think?


layers[dst_key] = any_lora_layer_from_state_dict(values)

def add_qkv_lora_layer_if_present(
src_keys: list[str],
src_weight_shapes: list[tuple[int, int]],
Expand Down Expand Up @@ -124,8 +137,8 @@ def add_qkv_lora_layer_if_present(
add_lora_layer_if_present("time_text_embed.text_embedder.linear_2", "vector_in.out_layer")

# time_text_embed.guidance_embedder -> guidance_in.
add_lora_layer_if_present("time_text_embed.guidance_embedder.linear_1", "guidance_in")
add_lora_layer_if_present("time_text_embed.guidance_embedder.linear_2", "guidance_in")
add_lora_layer_if_present("time_text_embed.guidance_embedder.linear_1", "guidance_in.in_layer")
add_lora_layer_if_present("time_text_embed.guidance_embedder.linear_2", "guidance_in.out_layer")

# context_embedder -> txt_in.
add_lora_layer_if_present("context_embedder", "txt_in")
Expand Down Expand Up @@ -223,6 +236,10 @@ def add_qkv_lora_layer_if_present(

# Final layer.
add_lora_layer_if_present("proj_out", "final_layer.linear")
add_adaLN_lora_layer_if_present(
'norm_out.linear',
'final_layer.adaLN_modulation.1',
)

# Assert that all keys were processed.
assert len(grouped_state_dict) == 0
Expand Down