Skip to content

Commit

Permalink
New API for model classification
Browse files Browse the repository at this point in the history
  • Loading branch information
jazzhaiku committed Mar 4, 2025
1 parent 0ad0016 commit 78b8451
Show file tree
Hide file tree
Showing 15 changed files with 446 additions and 253 deletions.
2 changes: 1 addition & 1 deletion invokeai/app/invocations/baseinvocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None

ui_type = field.json_schema_extra.get("ui_type", None)
if isinstance(ui_type, str) and ui_type.startswith("DEPRECATED_"):
logger.warn(f"\"UIType.{ui_type.split('_')[-1]}\" is deprecated, ignoring")
logger.warn(f'"UIType.{ui_type.split("_")[-1]}" is deprecated, ignoring')
field.json_schema_extra.pop("ui_type")
return None

Expand Down
2 changes: 1 addition & 1 deletion invokeai/app/invocations/compel.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ def log_tokenization_for_text(
usedTokens += 1

if usedTokens > 0:
print(f'\n>> [TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
print(f"\n>> [TOKENLOG] Tokens {display_label or ''} ({usedTokens}):")
print(f"{tokenized}\x1b[0m")

if discarded != "":
Expand Down
19 changes: 14 additions & 5 deletions invokeai/app/services/model_install/model_install_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@
AnyModelConfig,
CheckpointConfigBase,
InvalidModelConfigException,
ModelConfigBase,
ModelRepoVariant,
ModelSourceType,
)
from invokeai.backend.model_manager.legacy_probe import ModelProbe
from invokeai.backend.model_manager.metadata import (
AnyModelRepoMetadata,
HuggingFaceMetadataFetch,
Expand All @@ -49,7 +51,6 @@
RemoteModelFile,
)
from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMetadata
from invokeai.backend.model_manager.probe import ModelProbe
from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.util import InvokeAILogger
from invokeai.backend.util.catch_sigint import catch_sigint
Expand Down Expand Up @@ -182,9 +183,7 @@ def install_path(
) -> str: # noqa D102
model_path = Path(model_path)
config = config or ModelRecordChanges()
info: AnyModelConfig = ModelProbe.probe(
Path(model_path), config.model_dump(), hash_algo=self._app_config.hashing_algorithm
) # type: ignore
info: AnyModelConfig = self._probe(Path(model_path), config) # type: ignore

if preferred_name := config.name:
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
Expand Down Expand Up @@ -644,12 +643,22 @@ def _move_model(self, old_path: Path, new_path: Path) -> Path:
move(old_path, new_path)
return new_path

def _probe(self, model_path: Path, config: Optional[ModelRecordChanges] = None):
config = config or ModelRecordChanges()
overrides = config.model_dump()
try:
return ModelConfigBase.classify(model_path, **overrides)
except InvalidModelConfigException:
return ModelProbe.probe(
model_path=model_path, fields=overrides, hash_algo=self._app_config.hashing_algorithm
) # type: ignore

def _register(
self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None
) -> str:
config = config or ModelRecordChanges()

info = info or ModelProbe.probe(model_path, config.model_dump(), hash_algo=self._app_config.hashing_algorithm) # type: ignore
info = info or self._probe(model_path, config)

model_path = model_path.resolve()

Expand Down
10 changes: 5 additions & 5 deletions invokeai/backend/image_util/pngwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ def normalize_prompt(self):

switches = []
switches.append(f'"{opt.prompt}"')
switches.append(f"-s{opt.steps or t2i.steps}")
switches.append(f"-W{opt.width or t2i.width}")
switches.append(f"-H{opt.height or t2i.height}")
switches.append(f"-C{opt.cfg_scale or t2i.cfg_scale}")
switches.append(f"-s{opt.steps or t2i.steps}")
switches.append(f"-W{opt.width or t2i.width}")
switches.append(f"-H{opt.height or t2i.height}")
switches.append(f"-C{opt.cfg_scale or t2i.cfg_scale}")
switches.append(f"-A{opt.sampler_name or t2i.sampler_name}")
# to do: put model name into the t2i object
# switches.append(f'--model{t2i.model_name}')
Expand All @@ -109,7 +109,7 @@ def normalize_prompt(self):
if opt.gfpgan_strength:
switches.append(f"-G{opt.gfpgan_strength}")
if opt.upscale:
switches.append(f'-U {" ".join([str(u) for u in opt.upscale])}')
switches.append(f"-U {' '.join([str(u) for u in opt.upscale])}")
if opt.variation_amount > 0:
switches.append(f"-v{opt.variation_amount}")
if opt.with_variations:
Expand Down
2 changes: 1 addition & 1 deletion invokeai/backend/model_manager/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
SchedulerPredictionType,
SubModelType,
)
from invokeai.backend.model_manager.legacy_probe import ModelProbe
from invokeai.backend.model_manager.load import LoadedModel
from invokeai.backend.model_manager.probe import ModelProbe
from invokeai.backend.model_manager.search import ModelSearch

__all__ = [
Expand Down
Loading

0 comments on commit 78b8451

Please sign in to comment.