Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FLUX Redux support #7726

Merged
merged 12 commits into from
Mar 5, 2025
15 changes: 15 additions & 0 deletions invokeai/app/invocations/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ class UIType(str, Enum, metaclass=MetaEnum):
CLIPGEmbedModel = "CLIPGEmbedModelField"
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
ControlLoRAModel = "ControlLoRAModelField"
SigLipModel = "SigLipModelField"
FluxReduxModel = "FluxReduxModelField"
# endregion

# region Misc Field Types
Expand Down Expand Up @@ -152,6 +154,7 @@ class FieldDescriptions:
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
spandrel_image_to_image_model = "Image-to-Image model"
vllm_model = "VLLM model"
lora_weight = "The weight at which the LoRA is applied to each model"
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
raw_prompt = "Raw prompt text (no parsing)"
Expand Down Expand Up @@ -201,6 +204,7 @@ class FieldDescriptions:
freeu_b1 = "Scaling factor for stage 1 to amplify the contributions of backbone features."
freeu_b2 = "Scaling factor for stage 2 to amplify the contributions of backbone features."
instantx_control_mode = "The control mode for InstantX ControlNet union models. Ignored for other ControlNet models. The standard mapping is: canny (0), tile (1), depth (2), blur (3), pose (4), gray (5), low quality (6). Negative values will be treated as 'None'."
flux_redux_conditioning = "FLUX Redux conditioning tensor"


class ImageField(BaseModel):
Expand Down Expand Up @@ -259,6 +263,17 @@ class FluxConditioningField(BaseModel):
)


class FluxReduxConditioningField(BaseModel):
"""A FLUX Redux conditioning tensor primitive value"""

conditioning: TensorField = Field(description="The Redux image conditioning tensor.")
mask: Optional[TensorField] = Field(
default=None,
description="The mask associated with this conditioning tensor. Excluded regions should be set to False, "
"included regions should be set to True.",
)


class SD3ConditioningField(BaseModel):
"""A conditioning tensor primitive value"""

