Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: refine loRA diffusers to flux conversion logic #7708

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
18 changes: 18 additions & 0 deletions invokeai/backend/patches/layers/diffusers_ada_ln_lora_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch

from invokeai.backend.patches.layers.lora_layer import LoRALayer

def swap_shift_scale(tensor: torch.Tensor) -> torch.Tensor:
scale, shift = tensor.chunk(2, dim=0)
return torch.cat([shift, scale], dim=0)

class DiffusersAdaLN_LoRALayer(LoRALayer):
'''LoRA layer converted from Diffusers AdaLN, weight is shift-scale swapped'''

def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
# In SD3 and Flux implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
# while in diffusers it split into scale, shift.
# So we swap the linear projection weights in order to be able to use Flux implementation

weight = super().get_weight(orig_weight)
return swap_shift_scale(weight)
8 changes: 8 additions & 0 deletions invokeai/backend/patches/layers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from invokeai.backend.patches.layers.lokr_layer import LoKRLayer
from invokeai.backend.patches.layers.lora_layer import LoRALayer
from invokeai.backend.patches.layers.norm_layer import NormLayer
from invokeai.backend.patches.layers.diffusers_ada_ln_lora_layer import DiffusersAdaLN_LoRALayer


def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> BaseLayerPatch:
Expand All @@ -33,3 +34,10 @@ def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> BaseL
return NormLayer.from_state_dict_values(state_dict)
else:
raise ValueError(f"Unsupported lora format: {state_dict.keys()}")


def diffusers_adaLN_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> DiffusersAdaLN_LoRALayer:
if not "lora_up.weight" in state_dict:
raise ValueError(f"Unsupported lora format: {state_dict.keys()}")

return DiffusersAdaLN_LoRALayer.from_state_dict_values(state_dict)
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.merged_layer_patch import MergedLayerPatch, Range
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict, diffusers_adaLN_lora_layer_from_state_dict
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw

Expand Down Expand Up @@ -82,6 +82,12 @@ def add_lora_layer_if_present(src_key: str, dst_key: str) -> None:
values = get_lora_layer_values(src_layer_dict)
layers[dst_key] = any_lora_layer_from_state_dict(values)

def add_adaLN_lora_layer_if_present(src_key: str, dst_key: str) -> None:
if src_key in grouped_state_dict:
src_layer_dict = grouped_state_dict.pop(src_key)
values = get_lora_layer_values(src_layer_dict)
layers[dst_key] = diffusers_adaLN_lora_layer_from_state_dict(values)

def add_qkv_lora_layer_if_present(
src_keys: list[str],
src_weight_shapes: list[tuple[int, int]],
Expand Down Expand Up @@ -124,8 +130,8 @@ def add_qkv_lora_layer_if_present(
add_lora_layer_if_present("time_text_embed.text_embedder.linear_2", "vector_in.out_layer")

# time_text_embed.guidance_embedder -> guidance_in.
add_lora_layer_if_present("time_text_embed.guidance_embedder.linear_1", "guidance_in")
add_lora_layer_if_present("time_text_embed.guidance_embedder.linear_2", "guidance_in")
add_lora_layer_if_present("time_text_embed.guidance_embedder.linear_1", "guidance_in.in_layer")
add_lora_layer_if_present("time_text_embed.guidance_embedder.linear_2", "guidance_in.out_layer")

# context_embedder -> txt_in.
add_lora_layer_if_present("context_embedder", "txt_in")
Expand Down Expand Up @@ -223,6 +229,10 @@ def add_qkv_lora_layer_if_present(

# Final layer.
add_lora_layer_if_present("proj_out", "final_layer.linear")
add_adaLN_lora_layer_if_present(
'norm_out.linear',
'final_layer.adaLN_modulation.1',
)

# Assert that all keys were processed.
assert len(grouped_state_dict) == 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
from invokeai.backend.patches.layers.diffusers_ada_ln_lora_layer import DiffusersAdaLN_LoRALayer
from invokeai.backend.patches.layers.lokr_layer import LoKRLayer
from invokeai.backend.patches.layers.lora_layer import LoRALayer
from invokeai.backend.patches.layers.merged_layer_patch import MergedLayerPatch, Range
Expand Down Expand Up @@ -283,6 +284,7 @@ def test_inference_autocast_from_cpu_to_device(device: str, layer_under_test: La
"multiple_loras",
"concatenated_lora",
"flux_control_lora",
"diffusers_adaLN_lora",
"single_lokr",
]
)
Expand Down Expand Up @@ -370,6 +372,16 @@ def patch_under_test(request: pytest.FixtureRequest) -> PatchUnderTest:
)
input = torch.randn(1, in_features)
return ([(lokr_layer, 0.7)], input)
elif layer_type == "diffusers_adaLN_lora":
lora_layer = DiffusersAdaLN_LoRALayer(
up=torch.randn(out_features, rank),
mid=None,
down=torch.randn(rank, in_features),
alpha=1.0,
bias=torch.randn(out_features),
)
input = torch.randn(1, in_features)
return ([(lora_layer, 0.7)], input)
else:
raise ValueError(f"Unsupported layer_type: {layer_type}")

Expand Down
55 changes: 55 additions & 0 deletions tests/backend/patches/layers/test_diffuser_ada_ln_lora_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import torch

from invokeai.backend.patches.layers.lora_layer import LoRALayer
from invokeai.backend.patches.layers.diffusers_ada_ln_lora_layer import DiffusersAdaLN_LoRALayer, swap_shift_scale

def test_swap_shift_scale_for_tensor():
"""Test swaping function"""
tensor = torch.Tensor([1, 2])
expected = torch.Tensor([2, 1])

swapped = swap_shift_scale(tensor)
assert(torch.allclose(expected, swapped))

size= (3, 4)
first = torch.randn(size)
second = torch.randn(size)

tensor = torch.concat([first, second])
expected = torch.concat([second, first])

swapped = swap_shift_scale(tensor)
assert(torch.allclose(expected, swapped))

def test_diffusers_adaLN_lora_layer_get_weight():
"""Test getting weight from DiffusersAdaLN_LoRALayer."""
small_in_features = 4
big_in_features = 8
out_features = 16
rank = 4
alpha = 16.0

lora = LoRALayer(
up=torch.ones(out_features, rank),
mid=None,
down=torch.ones(rank, big_in_features),
alpha=alpha,
bias=None
)
layer = DiffusersAdaLN_LoRALayer(
up=torch.ones(out_features, rank),
mid=None,
down=torch.ones(rank, big_in_features),
alpha=alpha,
bias=None
)

# mock original weight, normally ignored in our loRA
orig_weight = torch.ones(small_in_features)

diffuser_weight = layer.get_weight(orig_weight)
lora_weight = lora.get_weight(orig_weight)

# diffusers lora weight should be flipped
assert(torch.allclose(diffuser_weight, swap_shift_scale(lora_weight)))

Loading