Skip to content

Commit

Permalink
Fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
jazzhaiku committed Mar 4, 2025
1 parent 8fdd99a commit 6eb4a2b
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
18 changes: 12 additions & 6 deletions invokeai/backend/model_manager/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,10 +303,7 @@ def __init__(self, **fields):
def get_tag(cls) -> Tag:
type = cls.model_fields["type"].default.value
format = cls.model_fields["format"].default.value
variant = cls.model_fields.get("variant")
if not variant:
return Tag(f"{type}.{format}")
return Tag(f"{type}.{format}.{variant.default.value}")
return Tag(f"{type}.{format}")

@classmethod
@abstractmethod
Expand Down Expand Up @@ -559,11 +556,19 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig):
"""Model config for CLIP-G Embeddings."""
variant: ClipVariantType = ClipVariantType.G

@classmethod
def get_tag(cls) -> Tag:
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.G.value}")


@legacy_probe
class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig):
"""Model config for CLIP-L Embeddings."""
variant: ClipVariantType = ClipVariantType.L
variant: ClipVariantType = ClipVariantType.L.value

@classmethod
def get_tag(cls) -> Tag:
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.L.value}")


@legacy_probe
Expand Down Expand Up @@ -613,7 +618,8 @@ def get_model_discriminator_value(v: Any) -> str:
type_ = v.type.value
variant_ = getattr(v, "variant", None)

if variant_:
# special case, ideally would return
if type_ == ModelType.CLIPEmbed.value and format_ == ModelFormat.Diffusers.value and variant_:
return f"{type_}.{format_}.{variant_}"
return f"{type_}.{format_}"

Expand Down
2 changes: 1 addition & 1 deletion tests/test_model_probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def test_minimal_working_example(datadir: Path):
assert isinstance(config, MinimalConfigExample)
assert config.base == BaseModelType.StableDiffusion1
assert config.path == model_path.as_posix()
assert config.quote == "Minimal working example of a ModelConfigBase subclass"
assert config.fun_quote == "Minimal working example of a ModelConfigBase subclass"



Expand Down

0 comments on commit 6eb4a2b

Please sign in to comment.