diff --git a/.eslintrc.json b/.eslintrc.json index c86dbb749..4ca2afbdb 100644 --- a/.eslintrc.json +++ b/.eslintrc.json @@ -22,6 +22,7 @@ "default-case":"off", "no-await-in-loop":"off", "no-bitwise":"off", + "no-continue":"off", "no-confusing-arrow":"off", "no-console":"off", "no-empty":"off", diff --git a/.pylintrc b/.pylintrc index 4a2850664..78361f7c4 100644 --- a/.pylintrc +++ b/.pylintrc @@ -26,6 +26,7 @@ ignore-paths=/usr/lib/.*$, modules/omnigen, modules/onnx_impl, modules/pag, + modules/pixelsmith, modules/prompt_parser_xhinker.py, modules/pulid/eva_clip, modules/rife, @@ -36,6 +37,7 @@ ignore-paths=/usr/lib/.*$, modules/unipc, modules/xadapter, repositories, + extensions-builtin/Lora, extensions-builtin/sd-webui-agent-scheduler, extensions-builtin/sd-extension-chainner/nodes, extensions-builtin/sdnext-modernui/node_modules, @@ -150,6 +152,7 @@ disable=abstract-method, consider-using-min-builtin, consider-using-max-builtin, consider-using-sys-exit, + cyclic-import, dangerous-default-value, deprecated-pragma, duplicate-code, diff --git a/.ruff.toml b/.ruff.toml index a2ad0b91a..89bd1586d 100644 --- a/.ruff.toml +++ b/.ruff.toml @@ -19,6 +19,7 @@ exclude = [ "modules/meissonic", "modules/omnigen", "modules/pag", + "modules/pixelsmith", "modules/postprocess/aurasr_arch.py", "modules/prompt_parser_xhinker.py", "modules/pulid/eva_clip", @@ -31,6 +32,7 @@ exclude = [ "modules/unipc", "modules/xadapter", "repositories", + "extensions-builtin/Lora", "extensions-builtin/sd-extension-chainner/nodes", "extensions-builtin/sd-webui-agent-scheduler", "extensions-builtin/sdnext-modernui/node_modules", diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 000000000..5fe45a869 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,23 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "name": "SD.Next VSCode Debugger", + "type": "debugpy", + "request": "launch", + "program": "launch.py", + "cwd": "${workspaceFolder}", + "console": "integratedTerminal", + "env": { "USED_VSCODE_COMMAND_PICKARGS": "1" }, + "args": [ + "--uv", + "--quick", + "--debug", + "--docs", + "--api-log", + "--log", "vscode.log", + "${command:pickArgs}", + ] + } + ] +} diff --git a/CHANGELOG.md b/CHANGELOG.md index 55ecccf30..bd558174d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,22 +1,80 @@ # Change Log for SD.Next -## Update for 2025-01-02 +## Update for 2025-01-15 +### Highlights for 2025-01-15 + +Two weeks since last release, time for update! +This time a bit shorter highligh reel as this is primarily a service release, but still there is more than few updates +*(actually, there are ~60 commits, so its not that tiny)* + +*What's New?" +- Large [Wiki](https://github.com/vladmandic/automatic/wiki)/[Docs](https://vladmandic.github.io/sdnext-docs/) updates +- New models: **Allegro Video**, new pipelines: **PixelSmith**, updates: **Hunyuan-Video**, **LTX-Video**, **Sana 4k** +- New version for **ZLUDA** +- New features in **Detailer**, **XYZ grid**, **Sysinfo**, **Logging**, **Schedulers**, **Video save/create** +- And a tons of hotfixes... + +### Details for 2025-01-15 + +- [Wiki/Docs](https://vladmandic.github.io/sdnext-docs/): + - updated: Detailer, Install, Update, Debug, Control-HowTo, ZLUDA - [Allegro Video](https://huggingface.co/rhymes-ai/Allegro) - optimizations: full offload and quantization support - *reference values*: width 1280 height 720 frames 88 steps 100 guidance 7.5 - *note*: allegro model is really sensitive to input width/height/frames/steps and may result in completely corrupt output if those are not within expected range +- [PixelSmith](https://github.com/Thanos-DB/Pixelsmith/) + - available for SD-XL in txt2img and img2img workflows + - select from *scripts -> pixelsmith* +- [Hunyuan Video](https://github.com/Tencent/HunyuanVideo) LoRA support + - example: +- [LTX Video](https://github.com/Lightricks/LTX-Video) framewise decoding + - enabled by default, allows generating longer videos with reduced memory requirements +- [Sana 4k](https://huggingface.co/Efficient-Large-Model/Sana_1600M_4Kpx_BF16_diffusers) + - new Sana variation with support of directly generating 4k images + - simply select from *networks -> models -> reference* + - tip: enable vae tiling when generating very large images - **Logging**: - reverted enable debug by default - updated [debug wiki](https://github.com/vladmandic/automatic/wiki/debug) - sort logged timers by duration - allow min duration env variable for timers: `SD_MIN_TIMER=0.1` (default) - update installer messages +- **Refactor**: + - refactored progress monitoring, job updates and live preview + - improved metadata save and restore + - startup tracing and optimizations + - threading load locks on model loads + - refactor native vs legacy model loader + - video save/create +- **Schedulers**: + - [TDD](https://github.com/RedAIGC/Target-Driven-Distillation) new super-fast scheduler that can generate images in 4-8 steps + recommended to use with [TDD LoRA](https://huggingface.co/RED-AIGC/TDD/tree/main) - **Detailer**: + - add explicit detailer prompt and negative prompt - add explicit detailer steps setting -- **SysInfo**: - - update to collected data and benchmarks + - move steps, strength, prompt, negative from settings into ui params + - set/restore detailer metadata + - new [detailer wiki](https://github.com/vladmandic/automatic/wiki/Detailer) +- **Preview** + - since different TAESD versions produce different results and latest is not necessarily greatest + you can choose TAESD version in settings -> live preview + also added is support for another finetuned version of TAESD [Hybrid TinyVAE](https://huggingface.co/cqyan/hybrid-sd-tinyvae-xl) +- **Video** + - all video create/save code is now unified + - add support for video formats: GIF, PNG, MP4/MP4V, MP4/AVC1, MP4/JVT3, MKV/H264, AVI/DIVX, AVI/RGBA, MJPEG/MJPG, MPG/MPG1, AVR/AVR1 + - *note*: video format support is platform dependent and not all formats may be available on all platforms + - *note*: avc1 and h264 need custom opencv due to oss licensing issues +- **ZLUDA** v3.8.7 + - new runtime compiler implementation: complex types, JIT are now available + - fast fourier transformation is implemented + - experimental BLASLt support via nightly build + - set `ZLUDA_NIGHTLY=1` to install nightly ZLUDA: newer torch such as 2.4.x (default) and 2.5.x are now available + - requirements: unofficial hipBLASLt +- **Other** + - **XYZ Grid**: add prompt search&replace options: *primary, refine, detailer, all* + - **SysInfo**: update to collected data and benchmarks - **Fixes**: - explict clear caches on model load - lock adetailer commit: `#a89c01d` @@ -26,6 +84,21 @@ - sd35 img2img - samplers test for scale noise before using - scheduler api + - sampler create error handling + - controlnet with hires + - controlnet with batch count + - apply settings skip hidden settings + - lora diffusers method apply only once + - lora diffusers method set prompt tags and metadata + - flux support on-the-fly quantization for bnb of unet only + - control restore pipeline before running hires + - restore args after batch run + - flux controlnet + - zluda installer + - control inherit parent pipe settings + - control logging + - hf cache folder settings + - fluxfill should not require base model ## Update for 2024-12-31 diff --git a/cli/zluda-python.py b/cli/zluda-python.py index e0399d096..894489b74 100644 --- a/cli/zluda-python.py +++ b/cli/zluda-python.py @@ -27,10 +27,9 @@ def from_file(self, path): sys.path.append(os.getcwd()) from modules import zluda_installer - zluda_path = zluda_installer.get_path() - zluda_installer.install(zluda_path) - zluda_installer.make_copy(zluda_path) - zluda_installer.load(zluda_path) + zluda_installer.install() + zluda_installer.make_copy() + zluda_installer.load() import torch interpreter = Interpreter({ diff --git a/extensions-builtin/sd-webui-agent-scheduler b/extensions-builtin/sd-webui-agent-scheduler index 721a36f59..a33753321 160000 --- a/extensions-builtin/sd-webui-agent-scheduler +++ b/extensions-builtin/sd-webui-agent-scheduler @@ -1 +1 @@ -Subproject commit 721a36f59507e625c9982397c22edd7c14a0f62a +Subproject commit a33753321b914c6122df96d1dc0b5117d38af680 diff --git a/extensions-builtin/sdnext-modernui b/extensions-builtin/sdnext-modernui index 62d8a54a7..2960e3679 160000 --- a/extensions-builtin/sdnext-modernui +++ b/extensions-builtin/sdnext-modernui @@ -1 +1 @@ -Subproject commit 62d8a54a7ec24a6fbf69697da18e67035754d072 +Subproject commit 2960e36797dafb46545f5ad03364cb7003b84c7e diff --git a/html/previews.json b/html/previews.json index a4ddfa7c9..3b59d376c 100644 --- a/html/previews.json +++ b/html/previews.json @@ -11,6 +11,7 @@ "THUDM--CogVideoX-5b-I2V": "models/Reference/THUDM--CogView3-Plus-3B.jpg", "Efficient-Large-Model--Sana_1600M_1024px_BF16_diffusers": "models/Reference/Efficient-Large-Model--Sana_1600M_1024px_diffusers.jpg", "Efficient-Large-Model--Sana_1600M_2Kpx_BF16_diffusers": "models/Reference/Efficient-Large-Model--Sana_1600M_1024px_diffusers.jpg", + "Efficient-Large-Model--Sana_1600M_4Kpx_BF16_diffusers": "models/Reference/Efficient-Large-Model--Sana_1600M_1024px_diffusers.jpg", "Efficient-Large-Model--Sana_600M_1024px_diffusers": "models/Reference/Efficient-Large-Model--Sana_1600M_1024px_diffusers.jpg", "stabilityai--stable-video-diffusion-img2vid-xt-1-1": "models/Reference/stabilityai--stable-video-diffusion-img2vid-xt.jpg", "shuttleai--shuttle-3-diffusion": "models/Reference/shuttleai--shuttle-3-diffusion.jpg" diff --git a/html/reference.json b/html/reference.json index 43115c549..ca55081a9 100644 --- a/html/reference.json +++ b/html/reference.json @@ -180,19 +180,25 @@ "extras": "sampler: Default, cfg_scale: 3.5" }, - "NVLabs Sana 1.6B 2048px": { + "NVLabs Sana 1.6B 4k": { + "path": "Efficient-Large-Model/Sana_1600M_4Kpx_BF16_diffusers", + "desc": "Sana is a text-to-image framework that can efficiently generate images up to 4096 × 4096 resolution. Sana can synthesize high-resolution, high-quality images with strong text-image alignment at a remarkably fast speed, deployable on laptop GPU.", + "preview": "Efficient-Large-Model--Sana_1600M_1024px_diffusers.jpg", + "skip": true + }, + "NVLabs Sana 1.6B 2k": { "path": "Efficient-Large-Model/Sana_1600M_2Kpx_BF16_diffusers", "desc": "Sana is a text-to-image framework that can efficiently generate images up to 4096 × 4096 resolution. Sana can synthesize high-resolution, high-quality images with strong text-image alignment at a remarkably fast speed, deployable on laptop GPU.", "preview": "Efficient-Large-Model--Sana_1600M_1024px_diffusers.jpg", "skip": true }, - "NVLabs Sana 1.6B 1024px": { + "NVLabs Sana 1.6B 1k": { "path": "Efficient-Large-Model/Sana_1600M_1024px_diffusers", "desc": "Sana is a text-to-image framework that can efficiently generate images up to 4096 × 4096 resolution. Sana can synthesize high-resolution, high-quality images with strong text-image alignment at a remarkably fast speed, deployable on laptop GPU.", "preview": "Efficient-Large-Model--Sana_1600M_1024px_diffusers.jpg", "skip": true }, - "NVLabs Sana 0.6B 512px": { + "NVLabs Sana 0.6B 0.5k": { "path": "Efficient-Large-Model/Sana_600M_512px_diffusers", "desc": "Sana is a text-to-image framework that can efficiently generate images up to 4096 × 4096 resolution. Sana can synthesize high-resolution, high-quality images with strong text-image alignment at a remarkably fast speed, deployable on laptop GPU.", "preview": "Efficient-Large-Model--Sana_1600M_1024px_diffusers.jpg", diff --git a/installer.py b/installer.py index 11e7f1d94..e925f2c30 100644 --- a/installer.py +++ b/installer.py @@ -8,6 +8,7 @@ import platform import subprocess import cProfile +import importlib # pylint: disable=deprecated-module class Dot(dict): # dot notation access to dictionary attributes @@ -58,6 +59,14 @@ class Dot(dict): # dot notation access to dictionary attributes # 'stable-diffusion-webui-images-browser': '27fe4a7', } + +try: + from modules.timer import init + ts = init.ts +except Exception: + ts = lambda *args, **kwargs: None # pylint: disable=unnecessary-lambda-assignment + + # setup console and file logging def setup_logging(): @@ -84,6 +93,7 @@ def emit(self, record): def get(self): return self.buffer + t_start = time.time() from functools import partial, partialmethod from logging.handlers import RotatingFileHandler from rich.theme import Theme @@ -140,6 +150,11 @@ def get(self): log.addHandler(rb) log.buffer = rb.buffer + def quiet_log(quiet: bool=False, *args, **kwargs): # pylint: disable=redefined-outer-name,keyword-arg-before-vararg + if not quiet: + log.debug(*args, **kwargs) + log.quiet = quiet_log + # overrides logging.getLogger("urllib3").setLevel(logging.ERROR) logging.getLogger("httpx").setLevel(logging.ERROR) @@ -148,6 +163,7 @@ def get(self): logging.getLogger("ControlNet").handlers = log.handlers logging.getLogger("lycoris").handlers = log.handlers # logging.getLogger("DeepSpeed").handlers = log.handlers + ts('log', t_start) def get_logfile(): @@ -201,11 +217,11 @@ def package_spec(package): # check if package is installed @lru_cache() def installed(package, friendly: str = None, reload = False, quiet = False): + t_start = time.time() ok = True try: if reload: try: - import importlib # pylint: disable=deprecated-module importlib.reload(pkg_resources) except Exception: pass @@ -237,13 +253,15 @@ def installed(package, friendly: str = None, reload = False, quiet = False): else: if not quiet: log.debug(f'Install: package="{p[0]}" install required') + ts('installed', t_start) return ok except Exception as e: log.error(f'Install: package="{pkgs}" {e}') + ts('installed', t_start) return False - def uninstall(package, quiet = False): + t_start = time.time() packages = package if isinstance(package, list) else [package] res = '' for p in packages: @@ -251,11 +269,13 @@ def uninstall(package, quiet = False): if not quiet: log.warning(f'Package: {p} uninstall') res += pip(f"uninstall {p} --yes --quiet", ignore=True, quiet=True, uv=False) + ts('uninstall', t_start) return res @lru_cache() def pip(arg: str, ignore: bool = False, quiet: bool = True, uv = True): + t_start = time.time() originalArg = arg arg = arg.replace('>=', '==') package = arg.replace("install", "").replace("--upgrade", "").replace("--no-deps", "").replace("--force", "").replace(" ", " ").strip() @@ -283,12 +303,14 @@ def pip(arg: str, ignore: bool = False, quiet: bool = True, uv = True): errors.append(f'pip: {package}') log.error(f'Install: {pipCmd}: {arg}') log.debug(f'Install: pip output {txt}') + ts('pip', t_start) return txt # install package using pip if not already installed @lru_cache() def install(package, friendly: str = None, ignore: bool = False, reinstall: bool = False, no_deps: bool = False, quiet: bool = False): + t_start = time.time() res = '' if args.reinstall or args.upgrade: global quick_allowed # pylint: disable=global-statement @@ -297,16 +319,17 @@ def install(package, friendly: str = None, ignore: bool = False, reinstall: bool deps = '' if not no_deps else '--no-deps ' res = pip(f"install{' --upgrade' if not args.uv else ''} {deps}{package}", ignore=ignore, uv=package != "uv" and not package.startswith('git+')) try: - import importlib # pylint: disable=deprecated-module importlib.reload(pkg_resources) except Exception: pass + ts('install', t_start) return res # execute git command @lru_cache() def git(arg: str, folder: str = None, ignore: bool = False, optional: bool = False): + t_start = time.time() if args.skip_git: return '' if optional: @@ -328,6 +351,7 @@ def git(arg: str, folder: str = None, ignore: bool = False, optional: bool = Fal if 'or stash them' in txt: log.error(f'Git local changes detected: check details log="{log_file}"') log.debug(f'Git output: {txt}') + ts('git', t_start) return txt @@ -335,6 +359,7 @@ def git(arg: str, folder: str = None, ignore: bool = False, optional: bool = Fal def branch(folder=None): # if args.experimental: # return None + t_start = time.time() if not os.path.exists(os.path.join(folder or os.curdir, '.git')): return None branches = [] @@ -357,11 +382,13 @@ def branch(folder=None): b = b.split('\n')[0].replace('*', '').strip() log.debug(f'Git submodule: {folder} / {b}') git(f'checkout {b}', folder, ignore=True, optional=True) + ts('branch', t_start) return b # update git repository def update(folder, keep_branch = False, rebase = True): + t_start = time.time() try: git('config rebase.Autostash true') except Exception: @@ -383,11 +410,13 @@ def update(folder, keep_branch = False, rebase = True): if commit is not None: res = git(f'checkout {commit}', folder) debug(f'Install update: folder={folder} branch={b} args={arg} commit={commit} {res}') + ts('update', t_start) return res # clone git repository def clone(url, folder, commithash=None): + t_start = time.time() if os.path.exists(folder): if commithash is None: update(folder) @@ -403,6 +432,7 @@ def clone(url, folder, commithash=None): git(f'clone "{url}" "{folder}"') if commithash is not None: git(f'-C "{folder}" checkout {commithash}') + ts('clone', t_start) def get_platform(): @@ -427,6 +457,7 @@ def get_platform(): # check python version def check_python(supported_minors=[9, 10, 11, 12], reason=None): + t_start = time.time() if args.quick: return log.info(f'Python: version={platform.python_version()} platform={platform.system()} bin="{sys.executable}" venv="{sys.prefix}"') @@ -453,59 +484,51 @@ def check_python(supported_minors=[9, 10, 11, 12], reason=None): else: git_version = git('--version', folder=None, ignore=False) log.debug(f'Git: version={git_version.replace("git version", "").strip()}') + ts('python', t_start) # check diffusers version def check_diffusers(): + t_start = time.time() if args.skip_all or args.skip_git: return - sha = '6dfaec348780c6153a4cfd03a01972a291d67f82' # diffusers commit hash + sha = 'c944f0651f679728d4ec7b6488120ac49c2f1315' # diffusers commit hash pkg = pkg_resources.working_set.by_key.get('diffusers', None) minor = int(pkg.version.split('.')[1] if pkg is not None else 0) cur = opts.get('diffusers_version', '') if minor > 0 else '' if (minor == 0) or (cur != sha): - log.info(f'Diffusers {"install" if minor == 0 else "upgrade"}: package={pkg} current={cur} target={sha}') - if minor > 0: + if minor == 0: + log.info(f'Diffusers install: commit={sha}') + else: + log.info(f'Diffusers update: package={pkg} current={cur} target={sha}') pip('uninstall --yes diffusers', ignore=True, quiet=True, uv=False) pip(f'install --upgrade git+https://github.com/huggingface/diffusers@{sha}', ignore=False, quiet=True, uv=False) global diffusers_commit # pylint: disable=global-statement diffusers_commit = sha + ts('diffusers', t_start) # check onnx version def check_onnx(): + t_start = time.time() if args.skip_all or args.skip_requirements: return if not installed('onnx', quiet=True): install('onnx', 'onnx', ignore=True) if not installed('onnxruntime', quiet=True) and not (installed('onnxruntime-gpu', quiet=True) or installed('onnxruntime-openvino', quiet=True) or installed('onnxruntime-training', quiet=True)): # allow either install('onnxruntime', 'onnxruntime', ignore=True) - - -def check_torchao(): - """ - if args.skip_all or args.skip_requirements: - return - if installed('torchao', quiet=True): - ver = package_version('torchao') - if ver != '0.5.0': - log.debug(f'Uninstall: torchao=={ver}') - pip('uninstall --yes torchao', ignore=True, quiet=True, uv=False) - for m in [m for m in sys.modules if m.startswith('torchao')]: - del sys.modules[m] - """ - return + ts('onnx', t_start) def install_cuda(): + t_start = time.time() log.info('CUDA: nVidia toolkit detected') - if not (args.skip_all or args.skip_requirements): - install('onnxruntime-gpu', 'onnxruntime-gpu', ignore=True, quiet=True) - # return os.environ.get('TORCH_COMMAND', 'torch torchvision --index-url https://download.pytorch.org/whl/cu124') + ts('cuda', t_start) return os.environ.get('TORCH_COMMAND', 'torch==2.5.1+cu124 torchvision==0.20.1+cu124 --index-url https://download.pytorch.org/whl/cu124') def install_rocm_zluda(): + t_start = time.time() if args.skip_all or args.skip_requirements: return None from modules import rocm @@ -551,6 +574,7 @@ def install_rocm_zluda(): msg += f', using agent {device.name}' log.info(msg) torch_command = '' + if sys.platform == "win32": # TODO install: enable ROCm for windows when available @@ -564,19 +588,22 @@ def install_rocm_zluda(): from modules import zluda_installer zluda_installer.set_default_agent(device) try: - if args.reinstall: + if args.reinstall or zluda_installer.is_old_zluda(): zluda_installer.uninstall() - zluda_path = zluda_installer.get_path() - zluda_installer.install(zluda_path) - zluda_installer.make_copy(zluda_path) + zluda_installer.install() except Exception as e: error = e log.warning(f'Failed to install ZLUDA: {e}') + if error is None: try: - zluda_installer.load(zluda_path) + if device is not None and zluda_installer.get_blaslt_enabled(): + log.debug(f'ROCm hipBLASLt: arch={device.name} available={device.blaslt_supported}') + zluda_installer.set_blaslt_enabled(device.blaslt_supported) + zluda_installer.make_copy() + zluda_installer.load() torch_command = os.environ.get('TORCH_COMMAND', f'torch=={zluda_installer.get_default_torch_version(device)} torchvision --index-url https://download.pytorch.org/whl/cu118') - log.info(f'Using ZLUDA in {zluda_path}') + log.info(f'Using ZLUDA in {zluda_installer.path}') except Exception as e: error = e log.warning(f'Failed to load ZLUDA: {e}') @@ -613,7 +640,7 @@ def install_rocm_zluda(): #elif not args.experimental: # uninstall('flash-attn') - if device is not None and rocm.version != "6.2" and rocm.version == rocm.version_torch and rocm.get_blaslt_enabled(): + if device is not None and rocm.version != "6.2" and rocm.get_blaslt_enabled(): log.debug(f'ROCm hipBLASLt: arch={device.name} available={device.blaslt_supported}') rocm.set_blaslt_enabled(device.blaslt_supported) @@ -626,10 +653,12 @@ def install_rocm_zluda(): else: log.warning(f'ROCm: device={device.name} could not auto-detect HSA version') + ts('amd', t_start) return torch_command def install_ipex(torch_command): + t_start = time.time() check_python(supported_minors=[10,11], reason='IPEX backend requires Python 3.10 or 3.11') args.use_ipex = True # pylint: disable=attribute-defined-outside-init log.info('IPEX: Intel OneAPI toolkit detected') @@ -648,10 +677,12 @@ def install_ipex(torch_command): install(os.environ.get('OPENVINO_PACKAGE', 'openvino==2024.5.0'), 'openvino', ignore=True) install('nncf==2.7.0', ignore=True, no_deps=True) # requires older pandas install(os.environ.get('ONNXRUNTIME_PACKAGE', 'onnxruntime-openvino'), 'onnxruntime-openvino', ignore=True) + ts('ipex', t_start) return torch_command def install_openvino(torch_command): + t_start = time.time() check_python(supported_minors=[9, 10, 11, 12], reason='OpenVINO backend requires Python 3.9, 3.10 or 3.11') log.info('OpenVINO: selected') if sys.platform == 'darwin': @@ -666,10 +697,12 @@ def install_openvino(torch_command): os.environ.setdefault('NEOReadDebugKeys', '1') if os.environ.get("ClDeviceGlobalMemSizeAvailablePercent", None) is None: os.environ.setdefault('ClDeviceGlobalMemSizeAvailablePercent', '100') + ts('openvino', t_start) return torch_command def install_torch_addons(): + t_start = time.time() xformers_package = os.environ.get('XFORMERS_PACKAGE', '--pre xformers') if opts.get('cross_attention_optimization', '') == 'xFormers' or args.use_xformers else 'none' triton_command = os.environ.get('TRITON_COMMAND', 'triton') if sys.platform == 'linux' else None if 'xformers' in xformers_package: @@ -695,10 +728,12 @@ def install_torch_addons(): uninstall('wandb', quiet=True) if triton_command is not None: install(triton_command, 'triton', quiet=True) + ts('addons', t_start) # check torch version def check_torch(): + t_start = time.time() if args.skip_torch: log.info('Torch: skip tests') return @@ -810,10 +845,12 @@ def check_torch(): if args.profile: pr.disable() print_profile(pr, 'Torch') + ts('torch', t_start) # check modified files def check_modified_files(): + t_start = time.time() if args.quick: return if args.skip_git: @@ -830,10 +867,12 @@ def check_modified_files(): log.warning(f'Modified files: {files}') except Exception: pass + ts('files', t_start) # install required packages def install_packages(): + t_start = time.time() if args.profile: pr = cProfile.Profile() pr.enable() @@ -848,6 +887,7 @@ def install_packages(): if args.profile: pr.disable( ) print_profile(pr, 'Packages') + ts('packages', t_start) # run extension installer @@ -871,6 +911,7 @@ def run_extension_installer(folder): except Exception as e: log.error(f'Extension installer exception: {e}') + # get list of all enabled extensions def list_extensions_folder(folder, quiet=False): name = os.path.basename(folder) @@ -886,6 +927,7 @@ def list_extensions_folder(folder, quiet=False): # run installer for each installed and enabled extension and optionally update them def install_extensions(force=False): + t_start = time.time() if args.profile: pr = cProfile.Profile() pr.enable() @@ -934,11 +976,13 @@ def install_extensions(force=False): if args.profile: pr.disable() print_profile(pr, 'Extensions') + ts('extensions', t_start) return '\n'.join(res) # initialize and optionally update submodules def install_submodules(force=True): + t_start = time.time() if args.profile: pr = cProfile.Profile() pr.enable() @@ -966,18 +1010,18 @@ def install_submodules(force=True): if args.profile: pr.disable() print_profile(pr, 'Submodule') + ts('submodules', t_start) return '\n'.join(res) def ensure_base_requirements(): + t_start = time.time() setuptools_version = '69.5.1' def update_setuptools(): - # print('Install base requirements') global pkg_resources, setuptools, distutils # pylint: disable=global-statement # python may ship with incompatible setuptools subprocess.run(f'"{sys.executable}" -m pip install setuptools=={setuptools_version}', shell=True, check=False, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - import importlib # need to delete all references to modules to be able to reload them otherwise python will use cached version modules = [m for m in sys.modules if m.startswith('setuptools') or m.startswith('pkg_resources') or m.startswith('distutils')] for m in modules: @@ -1002,13 +1046,16 @@ def update_setuptools(): install('rich', 'rich', quiet=True) install('psutil', 'psutil', quiet=True) install('requests', 'requests', quiet=True) + ts('base', t_start) def install_optional(): + t_start = time.time() log.info('Installing optional requirements...') install('basicsr') install('gfpgan') install('clean-fid') + install('pillow-jxl-plugin==1.3.1', ignore=True) install('optimum-quanto=0.2.6', ignore=True) install('bitsandbytes==0.45.0', ignore=True) install('pynvml', ignore=True) @@ -1025,9 +1072,11 @@ def install_optional(): os.rename(scripts_dir, scripts_dir + '_gguf') except Exception: pass + ts('optional', t_start) def install_requirements(): + t_start = time.time() if args.profile: pr = cProfile.Profile() pr.enable() @@ -1051,6 +1100,7 @@ def install_requirements(): if args.profile: pr.disable() print_profile(pr, 'Requirements') + ts('requirements', t_start) # set environment variables controling the behavior of various libraries @@ -1115,14 +1165,15 @@ def check_extensions(): for f in os.listdir(extension_dir): if '.json' in f or '.csv' in f or '__pycache__' in f: continue - ts = os.path.getmtime(os.path.join(extension_dir, f)) - newest = max(newest, ts) + mtime = os.path.getmtime(os.path.join(extension_dir, f)) + newest = max(newest, mtime) newest_all = max(newest_all, newest) # log.debug(f'Extension version: {time.ctime(newest)} {folder}{os.pathsep}{ext}') return round(newest_all) def get_version(force=False): + t_start = time.time() global version # pylint: disable=global-statement if version is None or force: try: @@ -1157,6 +1208,7 @@ def get_version(force=False): except Exception: os.chdir(cwd) version['ui'] = 'unknown' + ts('version', t_start) return version @@ -1166,6 +1218,7 @@ def same(ver): ui = ver['ui'] if ver is not None and 'ui' in ver else 'unknown' return core == ui or (core == 'master' and ui == 'main') + t_start = time.time() if not same(ver): log.debug(f'Branch mismatch: sdnext={ver["branch"]} ui={ver["ui"]}') cwd = os.getcwd() @@ -1182,6 +1235,7 @@ def same(ver): except Exception as e: log.debug(f'Branch switch: {e}') os.chdir(cwd) + ts('ui', t_start) def check_venv(): @@ -1191,6 +1245,7 @@ def try_relpath(p): except ValueError: return p + t_start = time.time() import site pkg_path = [try_relpath(p) for p in site.getsitepackages() if os.path.exists(p)] log.debug(f'Packages: venv={try_relpath(sys.prefix)} site={pkg_path}') @@ -1210,10 +1265,12 @@ def try_relpath(p): os.unlink(fn) except Exception as e: log.error(f'Packages: site={p} invalid={f} error={e}') + ts('venv', t_start) # check version of the main repo and optionally upgrade it def check_version(offline=False, reset=True): # pylint: disable=unused-argument + t_start = time.time() if args.skip_all: return if not os.path.exists('.git'): @@ -1259,15 +1316,18 @@ def check_version(offline=False, reset=True): # pylint: disable=unused-argument log.info(f'Repository latest available {commits["commit"]["sha"]} {commits["commit"]["commit"]["author"]["date"]}') except Exception as e: log.error(f'Repository failed to check version: {e} {commits}') + ts('latest', t_start) def update_wiki(): + t_start = time.time() if args.upgrade: log.info('Updating Wiki') try: update(os.path.join(os.path.dirname(__file__), "wiki")) except Exception: log.error('Wiki update error') + ts('wiki', t_start) # check if we can run setup in quick mode @@ -1347,15 +1407,25 @@ def add_args(parser): group_log.add_argument('--docs', default=os.environ.get("SD_DOCS", False), action='store_true', help="Mount API docs, default: %(default)s") group_log.add_argument("--api-log", default=os.environ.get("SD_APILOG", True), action='store_true', help="Log all API requests") + group_nargs = parser.add_argument_group('Other') + group_nargs.add_argument('args', type=str, nargs='*') + def parse_args(parser): # command line args global args # pylint: disable=global-statement - args = parser.parse_args() + if "USED_VSCODE_COMMAND_PICKARGS" in os.environ: + import shlex + argv = shlex.split(" ".join(sys.argv[1:])) if "USED_VSCODE_COMMAND_PICKARGS" in os.environ else sys.argv[1:] + log.debug('VSCode Launch') + args = parser.parse_args(argv) + else: + args = parser.parse_args() return args def extensions_preload(parser): + t_start = time.time() if args.profile: pr = cProfile.Profile() pr.enable() @@ -1377,9 +1447,11 @@ def extensions_preload(parser): if args.profile: pr.disable() print_profile(pr, 'Preload') + ts('preload', t_start) def git_reset(folder='.'): + t_start = time.time() log.warning('Running GIT reset') global quick_allowed # pylint: disable=global-statement quick_allowed = False @@ -1395,9 +1467,11 @@ def git_reset(folder='.'): git('submodule update --init --recursive') git('submodule sync --recursive') log.info('GIT reset complete') + ts('reset', t_start) def read_options(): + t_start = time.time() global opts # pylint: disable=global-statement if os.path.isfile(args.config): with open(args.config, "r", encoding="utf8") as file: @@ -1407,3 +1481,4 @@ def read_options(): opts = json.loads(opts) except Exception as e: log.error(f'Error reading options file: {file} {e}') + ts('options', t_start) diff --git a/javascript/base.css b/javascript/base.css index a30a71845..cadfd1868 100644 --- a/javascript/base.css +++ b/javascript/base.css @@ -129,7 +129,7 @@ div:has(>#tab-browser-folders) { flex-grow: 0 !important; background-color: var( /* loader */ .splash { position: fixed; top: 0; left: 0; width: 100vw; height: 100vh; z-index: 1000; display: block; text-align: center; } .motd { margin-top: 2em; color: var(--body-text-color-subdued); font-family: monospace; font-variant: all-petite-caps; } -.splash-img { margin: 10% auto 0 auto; width: 512px; background-repeat: no-repeat; height: 512px; animation: color 10s infinite alternate; } +.splash-img { margin: 10% auto 0 auto; width: 512px; background-repeat: no-repeat; height: 512px; animation: color 10s infinite alternate; max-width: 80vw; background-size: contain; } .loading { color: white; position: absolute; top: 20%; left: 50%; transform: translateX(-50%); } .loader { width: 300px; height: 300px; border: var(--spacing-md) solid transparent; border-radius: 50%; border-top: var(--spacing-md) solid var(--primary-600); animation: spin 4s linear infinite; position: relative; } .loader::before, .loader::after { content: ""; position: absolute; top: 6px; bottom: 6px; left: 6px; right: 6px; border-radius: 50%; border: var(--spacing-md) solid transparent; } diff --git a/javascript/logger.js b/javascript/logger.js index 8fa812b86..08baf1165 100644 --- a/javascript/logger.js +++ b/javascript/logger.js @@ -1,4 +1,4 @@ -const timeout = 10000; +const timeout = 30000; const log = async (...msg) => { const dt = new Date(); diff --git a/javascript/progressBar.js b/javascript/progressBar.js index 27404b440..0bac99d6f 100644 --- a/javascript/progressBar.js +++ b/javascript/progressBar.js @@ -20,29 +20,36 @@ function checkPaused(state) { function setProgress(res) { const elements = ['txt2img_generate', 'img2img_generate', 'extras_generate', 'control_generate']; - const progress = (res?.progress || 0); - let job = res?.job || ''; - job = job.replace('txt2img', 'Generate').replace('img2img', 'Generate'); - const perc = res && (progress > 0) ? `${Math.round(100.0 * progress)}%` : ''; - let sec = res?.eta || 0; + const progress = res?.progress || 0; + const job = res?.job || ''; + let perc = ''; let eta = ''; - if (res?.paused) eta = 'Paused'; - else if (res?.completed || (progress > 0.99)) eta = 'Finishing'; - else if (sec === 0) eta = 'Starting'; + if (job === 'VAE') perc = 'Decode'; else { - const min = Math.floor(sec / 60); - sec %= 60; - eta = min > 0 ? `${Math.round(min)}m ${Math.round(sec)}s` : `${Math.round(sec)}s`; + perc = res && (progress > 0) && (progress < 1) ? `${Math.round(100.0 * progress)}% ` : ''; + let sec = res?.eta || 0; + if (res?.paused) eta = 'Paused'; + else if (res?.completed || (progress > 0.99)) eta = 'Finishing'; + else if (sec === 0) eta = 'Start'; + else { + const min = Math.floor(sec / 60); + sec %= 60; + eta = min > 0 ? `${Math.round(min)}m ${Math.round(sec)}s` : `${Math.round(sec)}s`; + } } document.title = `SD.Next ${perc}`; for (const elId of elements) { const el = document.getElementById(elId); if (el) { - el.innerText = (res ? `${job} ${perc} ${eta}` : 'Generate'); + const jobLabel = (res ? `${job} ${perc}${eta}` : 'Generate').trim(); + el.innerText = jobLabel; if (!window.waitForUiReady) { - el.style.background = res && (progress > 0) - ? `linear-gradient(to right, var(--primary-500) 0%, var(--primary-800) ${perc}, var(--neutral-700) ${perc})` - : 'var(--button-primary-background-fill)'; + const gradient = perc !== '' ? perc : '100%'; + if (jobLabel === 'Generate') el.style.background = 'var(--primary-500)'; + else if (jobLabel.endsWith('Decode')) continue; + else if (jobLabel.endsWith('Start') || jobLabel.endsWith('Finishing')) el.style.background = 'var(--primary-800)'; + else if (res && progress > 0 && progress < 1) el.style.background = `linear-gradient(to right, var(--primary-500) 0%, var(--primary-800) ${gradient}, var(--neutral-700) ${gradient})`; + else el.style.background = 'var(--primary-500)'; } } } diff --git a/javascript/sdnext.css b/javascript/sdnext.css index 2e85c6988..91ab6318d 100644 --- a/javascript/sdnext.css +++ b/javascript/sdnext.css @@ -346,7 +346,7 @@ div:has(>#tab-gallery-folders) { flex-grow: 0 !important; background-color: var( /* loader */ .splash { position: fixed; top: 0; left: 0; width: 100vw; height: 100vh; z-index: 1000; display: block; text-align: center; } .motd { margin-top: 2em; color: var(--body-text-color-subdued); font-family: monospace; font-variant: all-petite-caps; } -.splash-img { margin: 10% auto 0 auto; width: 512px; background-repeat: no-repeat; height: 512px; animation: color 10s infinite alternate; } +.splash-img { margin: 10% auto 0 auto; width: 512px; background-repeat: no-repeat; height: 512px; animation: color 10s infinite alternate; max-width: 80vw; background-size: contain; } .loading { color: white; position: absolute; top: 20%; left: 50%; transform: translateX(-50%); } .loader { width: 300px; height: 300px; border: var(--spacing-md) solid transparent; border-radius: 50%; border-top: var(--spacing-md) solid var(--primary-600); animation: spin 4s linear infinite; position: relative; } .loader::before, .loader::after { content: ""; position: absolute; top: 6px; bottom: 6px; left: 6px; right: 6px; border-radius: 50%; border: var(--spacing-md) solid transparent; } diff --git a/javascript/startup.js b/javascript/startup.js index 51e1800fc..3e121eb12 100644 --- a/javascript/startup.js +++ b/javascript/startup.js @@ -8,7 +8,6 @@ async function initStartup() { initModels(); getUIDefaults(); initPromptChecker(); - initLogMonitor(); initContextMenu(); initDragDrop(); initAccordions(); @@ -25,6 +24,7 @@ async function initStartup() { // make sure all of the ui is ready and options are loaded while (Object.keys(window.opts).length === 0) await sleep(50); executeCallbacks(uiReadyCallbacks); + initLogMonitor(); setupExtraNetworks(); // optinally wait for modern ui diff --git a/launch.py b/launch.py index e00da58c7..07fa93697 100755 --- a/launch.py +++ b/launch.py @@ -24,12 +24,20 @@ skip_install = False # parsed by some extensions +try: + from modules.timer import launch + rec = launch.record +except Exception: + rec = lambda *args, **kwargs: None # pylint: disable=unnecessary-lambda-assignment + + def init_args(): global parser, args # pylint: disable=global-statement import modules.cmd_args parser = modules.cmd_args.parser installer.add_args(parser) args, _ = parser.parse_known_args() + rec('args') def init_paths(): @@ -38,6 +46,7 @@ def init_paths(): modules.paths.register_paths() script_path = modules.paths.script_path extensions_dir = modules.paths.extensions_dir + rec('paths') def get_custom_args(): @@ -60,6 +69,7 @@ def get_custom_args(): ldd = os.environ.get('LD_PRELOAD', None) if ldd is not None: installer.log.debug(f'Linker flags: "{ldd}"') + rec('args') @lru_cache() @@ -71,6 +81,7 @@ def commit_hash(): # compatbility function stored_commit_hash = run(f"{git} rev-parse HEAD").strip() except Exception: stored_commit_hash = "" + rec('commit') return stored_commit_hash @@ -185,6 +196,7 @@ def start_server(immediate=True, server=None): if args.profile: pr.disable() installer.print_profile(pr, 'WebUI') + rec('server') return uvicorn, server @@ -218,7 +230,6 @@ def main(): installer.install("uv", "uv") installer.check_torch() installer.check_onnx() - installer.check_torchao() installer.check_diffusers() installer.check_modified_files() if args.reinstall: diff --git a/modules/api/helpers.py b/modules/api/helpers.py index 2d6ae8110..da826a81e 100644 --- a/modules/api/helpers.py +++ b/modules/api/helpers.py @@ -75,6 +75,13 @@ def save_image(image, fn, ext): image = image.point(lambda p: p * 0.0038910505836576).convert("RGB") exif_bytes = piexif.dump({ "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") } }) image.save(fn, format=image_format, quality=shared.opts.jpeg_quality, lossless=shared.opts.webp_lossless, exif=exif_bytes) + elif image_format == 'JXL': + if image.mode == 'I;16': + image = image.point(lambda p: p * 0.0038910505836576).convert("RGB") + elif image.mode not in {"RGB", "RGBA"}: + image = image.convert("RGBA") + exif_bytes = piexif.dump({ "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") } }) + image.save(fn, format=image_format, quality=shared.opts.jpeg_quality, lossless=shared.opts.webp_lossless, exif=exif_bytes) else: # shared.log.warning(f'Unrecognized image format: {extension} attempting save as {image_format}') image.save(fn, format=image_format, quality=shared.opts.jpeg_quality) diff --git a/modules/call_queue.py b/modules/call_queue.py index 11ba7b56e..af6f2e4d0 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -15,8 +15,8 @@ def f(*args, **kwargs): return f -def wrap_gradio_gpu_call(func, extra_outputs=None): - name = func.__name__ +def wrap_gradio_gpu_call(func, extra_outputs=None, name=None): + name = name or func.__name__ def f(*args, **kwargs): # if the first argument is a string that says "task(...)", it is treated as a job id if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")": diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 018d802ee..755c31c97 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -1,4 +1,5 @@ import os +import sys import argparse from modules.paths import data_path @@ -77,7 +78,6 @@ def compatibility_args(): group_compat.add_argument("--disable-queue", default=os.environ.get("SD_DISABLEQUEUE", False), action='store_true', help=argparse.SUPPRESS) - def settings_args(opts, args): # removed args are added here as hidden in fixed format for compatbility reasons group_compat = parser.add_argument_group('Compatibility options') @@ -154,7 +154,12 @@ def settings_args(opts, args): opts.onchange("lora_dir", lambda: setattr(args, "lora_dir", opts.lora_dir)) opts.onchange("lyco_dir", lambda: setattr(args, "lyco_dir", opts.lyco_dir)) - args = parser.parse_args() + if "USED_VSCODE_COMMAND_PICKARGS" in os.environ: + import shlex + argv = shlex.split(" ".join(sys.argv[1:])) if "USED_VSCODE_COMMAND_PICKARGS" in os.environ else sys.argv[1:] + args = parser.parse_args(argv) + else: + args = parser.parse_args() return args diff --git a/modules/control/proc/dwpose/__init__.py b/modules/control/proc/dwpose/__init__.py index 4469e10c8..d8fdb9618 100644 --- a/modules/control/proc/dwpose/__init__.py +++ b/modules/control/proc/dwpose/__init__.py @@ -13,12 +13,14 @@ from modules.control.util import HWC3, resize_image from .draw import draw_bodypose, draw_handpose, draw_facepose checked_ok = False +busy = False def check_dependencies(): - global checked_ok # pylint: disable=global-statement + global checked_ok, busy # pylint: disable=global-statement debug = log.trace if os.environ.get('SD_DWPOSE_DEBUG', None) is not None else lambda *args, **kwargs: None packages = [ + 'termcolor', 'openmim==0.3.9', 'mmengine==0.10.4', 'mmcv==2.1.0', @@ -68,8 +70,13 @@ def __init__(self, det_config=None, det_ckpt=None, pose_config=None, pose_ckpt=N if not checked_ok: if not check_dependencies(): return - from .wholebody import Wholebody - self.pose_estimation = Wholebody(det_config, det_ckpt, pose_config, pose_ckpt, device) + Wholebody = None + try: + from .wholebody import Wholebody + except Exception as e: + log.error(f'DWPose: {e}') + if Wholebody is not None: + self.pose_estimation = Wholebody(det_config, det_ckpt, pose_config, pose_ckpt, device) def to(self, device): self.pose_estimation.to(device) @@ -78,6 +85,7 @@ def to(self, device): def __call__(self, input_image, detect_resolution=512, image_resolution=512, output_type="pil", min_confidence=0.3, **kwargs): if self.pose_estimation is None: log.error("DWPose: not loaded") + return None input_image = cv2.cvtColor(np.array(input_image, dtype=np.uint8), cv2.COLOR_RGB2BGR) input_image = HWC3(input_image) diff --git a/modules/control/processors.py b/modules/control/processors.py index 2f15c2d3f..38b5f2062 100644 --- a/modules/control/processors.py +++ b/modules/control/processors.py @@ -167,13 +167,13 @@ def load(self, processor_id: str = None, force: bool = True) -> str: self.config(processor_id) else: if not force and self.model is not None: - log.debug(f'Control Processor: id={processor_id} already loaded') + # log.debug(f'Control Processor: id={processor_id} already loaded') return '' if processor_id not in config: log.error(f'Control Processor unknown: id="{processor_id}" available={list(config)}') return f'Processor failed to load: {processor_id}' cls = config[processor_id]['class'] - log.debug(f'Control Processor loading: id="{processor_id}" class={cls.__name__}') + # log.debug(f'Control Processor loading: id="{processor_id}" class={cls.__name__}') debug(f'Control Processor config={self.load_config}') if 'DWPose' in processor_id: det_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_l_8x8_300e_coco/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth' @@ -253,6 +253,9 @@ def __call__(self, image_input: Image, mode: str = 'RGB', resize_mode: int = 0, image_resized = image_input with devices.inference_context(): image_process = self.model(image_resized, **kwargs) + if image_process is None: + log.error(f'Control Processor: id="{self.processor_id}" no image') + return image_input if isinstance(image_process, np.ndarray): if np.max(image_process) < 2: image_process = (255.0 * image_process).astype(np.uint8) diff --git a/modules/control/run.py b/modules/control/run.py index 8cecb93af..21fc76a4b 100644 --- a/modules/control/run.py +++ b/modules/control/run.py @@ -17,10 +17,11 @@ from modules.processing_class import StableDiffusionProcessingControl from modules.ui_common import infotext_to_html from modules.api import script +from modules.timer import process as process_timer -debug = shared.log.trace if os.environ.get('SD_CONTROL_DEBUG', None) is not None else lambda *args, **kwargs: None -debug('Trace: CONTROL') +debug = os.environ.get('SD_CONTROL_DEBUG', None) is not None +debug_log = shared.log.trace if debug else lambda *args, **kwargs: None pipe = None instance = None original_pipeline = None @@ -32,7 +33,7 @@ def restore_pipeline(): if instance is not None and hasattr(instance, 'restore'): instance.restore() if original_pipeline is not None and (original_pipeline.__class__.__name__ != shared.sd_model.__class__.__name__): - debug(f'Control restored pipeline: class={shared.sd_model.__class__.__name__} to={original_pipeline.__class__.__name__}') + debug_log(f'Control restored pipeline: class={shared.sd_model.__class__.__name__} to={original_pipeline.__class__.__name__}') shared.sd_model = original_pipeline pipe = None instance = None @@ -67,7 +68,10 @@ def set_pipe(p, has_models, unit_type, selected_models, active_model, active_str shared.log.warning('Control: T2I-Adapter does not support separate init image') elif unit_type == 'controlnet' and has_models: p.extra_generation_params["Control type"] = 'ControlNet' - p.task_args['controlnet_conditioning_scale'] = control_conditioning + if shared.sd_model_type == 'f1': + p.task_args['controlnet_conditioning_scale'] = control_conditioning if isinstance(control_conditioning, list) else [control_conditioning] + else: + p.task_args['controlnet_conditioning_scale'] = control_conditioning p.task_args['control_guidance_start'] = control_guidance_start p.task_args['control_guidance_end'] = control_guidance_end p.task_args['guess_mode'] = p.guess_mode @@ -106,7 +110,7 @@ def set_pipe(p, has_models, unit_type, selected_models, active_model, active_str p.strength = active_strength[0] pipe = shared.sd_model instance = None - debug(f'Control: run type={unit_type} models={has_models} pipe={pipe.__class__.__name__ if pipe is not None else None}') + debug_log(f'Control: run type={unit_type} models={has_models} pipe={pipe.__class__.__name__ if pipe is not None else None}') return pipe @@ -121,14 +125,14 @@ def check_active(p, unit_type, units): if u.type != unit_type: continue num_units += 1 - debug(f'Control unit: i={num_units} type={u.type} enabled={u.enabled}') + debug_log(f'Control unit: i={num_units} type={u.type} enabled={u.enabled}') if not u.enabled: if u.controlnet is not None and u.controlnet.model is not None: - debug(f'Control unit offload: model="{u.controlnet.model_id}" device={devices.cpu}') + debug_log(f'Control unit offload: model="{u.controlnet.model_id}" device={devices.cpu}') sd_models.move_model(u.controlnet.model, devices.cpu) continue if u.controlnet is not None and u.controlnet.model is not None: - debug(f'Control unit offload: model="{u.controlnet.model_id}" device={devices.device}') + debug_log(f'Control unit offload: model="{u.controlnet.model_id}" device={devices.device}') sd_models.move_model(u.controlnet.model, devices.device) if unit_type == 't2i adapter' and u.adapter.model is not None: active_process.append(u.process) @@ -173,7 +177,7 @@ def check_active(p, unit_type, units): active_process.append(u.process) shared.log.debug(f'Control process unit: i={num_units} process={u.process.processor_id}') active_strength.append(float(u.strength)) - debug(f'Control active: process={len(active_process)} model={len(active_model)}') + debug_log(f'Control active: process={len(active_process)} model={len(active_model)}') return active_process, active_model, active_strength, active_start, active_end @@ -188,7 +192,7 @@ def check_enabled(p, unit_type, units, active_model, active_strength, active_sta selected_models = None elif len(active_model) == 1: selected_models = active_model[0].model if active_model[0].model is not None else None - p.is_tile = p.is_tile or 'tile' in active_model[0].model_id.lower() + p.is_tile = p.is_tile or 'tile' in (active_model[0].model_id or '').lower() has_models = selected_models is not None control_conditioning = active_strength[0] if len(active_strength) > 0 else 1 # strength or list[strength] control_guidance_start = active_start[0] if len(active_start) > 0 else 0 @@ -210,7 +214,7 @@ def control_set(kwargs): if kwargs: global p_extra_args # pylint: disable=global-statement p_extra_args = {} - debug(f'Control extra args: {kwargs}') + debug_log(f'Control extra args: {kwargs}') for k, v in kwargs.items(): p_extra_args[k] = v @@ -222,7 +226,8 @@ def control_run(state: str = '', steps: int = 20, sampler_index: int = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, cfg_scale: float = 6.0, clip_skip: float = 1.0, image_cfg_scale: float = 6.0, diffusers_guidance_rescale: float = 0.7, pag_scale: float = 0.0, pag_adaptive: float = 0.5, cfg_end: float = 1.0, - full_quality: bool = True, detailer: bool = False, tiling: bool = False, hidiffusion: bool = False, + full_quality: bool = True, tiling: bool = False, hidiffusion: bool = False, + detailer_enabled: bool = True, detailer_prompt: str = '', detailer_negative: str = '', detailer_steps: int = 10, detailer_strength: float = 0.3, hdr_mode: int = 0, hdr_brightness: float = 0, hdr_color: float = 0, hdr_sharpen: float = 0, hdr_clamp: bool = False, hdr_boundary: float = 4.0, hdr_threshold: float = 0.95, hdr_maximize: bool = False, hdr_max_center: float = 0.6, hdr_max_boundry: float = 1.0, hdr_color_picker: str = None, hdr_tint_ratio: float = 0, resize_mode_before: int = 0, resize_name_before: str = 'None', resize_context_before: str = 'None', width_before: int = 512, height_before: int = 512, scale_by_before: float = 1.0, selected_scale_tab_before: int = 0, @@ -250,7 +255,7 @@ def control_run(state: str = '', u.process.override = u.override global pipe, original_pipeline # pylint: disable=global-statement - debug(f'Control: type={unit_type} input={inputs} init={inits} type={input_type}') + debug_log(f'Control: type={unit_type} input={inputs} init={inits} type={input_type}') if inputs is None or (type(inputs) is list and len(inputs) == 0): inputs = [None] output_images: List[Image.Image] = [] # output images @@ -286,9 +291,14 @@ def control_run(state: str = '', pag_scale = pag_scale, pag_adaptive = pag_adaptive, full_quality = full_quality, - detailer = detailer, tiling = tiling, hidiffusion = hidiffusion, + # detailer + detailer_enabled = detailer_enabled, + detailer_prompt = detailer_prompt, + detailer_negative = detailer_negative, + detailer_steps = detailer_steps, + detailer_strength = detailer_strength, # resize resize_mode = resize_mode_before if resize_name_before != 'None' else 0, resize_name = resize_name_before, @@ -393,7 +403,7 @@ def control_run(state: str = '', p.is_tile = p.is_tile and has_models pipe = set_pipe(p, has_models, unit_type, selected_models, active_model, active_strength, control_conditioning, control_guidance_start, control_guidance_end, inits) - debug(f'Control pipeline: class={pipe.__class__.__name__} args={vars(p)}') + debug_log(f'Control pipeline: class={pipe.__class__.__name__} args={vars(p)}') t1, t2, t3 = time.time(), 0, 0 status = True frame = None @@ -404,12 +414,14 @@ def control_run(state: str = '', blended_image = None # set pipeline - if pipe.__class__.__name__ != shared.sd_model.__class__.__name__: + if pipe is None: + return [], '', '', 'Pipeline not set' + elif pipe.__class__.__name__ != shared.sd_model.__class__.__name__: original_pipeline = shared.sd_model shared.sd_model = pipe sd_models.move_model(shared.sd_model, shared.device) shared.sd_model.to(dtype=devices.dtype) - debug(f'Control device={devices.device} dtype={devices.dtype}') + debug_log(f'Control device={devices.device} dtype={devices.dtype}') sd_models.copy_diffuser_options(shared.sd_model, original_pipeline) # copy options from original pipeline sd_models.set_diffuser_options(shared.sd_model) else: @@ -447,12 +459,12 @@ def control_run(state: str = '', while status: if pipe is None: # pipe may have been reset externally pipe = set_pipe(p, has_models, unit_type, selected_models, active_model, active_strength, control_conditioning, control_guidance_start, control_guidance_end, inits) - debug(f'Control pipeline reinit: class={pipe.__class__.__name__}') + debug_log(f'Control pipeline reinit: class={pipe.__class__.__name__}') processed_image = None if frame is not None: inputs = [Image.fromarray(frame)] # cv2 to pil for i, input_image in enumerate(inputs): - debug(f'Control Control image: {i + 1} of {len(inputs)}') + debug_log(f'Control Control image: {i + 1} of {len(inputs)}') if shared.state.skipped: shared.state.skipped = False continue @@ -470,20 +482,20 @@ def control_run(state: str = '', continue # match init input if input_type == 1: - debug('Control Init image: same as control') + debug_log('Control Init image: same as control') init_image = input_image elif inits is None: - debug('Control Init image: none') + debug_log('Control Init image: none') init_image = None elif isinstance(inits[i], str): - debug(f'Control: init image: {inits[i]}') + debug_log(f'Control: init image: {inits[i]}') try: init_image = Image.open(inits[i]) except Exception as e: shared.log.error(f'Control: image open failed: path={inits[i]} type=init error={e}') continue else: - debug(f'Control Init image: {i % len(inits) + 1} of {len(inits)}') + debug_log(f'Control Init image: {i % len(inits) + 1} of {len(inits)}') init_image = inits[i % len(inits)] if video is not None and index % (video_skip_frames + 1) != 0: index += 1 @@ -496,18 +508,18 @@ def control_run(state: str = '', width_before, height_before = int(input_image.width * scale_by_before), int(input_image.height * scale_by_before) if input_image is not None: p.extra_generation_params["Control resize"] = f'{resize_name_before}' - debug(f'Control resize: op=before image={input_image} width={width_before} height={height_before} mode={resize_mode_before} name={resize_name_before} context="{resize_context_before}"') + debug_log(f'Control resize: op=before image={input_image} width={width_before} height={height_before} mode={resize_mode_before} name={resize_name_before} context="{resize_context_before}"') input_image = images.resize_image(resize_mode_before, input_image, width_before, height_before, resize_name_before, context=resize_context_before) if input_image is not None and init_image is not None and init_image.size != input_image.size: - debug(f'Control resize init: image={init_image} target={input_image}') + debug_log(f'Control resize init: image={init_image} target={input_image}') init_image = images.resize_image(resize_mode=1, im=init_image, width=input_image.width, height=input_image.height) if input_image is not None and p.override is not None and p.override.size != input_image.size: - debug(f'Control resize override: image={p.override} target={input_image}') + debug_log(f'Control resize override: image={p.override} target={input_image}') p.override = images.resize_image(resize_mode=1, im=p.override, width=input_image.width, height=input_image.height) if input_image is not None: p.width = input_image.width p.height = input_image.height - debug(f'Control: input image={input_image}') + debug_log(f'Control: input image={input_image}') processed_images = [] if mask is not None: @@ -522,7 +534,7 @@ def control_run(state: str = '', else: masked_image = input_image for i, process in enumerate(active_process): # list[image] - debug(f'Control: i={i+1} process="{process.processor_id}" input={masked_image} override={process.override}') + debug_log(f'Control: i={i+1} process="{process.processor_id}" input={masked_image} override={process.override}') processed_image = process( image_input=masked_image, mode='RGB', @@ -537,7 +549,7 @@ def control_run(state: str = '', processors.config[process.processor_id]['dirty'] = True # to force reload process.model = None - debug(f'Control processed: {len(processed_images)}') + debug_log(f'Control processed: {len(processed_images)}') if len(processed_images) > 0: try: if len(p.extra_generation_params["Control process"]) == 0: @@ -563,7 +575,7 @@ def control_run(state: str = '', blended_image = util.blend(blended_image) # blend all processed images into one blended_image = Image.fromarray(blended_image) if isinstance(selected_models, list) and len(processed_images) == len(selected_models): - debug(f'Control: inputs match: input={len(processed_images)} models={len(selected_models)}') + debug_log(f'Control: inputs match: input={len(processed_images)} models={len(selected_models)}') p.init_images = processed_images elif isinstance(selected_models, list) and len(processed_images) != len(selected_models): if is_generator: @@ -572,14 +584,14 @@ def control_run(state: str = '', elif selected_models is not None: p.init_images = processed_image else: - debug('Control processed: using input direct') + debug_log('Control processed: using input direct') processed_image = input_image if unit_type == 'reference' and has_models: p.ref_image = p.override or input_image p.task_args.pop('image', None) p.task_args['ref_image'] = p.ref_image - debug(f'Control: process=None image={p.ref_image}') + debug_log(f'Control: process=None image={p.ref_image}') if p.ref_image is None: if is_generator: yield terminate('Attempting reference mode but image is none') @@ -614,7 +626,7 @@ def control_run(state: str = '', if is_generator: image_txt = f'{blended_image.width}x{blended_image.height}' if blended_image is not None else 'None' msg = f'process | {index} of {frames if video is not None else len(inputs)} | {"Image" if video is None else "Frame"} {image_txt}' - debug(f'Control yield: {msg}') + debug_log(f'Control yield: {msg}') if is_generator: yield (None, blended_image, f'Control {msg}') t2 += time.time() - t2 @@ -673,7 +685,7 @@ def control_run(state: str = '', if selected_scale_tab_mask == 1: width_mask, height_mask = int(input_image.width * scale_by_mask), int(input_image.height * scale_by_mask) p.width, p.height = width_mask, height_mask - debug(f'Control resize: op=mask image={mask} width={width_mask} height={height_mask} mode={resize_mode_mask} name={resize_name_mask} context="{resize_context_mask}"') + debug_log(f'Control resize: op=mask image={mask} width={width_mask} height={height_mask} mode={resize_mode_mask} name={resize_name_mask} context="{resize_context_mask}"') # pipeline output = None @@ -681,9 +693,10 @@ def control_run(state: str = '', if pipe is not None: # run new pipeline if not hasattr(pipe, 'restore_pipeline') and video is None: pipe.restore_pipeline = restore_pipeline - debug(f'Control exec pipeline: task={sd_models.get_diffusers_task(pipe)} class={pipe.__class__}') - # debug(f'Control exec pipeline: p={vars(p)}') - # debug(f'Control exec pipeline: args={p.task_args} image={p.task_args.get("image", None)} control={p.task_args.get("control_image", None)} mask={p.task_args.get("mask_image", None) or p.image_mask} ref={p.task_args.get("ref_image", None)}') + shared.sd_model.restore_pipeline = restore_pipeline + debug_log(f'Control exec pipeline: task={sd_models.get_diffusers_task(pipe)} class={pipe.__class__}') + # debug_log(f'Control exec pipeline: p={vars(p)}') + # debug_log(f'Control exec pipeline: args={p.task_args} image={p.task_args.get("image", None)} control={p.task_args.get("control_image", None)} mask={p.task_args.get("mask_image", None) or p.image_mask} ref={p.task_args.get("ref_image", None)}') if sd_models.get_diffusers_task(pipe) != sd_models.DiffusersTaskType.TEXT_2_IMAGE: # force vae back to gpu if not in txt2img mode sd_models.move_model(pipe.vae, devices.device) @@ -729,7 +742,7 @@ def control_run(state: str = '', width_after = int(output_image.width * scale_by_after) height_after = int(output_image.height * scale_by_after) if resize_mode_after != 0 and resize_name_after != 'None' and not is_grid: - debug(f'Control resize: op=after image={output_image} width={width_after} height={height_after} mode={resize_mode_after} name={resize_name_after} context="{resize_context_after}"') + debug_log(f'Control resize: op=after image={output_image} width={width_after} height={height_after} mode={resize_mode_after} name={resize_name_after} context="{resize_context_after}"') output_image = images.resize_image(resize_mode_after, output_image, width_after, height_after, resize_name_after, context=resize_context_after) output_images.append(output_image) @@ -749,14 +762,16 @@ def control_run(state: str = '', status, frame = video.read() if status: frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - debug(f'Control: video frame={index} frames={frames} status={status} skip={index % (video_skip_frames + 1)} progress={index/frames:.2f}') + debug_log(f'Control: video frame={index} frames={frames} status={status} skip={index % (video_skip_frames + 1)} progress={index/frames:.2f}') else: status = False if video is not None: video.release() - shared.log.info(f'Control: pipeline units={len(active_model)} process={len(active_process)} time={t3-t0:.2f} init={t1-t0:.2f} proc={t2-t1:.2f} ctrl={t3-t2:.2f} outputs={len(output_images)}') + debug_log(f'Control: pipeline units={len(active_model)} process={len(active_process)} time={t3-t0:.2f} init={t1-t0:.2f} proc={t2-t1:.2f} ctrl={t3-t2:.2f} outputs={len(output_images)}') + process_timer.add('init', t1-t0) + process_timer.add('proc', t2-t1) except Exception as e: shared.log.error(f'Control pipeline failed: type={unit_type} units={len(active_model)} error={e}') errors.display(e, 'Control') @@ -777,7 +792,7 @@ def control_run(state: str = '', p.close() restore_pipeline() - debug(f'Ready: {image_txt}') + debug_log(f'Ready: {image_txt}') html_txt = f'

Ready {image_txt}

' if image_txt != '' else '' if len(info_txt) > 0: diff --git a/modules/control/units/controlnet.py b/modules/control/units/controlnet.py index c887aca8f..4837577fe 100644 --- a/modules/control/units/controlnet.py +++ b/modules/control/units/controlnet.py @@ -1,5 +1,6 @@ import os import time +import threading from typing import Union from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, FluxPipeline, StableDiffusion3Pipeline, ControlNetModel from modules.control.units import detect @@ -9,8 +10,8 @@ what = 'ControlNet' -debug = log.trace if os.environ.get('SD_CONTROL_DEBUG', None) is not None else lambda *args, **kwargs: None -debug('Trace: CONTROL') +debug = os.environ.get('SD_CONTROL_DEBUG', None) is not None +debug_log = log.trace if os.environ.get('SD_CONTROL_DEBUG', None) is not None else lambda *args, **kwargs: None predefined_sd15 = { 'Canny': "lllyasviel/control_v11p_sd15_canny", 'Depth': "lllyasviel/control_v11f1p_sd15_depth", @@ -91,15 +92,15 @@ "XLabs-AI HED": 'XLabs-AI/flux-controlnet-hed-diffusers' } predefined_sd3 = { - "StabilityAI Canny": 'diffusers-internal-dev/sd35-controlnet-canny-8b', - "StabilityAI Depth": 'diffusers-internal-dev/sd35-controlnet-depth-8b', - "StabilityAI Blur": 'diffusers-internal-dev/sd35-controlnet-blur-8b', - "InstantX Canny": 'InstantX/SD3-Controlnet-Canny', - "InstantX Pose": 'InstantX/SD3-Controlnet-Pose', - "InstantX Depth": 'InstantX/SD3-Controlnet-Depth', - "InstantX Tile": 'InstantX/SD3-Controlnet-Tile', - "Alimama Inpainting": 'alimama-creative/SD3-Controlnet-Inpainting', - "Alimama SoftEdge": 'alimama-creative/SD3-Controlnet-Softedge', + "StabilityAI Canny SD35": 'diffusers-internal-dev/sd35-controlnet-canny-8b', + "StabilityAI Depth SD35": 'diffusers-internal-dev/sd35-controlnet-depth-8b', + "StabilityAI Blur SD35": 'diffusers-internal-dev/sd35-controlnet-blur-8b', + "InstantX Canny SD35": 'InstantX/SD3-Controlnet-Canny', + "InstantX Pose SD35": 'InstantX/SD3-Controlnet-Pose', + "InstantX Depth SD35": 'InstantX/SD3-Controlnet-Depth', + "InstantX Tile SD35": 'InstantX/SD3-Controlnet-Tile', + "Alimama Inpainting SD35": 'alimama-creative/SD3-Controlnet-Inpainting', + "Alimama SoftEdge SD35": 'alimama-creative/SD3-Controlnet-Softedge', } variants = { 'NoobAI Canny XL': 'fp16', @@ -116,6 +117,7 @@ all_models.update(predefined_f1) all_models.update(predefined_sd3) cache_dir = 'models/control/controlnet' +load_lock = threading.Lock() def find_models(): @@ -154,7 +156,7 @@ def list_models(refresh=False): else: log.warning(f'Control {what} model list failed: unknown model type') models = ['None'] + sorted(predefined_sd15) + sorted(predefined_sdxl) + sorted(predefined_f1) + sorted(predefined_sd3) + sorted(find_models()) - debug(f'Control list {what}: path={cache_dir} models={models}') + debug_log(f'Control list {what}: path={cache_dir} models={models}') return models @@ -172,7 +174,7 @@ def __init__(self, model_id: str = None, device = None, dtype = None, load_confi def reset(self): if self.model is not None: - debug(f'Control {what} model unloaded') + debug_log(f'Control {what} model unloaded') self.model = None self.model_id = None @@ -231,78 +233,86 @@ def load_safetensors(self, model_id, model_path): self.load_config['original_config_file '] = config_path cls, config = self.get_class(model_id) if cls is None: - log.error(f'Control {what} model load failed: unknown base model') + log.error(f'Control {what} model load: unknown base model') else: self.model = cls.from_single_file(model_path, config=config, **self.load_config) def load(self, model_id: str = None, force: bool = True) -> str: - try: - t0 = time.time() - model_id = model_id or self.model_id - if model_id is None or model_id == 'None': - self.reset() - return - if model_id not in all_models: - log.error(f'Control {what} unknown model: id="{model_id}" available={list(all_models)}') - return - model_path = all_models[model_id] - if model_path == '': - return - if model_path is None: - log.error(f'Control {what} model load failed: id="{model_id}" error=unknown model id') - return - if 'lora' in model_id.lower(): - self.model = model_path - return - if model_id == self.model_id and not force: - log.debug(f'Control {what} model: id="{model_id}" path="{model_path}" already loaded') - return - log.debug(f'Control {what} model loading: id="{model_id}" path="{model_path}"') - cls, _config = self.get_class(model_id) - if model_path.endswith('.safetensors'): - self.load_safetensors(model_id, model_path) - else: - kwargs = {} - if '/bin' in model_path: - model_path = model_path.replace('/bin', '') - self.load_config['use_safetensors'] = False - if cls is None: - log.error(f'Control {what} model load failed: id="{model_id}" unknown base model') + with load_lock: + try: + t0 = time.time() + model_id = model_id or self.model_id + if model_id is None or model_id == 'None': + self.reset() return - if variants.get(model_id, None) is not None: - kwargs['variant'] = variants[model_id] - self.model = cls.from_pretrained(model_path, **self.load_config, **kwargs) - if self.model is None: - return - if self.dtype is not None: - self.model.to(self.dtype) - if "ControlNet" in opts.nncf_compress_weights: - try: - log.debug(f'Control {what} model NNCF Compress: id="{model_id}"') - from installer import install - install('nncf==2.7.0', quiet=True) - from modules.sd_models_compile import nncf_compress_model - self.model = nncf_compress_model(self.model) - except Exception as e: - log.error(f'Control {what} model NNCF Compression failed: id="{model_id}" error={e}') - elif "ControlNet" in opts.optimum_quanto_weights: - try: - log.debug(f'Control {what} model Optimum Quanto: id="{model_id}"') - model_quant.load_quanto('Load model: type=ControlNet') - from modules.sd_models_compile import optimum_quanto_model - self.model = optimum_quanto_model(self.model) - except Exception as e: - log.error(f'Control {what} model Optimum Quanto failed: id="{model_id}" error={e}') - if self.device is not None: - self.model.to(self.device) - t1 = time.time() - self.model_id = model_id - log.debug(f'Control {what} model loaded: id="{model_id}" path="{model_path}" cls={cls.__name__} time={t1-t0:.2f}') - return f'{what} loaded model: {model_id}' - except Exception as e: - log.error(f'Control {what} model load failed: id="{model_id}" error={e}') - errors.display(e, f'Control {what} load') - return f'{what} failed to load model: {model_id}' + if model_id not in all_models: + log.error(f'Control {what}: id="{model_id}" available={list(all_models)} unknown model') + return + model_path = all_models[model_id] + if model_path == '': + return + if model_path is None: + log.error(f'Control {what} model load: id="{model_id}" unknown model id') + return + if 'lora' in model_id.lower(): + self.model = model_path + return + if model_id == self.model_id and not force: + # log.debug(f'Control {what} model: id="{model_id}" path="{model_path}" already loaded') + return + log.debug(f'Control {what} model loading: id="{model_id}" path="{model_path}"') + cls, _config = self.get_class(model_id) + if model_path.endswith('.safetensors'): + self.load_safetensors(model_id, model_path) + else: + kwargs = {} + if '/bin' in model_path: + model_path = model_path.replace('/bin', '') + self.load_config['use_safetensors'] = False + else: + self.load_config['use_safetensors'] = True + if cls is None: + log.error(f'Control {what} model load: id="{model_id}" unknown base model') + return + if variants.get(model_id, None) is not None: + kwargs['variant'] = variants[model_id] + try: + self.model = cls.from_pretrained(model_path, **self.load_config, **kwargs) + except Exception as e: + log.error(f'Control {what} model load: id="{model_id}" {e}') + if debug: + errors.display(e, 'Control') + if self.model is None: + return + if self.dtype is not None: + self.model.to(self.dtype) + if "ControlNet" in opts.nncf_compress_weights: + try: + log.debug(f'Control {what} model NNCF Compress: id="{model_id}"') + from installer import install + install('nncf==2.7.0', quiet=True) + from modules.sd_models_compile import nncf_compress_model + self.model = nncf_compress_model(self.model) + except Exception as e: + log.error(f'Control {what} model NNCF Compression failed: id="{model_id}" {e}') + elif "ControlNet" in opts.optimum_quanto_weights: + try: + log.debug(f'Control {what} model Optimum Quanto: id="{model_id}"') + model_quant.load_quanto('Load model: type=ControlNet') + from modules.sd_models_compile import optimum_quanto_model + self.model = optimum_quanto_model(self.model) + except Exception as e: + log.error(f'Control {what} model Optimum Quanto: id="{model_id}" {e}') + if self.device is not None: + self.model.to(self.device) + t1 = time.time() + self.model_id = model_id + log.info(f'Control {what} model loaded: id="{model_id}" path="{model_path}" cls={cls.__name__} time={t1-t0:.2f}') + return f'{what} loaded model: {model_id}' + except Exception as e: + log.error(f'Control {what} model load: id="{model_id}" {e}') + errors.display(e, f'Control {what} load') + return f'{what} failed to load model: {model_id}' class ControlNetPipeline(): @@ -401,13 +411,14 @@ def __init__(self, if dtype is not None: self.pipeline = self.pipeline.to(dtype) + sd_models.copy_diffuser_options(self.pipeline, pipeline) if opts.diffusers_offload_mode == 'none': sd_models.move_model(self.pipeline, devices.device) from modules.sd_models import set_diffuser_offload set_diffuser_offload(self.pipeline, 'model') t1 = time.time() - log.debug(f'Control {what} pipeline: class={self.pipeline.__class__.__name__} time={t1-t0:.2f}') + debug_log(f'Control {what} pipeline: class={self.pipeline.__class__.__name__} time={t1-t0:.2f}') def restore(self): self.pipeline.unload_lora_weights() diff --git a/modules/control/units/lite.py b/modules/control/units/lite.py index 0c10d8d53..854e5bd56 100644 --- a/modules/control/units/lite.py +++ b/modules/control/units/lite.py @@ -1,6 +1,7 @@ import os import time from typing import Union +import threading import numpy as np from PIL import Image from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline @@ -27,6 +28,7 @@ all_models.update(predefined_sd15) all_models.update(predefined_sdxl) cache_dir = 'models/control/lite' +load_lock = threading.Lock() def find_models(): @@ -79,44 +81,45 @@ def reset(self): self.model_id = None def load(self, model_id: str = None, force: bool = True) -> str: - try: - t0 = time.time() - model_id = model_id or self.model_id - if model_id is None or model_id == 'None': - self.reset() - return - if model_id not in all_models: - log.error(f'Control {what} unknown model: id="{model_id}" available={list(all_models)}') - return - model_path = all_models[model_id] - if model_path == '': - return - if model_path is None: - log.error(f'Control {what} model load failed: id="{model_id}" error=unknown model id') - return - if model_id == self.model_id and not force: - log.debug(f'Control {what} model: id="{model_id}" path="{model_path}" already loaded') - return - log.debug(f'Control {what} model loading: id="{model_id}" path="{model_path}" {self.load_config}') - if model_path.endswith('.safetensors'): - self.model = ControlNetLLLite(model_path) - else: - import huggingface_hub as hf - folder, filename = os.path.split(model_path) - model_path = hf.hf_hub_download(repo_id=folder, filename=f'{filename}.safetensors', cache_dir=cache_dir) - self.model = ControlNetLLLite(model_path) - if self.device is not None: - self.model.to(self.device) - if self.dtype is not None: - self.model.to(self.dtype) - t1 = time.time() - self.model_id = model_id - log.debug(f'Control {what} model loaded: id="{model_id}" path="{model_path}" time={t1-t0:.2f}') - return f'{what} loaded model: {model_id}' - except Exception as e: - log.error(f'Control {what} model load failed: id="{model_id}" error={e}') - errors.display(e, f'Control {what} load') - return f'{what} failed to load model: {model_id}' + with load_lock: + try: + t0 = time.time() + model_id = model_id or self.model_id + if model_id is None or model_id == 'None': + self.reset() + return + if model_id not in all_models: + log.error(f'Control {what} unknown model: id="{model_id}" available={list(all_models)}') + return + model_path = all_models[model_id] + if model_path == '': + return + if model_path is None: + log.error(f'Control {what} model load failed: id="{model_id}" error=unknown model id') + return + if model_id == self.model_id and not force: + # log.debug(f'Control {what} model: id="{model_id}" path="{model_path}" already loaded') + return + log.debug(f'Control {what} model loading: id="{model_id}" path="{model_path}" {self.load_config}') + if model_path.endswith('.safetensors'): + self.model = ControlNetLLLite(model_path) + else: + import huggingface_hub as hf + folder, filename = os.path.split(model_path) + model_path = hf.hf_hub_download(repo_id=folder, filename=f'{filename}.safetensors', cache_dir=cache_dir) + self.model = ControlNetLLLite(model_path) + if self.device is not None: + self.model.to(self.device) + if self.dtype is not None: + self.model.to(self.dtype) + t1 = time.time() + self.model_id = model_id + log.debug(f'Control {what} model loaded: id="{model_id}" path="{model_path}" time={t1-t0:.2f}') + return f'{what} loaded model: {model_id}' + except Exception as e: + log.error(f'Control {what} model load failed: id="{model_id}" error={e}') + errors.display(e, f'Control {what} load') + return f'{what} failed to load model: {model_id}' class ControlLLitePipeline(): diff --git a/modules/control/units/t2iadapter.py b/modules/control/units/t2iadapter.py index 6e15abe3d..66abdc75e 100644 --- a/modules/control/units/t2iadapter.py +++ b/modules/control/units/t2iadapter.py @@ -1,6 +1,7 @@ import os import time from typing import Union +import threading from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, MultiAdapter, StableDiffusionAdapterPipeline, StableDiffusionXLAdapterPipeline # pylint: disable=unused-import from modules.shared import log from modules import errors, sd_models @@ -43,6 +44,7 @@ all_models.update(predefined_sd15) all_models.update(predefined_sdxl) cache_dir = 'models/control/adapter' +load_lock = threading.Lock() def list_models(refresh=False): @@ -87,45 +89,46 @@ def reset(self): self.model_id = None def load(self, model_id: str = None, force: bool = True) -> str: - try: - t0 = time.time() - model_id = model_id or self.model_id - if model_id is None or model_id == 'None': - self.reset() - return - if model_id not in all_models: - log.error(f'Control {what} unknown model: id="{model_id}" available={list(all_models)}') - return - model_path, model_args = all_models[model_id] - self.load_config.update(model_args) - if model_path is None: - log.error(f'Control {what} model load failed: id="{model_id}" error=unknown model id') - return - if model_id == self.model_id and not force: - log.debug(f'Control {what} model: id="{model_id}" path="{model_path}" already loaded') - return - log.debug(f'Control {what} model loading: id="{model_id}" path="{model_path}"') - if model_path.endswith('.pth') or model_path.endswith('.pt') or model_path.endswith('.safetensors') or model_path.endswith('.bin'): - from huggingface_hub import hf_hub_download - parts = model_path.split('/') - repo_id = f'{parts[0]}/{parts[1]}' - filename = '/'.join(parts[2:]) - model = hf_hub_download(repo_id, filename, **self.load_config) - self.model = T2IAdapter.from_pretrained(model, **self.load_config) - else: - self.model = T2IAdapter.from_pretrained(model_path, **self.load_config) - if self.device is not None: - self.model.to(self.device) - if self.dtype is not None: - self.model.to(self.dtype) - t1 = time.time() - self.model_id = model_id - log.debug(f'Control {what} loaded: id="{model_id}" path="{model_path}" time={t1-t0:.2f}') - return f'{what} loaded model: {model_id}' - except Exception as e: - log.error(f'Control {what} model load failed: id="{model_id}" error={e}') - errors.display(e, f'Control {what} load') - return f'{what} failed to load model: {model_id}' + with load_lock: + try: + t0 = time.time() + model_id = model_id or self.model_id + if model_id is None or model_id == 'None': + self.reset() + return + if model_id not in all_models: + log.error(f'Control {what} unknown model: id="{model_id}" available={list(all_models)}') + return + model_path, model_args = all_models[model_id] + self.load_config.update(model_args) + if model_path is None: + log.error(f'Control {what} model load failed: id="{model_id}" error=unknown model id') + return + if model_id == self.model_id and not force: + # log.debug(f'Control {what} model: id="{model_id}" path="{model_path}" already loaded') + return + log.debug(f'Control {what} model loading: id="{model_id}" path="{model_path}"') + if model_path.endswith('.pth') or model_path.endswith('.pt') or model_path.endswith('.safetensors') or model_path.endswith('.bin'): + from huggingface_hub import hf_hub_download + parts = model_path.split('/') + repo_id = f'{parts[0]}/{parts[1]}' + filename = '/'.join(parts[2:]) + model = hf_hub_download(repo_id, filename, **self.load_config) + self.model = T2IAdapter.from_pretrained(model, **self.load_config) + else: + self.model = T2IAdapter.from_pretrained(model_path, **self.load_config) + if self.device is not None: + self.model.to(self.device) + if self.dtype is not None: + self.model.to(self.dtype) + t1 = time.time() + self.model_id = model_id + log.debug(f'Control {what} loaded: id="{model_id}" path="{model_path}" time={t1-t0:.2f}') + return f'{what} loaded model: {model_id}' + except Exception as e: + log.error(f'Control {what} model load failed: id="{model_id}" error={e}') + errors.display(e, f'Control {what} load') + return f'{what} failed to load model: {model_id}' class AdapterPipeline(): diff --git a/modules/control/units/xs.py b/modules/control/units/xs.py index ff21a5ed7..232387582 100644 --- a/modules/control/units/xs.py +++ b/modules/control/units/xs.py @@ -1,6 +1,7 @@ import os import time from typing import Union +import threading from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline from modules.shared import log, opts, listdir from modules import errors, sd_models @@ -23,6 +24,7 @@ all_models.update(predefined_sd15) all_models.update(predefined_sdxl) cache_dir = 'models/control/xs' +load_lock = threading.Lock() def find_models(): @@ -75,42 +77,43 @@ def reset(self): self.model_id = None def load(self, model_id: str = None, time_embedding_mix: float = 0.0, force: bool = True) -> str: - try: - t0 = time.time() - model_id = model_id or self.model_id - if model_id is None or model_id == 'None': - self.reset() - return - if model_id not in all_models: - log.error(f'Control {what} unknown model: id="{model_id}" available={list(all_models)}') - return - model_path = all_models[model_id] - if model_path == '': - return - if model_path is None: - log.error(f'Control {what} model load failed: id="{model_id}" error=unknown model id') - return - if model_id == self.model_id and not force: - log.debug(f'Control {what} model: id="{model_id}" path="{model_path}" already loaded') - return - self.load_config['time_embedding_mix'] = time_embedding_mix - log.debug(f'Control {what} model loading: id="{model_id}" path="{model_path}" {self.load_config}') - if model_path.endswith('.safetensors'): - self.model = ControlNetXSModel.from_single_file(model_path, **self.load_config) - else: - self.model = ControlNetXSModel.from_pretrained(model_path, **self.load_config) - if self.device is not None: - self.model.to(self.device) - if self.dtype is not None: - self.model.to(self.dtype) - t1 = time.time() - self.model_id = model_id - log.debug(f'Control {what} model loaded: id="{model_id}" path="{model_path}" time={t1-t0:.2f}') - return f'{what} loaded model: {model_id}' - except Exception as e: - log.error(f'Control {what} model load failed: id="{model_id}" error={e}') - errors.display(e, f'Control {what} load') - return f'{what} failed to load model: {model_id}' + with load_lock: + try: + t0 = time.time() + model_id = model_id or self.model_id + if model_id is None or model_id == 'None': + self.reset() + return + if model_id not in all_models: + log.error(f'Control {what} unknown model: id="{model_id}" available={list(all_models)}') + return + model_path = all_models[model_id] + if model_path == '': + return + if model_path is None: + log.error(f'Control {what} model load failed: id="{model_id}" error=unknown model id') + return + if model_id == self.model_id and not force: + # log.debug(f'Control {what} model: id="{model_id}" path="{model_path}" already loaded') + return + self.load_config['time_embedding_mix'] = time_embedding_mix + log.debug(f'Control {what} model loading: id="{model_id}" path="{model_path}" {self.load_config}') + if model_path.endswith('.safetensors'): + self.model = ControlNetXSModel.from_single_file(model_path, **self.load_config) + else: + self.model = ControlNetXSModel.from_pretrained(model_path, **self.load_config) + if self.device is not None: + self.model.to(self.device) + if self.dtype is not None: + self.model.to(self.dtype) + t1 = time.time() + self.model_id = model_id + log.debug(f'Control {what} model loaded: id="{model_id}" path="{model_path}" time={t1-t0:.2f}') + return f'{what} loaded model: {model_id}' + except Exception as e: + log.error(f'Control {what} model load failed: id="{model_id}" error={e}') + errors.display(e, f'Control {what} load') + return f'{what} failed to load model: {model_id}' class ControlNetXSPipeline(): diff --git a/modules/deepbooru.py b/modules/deepbooru.py index 719bd2737..f961a4351 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -1,11 +1,13 @@ import os import re +import threading import torch import numpy as np from PIL import Image from modules import modelloader, paths, deepbooru_model, devices, images, shared re_special = re.compile(r'([\\()])') +load_lock = threading.Lock() class DeepDanbooru: @@ -13,22 +15,23 @@ def __init__(self): self.model = None def load(self): - if self.model is not None: - return - model_path = os.path.join(paths.models_path, "DeepDanbooru") - shared.log.debug(f'Load interrogate model: type=DeepDanbooru folder="{model_path}"') - files = modelloader.load_models( - model_path=model_path, - model_url='https://github.com/AUTOMATIC1111/TorchDeepDanbooru/releases/download/v1/model-resnet_custom_v3.pt', - ext_filter=[".pt"], - download_name='model-resnet_custom_v3.pt', - ) + with load_lock: + if self.model is not None: + return + model_path = os.path.join(paths.models_path, "DeepDanbooru") + shared.log.debug(f'Load interrogate model: type=DeepDanbooru folder="{model_path}"') + files = modelloader.load_models( + model_path=model_path, + model_url='https://github.com/AUTOMATIC1111/TorchDeepDanbooru/releases/download/v1/model-resnet_custom_v3.pt', + ext_filter=[".pt"], + download_name='model-resnet_custom_v3.pt', + ) - self.model = deepbooru_model.DeepDanbooruModel() - self.model.load_state_dict(torch.load(files[0], map_location="cpu")) + self.model = deepbooru_model.DeepDanbooruModel() + self.model.load_state_dict(torch.load(files[0], map_location="cpu")) - self.model.eval() - self.model.to(devices.cpu, devices.dtype) + self.model.eval() + self.model.to(devices.cpu, devices.dtype) def start(self): self.load() diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 36b5ea382..ce1896169 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -67,6 +67,8 @@ def image_from_url_text(filedata): filedata = filedata[len("data:image/webp;base64,"):] if filedata.startswith("data:image/jpeg;base64,"): filedata = filedata[len("data:image/jpeg;base64,"):] + if filedata.startswith("data:image/jxl;base64,"): + filedata = filedata[len("data:image/jxl;base64,"):] filedata = base64.decodebytes(filedata.encode('utf-8')) image = Image.open(io.BytesIO(filedata)) images.read_info_from_image(image) @@ -122,12 +124,11 @@ def connect_paste_params_buttons(): if binding.tabname not in paste_fields: debug(f"Not not registered: tab={binding.tabname}") continue + """ + # legacy code that sets width/height based on image itself instead of metadata destination_image_component = paste_fields[binding.tabname]["init_img"] - fields = paste_fields[binding.tabname]["fields"] - override_settings_component = binding.override_settings_component or paste_fields[binding.tabname]["override_settings_component"] destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None) destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None) - if binding.source_image_component and destination_image_component: if isinstance(binding.source_image_component, gr.Gallery): func = send_image_and_dimensions if destination_width_component else image_from_url_text @@ -142,6 +143,9 @@ def connect_paste_params_buttons(): outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component], show_progress=False, ) + """ + fields = paste_fields[binding.tabname]["fields"] + override_settings_component = binding.override_settings_component or paste_fields[binding.tabname]["override_settings_component"] if binding.source_text_component is not None and fields is not None: connect_paste(binding.paste_button, fields, binding.source_text_component, override_settings_component, binding.tabname) if binding.source_tabname is not None and fields is not None and binding.source_tabname in paste_fields: diff --git a/modules/gr_tempdir.py b/modules/gr_tempdir.py index 0ee15b314..bbe2b2192 100644 --- a/modules/gr_tempdir.py +++ b/modules/gr_tempdir.py @@ -71,6 +71,7 @@ def pil_to_temp_file(self, img: Image, dir: str, format="png") -> str: # pylint: img.already_saved_as = name size = os.path.getsize(name) shared.log.debug(f'Save temp: image="{name}" width={img.width} height={img.height} size={size}') + shared.state.image_history += 1 params = ', '.join([f'{k}: {v}' for k, v in img.info.items()]) params = params[12:] if params.startswith('parameters: ') else params with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file: @@ -93,7 +94,7 @@ def cleanup_tmpdr(): for root, _dirs, files in os.walk(temp_dir, topdown=False): for name in files: _, extension = os.path.splitext(name) - if extension != ".png" and extension != ".jpg" and extension != ".webp": + if extension not in {".png", ".jpg", ".webp", ".jxl"}: continue filename = os.path.join(root, name) os.remove(filename) diff --git a/modules/history.py b/modules/history.py index 63c1669c2..3d6f672e8 100644 --- a/modules/history.py +++ b/modules/history.py @@ -47,13 +47,13 @@ def list(self): @property def selected(self): if self.index >= 0 and self.index < self.count: - index = self.index + current_index = self.index self.index = -1 else: - index = 0 - item = self.latents[index] - shared.log.debug(f'History get: index={index} time={item.ts} shape={item.latent.shape} dtype={item.latent.dtype} count={self.count}') - return item.latent.to(devices.device), index + current_index = 0 + item = self.latents[current_index] + shared.log.debug(f'History get: index={current_index} time={item.ts} shape={item.latent.shape} dtype={item.latent.dtype} count={self.count}') + return item.latent.to(devices.device), current_index def find(self, name): for i, item in enumerate(self.latents): @@ -62,6 +62,7 @@ def find(self, name): return -1 def add(self, latent, preview=None, info=None, ops=[]): + shared.state.latent_history += 1 if shared.opts.latent_history == 0: return if torch.is_tensor(latent): diff --git a/modules/images.py b/modules/images.py index 2cfbe941d..00e5969ce 100644 --- a/modules/images.py +++ b/modules/images.py @@ -15,6 +15,7 @@ from modules.images_grid import image_grid, get_grid_size, split_grid, combine_grid, check_grid_size, get_font, draw_grid_annotations, draw_prompt_matrix, GridAnnotation, Grid # pylint: disable=unused-import from modules.images_resize import resize_image # pylint: disable=unused-import from modules.images_namegen import FilenameGenerator, get_next_sequence_number # pylint: disable=unused-import +from modules.video import save_video # pylint: disable=unused-import debug = errors.log.trace if os.environ.get('SD_PATH_DEBUG', None) is not None else lambda *args, **kwargs: None @@ -29,6 +30,7 @@ def atomically_save_image(): Image.MAX_IMAGE_PIXELS = None # disable check in Pillow and rely on check below to allow large custom image sizes while True: image, filename, extension, params, exifinfo, filename_txt = save_queue.get() + shared.state.image_history += 1 with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file: file.write(exifinfo) fn = filename + extension @@ -49,6 +51,7 @@ def atomically_save_image(): shared.log.info(f'Save: text="{filename_txt}" len={len(exifinfo)}') except Exception as e: shared.log.warning(f'Save failed: description={filename_txt} {e}') + # actual save if image_format == 'PNG': pnginfo_data = PngImagePlugin.PngInfo() @@ -70,6 +73,14 @@ def atomically_save_image(): save_args = { 'optimize': True, 'quality': shared.opts.jpeg_quality, 'lossless': shared.opts.webp_lossless } if shared.opts.image_metadata: save_args['exif'] = piexif.dump({ "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(exifinfo, encoding="unicode") } }) + elif image_format == 'JXL': + if image.mode == 'I;16': + image = image.point(lambda p: p * 0.0038910505836576).convert("RGB") + elif image.mode not in {"RGB", "RGBA"}: + image = image.convert("RGBA") + save_args = { 'optimize': True, 'quality': shared.opts.jpeg_quality, 'lossless': shared.opts.webp_lossless } + if shared.opts.image_metadata: + save_args['exif'] = piexif.dump({ "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(exifinfo, encoding="unicode") } }) else: save_args = { 'quality': shared.opts.jpeg_quality } try: @@ -79,6 +90,7 @@ def atomically_save_image(): errors.display(e, 'Image save') size = os.path.getsize(fn) if os.path.exists(fn) else 0 shared.log.info(f'Save: image="{fn}" type={image_format} width={image.width} height={image.height} size={size}') + if shared.opts.save_log_fn != '' and len(exifinfo) > 0: fn = os.path.join(paths.data_path, shared.opts.save_log_fn) if not fn.endswith('.json'): @@ -179,76 +191,6 @@ def save_image(image, return params.filename, filename_txt, exifinfo -def save_video_atomic(images, filename, video_type: str = 'none', duration: float = 2.0, loop: bool = False, interpolate: int = 0, scale: float = 1.0, pad: int = 1, change: float = 0.3): - try: - import cv2 - except Exception as e: - shared.log.error(f'Save video: cv2: {e}') - return - os.makedirs(os.path.dirname(filename), exist_ok=True) - if video_type.lower() == 'mp4': - frames = images - if interpolate > 0: - try: - import modules.rife - frames = modules.rife.interpolate(images, count=interpolate, scale=scale, pad=pad, change=change) - except Exception as e: - shared.log.error(f'RIFE interpolation: {e}') - errors.display(e, 'RIFE interpolation') - video_frames = [np.array(frame) for frame in frames] - fourcc = "mp4v" - h, w, _c = video_frames[0].shape - video_writer = cv2.VideoWriter(filename, fourcc=cv2.VideoWriter_fourcc(*fourcc), fps=len(frames)/duration, frameSize=(w, h)) - for i in range(len(video_frames)): - img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR) - video_writer.write(img) - size = os.path.getsize(filename) - shared.log.info(f'Save video: file="{filename}" frames={len(frames)} duration={duration} fourcc={fourcc} size={size}') - if video_type.lower() == 'gif' or video_type.lower() == 'png': - append = images.copy() - image = append.pop(0) - if loop: - append += append[::-1] - frames=len(append) + 1 - image.save( - filename, - save_all = True, - append_images = append, - optimize = False, - duration = 1000.0 * duration / frames, - loop = 0 if loop else 1, - ) - size = os.path.getsize(filename) - shared.log.info(f'Save video: file="{filename}" frames={len(append) + 1} duration={duration} loop={loop} size={size}') - - -def save_video(p, images, filename = None, video_type: str = 'none', duration: float = 2.0, loop: bool = False, interpolate: int = 0, scale: float = 1.0, pad: int = 1, change: float = 0.3, sync: bool = False): - if images is None or len(images) < 2 or video_type is None or video_type.lower() == 'none': - return None - image = images[0] - if p is not None: - seed = p.all_seeds[0] if getattr(p, 'all_seeds', None) is not None else p.seed - prompt = p.all_prompts[0] if getattr(p, 'all_prompts', None) is not None else p.prompt - namegen = FilenameGenerator(p, seed=seed, prompt=prompt, image=image) - else: - namegen = FilenameGenerator(None, seed=0, prompt='', image=image) - if filename is None and p is not None: - filename = namegen.apply(shared.opts.samples_filename_pattern if shared.opts.samples_filename_pattern and len(shared.opts.samples_filename_pattern) > 0 else "[seq]-[prompt_words]") - filename = os.path.join(shared.opts.outdir_video, filename) - filename = namegen.sequence(filename, shared.opts.outdir_video, '') - else: - if os.pathsep not in filename: - filename = os.path.join(shared.opts.outdir_video, filename) - if not filename.lower().endswith(video_type.lower()): - filename += f'.{video_type.lower()}' - filename = namegen.sanitize(filename) - if not sync: - threading.Thread(target=save_video_atomic, args=(images, filename, video_type, duration, loop, interpolate, scale, pad, change)).start() - else: - save_video_atomic(images, filename, video_type, duration, loop, interpolate, scale, pad, change) - return filename - - def safe_decode_string(s: bytes): remove_prefix = lambda text, prefix: text[len(prefix):] if text.startswith(prefix) else text # pylint: disable=unnecessary-lambda-assignment for encoding in ['utf-8', 'utf-16', 'ascii', 'latin_1', 'cp1252', 'cp437']: # try different encodings diff --git a/modules/img2img.py b/modules/img2img.py index 2e3eca54d..7a5a33cd0 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -139,7 +139,8 @@ def img2img(id_task: str, state: str, mode: int, sampler_index, mask_blur, mask_alpha, inpainting_fill, - full_quality, detailer, tiling, hidiffusion, + full_quality, tiling, hidiffusion, + detailer_enabled, detailer_prompt, detailer_negative, detailer_steps, detailer_strength, n_iter, batch_size, cfg_scale, image_cfg_scale, diffusers_guidance_rescale, @@ -158,14 +159,15 @@ def img2img(id_task: str, state: str, mode: int, hdr_mode, hdr_brightness, hdr_color, hdr_sharpen, hdr_clamp, hdr_boundary, hdr_threshold, hdr_maximize, hdr_max_center, hdr_max_boundry, hdr_color_picker, hdr_tint_ratio, enable_hr, hr_sampler_index, hr_denoising_strength, hr_resize_mode, hr_resize_context, hr_upscaler, hr_force, hr_second_pass_steps, hr_scale, hr_resize_x, hr_resize_y, refiner_steps, hr_refiner_start, refiner_prompt, refiner_negative, override_settings_texts, - *args): # pylint: disable=unused-argument + *args): + + + debug(f'img2img: {id_task}') if shared.sd_model is None: shared.log.warning('Aborted: op=img model not loaded') return [], '', '', 'Error: model not loaded' - debug(f'img2img: id_task={id_task}|mode={mode}|prompt={prompt}|negative_prompt={negative_prompt}|prompt_styles={prompt_styles}|init_img={init_img}|sketch={sketch}|init_img_with_mask={init_img_with_mask}|inpaint_color_sketch={inpaint_color_sketch}|inpaint_color_sketch_orig={inpaint_color_sketch_orig}|init_img_inpaint={init_img_inpaint}|init_mask_inpaint={init_mask_inpaint}|steps={steps}|sampler_index={sampler_index}||mask_blur={mask_blur}|mask_alpha={mask_alpha}|inpainting_fill={inpainting_fill}|full_quality={full_quality}|detailer={detailer}|tiling={tiling}|hidiffusion={hidiffusion}|n_iter={n_iter}|batch_size={batch_size}|cfg_scale={cfg_scale}|image_cfg_scale={image_cfg_scale}|clip_skip={clip_skip}|denoising_strength={denoising_strength}|seed={seed}|subseed{subseed}|subseed_strength={subseed_strength}|seed_resize_from_h={seed_resize_from_h}|seed_resize_from_w={seed_resize_from_w}|selected_scale_tab={selected_scale_tab}|height={height}|width={width}|scale_by={scale_by}|resize_mode={resize_mode}|resize_name={resize_name}|resize_context={resize_context}|inpaint_full_res={inpaint_full_res}|inpaint_full_res_padding={inpaint_full_res_padding}|inpainting_mask_invert={inpainting_mask_invert}|img2img_batch_files={img2img_batch_files}|img2img_batch_input_dir={img2img_batch_input_dir}|img2img_batch_output_dir={img2img_batch_output_dir}|img2img_batch_inpaint_mask_dir={img2img_batch_inpaint_mask_dir}|override_settings_texts={override_settings_texts}') - if sampler_index is None: shared.log.warning('Sampler: invalid') sampler_index = 0 @@ -240,9 +242,13 @@ def img2img(id_task: str, state: str, mode: int, width=width, height=height, full_quality=full_quality, - detailer=detailer, tiling=tiling, hidiffusion=hidiffusion, + detailer_enabled=detailer_enabled, + detailer_prompt=detailer_prompt, + detailer_negative=detailer_negative, + detailer_steps=detailer_steps, + detailer_strength=detailer_strength, init_images=[image], mask=mask, mask_blur=mask_blur, diff --git a/modules/infotext.py b/modules/infotext.py index baa995c88..017691681 100644 --- a/modules/infotext.py +++ b/modules/infotext.py @@ -108,7 +108,7 @@ def parse(infotext): params[f"{key}-2"] = int(size.group(2)) elif isinstance(params[key], str): params[key] = val - debug(f'Param parsed: type={type(params[key])} {key}={params[key]} raw="{val}"') + debug(f'Param parsed: type={type(params[key])} "{key}"={params[key]} raw="{val}"') # check_lora(params) return params diff --git a/modules/instantir/ip_adapter/ip_adapter.py b/modules/instantir/ip_adapter/ip_adapter.py index 10f01d4f3..7f4bcbc45 100644 --- a/modules/instantir/ip_adapter/ip_adapter.py +++ b/modules/instantir/ip_adapter/ip_adapter.py @@ -169,8 +169,6 @@ def load_from_checkpoint(self, ckpt_path: str): if "latents" in state_dict["image_proj"] and "latents" in self.image_proj.state_dict(): # Check if the shapes are mismatched if state_dict["image_proj"]["latents"].shape != self.image_proj.state_dict()["latents"].shape: - print(f"Shapes of 'image_proj.latents' in checkpoint {ckpt_path} and current model do not match.") - print("Removing 'latents' from checkpoint and loading the rest of the weights.") del state_dict["image_proj"]["latents"] strict_load_image_proj_model = False diff --git a/modules/interrogate.py b/modules/interrogate.py index 6cdaae0c1..1cf9aee95 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -3,6 +3,7 @@ import time from collections import namedtuple from pathlib import Path +import threading import re import torch import torch.hub # pylint: disable=ungrouped-imports @@ -34,6 +35,7 @@ clip_model_name = 'ViT-L/14' Category = namedtuple("Category", ["name", "topn", "items"]) re_topn = re.compile(r"\.top(\d+)\.") +load_lock = threading.Lock() def category_types(): @@ -97,34 +99,35 @@ def checkpoint_wrapper(self): sys.modules["fairscale.nn.checkpoint.checkpoint_activations"] = FakeFairscale def load_blip_model(self): - self.create_fake_fairscale() - from repositories.blip import models # pylint: disable=unused-import - from repositories.blip.models import blip - import modules.modelloader as modelloader - model_path = os.path.join(paths.models_path, "BLIP") - download_name='model_base_caption_capfilt_large.pth' - shared.log.debug(f'Model interrogate load: type=BLiP model={download_name} path={model_path}') - files = modelloader.load_models( - model_path=model_path, - model_url='https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth', - ext_filter=[".pth"], - download_name=download_name, - ) - blip_model = blip.blip_decoder(pretrained=files[0], image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json")) # pylint: disable=c-extension-no-member - blip_model.eval() - - return blip_model + with load_lock: + self.create_fake_fairscale() + from repositories.blip import models # pylint: disable=unused-import + from repositories.blip.models import blip + import modules.modelloader as modelloader + model_path = os.path.join(paths.models_path, "BLIP") + download_name='model_base_caption_capfilt_large.pth' + shared.log.debug(f'Model interrogate load: type=BLiP model={download_name} path={model_path}') + files = modelloader.load_models( + model_path=model_path, + model_url='https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth', + ext_filter=[".pth"], + download_name=download_name, + ) + blip_model = blip.blip_decoder(pretrained=files[0], image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json")) # pylint: disable=c-extension-no-member + blip_model.eval() + return blip_model def load_clip_model(self): - shared.log.debug(f'Model interrogate load: type=CLiP model={clip_model_name} path={shared.opts.clip_models_path}') - import clip - if self.running_on_cpu: - model, preprocess = clip.load(clip_model_name, device="cpu", download_root=shared.opts.clip_models_path) - else: - model, preprocess = clip.load(clip_model_name, download_root=shared.opts.clip_models_path) - model.eval() - model = model.to(devices.device) - return model, preprocess + with load_lock: + shared.log.debug(f'Model interrogate load: type=CLiP model={clip_model_name} path={shared.opts.clip_models_path}') + import clip + if self.running_on_cpu: + model, preprocess = clip.load(clip_model_name, device="cpu", download_root=shared.opts.clip_models_path) + else: + model, preprocess = clip.load(clip_model_name, download_root=shared.opts.clip_models_path) + model.eval() + model = model.to(devices.device) + return model, preprocess def load(self): if self.blip_model is None: diff --git a/modules/k-diffusion b/modules/k-diffusion index 21d12c91a..8018de0b4 160000 --- a/modules/k-diffusion +++ b/modules/k-diffusion @@ -1 +1 @@ -Subproject commit 21d12c91ad4550e8fcf3308ff9fe7116b3f19a08 +Subproject commit 8018de0b43da8d66617f3ef10d3f2a41c1d78836 diff --git a/modules/linfusion/linfusion.py b/modules/linfusion/linfusion.py index 724cf2f3f..d8317e39b 100644 --- a/modules/linfusion/linfusion.py +++ b/modules/linfusion/linfusion.py @@ -89,9 +89,7 @@ def construct_for( pipe_name_path = pipe_name_path or pipeline._internal_dict._name_or_path # pylint: disable=protected-access pretrained_model_name_or_path = model_dict.get(pipe_name_path, None) if pretrained_model_name_or_path: - print( - f"Matching LinFusion '{pretrained_model_name_or_path}' for pipeline '{pipe_name_path}'." - ) + pass else: raise RuntimeError( f"LinFusion not found for pipeline [{pipe_name_path}], please provide the path." diff --git a/modules/loader.py b/modules/loader.py index 38d942fd4..c48afa7a9 100644 --- a/modules/loader.py +++ b/modules/loader.py @@ -72,6 +72,12 @@ logging.getLogger("diffusers.loaders.single_file").setLevel(logging.ERROR) timer.startup.record("diffusers") +try: + import pillow_jxl # pylint: disable=W0611,C0411 +except Exception: + pass +from PIL import Image # pylint: disable=W0611,C0411 +timer.startup.record("pillow") # patch different progress bars import tqdm as tqdm_lib # pylint: disable=C0411 diff --git a/modules/lora/extra_networks_lora.py b/modules/lora/extra_networks_lora.py index e89e57b6e..448b214ac 100644 --- a/modules/lora/extra_networks_lora.py +++ b/modules/lora/extra_networks_lora.py @@ -2,7 +2,7 @@ import os import re import numpy as np -from modules.lora import networks +from modules.lora import networks, network_overrides from modules import extra_networks, shared @@ -155,17 +155,23 @@ def activate(self, p, params_list, step=0, include=[], exclude=[]): fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access debug_log(f'Load network: type=LoRA include={include} exclude={exclude} requested={requested} fn={fn}') - networks.network_load(names, te_multipliers, unet_multipliers, dyn_dims) # load - has_changed = self.changed(requested, include, exclude) - if has_changed: - networks.network_deactivate(include, exclude) - networks.network_activate(include, exclude) - debug_log(f'Load network: type=LoRA previous={[n.name for n in networks.previously_loaded_networks]} current={[n.name for n in networks.loaded_networks]} changed') - - if len(networks.loaded_networks) > 0 and len(networks.applied_layers) > 0 and step == 0: + force_diffusers = network_overrides.check_override() + if force_diffusers: + has_changed = False # diffusers handle their own loading + if len(exclude) == 0: + networks.network_load(names, te_multipliers, unet_multipliers, dyn_dims) # load only on first call + else: + networks.network_load(names, te_multipliers, unet_multipliers, dyn_dims) # load + has_changed = self.changed(requested, include, exclude) + if has_changed: + networks.network_deactivate(include, exclude) + networks.network_activate(include, exclude) + debug_log(f'Load network: type=LoRA previous={[n.name for n in networks.previously_loaded_networks]} current={[n.name for n in networks.loaded_networks]} changed') + + if len(networks.loaded_networks) > 0 and (len(networks.applied_layers) > 0 or force_diffusers) and step == 0: infotext(p) prompt(p) - if has_changed and len(include) == 0: # print only once + if (has_changed or force_diffusers) and len(include) == 0: # print only once shared.log.info(f'Load network: type=LoRA apply={[n.name for n in networks.loaded_networks]} mode={"fuse" if shared.opts.lora_fuse_diffusers else "backup"} te={te_multipliers} unet={unet_multipliers} time={networks.timer.summary}') def deactivate(self, p): diff --git a/modules/lora/lora_convert.py b/modules/lora/lora_convert.py index 032ffa5a3..c2685aacd 100644 --- a/modules/lora/lora_convert.py +++ b/modules/lora/lora_convert.py @@ -205,8 +205,6 @@ def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None): up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0 ) i += dims[j] - # if is_sparse: - # print(f"weight is sparse: {sds_key}") # make ai-toolkit weight ait_down_keys = [k + ".lora_down.weight" for k in ait_keys] diff --git a/modules/lora/lora_extract.py b/modules/lora/lora_extract.py index 58cd065bb..a226b6017 100644 --- a/modules/lora/lora_extract.py +++ b/modules/lora/lora_extract.py @@ -265,7 +265,7 @@ def gr_show(visible=True): auto_rank.change(fn=lambda x: gr_show(x), inputs=[auto_rank], outputs=[rank_ratio]) extract.click( - fn=wrap_gradio_gpu_call(make_lora, extra_outputs=[]), + fn=wrap_gradio_gpu_call(make_lora, extra_outputs=[], name='LoRA'), inputs=[filename, rank, auto_rank, rank_ratio, modules, overwrite], outputs=[status] ) diff --git a/modules/lora/network_overrides.py b/modules/lora/network_overrides.py index 5334f3c1b..81bd07ce1 100644 --- a/modules/lora/network_overrides.py +++ b/modules/lora/network_overrides.py @@ -29,12 +29,17 @@ # 'sd3', 'kandinsky', 'hunyuandit', + 'hunyuanvideo', 'auraflow', ] force_classes = [ # forced always ] +fuse_ignore = [ + 'hunyuanvideo', +] + def check_override(shorthash=''): force = False @@ -47,3 +52,6 @@ def check_override(shorthash=''): if force and shared.opts.lora_maybe_diffusers: shared.log.debug('LoRA override: force diffusers') return force + +def check_fuse(): + return shared.sd_model_type in fuse_ignore diff --git a/modules/lora/networks.py b/modules/lora/networks.py index ebc30b6dd..9e981a234 100644 --- a/modules/lora/networks.py +++ b/modules/lora/networks.py @@ -267,7 +267,8 @@ def network_load(names, te_multipliers=None, unet_multipliers=None, dyn_dims=Non failed_to_load_networks.append(name) shared.log.error(f'Load network: type=LoRA name="{name}" detected={network_on_disk.sd_version if network_on_disk is not None else None} failed') continue - shared.sd_model.embedding_db.load_diffusers_embedding(None, net.bundle_embeddings) + if hasattr(shared.sd_model, 'embedding_db'): + shared.sd_model.embedding_db.load_diffusers_embedding(None, net.bundle_embeddings) net.te_multiplier = te_multipliers[i] if te_multipliers else shared.opts.extra_networks_default_multiplier net.unet_multiplier = unet_multipliers[i] if unet_multipliers else shared.opts.extra_networks_default_multiplier net.dyn_dim = dyn_dims[i] if dyn_dims else shared.opts.extra_networks_default_multiplier @@ -282,7 +283,7 @@ def network_load(names, te_multipliers=None, unet_multipliers=None, dyn_dims=Non try: t0 = time.time() shared.sd_model.set_adapters(adapter_names=diffuser_loaded, adapter_weights=diffuser_scales) - if shared.opts.lora_fuse_diffusers: + if shared.opts.lora_fuse_diffusers and not network_overrides.check_fuse(): shared.sd_model.fuse_lora(adapter_names=diffuser_loaded, lora_scale=1.0, fuse_unet=True, fuse_text_encoder=True) # fuse uses fixed scale since later apply does the scaling shared.sd_model.unload_lora_weights() timer.activate += time.time() - t0 diff --git a/modules/meissonic/transformer.py b/modules/meissonic/transformer.py index 43e77ddc7..543c30108 100644 --- a/modules/meissonic/transformer.py +++ b/modules/meissonic/transformer.py @@ -670,15 +670,11 @@ def __init__( self.upsample = None def forward(self, x): - # print("before,", x.shape) if self.downsample is not None: - # print('downsample') x = self.downsample(x) if self.upsample is not None: - # print('upsample') x = self.upsample(x) - # print("after,", x.shape) return x diff --git a/modules/model_flux.py b/modules/model_flux.py index 9b42de3e7..8fe147223 100644 --- a/modules/model_flux.py +++ b/modules/model_flux.py @@ -84,7 +84,7 @@ def load_flux_bnb(checkpoint_info, diffusers_load_config): # pylint: disable=unu repo_path = checkpoint_info else: repo_path = checkpoint_info.path - model_quant.load_bnb('Load model: type=T5') + model_quant.load_bnb('Load model: type=FLUX') quant = model_quant.get_quant(repo_path) try: if quant == 'fp8': @@ -203,7 +203,8 @@ def load_transformer(file_path): # triggered by opts.sd_unet change "torch_dtype": devices.dtype, "cache_dir": shared.opts.hfcache_dir, } - shared.log.info(f'Load module: type=UNet/Transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} quant={quant} dtype={devices.dtype}') + if quant is not None and quant != 'none': + shared.log.info(f'Load module: type=UNet/Transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} prequant={quant} dtype={devices.dtype}') if 'gguf' in file_path.lower(): # _transformer, _text_encoder_2 = load_flux_gguf(file_path) from modules import ggml @@ -214,25 +215,35 @@ def load_transformer(file_path): # triggered by opts.sd_unet change _transformer, _text_encoder_2 = load_flux_quanto(file_path) if _transformer is not None: transformer = _transformer - elif quant == 'fp8' or quant == 'fp4' or quant == 'nf4' or 'Model' in shared.opts.bnb_quantization: + elif quant == 'fp8' or quant == 'fp4' or quant == 'nf4': _transformer, _text_encoder_2 = load_flux_bnb(file_path, diffusers_load_config) if _transformer is not None: transformer = _transformer elif 'nf4' in quant: # TODO flux: fix loader for civitai nf4 models from modules.model_flux_nf4 import load_flux_nf4 - _transformer, _text_encoder_2 = load_flux_nf4(file_path) + _transformer, _text_encoder_2 = load_flux_nf4(file_path, prequantized=True) if _transformer is not None: transformer = _transformer else: - quant_args = {} - quant_args = model_quant.create_bnb_config(quant_args) + quant_args = model_quant.create_bnb_config({}) if quant_args: - model_quant.load_bnb(f'Load model: type=Sana quant={quant_args}') - if not quant_args: - quant_args = model_quant.create_ao_config(quant_args) - if quant_args: - model_quant.load_torchao(f'Load model: type=Sana quant={quant_args}') - transformer = diffusers.FluxTransformer2DModel.from_single_file(file_path, **diffusers_load_config, **quant_args) + model_quant.load_bnb(f'Load model: type=FLUX quant={quant_args}') + shared.log.info(f'Load module: type=UNet/Transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} quant=bnb dtype={devices.dtype}') + from modules.model_flux_nf4 import load_flux_nf4 + transformer, _text_encoder_2 = load_flux_nf4(file_path, prequantized=False) + if transformer is not None: + return transformer + quant_args = model_quant.create_ao_config({}) + if quant_args: + model_quant.load_torchao(f'Load model: type=FLUX quant={quant_args}') + shared.log.info(f'Load module: type=UNet/Transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} quant=torchao dtype={devices.dtype}') + transformer = diffusers.FluxTransformer2DModel.from_single_file(file_path, **diffusers_load_config, **quant_args) + if transformer is not None: + return transformer + shared.log.info(f'Load module: type=UNet/Transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} quant=none dtype={devices.dtype}') + # TODO flux transformer from-single-file with quant + # shared.log.warning('Load module: type=UNet/Transformer does not support load-time quantization') + transformer = diffusers.FluxTransformer2DModel.from_single_file(file_path, **diffusers_load_config) if transformer is None: shared.log.error('Failed to load UNet model') shared.opts.sd_unet = 'None' diff --git a/modules/model_flux_nf4.py b/modules/model_flux_nf4.py index d023907d6..b00c3320e 100644 --- a/modules/model_flux_nf4.py +++ b/modules/model_flux_nf4.py @@ -24,7 +24,6 @@ def _replace_with_bnb_linear( ): """ Private method that wraps the recursion for module replacement. - Returns the converted model and a boolean that indicates if the conversion has been successfull or not. """ bnb = model_quant.load_bnb('Load model: type=FLUX') @@ -106,7 +105,6 @@ def create_quantized_param( new_value = old_value.to(target_device) else: new_value = param_value.to(target_device) - new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad) module._parameters[tensor_name] = new_value # pylint: disable=protected-access return @@ -121,13 +119,8 @@ def create_quantized_param( raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {target_device}.") if pre_quantized: - if (param_name + ".quant_state.bitsandbytes__fp4" not in state_dict) and ( - param_name + ".quant_state.bitsandbytes__nf4" not in state_dict - ): - raise ValueError( - f"Supplied state dict for {param_name} does not contain `bitsandbytes__*` and possibly other `quantized_stats` components." - ) - + if (param_name + ".quant_state.bitsandbytes__fp4" not in state_dict) and (param_name + ".quant_state.bitsandbytes__nf4" not in state_dict): + raise ValueError(f"Supplied state dict for {param_name} does not contain `bitsandbytes__*` and possibly other `quantized_stats` components.") quantized_stats = {} for k, v in state_dict.items(): # `startswith` to counter for edge cases where `param_name` @@ -136,23 +129,20 @@ def create_quantized_param( quantized_stats[k] = v if unexpected_keys is not None and k in unexpected_keys: unexpected_keys.remove(k) - new_value = bnb.nn.Params4bit.from_prequantized( data=param_value, quantized_stats=quantized_stats, requires_grad=False, device=target_device, ) - else: new_value = param_value.to("cpu") kwargs = old_value.__dict__ new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(target_device) - module._parameters[tensor_name] = new_value # pylint: disable=protected-access -def load_flux_nf4(checkpoint_info): +def load_flux_nf4(checkpoint_info, prequantized: bool = True): transformer = None text_encoder_2 = None if isinstance(checkpoint_info, str): @@ -197,7 +187,7 @@ def load_flux_nf4(checkpoint_info): if not check_quantized_param(transformer, param_name): set_module_tensor_to_device(transformer, param_name, device=0, value=param) else: - create_quantized_param(transformer, param, param_name, target_device=0, state_dict=original_state_dict, pre_quantized=True) + create_quantized_param(transformer, param, param_name, target_device=0, state_dict=original_state_dict, pre_quantized=prequantized) except Exception as e: transformer, text_encoder_2 = None, None shared.log.error(f"Load model: type=FLUX failed to load UNET: {e}") diff --git a/modules/model_stablecascade.py b/modules/model_stablecascade.py index 2a7739e55..3c3339dca 100644 --- a/modules/model_stablecascade.py +++ b/modules/model_stablecascade.py @@ -256,7 +256,7 @@ def __call__( if isinstance(self.scheduler, diffusers.DDPMWuerstchenScheduler): timesteps = timesteps[:-1] else: - if hasattr(self.scheduler.config, "clip_sample") and self.scheduler.config.clip_sample: + if hasattr(self.scheduler.config, "clip_sample") and self.scheduler.config.clip_sample: # pylint: disable=no-member self.scheduler.config.clip_sample = False # disample sample clipping # 6. Run denoising loop diff --git a/modules/modeldata.py b/modules/modeldata.py index 4b7ec1776..deb4ac49a 100644 --- a/modules/modeldata.py +++ b/modules/modeldata.py @@ -37,6 +37,8 @@ def get_model_type(pipe): model_type = 'cogvideox' elif "Sana" in name: model_type = 'sana' + elif 'HunyuanVideoPipeline' in name: + model_type = 'hunyuanvideo' else: model_type = name return model_type diff --git a/modules/omnigen/model.py b/modules/omnigen/model.py index 3a42263d2..17d696b53 100644 --- a/modules/omnigen/model.py +++ b/modules/omnigen/model.py @@ -259,7 +259,6 @@ def cropped_pos_embed(self, height, width): left = (self.pos_embed_max_size - width) // 2 spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1) spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :] - # print(top, top + height, left, left + width, spatial_pos_embed.size()) spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) return spatial_pos_embed diff --git a/modules/paths.py b/modules/paths.py index 9a73a9f91..134bcd2ab 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -2,18 +2,20 @@ import os import sys import json +import shlex import argparse from modules.errors import log # parse args, parse again after we have the data-dir and early-read the config file +argv = shlex.split(" ".join(sys.argv[1:])) if "USED_VSCODE_COMMAND_PICKARGS" in os.environ else sys.argv[1:] parser = argparse.ArgumentParser(add_help=False) parser.add_argument("--ckpt", type=str, default=os.environ.get("SD_MODEL", None), help="Path to model checkpoint to load immediately, default: %(default)s") parser.add_argument("--data-dir", type=str, default=os.environ.get("SD_DATADIR", ''), help="Base path where all user data is stored, default: %(default)s") parser.add_argument("--models-dir", type=str, default=os.environ.get("SD_MODELSDIR", None), help="Base path where all models are stored, default: %(default)s",) -cli = parser.parse_known_args()[0] -parser.add_argument("--config", type=str, default=os.environ.get("SD_CONFIG", os.path.join(cli.data_dir, 'config.json')), help="Use specific server configuration file, default: %(default)s") -cli = parser.parse_known_args()[0] +cli = parser.parse_known_args(argv)[0] +parser.add_argument("--config", type=str, default=os.environ.get("SD_CONFIG", os.path.join(cli.data_dir, 'config.json')), help="Use specific server configuration file, default: %(default)s") # twice because we want data_dir +cli = parser.parse_known_args(argv)[0] config_path = cli.config if os.path.isabs(cli.config) else os.path.join(cli.data_dir, cli.config) try: with open(config_path, 'r', encoding='utf8') as f: diff --git a/modules/pixelsmith/__init__.py b/modules/pixelsmith/__init__.py new file mode 100644 index 000000000..219d65ab8 --- /dev/null +++ b/modules/pixelsmith/__init__.py @@ -0,0 +1,2 @@ +from .pixelsmith_pipeline import PixelSmithXLPipeline +from .autoencoder_kl import PixelSmithVAE diff --git a/modules/pixelsmith/autoencoder_kl.py b/modules/pixelsmith/autoencoder_kl.py new file mode 100644 index 000000000..90c2f9462 --- /dev/null +++ b/modules/pixelsmith/autoencoder_kl.py @@ -0,0 +1,496 @@ +# Original: + +from typing import Dict, Optional, Tuple, Union +import gc +import torch +import torch.nn as nn +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders.single_file_model import FromOriginalModelMixin +from diffusers.utils.accelerate_utils import apply_forward_hook +from diffusers.models.attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + Attention, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, +) +from diffusers.models.modeling_outputs import AutoencoderKLOutput +from diffusers.models.modeling_utils import ModelMixin +from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder + + +class PixelSmithVAE(ModelMixin, ConfigMixin, FromOriginalModelMixin): + r""" + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): + Tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): Sample input size. + scaling_factor (`float`, *optional*, defaults to 0.18215): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. + force_upcast (`bool`, *optional*, default to `True`): + If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE + can be fine-tuned / trained to a lower range without loosing too much precision in which case + `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str] = ("DownEncoderBlock2D",), + up_block_types: Tuple[str] = ("UpDecoderBlock2D",), + block_out_channels: Tuple[int] = (64,), + layers_per_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 4, + norm_num_groups: int = 32, + sample_size: int = 32, + scaling_factor: float = 0.18215, + force_upcast: float = True, + ): + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + ) + + # pass init params to Decoder + self.decoder = Decoder( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + ) + + self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) + self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) + + self.use_slicing = False + self.use_tiling = False + + # only relevant if vae tiling is enabled + self.tile_sample_min_size = self.config.sample_size + sample_size = ( + self.config.sample_size[0] + if isinstance(self.config.sample_size, (list, tuple)) + else self.config.sample_size + ) + self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) + self.tile_overlap_factor = 0.25 + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (Encoder, Decoder)): + module.gradient_checkpointing = value + + def enable_tiling(self, use_tiling: bool = True): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.use_tiling = use_tiling + + def disable_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.enable_tiling(False) + + def enable_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + @property + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor( + self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False + ): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor, _remove_lora=_remove_lora) + else: + module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor, _remove_lora=True) + + @apply_forward_hook + def encode( + self, x: torch.FloatTensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.FloatTensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded images. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size): + return self.tiled_encode(x, return_dict=return_dict) + + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self.encoder(x) + + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size): + return self.tiled_decode(z, return_dict=return_dict) + + z = self.post_quant_conv(z) + dec = self.decoder(z) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode( + self, z: torch.FloatTensor, return_dict: bool = True, generator=None + ) -> Union[DecoderOutput, torch.FloatTensor]: + """ + Decode a batch of images. + + Args: + z (`torch.FloatTensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[2], b.shape[2], blend_extent) + for y in range(blend_extent): + b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for x in range(blend_extent): + b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) + return b + + def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.FloatTensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`: + If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain + `tuple` is returned. + """ + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[2], overlap_size): + row = [] + for j in range(0, x.shape[3], overlap_size): + tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] + tile = self.encoder(tile.to("cuda")) + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + # + del row + gc.collect() + torch.cuda.empty_cache() + # + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + # + del result_row + gc.collect() + torch.cuda.empty_cache() + # + + moments = torch.cat(result_rows, dim=2) + # + del result_rows + gc.collect() + torch.cuda.empty_cache() + # + posterior = DiagonalGaussianDistribution(moments) + # + del moments + gc.collect() + torch.cuda.empty_cache() + # + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.FloatTensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + row_limit = self.tile_sample_min_size - blend_extent + + # Split z into overlapping 64x64 tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, z.shape[2], overlap_size): + row = [] + for j in range(0, z.shape[3], overlap_size): + tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size] + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile).to("cpu") + row.append(decoded) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + dec = torch.cat(result_rows, dim=2) + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.FloatTensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.FloatTensor]: + r""" + Args: + sample (`torch.FloatTensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, + key, value) are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) diff --git a/modules/pixelsmith/pixelsmith_pipeline.py b/modules/pixelsmith/pixelsmith_pipeline.py new file mode 100644 index 000000000..4e04b2d92 --- /dev/null +++ b/modules/pixelsmith/pixelsmith_pipeline.py @@ -0,0 +1,1882 @@ +# Original: + +from typing import Any, Dict, List, Optional, Tuple, Union +import inspect +import numpy as np +import matplotlib.pyplot as plt +import torch +import torch.nn.functional as F + +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from diffusers.models import ImageProjection, UNet2DConditionModel +from diffusers.models.attention_processor import ( + Attention, + AttnProcessor2_0, + FusedAttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor, +) +from diffusers.models.lora import adjust_lora_scale_text_encoder +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from .autoencoder_kl import PixelSmithVAE + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +plt.rcParams['figure.dpi'] = 300 +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionXLPipeline + + >>> pipe = StableDiffusionXLPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + ``` +""" +#+# +class PAGIdentitySelfAttnProcessor: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + # chunk + hidden_states_org, hidden_states_ptb = hidden_states.chunk(2) + + # original path + batch_size, sequence_length, _ = hidden_states_org.shape + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states_org) + key = attn.to_k(hidden_states_org) + value = attn.to_v(hidden_states_org) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states_org = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_org = hidden_states_org.to(query.dtype) + + # linear proj + hidden_states_org = attn.to_out[0](hidden_states_org) + # dropout + hidden_states_org = attn.to_out[1](hidden_states_org) + + if input_ndim == 4: + hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # perturbed path (identity attention) + batch_size, sequence_length, _ = hidden_states_ptb.shape + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2) + + value = attn.to_v(hidden_states_ptb) + + # hidden_states_ptb = torch.zeros(value.shape).to(value.get_device()) + hidden_states_ptb = value + + hidden_states_ptb = hidden_states_ptb.to(query.dtype) + + # linear proj + hidden_states_ptb = attn.to_out[0](hidden_states_ptb) + # dropout + hidden_states_ptb = attn.to_out[1](hidden_states_ptb) + + if input_ndim == 4: + hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # cat + hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class PAGCFGIdentitySelfAttnProcessor: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + # chunk + hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3) + hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org]) + + # original path + batch_size, sequence_length, _ = hidden_states_org.shape + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states_org) + key = attn.to_k(hidden_states_org) + value = attn.to_v(hidden_states_org) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states_org = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_org = hidden_states_org.to(query.dtype) + + # linear proj + hidden_states_org = attn.to_out[0](hidden_states_org) + # dropout + hidden_states_org = attn.to_out[1](hidden_states_org) + + if input_ndim == 4: + hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # perturbed path (identity attention) + batch_size, sequence_length, _ = hidden_states_ptb.shape + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2) + + value = attn.to_v(hidden_states_ptb) + hidden_states_ptb = value + hidden_states_ptb = hidden_states_ptb.to(query.dtype) + + # linear proj + hidden_states_ptb = attn.to_out[0](hidden_states_ptb) + # dropout + hidden_states_ptb = attn.to_out[1](hidden_states_ptb) + + if input_ndim == 4: + hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # cat + hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, + `timesteps` must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` + must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + + +class PixelSmithXLPipeline( + DiffusionPipeline, + #StableDiffusionMixin, + FromSingleFileMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to + watermark output images. If not defined, it will default to True if the package is installed, otherwise no + watermarker will be used. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "image_encoder", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "negative_add_time_ids", + ] + + def __init__( + self, + vae: PixelSmithVAE, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + self.default_sample_size = self.unet.config.sample_size + + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + image_embeds = [] + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) + single_negative_image_embeds = torch.stack( + [single_negative_image_embeds] * num_images_per_prompt, dim=0 + ) + + if do_classifier_free_guidance: + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + single_image_embeds = single_image_embeds.to(device) + + image_embeds.append(single_image_embeds) + else: + repeat_dims = [1] + image_embeds = [] + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) + single_negative_image_embeds = single_negative_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) + ) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + else: + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) + image_embeds.append(single_image_embeds) + + return image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + FusedAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.FloatTensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.FloatTensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + #+# + + def pred_z0(self, sample, model_output, timestep): + alpha_prod_t = self.scheduler.alphas_cumprod[timestep].to(sample.device) + + beta_prod_t = 1 - alpha_prod_t + if self.scheduler.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif self.scheduler.config.prediction_type == "sample": + pred_original_sample = model_output + elif self.scheduler.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + # predict V + model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," + " or `v_prediction`" + ) + + return pred_original_sample + + def pred_x0(self, latents, noise_pred, t, generator, device, prompt_embeds, output_type): + pred_z0 = self.pred_z0(latents, noise_pred, t) + pred_x0 = self.vae.decode( + pred_z0 / self.vae.config.scaling_factor, + return_dict=False, + generator=generator + )[0] + do_denormalize = [True] * pred_x0.shape[0] + pred_x0 = self.image_processor.postprocess(pred_x0, output_type=output_type, do_denormalize=do_denormalize) + + return pred_x0 + + #+# + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_end(self): + return self._denoising_end + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + #+# + + @property + def pag_scale(self): + return self._pag_scale + + @property + def do_adversarial_guidance(self): + return self._pag_scale > 0 + + @property + def pag_adaptive_scaling(self): + return self._pag_adaptive_scaling + + @property + def do_pag_adaptive_scaling(self): + return self._pag_adaptive_scaling > 0 + + @property + def pag_drop_rate(self): + return self._pag_drop_rate + + @property + def pag_applied_layers(self): + return self._pag_applied_layers + + @property + def pag_applied_layers_index(self): + return self._pag_applied_layers_index + #+# + + def _random_crop(self, z, i, j, patch_size): + p=patch_size//2 + return z[...,i-p:i+p, j-p:j+p] + + def get_value_coordinates(self, tensor): + value_indices = torch.nonzero(tensor == tensor.max(), as_tuple=False) + random_indices = value_indices[torch.randperm(value_indices.size(0))] + return random_indices + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + #+# + pag_scale: float = 0.0, # longer inference time if used (https://ku-cvlab.github.io/Perturbed-Attention-Guidance/) + pag_adaptive_scaling: float = 0.0, + pag_drop_rate: float = 0.5, + pag_applied_layers: List[str] = ['mid'], #['down', 'mid', 'up'] + pag_applied_layers_index: List[str] = None, #['d4', 'd5', 'm0'] + #+# + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + image = None, + slider = None, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. + Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding + if `do_classifier_free_guidance` is set to `True`. + If not provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + image('pil'): + Upscaled image from previous step + slider('int'): + Freedom of the model to be more generative or closer to the input + + Examples: + + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._interrupt = False + + #+# + self._pag_scale = pag_scale + self._pag_adaptive_scaling = pag_adaptive_scaling + self._pag_drop_rate = pag_drop_rate + self._pag_applied_layers = pag_applied_layers + self._pag_applied_layers_index = pag_applied_layers_index + #+# + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + ) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if self.do_classifier_free_guidance and not self.do_adversarial_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + #pag + elif not self.do_classifier_free_guidance and self.do_adversarial_guidance: + prompt_embeds = torch.cat([prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([add_text_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + #both + elif self.do_classifier_free_guidance and self.do_adversarial_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids, add_time_ids], dim=0) + #+# + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 8.1 Apply denoising_end + if ( + self.denoising_end is not None + and isinstance(self.denoising_end, float) + and self.denoising_end > 0 + and self.denoising_end < 1 + ): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 9. Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + #+# + # 10. Create down mid and up layer lists + if self.do_adversarial_guidance: + down_layers = [] + mid_layers = [] + up_layers = [] + for name, module in self.unet.named_modules(): + if 'attn1' in name and 'to' not in name: + layer_type = name.split('.')[0].split('_')[0] + if layer_type == 'down': + down_layers.append(module) + elif layer_type == 'mid': + mid_layers.append(module) + elif layer_type == 'up': + up_layers.append(module) + else: + raise ValueError(f"Invalid layer type: {layer_type}") + #+# + + self._num_timesteps = len(timesteps) + if image is None: + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + #+# + # #cfg + if self.do_classifier_free_guidance and not self.do_adversarial_guidance: + latent_model_input = torch.cat([latents] * 2) + #pag + elif not self.do_classifier_free_guidance and self.do_adversarial_guidance: + latent_model_input = torch.cat([latents] * 2) + #both + elif self.do_classifier_free_guidance and self.do_adversarial_guidance: + latent_model_input = torch.cat([latents] * 3) + #no + else: + latent_model_input = latents + + # change attention layer in UNet if use PAG + if self.do_adversarial_guidance: + + if self.do_classifier_free_guidance: + replace_processor = PAGCFGIdentitySelfAttnProcessor() + else: + replace_processor = PAGIdentitySelfAttnProcessor() + + if self.pag_applied_layers_index: + drop_layers = self.pag_applied_layers_index + for drop_layer in drop_layers: + layer_number = int(drop_layer[1:]) + try: + if drop_layer[0] == 'd': + down_layers[layer_number].processor = replace_processor + elif drop_layer[0] == 'm': + mid_layers[layer_number].processor = replace_processor + elif drop_layer[0] == 'u': + up_layers[layer_number].processor = replace_processor + else: + raise ValueError(f"Invalid layer type: {drop_layer[0]}") + except IndexError as err: + raise ValueError(f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers.") from err + elif self.pag_applied_layers: + drop_full_layers = self.pag_applied_layers + for drop_full_layer in drop_full_layers: + try: + if drop_full_layer == "down": + for down_layer in down_layers: + down_layer.processor = replace_processor + elif drop_full_layer == "mid": + for mid_layer in mid_layers: + mid_layer.processor = replace_processor + elif drop_full_layer == "up": + for up_layer in up_layers: + up_layer.processor = replace_processor + else: + raise ValueError(f"Invalid layer type: {drop_full_layer}") + except IndexError as err: + raise ValueError(f"Invalid layer index: {drop_full_layer}. Available layers are: down, mid and up. If you need to specify each layer index, you can use `pag_applied_layers_index`") from err + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = image_embeds + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance and not self.do_adversarial_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + # pag + elif not self.do_classifier_free_guidance and self.do_adversarial_guidance: + noise_pred_original, noise_pred_perturb = noise_pred.chunk(2) + signal_scale = self.pag_scale + if self.do_pag_adaptive_scaling: + signal_scale = self.pag_scale - self.pag_adaptive_scaling * (1000-t) + if signal_scale<0: + signal_scale = 0 + noise_pred = noise_pred_original + signal_scale * (noise_pred_original - noise_pred_perturb) + # both + elif self.do_classifier_free_guidance and self.do_adversarial_guidance: + noise_pred_uncond, noise_pred_text, noise_pred_text_perturb = noise_pred.chunk(3) + signal_scale = self.pag_scale + if self.do_pag_adaptive_scaling: + signal_scale = self.pag_scale - self.pag_adaptive_scaling * (1000-t) + if signal_scale<0: + signal_scale = 0 + noise_pred = noise_pred_text + (self.guidance_scale-1.0) * (noise_pred_text - noise_pred_uncond) + signal_scale * (noise_pred_text - noise_pred_text_perturb) + #+# + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + else: + noise = torch.randn(image.shape, dtype=torch.float16).to(self.unet.device) + guid_latents = self.scheduler.add_noise(image, noise, timesteps[:slider]) + guid_latents = [guid_latents[i:i+1] for i in range(guid_latents.size(0))] + latents = guid_latents[0] + + _b, _c, latent_size_h, latent_size_w=latents.shape + times=torch.ones((1,1,latent_size_h,latent_size_w)).int().to(self.device)*timesteps.max().item() + patch_size = 128 + p=patch_size//2 + + @torch.no_grad() + def create_gradient_border(mask, gradient_width=5): + """ + Needed to average overlapping patches + """ + mask = mask.float().to(self.unet.device) + inverted_mask = mask + distances = F.conv2d(inverted_mask, torch.ones(1, 1, 1, 1, device=device), padding=0) + distance_mask = distances <= gradient_width + kernel_size = gradient_width * 2 + 1 + kernel = torch.ones(1, 1, kernel_size, kernel_size, device=device) / (kernel_size ** 2) + padded_mask = F.pad(inverted_mask, (gradient_width, gradient_width, gradient_width, gradient_width), mode='reflect') + smoothed_distances = F.conv2d(padded_mask, kernel, padding=0).clamp(0, 1) + smoothed_mask = (mask + (1 - mask) * smoothed_distances * distance_mask.float()).clamp(0, 1) + return smoothed_mask + + prev_latents=latents.clone() + + while times.float().mean() >= 0: + + random_indices=self.get_value_coordinates(times[0,0])[0] + i=torch.clamp(random_indices,p,latent_size_h-p).tolist()[0] + j=torch.clamp(random_indices,p,latent_size_w-p).tolist()[1] + + # random patch cropping + sub_latents=self._random_crop(latents, i, j, patch_size) + sub_prev_latents=self._random_crop(prev_latents, i, j, patch_size) + sub_time=self._random_crop(times, i, j, patch_size) + + t = times.max() + ii = torch.where(t==timesteps)[0].item() + + if ii < slider: + sub_guid_latents = self._random_crop(guid_latents[ii], i, j, patch_size) + if ii < len(guid_latents)-1 and ii < slider: + sub_guid_latents_ahead = self._random_crop(guid_latents[ii+1], i, j, patch_size) + + print(f"\r PixelSmith progress: {(1 - times.float().mean() / timesteps.max().item()) * 100:.2f}%",end="") + + if sub_time.float().mean() > 0: + + # Compute the FFT of both sets of latents + fft_sub_latents = torch.fft.rfft2(sub_latents, dim=(-2, -1), norm='ortho') + fft_sub_guid_latents = torch.fft.rfft2(sub_guid_latents, dim=(-2, -1), norm='ortho') + # Calculate magnitude and phase for both FFTs + magnitude_latents = torch.abs(fft_sub_latents) + complex_latents = torch.exp(1j * torch.angle(fft_sub_latents)) + complex_guid_latents = torch.exp(1j * torch.angle(fft_sub_guid_latents)) + # Use the arg function to mix phases + if ii < slider: + mixed_phase = torch.angle(complex_latents + complex_guid_latents) + else: + mixed_phase = torch.angle(fft_sub_latents) + # Reconstruct the complex number using the mixed phase and the original magnitude + fft_sub_latents = magnitude_latents * torch.exp(1j * mixed_phase) + sub_latents = torch.fft.irfft2(fft_sub_latents, dim=(-2, -1), norm='ortho') + + # Generate random numbers for shift directions + shift_left = torch.rand(1).item() < 0.5 + shift_down = torch.rand(1).item() < 0.5 + # + d_rate = 2 + mask_first_row = torch.zeros(1, patch_size) + mask_first_row[:, ::d_rate] = 1 + mask_second_row = torch.roll(mask_first_row, shifts=1, dims=1) + for _d in range(1, d_rate): + stacked_rows = torch.concatenate((mask_first_row, mask_second_row), axis=-2) + den_mask = torch.tile(stacked_rows, (patch_size//stacked_rows.shape[0], 1)).to(self.device) + den_mask = den_mask[np.newaxis, np.newaxis, ...].to(self.unet.dtype) + den_mask = torch.roll(den_mask, shifts=(-1 if shift_down else 0, -1 if shift_left else 0), dims=(2, 3)) + + uniques=torch.unique(sub_time) + vmax=uniques[-1] + time_mask=torch.where(sub_time==vmax, 1, 0).to(self.device) + if len(uniques)>1: + sub_latents=sub_latents*time_mask+sub_prev_latents*(time_mask==0) + + #+# + #cfg + if self.do_classifier_free_guidance and not self.do_adversarial_guidance: + latent_model_input = torch.cat([sub_latents] * 2) + #pag + elif not self.do_classifier_free_guidance and self.do_adversarial_guidance: + latent_model_input = torch.cat([sub_latents] * 2) + #both + elif self.do_classifier_free_guidance and self.do_adversarial_guidance: + latent_model_input = torch.cat([sub_latents] * 3) + #no + else: + latent_model_input = sub_latents + + # change attention layer in UNet if use PAG + if self.do_adversarial_guidance: + + if self.do_classifier_free_guidance: + replace_processor = PAGCFGIdentitySelfAttnProcessor() + else: + replace_processor = PAGIdentitySelfAttnProcessor() + if self.pag_applied_layers_index: + drop_layers = self.pag_applied_layers_index + for drop_layer in drop_layers: + layer_number = int(drop_layer[1:]) + try: + if drop_layer[0] == 'd': + down_layers[layer_number].processor = replace_processor + elif drop_layer[0] == 'm': + mid_layers[layer_number].processor = replace_processor + elif drop_layer[0] == 'u': + up_layers[layer_number].processor = replace_processor + else: + raise ValueError(f"Invalid layer type: {drop_layer[0]}") + except IndexError as err: + raise ValueError(f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers.") from err + elif self.pag_applied_layers: + drop_full_layers = self.pag_applied_layers + for drop_full_layer in drop_full_layers: + try: + if drop_full_layer == "down": + for down_layer in down_layers: + down_layer.processor = replace_processor + elif drop_full_layer == "mid": + for mid_layer in mid_layers: + mid_layer.processor = replace_processor + elif drop_full_layer == "up": + for up_layer in up_layers: + up_layer.processor = replace_processor + else: + raise ValueError(f"Invalid layer type: {drop_full_layer}") + except IndexError as err: + raise ValueError(f"Invalid layer index: {drop_full_layer}. Available layers are: down, mid and up. If you need to specify each layer index, you can use `pag_applied_layers_index`") from err + #+# + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, sub_time.max().item()).to(self.unet.dtype) + + add_time_ids[:,0] = latents.shape[-2] * self.vae_scale_factor + add_time_ids[:,4] = latents.shape[-2] * self.vae_scale_factor + add_time_ids[:,1] = latents.shape[-1] * self.vae_scale_factor + add_time_ids[:,5] = latents.shape[-1] * self.vae_scale_factor + add_time_ids[:,2] = (j-64) * self.vae_scale_factor #top + add_time_ids[:,3] = (j-64) * self.vae_scale_factor #left + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = image_embeds + noise_pred = self.unet( + latent_model_input, + sub_time.max().item(), + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance and not self.do_adversarial_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + # pag + elif not self.do_classifier_free_guidance and self.do_adversarial_guidance: + noise_pred_original, noise_pred_perturb = noise_pred.chunk(2) + signal_scale = self.pag_scale + if self.do_pag_adaptive_scaling: + signal_scale = self.pag_scale - self.pag_adaptive_scaling * (1000-sub_time.max().item()) + if signal_scale<0: + signal_scale = 0 + noise_pred = noise_pred_original + signal_scale * (noise_pred_original - noise_pred_perturb) + # both + elif self.do_classifier_free_guidance and self.do_adversarial_guidance: + noise_pred_uncond, noise_pred_text, noise_pred_text_perturb = noise_pred.chunk(3) + signal_scale = self.pag_scale + if self.do_pag_adaptive_scaling: + signal_scale = self.pag_scale - self.pag_adaptive_scaling * (1000-sub_time.max().item()) + if signal_scale<0: + signal_scale = 0 + noise_pred = noise_pred_text + (self.guidance_scale-1.0) * (noise_pred_text - noise_pred_uncond) + signal_scale * (noise_pred_text - noise_pred_text_perturb) + #+# + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + # compute the previous noisy sample x_t -> x_t-1 + try: + sub_latents = self.scheduler.step(noise_pred, sub_time.max().item(), sub_latents, **extra_step_kwargs, return_dict=False)[0] + except Exception as e: + print('PixelSmith', e) + + smoothed_time_mask = create_gradient_border(time_mask, gradient_width=10) + full_replace_mask = smoothed_time_mask == 1 + no_replace_mask = smoothed_time_mask == 0 + gradient_mask = (smoothed_time_mask > 0) & (smoothed_time_mask < 1) + + if ii(timesteps.min().item()): + next_timestep_index = (timesteps == sub_time.max()).nonzero(as_tuple=True)[0][-1] + next_timestep = timesteps[next_timestep_index + 1].item() + times[...,i-p:i+p,j-p:j+p]=torch.where(sub_time==sub_time.max(), torch.ones_like(sub_time).to(sub_time.device)*next_timestep, sub_time) + else: + times[...,i-p:i+p,j-p:j+p]=torch.where(sub_time==sub_time.max(), torch.ones_like(sub_time).to(sub_time.device)*0, sub_time) + + if torch.all(times == times.max()): + prev_latents=latents.clone() + + if times.float().mean()==0: + break + + if output_type != "latent": + image = self.vae.tiled_decode(latents.to(self.vae.dtype) / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + + image = self.image_processor.postprocess(image, output_type=output_type) + + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + #+# + #Change the attention layers back to original ones after PAG was applied + if self.do_adversarial_guidance: + if self.pag_applied_layers_index: + drop_layers = self.pag_applied_layers_index + for drop_layer in drop_layers: + layer_number = int(drop_layer[1:]) + try: + if drop_layer[0] == 'd': + down_layers[layer_number].processor = AttnProcessor2_0() + elif drop_layer[0] == 'm': + mid_layers[layer_number].processor = AttnProcessor2_0() + elif drop_layer[0] == 'u': + up_layers[layer_number].processor = AttnProcessor2_0() + else: + raise ValueError(f"Invalid layer type: {drop_layer[0]}") + except IndexError as err: + raise ValueError(f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers.") from err + elif self.pag_applied_layers: + drop_full_layers = self.pag_applied_layers + for drop_full_layer in drop_full_layers: + try: + if drop_full_layer == "down": + for down_layer in down_layers: + down_layer.processor = AttnProcessor2_0() + elif drop_full_layer == "mid": + for mid_layer in mid_layers: + mid_layer.processor = AttnProcessor2_0() + elif drop_full_layer == "up": + for up_layer in up_layers: + up_layer.processor = AttnProcessor2_0() + else: + raise ValueError(f"Invalid layer type: {drop_full_layer}") + except IndexError as err: + raise ValueError(f"Invalid layer index: {drop_full_layer}. Available layers are: down, mid and up. If you need to specify each layer index, you can use `pag_applied_layers_index`") from err + #+# + + return ImagePipelineOutput(images=image) + +# %% diff --git a/modules/pixelsmith/vae.py b/modules/pixelsmith/vae.py new file mode 100644 index 000000000..98e051695 --- /dev/null +++ b/modules/pixelsmith/vae.py @@ -0,0 +1,979 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn + +from diffusers.utils import BaseOutput, is_torch_version +from diffusers.utils.torch_utils import randn_tensor +from diffusers.models.activations import get_activation +from diffusers.models.attention_processor import SpatialNorm +from diffusers.models.unets.unet_2d_blocks import ( + AutoencoderTinyBlock, + UNetMidBlock2D, + get_down_block, + get_up_block, +) + + +@dataclass +class DecoderOutput(BaseOutput): + r""" + Output of decoding method. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + The decoded output sample from the last layer of the model. + """ + + sample: torch.FloatTensor + + +class Encoder(nn.Module): + r""" + The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available + options. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + double_z (`bool`, *optional*, defaults to `True`): + Whether to double the number of output channels for the last block. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",), + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + double_z: bool = True, + mid_block_add_attention=True, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[0], + kernel_size=3, + stride=1, + padding=1, + ) + + self.mid_block = None + self.down_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=self.layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + add_downsample=not is_final_block, + resnet_eps=1e-6, + downsample_padding=0, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attention_head_dim=output_channel, + temb_channels=None, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default", + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=None, + add_attention=mid_block_add_attention, + ) + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + + conv_out_channels = 2 * out_channels if double_z else out_channels + self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor: + r"""The forward method of the `Encoder` class.""" + + sample = self.conv_in(sample) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + # down + if is_torch_version(">=", "1.11.0"): + for down_block in self.down_blocks: + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(down_block), sample, use_reentrant=False + ) + # middle + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), sample, use_reentrant=False + ) + else: + for down_block in self.down_blocks: + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample) + # middle + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) + + else: + # down + for down_block in self.down_blocks: + sample = down_block(sample) + + # middle + sample = self.mid_block(sample) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample + + +class Decoder(nn.Module): + r""" + The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + norm_type (`str`, *optional*, defaults to `"group"`): + The normalization type to use. Can be either `"group"` or `"spatial"`. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",), + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + norm_type: str = "group", # group, spatial + mid_block_add_attention=True, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[-1], + kernel_size=3, + stride=1, + padding=1, + ) + + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + temb_channels = in_channels if norm_type == "spatial" else None + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default" if norm_type == "group" else norm_type, + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=temb_channels, + add_attention=mid_block_add_attention, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = get_up_block( + up_block_type, + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + prev_output_channel=None, + add_upsample=not is_final_block, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attention_head_dim=output_channel, + temb_channels=temb_channels, + resnet_time_scale_shift=norm_type, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_type == "spatial": + self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) + else: + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward( + self, + sample: torch.FloatTensor, + latent_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + r"""The forward method of the `Decoder` class.""" + + sample = self.conv_in(sample) + + upscale_dtype = next(iter(self.up_blocks.parameters())).dtype + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + # middle + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), + sample, + latent_embeds, + use_reentrant=False, + ) + sample = sample.to(upscale_dtype) + + # up + for up_block in self.up_blocks: + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(up_block), + sample, + latent_embeds, + use_reentrant=False, + ) + else: + # middle + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), sample, latent_embeds + ) + sample = sample.to(upscale_dtype) + + # up + for up_block in self.up_blocks: + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds) + else: + # middle + sample = self.mid_block(sample, latent_embeds) + sample = sample.to(upscale_dtype) + + # up + for up_block in self.up_blocks: + sample = up_block(sample, latent_embeds) + + # post-process + if latent_embeds is None: + sample = self.conv_norm_out(sample) + else: + sample = self.conv_norm_out(sample, latent_embeds) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample + + +class UpSample(nn.Module): + r""" + The `UpSample` layer of a variational autoencoder that upsamples its input. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1) + + def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: + r"""The forward method of the `UpSample` class.""" + x = torch.relu(x) + x = self.deconv(x) + return x + + +class MaskConditionEncoder(nn.Module): + """ + used in AsymmetricAutoencoderKL + """ + + def __init__( + self, + in_ch: int, + out_ch: int = 192, + res_ch: int = 768, + stride: int = 16, + ) -> None: + super().__init__() + + channels = [] + while stride > 1: + stride = stride // 2 + in_ch_ = out_ch * 2 + if out_ch > res_ch: + out_ch = res_ch + if stride == 1: + in_ch_ = res_ch + channels.append((in_ch_, out_ch)) + out_ch *= 2 + + out_channels = [] + for _in_ch, _out_ch in channels: + out_channels.append(_out_ch) + out_channels.append(channels[-1][0]) + + layers = [] + in_ch_ = in_ch + for l in range(len(out_channels)): + out_ch_ = out_channels[l] + if l == 0 or l == 1: + layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=3, stride=1, padding=1)) + else: + layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=4, stride=2, padding=1)) + in_ch_ = out_ch_ + + self.layers = nn.Sequential(*layers) + + def forward(self, x: torch.FloatTensor, mask=None) -> torch.FloatTensor: + r"""The forward method of the `MaskConditionEncoder` class.""" + out = {} + for l in range(len(self.layers)): + layer = self.layers[l] + x = layer(x) + out[str(tuple(x.shape))] = x + x = torch.relu(x) + return out + + +class MaskConditionDecoder(nn.Module): + r"""The `MaskConditionDecoder` should be used in combination with [`AsymmetricAutoencoderKL`] to enhance the model's + decoder with a conditioner on the mask and masked image. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + norm_type (`str`, *optional*, defaults to `"group"`): + The normalization type to use. Can be either `"group"` or `"spatial"`. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",), + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + norm_type: str = "group", # group, spatial + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[-1], + kernel_size=3, + stride=1, + padding=1, + ) + + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + temb_channels = in_channels if norm_type == "spatial" else None + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default" if norm_type == "group" else norm_type, + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=temb_channels, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = get_up_block( + up_block_type, + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + prev_output_channel=None, + add_upsample=not is_final_block, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attention_head_dim=output_channel, + temb_channels=temb_channels, + resnet_time_scale_shift=norm_type, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # condition encoder + self.condition_encoder = MaskConditionEncoder( + in_ch=out_channels, + out_ch=block_out_channels[0], + res_ch=block_out_channels[-1], + ) + + # out + if norm_type == "spatial": + self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) + else: + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward( + self, + z: torch.FloatTensor, + image: Optional[torch.FloatTensor] = None, + mask: Optional[torch.FloatTensor] = None, + latent_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + r"""The forward method of the `MaskConditionDecoder` class.""" + sample = z + sample = self.conv_in(sample) + + upscale_dtype = next(iter(self.up_blocks.parameters())).dtype + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + # middle + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), + sample, + latent_embeds, + use_reentrant=False, + ) + sample = sample.to(upscale_dtype) + + # condition encoder + if image is not None and mask is not None: + masked_image = (1 - mask) * image + im_x = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.condition_encoder), + masked_image, + mask, + use_reentrant=False, + ) + + # up + for up_block in self.up_blocks: + if image is not None and mask is not None: + sample_ = im_x[str(tuple(sample.shape))] + mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest") + sample = sample * mask_ + sample_ * (1 - mask_) + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(up_block), + sample, + latent_embeds, + use_reentrant=False, + ) + if image is not None and mask is not None: + sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask) + else: + # middle + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), sample, latent_embeds + ) + sample = sample.to(upscale_dtype) + + # condition encoder + if image is not None and mask is not None: + masked_image = (1 - mask) * image + im_x = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.condition_encoder), + masked_image, + mask, + ) + + # up + for up_block in self.up_blocks: + if image is not None and mask is not None: + sample_ = im_x[str(tuple(sample.shape))] + mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest") + sample = sample * mask_ + sample_ * (1 - mask_) + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds) + if image is not None and mask is not None: + sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask) + else: + # middle + sample = self.mid_block(sample, latent_embeds) + sample = sample.to(upscale_dtype) + + # condition encoder + if image is not None and mask is not None: + masked_image = (1 - mask) * image + im_x = self.condition_encoder(masked_image, mask) + + # up + for up_block in self.up_blocks: + if image is not None and mask is not None: + sample_ = im_x[str(tuple(sample.shape))] + mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest") + sample = sample * mask_ + sample_ * (1 - mask_) + sample = up_block(sample, latent_embeds) + if image is not None and mask is not None: + sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask) + + # post-process + if latent_embeds is None: + sample = self.conv_norm_out(sample) + else: + sample = self.conv_norm_out(sample, latent_embeds) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample + + +class VectorQuantizer(nn.Module): + """ + Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix + multiplications and allows for post-hoc remapping of indices. + """ + + # NOTE: due to a bug the beta term was applied to the wrong term. for + # backwards compatibility we use the buggy version by default, but you can + # specify legacy=False to fix it. + def __init__( + self, + n_e: int, + vq_embed_dim: int, + beta: float, + remap=None, + unknown_index: str = "random", + sane_index_shape: bool = False, + legacy: bool = True, + ): + super().__init__() + self.n_e = n_e + self.vq_embed_dim = vq_embed_dim + self.beta = beta + self.legacy = legacy + + self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.used: torch.Tensor + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + else: + self.re_embed = n_e + + self.sane_index_shape = sane_index_shape + + def remap_to_used(self, inds: torch.LongTensor) -> torch.LongTensor: + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + match = (inds[:, :, None] == used[None, None, ...]).long() + new = match.argmax(-1) + unknown = match.sum(2) < 1 + if self.unknown_index == "random": + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds: torch.LongTensor) -> torch.LongTensor: + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds >= self.used.shape[0]] = 0 # simply set to zero + back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) + return back.reshape(ishape) + + def forward(self, z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, Tuple]: + # reshape z -> (batch, height, width, channel) and flatten + z = z.permute(0, 2, 3, 1).contiguous() + z_flattened = z.view(-1, self.vq_embed_dim) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + min_encoding_indices = torch.argmin(torch.cdist(z_flattened, self.embedding.weight), dim=1) + + z_q = self.embedding(min_encoding_indices).view(z.shape) + perplexity = None + min_encodings = None + + # compute loss for embedding + if not self.legacy: + loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2) + else: + loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients + z_q: torch.FloatTensor = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + if self.remap is not None: + min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis + min_encoding_indices = self.remap_to_used(min_encoding_indices) + min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten + + if self.sane_index_shape: + min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3]) + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices: torch.LongTensor, shape: Tuple[int, ...]) -> torch.FloatTensor: + # shape specifying (batch, height, width, channel) + if self.remap is not None: + indices = indices.reshape(shape[0], -1) # add batch axis + indices = self.unmap_to_all(indices) + indices = indices.reshape(-1) # flatten again + + # get quantized latent vectors + z_q: torch.FloatTensor = self.embedding(indices) + + if shape is not None: + z_q = z_q.view(shape) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters: torch.Tensor, deterministic: bool = False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like( + self.mean, device=self.parameters.device, dtype=self.parameters.dtype + ) + + def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor: + # make sure sample is on the same device as the parameters and has same dtype + sample = randn_tensor( + self.mean.shape, + generator=generator, + device=self.parameters.device, + dtype=self.parameters.dtype, + ) + x = self.mean + self.std * sample + return x + + def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor: + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.sum( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3], + ) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor: + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) + + def mode(self) -> torch.Tensor: + return self.mean + + +class EncoderTiny(nn.Module): + r""" + The `EncoderTiny` layer is a simpler version of the `Encoder` layer. + + Args: + in_channels (`int`): + The number of input channels. + out_channels (`int`): + The number of output channels. + num_blocks (`Tuple[int, ...]`): + Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to + use. + block_out_channels (`Tuple[int, ...]`): + The number of output channels for each block. + act_fn (`str`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + num_blocks: Tuple[int, ...], + block_out_channels: Tuple[int, ...], + act_fn: str, + ): + super().__init__() + + layers = [] + for i, num_block in enumerate(num_blocks): + num_channels = block_out_channels[i] + + if i == 0: + layers.append(nn.Conv2d(in_channels, num_channels, kernel_size=3, padding=1)) + else: + layers.append( + nn.Conv2d( + num_channels, + num_channels, + kernel_size=3, + padding=1, + stride=2, + bias=False, + ) + ) + + for _ in range(num_block): + layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn)) + + layers.append(nn.Conv2d(block_out_channels[-1], out_channels, kernel_size=3, padding=1)) + + self.layers = nn.Sequential(*layers) + self.gradient_checkpointing = False + + def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: + r"""The forward method of the `EncoderTiny` class.""" + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False) + else: + x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x) + + else: + # scale image from [-1, 1] to [0, 1] to match TAESD convention + x = self.layers(x.add(1).div(2)) + + return x + + +class DecoderTiny(nn.Module): + r""" + The `DecoderTiny` layer is a simpler version of the `Decoder` layer. + + Args: + in_channels (`int`): + The number of input channels. + out_channels (`int`): + The number of output channels. + num_blocks (`Tuple[int, ...]`): + Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to + use. + block_out_channels (`Tuple[int, ...]`): + The number of output channels for each block. + upsampling_scaling_factor (`int`): + The scaling factor to use for upsampling. + act_fn (`str`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + num_blocks: Tuple[int, ...], + block_out_channels: Tuple[int, ...], + upsampling_scaling_factor: int, + act_fn: str, + ): + super().__init__() + + layers = [ + nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1), + get_activation(act_fn), + ] + + for i, num_block in enumerate(num_blocks): + is_final_block = i == (len(num_blocks) - 1) + num_channels = block_out_channels[i] + + for _ in range(num_block): + layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn)) + + if not is_final_block: + layers.append(nn.Upsample(scale_factor=upsampling_scaling_factor)) + + conv_out_channel = num_channels if not is_final_block else out_channels + layers.append( + nn.Conv2d( + num_channels, + conv_out_channel, + kernel_size=3, + padding=1, + bias=is_final_block, + ) + ) + + self.layers = nn.Sequential(*layers) + self.gradient_checkpointing = False + + def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: + r"""The forward method of the `DecoderTiny` class.""" + # Clamp. + x = torch.tanh(x / 3) * 3 + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False) + else: + x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x) + + else: + x = self.layers(x) + + # scale image from [0, 1] to [-1, 1] to match diffusers convention + return x.mul(2).sub(1) diff --git a/modules/postprocess/yolo.py b/modules/postprocess/yolo.py index 4011147ec..85a50cc2a 100644 --- a/modules/postprocess/yolo.py +++ b/modules/postprocess/yolo.py @@ -1,5 +1,6 @@ from typing import TYPE_CHECKING import os +import threading from copy import copy import numpy as np import gradio as gr @@ -16,6 +17,7 @@ 'https://huggingface.co/vladmandic/yolo-detailers/resolve/main/eyes-v1.pt', 'https://huggingface.co/vladmandic/yolo-detailers/resolve/main/eyes-full-v1.pt', ] +load_lock = threading.Lock() class YoloResult: @@ -151,28 +153,29 @@ def predict( return result def load(self, model_name: str = None): - from modules import modelloader - model = None - self.dependencies() - if model_name is None: - model_name = list(self.list)[0] - if model_name in self.models: - return model_name, self.models[model_name] - else: - model_url = self.list.get(model_name) - file_name = os.path.basename(model_url) - model_file = None - try: - model_file = modelloader.load_file_from_url(url=model_url, model_dir=shared.opts.yolo_dir, file_name=file_name) - if model_file is not None: - import ultralytics - model = ultralytics.YOLO(model_file) - classes = list(model.names.values()) - shared.log.info(f'Load: type=Detailer name="{model_name}" model="{model_file}" ultralytics={ultralytics.__version__} classes={classes}') - self.models[model_name] = model - return model_name, model - except Exception as e: - shared.log.error(f'Load: type=Detailer name="{model_name}" error="{e}"') + with load_lock: + from modules import modelloader + model = None + self.dependencies() + if model_name is None: + model_name = list(self.list)[0] + if model_name in self.models: + return model_name, self.models[model_name] + else: + model_url = self.list.get(model_name) + file_name = os.path.basename(model_url) + model_file = None + try: + model_file = modelloader.load_file_from_url(url=model_url, model_dir=shared.opts.yolo_dir, file_name=file_name) + if model_file is not None: + import ultralytics + model = ultralytics.YOLO(model_file) + classes = list(model.names.values()) + shared.log.info(f'Load: type=Detailer name="{model_name}" model="{model_file}" ultralytics={ultralytics.__version__} classes={classes}') + self.models[model_name] = model + return model_name, model + except Exception as e: + shared.log.error(f'Load: type=Detailer name="{model_name}" error="{e}"') return None, None def restore(self, np_image, p: processing.StableDiffusionProcessing = None): @@ -207,8 +210,8 @@ def restore(self, np_image, p: processing.StableDiffusionProcessing = None): resolution = 512 if shared.sd_model_type in ['none', 'sd', 'lcm', 'unknown'] else 1024 orig_prompt: str = orig_p.get('all_prompts', [''])[0] orig_negative: str = orig_p.get('all_negative_prompts', [''])[0] - prompt: str = orig_p.get('refiner_prompt', '') - negative: str = orig_p.get('refiner_negative', '') + prompt: str = orig_p.get('detailer_prompt', '') + negative: str = orig_p.get('detailer_negative', '') if len(prompt) == 0: prompt = orig_prompt else: @@ -230,9 +233,9 @@ def restore(self, np_image, p: processing.StableDiffusionProcessing = None): 'n_iter': 1, 'prompt': prompt, 'negative_prompt': negative, - 'denoising_strength': shared.opts.detailer_strength, + 'denoising_strength': p.detailer_strength, 'sampler_name': orig_p.get('hr_sampler_name', 'default'), - 'steps': shared.opts.detailer_steps, + 'steps': p.detailer_steps, 'styles': [], 'inpaint_full_res': True, 'inpainting_mask_invert': 0, @@ -309,7 +312,6 @@ def ui(self, tab: str): def ui_settings_change(detailers, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps): shared.opts.detailer_models = detailers shared.opts.detailer_classes = classes - shared.opts.detailer_strength = strength shared.opts.detailer_padding = padding shared.opts.detailer_blur = blur shared.opts.detailer_conf = min_confidence @@ -317,9 +319,8 @@ def ui_settings_change(detailers, classes, strength, padding, blur, min_confiden shared.opts.detailer_min_size = min_size shared.opts.detailer_max_size = max_size shared.opts.detailer_iou = iou - shared.opts.detailer_steps = steps shared.opts.save(shared.config_filename, silent=True) - shared.log.debug(f'Detailer settings: models={shared.opts.detailer_models} classes={shared.opts.detailer_classes} strength={shared.opts.detailer_strength} conf={shared.opts.detailer_conf} max={shared.opts.detailer_max} iou={shared.opts.detailer_iou} size={shared.opts.detailer_min_size}-{shared.opts.detailer_max_size} padding={shared.opts.detailer_padding} steps={shared.opts.detailer_steps}') + shared.log.debug(f'Detailer settings: models={detailers} classes={classes} strength={strength} conf={min_confidence} max={max_detected} iou={iou} size={min_size}-{max_size} padding={padding} steps={steps}') with gr.Accordion(open=False, label="Detailer", elem_id=f"{tab}_detailer_accordion", elem_classes=["small-accordion"], visible=shared.native): with gr.Row(): @@ -330,8 +331,12 @@ def ui_settings_change(detailers, classes, strength, padding, blur, min_confiden with gr.Row(): classes = gr.Textbox(label="Classes", placeholder="Classes", elem_id=f"{tab}_detailer_classes") with gr.Row(): - steps = gr.Slider(label="Detailer steps", elem_id=f"{tab}_detailer_steps", value=shared.opts.detailer_steps, min=0, max=99, step=1) - strength = gr.Slider(label="Detailer strength", elem_id=f"{tab}_detailer_strength", value=shared.opts.detailer_strength, minimum=0, maximum=1, step=0.01) + prompt = gr.Textbox(label="Detailer prompt", value='', placeholder='Detailer prompt', lines=2, elem_id=f"{tab}_detailer_prompt") + with gr.Row(): + negative = gr.Textbox(label="Detailer negative prompt", value='', placeholder='Detailer negative prompt', lines=2, elem_id=f"{tab}_detailer_negative") + with gr.Row(): + steps = gr.Slider(label="Detailer steps", elem_id=f"{tab}_detailer_steps", value=10, min=0, max=99, step=1) + strength = gr.Slider(label="Detailer strength", elem_id=f"{tab}_detailer_strength", value=0.3, minimum=0, maximum=1, step=0.01) with gr.Row(): max_detected = gr.Slider(label="Max detected", elem_id=f"{tab}_detailer_max", value=shared.opts.detailer_max, min=1, maximum=10, step=1) with gr.Row(): @@ -347,7 +352,6 @@ def ui_settings_change(detailers, classes, strength, padding, blur, min_confiden max_size = gr.Slider(label="Max size", elem_id=f"{tab}_detailer_max_size", value=max_size, minimum=0.0, maximum=1.0, step=0.05) detailers.change(fn=ui_settings_change, inputs=[detailers, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps], outputs=[]) classes.change(fn=ui_settings_change, inputs=[detailers, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps], outputs=[]) - strength.change(fn=ui_settings_change, inputs=[detailers, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps], outputs=[]) padding.change(fn=ui_settings_change, inputs=[detailers, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps], outputs=[]) blur.change(fn=ui_settings_change, inputs=[detailers, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps], outputs=[]) min_confidence.change(fn=ui_settings_change, inputs=[detailers, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps], outputs=[]) @@ -355,8 +359,7 @@ def ui_settings_change(detailers, classes, strength, padding, blur, min_confiden min_size.change(fn=ui_settings_change, inputs=[detailers, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps], outputs=[]) max_size.change(fn=ui_settings_change, inputs=[detailers, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps], outputs=[]) iou.change(fn=ui_settings_change, inputs=[detailers, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps], outputs=[]) - steps.change(fn=ui_settings_change, inputs=[detailers, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps], outputs=[]) - return enabled + return enabled, prompt, negative, steps, strength def initialize(): diff --git a/modules/processing.py b/modules/processing.py index 23fde5dee..99fb9f7f3 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -53,8 +53,8 @@ def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info=None self.batch_size = max(1, p.batch_size) self.restore_faces = p.restore_faces or False self.face_restoration_model = shared.opts.face_restoration_model if p.restore_faces else None - self.detailer = p.detailer or False - self.detailer_model = shared.opts.detailer_model if p.detailer else None + self.detailer = p.detailer_enabled or False + self.detailer_model = shared.opts.detailer_model if p.detailer_enabled else None self.sd_model_hash = getattr(shared.sd_model, 'sd_model_hash', '') if model_data.sd_model is not None else '' self.seed_resize_from_w = p.seed_resize_from_w self.seed_resize_from_h = p.seed_resize_from_h @@ -280,19 +280,22 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: output_images = [] process_init(p) - if os.path.exists(shared.opts.embeddings_dir) and not p.do_not_reload_embeddings and not shared.native: + if not shared.native and os.path.exists(shared.opts.embeddings_dir) and not p.do_not_reload_embeddings: modules.sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=False) if p.scripts is not None and isinstance(p.scripts, scripts.ScriptRunner): p.scripts.process(p) ema_scope_context = p.sd_model.ema_scope if not shared.native else nullcontext - shared.state.job_count = p.n_iter + if not shared.native: + shared.state.job_count = p.n_iter with devices.inference_context(), ema_scope_context(): t0 = time.time() if not hasattr(p, 'skip_init'): p.init(p.all_prompts, p.all_seeds, p.all_subseeds) debug(f'Processing inner: args={vars(p)}') for n in range(p.n_iter): + # if hasattr(p, 'skip_processing'): + # continue pag.apply(p) debug(f'Processing inner: iteration={n+1}/{p.n_iter}') p.iteration = n @@ -371,7 +374,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: sample = face_restoration.restore_faces(sample, p) if sample is not None: image = Image.fromarray(sample) - if p.detailer: + if p.detailer_enabled: p.ops.append('detailer') if not p.do_not_save_samples and shared.opts.save_images_before_detailer: info = create_infotext(p, p.prompts, p.seeds, p.subseeds, index=i) diff --git a/modules/processing_args.py b/modules/processing_args.py index 360928de7..81f4ae5d3 100644 --- a/modules/processing_args.py +++ b/modules/processing_args.py @@ -15,6 +15,7 @@ debug_enabled = os.environ.get('SD_DIFFUSERS_DEBUG', None) debug_log = shared.log.trace if os.environ.get('SD_DIFFUSERS_DEBUG', None) is not None else lambda *args, **kwargs: None +disable_pbar = os.environ.get('SD_DISABLE_PBAR', None) is not None def task_specific_kwargs(p, model): @@ -35,7 +36,8 @@ def task_specific_kwargs(p, model): elif (sd_models.get_diffusers_task(model) == sd_models.DiffusersTaskType.IMAGE_2_IMAGE or is_img2img_model) and len(getattr(p, 'init_images', [])) > 0: if shared.sd_model_type == 'sdxl' and hasattr(model, 'register_to_config'): model.register_to_config(requires_aesthetics_score = False) - p.ops.append('img2img') + if 'hires' not in p.ops: + p.ops.append('img2img') task_args = { 'image': p.init_images, 'strength': p.denoising_strength, @@ -63,7 +65,7 @@ def task_specific_kwargs(p, model): elif (sd_models.get_diffusers_task(model) == sd_models.DiffusersTaskType.INPAINTING or is_img2img_model) and len(getattr(p, 'init_images', [])) > 0: if shared.sd_model_type == 'sdxl' and hasattr(model, 'register_to_config'): model.register_to_config(requires_aesthetics_score = False) - if p.detailer: + if p.detailer_enabled: p.ops.append('detailer') else: p.ops.append('inpaint') @@ -106,16 +108,21 @@ def set_pipeline_args(p, model, prompts:list, negative_prompts:list, prompts_2:t shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model) apply_circular(p.tiling, model) if hasattr(model, "set_progress_bar_config"): - model.set_progress_bar_config(bar_format='Progress {rate_fmt}{postfix} {bar} {percentage:3.0f}% {n_fmt}/{total_fmt} {elapsed} {remaining} ' + '\x1b[38;5;71m' + desc, ncols=80, colour='#327fba') + if disable_pbar: + model.set_progress_bar_config(bar_format='Progress {rate_fmt}{postfix} {bar} {percentage:3.0f}% {n_fmt}/{total_fmt} {elapsed} {remaining} ' + '\x1b[38;5;71m' + desc, ncols=80, colour='#327fba', disable=disable_pbar) + else: + model.set_progress_bar_config(bar_format='Progress {rate_fmt}{postfix} {bar} {percentage:3.0f}% {n_fmt}/{total_fmt} {elapsed} {remaining} ' + '\x1b[38;5;71m' + desc, ncols=80, colour='#327fba') args = {} + has_vae = hasattr(model, 'vae') or (hasattr(model, 'pipe') and hasattr(model.pipe, 'vae')) if hasattr(model, 'pipe') and not hasattr(model, 'no_recurse'): # recurse model = model.pipe + has_vae = has_vae or hasattr(model, 'vae') signature = inspect.signature(type(model).__call__, follow_wrapped=True) possible = list(signature.parameters) if debug_enabled: debug_log(f'Diffusers pipeline possible: {possible}') - prompts, negative_prompts, prompts_2, negative_prompts_2 = fix_prompts(prompts, negative_prompts, prompts_2, negative_prompts_2) + prompts, negative_prompts, prompts_2, negative_prompts_2 = fix_prompts(p, prompts, negative_prompts, prompts_2, negative_prompts_2) steps = kwargs.get("num_inference_steps", None) or len(getattr(p, 'timesteps', ['1'])) clip_skip = kwargs.pop("clip_skip", 1) @@ -172,6 +179,9 @@ def set_pipeline_args(p, model, prompts:list, negative_prompts:list, prompts_2:t p.extra_generation_params["CHI"] = chi if not chi: args['complex_human_instruction'] = None + if 'use_resolution_binning' in possible: + args['use_resolution_binning'] = True + p.extra_generation_params["Binning"] = True if prompt_parser_diffusers.embedder is not None and not prompt_parser_diffusers.embedder.scheduled_prompt: # not scheduled so we dont need it anymore prompt_parser_diffusers.embedder = None @@ -232,7 +242,7 @@ def set_pipeline_args(p, model, prompts:list, negative_prompts:list, prompts_2:t if sd_models.get_diffusers_task(model) == sd_models.DiffusersTaskType.TEXT_2_IMAGE: args['latents'] = p.init_latent if 'output_type' in possible: - if not hasattr(model, 'vae'): + if not has_vae: kwargs['output_type'] = 'np' # only set latent if model has vae # model specific @@ -271,7 +281,10 @@ def set_pipeline_args(p, model, prompts:list, negative_prompts:list, prompts_2:t args['callback'] = diffusers_callback_legacy if 'image' in kwargs: - p.init_images = kwargs['image'] if isinstance(kwargs['image'], list) else [kwargs['image']] + if isinstance(kwargs['image'], list) and isinstance(kwargs['image'][0], Image.Image): + p.init_images = kwargs['image'] + if isinstance(kwargs['image'], Image.Image): + p.init_images = [kwargs['image']] # handle remaining args for arg in kwargs: @@ -353,4 +366,14 @@ def set_pipeline_args(p, model, prompts:list, negative_prompts:list, prompts_2:t shared.log.debug(f'Profile: pipeline args: {t1-t0:.2f}') if debug_enabled: debug_log(f'Diffusers pipeline args: {args}') - return args + + _args = {} + for k, v in args.items(): # pipeline may modify underlying args + if isinstance(v, Image.Image): + _args[k] = v.copy() + elif (isinstance(v, list) and len(v) > 0 and isinstance(v[0], Image.Image)): + _args[k] = [i.copy() for i in v] + else: + _args[k] = v + + return _args diff --git a/modules/processing_callbacks.py b/modules/processing_callbacks.py index a9c6fbcf7..0191eff72 100644 --- a/modules/processing_callbacks.py +++ b/modules/processing_callbacks.py @@ -27,7 +27,7 @@ def prompt_callback(step, kwargs): assert prompt_embeds.shape == kwargs['prompt_embeds'].shape, f"prompt_embed shape mismatch {kwargs['prompt_embeds'].shape} {prompt_embeds.shape}" kwargs['prompt_embeds'] = prompt_embeds except Exception as e: - debug_callback(f"Callback: {e}") + debug_callback(f"Callback: type=prompt {e}") return kwargs @@ -56,8 +56,9 @@ def diffusers_callback(pipe, step: int = 0, timestep: int = 0, kwargs: dict = {} latents = kwargs.get('latents', None) if debug: debug_callback(f'Callback: step={step} timestep={timestep} latents={latents.shape if latents is not None else None} kwargs={list(kwargs)}') - order = getattr(pipe.scheduler, "order", 1) if hasattr(pipe, 'scheduler') else 1 - shared.state.sampling_step = step // order + shared.state.step() + # order = getattr(pipe.scheduler, "order", 1) if hasattr(pipe, 'scheduler') else 1 + # shared.state.sampling_step = step // order if shared.state.interrupted or shared.state.skipped: raise AssertionError('Interrupted...') if shared.state.paused: @@ -115,11 +116,16 @@ def diffusers_callback(pipe, step: int = 0, timestep: int = 0, kwargs: dict = {} shared.state.current_latent = kwargs['latents'] shared.state.current_noise_pred = current_noise_pred - if hasattr(pipe, "scheduler") and hasattr(pipe.scheduler, "sigmas") and hasattr(pipe.scheduler, "step_index"): - shared.state.current_sigma = pipe.scheduler.sigmas[pipe.scheduler.step_index - 1] - shared.state.current_sigma_next = pipe.scheduler.sigmas[pipe.scheduler.step_index] + if hasattr(pipe, "scheduler") and hasattr(pipe.scheduler, "sigmas") and hasattr(pipe.scheduler, "step_index") and pipe.scheduler.step_index is not None: + try: + shared.state.current_sigma = pipe.scheduler.sigmas[pipe.scheduler.step_index-1] + shared.state.current_sigma_next = pipe.scheduler.sigmas[pipe.scheduler.step_index] + except Exception: + pass except Exception as e: shared.log.error(f'Callback: {e}') + # from modules import errors + # errors.display(e, 'Callback') if shared.cmd_opts.profile and shared.profiler is not None: shared.profiler.step() t1 = time.time() diff --git a/modules/processing_class.py b/modules/processing_class.py index 5960fea76..3eeb73622 100644 --- a/modules/processing_class.py +++ b/modules/processing_class.py @@ -52,8 +52,13 @@ def __init__(self, # other hidiffusion: bool = False, do_not_reload_embeddings: bool = False, - detailer: bool = False, restore_faces: bool = False, + # detailer + detailer_enabled: bool = False, + detailer_prompt: str = '', + detailer_negative: str = '', + detailer_steps: int = 10, + detailer_strength: float = 0.3, # hdr corrections hdr_mode: int = 0, hdr_brightness: float = 0, @@ -167,7 +172,11 @@ def __init__(self, self.full_quality = full_quality self.hidiffusion = hidiffusion self.do_not_reload_embeddings = do_not_reload_embeddings - self.detailer = detailer + self.detailer_enabled = detailer_enabled + self.detailer_prompt = detailer_prompt + self.detailer_negative = detailer_negative + self.detailer_steps = detailer_steps + self.detailer_strength = detailer_strength self.restore_faces = restore_faces self.init_images = init_images self.resize_mode = resize_mode @@ -581,7 +590,7 @@ def init_hr(self, scale = None, upscaler = None, force = False): else: self.hr_upscale_to_x, self.hr_upscale_to_y = self.hr_resize_x, self.hr_resize_y # hypertile_set(self, hr=True) - shared.state.job_count = 2 * self.n_iter + # shared.state.job_count = 2 * self.n_iter shared.log.debug(f'Control hires: upscaler="{self.hr_upscaler}" scale={scale} fixed={not use_scale} size={self.hr_upscale_to_x}x{self.hr_upscale_to_y}') diff --git a/modules/processing_diffusers.py b/modules/processing_diffusers.py index d978cbe2b..c20ba85a8 100644 --- a/modules/processing_diffusers.py +++ b/modules/processing_diffusers.py @@ -6,7 +6,7 @@ import torchvision.transforms.functional as TF from PIL import Image from modules import shared, devices, processing, sd_models, errors, sd_hijack_hypertile, processing_vae, sd_models_compile, hidiffusion, timer, modelstats, extra_networks -from modules.processing_helpers import resize_hires, calculate_base_steps, calculate_hires_steps, calculate_refiner_steps, save_intermediate, update_sampler, is_txt2img, is_refiner_enabled +from modules.processing_helpers import resize_hires, calculate_base_steps, calculate_hires_steps, calculate_refiner_steps, save_intermediate, update_sampler, is_txt2img, is_refiner_enabled, get_job_name from modules.processing_args import set_pipeline_args from modules.onnx_impl import preprocess_pipeline as preprocess_onnx_pipeline, check_parameters_changed as olive_check_parameters_changed from modules.lora import networks @@ -47,14 +47,15 @@ def restore_state(p: processing.StableDiffusionProcessing): p.init_images = None if state == 'reprocess_detail': p.skip = ['encode', 'base', 'hires'] - p.detailer = True + p.detailer_enabled = True shared.log.info(f'Restore state: op={p.state} skip={p.skip}') return p def process_base(p: processing.StableDiffusionProcessing): - use_refiner_start = is_txt2img() and is_refiner_enabled(p) and not p.is_hr_pass and p.refiner_start > 0 and p.refiner_start < 1 - use_denoise_start = not is_txt2img() and p.refiner_start > 0 and p.refiner_start < 1 + txt2img = is_txt2img() + use_refiner_start = txt2img and is_refiner_enabled(p) and not p.is_hr_pass and p.refiner_start > 0 and p.refiner_start < 1 + use_denoise_start = not txt2img and p.refiner_start > 0 and p.refiner_start < 1 shared.sd_model = update_pipeline(shared.sd_model, p) update_sampler(p, shared.sd_model) @@ -76,7 +77,8 @@ def process_base(p: processing.StableDiffusionProcessing): clip_skip=p.clip_skip, desc='Base', ) - shared.state.sampling_steps = base_args.get('prior_num_inference_steps', None) or p.steps or base_args.get('num_inference_steps', None) + base_steps = base_args.get('prior_num_inference_steps', None) or p.steps or base_args.get('num_inference_steps', None) + shared.state.update(get_job_name(p, shared.sd_model), base_steps, 1) if shared.opts.scheduler_eta is not None and shared.opts.scheduler_eta > 0 and shared.opts.scheduler_eta < 1: p.extra_generation_params["Sampler Eta"] = shared.opts.scheduler_eta output = None @@ -172,8 +174,8 @@ def process_hires(p: processing.StableDiffusionProcessing, output): p.ops.append('upscale') if shared.opts.samples_save and not p.do_not_save_samples and shared.opts.save_images_before_highres_fix and hasattr(shared.sd_model, 'vae'): save_intermediate(p, latents=output.images, suffix="-before-hires") - shared.state.job = 'Upscale' - output.images = resize_hires(p, latents=output.images) + shared.state.update('Upscale', 0, 1) + output.images = resize_hires(p, latents=output.images) if output is not None else [] sd_hijack_hypertile.hypertile_set(p, hr=True) latent_upscale = shared.latent_upscale_modes.get(p.hr_upscaler, None) @@ -187,10 +189,9 @@ def process_hires(p: processing.StableDiffusionProcessing, output): # hires if p.hr_force and strength == 0: - shared.log.warning('HiRes skip: denoising=0') + shared.log.warning('Hires skip: denoising=0') p.hr_force = False if p.hr_force: - shared.state.job_count = 2 * p.n_iter shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, sd_models.DiffusersTaskType.IMAGE_2_IMAGE) if 'Upscale' in shared.sd_model.__class__.__name__ or 'Flux' in shared.sd_model.__class__.__name__ or 'Kandinsky' in shared.sd_model.__class__.__name__: output.images = processing_vae.vae_decode(latents=output.images, model=shared.sd_model, full_quality=p.full_quality, output_type='pil', width=p.width, height=p.height) @@ -200,6 +201,7 @@ def process_hires(p: processing.StableDiffusionProcessing, output): update_sampler(p, shared.sd_model, second_pass=True) orig_denoise = p.denoising_strength p.denoising_strength = strength + orig_image = p.task_args.pop('image', None) # remove image override from hires hires_args = set_pipeline_args( p=p, model=shared.sd_model, @@ -217,8 +219,8 @@ def process_hires(p: processing.StableDiffusionProcessing, output): strength=strength, desc='Hires', ) - shared.state.job = 'HiRes' - shared.state.sampling_steps = hires_args.get('prior_num_inference_steps', None) or p.steps or hires_args.get('num_inference_steps', None) + hires_steps = hires_args.get('prior_num_inference_steps', None) or p.hr_second_pass_steps or hires_args.get('num_inference_steps', None) + shared.state.update(get_job_name(p, shared.sd_model), hires_steps, 1) try: shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model) sd_models.move_model(shared.sd_model, devices.device) @@ -243,6 +245,8 @@ def process_hires(p: processing.StableDiffusionProcessing, output): shared.log.error(f'Processing step=hires: args={hires_args} {e}') errors.display(e, 'Processing') modelstats.analyze() + if orig_image is not None: + p.task_args['image'] = orig_image p.denoising_strength = orig_denoise shared.state.job = prev_job shared.state.nextjob() @@ -255,8 +259,6 @@ def process_refine(p: processing.StableDiffusionProcessing, output): # optional refiner pass or decode if is_refiner_enabled(p): prev_job = shared.state.job - shared.state.job = 'Refine' - shared.state.job_count +=1 if shared.opts.samples_save and not p.do_not_save_samples and shared.opts.save_images_before_refiner and hasattr(shared.sd_model, 'vae'): save_intermediate(p, latents=output.images, suffix="-before-refiner") if shared.opts.diffusers_move_base: @@ -306,7 +308,8 @@ def process_refine(p: processing.StableDiffusionProcessing, output): prompt_attention='fixed', desc='Refiner', ) - shared.state.sampling_steps = refiner_args.get('prior_num_inference_steps', None) or p.steps or refiner_args.get('num_inference_steps', None) + refiner_steps = refiner_args.get('prior_num_inference_steps', None) or p.steps or refiner_args.get('num_inference_steps', None) + shared.state.update(get_job_name(p, shared.sd_refiner), refiner_steps, 1) try: if 'requires_aesthetics_score' in shared.sd_refiner.config: # sdxl-model needs false and sdxl-refiner needs true shared.sd_refiner.register_to_config(requires_aesthetics_score = getattr(shared.sd_refiner, 'tokenizer', None) is None) diff --git a/modules/processing_helpers.py b/modules/processing_helpers.py index 51cbcff7f..93896a02f 100644 --- a/modules/processing_helpers.py +++ b/modules/processing_helpers.py @@ -428,11 +428,16 @@ def resize_hires(p, latents): # input=latents output=pil if not latent_upscaler return resized_images -def fix_prompts(prompts, negative_prompts, prompts_2, negative_prompts_2): +def fix_prompts(p, prompts, negative_prompts, prompts_2, negative_prompts_2): if type(prompts) is str: prompts = [prompts] if type(negative_prompts) is str: negative_prompts = [negative_prompts] + if hasattr(p, '[init_images]') and p.init_images is not None and len(p.init_images) > 1: + while len(prompts) < len(p.init_images): + prompts.append(prompts[-1]) + while len(negative_prompts) < len(p.init_images): + negative_prompts.append(negative_prompts[-1]) while len(negative_prompts) < len(prompts): negative_prompts.append(negative_prompts[-1]) while len(prompts) < len(negative_prompts): @@ -584,3 +589,26 @@ def update_sampler(p, sd_model, second_pass=False): sampler_options.append('low order') if len(sampler_options) > 0: p.extra_generation_params['Sampler options'] = '/'.join(sampler_options) + + +def get_job_name(p, model): + if hasattr(model, 'pipe'): + model = model.pipe + if hasattr(p, 'xyz'): + return 'Ignore' # xyz grid handles its own jobs + if sd_models.get_diffusers_task(model) == sd_models.DiffusersTaskType.TEXT_2_IMAGE: + return 'Text' + elif sd_models.get_diffusers_task(model) == sd_models.DiffusersTaskType.IMAGE_2_IMAGE: + if p.is_refiner_pass: + return 'Refiner' + elif p.is_hr_pass: + return 'Hires' + else: + return 'Image' + elif sd_models.get_diffusers_task(model) == sd_models.DiffusersTaskType.INPAINTING: + if p.detailer_enabled: + return 'Detailer' + else: + return 'Inpaint' + else: + return 'Unknown' diff --git a/modules/processing_info.py b/modules/processing_info.py index e0fca12ae..f3c8eea81 100644 --- a/modules/processing_info.py +++ b/modules/processing_info.py @@ -37,7 +37,6 @@ def create_infotext(p: StableDiffusionProcessing, all_prompts=None, all_seeds=No all_negative_prompts.append(all_negative_prompts[-1]) comment = ', '.join(comments) if comments is not None and type(comments) is list else None ops = list(set(p.ops)) - ops.reverse() args = { # basic "Size": f"{p.width}x{p.height}" if hasattr(p, 'width') and hasattr(p, 'height') else None, @@ -46,14 +45,15 @@ def create_infotext(p: StableDiffusionProcessing, all_prompts=None, all_seeds=No "Seed": all_seeds[index], "Seed resize from": None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}", "CFG scale": p.cfg_scale if p.cfg_scale > 1.0 else None, + "CFG rescale": p.diffusers_guidance_rescale if p.diffusers_guidance_rescale > 0 else None, "CFG end": p.cfg_end if p.cfg_end < 1.0 else None, "Clip skip": p.clip_skip if p.clip_skip > 1 else None, "Batch": f'{p.n_iter}x{p.batch_size}' if p.n_iter > 1 or p.batch_size > 1 else None, "Model": None if (not shared.opts.add_model_name_to_info) or (not shared.sd_model.sd_checkpoint_info.model_name) else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', ''), "Model hash": getattr(p, 'sd_model_hash', None if (not shared.opts.add_model_hash_to_info) or (not shared.sd_model.sd_model_hash) else shared.sd_model.sd_model_hash), "VAE": (None if not shared.opts.add_model_name_to_info or sd_vae.loaded_vae_file is None else os.path.splitext(os.path.basename(sd_vae.loaded_vae_file))[0]) if p.full_quality else 'TAESD', - "Prompt2": p.refiner_prompt if len(p.refiner_prompt) > 0 else None, - "Negative2": p.refiner_negative if len(p.refiner_negative) > 0 else None, + "Refiner prompt": p.refiner_prompt if len(p.refiner_prompt) > 0 else None, + "Refiner negative": p.refiner_negative if len(p.refiner_negative) > 0 else None, "Styles": "; ".join(p.styles) if p.styles is not None and len(p.styles) > 0 else None, # sdnext "App": 'SD.Next', @@ -82,31 +82,33 @@ def create_infotext(p: StableDiffusionProcessing, all_prompts=None, all_seeds=No args["Variation strength"] = p.subseed_strength if p.subseed_strength > 0 else None if 'hires' in p.ops or 'upscale' in p.ops: is_resize = p.hr_resize_mode > 0 and (p.hr_upscaler != 'None' or p.hr_resize_mode == 5) + is_fixed = p.hr_resize_x > 0 or p.hr_resize_y > 0 args["Refine"] = p.enable_hr args["Hires force"] = p.hr_force args["Hires steps"] = p.hr_second_pass_steps - args["HiRes resize mode"] = p.hr_resize_mode if is_resize else None - args["HiRes resize context"] = p.hr_resize_context if p.hr_resize_mode == 5 else None + args["HiRes mode"] = p.hr_resize_mode if is_resize else None + args["HiRes context"] = p.hr_resize_context if p.hr_resize_mode == 5 else None args["Hires upscaler"] = p.hr_upscaler if is_resize else None - args["Hires scale"] = p.hr_scale if is_resize else None - args["Hires resize"] = f"{p.hr_resize_x}x{p.hr_resize_y}" if is_resize else None + if is_fixed: + args["Hires fixed"] = f"{p.hr_resize_x}x{p.hr_resize_y}" if is_resize else None + else: + args["Hires scale"] = p.hr_scale if is_resize else None args["Hires size"] = f"{p.hr_upscale_to_x}x{p.hr_upscale_to_y}" if is_resize else None - args["Denoising strength"] = p.denoising_strength - args["Hires sampler"] = p.hr_sampler_name - args["Image CFG scale"] = p.image_cfg_scale - args["CFG rescale"] = p.diffusers_guidance_rescale + args["Hires strength"] = p.denoising_strength + args["Hires sampler"] = p.hr_sampler_name if p.hr_sampler_name != p.sampler_name else None + args["Hires CFG scale"] = p.image_cfg_scale if 'refine' in p.ops: args["Refine"] = p.enable_hr args["Refiner"] = None if (not shared.opts.add_model_name_to_info) or (not shared.sd_refiner) or (not shared.sd_refiner.sd_checkpoint_info.model_name) else shared.sd_refiner.sd_checkpoint_info.model_name.replace(',', '').replace(':', '') - args['Image CFG scale'] = p.image_cfg_scale + args['Hires CFG scale'] = p.image_cfg_scale args['Refiner steps'] = p.refiner_steps args['Refiner start'] = p.refiner_start args["Hires steps"] = p.hr_second_pass_steps args["Hires sampler"] = p.hr_sampler_name - args["CFG rescale"] = p.diffusers_guidance_rescale - if 'img2img' in p.ops or 'inpaint' in p.ops: + if ('img2img' in p.ops or 'inpaint' in p.ops) and ('txt2img' not in p.ops and 'hires' not in p.ops): # real img2img/inpaint args["Init image size"] = f"{getattr(p, 'init_img_width', 0)}x{getattr(p, 'init_img_height', 0)}" args["Init image hash"] = getattr(p, 'init_img_hash', None) + args['Image CFG scale'] = p.image_cfg_scale args['Resize scale'] = getattr(p, 'scale_by', None) args["Mask weight"] = getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None args["Denoising strength"] = getattr(p, 'denoising_strength', None) @@ -134,6 +136,10 @@ def create_infotext(p: StableDiffusionProcessing, all_prompts=None, all_seeds=No args['Size name mask'] = p.resize_name_mask if 'detailer' in p.ops: args["Detailer"] = ', '.join(shared.opts.detailer_models) + args["Detailer steps"] = p.detailer_steps + args["Detailer strength"] = p.detailer_strength + args["Detailer prompt"] = p.detailer_prompt if len(p.detailer_prompt) > 0 else None + args["Detailer negative"] = p.detailer_negative if len(p.detailer_negative) > 0 else None if 'color' in p.ops: args["Color correction"] = True # embeddings diff --git a/modules/processing_original.py b/modules/processing_original.py index 649023aae..7a0af1b04 100644 --- a/modules/processing_original.py +++ b/modules/processing_original.py @@ -27,6 +27,18 @@ def get_conds_with_caching(function, required_prompts, steps, cache): cache[0] = (required_prompts, steps) return cache[1] +def check_rollback_vae(): + if shared.cmd_opts.rollback_vae: + if not torch.cuda.is_available(): + shared.log.error("Rollback VAE functionality requires compatible GPU") + shared.cmd_opts.rollback_vae = False + elif torch.__version__.startswith('1.') or torch.__version__.startswith('2.0'): + shared.log.error("Rollback VAE functionality requires Torch 2.1 or higher") + shared.cmd_opts.rollback_vae = False + elif 0 < torch.cuda.get_device_capability()[0] < 8: + shared.log.error('Rollback VAE functionality device capabilities not met') + shared.cmd_opts.rollback_vae = False + def process_original(p: processing.StableDiffusionProcessing): cached_uc = [None, None] @@ -42,6 +54,7 @@ def process_original(p: processing.StableDiffusionProcessing): for x in x_samples_ddim: devices.test_for_nans(x, "vae") except devices.NansException as e: + check_rollback_vae() if not shared.opts.no_half and not shared.opts.no_half_vae and shared.cmd_opts.rollback_vae: shared.log.warning('Tensor with all NaNs was produced in VAE') devices.dtype_vae = torch.bfloat16 @@ -90,11 +103,11 @@ def sample_txt2img(p: processing.StableDiffusionProcessingTxt2Img, conditioning, for i, x_sample in enumerate(decoded_samples): x_sample = validate_sample(x_sample) image = Image.fromarray(x_sample) - bak_extra_generation_params, bak_detailer = p.extra_generation_params, p.detailer + orig_extra_generation_params, orig_detailer = p.extra_generation_params, p.detailer_denabled p.extra_generation_params = {} - p.detailer = False + p.detailer_denabled = False info = processing.create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, [], iteration=p.iteration, position_in_batch=i) - p.extra_generation_params, p.detailer = bak_extra_generation_params, bak_detailer + p.extra_generation_params, p.detailer_enabled = orig_extra_generation_params, orig_detailer images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], shared.opts.samples_format, info=info, suffix="-before-hires") if latent_scale_mode is None or p.hr_force: # non-latent upscaling shared.state.job = 'Upscale' diff --git a/modules/processing_vae.py b/modules/processing_vae.py index 04af9bab1..faaacb21e 100644 --- a/modules/processing_vae.py +++ b/modules/processing_vae.py @@ -239,6 +239,8 @@ def vae_decode(latents, model, output_type='np', full_quality=True, width=None, decoded = full_vqgan_decode(latents=latents, model=model) else: decoded = taesd_vae_decode(latents=latents) + if torch.is_tensor(decoded): + decoded = 2.0 * decoded - 1.0 # typical normalized range if torch.is_tensor(decoded): if hasattr(model, 'video_processor'): diff --git a/modules/progress.py b/modules/progress.py index bc3e5500c..8b15f7e13 100644 --- a/modules/progress.py +++ b/modules/progress.py @@ -64,23 +64,16 @@ def progressapi(req: ProgressRequest): queued = req.id_task in pending_tasks completed = req.id_task in finished_tasks paused = shared.state.paused - shared.state.job_count = max(shared.state.frame_count, shared.state.job_count, shared.state.job_no) - batch_x = max(shared.state.job_no, 0) - batch_y = max(shared.state.job_count, 1) - step_x = max(shared.state.sampling_step, 0) - step_y = max(shared.state.sampling_steps, 1) - current = step_y * batch_x + step_x - total = step_y * batch_y - while total < current: - total += step_y - progress = min(1, abs(current / total) if total > 0 else 0) + step = max(shared.state.sampling_step, 0) + steps = max(shared.state.sampling_steps, 1) + progress = round(min(1, abs(step / steps) if steps > 0 else 0), 2) elapsed = time.time() - shared.state.time_start if shared.state.time_start is not None else 0 predicted = elapsed / progress if progress > 0 else None eta = predicted - elapsed if predicted is not None else None id_live_preview = req.id_live_preview live_preview = None updated = shared.state.set_current_image() - debug_log(f'Preview: job={shared.state.job} active={active} progress={current}/{total} step={shared.state.current_image_sampling_step}/{shared.state.sampling_step} request={id_live_preview} last={shared.state.id_live_preview} enabled={shared.opts.live_previews_enable} job={shared.state.preview_job} updated={updated} image={shared.state.current_image} elapsed={elapsed:.3f}') + debug_log(f'Preview: job={shared.state.job} active={active} progress={step}/{steps}/{progress} image={shared.state.current_image_sampling_step} request={id_live_preview} last={shared.state.id_live_preview} enabled={shared.opts.live_previews_enable} job={shared.state.preview_job} updated={updated} image={shared.state.current_image} elapsed={elapsed:.3f}') if not active: return InternalProgressResponse(job=shared.state.job, active=active, queued=queued, paused=paused, completed=completed, id_live_preview=-1, debug=debug, textinfo="Queued..." if queued else "Waiting...") if shared.opts.live_previews_enable and (shared.state.id_live_preview != id_live_preview) and (shared.state.current_image is not None): diff --git a/modules/pulid/eva_clip/hf_model.py b/modules/pulid/eva_clip/hf_model.py index d148bbff2..0b9551993 100644 --- a/modules/pulid/eva_clip/hf_model.py +++ b/modules/pulid/eva_clip/hf_model.py @@ -222,7 +222,6 @@ def lock(self, unlocked_layers:int=0, freeze_layer_norm:bool=True): encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"]) - print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model") embeddings = getattr( self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"]) modules = [embeddings, *layer_list][:-unlocked_layers] diff --git a/modules/rocm.py b/modules/rocm.py index ef76a1cfa..05fb2260e 100644 --- a/modules/rocm.py +++ b/modules/rocm.py @@ -8,10 +8,6 @@ from enum import Enum -HIPBLASLT_TENSILE_LIBPATH = os.environ.get("HIPBLASLT_TENSILE_LIBPATH", None if sys.platform == "win32" # not available - else "/opt/rocm/lib/hipblaslt/library") - - def resolve_link(path_: str) -> str: if not os.path.islink(path_): return path_ @@ -55,8 +51,7 @@ class Agent: gfx_version: int arch: MicroArchitecture is_apu: bool - if sys.platform != "win32": - blaslt_supported: bool + blaslt_supported: bool @staticmethod def parse_gfx_version(name: str) -> int: @@ -83,8 +78,7 @@ def __init__(self, name: str): else: self.arch = MicroArchitecture.GCN self.is_apu = (self.gfx_version & 0xFFF0 == 0x1150) or self.gfx_version in (0x801, 0x902, 0x90c, 0x1013, 0x1033, 0x1035, 0x1036, 0x1103,) - if sys.platform != "win32": - self.blaslt_supported = os.path.exists(os.path.join(HIPBLASLT_TENSILE_LIBPATH, f"extop_{name}.co")) + self.blaslt_supported = os.path.exists(os.path.join(blaslt_tensile_libpath, f"Kernels.so-000-{name}.hsaco" if sys.platform == "win32" else f"extop_{name}.co")) def get_gfx_version(self) -> Union[str, None]: if self.gfx_version >= 0x1200: @@ -163,6 +157,7 @@ def get_agents() -> List[Agent]: return [Agent(x.split(' ')[-1].strip()) for x in spawn("hipinfo", cwd=os.path.join(path, 'bin')).split("\n") if x.startswith('gcnArchName:')] is_wsl: bool = False + version_torch = None else: def find() -> Union[str, None]: rocm_path = shutil.which("hipconfig") @@ -199,12 +194,12 @@ def load_hsa_runtime() -> None: def set_blaslt_enabled(enabled: bool) -> None: if enabled: load_library_global("/opt/rocm/lib/libhipblaslt.so") # Preload hipBLASLt. - os.environ["HIPBLASLT_TENSILE_LIBPATH"] = HIPBLASLT_TENSILE_LIBPATH + os.environ["HIPBLASLT_TENSILE_LIBPATH"] = blaslt_tensile_libpath else: os.environ["TORCH_BLAS_PREFER_HIPBLASLT"] = "0" def get_blaslt_enabled() -> bool: - return bool(int(os.environ.get("TORCH_BLAS_PREFER_HIPBLASLT", "1"))) + return version == version_torch and bool(int(os.environ.get("TORCH_BLAS_PREFER_HIPBLASLT", "1"))) def get_flash_attention_command(agent: Agent): if os.environ.get("FLASH_ATTENTION_USE_TRITON_ROCM", "FALSE") == "TRUE": @@ -215,10 +210,12 @@ def get_flash_attention_command(agent: Agent): return os.environ.get("FLASH_ATTENTION_PACKAGE", default) is_wsl: bool = os.environ.get('WSL_DISTRO_NAME', 'unknown' if spawn('wslpath -w /') else None) is not None + version_torch = get_version_torch() path = find() +blaslt_tensile_libpath = None is_installed = False version = None -version_torch = get_version_torch() if path is not None: + blaslt_tensile_libpath = os.environ.get("HIPBLASLT_TENSILE_LIBPATH", os.path.join(path, "bin" if sys.platform == "win32" else "lib", "hipblaslt", "library")) is_installed = True version = get_version() diff --git a/modules/schedulers/scheduler_dc.py b/modules/schedulers/scheduler_dc.py index a1ccfbeba..190588855 100644 --- a/modules/schedulers/scheduler_dc.py +++ b/modules/schedulers/scheduler_dc.py @@ -820,7 +820,6 @@ def closure(ratio_param): loss.backward() optimizer.step() ratio_bound = bound_func(ratio_param) - print(f'iter [{iter_}]', ratio_bound.item(), loss.item()) torch.cuda.empty_cache() return ratio_bound.data.detach().item() diff --git a/modules/schedulers/scheduler_tdd.py b/modules/schedulers/scheduler_tdd.py new file mode 100644 index 000000000..125ef1b3b --- /dev/null +++ b/modules/schedulers/scheduler_tdd.py @@ -0,0 +1,525 @@ +from typing import Union, List, Optional, Tuple +import numpy as np +import torch +from diffusers.utils import deprecate, logging +from diffusers.configuration_utils import register_to_config +from diffusers import DPMSolverSinglestepScheduler +from diffusers.schedulers.scheduling_utils import SchedulerOutput +from diffusers.utils.torch_utils import randn_tensor +# from diffusers.schedulers.scheduling_tcd import * +# from diffusers.schedulers.scheduling_dpmsolver_singlestep import * + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class TDDScheduler(DPMSolverSinglestepScheduler): + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[np.ndarray] = None, + solver_order: int = 1, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + algorithm_type: str = "dpmsolver++", + solver_type: str = "midpoint", + lower_order_final: bool = False, + use_karras_sigmas: Optional[bool] = False, + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + lambda_min_clipped: float = -float("inf"), + variance_type: Optional[str] = None, + tdd_train_step: int = 250, + special_jump: bool = False, + t_l: int = -1, + use_flow_sigmas: bool = False, + ): + self.t_l = t_l + self.special_jump = special_jump + self.tdd_train_step = tdd_train_step + if algorithm_type == "dpmsolver": + deprecation_message = "algorithm_type `dpmsolver` is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" + deprecate("algorithm_types=dpmsolver", "1.0.0", deprecation_message) + + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + # Currently we only support VP-type noise schedule + self.alpha_t = torch.sqrt(self.alphas_cumprod) + self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) + self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # settings for DPM-Solver + if algorithm_type not in ["dpmsolver", "dpmsolver++"]: + if algorithm_type == "deis": + self.register_to_config(algorithm_type="dpmsolver++") + else: + raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") + if solver_type not in ["midpoint", "heun"]: + if solver_type in ["logrho", "bh1", "bh2"]: + self.register_to_config(solver_type="midpoint") + else: + raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") + + if algorithm_type != "dpmsolver++" and final_sigmas_type == "zero": + raise ValueError( + f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please chooose `sigma_min` instead." + ) + + # setable values + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.model_outputs = [None] * solver_order + self.sample = None + self.order_list = self.get_order_list(num_train_timesteps) + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + self.num_inference_steps = num_inference_steps + # Clipping the minimum of all lambda(t) for numerical stability. + # This is critical for cosine (squaredcos_cap_v2) noise schedule. + #original_steps = self.config.original_inference_steps + if True: + original_steps=self.tdd_train_step + k = 1000 / original_steps + tcd_origin_timesteps = np.asarray(list(range(1, int(original_steps) + 1))) * k - 1 + else: + tcd_origin_timesteps = np.asarray(list(range(0, int(self.config.num_train_timesteps)))) + # TCD Inference Steps Schedule + tcd_origin_timesteps = tcd_origin_timesteps[::-1].copy() + # Select (approximately) evenly spaced indices from tcd_origin_timesteps. + inference_indices = np.linspace(0, len(tcd_origin_timesteps), num=num_inference_steps, endpoint=False) + inference_indices = np.floor(inference_indices).astype(np.int64) + timesteps = tcd_origin_timesteps[inference_indices] + if self.special_jump: + if self.tdd_train_step == 50: + pass + elif self.tdd_train_step == 250: + if num_inference_steps == 5: + timesteps = np.array([999., 875., 751., 499., 251.]) + elif num_inference_steps == 6: + timesteps = np.array([999., 875., 751., 627., 499., 251.]) + elif num_inference_steps == 7: + timesteps = np.array([999., 875., 751., 627., 499., 375., 251.]) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + if self.config.use_karras_sigmas: + log_sigmas = np.log(sigmas) + sigmas = np.flip(sigmas).copy() + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + else: + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f" `final_sigmas_type` must be one of `sigma_min` or `zero`, but got {self.config.final_sigmas_type}" + ) + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + + self.sigmas = torch.from_numpy(sigmas).to(device=device) + + self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) + self.model_outputs = [None] * self.config.solver_order + self.sample = None + + if not self.config.lower_order_final and num_inference_steps % self.config.solver_order != 0: + logger.warning( + "Changing scheduler {self.config} to have `lower_order_final` set to True to handle uneven amount of inference steps. Please make sure to always use an even number of `num_inference steps when using `lower_order_final=False`." + ) + self.register_to_config(lower_order_final=True) + + if not self.config.lower_order_final and self.config.final_sigmas_type == "zero": + logger.warning( + " `last_sigmas_type='zero'` is not supported for `lower_order_final=False`. Changing scheduler {self.config} to have `lower_order_final` set to True." + ) + self.register_to_config(lower_order_final=True) + + self.order_list = self.get_order_list(num_inference_steps) + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + def set_timesteps_s(self, eta: float = 0.0): + # Clipping the minimum of all lambda(t) for numerical stability. + # This is critical for cosine (squaredcos_cap_v2) noise schedule. + num_inference_steps = self.num_inference_steps + device = self.timesteps.device + if True: + original_steps=self.tdd_train_step + k = 1000 / original_steps + tcd_origin_timesteps = np.asarray(list(range(1, int(original_steps) + 1))) * k - 1 + else: + tcd_origin_timesteps = np.asarray(list(range(0, int(self.config.num_train_timesteps)))) + # TCD Inference Steps Schedule + tcd_origin_timesteps = tcd_origin_timesteps[::-1].copy() + # Select (approximately) evenly spaced indices from tcd_origin_timesteps. + inference_indices = np.linspace(0, len(tcd_origin_timesteps), num=num_inference_steps, endpoint=False) + inference_indices = np.floor(inference_indices).astype(np.int64) + timesteps = tcd_origin_timesteps[inference_indices] + if self.special_jump: + if self.tdd_train_step == 50: + timesteps = np.array([999., 879., 759., 499., 259.]) + elif self.tdd_train_step == 250: + if num_inference_steps == 5: + timesteps = np.array([999., 875., 751., 499., 251.]) + elif num_inference_steps == 6: + timesteps = np.array([999., 875., 751., 627., 499., 251.]) + elif num_inference_steps == 7: + timesteps = np.array([999., 875., 751., 627., 499., 375., 251.]) + + timesteps_s = np.floor((1 - eta) * timesteps).astype(np.int64) + + sigmas_s = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + if self.config.use_karras_sigmas: + pass + else: + sigmas_s = np.interp(timesteps_s, np.arange(0, len(sigmas_s)), sigmas_s) + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f" `final_sigmas_type` must be one of `sigma_min` or `zero`, but got {self.config.final_sigmas_type}" + ) + + sigmas_s = np.concatenate([sigmas_s, [sigma_last]]).astype(np.float32) + self.sigmas_s = torch.from_numpy(sigmas_s).to(device=device) + self.timesteps_s = torch.from_numpy(timesteps_s).to(device=device, dtype=torch.int64) + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + eta: float, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + if self.step_index == 0: + self.set_timesteps_s(eta) + + model_output = self.convert_model_output(model_output, sample=sample) + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.model_outputs[-1] = model_output + + order = self.order_list[self.step_index] + + # For img2img denoising might start with order>1 which is not possible + # In this case make sure that the first two steps are both order=1 + while self.model_outputs[-order] is None: + order -= 1 + + # For single-step solvers, we use the initial value at each time with order = 1. + if order == 1: + self.sample = sample + + prev_sample = self.singlestep_dpm_solver_update(self.model_outputs, sample=self.sample, order=order) + + if eta > 0: + if self.step_index != self.num_inference_steps - 1: + + alpha_prod_s = self.alphas_cumprod[self.timesteps_s[self.step_index + 1]] + alpha_prod_t_prev = self.alphas_cumprod[self.timesteps[self.step_index + 1]] + + noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=prev_sample.dtype + ) + prev_sample = (alpha_prod_t_prev / alpha_prod_s).sqrt() * prev_sample + ( + 1 - alpha_prod_t_prev / alpha_prod_s + ).sqrt() * noise + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def dpm_solver_first_order_update( + self, + model_output: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, + **kwargs, + ) -> torch.FloatTensor: + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + sigma_t, sigma_s = self.sigmas_s[self.step_index + 1], self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s = torch.log(alpha_s) - torch.log(sigma_s) + h = lambda_t - lambda_s + if self.config.algorithm_type == "dpmsolver++": + x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output + elif self.config.algorithm_type == "dpmsolver": + x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output + return x_t + + def singlestep_dpm_solver_second_order_update( + self, + model_output_list: List[torch.FloatTensor], + *args, + sample: torch.FloatTensor = None, + **kwargs, + ) -> torch.FloatTensor: + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + sigma_t, sigma_s0, sigma_s1 = ( + self.sigmas_s[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + + m0, m1 = model_output_list[-1], model_output_list[-2] + + h, h_0 = lambda_t - lambda_s1, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m1, (1.0 / r0) * (m0 - m1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2211.01095 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s1) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s1) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s1) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s1) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + ) + return x_t + + def singlestep_dpm_solver_update( + self, + model_output_list: List[torch.FloatTensor], + *args, + sample: torch.FloatTensor = None, + order: int = None, + **kwargs, + ) -> torch.FloatTensor: + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing`sample` as a required keyward argument") + if order is None: + if len(args) > 3: + order = args[3] + else: + raise ValueError(" missing `order` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if order == 1: + return self.dpm_solver_first_order_update(model_output_list[-1], sample=sample) + elif order == 2: + return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample) + else: + raise ValueError(f"Order must be 1, 2, got {order}") + + def convert_model_output( + self, + model_output: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, + **kwargs, + ) -> torch.FloatTensor: + """ + Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is + designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an + integral of the data prediction model. + + + + The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise + prediction and data prediction models. + + + + Args: + model_output (`torch.FloatTensor`): + The direct output from the learned diffusion model. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.FloatTensor`: + The converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + # DPM-Solver++ needs to solve an integral of the data prediction model. + if self.config.algorithm_type == "dpmsolver++": + if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned_range"]: + model_output = model_output[:, :3] + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + x0_pred = (sample - sigma_t * model_output) / alpha_t + elif self.config.prediction_type == "sample": + x0_pred = model_output + elif self.config.prediction_type == "v_prediction": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + x0_pred = alpha_t * sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the DPMSolverSinglestepScheduler." + ) + + if self.step_index <= self.t_l: + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + # DPM-Solver needs to solve an integral of the noise prediction model. + elif self.config.algorithm_type == "dpmsolver": + if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned_range"]: + model_output = model_output[:, :3] + return model_output + elif self.config.prediction_type == "sample": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + epsilon = (sample - alpha_t * model_output) / sigma_t + return epsilon + elif self.config.prediction_type == "v_prediction": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + epsilon = alpha_t * model_output + sigma_t * sample + return epsilon + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the DPMSolverSinglestepScheduler." + ) diff --git a/modules/schedulers/scheduler_vdm.py b/modules/schedulers/scheduler_vdm.py index 492c30a0c..4f48db163 100644 --- a/modules/schedulers/scheduler_vdm.py +++ b/modules/schedulers/scheduler_vdm.py @@ -355,7 +355,6 @@ def step( ) # 3. Clip or threshold "predicted x_0" - # print({ 'timestep': timestep.item(), 'min': pred_original_sample.min().item(), 'max': pred_original_sample.max().item(), 'alpha': alpha.item(), 'sigma': sigma.item() }) if self.config.thresholding: pred_original_sample = self._threshold_sample(pred_original_sample) elif self.config.clip_sample: diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 6c91edf44..a0c85a283 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -141,7 +141,7 @@ def print_timers(): if v > 0.05: long_callbacks.append(f'{k}={v:.2f}') if len(long_callbacks) > 0: - errors.log.debug(f'Script callback init time: {" ".join(long_callbacks)}') + errors.log.debug(f'Script init: {long_callbacks}') def clear_callbacks(): diff --git a/modules/sd_checkpoint.py b/modules/sd_checkpoint.py index 6ab396329..ec41a130f 100644 --- a/modules/sd_checkpoint.py +++ b/modules/sd_checkpoint.py @@ -210,6 +210,8 @@ def get_closet_checkpoint_match(s: str): if shared.opts.sd_checkpoint_autodownload and s.count('/') == 1: modelloader.hf_login() found = modelloader.find_diffuser(s, full=True) + if found is None: + return None found = [f for f in found if f == s] shared.log.info(f'HF search: model="{s}" results={found}') if found is not None and len(found) == 1: @@ -262,7 +264,7 @@ def select_checkpoint(op='model'): shared.log.info(f'Load {op}: select="{checkpoint_info.title if checkpoint_info is not None else None}"') return checkpoint_info if len(checkpoints_list) == 0: - shared.log.warning("Cannot generate without a checkpoint") + shared.log.error("No models found") shared.log.info("Set system paths to use existing folders") shared.log.info(" or use --models-dir to specify base folder with all models") shared.log.info(" or use --ckpt-dir to specify folder with sd models") diff --git a/modules/sd_detect.py b/modules/sd_detect.py index 15b22c69c..514517c8d 100644 --- a/modules/sd_detect.py +++ b/modules/sd_detect.py @@ -49,12 +49,6 @@ def detect_pipeline(f: str, op: str = 'model', warning=True, quiet=False): elif (size > 20000 and size < 40000): guess = 'FLUX' # guess by name - """ - if 'LCM_' in f.upper() or 'LCM-' in f.upper() or '_LCM' in f.upper() or '-LCM' in f.upper(): - if shared.backend == shared.Backend.ORIGINAL: - warn(f'Model detected as LCM model, but attempting to load using backend=original: {op}={f} size={size} MB') - guess = 'Latent Consistency Model' - """ if 'instaflow' in f.lower(): guess = 'InstaFlow' if 'segmoe' in f.lower(): diff --git a/modules/sd_models.py b/modules/sd_models.py index 51d25edfa..9d0349a3f 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -1,25 +1,22 @@ -import io import sys import time -import json import copy import inspect import logging -import contextlib import os.path from enum import Enum import diffusers import diffusers.loaders.single_file_utils -from rich import progress # pylint: disable=redefined-builtin import torch -import safetensors.torch -import accelerate -from omegaconf import OmegaConf + from modules import paths, shared, shared_state, modelloader, devices, script_callbacks, sd_vae, sd_unet, errors, sd_models_config, sd_models_compile, sd_hijack_accelerate, sd_detect from modules.timer import Timer, process as process_timer from modules.memstats import memory_stats from modules.modeldata import model_data from modules.sd_checkpoint import CheckpointInfo, select_checkpoint, list_models, checkpoints_list, checkpoint_titles, get_closet_checkpoint_match, model_hash, update_model_hashes, setup_model, write_metadata, read_metadata_from_safetensors # pylint: disable=unused-import +from modules.sd_offload import disable_offload, set_diffuser_offload, apply_balanced_offload, set_accelerate # pylint: disable=unused-import +from modules.sd_models_legacy import get_checkpoint_state_dict, load_model_weights, load_model, repair_config # pylint: disable=unused-import +from modules.sd_models_utils import NoWatermark, get_signature, get_call, path_to_repo, patch_diffuser_config, convert_to_faketensors, read_state_dict, get_state_dict_from_checkpoint # pylint: disable=unused-import model_dir = "Stable-diffusion" @@ -33,167 +30,6 @@ debug_process = shared.log.trace if os.environ.get('SD_PROCESS_DEBUG', None) is not None else lambda *args, **kwargs: None diffusers_version = int(diffusers.__version__.split('.')[1]) checkpoint_tiles = checkpoint_titles # legacy compatibility -should_offload = ['sc', 'sd3', 'f1', 'hunyuandit', 'auraflow', 'omnigen'] -offload_hook_instance = None - - -class NoWatermark: - def apply_watermark(self, img): - return img - - -def read_state_dict(checkpoint_file, map_location=None, what:str='model'): # pylint: disable=unused-argument - if not os.path.isfile(checkpoint_file): - shared.log.error(f'Load dict: path="{checkpoint_file}" not a file') - return None - try: - pl_sd = None - with progress.open(checkpoint_file, 'rb', description=f'[cyan]Load {what}: [yellow]{checkpoint_file}', auto_refresh=True, console=shared.console) as f: - _, extension = os.path.splitext(checkpoint_file) - if extension.lower() == ".ckpt" and shared.opts.sd_disable_ckpt: - shared.log.warning(f"Checkpoint loading disabled: {checkpoint_file}") - return None - if shared.opts.stream_load: - if extension.lower() == ".safetensors": - # shared.log.debug('Model weights loading: type=safetensors mode=buffered') - buffer = f.read() - pl_sd = safetensors.torch.load(buffer) - else: - # shared.log.debug('Model weights loading: type=checkpoint mode=buffered') - buffer = io.BytesIO(f.read()) - pl_sd = torch.load(buffer, map_location='cpu') - else: - if extension.lower() == ".safetensors": - # shared.log.debug('Model weights loading: type=safetensors mode=mmap') - pl_sd = safetensors.torch.load_file(checkpoint_file, device='cpu') - else: - # shared.log.debug('Model weights loading: type=checkpoint mode=direct') - pl_sd = torch.load(f, map_location='cpu') - sd = get_state_dict_from_checkpoint(pl_sd) - del pl_sd - except Exception as e: - errors.display(e, f'Load model: {checkpoint_file}') - sd = None - return sd - - -def get_state_dict_from_checkpoint(pl_sd): - checkpoint_dict_replacements = { - 'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.', - 'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.', - 'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.', - } - - def transform_checkpoint_dict_key(k): - for text, replacement in checkpoint_dict_replacements.items(): - if k.startswith(text): - k = replacement + k[len(text):] - return k - - pl_sd = pl_sd.pop("state_dict", pl_sd) - pl_sd.pop("state_dict", None) - sd = {} - for k, v in pl_sd.items(): - new_key = transform_checkpoint_dict_key(k) - if new_key is not None: - sd[new_key] = v - pl_sd.clear() - pl_sd.update(sd) - return pl_sd - - -def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer): - if not os.path.isfile(checkpoint_info.filename): - return None - """ - if checkpoint_info in checkpoints_loaded: - shared.log.info("Load model: cache") - checkpoints_loaded.move_to_end(checkpoint_info, last=True) # FIFO -> LRU cache - return checkpoints_loaded[checkpoint_info] - """ - res = read_state_dict(checkpoint_info.filename, what='model') - """ - if shared.opts.sd_checkpoint_cache > 0 and not shared.native: - # cache newly loaded model - checkpoints_loaded[checkpoint_info] = res - # clean up cache if limit is reached - while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: - checkpoints_loaded.popitem(last=False) - """ - timer.record("load") - return res - - -def load_model_weights(model: torch.nn.Module, checkpoint_info: CheckpointInfo, state_dict, timer): - _pipeline, _model_type = sd_detect.detect_pipeline(checkpoint_info.path, 'model') - shared.log.debug(f'Load model: memory={memory_stats()}') - timer.record("hash") - if model_data.sd_dict == 'None': - shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title - if state_dict is None: - state_dict = get_checkpoint_state_dict(checkpoint_info, timer) - try: - model.load_state_dict(state_dict, strict=False) - except Exception as e: - shared.log.error(f'Load model: path="{checkpoint_info.filename}"') - shared.log.error(' '.join(str(e).splitlines()[:2])) - return False - del state_dict - timer.record("apply") - if shared.opts.opt_channelslast: - model.to(memory_format=torch.channels_last) - timer.record("channels") - if not shared.opts.no_half: - vae = model.first_stage_model - depth_model = getattr(model, 'depth_model', None) - # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16 - if shared.opts.no_half_vae: - model.first_stage_model = None - # with --upcast-sampling, don't convert the depth model weights to float16 - if shared.opts.upcast_sampling and depth_model: - model.depth_model = None - model.half() - model.first_stage_model = vae - if depth_model: - model.depth_model = depth_model - if shared.opts.cuda_cast_unet: - devices.dtype_unet = model.model.diffusion_model.dtype - else: - model.model.diffusion_model.to(devices.dtype_unet) - model.first_stage_model.to(devices.dtype_vae) - model.sd_model_hash = checkpoint_info.calculate_shorthash() - model.sd_model_checkpoint = checkpoint_info.filename - model.sd_checkpoint_info = checkpoint_info - model.is_sdxl = False # a1111 compatibility item - model.is_sd2 = hasattr(model.cond_stage_model, 'model') # a1111 compatibility item - model.is_sd1 = not hasattr(model.cond_stage_model, 'model') # a1111 compatibility item - model.logvar = model.logvar.to(devices.device) if hasattr(model, 'logvar') else None # fix for training - shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256 - sd_vae.delete_base_vae() - sd_vae.clear_loaded_vae() - vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename) - sd_vae.load_vae(model, vae_file, vae_source) - timer.record("vae") - return True - - -def repair_config(sd_config): - if "use_ema" not in sd_config.model.params: - sd_config.model.params.use_ema = False - if shared.opts.no_half: - sd_config.model.params.unet_config.params.use_fp16 = False - elif shared.opts.upcast_sampling: - sd_config.model.params.unet_config.params.use_fp16 = True if sys.platform != 'darwin' else False - if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available: - sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla" - # For UnCLIP-L, override the hardcoded karlo directory - if "noise_aug_config" in sd_config.model.params and "clip_stats_path" in sd_config.model.params.noise_aug_config.params: - karlo_path = os.path.join(paths.models_path, 'karlo') - sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path) - - -sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight' -sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight' def change_backend(): @@ -225,21 +61,21 @@ def copy_diffuser_options(new_pipe, orig_pipe): set_accelerate(new_pipe) -def set_vae_options(sd_model, vae = None, op: str = 'model'): +def set_vae_options(sd_model, vae=None, op:str='model', quiet:bool=False): if hasattr(sd_model, "vae"): if vae is not None: sd_model.vae = vae - shared.log.debug(f'Setting {op}: component=VAE name="{sd_vae.loaded_vae_file}"') + shared.log.quiet(quiet, f'Setting {op}: component=VAE name="{sd_vae.loaded_vae_file}"') if shared.opts.diffusers_vae_upcast != 'default': sd_model.vae.config.force_upcast = True if shared.opts.diffusers_vae_upcast == 'true' else False - shared.log.debug(f'Setting {op}: component=VAE upcast={sd_model.vae.config.force_upcast}') + shared.log.quiet(quiet, f'Setting {op}: component=VAE upcast={sd_model.vae.config.force_upcast}') if shared.opts.no_half_vae: devices.dtype_vae = torch.float32 sd_model.vae.to(devices.dtype_vae) - shared.log.debug(f'Setting {op}: component=VAE no-half=True') + shared.log.quiet(quiet, f'Setting {op}: component=VAE no-half=True') if hasattr(sd_model, "enable_vae_slicing"): if shared.opts.diffusers_vae_slicing: - shared.log.debug(f'Setting {op}: component=VAE slicing=True') + shared.log.quiet(quiet, f'Setting {op}: component=VAE slicing=True') sd_model.enable_vae_slicing() else: sd_model.disable_vae_slicing() @@ -251,18 +87,18 @@ def set_vae_options(sd_model, vae = None, op: str = 'model'): sd_model.vae.tile_latent_min_size = int(sd_model.vae.config.sample_size / (2 ** (len(sd_model.vae.config.block_out_channels) - 1))) if shared.opts.diffusers_vae_tile_overlap != 0.25: sd_model.vae.tile_overlap_factor = float(shared.opts.diffusers_vae_tile_overlap) - shared.log.debug(f'Setting {op}: component=VAE tiling=True tile={sd_model.vae.tile_sample_min_size} overlap={sd_model.vae.tile_overlap_factor}') + shared.log.quiet(quiet, f'Setting {op}: component=VAE tiling=True tile={sd_model.vae.tile_sample_min_size} overlap={sd_model.vae.tile_overlap_factor}') else: - shared.log.debug(f'Setting {op}: component=VAE tiling=True') + shared.log.quiet(quiet, f'Setting {op}: component=VAE tiling=True') sd_model.enable_vae_tiling() else: sd_model.disable_vae_tiling() if hasattr(sd_model, "vqvae"): - shared.log.debug(f'Setting {op}: component=VQVAE upcast=True') + shared.log.quiet(quiet, f'Setting {op}: component=VQVAE upcast=True') sd_model.vqvae.to(torch.float32) # vqvae is producing nans in fp16 -def set_diffuser_options(sd_model, vae = None, op: str = 'model', offload=True): +def set_diffuser_options(sd_model, vae=None, op:str='model', offload:bool=True, quiet:bool=False): if sd_model is None: shared.log.warning(f'{op} is not loaded') return @@ -273,19 +109,19 @@ def set_diffuser_options(sd_model, vae = None, op: str = 'model', offload=True): sd_model.has_accelerate = False clear_caches() - set_vae_options(sd_model, vae, op) - set_diffusers_attention(sd_model) + set_vae_options(sd_model, vae, op, quiet) + set_diffusers_attention(sd_model, quiet) if shared.opts.diffusers_fuse_projections and hasattr(sd_model, 'fuse_qkv_projections'): try: sd_model.fuse_qkv_projections() - shared.log.debug(f'Setting {op}: fused-qkv=True') + shared.log.quiet(quiet, f'Setting {op}: fused-qkv=True') except Exception as e: shared.log.error(f'Setting {op}: fused-qkv=True {e}') if shared.opts.diffusers_fuse_projections and hasattr(sd_model, 'transformer') and hasattr(sd_model.transformer, 'fuse_qkv_projections'): try: sd_model.transformer.fuse_qkv_projections() - shared.log.debug(f'Setting {op}: fused-qkv=True') + shared.log.quiet(quiet, f'Setting {op}: fused-qkv=True') except Exception as e: shared.log.error(f'Setting {op}: fused-qkv=True {e}') if shared.opts.diffusers_eval: @@ -299,255 +135,11 @@ def eval_model(model, op=None, sd_model=None): # pylint: disable=unused-argument sd_model = sd_models_compile.torchao_quantization(sd_model) if shared.opts.opt_channelslast and hasattr(sd_model, 'unet'): - shared.log.debug(f'Setting {op}: channels-last=True') + shared.log.quiet(quiet, f'Setting {op}: channels-last=True') sd_model.unet.to(memory_format=torch.channels_last) if offload: - set_diffuser_offload(sd_model, op) - - -def set_accelerate_to_module(model): - if hasattr(model, "pipe"): - set_accelerate_to_module(model.pipe) - if hasattr(model, "_internal_dict"): - for k in model._internal_dict.keys(): # pylint: disable=protected-access - component = getattr(model, k, None) - if isinstance(component, torch.nn.Module): - component.has_accelerate = True - - -def set_accelerate(sd_model): - sd_model.has_accelerate = True - set_accelerate_to_module(sd_model) - if hasattr(sd_model, "prior_pipe"): - set_accelerate_to_module(sd_model.prior_pipe) - if hasattr(sd_model, "decoder_pipe"): - set_accelerate_to_module(sd_model.decoder_pipe) - - -def set_diffuser_offload(sd_model, op: str = 'model'): - t0 = time.time() - if not shared.native: - shared.log.warning('Attempting to use offload with backend=original') - return - if sd_model is None: - shared.log.warning(f'{op} is not loaded') - return - if not (hasattr(sd_model, "has_accelerate") and sd_model.has_accelerate): - sd_model.has_accelerate = False - if shared.opts.diffusers_offload_mode == "none": - if shared.sd_model_type in should_offload: - shared.log.warning(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} type={shared.sd_model.__class__.__name__} large model') - else: - shared.log.debug(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} limit={shared.opts.cuda_mem_fraction}') - if hasattr(sd_model, 'maybe_free_model_hooks'): - sd_model.maybe_free_model_hooks() - sd_model.has_accelerate = False - if shared.opts.diffusers_offload_mode == "model" and hasattr(sd_model, "enable_model_cpu_offload"): - try: - shared.log.debug(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} limit={shared.opts.cuda_mem_fraction}') - if shared.opts.diffusers_move_base or shared.opts.diffusers_move_unet or shared.opts.diffusers_move_refiner: - shared.opts.diffusers_move_base = False - shared.opts.diffusers_move_unet = False - shared.opts.diffusers_move_refiner = False - shared.log.warning(f'Disabling {op} "Move model to CPU" since "Model CPU offload" is enabled') - if not hasattr(sd_model, "_all_hooks") or len(sd_model._all_hooks) == 0: # pylint: disable=protected-access - sd_model.enable_model_cpu_offload(device=devices.device) - else: - sd_model.maybe_free_model_hooks() - set_accelerate(sd_model) - except Exception as e: - shared.log.error(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} {e}') - if shared.opts.diffusers_offload_mode == "sequential" and hasattr(sd_model, "enable_sequential_cpu_offload"): - try: - shared.log.debug(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} limit={shared.opts.cuda_mem_fraction}') - if shared.opts.diffusers_move_base or shared.opts.diffusers_move_unet or shared.opts.diffusers_move_refiner: - shared.opts.diffusers_move_base = False - shared.opts.diffusers_move_unet = False - shared.opts.diffusers_move_refiner = False - shared.log.warning(f'Disabling {op} "Move model to CPU" since "Sequential CPU offload" is enabled') - if sd_model.has_accelerate: - if op == "vae": # reapply sequential offload to vae - from accelerate import cpu_offload - sd_model.vae.to("cpu") - cpu_offload(sd_model.vae, devices.device, offload_buffers=len(sd_model.vae._parameters) > 0) # pylint: disable=protected-access - else: - pass # do nothing if offload is already applied - else: - sd_model.enable_sequential_cpu_offload(device=devices.device) - set_accelerate(sd_model) - except Exception as e: - shared.log.error(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} {e}') - if shared.opts.diffusers_offload_mode == "balanced": - sd_model = apply_balanced_offload(sd_model) - process_timer.add('offload', time.time() - t0) - - -class OffloadHook(accelerate.hooks.ModelHook): - def __init__(self, checkpoint_name): - if shared.opts.diffusers_offload_max_gpu_memory > 1: - shared.opts.diffusers_offload_max_gpu_memory = 0.75 - if shared.opts.diffusers_offload_max_cpu_memory > 1: - shared.opts.diffusers_offload_max_cpu_memory = 0.75 - self.checkpoint_name = checkpoint_name - self.min_watermark = shared.opts.diffusers_offload_min_gpu_memory - self.max_watermark = shared.opts.diffusers_offload_max_gpu_memory - self.cpu_watermark = shared.opts.diffusers_offload_max_cpu_memory - self.gpu = int(shared.gpu_memory * shared.opts.diffusers_offload_max_gpu_memory * 1024*1024*1024) - self.cpu = int(shared.cpu_memory * shared.opts.diffusers_offload_max_cpu_memory * 1024*1024*1024) - self.offload_map = {} - self.param_map = {} - gpu = f'{shared.gpu_memory * shared.opts.diffusers_offload_min_gpu_memory:.3f}-{shared.gpu_memory * shared.opts.diffusers_offload_max_gpu_memory}:{shared.gpu_memory}' - shared.log.info(f'Offload: type=balanced op=init watermark={self.min_watermark}-{self.max_watermark} gpu={gpu} cpu={shared.cpu_memory:.3f} limit={shared.opts.cuda_mem_fraction:.2f}') - self.validate() - super().__init__() - - def validate(self): - if shared.opts.diffusers_offload_mode != 'balanced': - return - if shared.opts.diffusers_offload_min_gpu_memory < 0 or shared.opts.diffusers_offload_min_gpu_memory > 1: - shared.opts.diffusers_offload_min_gpu_memory = 0.25 - shared.log.warning(f'Offload: type=balanced op=validate: watermark low={shared.opts.diffusers_offload_min_gpu_memory} invalid value') - if shared.opts.diffusers_offload_max_gpu_memory < 0.1 or shared.opts.diffusers_offload_max_gpu_memory > 1: - shared.opts.diffusers_offload_max_gpu_memory = 0.75 - shared.log.warning(f'Offload: type=balanced op=validate: watermark high={shared.opts.diffusers_offload_max_gpu_memory} invalid value') - if shared.opts.diffusers_offload_min_gpu_memory > shared.opts.diffusers_offload_max_gpu_memory: - shared.opts.diffusers_offload_min_gpu_memory = shared.opts.diffusers_offload_max_gpu_memory - shared.log.warning(f'Offload: type=balanced op=validate: watermark low={shared.opts.diffusers_offload_min_gpu_memory} reset') - if shared.opts.diffusers_offload_max_gpu_memory * shared.gpu_memory < 4: - shared.log.warning(f'Offload: type=balanced op=validate: watermark high={shared.opts.diffusers_offload_max_gpu_memory} low memory') - - def model_size(self): - return sum(self.offload_map.values()) - - def init_hook(self, module): - return module - - def pre_forward(self, module, *args, **kwargs): - if devices.normalize_device(module.device) != devices.normalize_device(devices.device): - device_index = torch.device(devices.device).index - if device_index is None: - device_index = 0 - max_memory = { device_index: self.gpu, "cpu": self.cpu } - device_map = getattr(module, "balanced_offload_device_map", None) - if device_map is None or max_memory != getattr(module, "balanced_offload_max_memory", None): - device_map = accelerate.infer_auto_device_map(module, max_memory=max_memory) - offload_dir = getattr(module, "offload_dir", os.path.join(shared.opts.accelerate_offload_path, module.__class__.__name__)) - module = accelerate.dispatch_model(module, device_map=device_map, offload_dir=offload_dir) - module._hf_hook.execution_device = torch.device(devices.device) # pylint: disable=protected-access - module.balanced_offload_device_map = device_map - module.balanced_offload_max_memory = max_memory - return args, kwargs - - def post_forward(self, module, output): - return output - - def detach_hook(self, module): - return module - - -def apply_balanced_offload(sd_model, exclude=[]): - global offload_hook_instance # pylint: disable=global-statement - if shared.opts.diffusers_offload_mode != "balanced": - return sd_model - t0 = time.time() - excluded = ['OmniGenPipeline'] - if sd_model.__class__.__name__ in excluded: - return sd_model - cached = True - checkpoint_name = sd_model.sd_checkpoint_info.name if getattr(sd_model, "sd_checkpoint_info", None) is not None else None - if checkpoint_name is None: - checkpoint_name = sd_model.__class__.__name__ - if offload_hook_instance is None or offload_hook_instance.min_watermark != shared.opts.diffusers_offload_min_gpu_memory or offload_hook_instance.max_watermark != shared.opts.diffusers_offload_max_gpu_memory or checkpoint_name != offload_hook_instance.checkpoint_name: - cached = False - offload_hook_instance = OffloadHook(checkpoint_name) - - def get_pipe_modules(pipe): - if hasattr(pipe, "_internal_dict"): - modules_names = pipe._internal_dict.keys() # pylint: disable=protected-access - else: - modules_names = get_signature(pipe).keys() - modules_names = [m for m in modules_names if m not in exclude and not m.startswith('_')] - modules = {} - for module_name in modules_names: - module_size = offload_hook_instance.offload_map.get(module_name, None) - if module_size is None: - module = getattr(pipe, module_name, None) - if not isinstance(module, torch.nn.Module): - continue - try: - module_size = sum(p.numel() * p.element_size() for p in module.parameters(recurse=True)) / 1024 / 1024 / 1024 - param_num = sum(p.numel() for p in module.parameters(recurse=True)) / 1024 / 1024 / 1024 - except Exception as e: - shared.log.error(f'Offload: type=balanced op=calc module={module_name} {e}') - module_size = 0 - offload_hook_instance.offload_map[module_name] = module_size - offload_hook_instance.param_map[module_name] = param_num - modules[module_name] = module_size - modules = sorted(modules.items(), key=lambda x: x[1], reverse=True) - return modules - - def apply_balanced_offload_to_module(pipe): - used_gpu, used_ram = devices.torch_gc(fast=True) - if hasattr(pipe, "pipe"): - apply_balanced_offload_to_module(pipe.pipe) - if hasattr(pipe, "_internal_dict"): - keys = pipe._internal_dict.keys() # pylint: disable=protected-access - else: - keys = get_signature(pipe).keys() - keys = [k for k in keys if k not in exclude and not k.startswith('_')] - for module_name, module_size in get_pipe_modules(pipe): # pylint: disable=protected-access - module = getattr(pipe, module_name, None) - if module is None: - continue - network_layer_name = getattr(module, "network_layer_name", None) - device_map = getattr(module, "balanced_offload_device_map", None) - max_memory = getattr(module, "balanced_offload_max_memory", None) - module = accelerate.hooks.remove_hook_from_module(module, recurse=True) - perc_gpu = used_gpu / shared.gpu_memory - try: - prev_gpu = used_gpu - do_offload = (perc_gpu > shared.opts.diffusers_offload_min_gpu_memory) and (module.device != devices.cpu) - if do_offload: - module = module.to(devices.cpu, non_blocking=True) - used_gpu -= module_size - if not cached: - shared.log.debug(f'Model module={module_name} type={module.__class__.__name__} dtype={module.dtype} quant={getattr(module, "quantization_method", None)} params={offload_hook_instance.param_map[module_name]:.3f} size={offload_hook_instance.offload_map[module_name]:.3f}') - debug_move(f'Offload: type=balanced op={"move" if do_offload else "skip"} gpu={prev_gpu:.3f}:{used_gpu:.3f} perc={perc_gpu:.2f} ram={used_ram:.3f} current={module.device} dtype={module.dtype} quant={getattr(module, "quantization_method", None)} module={module.__class__.__name__} size={module_size:.3f}') - except Exception as e: - if 'out of memory' in str(e): - devices.torch_gc(fast=True, force=True, reason='oom') - elif 'bitsandbytes' in str(e): - pass - else: - shared.log.error(f'Offload: type=balanced op=apply module={module_name} {e}') - if os.environ.get('SD_MOVE_DEBUG', None): - errors.display(e, f'Offload: type=balanced op=apply module={module_name}') - module.offload_dir = os.path.join(shared.opts.accelerate_offload_path, checkpoint_name, module_name) - module = accelerate.hooks.add_hook_to_module(module, offload_hook_instance, append=True) - module._hf_hook.execution_device = torch.device(devices.device) # pylint: disable=protected-access - if network_layer_name: - module.network_layer_name = network_layer_name - if device_map and max_memory: - module.balanced_offload_device_map = device_map - module.balanced_offload_max_memory = max_memory - devices.torch_gc(fast=True, force=True, reason='offload') - - apply_balanced_offload_to_module(sd_model) - if hasattr(sd_model, "pipe"): - apply_balanced_offload_to_module(sd_model.pipe) - if hasattr(sd_model, "prior_pipe"): - apply_balanced_offload_to_module(sd_model.prior_pipe) - if hasattr(sd_model, "decoder_pipe"): - apply_balanced_offload_to_module(sd_model.decoder_pipe) - set_accelerate(sd_model) - t = time.time() - t0 - process_timer.add('offload', t) - fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access - debug_move(f'Apply offload: time={t:.2f} type=balanced fn={fn}') - if not cached: - shared.log.info(f'Model class={sd_model.__class__.__name__} modules={len(offload_hook_instance.offload_map)} size={offload_hook_instance.model_size():.3f}') - return sd_model + set_diffuser_offload(sd_model, op, quiet) def move_model(model, device=None, force=False): @@ -647,50 +239,6 @@ def move_base(model, device): return R -def patch_diffuser_config(sd_model, model_file): - def load_config(fn, k): - model_file = os.path.splitext(fn)[0] - cfg_file = f'{model_file}_{k}.json' - try: - if os.path.exists(cfg_file): - with open(cfg_file, 'r', encoding='utf-8') as f: - return json.load(f) - cfg_file = f'{os.path.join(paths.sd_configs_path, os.path.basename(model_file))}_{k}.json' - if os.path.exists(cfg_file): - with open(cfg_file, 'r', encoding='utf-8') as f: - return json.load(f) - except Exception: - pass - return {} - - if sd_model is None: - return sd_model - if hasattr(sd_model, 'unet') and hasattr(sd_model.unet, 'config') and 'inpaint' in model_file.lower(): - if debug_load: - shared.log.debug('Model config patch: type=inpaint') - sd_model.unet.config.in_channels = 9 - if not hasattr(sd_model, '_internal_dict'): - return sd_model - for c in sd_model._internal_dict.keys(): # pylint: disable=protected-access - component = getattr(sd_model, c, None) - if hasattr(component, 'config'): - if debug_load: - shared.log.debug(f'Model config: component={c} config={component.config}') - override = load_config(model_file, c) - updated = {} - for k, v in override.items(): - if k.startswith('_'): - continue - if v != component.config.get(k, None): - if hasattr(component.config, '__frozen'): - component.config.__frozen = False # pylint: disable=protected-access - component.config[k] = v - updated[k] = v - if updated and debug_load: - shared.log.debug(f'Model config: component={c} override={updated}') - return sd_model - - def load_diffuser_initial(diffusers_load_config, op='model'): sd_model = None checkpoint_info = None @@ -1079,16 +627,6 @@ def get_diffusers_task(pipe: diffusers.DiffusionPipeline) -> DiffusersTaskType: return DiffusersTaskType.TEXT_2_IMAGE -def get_signature(cls): - signature = inspect.signature(cls.__init__, follow_wrapped=True, eval_str=True) - return signature.parameters - - -def get_call(cls): - signature = inspect.signature(cls.__call__, follow_wrapped=True, eval_str=True) - return signature.parameters - - def switch_pipe(cls: diffusers.DiffusionPipeline, pipeline: diffusers.DiffusionPipeline = None, force = False, args = {}): """ args: @@ -1208,6 +746,7 @@ def set_diffuser_pipe(pipe, new_pipe_type): 'FluxFillPipeline', 'FluxControlPipeline', 'StableVideoDiffusionPipeline', + 'PixelSmithXLPipeline', ] has_errors = False @@ -1241,6 +780,8 @@ def set_diffuser_pipe(pipe, new_pipe_type): default_scheduler = getattr(pipe, "default_scheduler", None) image_encoder = getattr(pipe, "image_encoder", None) feature_extractor = getattr(pipe, "feature_extractor", None) + mask_processor = getattr(pipe, "mask_processor", None) + restore_pipeline = getattr(pipe, "restore_pipeline", None) if new_pipe is None: if hasattr(pipe, 'config'): # real pipeline which can be auto-switched @@ -1288,6 +829,10 @@ def set_diffuser_pipe(pipe, new_pipe_type): new_pipe.image_encoder = image_encoder if feature_extractor is not None: new_pipe.feature_extractor = feature_extractor + if mask_processor is not None: + new_pipe.mask_processor = mask_processor + if restore_pipeline is not None: + new_pipe.restore_pipeline = restore_pipeline if new_pipe.__class__.__name__ in ['FluxPipeline', 'StableDiffusion3Pipeline']: new_pipe.register_modules(image_encoder = image_encoder) new_pipe.register_modules(feature_extractor = feature_extractor) @@ -1308,7 +853,7 @@ def set_diffuser_pipe(pipe, new_pipe_type): return pipe -def set_diffusers_attention(pipe): +def set_diffusers_attention(pipe, quiet:bool=False): import diffusers.models.attention_processor as p def set_attn(pipe, attention): @@ -1339,7 +884,7 @@ def set_attn(pipe, attention): if 'ControlNet' in pipe.__class__.__name__: # do not replace attention in ControlNet pipelines return - shared.log.debug(f'Setting model: attention="{shared.opts.cross_attention_optimization}"') + shared.log.quiet(quiet, f'Setting model: attention="{shared.opts.cross_attention_optimization}"') if shared.opts.cross_attention_optimization == "Disabled": pass # do nothing elif shared.opts.cross_attention_optimization == "Scaled-Dot-Product": # The default set by Diffusers @@ -1383,103 +928,6 @@ def get_native(pipe: diffusers.DiffusionPipeline): return size -def load_model(checkpoint_info=None, already_loaded_state_dict=None, timer=None, op='model'): - from ldm.util import instantiate_from_config - from modules import lowvram, sd_hijack - checkpoint_info = checkpoint_info or select_checkpoint(op=op) - if checkpoint_info is None: - return - if op == 'model' or op == 'dict': - if (model_data.sd_model is not None) and (getattr(model_data.sd_model, 'sd_checkpoint_info', None) is not None) and (checkpoint_info.hash == model_data.sd_model.sd_checkpoint_info.hash): # trying to load the same model - return - else: - if (model_data.sd_refiner is not None) and (getattr(model_data.sd_refiner, 'sd_checkpoint_info', None) is not None) and (checkpoint_info.hash == model_data.sd_refiner.sd_checkpoint_info.hash): # trying to load the same model - return - shared.log.debug(f'Load {op}: name={checkpoint_info.filename} dict={already_loaded_state_dict is not None}') - if timer is None: - timer = Timer() - current_checkpoint_info = None - if op == 'model' or op == 'dict': - if model_data.sd_model is not None: - sd_hijack.model_hijack.undo_hijack(model_data.sd_model) - current_checkpoint_info = getattr(model_data.sd_model, 'sd_checkpoint_info', None) - unload_model_weights(op=op) - else: - if model_data.sd_refiner is not None: - sd_hijack.model_hijack.undo_hijack(model_data.sd_refiner) - current_checkpoint_info = getattr(model_data.sd_refiner, 'sd_checkpoint_info', None) - unload_model_weights(op=op) - - if not shared.native: - from modules import sd_hijack_inpainting - sd_hijack_inpainting.do_inpainting_hijack() - - if already_loaded_state_dict is not None: - state_dict = already_loaded_state_dict - else: - state_dict = get_checkpoint_state_dict(checkpoint_info, timer) - checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) - if state_dict is None or checkpoint_config is None: - shared.log.error(f'Load {op}: path="{checkpoint_info.filename}"') - if current_checkpoint_info is not None: - shared.log.info(f'Load {op}: previous="{current_checkpoint_info.filename}" restore') - load_model(current_checkpoint_info, None) - return - shared.log.debug(f'Model dict loaded: {memory_stats()}') - sd_config = OmegaConf.load(checkpoint_config) - repair_config(sd_config) - timer.record("config") - shared.log.debug(f'Model config loaded: {memory_stats()}') - sd_model = None - stdout = io.StringIO() - if os.environ.get('SD_LDM_DEBUG', None) is not None: - sd_model = instantiate_from_config(sd_config.model) - else: - with contextlib.redirect_stdout(stdout): - sd_model = instantiate_from_config(sd_config.model) - for line in stdout.getvalue().splitlines(): - if len(line) > 0: - shared.log.info(f'LDM: {line.strip()}') - shared.log.debug(f"Model created from config: {checkpoint_config}") - sd_model.used_config = checkpoint_config - sd_model.has_accelerate = False - timer.record("create") - ok = load_model_weights(sd_model, checkpoint_info, state_dict, timer) - if not ok: - model_data.sd_model = sd_model - current_checkpoint_info = None - unload_model_weights(op=op) - shared.log.debug(f'Model weights unloaded: {memory_stats()} op={op}') - if op == 'refiner': - # shared.opts.data['sd_model_refiner'] = 'None' - shared.opts.sd_model_refiner = 'None' - return - else: - shared.log.debug(f'Model weights loaded: {memory_stats()}') - timer.record("load") - if not shared.native and (shared.cmd_opts.lowvram or shared.cmd_opts.medvram): - lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) - else: - move_model(sd_model, devices.device) - timer.record("move") - shared.log.debug(f'Model weights moved: {memory_stats()}') - sd_hijack.model_hijack.hijack(sd_model) - timer.record("hijack") - sd_model.eval() - if op == 'refiner': - model_data.sd_refiner = sd_model - else: - model_data.sd_model = sd_model - sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model - timer.record("embeddings") - script_callbacks.model_loaded_callback(sd_model) - timer.record("callbacks") - shared.log.info(f"Model loaded in {timer.summary()}") - current_checkpoint_info = None - devices.torch_gc(force=True) - shared.log.info(f'Model load finished: {memory_stats()}') - - def reload_text_encoder(initial=False): if initial and (shared.opts.sd_text_encoder is None or shared.opts.sd_text_encoder == 'None'): return # dont unload @@ -1579,37 +1027,8 @@ def reload_model_weights(sd_model=None, info=None, reuse_dict=False, op='model', return sd_model -def convert_to_faketensors(tensor): - try: - fake_module = torch._subclasses.fake_tensor.FakeTensorMode(allow_non_fake_inputs=True) # pylint: disable=protected-access - if hasattr(tensor, "weight"): - tensor.weight = torch.nn.Parameter(fake_module.from_tensor(tensor.weight)) - return tensor - except Exception: - pass - return tensor - - -def disable_offload(sd_model): - from accelerate.hooks import remove_hook_from_module - if not getattr(sd_model, 'has_accelerate', False): - return - if hasattr(sd_model, "_internal_dict"): - keys = sd_model._internal_dict.keys() # pylint: disable=protected-access - else: - keys = get_signature(sd_model).keys() - for module_name in keys: # pylint: disable=protected-access - module = getattr(sd_model, module_name, None) - if isinstance(module, torch.nn.Module): - network_layer_name = getattr(module, "network_layer_name", None) - module = remove_hook_from_module(module, recurse=True) - if network_layer_name: - module.network_layer_name = network_layer_name - sd_model.has_accelerate = False - - def clear_caches(): - shared.log.debug('Cache clear') + # shared.log.debug('Cache clear') if not shared.opts.lora_legacy: from modules.lora import networks networks.loaded_networks.clear() @@ -1648,16 +1067,3 @@ def unload_model_weights(op='model'): model_data.sd_refiner = None devices.torch_gc(force=True) shared.log.debug(f'Unload weights {op}: {memory_stats()}') - - -def path_to_repo(fn: str = ''): - if isinstance(fn, CheckpointInfo): - fn = fn.name - repo_id = fn.replace('\\', '/') - if 'models--' in repo_id: - repo_id = repo_id.split('models--')[-1] - repo_id = repo_id.split('/')[0] - repo_id = repo_id.split('/') - repo_id = '/'.join(repo_id[-2:] if len(repo_id) > 1 else repo_id) - repo_id = repo_id.replace('models--', '').replace('--', '/') - return repo_id diff --git a/modules/sd_models_legacy.py b/modules/sd_models_legacy.py new file mode 100644 index 000000000..ec21da7b7 --- /dev/null +++ b/modules/sd_models_legacy.py @@ -0,0 +1,207 @@ +import io +import os +import sys +import contextlib + +from modules import shared + + +sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight' +sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight' + + +def get_checkpoint_state_dict(checkpoint_info, timer): + from modules.sd_models_utils import read_state_dict + if not os.path.isfile(checkpoint_info.filename): + return None + """ + if checkpoint_info in checkpoints_loaded: + shared.log.info("Load model: cache") + checkpoints_loaded.move_to_end(checkpoint_info, last=True) # FIFO -> LRU cache + return checkpoints_loaded[checkpoint_info] + """ + res = read_state_dict(checkpoint_info.filename, what='model') + """ + if shared.opts.sd_checkpoint_cache > 0 and not shared.native: + # cache newly loaded model + checkpoints_loaded[checkpoint_info] = res + # clean up cache if limit is reached + while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: + checkpoints_loaded.popitem(last=False) + """ + timer.record("load") + return res + + +def repair_config(sd_config): + from modules import paths + if "use_ema" not in sd_config.model.params: + sd_config.model.params.use_ema = False + if shared.opts.no_half: + sd_config.model.params.unet_config.params.use_fp16 = False + elif shared.opts.upcast_sampling: + sd_config.model.params.unet_config.params.use_fp16 = True if sys.platform != 'darwin' else False + if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available: + sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla" + # For UnCLIP-L, override the hardcoded karlo directory + if "noise_aug_config" in sd_config.model.params and "clip_stats_path" in sd_config.model.params.noise_aug_config.params: + karlo_path = os.path.join(paths.models_path, 'karlo') + sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path) + + +def load_model_weights(model, checkpoint_info, state_dict, timer): + # _pipeline, _model_type = sd_detect.detect_pipeline(checkpoint_info.path, 'model') + from modules.modeldata import model_data + from modules.memstats import memory_stats + from modules import devices, sd_vae + shared.log.debug(f'Load model: memory={memory_stats()}') + timer.record("hash") + if model_data.sd_dict == 'None': + shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title + if state_dict is None: + state_dict = get_checkpoint_state_dict(checkpoint_info, timer) + try: + model.load_state_dict(state_dict, strict=False) + except Exception as e: + shared.log.error(f'Load model: path="{checkpoint_info.filename}"') + shared.log.error(' '.join(str(e).splitlines()[:2])) + return False + del state_dict + timer.record("apply") + if shared.opts.opt_channelslast: + import torch + model.to(memory_format=torch.channels_last) + timer.record("channels") + if not shared.opts.no_half: + vae = model.first_stage_model + depth_model = getattr(model, 'depth_model', None) + if shared.opts.no_half_vae: # remove VAE from model when doing half() to prevent its weights from being converted to float16 + model.first_stage_model = None + if shared.opts.upcast_sampling and depth_model: # with don't convert the depth model weights to float16 + model.depth_model = None + model.half() + model.first_stage_model = vae + if depth_model: + model.depth_model = depth_model + if shared.opts.cuda_cast_unet: + devices.dtype_unet = model.model.diffusion_model.dtype + else: + model.model.diffusion_model.to(devices.dtype_unet) + model.first_stage_model.to(devices.dtype_vae) + model.sd_model_hash = checkpoint_info.calculate_shorthash() + model.sd_model_checkpoint = checkpoint_info.filename + model.sd_checkpoint_info = checkpoint_info + model.is_sdxl = False # a1111 compatibility item + model.is_sd2 = hasattr(model.cond_stage_model, 'model') # a1111 compatibility item + model.is_sd1 = not hasattr(model.cond_stage_model, 'model') # a1111 compatibility item + model.logvar = model.logvar.to(devices.device) if hasattr(model, 'logvar') else None # fix for training + shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256 + sd_vae.delete_base_vae() + sd_vae.clear_loaded_vae() + vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename) + sd_vae.load_vae(model, vae_file, vae_source) + timer.record("vae") + return True + + +def load_model(checkpoint_info=None, already_loaded_state_dict=None, timer=None, op='model'): + from ldm.util import instantiate_from_config + from omegaconf import OmegaConf + from modules import devices, lowvram, sd_hijack, sd_models_config, script_callbacks + from modules.timer import Timer + from modules.memstats import memory_stats + from modules.modeldata import model_data + from modules.sd_models import unload_model_weights, move_model + from modules.sd_checkpoint import select_checkpoint + checkpoint_info = checkpoint_info or select_checkpoint(op=op) + if checkpoint_info is None: + return + if op == 'model' or op == 'dict': + if (model_data.sd_model is not None) and (getattr(model_data.sd_model, 'sd_checkpoint_info', None) is not None) and (checkpoint_info.hash == model_data.sd_model.sd_checkpoint_info.hash): # trying to load the same model + return + else: + if (model_data.sd_refiner is not None) and (getattr(model_data.sd_refiner, 'sd_checkpoint_info', None) is not None) and (checkpoint_info.hash == model_data.sd_refiner.sd_checkpoint_info.hash): # trying to load the same model + return + shared.log.debug(f'Load {op}: name={checkpoint_info.filename} dict={already_loaded_state_dict is not None}') + if timer is None: + timer = Timer() + current_checkpoint_info = None + if op == 'model' or op == 'dict': + if model_data.sd_model is not None: + sd_hijack.model_hijack.undo_hijack(model_data.sd_model) + current_checkpoint_info = getattr(model_data.sd_model, 'sd_checkpoint_info', None) + unload_model_weights(op=op) + else: + if model_data.sd_refiner is not None: + sd_hijack.model_hijack.undo_hijack(model_data.sd_refiner) + current_checkpoint_info = getattr(model_data.sd_refiner, 'sd_checkpoint_info', None) + unload_model_weights(op=op) + + if not shared.native: + from modules import sd_hijack_inpainting + sd_hijack_inpainting.do_inpainting_hijack() + + if already_loaded_state_dict is not None: + state_dict = already_loaded_state_dict + else: + state_dict = get_checkpoint_state_dict(checkpoint_info, timer) + checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) + if state_dict is None or checkpoint_config is None: + shared.log.error(f'Load {op}: path="{checkpoint_info.filename}"') + if current_checkpoint_info is not None: + shared.log.info(f'Load {op}: previous="{current_checkpoint_info.filename}" restore') + load_model(current_checkpoint_info, None) + return + shared.log.debug(f'Model dict loaded: {memory_stats()}') + sd_config = OmegaConf.load(checkpoint_config) + repair_config(sd_config) + timer.record("config") + shared.log.debug(f'Model config loaded: {memory_stats()}') + sd_model = None + stdout = io.StringIO() + if os.environ.get('SD_LDM_DEBUG', None) is not None: + sd_model = instantiate_from_config(sd_config.model) + else: + with contextlib.redirect_stdout(stdout): + sd_model = instantiate_from_config(sd_config.model) + for line in stdout.getvalue().splitlines(): + if len(line) > 0: + shared.log.info(f'LDM: {line.strip()}') + shared.log.debug(f"Model created from config: {checkpoint_config}") + sd_model.used_config = checkpoint_config + sd_model.has_accelerate = False + timer.record("create") + ok = load_model_weights(sd_model, checkpoint_info, state_dict, timer) + if not ok: + model_data.sd_model = sd_model + current_checkpoint_info = None + unload_model_weights(op=op) + shared.log.debug(f'Model weights unloaded: {memory_stats()} op={op}') + if op == 'refiner': + # shared.opts.data['sd_model_refiner'] = 'None' + shared.opts.sd_model_refiner = 'None' + return + else: + shared.log.debug(f'Model weights loaded: {memory_stats()}') + timer.record("load") + if not shared.native and (shared.cmd_opts.lowvram or shared.cmd_opts.medvram): + lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) + else: + move_model(sd_model, devices.device) + timer.record("move") + shared.log.debug(f'Model weights moved: {memory_stats()}') + sd_hijack.model_hijack.hijack(sd_model) + timer.record("hijack") + sd_model.eval() + if op == 'refiner': + model_data.sd_refiner = sd_model + else: + model_data.sd_model = sd_model + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model + timer.record("embeddings") + script_callbacks.model_loaded_callback(sd_model) + timer.record("callbacks") + shared.log.info(f"Model loaded in {timer.summary()}") + current_checkpoint_info = None + devices.torch_gc(force=True) + shared.log.info(f'Model load finished: {memory_stats()}') diff --git a/modules/sd_models_utils.py b/modules/sd_models_utils.py new file mode 100644 index 000000000..0ff903483 --- /dev/null +++ b/modules/sd_models_utils.py @@ -0,0 +1,151 @@ +import io +import json +import inspect +import os.path +from rich import progress # pylint: disable=redefined-builtin +import torch +import safetensors.torch + +from modules import paths, shared, errors +from modules.sd_checkpoint import CheckpointInfo, select_checkpoint, list_models, checkpoints_list, checkpoint_titles, get_closet_checkpoint_match, model_hash, update_model_hashes, setup_model, write_metadata, read_metadata_from_safetensors # pylint: disable=unused-import +from modules.sd_offload import disable_offload, set_diffuser_offload, apply_balanced_offload, set_accelerate # pylint: disable=unused-import +from modules.sd_models_legacy import get_checkpoint_state_dict, load_model_weights, load_model, repair_config # pylint: disable=unused-import + + +class NoWatermark: + def apply_watermark(self, img): + return img + + +def get_signature(cls): + signature = inspect.signature(cls.__init__, follow_wrapped=True, eval_str=True) + return signature.parameters + + +def get_call(cls): + if cls is None: + return [] + signature = inspect.signature(cls.__call__, follow_wrapped=True, eval_str=True) + return signature.parameters + + +def path_to_repo(fn: str = ''): + if isinstance(fn, CheckpointInfo): + fn = fn.name + repo_id = fn.replace('\\', '/') + if 'models--' in repo_id: + repo_id = repo_id.split('models--')[-1] + repo_id = repo_id.split('/')[0] + repo_id = repo_id.split('/') + repo_id = '/'.join(repo_id[-2:] if len(repo_id) > 1 else repo_id) + repo_id = repo_id.replace('models--', '').replace('--', '/') + return repo_id + + +def convert_to_faketensors(tensor): + try: + fake_module = torch._subclasses.fake_tensor.FakeTensorMode(allow_non_fake_inputs=True) # pylint: disable=protected-access + if hasattr(tensor, "weight"): + tensor.weight = torch.nn.Parameter(fake_module.from_tensor(tensor.weight)) + return tensor + except Exception: + pass + return tensor + + +def read_state_dict(checkpoint_file, map_location=None, what:str='model'): # pylint: disable=unused-argument + if not os.path.isfile(checkpoint_file): + shared.log.error(f'Load dict: path="{checkpoint_file}" not a file') + return None + try: + pl_sd = None + with progress.open(checkpoint_file, 'rb', description=f'[cyan]Load {what}: [yellow]{checkpoint_file}', auto_refresh=True, console=shared.console) as f: + _, extension = os.path.splitext(checkpoint_file) + if extension.lower() == ".ckpt" and shared.opts.sd_disable_ckpt: + shared.log.warning(f"Checkpoint loading disabled: {checkpoint_file}") + return None + if shared.opts.stream_load: + if extension.lower() == ".safetensors": + # shared.log.debug('Model weights loading: type=safetensors mode=buffered') + buffer = f.read() + pl_sd = safetensors.torch.load(buffer) + else: + # shared.log.debug('Model weights loading: type=checkpoint mode=buffered') + buffer = io.BytesIO(f.read()) + pl_sd = torch.load(buffer, map_location='cpu') + else: + if extension.lower() == ".safetensors": + # shared.log.debug('Model weights loading: type=safetensors mode=mmap') + pl_sd = safetensors.torch.load_file(checkpoint_file, device='cpu') + else: + # shared.log.debug('Model weights loading: type=checkpoint mode=direct') + pl_sd = torch.load(f, map_location='cpu') + sd = get_state_dict_from_checkpoint(pl_sd) + del pl_sd + except Exception as e: + errors.display(e, f'Load model: {checkpoint_file}') + sd = None + return sd + + +def get_state_dict_from_checkpoint(pl_sd): + checkpoint_dict_replacements = { + 'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.', + 'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.', + 'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.', + } + + def transform_checkpoint_dict_key(k): + for text, replacement in checkpoint_dict_replacements.items(): + if k.startswith(text): + k = replacement + k[len(text):] + return k + + pl_sd = pl_sd.pop("state_dict", pl_sd) + pl_sd.pop("state_dict", None) + sd = {} + for k, v in pl_sd.items(): + new_key = transform_checkpoint_dict_key(k) + if new_key is not None: + sd[new_key] = v + pl_sd.clear() + pl_sd.update(sd) + return pl_sd + + +def patch_diffuser_config(sd_model, model_file): + def load_config(fn, k): + model_file = os.path.splitext(fn)[0] + cfg_file = f'{model_file}_{k}.json' + try: + if os.path.exists(cfg_file): + with open(cfg_file, 'r', encoding='utf-8') as f: + return json.load(f) + cfg_file = f'{os.path.join(paths.sd_configs_path, os.path.basename(model_file))}_{k}.json' + if os.path.exists(cfg_file): + with open(cfg_file, 'r', encoding='utf-8') as f: + return json.load(f) + except Exception: + pass + return {} + + if sd_model is None: + return sd_model + if hasattr(sd_model, 'unet') and hasattr(sd_model.unet, 'config') and 'inpaint' in model_file.lower(): + sd_model.unet.config.in_channels = 9 + if not hasattr(sd_model, '_internal_dict'): + return sd_model + for c in sd_model._internal_dict.keys(): # pylint: disable=protected-access + component = getattr(sd_model, c, None) + if hasattr(component, 'config'): + override = load_config(model_file, c) + updated = {} + for k, v in override.items(): + if k.startswith('_'): + continue + if v != component.config.get(k, None): + if hasattr(component.config, '__frozen'): + component.config.__frozen = False # pylint: disable=protected-access + component.config[k] = v + updated[k] = v + return sd_model diff --git a/modules/sd_offload.py b/modules/sd_offload.py new file mode 100644 index 000000000..f9d01528c --- /dev/null +++ b/modules/sd_offload.py @@ -0,0 +1,280 @@ +import os +import sys +import time +import inspect +import torch +import accelerate + +from modules import shared, devices, errors +from modules.timer import process as process_timer + + +debug_move = shared.log.trace if os.environ.get('SD_MOVE_DEBUG', None) is not None else lambda *args, **kwargs: None +should_offload = ['sc', 'sd3', 'f1', 'hunyuandit', 'auraflow', 'omnigen'] +offload_hook_instance = None + + +def get_signature(cls): + signature = inspect.signature(cls.__init__, follow_wrapped=True, eval_str=True) + return signature.parameters + + +def disable_offload(sd_model): + from accelerate.hooks import remove_hook_from_module + if not getattr(sd_model, 'has_accelerate', False): + return + if hasattr(sd_model, "_internal_dict"): + keys = sd_model._internal_dict.keys() # pylint: disable=protected-access + else: + keys = get_signature(sd_model).keys() + for module_name in keys: # pylint: disable=protected-access + module = getattr(sd_model, module_name, None) + if isinstance(module, torch.nn.Module): + network_layer_name = getattr(module, "network_layer_name", None) + module = remove_hook_from_module(module, recurse=True) + if network_layer_name: + module.network_layer_name = network_layer_name + sd_model.has_accelerate = False + + +def set_accelerate(sd_model): + def set_accelerate_to_module(model): + if hasattr(model, "pipe"): + set_accelerate_to_module(model.pipe) + if hasattr(model, "_internal_dict"): + for k in model._internal_dict.keys(): # pylint: disable=protected-access + component = getattr(model, k, None) + if isinstance(component, torch.nn.Module): + component.has_accelerate = True + + sd_model.has_accelerate = True + set_accelerate_to_module(sd_model) + if hasattr(sd_model, "prior_pipe"): + set_accelerate_to_module(sd_model.prior_pipe) + if hasattr(sd_model, "decoder_pipe"): + set_accelerate_to_module(sd_model.decoder_pipe) + + +def set_diffuser_offload(sd_model, op:str='model', quiet:bool=False): + t0 = time.time() + if not shared.native: + shared.log.warning('Attempting to use offload with backend=original') + return + if sd_model is None: + shared.log.warning(f'{op} is not loaded') + return + if not (hasattr(sd_model, "has_accelerate") and sd_model.has_accelerate): + sd_model.has_accelerate = False + if shared.opts.diffusers_offload_mode == "none": + if shared.sd_model_type in should_offload: + shared.log.warning(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} type={shared.sd_model.__class__.__name__} large model') + else: + shared.log.quiet(quiet, f'Setting {op}: offload={shared.opts.diffusers_offload_mode} limit={shared.opts.cuda_mem_fraction}') + if hasattr(sd_model, 'maybe_free_model_hooks'): + sd_model.maybe_free_model_hooks() + sd_model.has_accelerate = False + if shared.opts.diffusers_offload_mode == "model" and hasattr(sd_model, "enable_model_cpu_offload"): + try: + shared.log.quiet(quiet, f'Setting {op}: offload={shared.opts.diffusers_offload_mode} limit={shared.opts.cuda_mem_fraction}') + if shared.opts.diffusers_move_base or shared.opts.diffusers_move_unet or shared.opts.diffusers_move_refiner: + shared.opts.diffusers_move_base = False + shared.opts.diffusers_move_unet = False + shared.opts.diffusers_move_refiner = False + shared.log.warning(f'Disabling {op} "Move model to CPU" since "Model CPU offload" is enabled') + if not hasattr(sd_model, "_all_hooks") or len(sd_model._all_hooks) == 0: # pylint: disable=protected-access + sd_model.enable_model_cpu_offload(device=devices.device) + else: + sd_model.maybe_free_model_hooks() + set_accelerate(sd_model) + except Exception as e: + shared.log.error(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} {e}') + if shared.opts.diffusers_offload_mode == "sequential" and hasattr(sd_model, "enable_sequential_cpu_offload"): + try: + shared.log.debug(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} limit={shared.opts.cuda_mem_fraction}') + if shared.opts.diffusers_move_base or shared.opts.diffusers_move_unet or shared.opts.diffusers_move_refiner: + shared.opts.diffusers_move_base = False + shared.opts.diffusers_move_unet = False + shared.opts.diffusers_move_refiner = False + shared.log.warning(f'Disabling {op} "Move model to CPU" since "Sequential CPU offload" is enabled') + if sd_model.has_accelerate: + if op == "vae": # reapply sequential offload to vae + from accelerate import cpu_offload + sd_model.vae.to("cpu") + cpu_offload(sd_model.vae, devices.device, offload_buffers=len(sd_model.vae._parameters) > 0) # pylint: disable=protected-access + else: + pass # do nothing if offload is already applied + else: + sd_model.enable_sequential_cpu_offload(device=devices.device) + set_accelerate(sd_model) + except Exception as e: + shared.log.error(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} {e}') + if shared.opts.diffusers_offload_mode == "balanced": + sd_model = apply_balanced_offload(sd_model) + process_timer.add('offload', time.time() - t0) + + +class OffloadHook(accelerate.hooks.ModelHook): + def __init__(self, checkpoint_name): + if shared.opts.diffusers_offload_max_gpu_memory > 1: + shared.opts.diffusers_offload_max_gpu_memory = 0.75 + if shared.opts.diffusers_offload_max_cpu_memory > 1: + shared.opts.diffusers_offload_max_cpu_memory = 0.75 + self.checkpoint_name = checkpoint_name + self.min_watermark = shared.opts.diffusers_offload_min_gpu_memory + self.max_watermark = shared.opts.diffusers_offload_max_gpu_memory + self.cpu_watermark = shared.opts.diffusers_offload_max_cpu_memory + self.gpu = int(shared.gpu_memory * shared.opts.diffusers_offload_max_gpu_memory * 1024*1024*1024) + self.cpu = int(shared.cpu_memory * shared.opts.diffusers_offload_max_cpu_memory * 1024*1024*1024) + self.offload_map = {} + self.param_map = {} + gpu = f'{shared.gpu_memory * shared.opts.diffusers_offload_min_gpu_memory:.3f}-{shared.gpu_memory * shared.opts.diffusers_offload_max_gpu_memory}:{shared.gpu_memory}' + shared.log.info(f'Offload: type=balanced op=init watermark={self.min_watermark}-{self.max_watermark} gpu={gpu} cpu={shared.cpu_memory:.3f} limit={shared.opts.cuda_mem_fraction:.2f}') + self.validate() + super().__init__() + + def validate(self): + if shared.opts.diffusers_offload_mode != 'balanced': + return + if shared.opts.diffusers_offload_min_gpu_memory < 0 or shared.opts.diffusers_offload_min_gpu_memory > 1: + shared.opts.diffusers_offload_min_gpu_memory = 0.25 + shared.log.warning(f'Offload: type=balanced op=validate: watermark low={shared.opts.diffusers_offload_min_gpu_memory} invalid value') + if shared.opts.diffusers_offload_max_gpu_memory < 0.1 or shared.opts.diffusers_offload_max_gpu_memory > 1: + shared.opts.diffusers_offload_max_gpu_memory = 0.75 + shared.log.warning(f'Offload: type=balanced op=validate: watermark high={shared.opts.diffusers_offload_max_gpu_memory} invalid value') + if shared.opts.diffusers_offload_min_gpu_memory > shared.opts.diffusers_offload_max_gpu_memory: + shared.opts.diffusers_offload_min_gpu_memory = shared.opts.diffusers_offload_max_gpu_memory + shared.log.warning(f'Offload: type=balanced op=validate: watermark low={shared.opts.diffusers_offload_min_gpu_memory} reset') + if shared.opts.diffusers_offload_max_gpu_memory * shared.gpu_memory < 4: + shared.log.warning(f'Offload: type=balanced op=validate: watermark high={shared.opts.diffusers_offload_max_gpu_memory} low memory') + + def model_size(self): + return sum(self.offload_map.values()) + + def init_hook(self, module): + return module + + def pre_forward(self, module, *args, **kwargs): + if devices.normalize_device(module.device) != devices.normalize_device(devices.device): + device_index = torch.device(devices.device).index + if device_index is None: + device_index = 0 + max_memory = { device_index: self.gpu, "cpu": self.cpu } + device_map = getattr(module, "balanced_offload_device_map", None) + if device_map is None or max_memory != getattr(module, "balanced_offload_max_memory", None): + device_map = accelerate.infer_auto_device_map(module, max_memory=max_memory) + offload_dir = getattr(module, "offload_dir", os.path.join(shared.opts.accelerate_offload_path, module.__class__.__name__)) + module = accelerate.dispatch_model(module, device_map=device_map, offload_dir=offload_dir) + module._hf_hook.execution_device = torch.device(devices.device) # pylint: disable=protected-access + module.balanced_offload_device_map = device_map + module.balanced_offload_max_memory = max_memory + return args, kwargs + + def post_forward(self, module, output): + return output + + def detach_hook(self, module): + return module + + +def apply_balanced_offload(sd_model, exclude=[]): + global offload_hook_instance # pylint: disable=global-statement + if shared.opts.diffusers_offload_mode != "balanced": + return sd_model + t0 = time.time() + excluded = ['OmniGenPipeline'] + if sd_model.__class__.__name__ in excluded: + return sd_model + cached = True + checkpoint_name = sd_model.sd_checkpoint_info.name if getattr(sd_model, "sd_checkpoint_info", None) is not None else None + if checkpoint_name is None: + checkpoint_name = sd_model.__class__.__name__ + if offload_hook_instance is None or offload_hook_instance.min_watermark != shared.opts.diffusers_offload_min_gpu_memory or offload_hook_instance.max_watermark != shared.opts.diffusers_offload_max_gpu_memory or checkpoint_name != offload_hook_instance.checkpoint_name: + cached = False + offload_hook_instance = OffloadHook(checkpoint_name) + + def get_pipe_modules(pipe): + if hasattr(pipe, "_internal_dict"): + modules_names = pipe._internal_dict.keys() # pylint: disable=protected-access + else: + modules_names = get_signature(pipe).keys() + modules_names = [m for m in modules_names if m not in exclude and not m.startswith('_')] + modules = {} + for module_name in modules_names: + module_size = offload_hook_instance.offload_map.get(module_name, None) + if module_size is None: + module = getattr(pipe, module_name, None) + if not isinstance(module, torch.nn.Module): + continue + try: + module_size = sum(p.numel() * p.element_size() for p in module.parameters(recurse=True)) / 1024 / 1024 / 1024 + param_num = sum(p.numel() for p in module.parameters(recurse=True)) / 1024 / 1024 / 1024 + except Exception as e: + shared.log.error(f'Offload: type=balanced op=calc module={module_name} {e}') + module_size = 0 + offload_hook_instance.offload_map[module_name] = module_size + offload_hook_instance.param_map[module_name] = param_num + modules[module_name] = module_size + modules = sorted(modules.items(), key=lambda x: x[1], reverse=True) + return modules + + def apply_balanced_offload_to_module(pipe): + used_gpu, used_ram = devices.torch_gc(fast=True) + if hasattr(pipe, "pipe"): + apply_balanced_offload_to_module(pipe.pipe) + if hasattr(pipe, "_internal_dict"): + keys = pipe._internal_dict.keys() # pylint: disable=protected-access + else: + keys = get_signature(pipe).keys() + keys = [k for k in keys if k not in exclude and not k.startswith('_')] + for module_name, module_size in get_pipe_modules(pipe): # pylint: disable=protected-access + module = getattr(pipe, module_name, None) + if module is None: + continue + network_layer_name = getattr(module, "network_layer_name", None) + device_map = getattr(module, "balanced_offload_device_map", None) + max_memory = getattr(module, "balanced_offload_max_memory", None) + module = accelerate.hooks.remove_hook_from_module(module, recurse=True) + perc_gpu = used_gpu / shared.gpu_memory + try: + prev_gpu = used_gpu + do_offload = (perc_gpu > shared.opts.diffusers_offload_min_gpu_memory) and (module.device != devices.cpu) + if do_offload: + module = module.to(devices.cpu, non_blocking=True) + used_gpu -= module_size + if not cached: + shared.log.debug(f'Model module={module_name} type={module.__class__.__name__} dtype={module.dtype} quant={getattr(module, "quantization_method", None)} params={offload_hook_instance.param_map[module_name]:.3f} size={offload_hook_instance.offload_map[module_name]:.3f}') + debug_move(f'Offload: type=balanced op={"move" if do_offload else "skip"} gpu={prev_gpu:.3f}:{used_gpu:.3f} perc={perc_gpu:.2f} ram={used_ram:.3f} current={module.device} dtype={module.dtype} quant={getattr(module, "quantization_method", None)} module={module.__class__.__name__} size={module_size:.3f}') + except Exception as e: + if 'out of memory' in str(e): + devices.torch_gc(fast=True, force=True, reason='oom') + elif 'bitsandbytes' in str(e): + pass + else: + shared.log.error(f'Offload: type=balanced op=apply module={module_name} {e}') + if os.environ.get('SD_MOVE_DEBUG', None): + errors.display(e, f'Offload: type=balanced op=apply module={module_name}') + module.offload_dir = os.path.join(shared.opts.accelerate_offload_path, checkpoint_name, module_name) + module = accelerate.hooks.add_hook_to_module(module, offload_hook_instance, append=True) + module._hf_hook.execution_device = torch.device(devices.device) # pylint: disable=protected-access + if network_layer_name: + module.network_layer_name = network_layer_name + if device_map and max_memory: + module.balanced_offload_device_map = device_map + module.balanced_offload_max_memory = max_memory + devices.torch_gc(fast=True, force=True, reason='offload') + + apply_balanced_offload_to_module(sd_model) + if hasattr(sd_model, "pipe"): + apply_balanced_offload_to_module(sd_model.pipe) + if hasattr(sd_model, "prior_pipe"): + apply_balanced_offload_to_module(sd_model.prior_pipe) + if hasattr(sd_model, "decoder_pipe"): + apply_balanced_offload_to_module(sd_model.decoder_pipe) + set_accelerate(sd_model) + t = time.time() - t0 + process_timer.add('offload', t) + fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access + debug_move(f'Apply offload: time={t:.2f} type=balanced fn={fn}') + if not cached: + shared.log.info(f'Model class={sd_model.__class__.__name__} modules={len(offload_hook_instance.offload_map)} size={offload_hook_instance.model_size():.3f}') + return sd_model diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index b01b847d8..62bc4f55c 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -59,7 +59,7 @@ def single_sample_to_image(sample, approximation=None): except Exception: pass x_sample = sd_vae_taesd.decode(sample) - x_sample = (1.0 + x_sample) / 2.0 # preview requires smaller range + # x_sample = (1.0 + x_sample) / 2.0 # preview requires smaller range elif shared.sd_model_type == 'sc' and approximation != 3: x_sample = sd_vae_stablecascade.decode(sample) elif approximation == 0: # Simple @@ -67,19 +67,24 @@ def single_sample_to_image(sample, approximation=None): elif approximation == 1: # Approximate x_sample = sd_vae_approx.nn_approximation(sample) * 0.5 + 0.5 if shared.sd_model_type == "sdxl": - x_sample = x_sample[[2,1,0], :, :] # BGR to RGB + x_sample = x_sample[[2, 1, 0], :, :] # BGR to RGB elif approximation == 3: # Full VAE x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] else: warn_once(f"Unknown latent decode type: {approximation}") return Image.new(mode="RGB", size=(512, 512)) try: - if x_sample.shape[0] > 4: - return Image.new(mode="RGB", size=(512, 512)) - if x_sample.dtype == torch.bfloat16: - x_sample.to(torch.float16) - transform = T.ToPILImage() - image = transform(x_sample) + if isinstance(x_sample, Image.Image): + image = x_sample + else: + if x_sample.shape[0] > 4 or x_sample.shape[0] == 4: + return Image.new(mode="RGB", size=(512, 512)) + if x_sample.dtype == torch.bfloat16: + x_sample = x_sample.to(torch.float16) + if len(x_sample.shape) == 4: + x_sample = x_sample[0] + transform = T.ToPILImage() + image = transform(x_sample) except Exception as e: warn_once(f'Preview: {e}') image = Image.new(mode="RGB", size=(512, 512)) diff --git a/modules/sd_samplers_diffusers.py b/modules/sd_samplers_diffusers.py index 6e5f37cd7..6b2de72aa 100644 --- a/modules/sd_samplers_diffusers.py +++ b/modules/sd_samplers_diffusers.py @@ -7,8 +7,8 @@ from modules.sd_samplers_common import SamplerData, flow_models -debug = shared.log.trace if os.environ.get('SD_SAMPLER_DEBUG', None) is not None else lambda *args, **kwargs: None -debug('Trace: SAMPLER') +debug = os.environ.get('SD_SAMPLER_DEBUG', None) is not None +debug_log = shared.log.trace if debug else lambda *args, **kwargs: None try: from diffusers import ( @@ -48,6 +48,7 @@ errors.display(e, 'Samplers') try: from modules.schedulers.scheduler_tcd import TCDScheduler # pylint: disable=ungrouped-imports + from modules.schedulers.scheduler_tdd import TDDScheduler # pylint: disable=ungrouped-imports from modules.schedulers.scheduler_dc import DCSolverMultistepScheduler # pylint: disable=ungrouped-imports from modules.schedulers.scheduler_vdm import VDMScheduler # pylint: disable=ungrouped-imports from modules.schedulers.scheduler_dpm_flowmatch import FlowMatchDPMSolverMultistepScheduler # pylint: disable=ungrouped-imports @@ -98,7 +99,8 @@ 'VDM Solver': { 'clip_sample_range': 2.0, }, 'LCM': { 'beta_start': 0.00085, 'beta_end': 0.012, 'beta_schedule': "scaled_linear", 'set_alpha_to_one': True, 'rescale_betas_zero_snr': False, 'thresholding': False, 'timestep_spacing': 'linspace' }, 'TCD': { 'set_alpha_to_one': True, 'rescale_betas_zero_snr': False, 'beta_schedule': 'scaled_linear' }, - 'UFOGen': {}, + 'TDD': { }, + 'UFOGen': { }, 'BDIA DDIM': { 'clip_sample': False, 'set_alpha_to_one': True, 'steps_offset': 0, 'clip_sample_range': 1.0, 'sample_max_value': 1.0, 'timestep_spacing': 'leading', 'rescale_betas_zero_snr': False, 'thresholding': False, 'gamma': 1.0 }, 'PNDM': { 'skip_prk_steps': False, 'set_alpha_to_one': False, 'steps_offset': 0, 'timestep_spacing': 'linspace' }, @@ -157,6 +159,7 @@ SamplerData('LCM', lambda model: DiffusionSampler('LCM', LCMScheduler, model), [], {}), SamplerData('TCD', lambda model: DiffusionSampler('TCD', TCDScheduler, model), [], {}), + SamplerData('TDD', lambda model: DiffusionSampler('TDD', TDDScheduler, model), [], {}), SamplerData('UFOGen', lambda model: DiffusionSampler('UFOGen', UFOGenScheduler, model), [], {}), SamplerData('Same as primary', None, [], {}), @@ -175,17 +178,17 @@ def __init__(self, name, constructor, model, **kwargs): model.default_scheduler = copy.deepcopy(model.scheduler) for key, value in config.get('All', {}).items(): # apply global defaults self.config[key] = value - debug(f'Sampler: all="{self.config}"') + debug_log(f'Sampler: all="{self.config}"') if hasattr(model.default_scheduler, 'scheduler_config'): # find model defaults orig_config = model.default_scheduler.scheduler_config else: orig_config = model.default_scheduler.config - debug(f'Sampler: diffusers="{self.config}"') - debug(f'Sampler: original="{orig_config}"') + debug_log(f'Sampler: diffusers="{self.config}"') + debug_log(f'Sampler: original="{orig_config}"') for key, value in orig_config.items(): # apply model defaults if key in self.config: self.config[key] = value - debug(f'Sampler: default="{self.config}"') + debug_log(f'Sampler: default="{self.config}"') for key, value in config.get(name, {}).items(): # apply diffusers per-scheduler defaults self.config[key] = value for key, value in kwargs.items(): # apply user args, if any @@ -264,15 +267,22 @@ def __init__(self, name, constructor, model, **kwargs): if key not in possible: # shared.log.warning(f'Sampler: sampler="{name}" config={self.config} invalid={key}') del self.config[key] - debug(f'Sampler: name="{name}"') - debug(f'Sampler: config={self.config}') - debug(f'Sampler: signature={possible}') - # shared.log.debug(f'Sampler: sampler="{name}" config={self.config}') - sampler = constructor(**self.config) + debug_log(f'Sampler: name="{name}"') + debug_log(f'Sampler: config={self.config}') + debug_log(f'Sampler: signature={possible}') + # shared.log.debug_log(f'Sampler: sampler="{name}" config={self.config}') + try: + sampler = constructor(**self.config) + except Exception as e: + shared.log.error(f'Sampler: sampler="{name}" {e}') + if debug: + errors.display(e, 'Samplers') + self.sampler = None + return accept_sigmas = "sigmas" in set(inspect.signature(sampler.set_timesteps).parameters.keys()) accepts_timesteps = "timesteps" in set(inspect.signature(sampler.set_timesteps).parameters.keys()) accept_scale_noise = hasattr(sampler, "scale_noise") - debug(f'Sampler: sampler="{name}" sigmas={accept_sigmas} timesteps={accepts_timesteps}') + debug_log(f'Sampler: sampler="{name}" sigmas={accept_sigmas} timesteps={accepts_timesteps}') if ('Flux' in model.__class__.__name__) and (not accept_sigmas): shared.log.warning(f'Sampler: sampler="{name}" does not accept sigmas') self.sampler = None @@ -286,5 +296,5 @@ def __init__(self, name, constructor, model, **kwargs): if not hasattr(self.sampler, 'dc_ratios'): pass # self.sampler.dc_ratios = self.sampler.cascade_polynomial_regression(test_CFG=6.0, test_NFE=10, cpr_path='tmp/sd2.1.npy') - # shared.log.debug(f'Sampler: class="{self.sampler.__class__.__name__}" config={self.sampler.config}') + # shared.log.debug_log(f'Sampler: class="{self.sampler.__class__.__name__}" config={self.sampler.config}') self.sampler.name = name diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py index 4507ee3c8..c8a1b882f 100644 --- a/modules/sd_vae_taesd.py +++ b/modules/sd_vae_taesd.py @@ -5,217 +5,132 @@ https://github.com/madebyollin/taesd """ import os +import threading from PIL import Image import torch -import torch.nn as nn from modules import devices, paths -taesd_models = { - 'sd-decoder': None, - 'sd-encoder': None, - 'sdxl-decoder': None, - 'sdxl-encoder': None, - 'sd3-decoder': None, - 'sd3-encoder': None, - 'f1-decoder': None, - 'f1-encoder': None, +TAESD_MODELS = { + 'TAESD 1.3 Mocha Croissant': { 'fn': 'taesd_13_', 'uri': 'https://github.com/madebyollin/taesd/raw/7f572ca629c9b0d3c9f71140e5f501e09f9ea280', 'model': None }, + 'TAESD 1.2 Chocolate-Dipped Shortbread': { 'fn': 'taesd_12_', 'uri': 'https://github.com/madebyollin/taesd/raw/8909b44e3befaa0efa79c5791e4fe1c4d4f7884e', 'model': None }, + 'TAESD 1.1 Fruit Loops': { 'fn': 'taesd_11_', 'uri': 'https://github.com/madebyollin/taesd/raw/3e8a8a2ab4ad4079db60c1c7dc1379b4cc0c6b31', 'model': None }, + 'TAESD 1.0': { 'fn': 'taesd_10_', 'uri': 'https://github.com/madebyollin/taesd/raw/88012e67cf0454e6d90f98911fe9d4aef62add86', 'model': None }, +} +CQYAN_MODELS = { + 'Hybrid-Tiny SD': { + 'sd': { 'repo': 'cqyan/hybrid-sd-tinyvae', 'model': None }, + 'sdxl': { 'repo': 'cqyan/hybrid-sd-tinyvae-xl', 'model': None }, + }, + 'Hybrid-Small SD': { + 'sd': { 'repo': 'cqyan/hybrid-sd-small-vae', 'model': None }, + 'sdxl': { 'repo': 'cqyan/hybrid-sd-small-vae-xl', 'model': None }, + }, } -previous_warnings = False - - -def conv(n_in, n_out, **kwargs): - return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs) - -class Clamp(nn.Module): - def forward(self, x): - return torch.tanh(x / 3) * 3 -class Block(nn.Module): - def __init__(self, n_in, n_out): - super().__init__() - self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out)) - self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() - self.fuse = nn.ReLU() - def forward(self, x): - return self.fuse(self.conv(x) + self.skip(x)) +prev_warnings = False +prev_cls = '' +prev_type = '' +prev_model = '' +lock = threading.Lock() -def Encoder(latent_channels=4): - return nn.Sequential( - conv(3, 64), Block(64, 64), - conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), - conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), - conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), - conv(64, latent_channels), - ) -def Decoder(latent_channels=4): +def warn_once(msg): from modules import shared - if shared.opts.live_preview_taesd_layers == 1: - return nn.Sequential( - Clamp(), conv(latent_channels, 64), nn.ReLU(), - Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), - Block(64, 64), Block(64, 64), Block(64, 64), nn.Identity(), conv(64, 64, bias=False), - Block(64, 64), Block(64, 64), Block(64, 64), nn.Identity(), conv(64, 64, bias=False), - Block(64, 64), conv(64, 3), - ) - elif shared.opts.live_preview_taesd_layers == 2: - return nn.Sequential( - Clamp(), conv(latent_channels, 64), nn.ReLU(), - Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), - Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), - Block(64, 64), Block(64, 64), Block(64, 64), nn.Identity(), conv(64, 64, bias=False), - Block(64, 64), conv(64, 3), - ) - else: - return nn.Sequential( - Clamp(), conv(latent_channels, 64), nn.ReLU(), - Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), - Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), - Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), - Block(64, 64), conv(64, 3), - ) - - -class TAESD(nn.Module): # pylint: disable=abstract-method - latent_magnitude = 3 - latent_shift = 0.5 - - def __init__(self, encoder_path="taesd_encoder.pth", decoder_path="taesd_decoder.pth", latent_channels=None): - """Initialize pretrained TAESD on the given device from the given checkpoints.""" - super().__init__() - if latent_channels is None: - latent_channels = self.guess_latent_channels(str(decoder_path), str(encoder_path)) - self.encoder = Encoder(latent_channels) - self.decoder = Decoder(latent_channels) - if encoder_path is not None: - self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu"), strict=False) - if decoder_path is not None: - self.decoder.load_state_dict(torch.load(decoder_path, map_location="cpu"), strict=False) - - def guess_latent_channels(self, decoder_path, encoder_path): - """guess latent channel count based on encoder filename""" - if "taef1" in encoder_path or "taef1" in decoder_path: - return 16 - if "taesd3" in encoder_path or "taesd3" in decoder_path: - return 16 - return 4 - - @staticmethod - def scale_latents(x): - """raw latents -> [0, 1]""" - return x.div(2 * TAESD.latent_magnitude).add(TAESD.latent_shift).clamp(0, 1) + global prev_warnings # pylint: disable=global-statement + if not prev_warnings: + prev_warnings = True + shared.log.error(f'Decode: type="taesd" variant="{shared.opts.taesd_variant}": {msg}') + return Image.new('RGB', (8, 8), color = (0, 0, 0)) - @staticmethod - def unscale_latents(x): - """[0, 1] -> raw latents""" - return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude) - -def download_model(model_path): - model_name = os.path.basename(model_path) - model_url = f'https://github.com/madebyollin/taesd/raw/main/{model_name}' - if not os.path.exists(model_path): - from modules.shared import log - os.makedirs(os.path.dirname(model_path), exist_ok=True) - log.info(f'Downloading TAESD decoder: {model_path}') - torch.hub.download_url_to_file(model_url, model_path) - - -def model(model_class = 'sd', model_type = 'decoder'): - vae = taesd_models[f'{model_class}-{model_type}'] - if vae is None: - model_path = os.path.join(paths.models_path, "TAESD", f"tae{model_class}_{model_type}.pth") - download_model(model_path) - if os.path.exists(model_path): - from modules.shared import log - taesd_models[f'{model_class}-{model_type}'] = TAESD(decoder_path=model_path, encoder_path=None) if model_type == 'decoder' else TAESD(encoder_path=model_path, decoder_path=None) - vae = taesd_models[f'{model_class}-{model_type}'] - vae.eval() - vae.to(devices.device, devices.dtype_vae) - log.info(f"Load VAE-TAESD: model={model_path}") - else: - raise FileNotFoundError(f'TAESD model not found: {model_path}') - if vae is None: +def get_model(model_type = 'decoder'): + global prev_cls, prev_type, prev_model # pylint: disable=global-statement + from modules import shared + cls = shared.sd_model_type + if cls == 'ldm': + cls = 'sd' + folder = os.path.join(paths.models_path, "TAESD") + os.makedirs(folder, exist_ok=True) + if 'sd' not in cls and 'f1' not in cls: + warn_once(f'cls={shared.sd_model.__class__.__name__} type={cls} unsuppported') return None + if shared.opts.taesd_variant.startswith('TAESD'): + cfg = TAESD_MODELS[shared.opts.taesd_variant] + if (cls == prev_cls) and (model_type == prev_type) and (shared.opts.taesd_variant == prev_model) and (cfg['model'] is not None): + return cfg['model'] + fn = os.path.join(folder, cfg['fn'] + cls + '_' + model_type + '.pth') + if not os.path.exists(fn): + uri = cfg['uri'] + '/tae' + cls + '_' + model_type + '.pth' + try: + shared.log.info(f'Decode: type="taesd" variant="{shared.opts.taesd_variant}": uri="{uri}" fn="{fn}" download') + torch.hub.download_url_to_file(uri, fn) + except Exception as e: + warn_once(f'download uri={uri} {e}') + if os.path.exists(fn): + prev_cls = cls + prev_type = model_type + prev_model = shared.opts.taesd_variant + shared.log.debug(f'Decode: type="taesd" variant="{shared.opts.taesd_variant}" fn="{fn}" load') + from modules.taesd.taesd import TAESD + TAESD_MODELS[shared.opts.taesd_variant]['model'] = TAESD(decoder_path=fn if model_type=='decoder' else None, encoder_path=fn if model_type=='encoder' else None) + return TAESD_MODELS[shared.opts.taesd_variant]['model'] + elif shared.opts.taesd_variant.startswith('Hybrid'): + cfg = CQYAN_MODELS[shared.opts.taesd_variant].get(cls, None) + if (cls == prev_cls) and (model_type == prev_type) and (shared.opts.taesd_variant == prev_model) and (cfg['model'] is not None): + return cfg['model'] + if cfg is None: + warn_once(f'cls={shared.sd_model.__class__.__name__} type={cls} unsuppported') + return None + repo = cfg['repo'] + prev_cls = cls + prev_type = model_type + prev_model = shared.opts.taesd_variant + shared.log.debug(f'Decode: type="taesd" variant="{shared.opts.taesd_variant}" id="{repo}" load') + dtype = devices.dtype_vae if devices.dtype_vae != torch.bfloat16 else torch.float16 # taesd does not support bf16 + if 'tiny' in repo: + from diffusers.models import AutoencoderTiny + vae = AutoencoderTiny.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir, torch_dtype=dtype) + else: + from modules.taesd.hybrid_small import AutoencoderSmall + vae = AutoencoderSmall.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir, torch_dtype=dtype) + vae = vae.to(devices.device, dtype=dtype) + CQYAN_MODELS[shared.opts.taesd_variant][cls]['model'] = vae + return vae else: - return vae.decoder if model_type == 'decoder' else vae.encoder + warn_once(f'cls={shared.sd_model.__class__.__name__} type={cls} unsuppported') + return None def decode(latents): - global previous_warnings # pylint: disable=global-statement - from modules import shared - model_class = shared.sd_model_type - if model_class == 'ldm': - model_class = 'sd' - dtype = devices.dtype_vae if devices.dtype_vae != torch.bfloat16 else torch.float16 # taesd does not support bf16 - if 'sd' not in model_class and 'f1' not in model_class: - if not previous_warnings: - previous_warnings = True - shared.log.warning(f'TAESD unsupported model type: {model_class}') - # return Image.new('RGB', (8, 8), color = (0, 0, 0)) - return latents - vae = taesd_models.get(f'{model_class}-decoder', None) - if vae is None: - model_path = os.path.join(paths.models_path, "TAESD", f"tae{model_class}_decoder.pth") - download_model(model_path) - if os.path.exists(model_path): - taesd_models[f'{model_class}-decoder'] = TAESD(decoder_path=model_path, encoder_path=None) - shared.log.debug(f'VAE load: type=taesd model="{model_path}"') - vae = taesd_models[f'{model_class}-decoder'] - vae.decoder.to(devices.device, dtype) - else: - shared.log.error(f'VAE load: type=taesd model="{model_path}" not found') + with lock: + from modules import shared + vae = get_model(model_type='decoder') + if vae is None or max(latents.shape) > 256: # safetey check of large tensors return latents - if vae is None: - return latents - try: - size = max(latents.shape[-1], latents.shape[-2]) - if size > 256: - return latents - with devices.inference_context(): - latents = latents.detach().clone().to(devices.device, dtype) - if len(latents.shape) == 3: - latents = latents.unsqueeze(0) - image = vae.decoder(latents).clamp(0, 1).detach() - image = 2.0 * image - 1.0 # typical normalized range except for preview which runs denormalization - return image[0] - elif len(latents.shape) == 4: - image = vae.decoder(latents).clamp(0, 1).detach() - image = 2.0 * image - 1.0 # typical normalized range except for preview which runs denormalization - return image - else: - if not previous_warnings: - shared.log.error(f'TAESD decode unsupported latent type: {latents.shape}') - previous_warnings = True - return latents - except Exception as e: - if not previous_warnings: - shared.log.error(f'VAE decode taesd: {e}') - previous_warnings = True - return latents + try: + with devices.inference_context(): + tensor = latents.unsqueeze(0) if len(latents.shape) == 3 else latents + tensor = tensor.half().detach().clone().to(devices.device, dtype=vae.dtype) + if shared.opts.taesd_variant.startswith('TAESD'): + image = vae.decoder(tensor).clamp(0, 1).detach() + return image[0] + else: + image = vae.decode(tensor, return_dict=False)[0] + image = (image / 2.0 + 0.5).clamp(0, 1).detach() + return image + except Exception as e: + return warn_once(f'decode {e}') def encode(image): - global previous_warnings # pylint: disable=global-statement - from modules import shared - model_class = shared.sd_model_type - if model_class == 'ldm': - model_class = 'sd' - if 'sd' not in model_class and 'f1' not in model_class: - if not previous_warnings: - previous_warnings = True - shared.log.warning(f'TAESD unsupported model type: {model_class}') - return Image.new('RGB', (8, 8), color = (0, 0, 0)) - vae = taesd_models[f'{model_class}-encoder'] - if vae is None: - model_path = os.path.join(paths.models_path, "TAESD", f"tae{model_class}_encoder.pth") - download_model(model_path) - if os.path.exists(model_path): - shared.log.debug(f'VAE load: type=taesd model="{model_path}"') - taesd_models[f'{model_class}-encoder'] = TAESD(encoder_path=model_path, decoder_path=None) - vae = taesd_models[f'{model_class}-encoder'] - vae.encoder.to(devices.device, devices.dtype_vae) - # image = vae.scale_latents(image) - latents = vae.encoder(image) - return latents.detach() + with lock: + vae = get_model(model_type='encoder') + if vae is None: + return image + try: + with devices.inference_context(): + latents = vae.encoder(image) + return latents.detach() + except Exception as e: + return warn_once(f'encode {e}') diff --git a/modules/shared.py b/modules/shared.py index 6150017d0..b092c27b5 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -84,15 +84,7 @@ dir_timestamps = {} dir_cache = {} max_workers = 8 -if os.environ.get("SD_HFCACHEDIR", None) is not None: - hfcache_dir = os.environ.get("SD_HFCACHEDIR") -if os.environ.get("HF_HUB_CACHE", None) is not None: - hfcache_dir = os.environ.get("HF_HUB_CACHE") -elif os.environ.get("HF_HUB", None) is not None: - hfcache_dir = os.path.join(os.environ.get("HF_HUB"), '.cache') -else: - hfcache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'huggingface', 'hub') - os.environ["HF_HUB_CACHE"] = hfcache_dir +default_hfcache_dir = os.environ.get("SD_HFCACHEDIR", None) or os.path.join(os.path.expanduser('~'), '.cache', 'huggingface', 'hub') class Backend(Enum): @@ -211,14 +203,12 @@ def default(obj): # early select backend -default_backend = 'diffusers' early_opts = readfile(cmd_opts.config, silent=True) -early_backend = early_opts.get('sd_backend', default_backend) -backend = Backend.DIFFUSERS if early_backend.lower() == 'diffusers' else Backend.ORIGINAL +early_backend = early_opts.get('sd_backend', 'diffusers') +backend = Backend.ORIGINAL if early_backend.lower() == 'original' else Backend.DIFFUSERS if cmd_opts.backend is not None: # override with args - backend = Backend.DIFFUSERS if cmd_opts.backend.lower() == 'diffusers' else Backend.ORIGINAL + backend = Backend.ORIGINAL if cmd_opts.backend.lower() == 'original' else Backend.DIFFUSERS if cmd_opts.use_openvino: # override for openvino - backend = Backend.DIFFUSERS from modules.intel.openvino import get_device_list as get_openvino_device_list # pylint: disable=ungrouped-imports elif cmd_opts.use_ipex or devices.has_xpu(): from modules.intel.ipex import ipex_init @@ -226,15 +216,14 @@ def default(obj): if not ok: log.error(f'IPEX initialization failed: {e}') elif cmd_opts.use_directml: - name = 'directml' from modules.dml import directml_init ok, e = directml_init() if not ok: log.error(f'DirectML initialization failed: {e}') devices.backend = devices.get_backend(cmd_opts) devices.device = devices.get_optimal_device() -cpu_memory = round(psutil.virtual_memory().total / 1024 / 1024 / 1024, 2) mem_stat = memory_stats() +cpu_memory = round(psutil.virtual_memory().total / 1024 / 1024 / 1024, 2) gpu_memory = mem_stat['gpu']['total'] if "gpu" in mem_stat else 0 native = backend == Backend.DIFFUSERS if not files_cache.do_cache_folders: @@ -438,14 +427,14 @@ def get_default_modes(): if gpu_memory <= 4: cmd_opts.lowvram = True default_offload_mode = "sequential" - log.info(f"Device detect: memory={gpu_memory:.1f} optimization=lowvram") + log.info(f"Device detect: memory={gpu_memory:.1f} default=sequential optimization=lowvram") # elif gpu_memory <= 8: # cmd_opts.medvram = True # default_offload_mode = "model" # log.info(f"Device detect: memory={gpu_memory:.1f} optimization=medvram") else: default_offload_mode = "balanced" - log.info(f"Device detect: memory={gpu_memory:.1f} optimization=balanced") + log.info(f"Device detect: memory={gpu_memory:.1f} default=balanced") elif cmd_opts.medvram: default_offload_mode = "balanced" elif cmd_opts.lowvram: @@ -475,7 +464,7 @@ def get_default_modes(): startup_offload_mode, startup_cross_attention, startup_sdp_options = get_default_modes() options_templates.update(options_section(('sd', "Models & Loading"), { - "sd_backend": OptionInfo(default_backend, "Execution backend", gr.Radio, {"choices": ["diffusers", "original"] }), + "sd_backend": OptionInfo('diffusers', "Execution backend", gr.Radio, {"choices": ['diffusers', 'original'] }), "diffusers_pipeline": OptionInfo('Autodetect', 'Model pipeline', gr.Dropdown, lambda: {"choices": list(shared_items.get_pipelines()), "visible": native}), "sd_model_checkpoint": OptionInfo(default_checkpoint, "Base model", DropdownEditable, lambda: {"choices": list_checkpoint_titles()}, refresh=refresh_checkpoints), "sd_model_refiner": OptionInfo('None', "Refiner model", gr.Dropdown, lambda: {"choices": ['None'] + list_checkpoint_titles()}, refresh=refresh_checkpoints), @@ -513,7 +502,7 @@ def get_default_modes(): "diffusers_vae_tile_overlap": OptionInfo(0.25, "VAE tile overlap", gr.Slider, {"minimum": 0, "maximum": 0.95, "step": 0.05 }), "sd_vae_sliced_encode": OptionInfo(False, "VAE sliced encode", gr.Checkbox, {"visible": not native}), "nan_skip": OptionInfo(False, "Skip Generation if NaN found in latents", gr.Checkbox), - "rollback_vae": OptionInfo(False, "Attempt VAE roll back for NaN values"), + "rollback_vae": OptionInfo(False, "Attempt VAE roll back for NaN values", gr.Checkbox, {"visible": not native}), })) options_templates.update(options_section(('text_encoder', "Text Encoder"), { @@ -665,11 +654,13 @@ def get_default_modes(): })) options_templates.update(options_section(('system-paths', "System Paths"), { + "clean_temp_dir_at_start": OptionInfo(True, "Cleanup temporary folder on startup"), "models_paths_sep_options": OptionInfo("

Models Paths

", "", gr.HTML), - "models_dir": OptionInfo('models', "Base path where all models are stored", folder=True), + "models_dir": OptionInfo('models', "Root model folder", folder=True), + "model_paths_sep_options": OptionInfo("

Paths for specific models

", "", gr.HTML), "ckpt_dir": OptionInfo(os.path.join(paths.models_path, 'Stable-diffusion'), "Folder with stable diffusion models", folder=True), "diffusers_dir": OptionInfo(os.path.join(paths.models_path, 'Diffusers'), "Folder with Huggingface models", folder=True), - "hfcache_dir": OptionInfo(hfcache_dir, "Folder for Huggingface cache", folder=True), + "hfcache_dir": OptionInfo(default_hfcache_dir, "Folder for Huggingface cache", folder=True), "vae_dir": OptionInfo(os.path.join(paths.models_path, 'VAE'), "Folder with VAE files", folder=True), "unet_dir": OptionInfo(os.path.join(paths.models_path, 'UNET'), "Folder with UNET files", folder=True), "te_dir": OptionInfo(os.path.join(paths.models_path, 'Text-encoder'), "Folder with Text encoder files", folder=True), @@ -689,19 +680,18 @@ def get_default_modes(): "swinir_models_path": OptionInfo(os.path.join(paths.models_path, 'SwinIR'), "Folder with SwinIR models", folder=True), "ldsr_models_path": OptionInfo(os.path.join(paths.models_path, 'LDSR'), "Folder with LDSR models", folder=True), "clip_models_path": OptionInfo(os.path.join(paths.models_path, 'CLIP'), "Folder with CLIP models", folder=True), - "other_paths_sep_options": OptionInfo("

Other paths

", "", gr.HTML), - "openvino_cache_path": OptionInfo('cache', "Directory for OpenVINO cache", folder=True), - "accelerate_offload_path": OptionInfo('cache/accelerate', "Directory for disk offload with Accelerate", folder=True), - "onnx_cached_models_path": OptionInfo(os.path.join(paths.models_path, 'ONNX', 'cache'), "Folder with ONNX cached models", folder=True), - "onnx_temp_dir": OptionInfo(os.path.join(paths.models_path, 'ONNX', 'temp'), "Directory for ONNX conversion and Olive optimization process", folder=True), + "other_paths_sep_options": OptionInfo("

Cache folders

", "", gr.HTML), "temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default", folder=True), - "clean_temp_dir_at_start": OptionInfo(True, "Cleanup non-default temporary directory when starting webui"), + "accelerate_offload_path": OptionInfo('cache/accelerate', "Folder for disk offload", folder=True), + "openvino_cache_path": OptionInfo('cache', "Folder for OpenVINO cache", folder=True), + "onnx_cached_models_path": OptionInfo(os.path.join(paths.models_path, 'ONNX', 'cache'), "Folder for ONNX cached models", folder=True), + "onnx_temp_dir": OptionInfo(os.path.join(paths.models_path, 'ONNX', 'temp'), "Folder for ONNX conversion", folder=True), })) options_templates.update(options_section(('saving-images', "Image Options"), { "keep_incomplete": OptionInfo(True, "Keep incomplete images"), "samples_save": OptionInfo(True, "Save all generated images"), - "samples_format": OptionInfo('jpg', 'File format', gr.Dropdown, {"choices": ["jpg", "png", "webp", "tiff", "jp2"]}), + "samples_format": OptionInfo('jpg', 'File format', gr.Dropdown, {"choices": ["jpg", "png", "webp", "tiff", "jp2", "jxl"]}), "jpeg_quality": OptionInfo(90, "Image quality", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}), "img_max_size_mp": OptionInfo(1000, "Maximum image size (MP)", gr.Slider, {"minimum": 100, "maximum": 2000, "step": 1}), "webp_lossless": OptionInfo(False, "WebP lossless compression"), @@ -716,7 +706,7 @@ def get_default_modes(): "save_log_fn": OptionInfo("", "Append image info JSON file", component_args=hide_dirs), "image_sep_grid": OptionInfo("

Grid Options

", "", gr.HTML), "grid_save": OptionInfo(True, "Save all generated image grids"), - "grid_format": OptionInfo('jpg', 'File format', gr.Dropdown, {"choices": ["jpg", "png", "webp", "tiff", "jp2"]}), + "grid_format": OptionInfo('jpg', 'File format', gr.Dropdown, {"choices": ["jpg", "png", "webp", "tiff", "jp2", "jxl"]}), "n_rows": OptionInfo(-1, "Row count", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}), "grid_background": OptionInfo("#000000", "Grid background color", gr.ColorPicker, {}), "font": OptionInfo("", "Font file"), @@ -778,14 +768,16 @@ def get_default_modes(): "autolaunch": OptionInfo(False, "Autolaunch browser upon startup"), "font_size": OptionInfo(14, "Font size", gr.Slider, {"minimum": 8, "maximum": 32, "step": 1, "visible": True}), "aspect_ratios": OptionInfo("1:1, 4:3, 3:2, 16:9, 16:10, 21:9, 2:3, 3:4, 9:16, 10:16, 9:21", "Allowed aspect ratios"), + "logmonitor_show": OptionInfo(True, "Show log view"), + "logmonitor_refresh_period": OptionInfo(5000, "Log view update period", gr.Slider, {"minimum": 0, "maximum": 30000, "step": 25}), "motd": OptionInfo(False, "Show MOTD"), "compact_view": OptionInfo(False, "Compact view"), "return_grid": OptionInfo(True, "Show grid in results"), "return_mask": OptionInfo(False, "Inpainting include greyscale mask in results"), "return_mask_composite": OptionInfo(False, "Inpainting include masked composite in results"), "disable_weights_auto_swap": OptionInfo(True, "Do not change selected model when reading generation parameters"), - "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"), - "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"), + "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface", gr.Checkbox, {"visible": False}), + "send_size": OptionInfo(False, "Send size when sending prompt or image to another interface", gr.Checkbox, {"visible": False}), "quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", gr.Dropdown, lambda: {"multiselect":True, "choices": list(opts.data_labels.keys())}), })) @@ -793,10 +785,10 @@ def get_default_modes(): "show_progress_every_n_steps": OptionInfo(1, "Live preview display period", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}), "show_progress_type": OptionInfo("Approximate", "Live preview method", gr.Radio, {"choices": ["Simple", "Approximate", "TAESD", "Full VAE"]}), "live_preview_refresh_period": OptionInfo(500, "Progress update period", gr.Slider, {"minimum": 0, "maximum": 5000, "step": 25}), - "live_preview_taesd_layers": OptionInfo(3, "TAESD decode layers", gr.Slider, {"minimum": 1, "maximum": 3, "step": 1}), + "taesd_variant": OptionInfo(shared_items.sd_taesd_items()[0], "TAESD variant", gr.Dropdown, {"choices": shared_items.sd_taesd_items()}), + "taesd_layers": OptionInfo(3, "TAESD decode layers", gr.Slider, {"minimum": 1, "maximum": 3, "step": 1}), "live_preview_downscale": OptionInfo(True, "Downscale high resolution live previews"), - "logmonitor_show": OptionInfo(True, "Show log view"), - "logmonitor_refresh_period": OptionInfo(5000, "Log view update period", gr.Slider, {"minimum": 0, "maximum": 30000, "step": 25}), + "notification_audio_enable": OptionInfo(False, "Play a notification upon completion"), "notification_audio_path": OptionInfo("html/notification.mp3","Path to notification sound", component_args=hide_dirs, folder=True), })) @@ -865,8 +857,6 @@ def get_default_modes(): "detailer_max_size": OptionInfo(1.0, "Max object size", gr.Slider, {"minimum": 0.1, "maximum": 1, "step": 0.05, "visible": False}), "detailer_padding": OptionInfo(20, "Item padding", gr.Slider, {"minimum": 0, "maximum": 100, "step": 1, "visible": False}), "detailer_blur": OptionInfo(10, "Item edge blur", gr.Slider, {"minimum": 0, "maximum": 100, "step": 1, "visible": False}), - "detailer_steps": OptionInfo(10, "Detailer steps", gr.Slider, {"minimum": 0, "maximum": 99, "step": 1, "visible": False}), - "detailer_strength": OptionInfo(0.5, "Detailer strength", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01, "visible": False}), "detailer_models": OptionInfo(['face-yolo8n'], "Detailer models", gr.Dropdown, lambda: {"multiselect":True, "choices": list(yolo.list), "visible": False}), "detailer_unload": OptionInfo(False, "Move detailer model to CPU when complete"), "detailer_augment": OptionInfo(True, "Detailer use model augment"), @@ -1011,6 +1001,7 @@ class Options: data_labels = options_templates filename = None typemap = {int: float} + debug = os.environ.get('SD_CONFIG_DEBUG', None) is not None def __init__(self): self.data = {k: v.default for k, v in self.data_labels.items()} @@ -1024,6 +1015,8 @@ def __setattr__(self, key, value): # pylint: disable=inconsistent-return-stateme if cmd_opts.hide_ui_dir_config and key in restricted_opts: log.warning(f'Settings key is restricted: {key}') return + if self.debug: + log.trace(f'Settings set: {key}={value}') self.data[key] = value return return super(Options, self).__setattr__(key, value) # pylint: disable=super-with-arguments @@ -1077,17 +1070,13 @@ def save_atomic(self, filename=None, silent=False): log.warning(f'Setting: fn="{filename}" save disabled') return try: - # output = json.dumps(self.data, indent=2) diff = {} unused_settings = [] - if os.environ.get('SD_CONFIG_DEBUG', None) is not None: + if self.debug: log.debug('Settings: user') for k, v in self.data.items(): log.trace(f' Config: item={k} value={v} default={self.data_labels[k].default if k in self.data_labels else None}') - log.debug('Settings: defaults') - for k in self.data_labels.keys(): - log.trace(f' Setting: item={k} default={self.data_labels[k].default}') for k, v in self.data.items(): if k in self.data_labels: @@ -1101,6 +1090,8 @@ def save_atomic(self, filename=None, silent=False): if not k.startswith('uiux_'): unused_settings.append(k) writefile(diff, filename, silent=silent) + if self.debug: + log.trace(f'Settings save: {diff}') if len(unused_settings) > 0: log.debug(f"Settings: unused={unused_settings}") except Exception as err: diff --git a/modules/shared_items.py b/modules/shared_items.py index 17b7ce1ee..5c1e3aebb 100644 --- a/modules/shared_items.py +++ b/modules/shared_items.py @@ -8,6 +8,10 @@ def sd_vae_items(): return ["Automatic", "None"] + list(modules.sd_vae.vae_dict) +def sd_taesd_items(): + import modules.sd_vae_taesd + return list(modules.sd_vae_taesd.TAESD_MODELS.keys()) + list(modules.sd_vae_taesd.CQYAN_MODELS.keys()) + def refresh_vae_list(): import modules.sd_vae modules.sd_vae.refresh_vae_list() diff --git a/modules/shared_state.py b/modules/shared_state.py index 9f56b14e0..b4c92bb65 100644 --- a/modules/shared_state.py +++ b/modules/shared_state.py @@ -1,10 +1,18 @@ import os +import sys import time import datetime from modules.errors import log, display +debug_output = os.environ.get('SD_STATE_DEBUG', None) + + class State: + job_history = [] + task_history = [] + image_history = 0 + latent_history = 0 skipped = False interrupted = False paused = False @@ -14,7 +22,7 @@ class State: frame_count = 0 total_jobs = 0 job_timestamp = '0' - sampling_step = 0 + _sampling_step = 0 sampling_steps = 0 current_latent = None current_noise_pred = None @@ -32,29 +40,48 @@ class State: need_restart = False server_start = time.time() oom = False - debug_output = os.environ.get('SD_STATE_DEBUG', None) def __str__(self) -> str: - return f'State: job={self.job} {self.job_no}/{self.job_count} step={self.sampling_step}/{self.sampling_steps} skipped={self.skipped} interrupted={self.interrupted} paused={self.paused} info={self.textinfo}' + status = ' ' + status += 'skipped ' if self.skipped else '' + status += 'interrupted ' if self.interrupted else '' + status += 'paused ' if self.paused else '' + status += 'restart ' if self.need_restart else '' + status += 'oom ' if self.oom else '' + status += 'api ' if self.api else '' + fn = f'{sys._getframe(3).f_code.co_name}:{sys._getframe(2).f_code.co_name}' # pylint: disable=protected-access + return f'State: ts={self.job_timestamp} job={self.job} jobs={self.job_no+1}/{self.job_count}/{self.total_jobs} step={self.sampling_step}/{self.sampling_steps} preview={self.preview_job}/{self.id_live_preview}/{self.current_image_sampling_step} status="{status.strip()}" fn={fn}' + + @property + def sampling_step(self): + return self._sampling_step + + @sampling_step.setter + def sampling_step(self, value): + self._sampling_step = value + if debug_output: + log.trace(f'State step: {self}') def skip(self): - log.debug('Requested skip') + log.debug('State: skip requested') self.skipped = True def interrupt(self): - log.debug('Requested interrupt') + log.debug('State: interrupt requested') self.interrupted = True def pause(self): self.paused = not self.paused - log.debug(f'Requested {"pause" if self.paused else "continue"}') + log.debug(f'State: {"pause" if self.paused else "continue"} requested') def nextjob(self): import modules.devices self.do_set_current_image() self.job_no += 1 - self.sampling_step = 0 + # self.sampling_step = 0 self.current_image_sampling_step = 0 + if debug_output: + log.trace(f'State next: {self}') modules.devices.torch_gc() def dict(self): @@ -104,6 +131,7 @@ def status(self): def begin(self, title="", api=None): import modules.devices + self.job_history.append(title) self.total_jobs += 1 self.current_image = None self.current_image_sampling_step = 0 @@ -115,19 +143,20 @@ def begin(self, title="", api=None): self.interrupted = False self.preview_job = -1 self.job = title - self.job_count = -1 - self.frame_count = -1 + self.job_count = 0 + self.frame_count = 0 self.job_no = 0 self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S") self.paused = False - self.sampling_step = 0 + self._sampling_step = 0 + self.sampling_steps = 0 self.skipped = False self.textinfo = None self.prediction_type = "epsilon" self.api = api or self.api self.time_start = time.time() - if self.debug_output: - log.debug(f'State begin: {self.job}') + if debug_output: + log.trace(f'State begin: {self}') modules.devices.torch_gc() def end(self, api=None): @@ -136,6 +165,8 @@ def end(self, api=None): # fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access # log.debug(f'Access state.end: {fn}') # pylint: disable=protected-access self.time_start = time.time() + if debug_output: + log.trace(f'State end: {self}') self.job = "" self.job_count = 0 self.job_no = 0 @@ -147,6 +178,24 @@ def end(self, api=None): self.api = api or self.api modules.devices.torch_gc() + def step(self, step:int=1): + self.sampling_step += step + + def update(self, job:str, steps:int=0, jobs:int=0): + self.task_history.append(job) + # self._sampling_step = 0 + if job == 'Ignore': + return + elif job == 'Grid': + self.sampling_steps = steps + self.job_count = jobs + else: + self.sampling_steps += steps * jobs + self.job_count += jobs + self.job = job + if debug_output: + log.trace(f'State update: {self} steps={steps} jobs={jobs}') + def set_current_image(self): if self.job == 'VAE' or self.job == 'Upscale': # avoid generating preview while vae is running return False diff --git a/modules/taesd/hybrid_small.py b/modules/taesd/hybrid_small.py new file mode 100644 index 000000000..a59b0b4d7 --- /dev/null +++ b/modules/taesd/hybrid_small.py @@ -0,0 +1,506 @@ +# pylint: disable=no-member,unused-argument,attribute-defined-outside-init + +# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders.single_file_model import FromOriginalModelMixin +from diffusers.utils.accelerate_utils import apply_forward_hook +from diffusers.models.attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + Attention, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, +) +from diffusers.models.modeling_outputs import AutoencoderKLOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.autoencoders.vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder + + +class AutoencoderSmall(ModelMixin, ConfigMixin, FromOriginalModelMixin): + r""" + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): + Tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): Sample input size. + scaling_factor (`float`, *optional*, defaults to 0.18215): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. + force_upcast (`bool`, *optional*, default to `True`): + If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE + can be fine-tuned / trained to a lower range without loosing too much precision in which case + `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str] = ("DownEncoderBlock2D",), + up_block_types: Tuple[str] = ("UpDecoderBlock2D",), + block_out_channels: Tuple[int] = (64,), + encoder_block_out_channels: Tuple[int] = None, + decoder_block_out_channels: Tuple[int] = None, + layers_per_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 4, + norm_num_groups: int = 32, + sample_size: int = 32, + scaling_factor: float = 0.18215, + latents_mean: Optional[Tuple[float]] = None, + latents_std: Optional[Tuple[float]] = None, + force_upcast: float = True, + ): + super().__init__() + + if encoder_block_out_channels is not None or decoder_block_out_channels is not None: + if encoder_block_out_channels is None: + raise NotImplementedError + if decoder_block_out_channels is None: + raise NotImplementedError + + else: + encoder_block_out_channels = block_out_channels + decoder_block_out_channels = block_out_channels + self.config.encoder_block_out_channels = self.config.decoder_block_out_channels = block_out_channels + + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=encoder_block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + ) + + # pass init params to Decoder + self.decoder = Decoder( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=decoder_block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + ) + + self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) + self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) + + self.use_slicing = False + self.use_tiling = False + + # only relevant if vae tiling is enabled + self.tile_sample_min_size = self.config.sample_size + sample_size = ( + self.config.sample_size[0] + if isinstance(self.config.sample_size, (list, tuple)) + else self.config.sample_size + ) + self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.encoder_block_out_channels) - 1))) + self.tile_overlap_factor = 0.25 + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (Encoder, Decoder)): + module.gradient_checkpointing = value + + def enable_tiling(self, use_tiling: bool = True): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.use_tiling = use_tiling + + def disable_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.enable_tiling(False) + + def enable_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + @apply_forward_hook + def encode( + self, x: torch.FloatTensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.FloatTensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded images. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size): + return self.tiled_encode(x, return_dict=return_dict) + + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self.encoder(x) + + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size): + return self.tiled_decode(z, return_dict=return_dict) + + z = self.post_quant_conv(z) + dec = self.decoder(z) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode( + self, z: torch.FloatTensor, return_dict: bool = True, generator=None + ) -> Union[DecoderOutput, torch.FloatTensor]: + """ + Decode a batch of images. + + Args: + z (`torch.FloatTensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[2], b.shape[2], blend_extent) + for y in range(blend_extent): + b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for x in range(blend_extent): + b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) + return b + + def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.FloatTensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`: + If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain + `tuple` is returned. + """ + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[2], overlap_size): + row = [] + for j in range(0, x.shape[3], overlap_size): + tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] + tile = self.encoder(tile) + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + moments = torch.cat(result_rows, dim=2) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.FloatTensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + row_limit = self.tile_sample_min_size - blend_extent + + # Split z into overlapping 64x64 tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, z.shape[2], overlap_size): + row = [] + for j in range(0, z.shape[3], overlap_size): + tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size] + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + dec = torch.cat(result_rows, dim=2) + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.FloatTensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.FloatTensor]: + r""" + Args: + sample (`torch.FloatTensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, + key, value) are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) diff --git a/modules/taesd/taesd.py b/modules/taesd/taesd.py new file mode 100644 index 000000000..8e391a8fb --- /dev/null +++ b/modules/taesd/taesd.py @@ -0,0 +1,88 @@ +import torch +import torch.nn as nn +from modules import devices + + +def conv(n_in, n_out, **kwargs): + return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs) + +class Clamp(nn.Module): + def forward(self, x): + return torch.tanh(x / 3) * 3 + +class Block(nn.Module): + def __init__(self, n_in, n_out): + super().__init__() + self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out)) + self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() + self.fuse = nn.ReLU() + def forward(self, x): + return self.fuse(self.conv(x) + self.skip(x)) + +def Encoder(latent_channels=4): + return nn.Sequential( + conv(3, 64), Block(64, 64), + conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), + conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), + conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), + conv(64, latent_channels), + ) + +def Decoder(latent_channels=4): + from modules import shared + if shared.opts.taesd_layers == 1: + return nn.Sequential( + Clamp(), conv(latent_channels, 64), nn.ReLU(), + Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), + Block(64, 64), Block(64, 64), Block(64, 64), nn.Identity(), conv(64, 64, bias=False), + Block(64, 64), Block(64, 64), Block(64, 64), nn.Identity(), conv(64, 64, bias=False), + Block(64, 64), conv(64, 3), + ) + elif shared.opts.taesd_layers == 2: + return nn.Sequential( + Clamp(), conv(latent_channels, 64), nn.ReLU(), + Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), + Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), + Block(64, 64), Block(64, 64), Block(64, 64), nn.Identity(), conv(64, 64, bias=False), + Block(64, 64), conv(64, 3), + ) + else: + return nn.Sequential( + Clamp(), conv(latent_channels, 64), nn.ReLU(), + Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), + Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), + Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), + Block(64, 64), conv(64, 3), + ) + + +class TAESD(nn.Module): # pylint: disable=abstract-method + latent_magnitude = 3 + latent_shift = 0.5 + + def __init__(self, encoder_path=None, decoder_path=None, latent_channels=None): + super().__init__() + self.dtype = devices.dtype_vae if devices.dtype_vae != torch.bfloat16 else torch.float16 # taesd does not support bf16 + if latent_channels is None: + latent_channels = self.guess_latent_channels(str(decoder_path), str(encoder_path)) + self.encoder = Encoder(latent_channels) + self.decoder = Decoder(latent_channels) + if encoder_path is not None: + self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu"), strict=False) + self.encoder.eval() + self.encoder = self.encoder.to(devices.device, dtype=self.dtype) + if decoder_path is not None: + self.decoder.load_state_dict(torch.load(decoder_path, map_location="cpu"), strict=False) + self.decoder.eval() + self.decoder = self.decoder.to(devices.device, dtype=self.dtype) + + def guess_latent_channels(self, decoder_path, encoder_path): + return 16 if ("f1" in encoder_path or "f1" in decoder_path) or ("sd3" in encoder_path or "sd3" in decoder_path) else 4 + + @staticmethod + def scale_latents(x): + return x.div(2 * TAESD.latent_magnitude).add(TAESD.latent_shift).clamp(0, 1) # raw latents -> [0, 1] + + @staticmethod + def unscale_latents(x): + return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude) # [0, 1] -> raw latents diff --git a/modules/timer.py b/modules/timer.py index 977206e06..69107e605 100644 --- a/modules/timer.py +++ b/modules/timer.py @@ -25,9 +25,12 @@ def elapsed(self, reset=True): def add(self, name, t): if name not in self.records: - self.records[name] = t - else: - self.records[name] += t + self.records[name] = 0 + self.records[name] += t + + def ts(self, name, t): + elapsed = time.time() - t + self.add(name, elapsed) def record(self, category=None, extra_time=0, reset=True): e = self.elapsed(reset) @@ -41,6 +44,8 @@ def record(self, category=None, extra_time=0, reset=True): def summary(self, min_time=default_min_time, total=True): if self.profile: min_time = -1 + if self.total <= 0: + self.total = sum(self.records.values()) res = f"total={self.total:.2f} " if total else '' additions = [x for x in self.records.items() if x[1] >= min_time] additions = sorted(additions, key=lambda x: x[1], reverse=True) @@ -49,11 +54,14 @@ def summary(self, min_time=default_min_time, total=True): res += " ".join([f"{category}={time_taken:.2f}" for category, time_taken in additions]) return res + def get_total(self): + return sum(self.records.values()) + def dct(self, min_time=default_min_time): if self.profile: res = {k: round(v, 4) for k, v in self.records.items()} res = {k: round(v, 2) for k, v in self.records.items() if v >= min_time} - res = {k: v for k, v in sorted(res.items(), key=lambda x: x[1], reverse=True)} # noqa: C416 + res = {k: v for k, v in sorted(res.items(), key=lambda x: x[1], reverse=True)} # noqa: C416 # pylint: disable=unnecessary-comprehension return res def reset(self): @@ -61,3 +69,5 @@ def reset(self): startup = Timer() process = Timer() +launch = Timer() +init = Timer() diff --git a/modules/todo/todo_merge.py b/modules/todo/todo_merge.py index bfbff5621..77840d6ed 100644 --- a/modules/todo/todo_merge.py +++ b/modules/todo/todo_merge.py @@ -25,7 +25,6 @@ def init_generator(device: torch.device, fallback: torch.Generator = None): """ Forks the current default random generator given device. """ - print(f"init_generator device = {device}") if device.type == "cpu": return torch.Generator(device="cpu").set_state(torch.get_rng_state()) elif device.type == "cuda": diff --git a/modules/txt2img.py b/modules/txt2img.py index e82c744a2..e2cf7af77 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -11,7 +11,8 @@ def txt2img(id_task, state, prompt, negative_prompt, prompt_styles, steps, sampler_index, hr_sampler_index, - full_quality, detailer, tiling, hidiffusion, + full_quality, tiling, hidiffusion, + detailer_enabled, detailer_prompt, detailer_negative, detailer_steps, detailer_strength, n_iter, batch_size, cfg_scale, image_cfg_scale, diffusers_guidance_rescale, pag_scale, pag_adaptive, cfg_end, clip_skip, @@ -24,7 +25,7 @@ def txt2img(id_task, state, override_settings_texts, *args): - debug(f'txt2img: id_task={id_task}|prompt={prompt}|negative={negative_prompt}|styles={prompt_styles}|steps={steps}|sampler_index={sampler_index}|hr_sampler_index={hr_sampler_index}|full_quality={full_quality}|detailer={detailer}|tiling={tiling}|hidiffusion={hidiffusion}|batch_count={n_iter}|batch_size={batch_size}|cfg_scale={cfg_scale}|clip_skip={clip_skip}|seed={seed}|subseed={subseed}|subseed_strength={subseed_strength}|seed_resize_from_h={seed_resize_from_h}|seed_resize_from_w={seed_resize_from_w}|height={height}|width={width}|enable_hr={enable_hr}|denoising_strength={denoising_strength}|hr_resize_mode={hr_resize_mode}|hr_resize_context={hr_resize_context}|hr_scale={hr_scale}|hr_upscaler={hr_upscaler}|hr_force={hr_force}|hr_second_pass_steps={hr_second_pass_steps}|hr_resize_x={hr_resize_x}|hr_resize_y={hr_resize_y}|image_cfg_scale={image_cfg_scale}|diffusers_guidance_rescale={diffusers_guidance_rescale}|refiner_steps={refiner_steps}|refiner_start={refiner_start}|refiner_prompt={refiner_prompt}|refiner_negative={refiner_negative}|override_settings={override_settings_texts}') + debug(f'txt2img: {id_task}') if shared.sd_model is None: shared.log.warning('Aborted: op=txt model not loaded') @@ -64,7 +65,11 @@ def txt2img(id_task, state, width=width, height=height, full_quality=full_quality, - detailer=detailer, + detailer_enabled=detailer_enabled, + detailer_prompt=detailer_prompt, + detailer_negative=detailer_negative, + detailer_steps=detailer_steps, + detailer_strength=detailer_strength, tiling=tiling, hidiffusion=hidiffusion, enable_hr=enable_hr, diff --git a/modules/ui.py b/modules/ui.py index 4490bbf8c..8e9a9db85 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -16,6 +16,7 @@ mimetypes.init() mimetypes.add_type('application/javascript', '.js') mimetypes.add_type('image/webp', '.webp') +mimetypes.add_type('image/jxl', '.jxl') log = shared.log opts = shared.opts cmd_opts = shared.cmd_opts @@ -234,6 +235,8 @@ def run_settings(*args): for key, value, comp in zip(opts.data_labels.keys(), args, components): if comp == dummy_component or value=='dummy': continue + if getattr(comp, 'visible', True) is False: + continue if not opts.same_type(value, opts.data_labels[key].default): log.error(f'Setting bad value: {key}={value} expecting={type(opts.data_labels[key].default).__name__}') continue diff --git a/modules/ui_control.py b/modules/ui_control.py index c12a03130..f577ad4c2 100644 --- a/modules/ui_control.py +++ b/modules/ui_control.py @@ -168,16 +168,11 @@ def create_ui(_blocks: gr.Blocks=None): with gr.Row(): video_skip_frames = gr.Slider(minimum=0, maximum=100, step=1, label='Skip input frames', value=0, elem_id="control_video_skip_frames") with gr.Row(): - video_type = gr.Dropdown(label='Video file', choices=['None', 'GIF', 'PNG', 'MP4'], value='None', elem_id="control_video_type") - video_duration = gr.Slider(label='Duration', minimum=0.25, maximum=300, step=0.25, value=2, visible=False, elem_id="control_video_duration") - with gr.Row(): - video_loop = gr.Checkbox(label='Loop', value=True, visible=False, elem_id="control_video_loop") - video_pad = gr.Slider(label='Pad frames', minimum=0, maximum=24, step=1, value=1, visible=False, elem_id="control_video_pad") - video_interpolate = gr.Slider(label='Interpolate frames', minimum=0, maximum=24, step=1, value=0, visible=False, elem_id="control_video_interpolate") - video_type.change(fn=helpers.video_type_change, inputs=[video_type], outputs=[video_duration, video_loop, video_pad, video_interpolate]) + from modules.ui_sections import create_video_inputs + video_type, video_duration, video_loop, video_pad, video_interpolate = create_video_inputs() enable_hr, hr_sampler_index, hr_denoising_strength, hr_resize_mode, hr_resize_context, hr_upscaler, hr_force, hr_second_pass_steps, hr_scale, hr_resize_x, hr_resize_y, refiner_steps, refiner_start, refiner_prompt, refiner_negative = ui_sections.create_hires_inputs('control') - detailer = shared.yolo.ui('control') + detailer_enabled, detailer_prompt, detailer_negative, detailer_steps, detailer_strength = shared.yolo.ui('control') with gr.Row(): override_settings = ui_common.create_override_inputs('control') @@ -567,7 +562,8 @@ def create_ui(_blocks: gr.Blocks=None): prompt, negative, styles, steps, sampler_index, seed, subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, - cfg_scale, clip_skip, image_cfg_scale, guidance_rescale, pag_scale, pag_adaptive, cfg_end, full_quality, detailer, tiling, hidiffusion, + cfg_scale, clip_skip, image_cfg_scale, guidance_rescale, pag_scale, pag_adaptive, cfg_end, full_quality, tiling, hidiffusion, + detailer_enabled, detailer_prompt, detailer_negative, detailer_steps, detailer_strength, hdr_mode, hdr_brightness, hdr_color, hdr_sharpen, hdr_clamp, hdr_boundary, hdr_threshold, hdr_maximize, hdr_max_center, hdr_max_boundry, hdr_color_picker, hdr_tint_ratio, resize_mode_before, resize_name_before, resize_context_before, width_before, height_before, scale_by_before, selected_scale_tab_before, resize_mode_after, resize_name_after, resize_context_after, width_after, height_after, scale_by_after, selected_scale_tab_after, @@ -649,27 +645,37 @@ def create_ui(_blocks: gr.Blocks=None): (cfg_end, "CFG end"), (clip_skip, "Clip skip"), (image_cfg_scale, "Image CFG scale"), + (image_cfg_scale, "Hires CFG scale"), (guidance_rescale, "CFG rescale"), (full_quality, "Full quality"), - (detailer, "Detailer"), (tiling, "Tiling"), (hidiffusion, "HiDiffusion"), + # detailer + (detailer_enabled, "Detailer"), + (detailer_prompt, "Detailer prompt"), + (detailer_negative, "Detailer negative"), + (detailer_steps, "Detailer steps"), + (detailer_strength, "Detailer strength"), # second pass (enable_hr, "Second pass"), (enable_hr, "Refine"), - (hr_sampler_index, "Hires sampler"), (denoising_strength, "Denoising strength"), + (denoising_strength, "Hires strength"), + (hr_sampler_index, "Hires sampler"), + (hr_resize_mode, "Hires mode"), + (hr_resize_context, "Hires context"), (hr_upscaler, "Hires upscaler"), (hr_force, "Hires force"), (hr_second_pass_steps, "Hires steps"), (hr_scale, "Hires upscale"), - (hr_resize_x, "Hires resize-1"), - (hr_resize_y, "Hires resize-2"), + (hr_scale, "Hires scale"), + (hr_resize_x, "Hires fixed-1"), + (hr_resize_y, "Hires fixed-2"), # refiner (refiner_start, "Refiner start"), (refiner_steps, "Refiner steps"), - (refiner_prompt, "Prompt2"), - (refiner_negative, "Negative2"), + (refiner_prompt, "refiner prompt"), + (refiner_negative, "Refiner negative"), # pag (pag_scale, "PAG scale"), (pag_adaptive, "PAG adaptive"), diff --git a/modules/ui_control_helpers.py b/modules/ui_control_helpers.py index 1eeae23c2..553b5fbac 100644 --- a/modules/ui_control_helpers.py +++ b/modules/ui_control_helpers.py @@ -6,8 +6,8 @@ gr_height = None max_units = shared.opts.control_max_units -debug = shared.log.trace if os.environ.get('SD_CONTROL_DEBUG', None) is not None else lambda *args, **kwargs: None -debug('Trace: CONTROL') +debug = os.environ.get('SD_CONTROL_DEBUG', None) is not None +debug_log = shared.log.trace if debug else lambda *args, **kwargs: None # state variables busy = False # used to synchronize select_input and generate_click @@ -127,7 +127,7 @@ def select_input(input_mode, input_image, init_image, init_type, input_resize, i busy = False # debug('Control input: none') return [gr.Tabs.update(), None, ''] - debug(f'Control select input: source={selected_input} init={init_image} type={init_type} mode={input_mode}') + debug_log(f'Control select input: source={selected_input} init={init_image} type={init_type} mode={input_mode}') input_type = type(selected_input) input_mask = None status = 'Control input | Unknown' @@ -168,7 +168,7 @@ def select_input(input_mode, input_image, init_image, init_type, input_resize, i res = [gr.Tabs.update(selected='out-gallery'), input_mask, status] else: # unknown input_source = None - shared.log.debug(f'Control input: type={input_type} input={input_source}') + debug_log(f'Control input: type={input_type} input={input_source}') # init inputs: optional if init_type == 0: # Control only input_init = None @@ -176,22 +176,13 @@ def select_input(input_mode, input_image, init_image, init_type, input_resize, i input_init = None elif init_type == 2: # Separate init image input_init = [init_image] - debug(f'Control select input: source={input_source} init={input_init} mask={input_mask} mode={input_mode}') + debug_log(f'Control select input: source={input_source} init={input_init} mask={input_mask} mode={input_mode}') busy = False return res -def video_type_change(video_type): - return [ - gr.update(visible=video_type != 'None'), - gr.update(visible=video_type == 'GIF' or video_type == 'PNG'), - gr.update(visible=video_type == 'MP4'), - gr.update(visible=video_type == 'MP4'), - ] - - def copy_input(mode_from, mode_to, input_image, input_resize, input_inpaint): - debug(f'Control transfter input: from={mode_from} to={mode_to} image={input_image} resize={input_resize} inpaint={input_inpaint}') + debug_log(f'Control transfter input: from={mode_from} to={mode_to} image={input_image} resize={input_resize} inpaint={input_inpaint}') def getimg(ctrl): if ctrl is None: return None diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index fed39da27..769f47581 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -331,7 +331,7 @@ def find_preview_file(self, path): return 'html/card-no-preview.png' if os.path.join('models', 'Reference') in path: return path - exts = ["jpg", "jpeg", "png", "webp", "tiff", "jp2"] + exts = ["jpg", "jpeg", "png", "webp", "tiff", "jp2", "jxl"] reference_path = os.path.abspath(os.path.join('models', 'Reference')) files = list(files_cache.list_files(reference_path, ext_filter=exts, recursive=False)) if shared.opts.diffusers_dir in path: @@ -360,7 +360,7 @@ def update_all_previews(self, items): t0 = time.time() reference_path = os.path.abspath(os.path.join('models', 'Reference')) possible_paths = list(set([os.path.dirname(item['filename']) for item in items] + [reference_path])) - exts = ["jpg", "jpeg", "png", "webp", "tiff", "jp2"] + exts = ["jpg", "jpeg", "png", "webp", "tiff", "jp2", "jxl"] all_previews = list(files_cache.list_files(*possible_paths, ext_filter=exts, recursive=False)) all_previews_fn = [os.path.basename(x) for x in all_previews] for item in items: @@ -685,7 +685,7 @@ def fn_save_img(image): return image def fn_delete_img(_image): - preview_extensions = ["jpg", "jpeg", "png", "webp", "tiff", "jp2"] + preview_extensions = ["jpg", "jpeg", "png", "webp", "tiff", "jp2", "jxl"] fn = os.path.splitext(ui.last_item.filename)[0] for file in [f'{fn}{mid}{ext}' for ext in preview_extensions for mid in ['.thumb.', '.preview.', '.']]: if os.path.exists(file): diff --git a/modules/ui_img2img.py b/modules/ui_img2img.py index 45c901c6d..0b59e1b09 100644 --- a/modules/ui_img2img.py +++ b/modules/ui_img2img.py @@ -133,7 +133,7 @@ def fn_img_composite_change(img, img_composite): full_quality, tiling, hidiffusion, cfg_scale, clip_skip, image_cfg_scale, diffusers_guidance_rescale, pag_scale, pag_adaptive, cfg_end = ui_sections.create_advanced_inputs('img2img') hdr_mode, hdr_brightness, hdr_color, hdr_sharpen, hdr_clamp, hdr_boundary, hdr_threshold, hdr_maximize, hdr_max_center, hdr_max_boundry, hdr_color_picker, hdr_tint_ratio = ui_sections.create_correction_inputs('img2img') enable_hr, hr_sampler_index, hr_denoising_strength, hr_resize_mode, hr_resize_context, hr_upscaler, hr_force, hr_second_pass_steps, hr_scale, hr_resize_x, hr_resize_y, refiner_steps, hr_refiner_start, refiner_prompt, refiner_negative = ui_sections.create_hires_inputs('txt2img') - detailer = shared.yolo.ui('img2img') + detailer_enabled, detailer_prompt, detailer_negative, detailer_steps, detailer_strength = shared.yolo.ui('img2img') # with gr.Group(elem_id="inpaint_controls", visible=False) as inpaint_controls: with gr.Accordion(open=False, label="Mask", elem_classes=["small-accordion"], elem_id="img2img_mask_group") as inpaint_controls: @@ -174,7 +174,8 @@ def select_img2img_tab(tab): sampler_index, mask_blur, mask_alpha, inpainting_fill, - full_quality, detailer, tiling, hidiffusion, + full_quality, tiling, hidiffusion, + detailer_enabled, detailer_prompt, detailer_negative, detailer_steps, detailer_strength, batch_count, batch_size, cfg_scale, image_cfg_scale, diffusers_guidance_rescale, pag_scale, pag_adaptive, cfg_end, @@ -193,7 +194,7 @@ def select_img2img_tab(tab): override_settings, ] img2img_dict = dict( - fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']), + fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', ''], name='Image'), _js="submit_img2img", inputs= img2img_args + img2img_script_inputs, outputs=[ @@ -253,19 +254,45 @@ def select_img2img_tab(tab): (seed, "Seed"), (subseed, "Variation seed"), (subseed_strength, "Variation strength"), - # denoise - (denoising_strength, "Denoising strength"), - (refiner_start, "Refiner start"), # advanced (cfg_scale, "CFG scale"), (cfg_end, "CFG end"), (image_cfg_scale, "Image CFG scale"), + (image_cfg_scale, "Hires CFG scale"), (clip_skip, "Clip skip"), (diffusers_guidance_rescale, "CFG rescale"), (full_quality, "Full quality"), - (detailer, "Detailer"), (tiling, "Tiling"), (hidiffusion, "HiDiffusion"), + # detailer + (detailer_enabled, "Detailer"), + (detailer_prompt, "Detailer prompt"), + (detailer_negative, "Detailer negative"), + (detailer_steps, "Detailer steps"), + (detailer_strength, "Detailer strength"), + # second pass + (enable_hr, "Second pass"), + (enable_hr, "Refine"), + (denoising_strength, "Denoising strength"), + (denoising_strength, "Hires strength"), + (hr_sampler_index, "Hires sampler"), + (hr_resize_mode, "Hires mode"), + (hr_resize_context, "Hires context"), + (hr_upscaler, "Hires upscaler"), + (hr_force, "Hires force"), + (hr_second_pass_steps, "Hires steps"), + (hr_scale, "Hires upscale"), + (hr_scale, "Hires scale"), + (hr_resize_x, "Hires fixed-1"), + (hr_resize_y, "Hires fixed-2"), + # refiner + (refiner_start, "Refiner start"), + (refiner_steps, "Refiner steps"), + (refiner_prompt, "refiner prompt"), + (refiner_negative, "Refiner negative"), + # pag + (pag_scale, "PAG scale"), + (pag_adaptive, "PAG adaptive"), # inpaint (mask_blur, "Mask blur"), (mask_alpha, "Mask alpha"), diff --git a/modules/ui_loadsave.py b/modules/ui_loadsave.py index 6e398603a..01cf2e97c 100644 --- a/modules/ui_loadsave.py +++ b/modules/ui_loadsave.py @@ -4,7 +4,7 @@ from modules.ui_components import ToolButton -debug = os.environ.get('SD_UI_DEBUG', None) +debug_ui = os.environ.get('SD_UI_DEBUG', None) class UiLoadsave: @@ -46,7 +46,7 @@ def apply_field(obj, field, condition=None, init_field=None): setattr(obj, field, saved_value) if init_field is not None: init_field(saved_value) - if debug and key in self.component_mapping and not key.startswith('customscript'): + if debug_ui and key in self.component_mapping and not key.startswith('customscript'): errors.log.warning(f'UI duplicate: key="{key}" id={getattr(obj, "elem_id", None)} class={getattr(obj, "elem_classes", None)}') if field == 'value' and key not in self.component_mapping: self.component_mapping[key] = x diff --git a/modules/ui_models.py b/modules/ui_models.py index 7ab8b0d07..5d5b452e2 100644 --- a/modules/ui_models.py +++ b/modules/ui_models.py @@ -290,7 +290,7 @@ def preset_choices(sdxl): beta_apply_preset.click(fn=load_presets, inputs=[beta_preset, beta_preset_lambda], outputs=[beta_base, beta_in_blocks, beta_mid_block, beta_out_blocks, tabs]) modelmerger_merge.click( - fn=wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]), + fn=wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)], name='Models'), _js='modelmerger', inputs=[ dummy_component, diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py index e9ac2d72a..6e12339e7 100644 --- a/modules/ui_postprocessing.py +++ b/modules/ui_postprocessing.py @@ -129,7 +129,7 @@ def create_ui(): ) submit.click( _js="submit_postprocessing", - fn=call_queue.wrap_gradio_gpu_call(submit_process, extra_outputs=[None, '']), + fn=call_queue.wrap_gradio_gpu_call(submit_process, extra_outputs=[None, ''], name='Postprocess'), inputs=[ tab_index, extras_image, diff --git a/modules/ui_sections.py b/modules/ui_sections.py index 839fa9025..a9bbe4ef2 100644 --- a/modules/ui_sections.py +++ b/modules/ui_sections.py @@ -63,21 +63,6 @@ def parse_style(styles): def ar_change(ar, width, height): - """ - if ar == 'AR': - return gr.update(interactive=True), gr.update(interactive=True) - try: - (w, h) = [float(x) for x in ar.split(':')] - except Exception as e: - shared.log.warning(f"Invalid aspect ratio: {ar} {e}") - return gr.update(interactive=True), gr.update(interactive=True) - if w > h: - return gr.update(interactive=True, value=width), gr.update(interactive=False, value=int(width * h / w)) - elif w < h: - return gr.update(interactive=False, value=int(height * w / h)), gr.update(interactive=True, value=height) - else: - return gr.update(interactive=True, value=width), gr.update(interactive=False, value=width) - """ if ar == 'AR': return gr.update(), gr.update() try: @@ -146,6 +131,26 @@ def create_seed_inputs(tab, reuse_visible=True): return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w +def create_video_inputs(): + def video_type_change(video_type): + return [ + gr.update(visible=video_type != 'None'), + gr.update(visible=video_type in ['GIF', 'PNG']), + gr.update(visible=video_type not in ['None', 'GIF', 'PNG']), + gr.update(visible=video_type not in ['None', 'GIF', 'PNG']), + ] + with gr.Column(): + video_codecs = ['None', 'GIF', 'PNG', 'MP4/MP4V', 'MP4/AVC1', 'MP4/JVT3', 'MKV/H264', 'AVI/DIVX', 'AVI/RGBA', 'MJPEG/MJPG', 'MPG/MPG1', 'AVR/AVR1'] + video_type = gr.Dropdown(label='Video type', choices=video_codecs, value='None') + with gr.Column(): + video_duration = gr.Slider(label='Duration', minimum=0.25, maximum=300, step=0.25, value=2, visible=False) + video_loop = gr.Checkbox(label='Loop', value=True, visible=False, elem_id="control_video_loop") + video_pad = gr.Slider(label='Pad frames', minimum=0, maximum=24, step=1, value=1, visible=False) + video_interpolate = gr.Slider(label='Interpolate frames', minimum=0, maximum=24, step=1, value=0, visible=False) + video_type.change(fn=video_type_change, inputs=[video_type], outputs=[video_duration, video_loop, video_pad, video_interpolate]) + return video_type, video_duration, video_loop, video_pad, video_interpolate + + def create_cfg_inputs(tab): with gr.Row(): cfg_scale = gr.Slider(minimum=0.0, maximum=30.0, step=0.1, label='Guidance scale', value=6.0, elem_id=f"{tab}_cfg_scale") diff --git a/modules/ui_txt2img.py b/modules/ui_txt2img.py index e2886e901..64d2906f8 100644 --- a/modules/ui_txt2img.py +++ b/modules/ui_txt2img.py @@ -47,7 +47,7 @@ def create_ui(): full_quality, tiling, hidiffusion, _cfg_scale, clip_skip, image_cfg_scale, diffusers_guidance_rescale, pag_scale, pag_adaptive, _cfg_end = ui_sections.create_advanced_inputs('txt2img', base=False) hdr_mode, hdr_brightness, hdr_color, hdr_sharpen, hdr_clamp, hdr_boundary, hdr_threshold, hdr_maximize, hdr_max_center, hdr_max_boundry, hdr_color_picker, hdr_tint_ratio = ui_sections.create_correction_inputs('txt2img') enable_hr, hr_sampler_index, denoising_strength, hr_resize_mode, hr_resize_context, hr_upscaler, hr_force, hr_second_pass_steps, hr_scale, hr_resize_x, hr_resize_y, refiner_steps, refiner_start, refiner_prompt, refiner_negative = ui_sections.create_hires_inputs('txt2img') - detailer = shared.yolo.ui('txt2img') + detailer_enabled, detailer_prompt, detailer_negative, detailer_steps, detailer_strength = shared.yolo.ui('txt2img') override_settings = ui_common.create_override_inputs('txt2img') state = gr.Textbox(value='', visible=False) @@ -64,7 +64,8 @@ def create_ui(): dummy_component, state, txt2img_prompt, txt2img_negative_prompt, txt2img_prompt_styles, steps, sampler_index, hr_sampler_index, - full_quality, detailer, tiling, hidiffusion, + full_quality, tiling, hidiffusion, + detailer_enabled, detailer_prompt, detailer_negative, detailer_steps, detailer_strength, batch_count, batch_size, cfg_scale, image_cfg_scale, diffusers_guidance_rescale, pag_scale, pag_adaptive, cfg_end, clip_skip, @@ -77,7 +78,7 @@ def create_ui(): override_settings, ] txt2img_dict = dict( - fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']), + fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', ''], name='Text'), _js="submit_txt2img", inputs=txt2img_args + txt2img_script_inputs, outputs=[ @@ -119,32 +120,37 @@ def create_ui(): (cfg_end, "CFG end"), (clip_skip, "Clip skip"), (image_cfg_scale, "Image CFG scale"), + (image_cfg_scale, "Hires CFG scale"), (diffusers_guidance_rescale, "CFG rescale"), (full_quality, "Full quality"), - (detailer, "Detailer"), (tiling, "Tiling"), (hidiffusion, "HiDiffusion"), + # detailer + (detailer_enabled, "Detailer"), + (detailer_prompt, "Detailer prompt"), + (detailer_negative, "Detailer negative"), + (detailer_steps, "Detailer steps"), + (detailer_strength, "Detailer strength"), # second pass (enable_hr, "Second pass"), (enable_hr, "Refine"), (denoising_strength, "Denoising strength"), + (denoising_strength, "Hires strength"), (hr_sampler_index, "Hires sampler"), - (hr_resize_mode, "Hires resize mode"), - (hr_resize_context, "Hires resize context"), + (hr_resize_mode, "Hires mode"), + (hr_resize_context, "Hires context"), (hr_upscaler, "Hires upscaler"), (hr_force, "Hires force"), (hr_second_pass_steps, "Hires steps"), (hr_scale, "Hires upscale"), (hr_scale, "Hires scale"), - (hr_resize_x, "Hires resize-1"), - (hr_resize_y, "Hires resize-2"), - (hr_resize_x, "Hires size-1"), - (hr_resize_y, "Hires size-2"), + (hr_resize_x, "Hires fixed-1"), + (hr_resize_y, "Hires fixed-2"), # refiner (refiner_start, "Refiner start"), (refiner_steps, "Refiner steps"), - (refiner_prompt, "Prompt2"), - (refiner_negative, "Negative2"), + (refiner_prompt, "refiner prompt"), + (refiner_negative, "Refiner negative"), # pag (pag_scale, "PAG scale"), (pag_adaptive, "PAG adaptive"), diff --git a/modules/video.py b/modules/video.py new file mode 100644 index 000000000..d9e40a27f --- /dev/null +++ b/modules/video.py @@ -0,0 +1,85 @@ +import os +import threading +import numpy as np +from modules import shared, errors +from modules.images_namegen import FilenameGenerator # pylint: disable=unused-import + + +def interpolate_frames(images, count: int = 0, scale: float = 1.0, pad: int = 1, change: float = 0.3): + if images is None: + return [] + if not isinstance(images, list): + images = [images] + if count > 0: + try: + import modules.rife + frames = modules.rife.interpolate(images, count=count, scale=scale, pad=pad, change=change) + if len(frames) > 0: + images = frames + except Exception as e: + shared.log.error(f'RIFE interpolation: {e}') + errors.display(e, 'RIFE interpolation') + return [np.array(image) for image in images] + + +def save_video_atomic(images, filename, video_type: str = 'none', duration: float = 2.0, loop: bool = False, interpolate: int = 0, scale: float = 1.0, pad: int = 1, change: float = 0.3): + try: + import cv2 + except Exception as e: + shared.log.error(f'Save video: cv2: {e}') + return + os.makedirs(os.path.dirname(filename), exist_ok=True) + if video_type.lower() in ['gif', 'png']: + append = images.copy() + image = append.pop(0) + if loop: + append += append[::-1] + frames=len(append) + 1 + image.save( + filename, + save_all = True, + append_images = append, + optimize = False, + duration = 1000.0 * duration / frames, + loop = 0 if loop else 1, + ) + size = os.path.getsize(filename) + shared.log.info(f'Save video: file="{filename}" frames={len(append) + 1} duration={duration} loop={loop} size={size}') + elif video_type.lower() != 'none': + frames = interpolate_frames(images, count=interpolate, scale=scale, pad=pad, change=change) + fourcc = "mp4v" + h, w, _c = frames[0].shape + video_writer = cv2.VideoWriter(filename, fourcc=cv2.VideoWriter_fourcc(*fourcc), fps=len(frames)/duration, frameSize=(w, h)) + for i in range(len(frames)): + img = cv2.cvtColor(frames[i], cv2.COLOR_RGB2BGR) + video_writer.write(img) + size = os.path.getsize(filename) + shared.log.info(f'Save video: file="{filename}" frames={len(frames)} duration={duration} fourcc={fourcc} size={size}') + + +def save_video(p, images, filename = None, video_type: str = 'none', duration: float = 2.0, loop: bool = False, interpolate: int = 0, scale: float = 1.0, pad: int = 1, change: float = 0.3, sync: bool = False): + if images is None or len(images) < 2 or video_type is None or video_type.lower() == 'none': + return None + image = images[0] + if p is not None: + seed = p.all_seeds[0] if getattr(p, 'all_seeds', None) is not None else p.seed + prompt = p.all_prompts[0] if getattr(p, 'all_prompts', None) is not None else p.prompt + namegen = FilenameGenerator(p, seed=seed, prompt=prompt, image=image) + else: + namegen = FilenameGenerator(None, seed=0, prompt='', image=image) + if filename is None and p is not None: + filename = namegen.apply(shared.opts.samples_filename_pattern if shared.opts.samples_filename_pattern and len(shared.opts.samples_filename_pattern) > 0 else "[seq]-[prompt_words]") + filename = os.path.join(shared.opts.outdir_video, filename) + filename = namegen.sequence(filename, shared.opts.outdir_video, '') + else: + if os.pathsep not in filename: + filename = os.path.join(shared.opts.outdir_video, filename) + ext = video_type.lower().split('/')[0] if '/' in video_type else video_type.lower() + if not filename.lower().endswith(ext): + filename += f'.{ext}' + filename = namegen.sanitize(filename) + if not sync: + threading.Thread(target=save_video_atomic, args=(images, filename, video_type, duration, loop, interpolate, scale, pad, change)).start() + else: + save_video_atomic(images, filename, video_type, duration, loop, interpolate, scale, pad, change) + return filename diff --git a/modules/zluda_hijacks.py b/modules/zluda_hijacks.py index 0f42a5448..872127cf1 100644 --- a/modules/zluda_hijacks.py +++ b/modules/zluda_hijacks.py @@ -9,30 +9,6 @@ def topk(input: torch.Tensor, *args, **kwargs): # pylint: disable=redefined-buil return torch.return_types.topk((values.to(device), indices.to(device),)) -_fft_fftn = torch.fft.fftn -def fft_fftn(input: torch.Tensor, *args, **kwargs) -> torch.Tensor: # pylint: disable=redefined-builtin - return _fft_fftn(input.cpu(), *args, **kwargs).to(input.device) - - -_fft_ifftn = torch.fft.ifftn -def fft_ifftn(input: torch.Tensor, *args, **kwargs) -> torch.Tensor: # pylint: disable=redefined-builtin - return _fft_ifftn(input.cpu(), *args, **kwargs).to(input.device) - - -_fft_rfftn = torch.fft.rfftn -def fft_rfftn(input: torch.Tensor, *args, **kwargs) -> torch.Tensor: # pylint: disable=redefined-builtin - return _fft_rfftn(input.cpu(), *args, **kwargs).to(input.device) - - -def jit_script(f, *_, **__): # experiment / provide dummy graph - f.graph = torch._C.Graph() # pylint: disable=protected-access - return f - - def do_hijack(): torch.version.hip = rocm.version torch.topk = topk - torch.fft.fftn = fft_fftn - torch.fft.ifftn = fft_ifftn - torch.fft.rfftn = fft_rfftn - torch.jit.script = jit_script diff --git a/modules/zluda_installer.py b/modules/zluda_installer.py index 5e43a6635..8888a56f4 100644 --- a/modules/zluda_installer.py +++ b/modules/zluda_installer.py @@ -1,5 +1,6 @@ import os import sys +import site import ctypes import shutil import zipfile @@ -11,15 +12,17 @@ DLL_MAPPING = { 'cublas.dll': 'cublas64_11.dll', 'cusparse.dll': 'cusparse64_11.dll', + 'cufft.dll': 'cufft64_10.dll', + 'cufftw.dll': 'cufftw64_10.dll', 'nvrtc.dll': 'nvrtc64_112_0.dll', } -HIPSDK_TARGETS = ['rocblas.dll', 'rocsolver.dll', f'hiprtc{"".join([v.zfill(2) for v in rocm.version.split(".")])}.dll'] +HIPSDK_TARGETS = ['rocblas.dll', 'rocsolver.dll', 'hipfft.dll',] ZLUDA_TARGETS = ('nvcuda.dll', 'nvml.dll',) -default_agent: Union[rocm.Agent, None] = None - -def get_path() -> str: - return os.path.abspath(os.environ.get('ZLUDA', '.zluda')) +path = os.path.abspath(os.environ.get('ZLUDA', '.zluda')) +default_agent: Union[rocm.Agent, None] = None +nightly = os.environ.get("ZLUDA_NIGHTLY", "0") == "1" +hipBLASLt_enabled = os.path.exists(os.path.join(rocm.path, "bin", "hipblaslt.dll")) and os.path.exists(rocm.blaslt_tensile_libpath) and ((not os.path.exists(path) and nightly) or os.path.exists(os.path.join(path, 'cublasLt.dll'))) def set_default_agent(agent: rocm.Agent): @@ -27,46 +30,72 @@ def set_default_agent(agent: rocm.Agent): default_agent = agent -def install(zluda_path: os.PathLike) -> None: - if os.path.exists(zluda_path): +def is_old_zluda() -> bool: # ZLUDA<3.8.7 + return not os.path.exists(os.path.join(path, "cufftw.dll")) + + +def install() -> None: + if os.path.exists(path): return - commit = os.environ.get("ZLUDA_HASH", "1b6e012d8f2404840b524e2abae12cb91e1ac01d") - if rocm.version == "6.1": - commit = "c0804ca624963aab420cb418412b1c7fbae3454b" - urllib.request.urlretrieve(f'https://github.com/lshqqytiger/ZLUDA/releases/download/rel.{commit}/ZLUDA-windows-rocm{rocm.version[0]}-amd64.zip', '_zluda') + platform = "windows" + commit = os.environ.get("ZLUDA_HASH", "c4994b3093e02231339d22e12be08418b2af781f") + if nightly: + platform = "nightly-" + platform + urllib.request.urlretrieve(f'https://github.com/lshqqytiger/ZLUDA/releases/download/rel.{commit}/ZLUDA-{platform}-rocm{rocm.version[0]}-amd64.zip', '_zluda') with zipfile.ZipFile('_zluda', 'r') as archive: infos = archive.infolist() for info in infos: if not info.is_dir(): info.filename = os.path.basename(info.filename) - archive.extract(info, '.zluda') + archive.extract(info, path) os.remove('_zluda') def uninstall() -> None: - if os.path.exists('.zluda'): - shutil.rmtree('.zluda') + if os.path.exists(path): + shutil.rmtree(path) + + +def set_blaslt_enabled(enabled: bool): + global hipBLASLt_enabled # pylint: disable=global-statement + hipBLASLt_enabled = enabled -def make_copy(zluda_path: os.PathLike) -> None: +def get_blaslt_enabled() -> bool: + return hipBLASLt_enabled + + +def link_or_copy(src: os.PathLike, dst: os.PathLike): + try: + os.link(src, dst) + except Exception: + shutil.copyfile(src, dst) + + +def make_copy() -> None: for k, v in DLL_MAPPING.items(): - if not os.path.exists(os.path.join(zluda_path, v)): - try: - os.link(os.path.join(zluda_path, k), os.path.join(zluda_path, v)) - except Exception: - shutil.copyfile(os.path.join(zluda_path, k), os.path.join(zluda_path, v)) + if not os.path.exists(os.path.join(path, v)): + link_or_copy(os.path.join(path, k), os.path.join(path, v)) + + if hipBLASLt_enabled and not os.path.exists(os.path.join(path, 'cublasLt64_11.dll')): + link_or_copy(os.path.join(path, 'cublasLt.dll'), os.path.join(path, 'cublasLt64_11.dll')) -def load(zluda_path: os.PathLike) -> None: +def load() -> None: os.environ["ZLUDA_COMGR_LOG_LEVEL"] = "1" + os.environ["ZLUDA_NVRTC_LIB"] = os.path.join([v for v in site.getsitepackages() if v.endswith("site-packages")][0], "torch", "lib", "nvrtc64_112_0.dll") for v in HIPSDK_TARGETS: ctypes.windll.LoadLibrary(os.path.join(rocm.path, 'bin', v)) for v in ZLUDA_TARGETS: - ctypes.windll.LoadLibrary(os.path.join(zluda_path, v)) + ctypes.windll.LoadLibrary(os.path.join(path, v)) for v in DLL_MAPPING.values(): - ctypes.windll.LoadLibrary(os.path.join(zluda_path, v)) + ctypes.windll.LoadLibrary(os.path.join(path, v)) + + if hipBLASLt_enabled: + ctypes.windll.LoadLibrary(os.path.join(rocm.path, 'bin', 'hipblaslt.dll')) + ctypes.windll.LoadLibrary(os.path.join(path, 'cublasLt64_11.dll')) def conceal(): import torch # pylint: disable=unused-import @@ -86,7 +115,7 @@ def _join_rocm_home(*paths) -> str: def get_default_torch_version(agent: Optional[rocm.Agent]) -> str: if agent is not None: if agent.arch in (rocm.MicroArchitecture.RDNA, rocm.MicroArchitecture.CDNA,): - return "2.3.1" + return "2.4.1" if hipBLASLt_enabled else "2.3.1" elif agent.arch == rocm.MicroArchitecture.GCN: return "2.2.1" - return "2.3.1" + return "2.4.1" if hipBLASLt_enabled else "2.3.1" diff --git a/package.json b/package.json index 64115587d..95f8899bf 100644 --- a/package.json +++ b/package.json @@ -17,7 +17,7 @@ }, "scripts": { "venv": "source venv/bin/activate", - "start": "python launch.py --debug --experimental", + "start": "python launch.py --debug", "ruff": "ruff check", "eslint": "eslint javascript/ extensions-builtin/sdnext-modernui/javascript/", "pylint": "pylint *.py modules/ extensions-builtin/", diff --git a/requirements.txt b/requirements.txt index a67ad1ba2..c8ed07a40 100644 --- a/requirements.txt +++ b/requirements.txt @@ -32,7 +32,7 @@ invisible-watermark pi-heif # versioned -safetensors==0.4.5 +safetensors==0.5.0 tensordict==0.1.2 peft==0.14.0 httpx==0.24.1 diff --git a/scripts/allegrovideo.py b/scripts/allegrovideo.py index 675c6fb06..b340e4962 100644 --- a/scripts/allegrovideo.py +++ b/scripts/allegrovideo.py @@ -38,14 +38,6 @@ def show(self, is_img2img): # return signature is array of gradio components def ui(self, _is_img2img): - def video_type_change(video_type): - return [ - gr.update(visible=video_type != 'None'), - gr.update(visible=video_type == 'GIF' or video_type == 'PNG'), - gr.update(visible=video_type == 'MP4'), - gr.update(visible=video_type == 'MP4'), - ] - with gr.Row(): gr.HTML('  Allegro Video
') with gr.Row(): @@ -53,13 +45,8 @@ def video_type_change(video_type): with gr.Row(): override_scheduler = gr.Checkbox(label='Override scheduler', value=True) with gr.Row(): - video_type = gr.Dropdown(label='Video file', choices=['None', 'GIF', 'PNG', 'MP4'], value='None') - duration = gr.Slider(label='Duration', minimum=0.25, maximum=10, step=0.25, value=2, visible=False) - with gr.Row(): - gif_loop = gr.Checkbox(label='Loop', value=True, visible=False) - mp4_pad = gr.Slider(label='Pad frames', minimum=0, maximum=24, step=1, value=1, visible=False) - mp4_interpolate = gr.Slider(label='Interpolate frames', minimum=0, maximum=24, step=1, value=0, visible=False) - video_type.change(fn=video_type_change, inputs=[video_type], outputs=[duration, gif_loop, mp4_pad, mp4_interpolate]) + from modules.ui_sections import create_video_inputs + video_type, duration, gif_loop, mp4_pad, mp4_interpolate = create_video_inputs() return [num_frames, override_scheduler, video_type, duration, gif_loop, mp4_pad, mp4_interpolate] def run(self, p: processing.StableDiffusionProcessing, num_frames, override_scheduler, video_type, duration, gif_loop, mp4_pad, mp4_interpolate): # pylint: disable=arguments-differ, unused-argument diff --git a/scripts/animatediff.py b/scripts/animatediff.py index f44c85bb7..a704baae7 100644 --- a/scripts/animatediff.py +++ b/scripts/animatediff.py @@ -197,14 +197,6 @@ def show(self, is_img2img): def ui(self, _is_img2img): - def video_type_change(video_type): - return [ - gr.update(visible=video_type != 'None'), - gr.update(visible=video_type == 'GIF' or video_type == 'PNG'), - gr.update(visible=video_type == 'MP4'), - gr.update(visible=video_type == 'MP4'), - ] - with gr.Row(): gr.HTML("  AnimateDiff
") with gr.Row(): @@ -217,9 +209,6 @@ def video_type_change(video_type): strength = gr.Slider(label='Strength', minimum=0.0, maximum=2.0, step=0.05, value=1.0) with gr.Row(): latent_mode = gr.Checkbox(label='Latent mode', value=True, visible=False) - with gr.Row(): - video_type = gr.Dropdown(label='Video file', choices=['None', 'GIF', 'PNG', 'MP4'], value='None') - duration = gr.Slider(label='Duration', minimum=0.25, maximum=10, step=0.25, value=2, visible=False) with gr.Accordion('FreeInit', open=False): with gr.Row(): fi_method = gr.Dropdown(label='Method', choices=['none', 'butterworth', 'ideal', 'gaussian'], value='none') @@ -231,10 +220,8 @@ def video_type_change(video_type): fi_spatial = gr.Slider(label='Spatial frequency', minimum=0.0, maximum=1.0, step=0.05, value=0.25) fi_temporal = gr.Slider(label='Temporal frequency', minimum=0.0, maximum=1.0, step=0.05, value=0.25) with gr.Row(): - gif_loop = gr.Checkbox(label='Loop', value=True, visible=False) - mp4_pad = gr.Slider(label='Pad frames', minimum=0, maximum=24, step=1, value=1, visible=False) - mp4_interpolate = gr.Slider(label='Interpolate frames', minimum=0, maximum=24, step=1, value=0, visible=False) - video_type.change(fn=video_type_change, inputs=[video_type], outputs=[duration, gif_loop, mp4_pad, mp4_interpolate]) + from modules.ui_sections import create_video_inputs + video_type, duration, gif_loop, mp4_pad, mp4_interpolate = create_video_inputs() return [adapter_index, frames, lora_index, strength, latent_mode, video_type, duration, gif_loop, mp4_pad, mp4_interpolate, override_scheduler, fi_method, fi_iters, fi_order, fi_spatial, fi_temporal] def run(self, p: processing.StableDiffusionProcessing, adapter_index, frames, lora_index, strength, latent_mode, video_type, duration, gif_loop, mp4_pad, mp4_interpolate, override_scheduler, fi_method, fi_iters, fi_order, fi_spatial, fi_temporal): # pylint: disable=arguments-differ, unused-argument diff --git a/scripts/cogvideo.py b/scripts/cogvideo.py index e689a5e3f..7b7a557f8 100644 --- a/scripts/cogvideo.py +++ b/scripts/cogvideo.py @@ -29,14 +29,6 @@ def show(self, is_img2img): def ui(self, _is_img2img): - def video_type_change(video_type): - return [ - gr.update(visible=video_type != 'None'), - gr.update(visible=video_type == 'GIF' or video_type == 'PNG'), - gr.update(visible=video_type == 'MP4'), - gr.update(visible=video_type == 'MP4'), - ] - with gr.Row(): gr.HTML("  CogVideoX
") with gr.Row(): @@ -48,18 +40,13 @@ def video_type_change(video_type): with gr.Row(): offload = gr.Dropdown(label='Offload', choices=['none', 'balanced', 'model', 'sequential'], value='balanced') override = gr.Checkbox(label='Override resolution', value=True) - with gr.Row(): - video_type = gr.Dropdown(label='Video file', choices=['None', 'GIF', 'PNG', 'MP4'], value='None') - duration = gr.Slider(label='Duration', minimum=0.25, maximum=30, step=0.25, value=8, visible=False) with gr.Accordion('Optional init image or video', open=False): with gr.Row(): image = gr.Image(value=None, label='Image', type='pil', source='upload', width=256, height=256) video = gr.Video(value=None, label='Video', source='upload', width=256, height=256) with gr.Row(): - loop = gr.Checkbox(label='Loop', value=True, visible=False) - pad = gr.Slider(label='Pad frames', minimum=0, maximum=24, step=1, value=1, visible=False) - interpolate = gr.Slider(label='Interpolate frames', minimum=0, maximum=24, step=1, value=0, visible=False) - video_type.change(fn=video_type_change, inputs=[video_type], outputs=[duration, loop, pad, interpolate]) + from modules.ui_sections import create_video_inputs + video_type, duration, loop, pad, interpolate = create_video_inputs() return [model, sampler, frames, guidance, offload, override, video_type, duration, loop, pad, interpolate, image, video] def load(self, model): diff --git a/scripts/flux_tools.py b/scripts/flux_tools.py index 50904eedb..ff5367cd8 100644 --- a/scripts/flux_tools.py +++ b/scripts/flux_tools.py @@ -45,10 +45,6 @@ def run(self, p: processing.StableDiffusionProcessing, tool: str = 'None', promp global redux_pipe, processor_canny, processor_depth # pylint: disable=global-statement if tool is None or tool == 'None': return - supported_model_list = ['f1'] - if shared.sd_model_type not in supported_model_list: - shared.log.warning(f'{title}: class={shared.sd_model.__class__.__name__} model={shared.sd_model_type} required={supported_model_list}') - return None image = getattr(p, 'init_images', None) if image is None or len(image) == 0: shared.log.error(f'{title}: tool={tool} no init_images') @@ -60,6 +56,10 @@ def run(self, p: processing.StableDiffusionProcessing, tool: str = 'None', promp t0 = time.time() if tool == 'Redux': + supported_model_list = ['f1'] + if shared.sd_model_type not in supported_model_list: + shared.log.warning(f'{title}: class={shared.sd_model.__class__.__name__} model={shared.sd_model_type} required={supported_model_list}') + return None # pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained("black-forest-labs/FLUX.1-Redux-dev", revision="refs/pr/8", torch_dtype=torch.bfloat16).to("cuda") shared.log.debug(f'{title}: tool={tool} prompt={prompt}') if redux_pipe is None: diff --git a/scripts/hunyuanvideo.py b/scripts/hunyuanvideo.py index 0e99a26a9..874e20fc0 100644 --- a/scripts/hunyuanvideo.py +++ b/scripts/hunyuanvideo.py @@ -61,14 +61,6 @@ def show(self, is_img2img): # return signature is array of gradio components def ui(self, _is_img2img): - def video_type_change(video_type): - return [ - gr.update(visible=video_type != 'None'), - gr.update(visible=video_type == 'GIF' or video_type == 'PNG'), - gr.update(visible=video_type == 'MP4'), - gr.update(visible=video_type == 'MP4'), - ] - with gr.Row(): gr.HTML('  Hunyuan Video
') with gr.Row(): @@ -79,13 +71,8 @@ def video_type_change(video_type): with gr.Row(): template = gr.TextArea(label='Prompt processor', lines=3, value=default_template) with gr.Row(): - video_type = gr.Dropdown(label='Video file', choices=['None', 'GIF', 'PNG', 'MP4'], value='None') - duration = gr.Slider(label='Duration', minimum=0.25, maximum=10, step=0.25, value=2, visible=False) - with gr.Row(): - gif_loop = gr.Checkbox(label='Loop', value=True, visible=False) - mp4_pad = gr.Slider(label='Pad frames', minimum=0, maximum=24, step=1, value=1, visible=False) - mp4_interpolate = gr.Slider(label='Interpolate frames', minimum=0, maximum=24, step=1, value=0, visible=False) - video_type.change(fn=video_type_change, inputs=[video_type], outputs=[duration, gif_loop, mp4_pad, mp4_interpolate]) + from modules.ui_sections import create_video_inputs + video_type, duration, gif_loop, mp4_pad, mp4_interpolate = create_video_inputs() return [num_frames, tile_frames, override_scheduler, template, video_type, duration, gif_loop, mp4_pad, mp4_interpolate] def run(self, p: processing.StableDiffusionProcessing, num_frames, tile_frames, override_scheduler, template, video_type, duration, gif_loop, mp4_pad, mp4_interpolate): # pylint: disable=arguments-differ, unused-argument diff --git a/scripts/image2video.py b/scripts/image2video.py index ad6615f67..6ef27412c 100644 --- a/scripts/image2video.py +++ b/scripts/image2video.py @@ -21,15 +21,6 @@ def show(self, is_img2img): # return signature is array of gradio components def ui(self, _is_img2img): - - def video_change(video_type): - return [ - gr.update(visible=video_type != 'None'), - gr.update(visible=video_type == 'GIF' or video_type == 'PNG'), - gr.update(visible=video_type == 'MP4'), - gr.update(visible=video_type == 'MP4'), - ] - def model_change(model_name): model = next(m for m in MODELS if m['name'] == model_name) return gr.update(value=model['info']), gr.update(visible=model_name == 'PIA'), gr.update(visible=model_name == 'VGen') @@ -40,9 +31,6 @@ def model_change(model_name): model_info = gr.HTML() with gr.Row(): num_frames = gr.Slider(label='Frames', minimum=0, maximum=50, step=1, value=16) - with gr.Row(): - video_type = gr.Dropdown(label='Video file', choices=['None', 'GIF', 'PNG', 'MP4'], value='None') - duration = gr.Slider(label='Duration', minimum=0.25, maximum=10, step=0.25, value=2, visible=False) with gr.Accordion('FreeInit', open=False, visible=False) as fi_accordion: with gr.Row(): fi_method = gr.Dropdown(label='Method', choices=['none', 'butterworth', 'ideal', 'gaussian'], value='none') @@ -58,11 +46,9 @@ def model_change(model_name): vg_chunks = gr.Slider(label='Decode chunks', minimum=0.1, maximum=1.0, step=0.1, value=0.5) vg_fps = gr.Slider(label='Change rate', minimum=0.1, maximum=1.0, step=0.1, value=0.5) with gr.Row(): - gif_loop = gr.Checkbox(label='Loop', value=True, visible=False) - mp4_pad = gr.Slider(label='Pad frames', minimum=0, maximum=24, step=1, value=1, visible=False) - mp4_interpolate = gr.Slider(label='Interpolate frames', minimum=0, maximum=24, step=1, value=0, visible=False) + from modules.ui_sections import create_video_inputs + video_type, duration, gif_loop, mp4_pad, mp4_interpolate = create_video_inputs() model_name.change(fn=model_change, inputs=[model_name], outputs=[model_info, fi_accordion, vgen_accordion]) - video_type.change(fn=video_change, inputs=[video_type], outputs=[duration, gif_loop, mp4_pad, mp4_interpolate]) return [model_name, num_frames, video_type, duration, gif_loop, mp4_pad, mp4_interpolate, fi_method, fi_iters, fi_order, fi_spatial, fi_temporal, vg_chunks, vg_fps] def run(self, p: processing.StableDiffusionProcessing, model_name, num_frames, video_type, duration, gif_loop, mp4_pad, mp4_interpolate, fi_method, fi_iters, fi_order, fi_spatial, fi_temporal, vg_chunks, vg_fps): # pylint: disable=arguments-differ, unused-argument diff --git a/scripts/ltxvideo.py b/scripts/ltxvideo.py index e6464970d..3f6cad08e 100644 --- a/scripts/ltxvideo.py +++ b/scripts/ltxvideo.py @@ -66,13 +66,6 @@ def show(self, is_img2img): # return signature is array of gradio components def ui(self, _is_img2img): - def video_type_change(video_type): - return [ - gr.update(visible=video_type != 'None'), - gr.update(visible=video_type == 'GIF' or video_type == 'PNG'), - gr.update(visible=video_type == 'MP4'), - gr.update(visible=video_type == 'MP4'), - ] def model_change(model): return gr.update(visible=model == 'custom') @@ -90,13 +83,8 @@ def model_change(model): with gr.Row(): model_custom = gr.Textbox(value='', label='Path to model file', visible=False) with gr.Row(): - video_type = gr.Dropdown(label='Video file', choices=['None', 'GIF', 'PNG', 'MP4'], value='None') - duration = gr.Slider(label='Duration', minimum=0.25, maximum=10, step=0.25, value=2, visible=False) - with gr.Row(): - gif_loop = gr.Checkbox(label='Loop', value=True, visible=False) - mp4_pad = gr.Slider(label='Pad frames', minimum=0, maximum=24, step=1, value=1, visible=False) - mp4_interpolate = gr.Slider(label='Interpolate frames', minimum=0, maximum=24, step=1, value=0, visible=False) - video_type.change(fn=video_type_change, inputs=[video_type], outputs=[duration, gif_loop, mp4_pad, mp4_interpolate]) + from modules.ui_sections import create_video_inputs + video_type, duration, gif_loop, mp4_pad, mp4_interpolate = create_video_inputs() model.change(fn=model_change, inputs=[model], outputs=[model_custom]) return [model, model_custom, decode, sampler, num_frames, video_type, duration, gif_loop, mp4_pad, mp4_interpolate, teacache_enable, teacache_threshold] @@ -160,6 +148,7 @@ def run(self, p: processing.StableDiffusionProcessing, model, model_custom, deco shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model) shared.sd_model.vae.enable_slicing() shared.sd_model.vae.enable_tiling() + shared.sd_model.vae.use_framewise_decoding = True devices.torch_gc(force=True) shared.sd_model.transformer.cnt = 0 diff --git a/scripts/mochivideo.py b/scripts/mochivideo.py index f85616a5e..cb2950eda 100644 --- a/scripts/mochivideo.py +++ b/scripts/mochivideo.py @@ -17,26 +17,13 @@ def show(self, is_img2img): # return signature is array of gradio components def ui(self, _is_img2img): - def video_type_change(video_type): - return [ - gr.update(visible=video_type != 'None'), - gr.update(visible=video_type == 'GIF' or video_type == 'PNG'), - gr.update(visible=video_type == 'MP4'), - gr.update(visible=video_type == 'MP4'), - ] - with gr.Row(): gr.HTML('  Mochi.1 Video
') with gr.Row(): num_frames = gr.Slider(label='Frames', minimum=9, maximum=257, step=1, value=45) with gr.Row(): - video_type = gr.Dropdown(label='Video file', choices=['None', 'GIF', 'PNG', 'MP4'], value='None') - duration = gr.Slider(label='Duration', minimum=0.25, maximum=10, step=0.25, value=2, visible=False) - with gr.Row(): - gif_loop = gr.Checkbox(label='Loop', value=True, visible=False) - mp4_pad = gr.Slider(label='Pad frames', minimum=0, maximum=24, step=1, value=1, visible=False) - mp4_interpolate = gr.Slider(label='Interpolate frames', minimum=0, maximum=24, step=1, value=0, visible=False) - video_type.change(fn=video_type_change, inputs=[video_type], outputs=[duration, gif_loop, mp4_pad, mp4_interpolate]) + from modules.ui_sections import create_video_inputs + video_type, duration, gif_loop, mp4_pad, mp4_interpolate = create_video_inputs() return [num_frames, video_type, duration, gif_loop, mp4_pad, mp4_interpolate] def run(self, p: processing.StableDiffusionProcessing, num_frames, video_type, duration, gif_loop, mp4_pad, mp4_interpolate): # pylint: disable=arguments-differ, unused-argument diff --git a/scripts/pixelsmith.py b/scripts/pixelsmith.py new file mode 100644 index 000000000..a9509081c --- /dev/null +++ b/scripts/pixelsmith.py @@ -0,0 +1,78 @@ +import gradio as gr +from PIL import Image +from modules import scripts, processing, shared, sd_models, devices, images + + +class Script(scripts.Script): + def __init__(self): + super().__init__() + self.orig_pipe = None + self.orig_vae = None + self.vae = None + + def title(self): + return 'PixelSmith' + + def show(self, is_img2img): + return shared.native + + def ui(self, _is_img2img): # ui elements + with gr.Row(): + gr.HTML('  PixelSmith
') + with gr.Row(): + slider = gr.Slider(label="Slider", value=20, minimum=0, maximum=100, step=1) + return [slider] + + def encode(self, p: processing.StableDiffusionProcessing, image: Image.Image): + if image is None: + return None + import numpy as np + import torch + if p.width is None or p.width == 0: + p.width = int(8 * (image.width * p.scale_by // 8)) + if p.height is None or p.height == 0: + p.height = int(8 * (image.height * p.scale_by // 8)) + image = images.resize_image(p.resize_mode, image, p.width, p.height, upscaler_name=p.resize_name, context=p.resize_context) + tensor = np.array(image).astype(np.float16) / 255.0 + tensor = tensor[None].transpose(0, 3, 1, 2) + # image = image.transpose(0, 3, 1, 2) + tensor = torch.from_numpy(tensor).to(device=devices.device, dtype=devices.dtype) + tensor = 2.0 * tensor - 1.0 + with devices.inference_context(): + latent = shared.sd_model.vae.tiled_encode(tensor) + latent = shared.sd_model.vae.config.scaling_factor * latent.latent_dist.sample() + shared.log.info(f'PixelSmith encode: image={image} latent={latent.shape} width={p.width} height={p.height} vae={shared.sd_model.vae.__class__.__name__}') + return latent + + + def run(self, p: processing.StableDiffusionProcessing, slider: int = 20): # pylint: disable=arguments-differ + supported_model_list = ['sdxl'] + if shared.sd_model_type not in supported_model_list: + shared.log.warning(f'PixelSmith: class={shared.sd_model.__class__.__name__} model={shared.sd_model_type} required={supported_model_list}') + from modules.pixelsmith import PixelSmithXLPipeline, PixelSmithVAE + self.orig_pipe = shared.sd_model + self.orig_vae = shared.sd_model.vae + if self.vae is None: + self.vae = PixelSmithVAE.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=devices.dtype).to(devices.device) + shared.sd_model = sd_models.switch_pipe(PixelSmithXLPipeline, shared.sd_model) + shared.sd_model.vae = self.vae + shared.sd_model.vae.enable_tiling() + p.extra_generation_params["PixelSmith"] = f'Slider={slider}' + p.sampler_name = 'DDIM' + p.task_args['slider'] = slider + # p.task_args['output_type'] = 'pil' + if hasattr(p, 'init_images') and p.init_images is not None and len(p.init_images) > 0: + p.task_args['image'] = self.encode(p, p.init_images[0]) + p.init_images = None + shared.log.info(f'PixelSmith apply: slider={slider} class={shared.sd_model.__class__.__name__} vae={shared.sd_model.vae.__class__.__name__}') + # processed = processing.process_images(p) + + def after(self, p: processing.StableDiffusionProcessing, processed: processing.Processed, slider): # pylint: disable=unused-argument + if self.orig_pipe is None: + return processed + if shared.sd_model.__class__.__name__ == 'PixelSmithXLPipeline': + shared.sd_model = self.orig_pipe + shared.sd_model.vae = self.orig_vae + self.orig_pipe = None + self.orig_vae = None + return processed diff --git a/scripts/prompt_enhance.py b/scripts/prompt_enhance.py index a02aea608..0ad17bba4 100644 --- a/scripts/prompt_enhance.py +++ b/scripts/prompt_enhance.py @@ -2,6 +2,7 @@ import time import random +import threading from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import gradio as gr from modules import shared, scripts, devices, processing @@ -9,6 +10,7 @@ repo_id = "gokaygokay/Flux-Prompt-Enhance" num_return_sequences = 5 +load_lock = threading.Lock() class Script(scripts.Script): @@ -31,11 +33,12 @@ def show(self, is_img2img): return shared.native def load(self): - if self.tokenizer is None: - self.tokenizer = AutoTokenizer.from_pretrained('gokaygokay/Flux-Prompt-Enhance', cache_dir=shared.opts.hfcache_dir) - if self.model is None: - shared.log.info(f'Prompt enhance: model="{repo_id}"') - self.model = AutoModelForSeq2SeqLM.from_pretrained('gokaygokay/Flux-Prompt-Enhance', cache_dir=shared.opts.hfcache_dir).to(device=devices.cpu, dtype=devices.dtype) + with load_lock: + if self.tokenizer is None: + self.tokenizer = AutoTokenizer.from_pretrained('gokaygokay/Flux-Prompt-Enhance', cache_dir=shared.opts.hfcache_dir) + if self.model is None: + shared.log.info(f'Prompt enhance: model="{repo_id}"') + self.model = AutoModelForSeq2SeqLM.from_pretrained('gokaygokay/Flux-Prompt-Enhance', cache_dir=shared.opts.hfcache_dir).to(device=devices.cpu, dtype=devices.dtype) def enhance(self, prompt, auto_apply: bool = False, temperature: float = 0.7, repetition_penalty: float = 1.2, max_length: int = 128): self.load() diff --git a/scripts/pulid_ext.py b/scripts/pulid_ext.py index ee08e348b..4f2890e52 100644 --- a/scripts/pulid_ext.py +++ b/scripts/pulid_ext.py @@ -3,10 +3,8 @@ import time import contextlib import gradio as gr -import numpy as np from PIL import Image from modules import shared, devices, errors, scripts, processing, processing_helpers, sd_models -from modules.api.api import decode_base64_to_image debug = os.environ.get('SD_PULID_DEBUG', None) is not None @@ -59,12 +57,16 @@ def fun(p, x, xs): # pylint: disable=unused-argument xyz_classes.axis_options.append(option) + def decode_image(self, b64): + from modules.api.api import decode_base64_to_image + return decode_base64_to_image(b64) + def load_images(self, files): uploaded_images.clear() for file in files or []: try: if isinstance(file, str): - image = decode_base64_to_image(file) + image = self.decode_image(file) elif isinstance(file, Image.Image): image = file elif isinstance(file, dict) and 'name' in file: @@ -113,16 +115,17 @@ def run( version: str = 'v1.1' ): # pylint: disable=arguments-differ, unused-argument images = [] + import numpy as np try: if gallery is None or (isinstance(gallery, list) and len(gallery) == 0): images = getattr(p, 'pulid_images', uploaded_images) - images = [decode_base64_to_image(image) if isinstance(image, str) else image for image in images] + images = [self.decode_image(image) if isinstance(image, str) else image for image in images] elif isinstance(gallery[0], dict): images = [Image.open(f['name']) for f in gallery] elif isinstance(gallery, str): - images = [decode_base64_to_image(gallery)] + images = [self.decode_image(gallery)] elif isinstance(gallery[0], str): - images = [decode_base64_to_image(f) for f in gallery] + images = [self.decode_image(f) for f in gallery] else: images = gallery images = [np.array(image) for image in images] diff --git a/scripts/stablevideodiffusion.py b/scripts/stablevideodiffusion.py index c1283e1b6..59cba1e96 100644 --- a/scripts/stablevideodiffusion.py +++ b/scripts/stablevideodiffusion.py @@ -23,14 +23,6 @@ def show(self, is_img2img): # return signature is array of gradio components def ui(self, _is_img2img): - def video_type_change(video_type): - return [ - gr.update(visible=video_type != 'None'), - gr.update(visible=video_type == 'GIF' or video_type == 'PNG'), - gr.update(visible=video_type == 'MP4'), - gr.update(visible=video_type == 'MP4'), - ] - with gr.Row(): gr.HTML('  Stable Video Diffusion
') with gr.Row(): @@ -46,13 +38,8 @@ def video_type_change(video_type): with gr.Row(): override_resolution = gr.Checkbox(label='Override resolution', value=True) with gr.Row(): - video_type = gr.Dropdown(label='Video file', choices=['None', 'GIF', 'PNG', 'MP4'], value='None') - duration = gr.Slider(label='Duration', minimum=0.25, maximum=10, step=0.25, value=2, visible=False) - with gr.Row(): - gif_loop = gr.Checkbox(label='Loop', value=True, visible=False) - mp4_pad = gr.Slider(label='Pad frames', minimum=0, maximum=24, step=1, value=1, visible=False) - mp4_interpolate = gr.Slider(label='Interpolate frames', minimum=0, maximum=24, step=1, value=0, visible=False) - video_type.change(fn=video_type_change, inputs=[video_type], outputs=[duration, gif_loop, mp4_pad, mp4_interpolate]) + from modules.ui_sections import create_video_inputs + video_type, duration, gif_loop, mp4_pad, mp4_interpolate = create_video_inputs() return [model, num_frames, override_resolution, min_guidance_scale, max_guidance_scale, decode_chunk_size, motion_bucket_id, noise_aug_strength, video_type, duration, gif_loop, mp4_pad, mp4_interpolate] def run(self, p: processing.StableDiffusionProcessing, model, num_frames, override_resolution, min_guidance_scale, max_guidance_scale, decode_chunk_size, motion_bucket_id, noise_aug_strength, video_type, duration, gif_loop, mp4_pad, mp4_interpolate): # pylint: disable=arguments-differ, unused-argument @@ -75,7 +62,7 @@ def run(self, p: processing.StableDiffusionProcessing, model, num_frames, overri if model_name != model_loaded or c != 'StableVideoDiffusionPipeline': shared.opts.sd_model_checkpoint = model_path sd_models.reload_model_weights() - shared.sd_model = shared.sd_model.to(torch.float32) # TODO svd: runs in fp32 causing dtype mismatch + shared.sd_model = shared.sd_model.to(torch.float32) # TODO svd: runs in fp32 due to dtype mismatch # set params if override_resolution: diff --git a/scripts/text2video.py b/scripts/text2video.py index c7b3d1c05..a58340f60 100644 --- a/scripts/text2video.py +++ b/scripts/text2video.py @@ -31,14 +31,6 @@ def show(self, is_img2img): # return signature is array of gradio components def ui(self, _is_img2img): - def video_type_change(video_type): - return [ - gr.update(visible=video_type != 'None'), - gr.update(visible=video_type == 'GIF' or video_type == 'PNG'), - gr.update(visible=video_type == 'MP4'), - gr.update(visible=video_type == 'MP4'), - ] - def model_info_change(model_name): if model_name == 'None': return gr.update(value='') @@ -57,13 +49,8 @@ def model_info_change(model_name): use_default = gr.Checkbox(label='Use defaults', value=True) num_frames = gr.Slider(label='Frames', minimum=1, maximum=50, step=1, value=0) with gr.Row(): - video_type = gr.Dropdown(label='Video file', choices=['None', 'GIF', 'PNG', 'MP4'], value='None') - duration = gr.Slider(label='Duration', minimum=0.25, maximum=10, step=0.25, value=2, visible=False) - with gr.Row(): - gif_loop = gr.Checkbox(label='Loop', value=True, visible=False) - mp4_pad = gr.Slider(label='Pad frames', minimum=0, maximum=24, step=1, value=1, visible=False) - mp4_interpolate = gr.Slider(label='Interpolate frames', minimum=0, maximum=24, step=1, value=0, visible=False) - video_type.change(fn=video_type_change, inputs=[video_type], outputs=[duration, gif_loop, mp4_pad, mp4_interpolate]) + from modules.ui_sections import create_video_inputs + video_type, duration, gif_loop, mp4_pad, mp4_interpolate = create_video_inputs() return [model_name, use_default, num_frames, video_type, duration, gif_loop, mp4_pad, mp4_interpolate] def run(self, p: processing.StableDiffusionProcessing, model_name, use_default, num_frames, video_type, duration, gif_loop, mp4_pad, mp4_interpolate): # pylint: disable=arguments-differ, unused-argument diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index bb067ea21..38ee177e1 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -15,6 +15,7 @@ from scripts.xyz_grid_shared import apply_field, apply_task_args, apply_setting, apply_prompt, apply_order, apply_sampler, apply_hr_sampler_name, confirm_samplers, apply_checkpoint, apply_refiner, apply_unet, apply_dict, apply_clip_skip, apply_vae, list_lora, apply_lora, apply_lora_strength, apply_te, apply_styles, apply_upscaler, apply_context, apply_detailer, apply_override, apply_processing, apply_options, apply_seed, format_value_add_label, format_value, format_value_join_list, do_nothing, format_nothing # pylint: disable=no-name-in-module, unused-import from modules import shared, errors, scripts, images, processing from modules.ui_components import ToolButton +from modules.ui_sections import create_video_inputs import modules.ui_symbols as symbols @@ -64,23 +65,8 @@ def ui(self, is_img2img): create_video = gr.Checkbox(label='Create video', value=False, elem_id=self.elem_id("xyz_create_video"), container=False) with gr.Row(visible=False) as ui_video: - def video_type_change(video_type): - return [ - gr.update(visible=video_type != 'None'), - gr.update(visible=video_type == 'GIF' or video_type == 'PNG'), - gr.update(visible=video_type == 'MP4'), - gr.update(visible=video_type == 'MP4'), - ] - - with gr.Column(): - video_type = gr.Dropdown(label='Video type', choices=['None', 'GIF', 'PNG', 'MP4'], value='None') - with gr.Column(): - video_duration = gr.Slider(label='Duration', minimum=0.25, maximum=300, step=0.25, value=2, visible=False) - video_loop = gr.Checkbox(label='Loop', value=True, visible=False, elem_id="control_video_loop") - video_pad = gr.Slider(label='Pad frames', minimum=0, maximum=24, step=1, value=1, visible=False) - video_interpolate = gr.Slider(label='Interpolate frames', minimum=0, maximum=24, step=1, value=0, visible=False) - video_type.change(fn=video_type_change, inputs=[video_type], outputs=[video_duration, video_loop, video_pad, video_interpolate]) - create_video.change(fn=lambda x: gr.update(visible=x), inputs=[create_video], outputs=[ui_video]) + video_type, video_duration, video_loop, video_pad, video_interpolate = create_video_inputs() + create_video.change(fn=lambda x: gr.update(visible=x), inputs=[create_video], outputs=[ui_video]) with gr.Row(): margin_size = gr.Slider(label="Grid margins", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size")) @@ -253,6 +239,7 @@ def fix_axis_seeds(axis_opt, axis_list): ys = fix_axis_seeds(y_opt, ys) zs = fix_axis_seeds(z_opt, zs) + total_jobs = len(xs) * len(ys) * len(zs) if x_opt.label == 'Steps': total_steps = sum(xs) * len(ys) * len(zs) elif y_opt.label == 'Steps': @@ -260,7 +247,7 @@ def fix_axis_seeds(axis_opt, axis_list): elif z_opt.label == 'Steps': total_steps = sum(zs) * len(xs) * len(ys) else: - total_steps = p.steps * len(xs) * len(ys) * len(zs) + total_steps = p.steps * total_jobs if isinstance(p, processing.StableDiffusionProcessingTxt2Img) and p.enable_hr: if x_opt.label == "Hires steps": total_steps += sum(xs) * len(ys) * len(zs) @@ -269,10 +256,12 @@ def fix_axis_seeds(axis_opt, axis_list): elif z_opt.label == "Hires steps": total_steps += sum(zs) * len(xs) * len(ys) elif p.hr_second_pass_steps: - total_steps += p.hr_second_pass_steps * len(xs) * len(ys) * len(zs) + total_steps += p.hr_second_pass_steps * total_jobs else: total_steps *= 2 total_steps *= p.n_iter + shared.state.update('Grid', total_steps, total_jobs * p.n_iter) + image_cell_count = p.n_iter * p.batch_size shared.log.info(f"XYZ grid: images={len(xs)*len(ys)*len(zs)*image_cell_count} grid={len(zs)} shape={len(xs)}x{len(ys)} cells={len(zs)} steps={total_steps}") AxisInfo = namedtuple('AxisInfo', ['axis', 'values']) diff --git a/scripts/xyz_grid_classes.py b/scripts/xyz_grid_classes.py index cc70d68f8..2b1fa2d59 100644 --- a/scripts/xyz_grid_classes.py +++ b/scripts/xyz_grid_classes.py @@ -1,4 +1,4 @@ -from scripts.xyz_grid_shared import apply_field, apply_task_args, apply_setting, apply_prompt, apply_order, apply_sampler, apply_hr_sampler_name, confirm_samplers, apply_checkpoint, apply_refiner, apply_unet, apply_dict, apply_clip_skip, apply_vae, list_lora, apply_lora, apply_lora_strength, apply_te, apply_styles, apply_upscaler, apply_context, apply_detailer, apply_override, apply_processing, apply_options, apply_seed, format_value_add_label, format_value, format_value_join_list, do_nothing, format_nothing, str_permutations # pylint: disable=no-name-in-module, unused-import +from scripts.xyz_grid_shared import apply_field, apply_task_args, apply_setting, apply_prompt_primary, apply_prompt_refine, apply_prompt_detailer, apply_prompt_all, apply_order, apply_sampler, apply_hr_sampler_name, confirm_samplers, apply_checkpoint, apply_refiner, apply_unet, apply_dict, apply_clip_skip, apply_vae, list_lora, apply_lora, apply_lora_strength, apply_te, apply_styles, apply_upscaler, apply_context, apply_detailer, apply_override, apply_processing, apply_options, apply_seed, format_value_add_label, format_value, format_value_join_list, do_nothing, format_nothing, str_permutations # pylint: disable=no-name-in-module, unused-import from modules import shared, shared_items, sd_samplers, ipadapter, sd_models, sd_vae, sd_unet @@ -93,7 +93,10 @@ def __exit__(self, exc_type, exc_value, tb): AxisOption("[Model] Refiner", str, apply_refiner, cost=0.8, fmt=format_value_add_label, choices=lambda: ['None'] + sorted(sd_models.checkpoints_list)), AxisOption("[Model] Text encoder", str, apply_te, cost=0.7, choices=shared_items.sd_te_items), AxisOption("[Model] Dictionary", str, apply_dict, fmt=format_value_add_label, cost=0.9, choices=lambda: ['None'] + list(sd_models.checkpoints_list)), - AxisOption("[Prompt] Search & replace", str, apply_prompt, fmt=format_value_add_label), + AxisOption("[Prompt] Search & replace", str, apply_prompt_primary, fmt=format_value_add_label), + AxisOption("[Prompt] Search & replace refine", str, apply_prompt_refine, fmt=format_value_add_label), + AxisOption("[Prompt] Search & replace detailer", str, apply_prompt_detailer, fmt=format_value_add_label), + AxisOption("[Prompt] Search & replace all", str, apply_prompt_all, fmt=format_value_add_label), AxisOption("[Prompt] Prompt order", str_permutations, apply_order, fmt=format_value_join_list), AxisOption("[Prompt] Prompt parser", str, apply_setting("prompt_attention"), choices=lambda: ["native", "compel", "xhinker", "a1111", "fixed"]), AxisOption("[Network] LoRA", str, apply_lora, cost=0.5, choices=list_lora), diff --git a/scripts/xyz_grid_draw.py b/scripts/xyz_grid_draw.py index cd7eb8d1f..bac96bb8b 100644 --- a/scripts/xyz_grid_draw.py +++ b/scripts/xyz_grid_draw.py @@ -10,7 +10,7 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend z_texts = [[images.GridAnnotation(z)] for z in z_labels] list_size = (len(xs) * len(ys) * len(zs)) processed_result = None - shared.state.job_count = list_size * p.n_iter + t0 = time.time() i = 0 @@ -22,7 +22,6 @@ def process_cell(x, y, z, ix, iy, iz): def index(ix, iy, iz): return ix + iy * len(xs) + iz * len(xs) * len(ys) - shared.state.job = 'Grid' p0 = time.time() processed: processing.Processed = cell(x, y, z, ix, iy, iz) p1 = time.time() @@ -63,7 +62,7 @@ def index(ix, iy, iz): cell_mode = processed_result.images[0].mode cell_size = processed_result.images[0].size processed_result.images[idx] = Image.new(cell_mode, cell_size) - return + shared.state.nextjob() if first_axes_processed == 'x': for ix, x in enumerate(xs): @@ -129,5 +128,6 @@ def index(ix, iy, iz): processed_result.infotexts.insert(0, processed_result.infotexts[0]) t2 = time.time() - shared.log.info(f'XYZ grid complete: images={list_size} size={grid.size if grid is not None else None} time={t1-t0:.2f} save={t2-t1:.2f}') + shared.log.info(f'XYZ grid complete: images={list_size} results={len(processed_result.images)}size={grid.size if grid is not None else None} time={t1-t0:.2f} save={t2-t1:.2f}') + p.skip_processing = True return processed_result diff --git a/scripts/xyz_grid_on.py b/scripts/xyz_grid_on.py index 0abd0fa1c..2a8f22aba 100644 --- a/scripts/xyz_grid_on.py +++ b/scripts/xyz_grid_on.py @@ -14,11 +14,12 @@ from scripts.xyz_grid_draw import draw_xyz_grid # pylint: disable=no-name-in-module from modules import shared, errors, scripts, images, processing from modules.ui_components import ToolButton +from modules.ui_sections import create_video_inputs import modules.ui_symbols as symbols active = False -cache = None +xyz_results_cache = None debug = shared.log.trace if os.environ.get('SD_XYZ_DEBUG', None) is not None else lambda *args, **kwargs: None @@ -70,23 +71,8 @@ def ui(self, is_img2img): create_video = gr.Checkbox(label='Create video', value=False, elem_id=self.elem_id("xyz_create_video"), container=False) with gr.Row(visible=False) as ui_video: - def video_type_change(video_type): - return [ - gr.update(visible=video_type != 'None'), - gr.update(visible=video_type == 'GIF' or video_type == 'PNG'), - gr.update(visible=video_type == 'MP4'), - gr.update(visible=video_type == 'MP4'), - ] - - with gr.Column(): - video_type = gr.Dropdown(label='Video type', choices=['None', 'GIF', 'PNG', 'MP4'], value='None') - with gr.Column(): - video_duration = gr.Slider(label='Duration', minimum=0.25, maximum=300, step=0.25, value=2, visible=False) - video_loop = gr.Checkbox(label='Loop', value=True, visible=False, elem_id="control_video_loop") - video_pad = gr.Slider(label='Pad frames', minimum=0, maximum=24, step=1, value=1, visible=False) - video_interpolate = gr.Slider(label='Interpolate frames', minimum=0, maximum=24, step=1, value=0, visible=False) - video_type.change(fn=video_type_change, inputs=[video_type], outputs=[video_duration, video_loop, video_pad, video_interpolate]) - create_video.change(fn=lambda x: gr.update(visible=x), inputs=[create_video], outputs=[ui_video]) + video_type, video_duration, video_loop, video_pad, video_interpolate = create_video_inputs() + create_video.change(fn=lambda x: gr.update(visible=x), inputs=[create_video], outputs=[ui_video]) with gr.Row(): margin_size = gr.Slider(label="Grid margins", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size")) @@ -188,8 +174,8 @@ def process(self, p, include_time, include_text, margin_size, create_video, video_type, video_duration, video_loop, video_pad, video_interpolate, ): # pylint: disable=W0221 - global active, cache # pylint: disable=W0603 - cache = None + global active, xyz_results_cache # pylint: disable=W0603 + xyz_results_cache = None if not enabled or active: return active = True @@ -266,6 +252,7 @@ def fix_axis_seeds(axis_opt, axis_list): ys = fix_axis_seeds(y_opt, ys) zs = fix_axis_seeds(z_opt, zs) + total_jobs = len(xs) * len(ys) * len(zs) if x_opt.label == 'Steps': total_steps = sum(xs) * len(ys) * len(zs) elif y_opt.label == 'Steps': @@ -273,8 +260,8 @@ def fix_axis_seeds(axis_opt, axis_list): elif z_opt.label == 'Steps': total_steps = sum(zs) * len(xs) * len(ys) else: - total_steps = p.steps * len(xs) * len(ys) * len(zs) - if isinstance(p, processing.StableDiffusionProcessingTxt2Img) and p.enable_hr: + total_steps = p.steps * total_jobs + if p.enable_hr: if x_opt.label == "Hires steps": total_steps += sum(xs) * len(ys) * len(zs) elif y_opt.label == "Hires steps": @@ -282,10 +269,16 @@ def fix_axis_seeds(axis_opt, axis_list): elif z_opt.label == "Hires steps": total_steps += sum(zs) * len(xs) * len(ys) elif p.hr_second_pass_steps: - total_steps += p.hr_second_pass_steps * len(xs) * len(ys) * len(zs) + total_steps += p.hr_second_pass_steps * total_jobs else: total_steps *= 2 + if p.detailer_enabled: + total_steps += p.detailer_steps * total_jobs + total_steps *= p.n_iter + total_jobs *= p.n_iter + shared.state.update('Grid', total_steps, total_jobs) + image_cell_count = p.n_iter * p.batch_size shared.log.info(f"XYZ grid start: images={len(xs)*len(ys)*len(zs)*image_cell_count} grid={len(zs)} shape={len(xs)}x{len(ys)} cells={len(zs)} steps={total_steps}") AxisInfo = namedtuple('AxisInfo', ['axis', 'values']) @@ -360,7 +353,7 @@ def cell(x, y, z, ix, iy, iz): return processed with SharedSettingsStackHelper(): - processed = draw_xyz_grid( + processed: processing.Processed = draw_xyz_grid( p, xs=xs, ys=ys, @@ -418,19 +411,15 @@ def cell(x, y, z, ix, iy, iz): p.do_not_save_samples = True p.disable_extra_networks = True active = False - cache = processed + xyz_results_cache = processed return processed def process_images(self, p, *args): # pylint: disable=W0221, W0613 - if hasattr(cache, 'used'): - cache.images.clear() - cache.used = False - elif cache is not None and len(cache.images) > 0: - cache.used = True + if xyz_results_cache is not None and len(xyz_results_cache.images) > 0: p.restore_faces = False - p.detailer = False + p.detailer_enabled = False p.color_corrections = None - p.scripts = None - return cache + # p.scripts = None + return xyz_results_cache return None diff --git a/scripts/xyz_grid_shared.py b/scripts/xyz_grid_shared.py index f9bf26c67..efd70724b 100644 --- a/scripts/xyz_grid_shared.py +++ b/scripts/xyz_grid_shared.py @@ -62,18 +62,37 @@ def apply_seed(p, x, xs): shared.log.debug(f'XYZ grid apply seed: {x}') -def apply_prompt(p, x, xs): +def apply_prompt(positive, negative, p, x, xs): for s in xs: - if s in p.prompt: - shared.log.debug(f'XYZ grid apply prompt: "{s}"="{x}"') - p.prompt = p.prompt.replace(s, x) - if s in p.negative_prompt: - shared.log.debug(f'XYZ grid apply negative: "{s}"="{x}"') - p.negative_prompt = p.negative_prompt.replace(s, x) + shared.log.debug(f'XYZ grid apply prompt: fields={positive}/{negative} "{s}"="{x}"') + orig_positive = getattr(p, positive) + orig_negative = getattr(p, negative) + if s in orig_positive: + setattr(p, positive, orig_positive.replace(s, x)) + if s in orig_negative: + setattr(p, negative, orig_negative.replace(s, x)) + + +def apply_prompt_primary(p, x, xs): + apply_prompt('prompt', 'negative_prompt', p, x, xs) p.all_prompts = None p.all_negative_prompts = None +def apply_prompt_refine(p, x, xs): + apply_prompt('refiner_prompt', 'refiner_negative', p, x, xs) + + +def apply_prompt_detailer(p, x, xs): + apply_prompt('detailer_prompt', 'detailer_negative', p, x, xs) + + +def apply_prompt_all(p, x, xs): + apply_prompt('prompt', 'negative_prompt', p, x, xs) + apply_prompt('refiner_prompt', 'refiner_negative', p, x, xs) + apply_prompt('detailer_prompt', 'detailer_negative', p, x, xs) + + def apply_order(p, x, xs): token_order = [] for token in x: @@ -251,7 +270,7 @@ def apply_detailer(p, opt, x): p.detailer_model = 'GFPGAN' else: is_active = opt in ('true', 'yes', 'y', '1') - p.detailer = is_active + p.detailer_enabled = is_active shared.log.debug(f'XYZ grid apply face-restore: "{x}"') diff --git a/webui.py b/webui.py index 3c8361ecf..f104bbee8 100644 --- a/webui.py +++ b/webui.py @@ -1,6 +1,7 @@ import io import os import sys +import time import glob import signal import asyncio @@ -10,13 +11,9 @@ from threading import Thread import modules.hashes import modules.loader -import torch # pylint: disable=wrong-import-order -from modules import timer, errors, paths # pylint: disable=unused-import + from installer import log, git_commit, custom_excepthook -# import ldm.modules.encoders.modules # pylint: disable=unused-import, wrong-import-order -from modules import shared, extensions, gr_tempdir, modelloader # pylint: disable=ungrouped-imports -from modules import extra_networks, ui_extra_networks # pylint: disable=ungrouped-imports -from modules.paths import create_paths +from modules import timer, paths, shared, extensions, gr_tempdir, modelloader from modules.call_queue import queue_lock, wrap_queued_call, wrap_gradio_gpu_call # pylint: disable=unused-import import modules.devices import modules.sd_checkpoint @@ -32,23 +29,28 @@ import modules.txt2img import modules.img2img import modules.upscaler +import modules.extra_networks +import modules.ui_extra_networks import modules.textual_inversion.textual_inversion import modules.hypernetworks.hypernetwork import modules.script_callbacks -from modules.api.middleware import setup_middleware -from modules.shared import cmd_opts, opts # pylint: disable=unused-import +import modules.api.middleware + +if not modules.loader.initialized: + timer.startup.record("libraries") + import modules.sd_hijack # runs conditional load of ldm if not shared.native + timer.startup.record("ldm") +modules.loader.initialized = True sys.excepthook = custom_excepthook local_url = None state = shared.state backend = shared.backend -if not modules.loader.initialized: - timer.startup.record("libraries") -if cmd_opts.server_name: - server_name = cmd_opts.server_name +if shared.cmd_opts.server_name: + server_name = shared.cmd_opts.server_name else: - server_name = "0.0.0.0" if cmd_opts.listen else None + server_name = "0.0.0.0" if shared.cmd_opts.listen else None fastapi_args = { "version": f'0.0.{git_commit}', "title": "SD.Next", @@ -59,30 +61,12 @@ # "redoc_url": "/redocs" if cmd_opts.docs else None, } -import modules.sd_hijack -timer.startup.record("ldm") -modules.loader.initialized = True - - -def check_rollback_vae(): - if shared.cmd_opts.rollback_vae: - if not torch.cuda.is_available(): - log.error("Rollback VAE functionality requires compatible GPU") - shared.cmd_opts.rollback_vae = False - elif torch.__version__.startswith('1.') or torch.__version__.startswith('2.0'): - log.error("Rollback VAE functionality requires Torch 2.1 or higher") - shared.cmd_opts.rollback_vae = False - elif 0 < torch.cuda.get_device_capability()[0] < 8: - log.error('Rollback VAE functionality device capabilities not met') - shared.cmd_opts.rollback_vae = False - def initialize(): log.debug('Initializing') modules.sd_checkpoint.init_metadata() modules.hashes.init_cache() - check_rollback_vae() log.debug(f'Huggingface cache: path="{shared.opts.hfcache_dir}"') @@ -135,20 +119,20 @@ def initialize(): shared.reload_hypernetworks() timer.startup.record("hypernetworks") - ui_extra_networks.initialize() - ui_extra_networks.register_pages() - extra_networks.initialize() - extra_networks.register_default_extra_networks() + modules.ui_extra_networks.initialize() + modules.ui_extra_networks.register_pages() + modules.extra_networks.initialize() + modules.extra_networks.register_default_extra_networks() timer.startup.record("networks") - if cmd_opts.tls_keyfile is not None and cmd_opts.tls_certfile is not None: + if shared.cmd_opts.tls_keyfile is not None and shared.cmd_opts.tls_certfile is not None: try: - if not os.path.exists(cmd_opts.tls_keyfile): + if not os.path.exists(shared.cmd_opts.tls_keyfile): log.error("Invalid path to TLS keyfile given") - if not os.path.exists(cmd_opts.tls_certfile): - log.error(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'") + if not os.path.exists(shared.cmd_opts.tls_certfile): + log.error(f"Invalid path to TLS certfile: '{shared.cmd_opts.tls_certfile}'") except TypeError: - cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None + shared.cmd_opts.tls_keyfile = shared.cmd_opts.tls_certfile = None log.error("TLS setup invalid, running webui without TLS") else: log.info("Running with TLS") @@ -156,6 +140,7 @@ def initialize(): # make the program just exit at ctrl+c without waiting for anything def sigint_handler(_sig, _frame): + log.trace(f'State history: uptime={round(time.time() - shared.state.server_start)} jobs={len(shared.state.job_history)} tasks={len(shared.state.task_history)} latents={shared.state.latent_history} images={shared.state.image_history}') log.info('Exiting') try: for f in glob.glob("*.lock"): @@ -169,16 +154,16 @@ def sigint_handler(_sig, _frame): def load_model(): if not shared.opts.sd_checkpoint_autoload and shared.cmd_opts.ckpt is None: - log.debug('Model auto load disabled') + log.info('Model auto load disabled') else: shared.state.begin('Load') thread_model = Thread(target=lambda: shared.sd_model) thread_model.start() thread_refiner = Thread(target=lambda: shared.sd_refiner) thread_refiner.start() - shared.state.end() thread_model.join() thread_refiner.join() + shared.state.end() timer.startup.record("checkpoint") shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(op='model')), call=False) shared.opts.onchange("sd_model_refiner", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(op='refiner')), call=False) @@ -229,7 +214,7 @@ def start_common(): log.info(f'Using data path: {shared.cmd_opts.data_dir}') if shared.cmd_opts.models_dir is not None and len(shared.cmd_opts.models_dir) > 0 and shared.cmd_opts.models_dir != 'models': log.info(f'Models path: {shared.cmd_opts.models_dir}') - create_paths(shared.opts) + paths.create_paths(shared.opts) async_policy() initialize() try: @@ -249,20 +234,20 @@ def start_ui(): timer.startup.record("before-ui") shared.demo = modules.ui.create_ui(timer.startup) timer.startup.record("ui") - if cmd_opts.disable_queue: + if shared.cmd_opts.disable_queue: log.info('Server queues disabled') shared.demo.progress_tracking = False else: shared.demo.queue(concurrency_count=64) gradio_auth_creds = [] - if cmd_opts.auth: - gradio_auth_creds += [x.strip() for x in cmd_opts.auth.strip('"').replace('\n', '').split(',') if x.strip()] - if cmd_opts.auth_file: - if not os.path.exists(cmd_opts.auth_file): - log.error(f"Invalid path to auth file: '{cmd_opts.auth_file}'") + if shared.cmd_opts.auth: + gradio_auth_creds += [x.strip() for x in shared.cmd_opts.auth.strip('"').replace('\n', '').split(',') if x.strip()] + if shared.cmd_opts.auth_file: + if not os.path.exists(shared.cmd_opts.auth_file): + log.error(f"Invalid path to auth file: '{shared.cmd_opts.auth_file}'") else: - with open(cmd_opts.auth_file, 'r', encoding="utf8") as file: + with open(shared.cmd_opts.auth_file, 'r', encoding="utf8") as file: for line in file.readlines(): gradio_auth_creds += [x.strip() for x in line.split(',') if x.strip()] if len(gradio_auth_creds) > 0: @@ -271,19 +256,19 @@ def start_ui(): global local_url # pylint: disable=global-statement stdout = io.StringIO() allowed_paths = [os.path.dirname(__file__)] - if cmd_opts.data_dir is not None and os.path.isdir(cmd_opts.data_dir): - allowed_paths.append(cmd_opts.data_dir) - if cmd_opts.allowed_paths is not None: - allowed_paths += [p for p in cmd_opts.allowed_paths if os.path.isdir(p)] + if shared.cmd_opts.data_dir is not None and os.path.isdir(shared.cmd_opts.data_dir): + allowed_paths.append(shared.cmd_opts.data_dir) + if shared.cmd_opts.allowed_paths is not None: + allowed_paths += [p for p in shared.cmd_opts.allowed_paths if os.path.isdir(p)] shared.log.debug(f'Root paths: {allowed_paths}') with contextlib.redirect_stdout(stdout): app, local_url, share_url = shared.demo.launch( # app is FastAPI(Starlette) instance - share=cmd_opts.share, + share=shared.cmd_opts.share, server_name=server_name, - server_port=cmd_opts.port if cmd_opts.port != 7860 else None, - ssl_keyfile=cmd_opts.tls_keyfile, - ssl_certfile=cmd_opts.tls_certfile, - ssl_verify=not cmd_opts.tls_selfsign, + server_port=shared.cmd_opts.port if shared.cmd_opts.port != 7860 else None, + ssl_keyfile=shared.cmd_opts.tls_keyfile, + ssl_certfile=shared.cmd_opts.tls_certfile, + ssl_verify=not shared.cmd_opts.tls_selfsign, debug=False, auth=[tuple(cred.split(':')) for cred in gradio_auth_creds] if gradio_auth_creds else None, prevent_thread_lock=True, @@ -293,24 +278,24 @@ def start_ui(): favicon_path='html/favicon.svg', allowed_paths=allowed_paths, app_kwargs=fastapi_args, - _frontend=True and cmd_opts.share, + _frontend=True and shared.cmd_opts.share, ) - if cmd_opts.data_dir is not None: - gr_tempdir.register_tmp_file(shared.demo, os.path.join(cmd_opts.data_dir, 'x')) + if shared.cmd_opts.data_dir is not None: + gr_tempdir.register_tmp_file(shared.demo, os.path.join(shared.cmd_opts.data_dir, 'x')) shared.log.info(f'Local URL: {local_url}') - if cmd_opts.docs: + if shared.cmd_opts.docs: shared.log.info(f'API Docs: {local_url[:-1]}/docs') # pylint: disable=unsubscriptable-object shared.log.info(f'API ReDocs: {local_url[:-1]}/redocs') # pylint: disable=unsubscriptable-object if share_url is not None: shared.log.info(f'Share URL: {share_url}') # shared.log.debug(f'Gradio functions: registered={len(shared.demo.fns)}') shared.demo.server.wants_restart = False - setup_middleware(app, cmd_opts) + modules.api.middleware.setup_middleware(app, shared.cmd_opts) - if cmd_opts.subpath: + if shared.cmd_opts.subpath: import gradio - gradio.mount_gradio_app(app, shared.demo, path=f"/{cmd_opts.subpath}") - shared.log.info(f'Redirector mounted: /{cmd_opts.subpath}') + gradio.mount_gradio_app(app, shared.demo, path=f"/{shared.cmd_opts.subpath}") + shared.log.info(f'Redirector mounted: /{shared.cmd_opts.subpath}') timer.startup.record("launch") @@ -318,7 +303,7 @@ def start_ui(): shared.api = create_api(app) timer.startup.record("api") - ui_extra_networks.init_api(app) + modules.ui_extra_networks.init_api(app) modules.script_callbacks.app_started_callback(shared.demo, app) timer.startup.record("app-started") @@ -343,7 +328,7 @@ def webui(restart=False): modules.sd_models.write_metadata() load_model() shared.opts.save(shared.config_filename) - if cmd_opts.profile: + if shared.cmd_opts.profile: for k, v in modules.script_callbacks.callback_map.items(): shared.log.debug(f'Registered callbacks: {k}={len(v)} {[c.script for c in v]}') debug = log.trace if os.environ.get('SD_SCRIPT_DEBUG', None) is not None else lambda *args, **kwargs: None @@ -354,7 +339,15 @@ def webui(restart=False): for m in modules.scripts.postprocessing_scripts_data: debug(f' {m}') modules.script_callbacks.print_timers() - log.info(f"Startup time: {timer.startup.summary()}") + + if shared.cmd_opts.profile: + log.info(f"Launch time: {timer.launch.summary(min_time=0)}") + log.info(f"Installer time: {timer.init.summary(min_time=0)}") + log.info(f"Startup time: {timer.startup.summary(min_time=0)}") + else: + timer.startup.add('launch', timer.launch.get_total()) + timer.startup.add('installer', timer.launch.get_total()) + log.info(f"Startup time: {timer.startup.summary()}") timer.startup.reset() if not restart: @@ -364,8 +357,8 @@ def webui(restart=False): continue logger.handlers = log.handlers # autolaunch only on initial start - if (shared.opts.autolaunch or cmd_opts.autolaunch) and local_url is not None: - cmd_opts.autolaunch = False + if (shared.opts.autolaunch or shared.cmd_opts.autolaunch) and local_url is not None: + shared.cmd_opts.autolaunch = False shared.log.info('Launching browser') import webbrowser webbrowser.open(local_url, new=2, autoraise=True) @@ -380,7 +373,7 @@ def api_only(): start_common() from fastapi import FastAPI app = FastAPI(**fastapi_args) - setup_middleware(app, cmd_opts) + modules.api.middleware.setup_middleware(app, shared.cmd_opts) shared.api = create_api(app) shared.api.wants_restart = False modules.script_callbacks.app_started_callback(None, app) @@ -391,7 +384,7 @@ def api_only(): if __name__ == "__main__": - if cmd_opts.api_only: + if shared.cmd_opts.api_only: api_only() else: webui() diff --git a/wiki b/wiki index 7bd8f8200..7f072b554 160000 --- a/wiki +++ b/wiki @@ -1 +1 @@ -Subproject commit 7bd8f82008007bc7a766ca47b4dc1a54470397df +Subproject commit 7f072b554c6ee2edadc33879e5b4bdbfa48e6282