Skip to content

Commit

Permalink
Fix typegen
Browse files Browse the repository at this point in the history
  • Loading branch information
jazzhaiku committed Mar 5, 2025
1 parent 7a0f958 commit e63e821
Show file tree
Hide file tree
Showing 2 changed files with 624 additions and 651 deletions.
51 changes: 25 additions & 26 deletions invokeai/backend/model_manager/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,9 +381,8 @@ def parse(c):
return cls


class CheckpointConfigBase:
class CheckpointConfigBase(BaseModel):
"""Base class for checkpoint-style models."""

format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b, ModelFormat.GGUFQuantized] = Field(
description="Format of the provided checkpoint model", default=ModelFormat.Checkpoint
)
Expand All @@ -393,51 +392,51 @@ class CheckpointConfigBase:
)


class DiffusersConfigBase:
class DiffusersConfigBase(BaseModel):
"""Base class for diffusers-style models."""

format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.Default


class LoRAConfigBase:
class LoRAConfigBase(BaseModel):
"""Base class for LoRA models."""

type: Literal[ModelType.LoRA] = ModelType.LoRA
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)


class T5EncoderConfigBase:
class T5EncoderConfigBase(BaseModel):
"""Base class for diffusers-style models."""

type: Literal[ModelType.T5Encoder] = ModelType.T5Encoder


@legacy_probe
class T5EncoderConfig(ModelConfigBase, T5EncoderConfigBase):
class T5EncoderConfig(T5EncoderConfigBase, ModelConfigBase):
format: Literal[ModelFormat.T5Encoder] = ModelFormat.T5Encoder


@legacy_probe
class T5EncoderBnbQuantizedLlmInt8bConfig(ModelConfigBase, T5EncoderConfigBase):
class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, ModelConfigBase):
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b


@legacy_probe
class LoRALyCORISConfig(ModelConfigBase, LoRAConfigBase):
class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
"""Model config for LoRA/Lycoris models."""

format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS


class ControlAdapterConfigBase:
class ControlAdapterConfigBase(BaseModel):
default_settings: Optional[ControlAdapterDefaultSettings] = Field(
description="Default settings for this model", default=None
)


@legacy_probe
class ControlLoRALyCORISConfig(ModelConfigBase, ControlAdapterConfigBase):
class ControlLoRALyCORISConfig(ControlAdapterConfigBase, ModelConfigBase):
"""Model config for Control LoRA models."""

type: Literal[ModelType.ControlLoRa] = ModelType.ControlLoRa
Expand All @@ -446,7 +445,7 @@ class ControlLoRALyCORISConfig(ModelConfigBase, ControlAdapterConfigBase):


@legacy_probe
class ControlLoRADiffusersConfig(ModelConfigBase, ControlAdapterConfigBase):
class ControlLoRADiffusersConfig(ControlAdapterConfigBase, ModelConfigBase):
"""Model config for Control LoRA models."""

type: Literal[ModelType.ControlLoRa] = ModelType.ControlLoRa
Expand All @@ -455,16 +454,15 @@ class ControlLoRADiffusersConfig(ModelConfigBase, ControlAdapterConfigBase):


@legacy_probe
class LoRADiffusersConfig(ModelConfigBase, LoRAConfigBase):
class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase):
"""Model config for LoRA/Diffusers models."""

format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers


@legacy_probe
class VAECheckpointConfig(ModelConfigBase, CheckpointConfigBase):
class VAECheckpointConfig(CheckpointConfigBase, ModelConfigBase):
"""Model config for standalone VAE models."""

type: Literal[ModelType.VAE] = ModelType.VAE


Expand All @@ -477,15 +475,15 @@ class VAEDiffusersConfig(ModelConfigBase):


@legacy_probe
class ControlNetDiffusersConfig(ModelConfigBase, DiffusersConfigBase, ControlAdapterConfigBase):
class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, ModelConfigBase):
"""Model config for ControlNet models (diffusers version)."""

