Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
jazzhaiku committed Mar 4, 2025
1 parent be73f25 commit 8fdd99a
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 29 deletions.
37 changes: 36 additions & 1 deletion invokeai/backend/model_manager/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,20 @@

import diffusers
import onnxruntime as ort
import safetensors.torch
import torch
from diffusers.models.modeling_utils import ModelMixin
from picklescan.scanner import scan_file_path
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
from typing_extensions import Annotated, Any, Dict

from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_hash.hash_validator import validate_hash
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
from invokeai.backend.raw_model import RawModel
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
from invokeai.backend.util.silence_warnings import SilenceWarnings

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -199,6 +202,36 @@ class ControlAdapterDefaultSettings(BaseModel):
model_config = ConfigDict(extra="forbid")


class ModelOnDisk():
"""A utility class representing a model stored on disk."""
def __init__(self, path: Path):
self.path = path
self.format_type = ModelFormat.Diffusers if path.is_dir() else ModelFormat.Checkpoint
if self.path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
self.name = path.stem
else:
self.name = path.name

def lazy_load_state_dict(self) -> dict[str, torch.Tensor]:
if self.format_type == ModelFormat.Diffusers:
raise NotImplementedError()

with SilenceWarnings():
if self.path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
scan_result = scan_file_path(self.path)
if scan_result.infected_files != 0 or scan_result.scan_err:
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")
checkpoint = torch.load(self.path, map_location="cpu")

elif self.path.suffix.endswith(".gguf"):
checkpoint = gguf_sd_loader(self.path, compute_dtype=torch.float32)
else:
checkpoint = safetensors.torch.load_file(self.path)

state_dict = checkpoint.get("state_dict") or checkpoint
return state_dict


class MatchSpeed(int, Enum):
"""Represents the estimated runtime speed of a config's 'matches' method."""
FAST = 0
Expand Down Expand Up @@ -610,3 +643,5 @@ def make_config(model_data: Dict[str, Any], timestamp: Optional[float] = None) -
model.converted_at = timestamp
validate_hash(model.hash)
return model # type: ignore


26 changes: 0 additions & 26 deletions invokeai/backend/model_manager/model_on_disk.py

This file was deleted.

3 changes: 1 addition & 2 deletions tests/test_model_probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
ModelFormat,
ModelType,
ModelVariantType,
concrete_subclasses,
concrete_subclasses, ModelOnDisk,
)
from invokeai.backend.model_manager.legacy_probe import (
CkptType,
Expand All @@ -26,7 +26,6 @@
get_default_settings_control_adapters,
get_default_settings_main,
)
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.search import ModelSearch


Expand Down

0 comments on commit 8fdd99a

Please sign in to comment.