Expand Down
60 changes: 57 additions & 3 deletions invokeai/app/invocations/flux_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
DenoiseMaskField,
FieldDescriptions,
FluxConditioningField,
FluxReduxConditioningField,
ImageField,
Input,
InputField,
Expand Down Expand Up @@ -46,7 +47,7 @@
pack,
unpack,
)
from invokeai.backend.flux.text_conditioning import FluxTextConditioning
from invokeai.backend.flux.text_conditioning import FluxReduxConditioning, FluxTextConditioning
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
Expand Down Expand Up @@ -103,6 +104,11 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
description="Negative conditioning tensor. Can be None if cfg_scale is 1.0.",
input=Input.Connection,
)
redux_conditioning: FluxReduxConditioningField | list[FluxReduxConditioningField] | None = InputField(
default=None,
description="FLUX Redux conditioning tensor.",
input=Input.Connection,
)
cfg_scale: float | list[float] = InputField(default=1.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
cfg_scale_start_step: int = InputField(
default=0,
Expand Down Expand Up @@ -190,11 +196,23 @@ def _run_diffusion(
dtype=inference_dtype,
device=TorchDevice.choose_torch_device(),
)
redux_conditionings: list[FluxReduxConditioning] = self._load_redux_conditioning(
context=context,
redux_cond_field=self.redux_conditioning,
packed_height=packed_h,
packed_width=packed_w,
device=TorchDevice.choose_torch_device(),
dtype=inference_dtype,
)
pos_regional_prompting_extension = RegionalPromptingExtension.from_text_conditioning(
pos_text_conditionings, img_seq_len=packed_h * packed_w
text_conditioning=pos_text_conditionings,
redux_conditioning=redux_conditionings,
img_seq_len=packed_h * packed_w,
)
neg_regional_prompting_extension = (
RegionalPromptingExtension.from_text_conditioning(neg_text_conditionings, img_seq_len=packed_h * packed_w)
RegionalPromptingExtension.from_text_conditioning(
text_conditioning=neg_text_conditionings, redux_conditioning=[], img_seq_len=packed_h * packed_w
)
if neg_text_conditionings
else None
)
Expand Down Expand Up @@ -400,6 +418,42 @@ def _load_text_conditioning(

return text_conditionings

def _load_redux_conditioning(
self,
context: InvocationContext,
redux_cond_field: FluxReduxConditioningField | list[FluxReduxConditioningField] | None,
packed_height: int,
packed_width: int,
device: torch.device,
dtype: torch.dtype,
) -> list[FluxReduxConditioning]:
# Normalize to a list of FluxReduxConditioningFields.
if redux_cond_field is None:
return []

redux_cond_list = (
[redux_cond_field] if isinstance(redux_cond_field, FluxReduxConditioningField) else redux_cond_field
)

redux_conditionings: list[FluxReduxConditioning] = []
for redux_cond_field in redux_cond_list:
# Load the Redux conditioning tensor.
redux_cond_data = context.tensors.load(redux_cond_field.conditioning.tensor_name)
redux_cond_data.to(device=device, dtype=dtype)

# Load the mask, if provided.
mask: Optional[torch.Tensor] = None
if redux_cond_field.mask is not None:
mask = context.tensors.load(redux_cond_field.mask.tensor_name)
mask = mask.to(device=device)
mask = RegionalPromptingExtension.preprocess_regional_prompt_mask(
mask, packed_height, packed_width, dtype, device
)

redux_conditionings.append(FluxReduxConditioning(redux_embeddings=redux_cond_data, mask=mask))

return redux_conditionings

@classmethod
def prep_cfg_scale(
cls, cfg_scale: float | list[float], timesteps: list[float], cfg_scale_start_step: int, cfg_scale_end_step: int
Expand Down
95 changes: 95 additions & 0 deletions invokeai/app/invocations/flux_redux.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from typing import Optional

import torch
from PIL import Image

from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import (
FieldDescriptions,
FluxReduxConditioningField,
InputField,
OutputField,
TensorField,
UIType,
)
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.redux.flux_redux_model import FluxReduxModel
from invokeai.backend.sig_lip.sig_lip_pipeline import SigLipPipeline
from invokeai.backend.util.devices import TorchDevice


@invocation_output("flux_redux_output")
class FluxReduxOutput(BaseInvocationOutput):
"""The conditioning output of a FLUX Redux invocation."""

redux_cond: FluxReduxConditioningField = OutputField(
description=FieldDescriptions.flux_redux_conditioning, title="Conditioning"
)


SIGLIP_STARTER_MODEL_NAME = "SigLIP - google/siglip-so400m-patch14-384"
FLUX_REDUX_STARTER_MODEL_NAME = "FLUX Redux"


@invocation(
"flux_redux",
title="FLUX Redux",
tags=["ip_adapter", "control"],
category="ip_adapter",
version="1.0.0",
classification=Classification.Prototype,
)
class FluxReduxInvocation(BaseInvocation):
"""Runs a FLUX Redux model to generate a conditioning tensor."""

image: ImageField = InputField(description="The FLUX Redux image prompt.")
mask: Optional[TensorField] = InputField(
default=None,
description="The bool mask associated with this FLUX Redux image prompt. Excluded regions should be set to "
"False, included regions should be set to True.",
)
redux_model: ModelIdentifierField = InputField(
description="The FLUX Redux model to use.",
title="FLUX Redux Model",
ui_type=UIType.FluxReduxModel,
)
siglip_model: ModelIdentifierField = InputField(
description="The SigLIP model to use.",
title="SigLIP Model",
ui_type=UIType.SigLipModel,
)

def invoke(self, context: InvocationContext) -> FluxReduxOutput:
image = context.images.get_pil(self.image.image_name, "RGB")

encoded_x = self._siglip_encode(context, image)
redux_conditioning = self._flux_redux_encode(context, encoded_x)

tensor_name = context.tensors.save(redux_conditioning)
return FluxReduxOutput(
redux_cond=FluxReduxConditioningField(conditioning=TensorField(tensor_name=tensor_name), mask=self.mask)
)

@torch.no_grad()
def _siglip_encode(self, context: InvocationContext, image: Image.Image) -> torch.Tensor:
with context.models.load(self.siglip_model).model_on_device() as (_, siglip_pipeline):
assert isinstance(siglip_pipeline, SigLipPipeline)
return siglip_pipeline.encode_image(
x=image, device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()
)

@torch.no_grad()
def _flux_redux_encode(self, context: InvocationContext, encoded_x: torch.Tensor) -> torch.Tensor:
with context.models.load(self.redux_model).model_on_device() as (_, flux_redux):
assert isinstance(flux_redux, FluxReduxModel)
dtype = next(flux_redux.parameters()).dtype
encoded_x = encoded_x.to(dtype=dtype)
return flux_redux(encoded_x)
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
import torch
import torchvision

from invokeai.backend.flux.text_conditioning import FluxRegionalTextConditioning, FluxTextConditioning
from invokeai.backend.flux.text_conditioning import (
FluxReduxConditioning,
FluxRegionalTextConditioning,
FluxTextConditioning,
)
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Range
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.mask import to_standard_float_mask
Expand Down Expand Up @@ -32,14 +36,19 @@ def get_single_stream_attn_mask(self, block_index: int) -> torch.Tensor | None:
return order[block_index % len(order)]

@classmethod
def from_text_conditioning(cls, text_conditioning: list[FluxTextConditioning], img_seq_len: int):
def from_text_conditioning(
cls,
text_conditioning: list[FluxTextConditioning],
redux_conditioning: list[FluxReduxConditioning],
img_seq_len: int,
):
"""Create a RegionalPromptingExtension from a list of text conditionings.

Args:
text_conditioning (list[FluxTextConditioning]): The text conditionings to use for regional prompting.
img_seq_len (int): The image sequence length (i.e. packed_height * packed_width).
"""
regional_text_conditioning = cls._concat_regional_text_conditioning(text_conditioning)
regional_text_conditioning = cls._concat_regional_text_conditioning(text_conditioning, redux_conditioning)
attn_mask_with_restricted_img_self_attn = cls._prepare_restricted_attn_mask(
regional_text_conditioning, img_seq_len
)
Expand Down Expand Up @@ -202,6 +211,7 @@ def _prepare_restricted_attn_mask(
def _concat_regional_text_conditioning(
cls,
text_conditionings: list[FluxTextConditioning],
redux_conditionings: list[FluxReduxConditioning],
) -> FluxRegionalTextConditioning:
"""Concatenate regional text conditioning data into a single conditioning tensor (with associated masks)."""
concat_t5_embeddings: list[torch.Tensor] = []
Expand All @@ -217,18 +227,27 @@ def _concat_regional_text_conditioning(
global_clip_embedding = text_conditioning.clip_embeddings
break

# Handle T5 text embeddings.
cur_t5_embedding_len = 0
for text_conditioning in text_conditionings:
concat_t5_embeddings.append(text_conditioning.t5_embeddings)

concat_t5_embedding_ranges.append(
Range(start=cur_t5_embedding_len, end=cur_t5_embedding_len + text_conditioning.t5_embeddings.shape[1])
)

image_masks.append(text_conditioning.mask)

cur_t5_embedding_len += text_conditioning.t5_embeddings.shape[1]

# Handle Redux embeddings.
for redux_conditioning in redux_conditionings:
concat_t5_embeddings.append(redux_conditioning.redux_embeddings)
concat_t5_embedding_ranges.append(
Range(
start=cur_t5_embedding_len, end=cur_t5_embedding_len + redux_conditioning.redux_embeddings.shape[1]
)
)
image_masks.append(redux_conditioning.mask)
cur_t5_embedding_len += redux_conditioning.redux_embeddings.shape[1]

t5_embeddings = torch.cat(concat_t5_embeddings, dim=1)

# Initialize the txt_ids tensor.
Expand Down
17 changes: 17 additions & 0 deletions invokeai/backend/flux/redux/flux_redux_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import torch

# This model definition is based on:
# https://github.com/black-forest-labs/flux/blob/716724eb276d94397be99710a0a54d352664e23b/src/flux/modules/image_embedders.py#L66


class FluxReduxModel(torch.nn.Module):
def __init__(self, redux_dim: int = 1152, txt_in_features: int = 4096) -> None:
super().__init__()

self.redux_dim = redux_dim

self.redux_up = torch.nn.Linear(redux_dim, txt_in_features * 3)
self.redux_down = torch.nn.Linear(txt_in_features * 3, txt_in_features)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.redux_down(torch.nn.functional.silu(self.redux_up(x)))
11 changes: 11 additions & 0 deletions invokeai/backend/flux/redux/flux_redux_state_dict_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from typing import Any, Dict


def is_state_dict_likely_flux_redux(state_dict: Dict[str, Any]) -> bool:
"""Checks if the provided state dict is likely a FLUX Redux model."""

expected_keys = {"redux_down.bias", "redux_down.weight", "redux_up.bias", "redux_up.weight"}
if set(state_dict.keys()) == expected_keys:
return True

return False
7 changes: 7 additions & 0 deletions invokeai/backend/flux/text_conditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@ class FluxTextConditioning:
mask: torch.Tensor | None


@dataclass
class FluxReduxConditioning:
redux_embeddings: torch.Tensor
# If mask is None, the prompt is a global prompt.
mask: torch.Tensor | None


@dataclass
class FluxRegionalTextConditioning:
# Concatenated text embeddings.
Expand Down
Loading
Loading