type: Literal[ModelType.ControlNet] = ModelType.ControlNet
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers


@legacy_probe
class ControlNetCheckpointConfig(ModelConfigBase, CheckpointConfigBase, ControlAdapterConfigBase):
class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase, ModelConfigBase):
"""Model config for ControlNet models (diffusers version)."""

type: Literal[ModelType.ControlNet] = ModelType.ControlNet
Expand All @@ -507,7 +505,7 @@ class TextualInversionFolderConfig(ModelConfigBase):
format: Literal[ModelFormat.EmbeddingFolder] = ModelFormat.EmbeddingFolder


class MainConfigBase:
class MainConfigBase(BaseModel):
type: Literal[ModelType.Main] = ModelType.Main
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
default_settings: Optional[MainModelDefaultSettings] = Field(
Expand All @@ -517,15 +515,15 @@ class MainConfigBase:


@legacy_probe
class MainCheckpointConfig(ModelConfigBase, CheckpointConfigBase, MainConfigBase):
class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase, ModelConfigBase):
"""Model config for main checkpoint models."""

prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False


@legacy_probe
class MainBnbQuantized4bCheckpointConfig(ModelConfigBase, CheckpointConfigBase, MainConfigBase):
class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase, ModelConfigBase):
"""Model config for main checkpoint models."""

format: Literal[ModelFormat.BnbQuantizednf4b] = ModelFormat.BnbQuantizednf4b
Expand All @@ -534,7 +532,7 @@ class MainBnbQuantized4bCheckpointConfig(ModelConfigBase, CheckpointConfigBase,


@legacy_probe
class MainGGUFCheckpointConfig(ModelConfigBase, CheckpointConfigBase, MainConfigBase):
class MainGGUFCheckpointConfig(CheckpointConfigBase, MainConfigBase, ModelConfigBase):
"""Model config for main checkpoint models."""

format: Literal[ModelFormat.GGUFQuantized] = ModelFormat.GGUFQuantized
Expand All @@ -543,18 +541,18 @@ class MainGGUFCheckpointConfig(ModelConfigBase, CheckpointConfigBase, MainConfig


@legacy_probe
class MainDiffusersConfig(ModelConfigBase, DiffusersConfigBase, MainConfigBase):
class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, ModelConfigBase):
"""Model config for main diffusers models."""

pass


class IPAdapterConfigBase:
class IPAdapterConfigBase(BaseModel):
type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter


@legacy_probe
class IPAdapterInvokeAIConfig(ModelConfigBase, IPAdapterConfigBase):
class IPAdapterInvokeAIConfig(IPAdapterConfigBase, ModelConfigBase):
"""Model config for IP Adapter diffusers format models."""

# TODO(ryand): Should we deprecate this field? From what I can tell, it hasn't been probed correctly for a long
Expand All @@ -564,16 +562,16 @@ class IPAdapterInvokeAIConfig(ModelConfigBase, IPAdapterConfigBase):


@legacy_probe
class IPAdapterCheckpointConfig(ModelConfigBase, IPAdapterConfigBase):
class IPAdapterCheckpointConfig(IPAdapterConfigBase, ModelConfigBase):
"""Model config for IP Adapter checkpoint format models."""

format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint


@legacy_probe
class CLIPEmbedDiffusersConfig(ModelConfigBase, DiffusersConfigBase):
class CLIPEmbedDiffusersConfig(DiffusersConfigBase, ModelConfigBase):
"""Model config for Clip Embeddings."""

variant: ClipVariantType = Field(description="Clip variant for this model")
type: Literal[ModelType.CLIPEmbed] = ModelType.CLIPEmbed
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers

Expand Down Expand Up @@ -684,3 +682,4 @@ def make_config(model_data: Dict[str, Any], timestamp: Optional[float] = None) -
model.converted_at = timestamp
validate_hash(model.hash)
return model # type: ignore

Loading

0 comments on commit e63e821

Please sign in to comment.