diff --git a/ai_diffusion/__init__.py b/ai_diffusion/__init__.py index a1021d94f..b47c7f95e 100644 --- a/ai_diffusion/__init__.py +++ b/ai_diffusion/__init__.py @@ -1,6 +1,6 @@ """Generative AI plugin for Krita""" -__version__ = "1.27.1" +__version__ = "1.28.0" import importlib.util diff --git a/ai_diffusion/cloud_client.py b/ai_diffusion/cloud_client.py index a141466ad..a88dd1dba 100644 --- a/ai_diffusion/cloud_client.py +++ b/ai_diffusion/cloud_client.py @@ -7,6 +7,7 @@ from base64 import b64encode from datetime import datetime from dataclasses import dataclass +from itertools import chain from .api import WorkflowInput, WorkflowKind from .client import Client, ClientEvent, ClientMessage, ClientModels, DeviceInfo, CheckpointInfo @@ -18,7 +19,7 @@ from .settings import PerformanceSettings, settings from .localization import translate as _ from .util import clamp, ensure, client_logger as log -from . import __version__ as plugin_version +from . import resources, __version__ as plugin_version @dataclass @@ -346,22 +347,31 @@ def _base64_size(size: int): return math.ceil(size / 3) * 4 +def _checkpoint_info(id: str, arch: Arch): + models = chain(resources.default_checkpoints, resources.deprecated_models) + res = next(m for m in models if m.id.identifier == id and m.arch == arch) + return (res.filename, CheckpointInfo(res.filename, res.arch)) + + _poll_interval = 0.5 # seconds models = ClientModels() models.checkpoints = { - "dreamshaper_8.safetensors": CheckpointInfo("dreamshaper_8.safetensors", Arch.sd15), - "realisticVisionV51_v51VAE.safetensors": CheckpointInfo( - "realisticVisionV51_v51VAE.safetensors", Arch.sd15 - ), - "flat2DAnimerge_v45Sharp.safetensors": CheckpointInfo( - "flat2DAnimerge_v45Sharp.safetensors", Arch.sd15 - ), - "juggernautXL_version6Rundiffusion.safetensors": CheckpointInfo( - "juggernautXL_version6Rundiffusion.safetensors", Arch.sdxl - ), - "zavychromaxl_v80.safetensors": CheckpointInfo("zavychromaxl_v80.safetensors", Arch.sdxl), - "flux1-schnell-fp8.safetensors": CheckpointInfo("flux1-schnell-fp8.safetensors", Arch.flux), + filename: info + for filename, info in ( + _checkpoint_info(name, arch) + for name, arch in [ + ("dreamshaper", Arch.sd15), + ("realistic-vision", Arch.sd15), + ("serenity", Arch.sd15), + ("flat2d-animerge", Arch.sd15), + ("juggernaut", Arch.sdxl), + ("realvis", Arch.sdxl), + ("zavychroma", Arch.sdxl), + ("pixelwave", Arch.sdxl), + ("flux-schnell", Arch.flux), + ] + ) } models.vae = [] models.loras = [ diff --git a/ai_diffusion/resources.py b/ai_diffusion/resources.py index a5d242e6a..436c447c9 100644 --- a/ai_diffusion/resources.py +++ b/ai_diffusion/resources.py @@ -9,7 +9,7 @@ version = "1.28.0" comfy_url = "https://github.com/comfyanonymous/ComfyUI" -comfy_version = "52810907e20e11b126642f5b4917406e7043e70a" +comfy_version = "5e29e7a488b3f48afc6c4a3cb8ed110976d0ebb8" class CustomNode(NamedTuple): @@ -46,7 +46,7 @@ class CustomNode(NamedTuple): "Inpaint Nodes", "comfyui-inpaint-nodes", "https://github.com/Acly/comfyui-inpaint-nodes", - "146a2f17b1f91eb155011ab36aa349c696b6e38b", + "422eccd86685e084b551fb7e14bc025d77a64cc2", ["INPAINT_LoadFooocusInpaint", "INPAINT_ApplyFooocusInpaint", "INPAINT_ExpandMask"], ), ] @@ -56,7 +56,7 @@ class CustomNode(NamedTuple): "GGUF", "ComfyUI-GGUF", "https://github.com/city96/ComfyUI-GGUF", - "98333480059a2ccafb4718924ebcb9cdcb9b1f43", + "8e898fad4caab59bf4144e0cf11978b893de7e54", ["UnetLoaderGGUF", "DualCLIPLoaderGGUF"], ) ] @@ -437,16 +437,16 @@ def __hash__(self): default_checkpoints = [ ModelResource( - "Realistic Vision (Photography)", - ResourceId(ResourceKind.checkpoint, Arch.sd15, "realistic-vision"), + "Serenity (SD1.5 - Photography)", + ResourceId(ResourceKind.checkpoint, Arch.sd15, "serenity"), { Path( - "models/checkpoints/realisticVisionV51_v51VAE.safetensors" - ): "https://huggingface.co/lllyasviel/fav_models/resolve/main/fav/realisticVisionV51_v51VAE.safetensors", + "models/checkpoints/serenity_v21Safetensors.safetensors" + ): "https://huggingface.co/Acly/SD-Checkpoints/resolve/main/serenity_v21Safetensors.safetensors" }, ), ModelResource( - "DreamShaper (Artwork)", + "DreamShaper (SD1.5 - Artwork)", ResourceId(ResourceKind.checkpoint, Arch.sd15, "dreamshaper"), { Path( @@ -455,7 +455,7 @@ def __hash__(self): }, ), ModelResource( - "Flat2D AniMerge (Cartoon/Anime)", + "Flat2D AniMerge (SD1.5 - Cartoon/Anime)", ResourceId(ResourceKind.checkpoint, Arch.sd15, "flat2d-animerge"), { Path( @@ -464,16 +464,16 @@ def __hash__(self): }, ), ModelResource( - "Juggernaut XL", - ResourceId(ResourceKind.checkpoint, Arch.sdxl, "juggernaut"), + "RealVis (SDXL - Photography)", + ResourceId(ResourceKind.checkpoint, Arch.sdxl, "realvis"), { Path( - "models/checkpoints/juggernautXL_version6Rundiffusion.safetensors" - ): "https://huggingface.co/lllyasviel/fav_models/resolve/main/fav/juggernautXL_version6Rundiffusion.safetensors" + "models/checkpoints/RealVisXL_V5.0_fp16.safetensors" + ): "https://huggingface.co/SG161222/RealVisXL_V5.0/resolve/main/RealVisXL_V5.0_fp16.safetensors" }, ), ModelResource( - "ZavyChroma XL", + "ZavyChroma (SDXL - Artwork)", ResourceId(ResourceKind.checkpoint, Arch.sdxl, "zavychroma"), { Path( @@ -481,6 +481,15 @@ def __hash__(self): ): "https://huggingface.co/misri/zavychromaxl_v80/resolve/main/zavychromaxl_v80.safetensors" }, ), + ModelResource( + "Pixelwave (SDXL - Artwork)", + ResourceId(ResourceKind.checkpoint, Arch.sdxl, "pixelwave"), + { + Path( + "models/checkpoints/pixelwave_11.safetensors" + ): "https://huggingface.co/Acly/SD-Checkpoints/resolve/main/pixelwave_11.safetensors" + }, + ), ModelResource( "Flux [dev]", ResourceId(ResourceKind.checkpoint, Arch.flux, "flux-dev"), @@ -768,6 +777,24 @@ def __hash__(self): ): "https://huggingface.co/latent-consistency/lcm-lora-sdxl/resolve/main/pytorch_lora_weights.safetensors", }, ), + ModelResource( + "Realistic Vision (SD1.5 - Photography)", + ResourceId(ResourceKind.checkpoint, Arch.sd15, "realistic-vision"), + { + Path( + "models/checkpoints/realisticVisionV51_v51VAE.safetensors" + ): "https://huggingface.co/lllyasviel/fav_models/resolve/main/fav/realisticVisionV51_v51VAE.safetensors", + }, + ), + ModelResource( + "Juggernaut XL (Old)", + ResourceId(ResourceKind.checkpoint, Arch.sdxl, "juggernaut"), + { + Path( + "models/checkpoints/juggernautXL_version6Rundiffusion.safetensors" + ): "https://huggingface.co/lllyasviel/fav_models/resolve/main/fav/juggernautXL_version6Rundiffusion.safetensors" + }, + ), ] diff --git a/ai_diffusion/styles/cinematic-photo-xl.json b/ai_diffusion/styles/cinematic-photo-xl.json index c86fa5c8f..5f51cd6c0 100644 --- a/ai_diffusion/styles/cinematic-photo-xl.json +++ b/ai_diffusion/styles/cinematic-photo-xl.json @@ -3,6 +3,7 @@ "version": 2, "architecture": "auto", "checkpoints": [ + "RealVisXL_V5.0_fp16.safetensors", "juggernautXL_juggXIByRundiffusion.safetensors", "Juggernaut_X_RunDiffusion.safetensors", "juggernautXL_v9Rundiffusionphoto2.safetensors", diff --git a/scripts/download_models.py b/scripts/download_models.py index 644281b03..aaa4bfa5e 100644 --- a/scripts/download_models.py +++ b/scripts/download_models.py @@ -112,6 +112,7 @@ async def main( checkpoints=[], controlnet=False, prefetch=False, + deprecated=False, minimal=False, recommended=False, all=False, @@ -151,6 +152,8 @@ async def main( models.update([m for m in optional_models if m.kind in kinds and m.arch in versions]) if prefetch or all: models.update(resources.prefetch_models) + if deprecated: + models.update([m for m in resources.deprecated_models if m.arch in versions]) models = models - set([m for m in models if m.id.string in exclude]) @@ -199,6 +202,7 @@ async def main( parser.add_argument("--checkpoint", action="append", choices=checkpoint_names, dest="checkpoint_list", help="download a specific checkpoint (can specify multiple times)") parser.add_argument("--upscalers", action="store_true", help="download additional upscale models") parser.add_argument("--prefetch", action="store_true", help="download models which would be automatically downloaded on first use") + parser.add_argument("--deprecated", action="store_true", help="download old models which will be removed in the near future") parser.add_argument("--retry-attempts", type=int, default=5, metavar="N", help="number of retry attempts for downloading a model") parser.add_argument("--continue-on-error", action="store_true", help="continue downloading models even if an error occurs") # fmt: on @@ -222,6 +226,7 @@ async def main( checkpoints=checkpoints, controlnet=args.controlnet, prefetch=args.prefetch, + deprecated=args.deprecated, minimal=args.minimal, recommended=args.recommended, all=args.all,