From c5df80197e16e5aa788387d7850becfed09a24a6 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Wed, 10 May 2023 15:12:02 -0400 Subject: [PATCH] new extension manager --- .gitignore | 5 +- .../stable-diffusion-webui-images-browser | 2 +- installer.py | 7 +- javascript/aspectRatioOverlay.js | 213 ++- javascript/contextMenus.js | 352 ++--- javascript/dragdrop.js | 4 +- javascript/edit-attention.js | 238 +-- javascript/extensions.js | 73 +- javascript/extraNetworks.js | 340 ++--- javascript/style.css | 18 +- javascript/textualInversion.js | 23 +- modules/call_queue.py | 11 +- modules/extensions.py | 39 +- modules/shared.py | 2 +- modules/ui_extensions.py | 463 +++--- scripts/custom_code.py | 180 +-- scripts/img2imgalt.py | 352 ++--- scripts/loopback.py | 280 ++-- scripts/outpainting_mk_2.py | 563 ++++--- scripts/poor_mans_outpainting.py | 289 ++-- scripts/postprocessing_codeformer.py | 70 +- scripts/postprocessing_gfpgan.py | 64 +- scripts/postprocessing_upscale.py | 268 ++-- scripts/prompt_matrix.py | 217 ++- scripts/prompts_from_file.py | 342 ++--- scripts/sd_upscale.py | 200 ++- scripts/xyz_grid.py | 1350 ++++++++--------- 27 files changed, 2964 insertions(+), 3001 deletions(-) diff --git a/.gitignore b/.gitignore index 3382bf80d..1c4049a7b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,14 +1,15 @@ # defaults __pycache__ -/params.txt /cache.json /config.json -/ui-config.json +/params.txt /setup.log /styles.csv +/ui-config.json /user.css /webui-user.bat /webui-user.sh +/html/extensions.json /javascript/themes.json node_modules pnpm-lock.yaml diff --git a/extensions-builtin/stable-diffusion-webui-images-browser b/extensions-builtin/stable-diffusion-webui-images-browser index c751a9eee..5f9e5ad49 160000 --- a/extensions-builtin/stable-diffusion-webui-images-browser +++ b/extensions-builtin/stable-diffusion-webui-images-browser @@ -1 +1 @@ -Subproject commit c751a9eeedef738fd0a731db1dd65690690b6161 +Subproject commit 5f9e5ad49035b6237f6c20f528d9313bb41e6f75 diff --git a/installer.py b/installer.py index be164ca73..84b1fe746 100644 --- a/installer.py +++ b/installer.py @@ -318,8 +318,9 @@ def run_extension_installer(folder): # get list of all enabled extensions def list_extensions(folder): - if opts.get('disable_all_extensions', 'none') != 'none': - log.debug('Disabled extensions: all') + disabled_extensions = opts.get('disable_all_extensions', 'none') + if disabled_extensions != 'none': + log.debug(f'Disabled extensions: {disabled_extensions}') return [] disabled_extensions = set(opts.get('disabled_extensions', [])) if len(disabled_extensions) > 0: @@ -350,7 +351,7 @@ def install_extensions(): if not args.skip_extensions: run_extension_installer(os.path.join(folder, ext)) log.info(f'Extensions enabled: {extensions_enabled}') - if (len(extensions_duplicates) > 0): + if len(extensions_duplicates) > 0: log.warning(f'Extensions duplicates: {extensions_duplicates}') diff --git a/javascript/aspectRatioOverlay.js b/javascript/aspectRatioOverlay.js index faec76d54..055d5bd7f 100644 --- a/javascript/aspectRatioOverlay.js +++ b/javascript/aspectRatioOverlay.js @@ -1,107 +1,106 @@ -let currentWidth = null; -let currentHeight = null; -let arFrameTimeout = setTimeout(() => {}, 0); - -function dimensionChange(e, is_width, is_height) { - if (is_width) { - currentWidth = e.target.value * 1.0; - } - if (is_height) { - currentHeight = e.target.value * 1.0; - } - - const inImg2img = gradioApp().querySelector('#tab_img2img').style.display == 'block'; - - if (!inImg2img) { - return; - } - - let targetElement = null; - - const tabIndex = get_tab_index('mode_img2img'); - if (tabIndex == 0) { // img2img - targetElement = gradioApp().querySelector('#img2img_image div[data-testid=image] img'); - } else if (tabIndex == 1) { // Sketch - targetElement = gradioApp().querySelector('#img2img_sketch div[data-testid=image] img'); - } else if (tabIndex == 2) { // Inpaint - targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img'); - } else if (tabIndex == 3) { // Inpaint sketch - targetElement = gradioApp().querySelector('#inpaint_sketch div[data-testid=image] img'); - } - - if (targetElement) { - let arPreviewRect = gradioApp().querySelector('#imageARPreview'); - if (!arPreviewRect) { - arPreviewRect = document.createElement('div'); - arPreviewRect.id = 'imageARPreview'; - gradioApp().appendChild(arPreviewRect); - } - - const viewportOffset = targetElement.getBoundingClientRect(); - - viewportscale = Math.min(targetElement.clientWidth / targetElement.naturalWidth, targetElement.clientHeight / targetElement.naturalHeight); - - scaledx = targetElement.naturalWidth * viewportscale; - scaledy = targetElement.naturalHeight * viewportscale; - - cleintRectTop = (viewportOffset.top + window.scrollY); - cleintRectLeft = (viewportOffset.left + window.scrollX); - cleintRectCentreY = cleintRectTop + (targetElement.clientHeight / 2); - cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth / 2); - - viewRectTop = cleintRectCentreY - (scaledy / 2); - viewRectLeft = cleintRectCentreX - (scaledx / 2); - arRectWidth = scaledx; - arRectHeight = scaledy; - - arscale = Math.min(arRectWidth / currentWidth, arRectHeight / currentHeight); - arscaledx = currentWidth * arscale; - arscaledy = currentHeight * arscale; - - arRectTop = cleintRectCentreY - (arscaledy / 2); - arRectLeft = cleintRectCentreX - (arscaledx / 2); - arRectWidth = arscaledx; - arRectHeight = arscaledy; - - arPreviewRect.style.top = `${arRectTop}px`; - arPreviewRect.style.left = `${arRectLeft}px`; - arPreviewRect.style.width = `${arRectWidth}px`; - arPreviewRect.style.height = `${arRectHeight}px`; - - clearTimeout(arFrameTimeout); - arFrameTimeout = setTimeout(() => { - arPreviewRect.style.display = 'none'; - }, 2000); - - arPreviewRect.style.display = 'block'; - } -} - -onUiUpdate(() => { - const arPreviewRect = gradioApp().querySelector('#imageARPreview'); - if (arPreviewRect) { - arPreviewRect.style.display = 'none'; - } - const tabImg2img = gradioApp().querySelector('#tab_img2img'); - if (tabImg2img) { - const inImg2img = tabImg2img.style.display == 'block'; - if (inImg2img) { - const inputs = gradioApp().querySelectorAll('input'); - inputs.forEach((e) => { - const is_width = e.parentElement.id == 'img2img_width'; - const is_height = e.parentElement.id == 'img2img_height'; - - if ((is_width || is_height) && !e.classList.contains('scrollwatch')) { - e.addEventListener('input', (e) => { dimensionChange(e, is_width, is_height); }); - e.classList.add('scrollwatch'); - } - if (is_width) { - currentWidth = e.value * 1.0; - } - if (is_height) { - currentHeight = e.value * 1.0; - } - }); - } - } -}); +let currentWidth = null; +let currentHeight = null; +let arFrameTimeout = setTimeout(() => {}, 0); + +function dimensionChange(e, is_width, is_height) { + if (is_width) { + currentWidth = e.target.value * 1.0; + } + if (is_height) { + currentHeight = e.target.value * 1.0; + } + + const inImg2img = gradioApp().querySelector('#tab_img2img').style.display === 'block'; + + if (!inImg2img) { + return; + } + + let targetElement = null; + + const tabIndex = get_tab_index('mode_img2img'); + if (tabIndex === 0) { // img2img + targetElement = gradioApp().querySelector('#img2img_image div[data-testid=image] img'); + } else if (tabIndex === 1) { // Sketch + targetElement = gradioApp().querySelector('#img2img_sketch div[data-testid=image] img'); + } else if (tabIndex === 2) { // Inpaint + targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img'); + } else if (tabIndex === 3) { // Inpaint sketch + targetElement = gradioApp().querySelector('#inpaint_sketch div[data-testid=image] img'); + } + + if (targetElement) { + let arPreviewRect = gradioApp().querySelector('#imageARPreview'); + if (!arPreviewRect) { + arPreviewRect = document.createElement('div'); + arPreviewRect.id = 'imageARPreview'; + gradioApp().appendChild(arPreviewRect); + } + + const viewportOffset = targetElement.getBoundingClientRect(); + + viewportscale = Math.min(targetElement.clientWidth / targetElement.naturalWidth, targetElement.clientHeight / targetElement.naturalHeight); + + scaledx = targetElement.naturalWidth * viewportscale; + scaledy = targetElement.naturalHeight * viewportscale; + + cleintRectTop = (viewportOffset.top + window.scrollY); + cleintRectLeft = (viewportOffset.left + window.scrollX); + cleintRectCentreY = cleintRectTop + (targetElement.clientHeight / 2); + cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth / 2); + + viewRectTop = cleintRectCentreY - (scaledy / 2); + viewRectLeft = cleintRectCentreX - (scaledx / 2); + arRectWidth = scaledx; + arRectHeight = scaledy; + + arscale = Math.min(arRectWidth / currentWidth, arRectHeight / currentHeight); + arscaledx = currentWidth * arscale; + arscaledy = currentHeight * arscale; + + arRectTop = cleintRectCentreY - (arscaledy / 2); + arRectLeft = cleintRectCentreX - (arscaledx / 2); + arRectWidth = arscaledx; + arRectHeight = arscaledy; + + arPreviewRect.style.top = `${arRectTop}px`; + arPreviewRect.style.left = `${arRectLeft}px`; + arPreviewRect.style.width = `${arRectWidth}px`; + arPreviewRect.style.height = `${arRectHeight}px`; + + clearTimeout(arFrameTimeout); + arFrameTimeout = setTimeout(() => { + arPreviewRect.style.display = 'none'; + }, 2000); + arPreviewRect.style.display = 'block'; + } +} + +onUiUpdate(() => { + const arPreviewRect = gradioApp().querySelector('#imageARPreview'); + if (arPreviewRect) { + arPreviewRect.style.display = 'none'; + } + const tabImg2img = gradioApp().querySelector('#tab_img2img'); + if (tabImg2img) { + const inImg2img = tabImg2img.style.display === 'block'; + if (inImg2img) { + const inputs = gradioApp().querySelectorAll('input'); + inputs.forEach((e) => { + const is_width = e.parentElement.id === 'img2img_width'; + const is_height = e.parentElement.id === 'img2img_height'; + + if ((is_width || is_height) && !e.classList.contains('scrollwatch')) { + e.addEventListener('input', (e) => { dimensionChange(e, is_width, is_height); }); + e.classList.add('scrollwatch'); + } + if (is_width) { + currentWidth = e.value * 1.0; + } + if (is_height) { + currentHeight = e.value * 1.0; + } + }); + } + } +}); diff --git a/javascript/contextMenus.js b/javascript/contextMenus.js index 39d880165..2296ba297 100644 --- a/javascript/contextMenus.js +++ b/javascript/contextMenus.js @@ -1,176 +1,176 @@ -contextMenuInit = function () { - let eventListenerApplied = false; - const menuSpecs = new Map(); - - const uid = function () { - return Date.now().toString(36) + Math.random().toString(36).substr(2); - }; - - function showContextMenu(event, element, menuEntries) { - const posx = event.clientX + document.body.scrollLeft + document.documentElement.scrollLeft; - const posy = event.clientY + document.body.scrollTop + document.documentElement.scrollTop; - - const oldMenu = gradioApp().querySelector('#context-menu'); - if (oldMenu) { - oldMenu.remove(); - } - - const tabButton = uiCurrentTab; - const baseStyle = window.getComputedStyle(tabButton); - - const contextMenu = document.createElement('nav'); - contextMenu.id = 'context-menu'; - contextMenu.style.background = baseStyle.background; - contextMenu.style.color = baseStyle.color; - contextMenu.style.fontFamily = baseStyle.fontFamily; - contextMenu.style.top = `${posy}px`; - contextMenu.style.left = `${posx}px`; - - const contextMenuList = document.createElement('ul'); - contextMenuList.className = 'context-menu-items'; - contextMenu.append(contextMenuList); - - menuEntries.forEach((entry) => { - const contextMenuEntry = document.createElement('a'); - contextMenuEntry.innerHTML = entry.name; - contextMenuEntry.addEventListener('click', (e) => { - entry.func(); - }); - contextMenuList.append(contextMenuEntry); - }); - - gradioApp().appendChild(contextMenu); - - const menuWidth = contextMenu.offsetWidth + 4; - const menuHeight = contextMenu.offsetHeight + 4; - - const windowWidth = window.innerWidth; - const windowHeight = window.innerHeight; - - if ((windowWidth - posx) < menuWidth) { - contextMenu.style.left = `${windowWidth - menuWidth}px`; - } - - if ((windowHeight - posy) < menuHeight) { - contextMenu.style.top = `${windowHeight - menuHeight}px`; - } - } - - function appendContextMenuOption(targetElementSelector, entryName, entryFunction) { - currentItems = menuSpecs.get(targetElementSelector); - - if (!currentItems) { - currentItems = []; - menuSpecs.set(targetElementSelector, currentItems); - } - const newItem = { - id: `${targetElementSelector}_${uid()}`, - name: entryName, - func: entryFunction, - isNew: true, - }; - - currentItems.push(newItem); - return newItem.id; - } - - function removeContextMenuOption(uid) { - menuSpecs.forEach((v, k) => { - let index = -1; - v.forEach((e, ei) => { if (e.id == uid) { index = ei; } }); - if (index >= 0) { - v.splice(index, 1); - } - }); - } - - function addContextMenuEventListener() { - if (eventListenerApplied) { - return; - } - gradioApp().addEventListener('click', (e) => { - const source = e.composedPath()[0]; - if (source.id && source.id.indexOf('check_progress') > -1) { - return; - } - - const oldMenu = gradioApp().querySelector('#context-menu'); - if (oldMenu) { - oldMenu.remove(); - } - }); - gradioApp().addEventListener('contextmenu', (e) => { - const oldMenu = gradioApp().querySelector('#context-menu'); - if (oldMenu) { - oldMenu.remove(); - } - menuSpecs.forEach((v, k) => { - if (e.composedPath()[0].matches(k)) { - showContextMenu(e, e.composedPath()[0], v); - e.preventDefault(); - } - }); - }); - eventListenerApplied = true; - } - - return [appendContextMenuOption, removeContextMenuOption, addContextMenuEventListener]; -}; - -initResponse = contextMenuInit(); -appendContextMenuOption = initResponse[0]; -removeContextMenuOption = initResponse[1]; -addContextMenuEventListener = initResponse[2]; - -(function () { - // Start example Context Menu Items - const generateOnRepeat = function (genbuttonid, interruptbuttonid) { - const genbutton = gradioApp().querySelector(genbuttonid); - const busy = document.getElementById('progressbar')?.style.display == 'block'; - if (!busy) { - genbutton.click(); - } - clearInterval(window.generateOnRepeatInterval); - window.generateOnRepeatInterval = setInterval( - () => { - const busy = document.getElementById('progressbar')?.style.display == 'block'; - if (!busy) { - genbutton.click(); - } - }, - 500, - ); - }; - - appendContextMenuOption('#txt2img_generate', 'Generate forever', () => { - generateOnRepeat('#txt2img_generate', '#txt2img_interrupt'); - }); - appendContextMenuOption('#img2img_generate', 'Generate forever', () => { - generateOnRepeat('#img2img_generate', '#img2img_interrupt'); - }); - - const cancelGenerateForever = function () { - clearInterval(window.generateOnRepeatInterval); - }; - - appendContextMenuOption('#txt2img_interrupt', 'Cancel generate forever', cancelGenerateForever); - appendContextMenuOption('#txt2img_generate', 'Cancel generate forever', cancelGenerateForever); - appendContextMenuOption('#img2img_interrupt', 'Cancel generate forever', cancelGenerateForever); - appendContextMenuOption('#img2img_generate', 'Cancel generate forever', cancelGenerateForever); - - appendContextMenuOption( - '#roll', - 'Roll three', - () => { - const rollbutton = get_uiCurrentTabContent().querySelector('#roll'); - setTimeout(() => { rollbutton.click(); }, 100); - setTimeout(() => { rollbutton.click(); }, 200); - setTimeout(() => { rollbutton.click(); }, 300); - }, - ); -}()); -// End example Context Menu Items - -onUiUpdate(() => { - addContextMenuEventListener(); -}); +contextMenuInit = function () { + let eventListenerApplied = false; + const menuSpecs = new Map(); + + const uid = function () { + return Date.now().toString(36) + Math.random().toString(36).substr(2); + }; + + function showContextMenu(event, element, menuEntries) { + const posx = event.clientX + document.body.scrollLeft + document.documentElement.scrollLeft; + const posy = event.clientY + document.body.scrollTop + document.documentElement.scrollTop; + + const oldMenu = gradioApp().querySelector('#context-menu'); + if (oldMenu) { + oldMenu.remove(); + } + + const tabButton = uiCurrentTab; + const baseStyle = window.getComputedStyle(tabButton); + + const contextMenu = document.createElement('nav'); + contextMenu.id = 'context-menu'; + contextMenu.style.background = baseStyle.background; + contextMenu.style.color = baseStyle.color; + contextMenu.style.fontFamily = baseStyle.fontFamily; + contextMenu.style.top = `${posy}px`; + contextMenu.style.left = `${posx}px`; + + const contextMenuList = document.createElement('ul'); + contextMenuList.className = 'context-menu-items'; + contextMenu.append(contextMenuList); + + menuEntries.forEach((entry) => { + const contextMenuEntry = document.createElement('a'); + contextMenuEntry.innerHTML = entry.name; + contextMenuEntry.addEventListener('click', (e) => { + entry.func(); + }); + contextMenuList.append(contextMenuEntry); + }); + + gradioApp().appendChild(contextMenu); + + const menuWidth = contextMenu.offsetWidth + 4; + const menuHeight = contextMenu.offsetHeight + 4; + + const windowWidth = window.innerWidth; + const windowHeight = window.innerHeight; + + if ((windowWidth - posx) < menuWidth) { + contextMenu.style.left = `${windowWidth - menuWidth}px`; + } + + if ((windowHeight - posy) < menuHeight) { + contextMenu.style.top = `${windowHeight - menuHeight}px`; + } + } + + function appendContextMenuOption(targetElementSelector, entryName, entryFunction) { + currentItems = menuSpecs.get(targetElementSelector); + + if (!currentItems) { + currentItems = []; + menuSpecs.set(targetElementSelector, currentItems); + } + const newItem = { + id: `${targetElementSelector}_${uid()}`, + name: entryName, + func: entryFunction, + isNew: true, + }; + + currentItems.push(newItem); + return newItem.id; + } + + function removeContextMenuOption(uid) { + menuSpecs.forEach((v, k) => { + let index = -1; + v.forEach((e, ei) => { if (e.id === uid) { index = ei; } }); + if (index >= 0) { + v.splice(index, 1); + } + }); + } + + function addContextMenuEventListener() { + if (eventListenerApplied) { + return; + } + gradioApp().addEventListener('click', (e) => { + const source = e.composedPath()[0]; + if (source.id && source.id.indexOf('check_progress') > -1) { + return; + } + + const oldMenu = gradioApp().querySelector('#context-menu'); + if (oldMenu) { + oldMenu.remove(); + } + }); + gradioApp().addEventListener('contextmenu', (e) => { + const oldMenu = gradioApp().querySelector('#context-menu'); + if (oldMenu) { + oldMenu.remove(); + } + menuSpecs.forEach((v, k) => { + if (e.composedPath()[0].matches(k)) { + showContextMenu(e, e.composedPath()[0], v); + e.preventDefault(); + } + }); + }); + eventListenerApplied = true; + } + + return [appendContextMenuOption, removeContextMenuOption, addContextMenuEventListener]; +}; + +initResponse = contextMenuInit(); +appendContextMenuOption = initResponse[0]; +removeContextMenuOption = initResponse[1]; +addContextMenuEventListener = initResponse[2]; + +(function () { + // Start example Context Menu Items + const generateOnRepeat = function (genbuttonid, interruptbuttonid) { + const genbutton = gradioApp().querySelector(genbuttonid); + const busy = document.getElementById('progressbar')?.style.display === 'block'; + if (!busy) { + genbutton.click(); + } + clearInterval(window.generateOnRepeatInterval); + window.generateOnRepeatInterval = setInterval( + () => { + const busy = document.getElementById('progressbar')?.style.display === 'block'; + if (!busy) { + genbutton.click(); + } + }, + 500, + ); + }; + + appendContextMenuOption('#txt2img_generate', 'Generate forever', () => { + generateOnRepeat('#txt2img_generate', '#txt2img_interrupt'); + }); + appendContextMenuOption('#img2img_generate', 'Generate forever', () => { + generateOnRepeat('#img2img_generate', '#img2img_interrupt'); + }); + + const cancelGenerateForever = function () { + clearInterval(window.generateOnRepeatInterval); + }; + + appendContextMenuOption('#txt2img_interrupt', 'Cancel generate forever', cancelGenerateForever); + appendContextMenuOption('#txt2img_generate', 'Cancel generate forever', cancelGenerateForever); + appendContextMenuOption('#img2img_interrupt', 'Cancel generate forever', cancelGenerateForever); + appendContextMenuOption('#img2img_generate', 'Cancel generate forever', cancelGenerateForever); + + appendContextMenuOption( + '#roll', + 'Roll three', + () => { + const rollbutton = get_uiCurrentTabContent().querySelector('#roll'); + setTimeout(() => { rollbutton.click(); }, 100); + setTimeout(() => { rollbutton.click(); }, 200); + setTimeout(() => { rollbutton.click(); }, 300); + }, + ); +}()); +// End example Context Menu Items + +onUiUpdate(() => { + addContextMenuEventListener(); +}); diff --git a/javascript/dragdrop.js b/javascript/dragdrop.js index 27fa4341d..55e229e89 100644 --- a/javascript/dragdrop.js +++ b/javascript/dragdrop.js @@ -47,7 +47,7 @@ function dropReplaceImage(imgWrap, files) { window.document.addEventListener('dragover', (e) => { const target = e.composedPath()[0]; const imgWrap = target.closest('[data-testid="image"]'); - if (!imgWrap && target.placeholder && target.placeholder.indexOf('Prompt') == -1) return; + if (!imgWrap && target.placeholder && target.placeholder.indexOf('Prompt') === -1) return; e.stopPropagation(); e.preventDefault(); e.dataTransfer.dropEffect = 'copy'; @@ -56,7 +56,7 @@ window.document.addEventListener('dragover', (e) => { window.document.addEventListener('drop', (e) => { const target = e.composedPath()[0]; if (!target.placeholder) return; - if (target.placeholder.indexOf('Prompt') == -1) return; + if (target.placeholder.indexOf('Prompt') === -1) return; const imgWrap = target.closest('[data-testid="image"]'); if (!imgWrap) return; e.stopPropagation(); diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js index 87ec52617..594baae67 100644 --- a/javascript/edit-attention.js +++ b/javascript/edit-attention.js @@ -1,119 +1,119 @@ -function keyupEditAttention(event) { - const target = event.originalTarget || event.composedPath()[0]; - if (!target.matches("[id*='_toprow'] [id*='_prompt'] textarea")) return; - if (!(event.metaKey || event.ctrlKey)) return; - - const isPlus = event.key == 'ArrowUp'; - const isMinus = event.key == 'ArrowDown'; - if (!isPlus && !isMinus) return; - - let { selectionStart } = target; - let { selectionEnd } = target; - let text = target.value; - - function selectCurrentParenthesisBlock(OPEN, CLOSE) { - if (selectionStart !== selectionEnd) return false; - - // Find opening parenthesis around current cursor - const before = text.substring(0, selectionStart); - let beforeParen = before.lastIndexOf(OPEN); - if (beforeParen == -1) return false; - let beforeParenClose = before.lastIndexOf(CLOSE); - while (beforeParenClose !== -1 && beforeParenClose > beforeParen) { - beforeParen = before.lastIndexOf(OPEN, beforeParen - 1); - beforeParenClose = before.lastIndexOf(CLOSE, beforeParenClose - 1); - } - - // Find closing parenthesis around current cursor - const after = text.substring(selectionStart); - let afterParen = after.indexOf(CLOSE); - if (afterParen == -1) return false; - let afterParenOpen = after.indexOf(OPEN); - while (afterParenOpen !== -1 && afterParen > afterParenOpen) { - afterParen = after.indexOf(CLOSE, afterParen + 1); - afterParenOpen = after.indexOf(OPEN, afterParenOpen + 1); - } - if (beforeParen === -1 || afterParen === -1) return false; - - // Set the selection to the text between the parenthesis - const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen); - const lastColon = parenContent.lastIndexOf(':'); - selectionStart = beforeParen + 1; - selectionEnd = selectionStart + lastColon; - target.setSelectionRange(selectionStart, selectionEnd); - return true; - } - - function selectCurrentWord() { - if (selectionStart !== selectionEnd) return false; - const delimiters = `${opts.keyedit_delimiters} \r\n\t`; - - // seek backward until to find beggining - while (!delimiters.includes(text[selectionStart - 1]) && selectionStart > 0) { - selectionStart--; - } - - // seek forward to find end - while (!delimiters.includes(text[selectionEnd]) && selectionEnd < text.length) { - selectionEnd++; - } - - target.setSelectionRange(selectionStart, selectionEnd); - return true; - } - - // If the user hasn't selected anything, let's select their current parenthesis block or word - if (!selectCurrentParenthesisBlock('<', '>') && !selectCurrentParenthesisBlock('(', ')')) { - selectCurrentWord(); - } - - event.preventDefault(); - - closeCharacter = ')'; - delta = opts.keyedit_precision_attention; - - if (selectionStart > 0 && text[selectionStart - 1] == '<') { - closeCharacter = '>'; - delta = opts.keyedit_precision_extra; - } else if (selectionStart == 0 || text[selectionStart - 1] != '(') { - // do not include spaces at the end - while (selectionEnd > selectionStart && text[selectionEnd - 1] == ' ') { - selectionEnd -= 1; - } - if (selectionStart == selectionEnd) { - return; - } - - text = `${text.slice(0, selectionStart)}(${text.slice(selectionStart, selectionEnd)}:1.0)${text.slice(selectionEnd)}`; - - selectionStart += 1; - selectionEnd += 1; - } - - end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1; - weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end)); - if (isNaN(weight)) return; - - weight += isPlus ? delta : -delta; - weight = parseFloat(weight.toPrecision(12)); - if (String(weight).length == 1) weight += '.0'; - - if (closeCharacter == ')' && weight == 1) { - text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + 5); - selectionStart--; - selectionEnd--; - } else { - text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1); - } - - target.focus(); - target.value = text; - target.selectionStart = selectionStart; - target.selectionEnd = selectionEnd; - - updateInput(target); -} - -addEventListener('keydown', (event) => { - keyupEditAttention(event); -}); +function keyupEditAttention(event) { + const target = event.originalTarget || event.composedPath()[0]; + if (!target.matches("[id*='_toprow'] [id*='_prompt'] textarea")) return; + if (!(event.metaKey || event.ctrlKey)) return; + + const isPlus = event.key === 'ArrowUp'; + const isMinus = event.key === 'ArrowDown'; + if (!isPlus && !isMinus) return; + + let { selectionStart } = target; + let { selectionEnd } = target; + let text = target.value; + + function selectCurrentParenthesisBlock(OPEN, CLOSE) { + if (selectionStart !== selectionEnd) return false; + + // Find opening parenthesis around current cursor + const before = text.substring(0, selectionStart); + let beforeParen = before.lastIndexOf(OPEN); + if (beforeParen === -1) return false; + let beforeParenClose = before.lastIndexOf(CLOSE); + while (beforeParenClose !== -1 && beforeParenClose > beforeParen) { + beforeParen = before.lastIndexOf(OPEN, beforeParen - 1); + beforeParenClose = before.lastIndexOf(CLOSE, beforeParenClose - 1); + } + + // Find closing parenthesis around current cursor + const after = text.substring(selectionStart); + let afterParen = after.indexOf(CLOSE); + if (afterParen === -1) return false; + let afterParenOpen = after.indexOf(OPEN); + while (afterParenOpen !== -1 && afterParen > afterParenOpen) { + afterParen = after.indexOf(CLOSE, afterParen + 1); + afterParenOpen = after.indexOf(OPEN, afterParenOpen + 1); + } + if (beforeParen === -1 || afterParen === -1) return false; + + // Set the selection to the text between the parenthesis + const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen); + const lastColon = parenContent.lastIndexOf(':'); + selectionStart = beforeParen + 1; + selectionEnd = selectionStart + lastColon; + target.setSelectionRange(selectionStart, selectionEnd); + return true; + } + + function selectCurrentWord() { + if (selectionStart !== selectionEnd) return false; + const delimiters = `${opts.keyedit_delimiters} \r\n\t`; + + // seek backward until to find beggining + while (!delimiters.includes(text[selectionStart - 1]) && selectionStart > 0) { + selectionStart--; + } + + // seek forward to find end + while (!delimiters.includes(text[selectionEnd]) && selectionEnd < text.length) { + selectionEnd++; + } + + target.setSelectionRange(selectionStart, selectionEnd); + return true; + } + + // If the user hasn't selected anything, let's select their current parenthesis block or word + if (!selectCurrentParenthesisBlock('<', '>') && !selectCurrentParenthesisBlock('(', ')')) { + selectCurrentWord(); + } + + event.preventDefault(); + + closeCharacter = ')'; + delta = opts.keyedit_precision_attention; + + if (selectionStart > 0 && text[selectionStart - 1] === '<') { + closeCharacter = '>'; + delta = opts.keyedit_precision_extra; + } else if (selectionStart === 0 || text[selectionStart - 1] != '(') { + // do not include spaces at the end + while (selectionEnd > selectionStart && text[selectionEnd - 1] === ' ') { + selectionEnd -= 1; + } + if (selectionStart === selectionEnd) { + return; + } + + text = `${text.slice(0, selectionStart)}(${text.slice(selectionStart, selectionEnd)}:1.0)${text.slice(selectionEnd)}`; + + selectionStart += 1; + selectionEnd += 1; + } + + end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1; + weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end)); + if (isNaN(weight)) return; + + weight += isPlus ? delta : -delta; + weight = parseFloat(weight.toPrecision(12)); + if (String(weight).length === 1) weight += '.0'; + + if (closeCharacter === ')' && weight === 1) { + text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + 5); + selectionStart--; + selectionEnd--; + } else { + text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1); + } + + target.focus(); + target.value = text; + target.selectionStart = selectionStart; + target.selectionEnd = selectionEnd; + + updateInput(target); +} + +addEventListener('keydown', (event) => { + keyupEditAttention(event); +}); diff --git a/javascript/extensions.js b/javascript/extensions.js index 2bcaa50fc..5f8049d95 100644 --- a/javascript/extensions.js +++ b/javascript/extensions.js @@ -1,32 +1,41 @@ -function extensions_apply(_a, _b, disable_all){ - var disable = [] - var update = [] - gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){ - if(x.name.startsWith("enable_") && ! x.checked) disable.push(x.name.substr(7)) - if(x.name.startsWith("update_") && x.checked) update.push(x.name.substr(7)) - }) - restart_reload() - return [JSON.stringify(disable), JSON.stringify(update), disable_all] -} - -function extensions_check(_, _){ - var disable = [] - gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){ - if(x.name.startsWith("enable_") && ! x.checked) disable.push(x.name.substr(7)) - }) - gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){ - x.innerHTML = "Loading..." - }) - var id = randomId() - requestProgress(id, gradioApp().getElementById('extensions_installed_top'), null, null, null, false) - return [id, JSON.stringify(disable)] -} - -function install_extension_from_index(button, url){ - button.disabled = "disabled" - button.value = "Installing..." - textarea = gradioApp().querySelector('#extension_to_install textarea') - textarea.value = url - updateInput(textarea) - gradioApp().querySelector('#install_extension_button').click() -} +function extensions_apply(_a, _b, disable_all) { + const disable = []; + const update = []; + gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach((x) => { + if (x.name.startsWith('enable_') && !x.checked) disable.push(x.name.substr(7)); + if (x.name.startsWith('update_') && x.checked) update.push(x.name.substr(7)); + }); + restart_reload(); + return [JSON.stringify(disable), JSON.stringify(update), disable_all]; +} + +function extensions_check(_a, _b) { + const disable = []; + gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach((x) => { + if (x.name.startsWith('enable_') && !x.checked) disable.push(x.name.substr(7)); + }); + gradioApp().querySelectorAll('#extensions .extension_status').forEach((x) => { + x.innerHTML = 'Loading...'; + }); + const id = randomId(); + // requestProgress(id, gradioApp().getElementById('extensions_installed_top'), null, null, null, false); + return [id, JSON.stringify(disable)]; +} + +function install_extension(button, url) { + button.disabled = 'disabled'; + button.value = 'Installing...'; + const textarea = gradioApp().querySelector('#extension_to_install textarea'); + textarea.value = url; + updateInput(textarea); + gradioApp().querySelector('#install_extension_button').click(); +} + +function uninstall_extension(button, url) { + button.disabled = 'disabled'; + button.value = 'Uninstalling...'; + const textarea = gradioApp().querySelector('#extension_to_install textarea'); + textarea.value = url; + updateInput(textarea); + gradioApp().querySelector('#uninstall_extension_button').click(); +} diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 9b2a610f7..5b9254cd5 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -1,170 +1,170 @@ -function setupExtraNetworksForTab(tabname) { - gradioApp().querySelector(`#${tabname}_extra_tabs`).classList.add('extra-networks'); - const tabs = gradioApp().querySelector(`#${tabname}_extra_tabs > div`); - const search = gradioApp().querySelector(`#${tabname}_extra_search textarea`); - const refresh = gradioApp().getElementById(`${tabname}_extra_refresh`); - const descriptInput = gradioApp().getElementById(`${tabname}_description_input`); - const close = gradioApp().getElementById(`${tabname}_extra_close`); - search.classList.add('search'); - tabs.appendChild(search); - tabs.appendChild(refresh); - tabs.appendChild(close); - tabs.appendChild(descriptInput); - search.addEventListener('input', (evt) => { - searchTerm = search.value.toLowerCase(); - gradioApp().querySelectorAll(`#${tabname}_extra_tabs div.card`).forEach((elem) => { - text = `${elem.querySelector('.name').textContent.toLowerCase()} ${elem.querySelector('.search_term').textContent.toLowerCase()}`; - elem.style.display = text.indexOf(searchTerm) == -1 ? 'none' : ''; - }); - }); -} - -const activePromptTextarea = {}; - -function setupExtraNetworks() { - setupExtraNetworksForTab('txt2img'); - setupExtraNetworksForTab('img2img'); - function registerPrompt(tabname, id) { - const textarea = gradioApp().querySelector(`#${id} > label > textarea`); - if (!activePromptTextarea[tabname]) activePromptTextarea[tabname] = textarea; - textarea.addEventListener('focus', () => { - activePromptTextarea[tabname] = textarea; - }); - } - registerPrompt('txt2img', 'txt2img_prompt'); - registerPrompt('txt2img', 'txt2img_neg_prompt'); - registerPrompt('img2img', 'img2img_prompt'); - registerPrompt('img2img', 'img2img_neg_prompt'); -} - -onUiLoaded(setupExtraNetworks); -const re_extranet = /<([^:]+:[^:]+):[\d\.]+>/; -const re_extranet_g = /\s+<([^:]+:[^:]+):[\d\.]+>/g; - -function tryToRemoveExtraNetworkFromPrompt(textarea, text) { - let m = text.match(re_extranet); - if (!m) return false; - const partToSearch = m[1]; - let replaced = false; - const newTextareaText = textarea.value.replaceAll(re_extranet_g, (found, index) => { - m = found.match(re_extranet); - if (m[1] == partToSearch) { - replaced = true; - return ''; - } - return found; - }); - if (replaced) { - textarea.value = newTextareaText; - return true; - } - return false; -} - -function cardClicked(tabname, textToAdd, allowNegativePrompt) { - const textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector(`#${tabname}_prompt > label > textarea`); - if (!tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)) textarea.value = textarea.value + opts.extra_networks_add_text_separator + textToAdd; - updateInput(textarea); -} - -function saveCardPreview(event, tabname, filename) { - const textarea = gradioApp().querySelector(`#${tabname}_preview_filename > label > textarea`); - const button = gradioApp().getElementById(`${tabname}_save_preview`); - textarea.value = filename; - updateInput(textarea); - button.click(); - event.stopPropagation(); - event.preventDefault(); -} - -function saveCardDescription(event, tabname, filename, descript) { - const textarea = gradioApp().querySelector(`#${tabname}_description_filename > label > textarea`); - const button = gradioApp().getElementById(`${tabname}_save_description`); - const description = gradioApp().getElementById(`${tabname}_description_input`); - textarea.value = filename; - description.value = descript; - updateInput(textarea); - button.click(); - event.stopPropagation(); - event.preventDefault(); -} - -function readCardDescription(event, tabname, filename, descript, extraPage, cardName) { - const textarea = gradioApp().querySelector(`#${tabname}_description_filename > label > textarea`); - const description_textarea = gradioApp().querySelector(`#${tabname}_description_input > label > textarea`); - const button = gradioApp().getElementById(`${tabname}_read_description`); - textarea.value = filename; - description_textarea.value = descript; - updateInput(textarea); - updateInput(description_textarea); - button.click(); - event.stopPropagation(); - event.preventDefault(); -} - -function extraNetworksSearchButton(tabs_id, event) { - searchTextarea = gradioApp().querySelector(`#${tabs_id} > div > textarea`); - button = event.target; - text = button.classList.contains('search-all') ? '' : button.textContent.trim(); - searchTextarea.value = text; - updateInput(searchTextarea); -} - -let globalPopup = null; -let globalPopupInner = null; -function popup(contents) { - if (!globalPopup) { - globalPopup = document.createElement('div'); - globalPopup.onclick = function () { globalPopup.style.display = 'none'; }; - globalPopup.classList.add('global-popup'); - const close = document.createElement('div'); - close.classList.add('global-popup-close'); - close.onclick = function () { globalPopup.style.display = 'none'; }; - close.title = 'Close'; - globalPopup.appendChild(close); - globalPopupInner = document.createElement('div'); - globalPopupInner.onclick = function (event) { event.stopPropagation(); return false; }; - globalPopupInner.classList.add('global-popup-inner'); - globalPopup.appendChild(globalPopupInner); - gradioApp().appendChild(globalPopup); - } - globalPopupInner.innerHTML = ''; - globalPopupInner.appendChild(contents); - globalPopup.style.display = 'flex'; -} - -function readCardMetadata(event, extraPage, cardName) { - requestGet('./sd_extra_networks/metadata', { page: extraPage, item: cardName }, (data) => { - if (data && data.metadata) { - elem = document.createElement('pre'); - elem.classList.add('popup-metadata'); - elem.textContent = data.metadata; - popup(elem); - } - }, () => {}); - event.stopPropagation(); - event.preventDefault(); -} - -function requestGet(url, data, handler, errorHandler) { - const xhr = new XMLHttpRequest(); - const args = Object.keys(data).map((k) => `${encodeURIComponent(k)}=${encodeURIComponent(data[k])}`).join('&'); - xhr.open('GET', `${url}?${args}`, true); - xhr.onreadystatechange = function () { - if (xhr.readyState === 4) { - if (xhr.status === 200) { - try { - const js = JSON.parse(xhr.responseText); - handler(js); - } catch (error) { - console.error(error); - errorHandler(); - } - } else { - errorHandler(); - } - } - }; - const js = JSON.stringify(data); - xhr.send(js); -} +function setupExtraNetworksForTab(tabname) { + gradioApp().querySelector(`#${tabname}_extra_tabs`).classList.add('extra-networks'); + const tabs = gradioApp().querySelector(`#${tabname}_extra_tabs > div`); + const search = gradioApp().querySelector(`#${tabname}_extra_search textarea`); + const refresh = gradioApp().getElementById(`${tabname}_extra_refresh`); + const descriptInput = gradioApp().getElementById(`${tabname}_description_input`); + const close = gradioApp().getElementById(`${tabname}_extra_close`); + search.classList.add('search'); + tabs.appendChild(search); + tabs.appendChild(refresh); + tabs.appendChild(close); + tabs.appendChild(descriptInput); + search.addEventListener('input', (evt) => { + searchTerm = search.value.toLowerCase(); + gradioApp().querySelectorAll(`#${tabname}_extra_tabs div.card`).forEach((elem) => { + text = `${elem.querySelector('.name').textContent.toLowerCase()} ${elem.querySelector('.search_term').textContent.toLowerCase()}`; + elem.style.display = text.indexOf(searchTerm) == -1 ? 'none' : ''; + }); + }); +} + +const activePromptTextarea = {}; + +function setupExtraNetworks() { + setupExtraNetworksForTab('txt2img'); + setupExtraNetworksForTab('img2img'); + function registerPrompt(tabname, id) { + const textarea = gradioApp().querySelector(`#${id} > label > textarea`); + if (!activePromptTextarea[tabname]) activePromptTextarea[tabname] = textarea; + textarea.addEventListener('focus', () => { + activePromptTextarea[tabname] = textarea; + }); + } + registerPrompt('txt2img', 'txt2img_prompt'); + registerPrompt('txt2img', 'txt2img_neg_prompt'); + registerPrompt('img2img', 'img2img_prompt'); + registerPrompt('img2img', 'img2img_neg_prompt'); +} + +onUiLoaded(setupExtraNetworks); +const re_extranet = /<([^:]+:[^:]+):[\d\.]+>/; +const re_extranet_g = /\s+<([^:]+:[^:]+):[\d\.]+>/g; + +function tryToRemoveExtraNetworkFromPrompt(textarea, text) { + let m = text.match(re_extranet); + if (!m) return false; + const partToSearch = m[1]; + let replaced = false; + const newTextareaText = textarea.value.replaceAll(re_extranet_g, (found, index) => { + m = found.match(re_extranet); + if (m[1] == partToSearch) { + replaced = true; + return ''; + } + return found; + }); + if (replaced) { + textarea.value = newTextareaText; + return true; + } + return false; +} + +function cardClicked(tabname, textToAdd, allowNegativePrompt) { + const textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector(`#${tabname}_prompt > label > textarea`); + if (!tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)) textarea.value = textarea.value + opts.extra_networks_add_text_separator + textToAdd; + updateInput(textarea); +} + +function saveCardPreview(event, tabname, filename) { + const textarea = gradioApp().querySelector(`#${tabname}_preview_filename > label > textarea`); + const button = gradioApp().getElementById(`${tabname}_save_preview`); + textarea.value = filename; + updateInput(textarea); + button.click(); + event.stopPropagation(); + event.preventDefault(); +} + +function saveCardDescription(event, tabname, filename, descript) { + const textarea = gradioApp().querySelector(`#${tabname}_description_filename > label > textarea`); + const button = gradioApp().getElementById(`${tabname}_save_description`); + const description = gradioApp().getElementById(`${tabname}_description_input`); + textarea.value = filename; + description.value = descript; + updateInput(textarea); + button.click(); + event.stopPropagation(); + event.preventDefault(); +} + +function readCardDescription(event, tabname, filename, descript, extraPage, cardName) { + const textarea = gradioApp().querySelector(`#${tabname}_description_filename > label > textarea`); + const description_textarea = gradioApp().querySelector(`#${tabname}_description_input > label > textarea`); + const button = gradioApp().getElementById(`${tabname}_read_description`); + textarea.value = filename; + description_textarea.value = descript; + updateInput(textarea); + updateInput(description_textarea); + button.click(); + event.stopPropagation(); + event.preventDefault(); +} + +function extraNetworksSearchButton(tabs_id, event) { + searchTextarea = gradioApp().querySelector(`#${tabs_id} > div > textarea`); + button = event.target; + text = button.classList.contains('search-all') ? '' : button.textContent.trim(); + searchTextarea.value = text; + updateInput(searchTextarea); +} + +let globalPopup = null; +let globalPopupInner = null; +function popup(contents) { + if (!globalPopup) { + globalPopup = document.createElement('div'); + globalPopup.onclick = function () { globalPopup.style.display = 'none'; }; + globalPopup.classList.add('global-popup'); + const close = document.createElement('div'); + close.classList.add('global-popup-close'); + close.onclick = function () { globalPopup.style.display = 'none'; }; + close.title = 'Close'; + globalPopup.appendChild(close); + globalPopupInner = document.createElement('div'); + globalPopupInner.onclick = function (event) { event.stopPropagation(); return false; }; + globalPopupInner.classList.add('global-popup-inner'); + globalPopup.appendChild(globalPopupInner); + gradioApp().appendChild(globalPopup); + } + globalPopupInner.innerHTML = ''; + globalPopupInner.appendChild(contents); + globalPopup.style.display = 'flex'; +} + +function readCardMetadata(event, extraPage, cardName) { + requestGet('./sd_extra_networks/metadata', { page: extraPage, item: cardName }, (data) => { + if (data && data.metadata) { + elem = document.createElement('pre'); + elem.classList.add('popup-metadata'); + elem.textContent = data.metadata; + popup(elem); + } + }, () => {}); + event.stopPropagation(); + event.preventDefault(); +} + +function requestGet(url, data, handler, errorHandler) { + const xhr = new XMLHttpRequest(); + const args = Object.keys(data).map((k) => `${encodeURIComponent(k)}=${encodeURIComponent(data[k])}`).join('&'); + xhr.open('GET', `${url}?${args}`, true); + xhr.onreadystatechange = function () { + if (xhr.readyState === 4) { + if (xhr.status === 200) { + try { + const js = JSON.parse(xhr.responseText); + handler(js); + } catch (error) { + console.error(error); + errorHandler(); + } + } else { + errorHandler(); + } + } + }; + const js = JSON.stringify(data); + xhr.send(js); +} diff --git a/javascript/style.css b/javascript/style.css index 0f77632cb..04902104c 100644 --- a/javascript/style.css +++ b/javascript/style.css @@ -505,11 +505,25 @@ div#extras_scale_to_tab div.form{ font-size: 95%; } -#available_extensions .info{ +#extensions .name{ + font-size: 1.1rem +} + +#extensions .type{ + opacity: 0.5; + font-size: 90%; + text-align: center; +} + +#extensions .version{ + opacity: 0.7; +} + +#extensions .info{ margin: 0; } -#available_extensions .date_added{ +#extensions .date{ opacity: 0.85; font-size: 90%; } diff --git a/javascript/textualInversion.js b/javascript/textualInversion.js index 8a50ec601..db73b03d9 100644 --- a/javascript/textualInversion.js +++ b/javascript/textualInversion.js @@ -1,13 +1,10 @@ - - - -function start_training_textual_inversion(){ - gradioApp().querySelector('#ti_error').innerHTML='' - var id = randomId() - const onProgress = (progress) => gradioApp().getElementById('ti_progress').innerHTML = progress.textinfo; - // requestProgress(id_task, progressbarContainer, gallery, atEnd = null, onProgress = null, once = false) { - requestProgress(id, gradioApp().getElementById('ti_output'), gradioApp().getElementById('ti_gallery'), null, onProgress, false) - var res = args_to_array(arguments) - res[0] = id - return res -} +function start_training_textual_inversion() { + gradioApp().querySelector('#ti_error').innerHTML='' + var id = randomId() + const onProgress = (progress) => gradioApp().getElementById('ti_progress').innerHTML = progress.textinfo; + // requestProgress(id_task, progressbarContainer, gallery, atEnd = null, onProgress = null, once = false) { + requestProgress(id, gradioApp().getElementById('ti_output'), gradioApp().getElementById('ti_gallery'), null, onProgress, false) + var res = args_to_array(arguments) + res[0] = id + return res +} diff --git a/modules/call_queue.py b/modules/call_queue.py index 2ea136a19..3af8eca08 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -14,9 +14,7 @@ def wrap_queued_call(func): def f(*args, **kwargs): with queue_lock: res = func(*args, **kwargs) - return res - return f @@ -38,9 +36,9 @@ def f(*args, **kwargs): progress.finish_task(id_task) shared.state.end() return res - return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True) + def wrap_gradio_call(func, extra_outputs=None, add_stats=False): def f(*args, extra_outputs_array=extra_outputs, **kwargs): run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats @@ -73,21 +71,17 @@ def f(*args, extra_outputs_array=extra_outputs, **kwargs): if extra_outputs_array is None: extra_outputs_array = [None, ''] res = extra_outputs_array + [f"
{html.escape(type(e).__name__+': '+str(e))}
"] - shared.state.skipped = False shared.state.interrupted = False shared.state.job_count = 0 - if not add_stats: return tuple(res) - elapsed = time.perf_counter() - t elapsed_m = int(elapsed // 60) elapsed_s = elapsed % 60 elapsed_text = f"{elapsed_s:.2f}s" if elapsed_m > 0: elapsed_text = f"{elapsed_m}m "+elapsed_text - if run_memmon: mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()} active_peak = mem_stats['active_peak'] @@ -97,9 +91,6 @@ def f(*args, extra_outputs_array=extra_outputs, **kwargs): vram_html = f" |

GPU active {active_peak} MB reserved {reserved_peak} MB | System peak {sys_peak} MB total {sys_total} MB

" else: vram_html = '' - res[-1] += f"

Time taken: {elapsed_text}

{vram_html}
" - return tuple(res) - return f diff --git a/modules/extensions.py b/modules/extensions.py index b0d78259a..524830020 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -1,12 +1,11 @@ import os -import time +from datetime import datetime import git - from modules import shared, errors from modules.paths_internal import extensions_dir, extensions_builtin_dir -extensions = [] +extensions = [] if not os.path.exists(extensions_dir): os.makedirs(extensions_dir) @@ -14,7 +13,7 @@ def active(): if shared.opts.disable_all_extensions == "all": return [] - elif shared.opts.disable_all_extensions == "extra": + elif shared.opts.disable_all_extensions == "user": return [x for x in extensions if x.enabled and x.is_builtin] else: return [x for x in extensions if x.enabled] @@ -23,6 +22,7 @@ def active(): class Extension: def __init__(self, name, path, enabled=True, is_builtin=False): self.name = name + self.git_name = '' self.path = path self.enabled = enabled self.status = '' @@ -31,48 +31,48 @@ def __init__(self, name, path, enabled=True, is_builtin=False): self.commit_hash = '' self.commit_date = None self.version = '' + self.description = '' self.branch = None self.remote = None self.have_info_from_repo = False + self.mtime = 0 + self.ctime = 0 def read_info_from_repo(self): - if self.is_builtin or self.have_info_from_repo: + if self.have_info_from_repo: return - self.have_info_from_repo = True - repo = None + self.mtime = datetime.fromtimestamp(os.path.getmtime(self.path)).isoformat() + 'Z' + self.ctime = datetime.fromtimestamp(os.path.getctime(self.path)).isoformat() + 'Z' try: if os.path.exists(os.path.join(self.path, ".git")): repo = git.Repo(self.path) except Exception as e: errors.display(e, f'github info from {self.path}') - if repo is None or repo.bare: self.remote = None else: try: self.status = 'unknown' + self.git_name = repo.remotes.origin.url.split('.git')[0].split('/')[-1] + self.description = repo.description self.remote = next(repo.remote().urls, None) head = repo.head.commit self.commit_date = repo.head.commit.committed_date - ts = time.asctime(time.gmtime(self.commit_date)) if repo.active_branch: self.branch = repo.active_branch.name self.commit_hash = head.hexsha - self.version = f'{self.commit_hash[:8]} ({ts})' - + self.version = f"

{self.commit_hash[:8]}

{datetime.fromtimestamp(self.commit_date).strftime('%a %b%d %Y %H:%M')}

" except Exception as ex: shared.log.error(f"Failed reading extension data from Git repository: {self.name}: {ex}") self.remote = None def list_files(self, subdir, extension): from modules import scripts - dirpath = os.path.join(self.path, subdir) if not os.path.isdir(dirpath): return [] - res = [] for filename in sorted(os.listdir(dirpath)): priority = '50' @@ -80,9 +80,7 @@ def list_files(self, subdir, extension): with open(os.path.join(dirpath, "..", ".priority"), "r", encoding="utf-8") as f: priority = str(f.read().strip()) res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename), priority)) - res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)] - return res def check_updates(self): @@ -92,7 +90,6 @@ def check_updates(self): self.can_update = True self.status = "new commits" return - try: origin = repo.rev_parse('origin') if repo.head.commit != origin: @@ -103,7 +100,6 @@ def check_updates(self): self.can_update = False self.status = "unknown (remote error)" return - self.can_update = False self.status = "latest" @@ -119,19 +115,15 @@ def fetch_and_reset_hard(self, commit='origin'): def list_extensions(): extensions.clear() - if not os.path.isdir(extensions_dir): return - - if shared.opts.disable_all_extensions == "all" or shared.opts.disable_all_extensions == "extra": - shared.log.warning("Option set: Disable all extensions") - + if shared.opts.disable_all_extensions == "all" or shared.opts.disable_all_extensions == "user": + shared.log.warning(f"Option set: Disable extensions: {shared.opts.disable_all_extensions}") extension_paths = [] extension_names = [] for dirname in [extensions_builtin_dir, extensions_dir]: if not os.path.isdir(dirname): return - for extension_dirname in sorted(os.listdir(dirname)): path = os.path.join(dirname, extension_dirname) if not os.path.isdir(path): @@ -141,7 +133,6 @@ def list_extensions(): continue extension_names.append(extension_dirname) extension_paths.append((extension_dirname, path, dirname == extensions_builtin_dir)) - for dirname, path, is_builtin in extension_paths: extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin) extensions.append(extension) diff --git a/modules/shared.py b/modules/shared.py index 509f314ea..f82ace1c1 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -463,7 +463,7 @@ def refresh_themes(): options_templates.update(options_section((None, "Hidden options"), { "disabled_extensions": OptionInfo([], "Disable these extensions"), - "disable_all_extensions": OptionInfo("none", "Disable all extensions (preserves the list of disabled extensions)", gr.Radio, {"choices": ["none", "extra", "all"]}), + "disable_all_extensions": OptionInfo("none", "Disable all extensions (preserves the list of disabled extensions)", gr.Radio, {"choices": ["none", "user", "all"]}), "sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"), })) diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 69de248b9..585fd34aa 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -1,16 +1,51 @@ import json import os.path -import time import shutil import errno import html +from datetime import datetime import git import gradio as gr from modules import extensions, shared, paths, errors from modules.call_queue import wrap_gradio_gpu_call -available_extensions = {"extensions": []} -STYLE_PRIMARY = ' style="color: var(--primary-400)"' + +extensions_index = "https://vladmandic.github.io/sd-data/pages/extensions.json" +hide_tags = ["localization"] +extensions_list = [] + + +def update_extension_list(): + global extensions_list # pylint: disable=global-statement + try: + with open(os.path.join(paths.script_path, "html", "extensions.json"), "r", encoding="utf-8") as f: + extensions_list = json.loads(f.read()) + shared.log.debug(f'Extensions list loaded: {os.path.join(paths.script_path, "html", "extensions.json")}') + except: + shared.log.debug(f'Extensions list failed to load: {os.path.join(paths.script_path, "html", "extensions.json")}') + found = [] + for ext in extensions_list: + installed = [extension for extension in extensions.extensions if extension.git_name == ext['name'] or extension.name == ext['name']] + if len(installed) > 0: + found.append(installed[0]) + not_matched = [extension for extension in extensions.extensions if extension not in found] + for ext in not_matched: + ext.read_info_from_repo() + entry = { + "name": ext.name or "", + "description": ext.description or "", + "url": ext.remote or "", + "tags": [], + "stars": 0, + "issues": 0, + "commits": 0, + "size": 0, + "long": ext.git_name or ext.name or "", + "added": ext.ctime, + "created": ext.ctime, + "updated": ext.mtime, + } + extensions_list.append(entry) def check_access(): @@ -19,6 +54,7 @@ def check_access(): def apply_and_restart(disable_list, update_list, disable_all): check_access() + shared.log.debug(f'Extensions apply: disable={disable_list} update={update_list}') disabled = json.loads(disable_list) assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}" update = json.loads(update_list) @@ -37,18 +73,16 @@ def apply_and_restart(disable_list, update_list, disable_all): shared.restart_server(restart=True) -def check_updates(_id_task, disable_list): +def check_updates(_id_task, disable_list, search_text, sort_column): check_access() - disabled = json.loads(disable_list) assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}" - exts = [ext for ext in extensions.extensions if ext.remote is not None and ext.name not in disabled] + shared.log.info(f'Extensions update check: update={len(exts)} disabled={len(disable_list)}') shared.state.job_count = len(exts) - for ext in exts: + shared.log.debug(f'Extensions update: {ext.name}') shared.state.textinfo = ext.name - try: ext.check_updates() except FileNotFoundError as e: @@ -56,10 +90,8 @@ def check_updates(_id_task, disable_list): raise except Exception: errors.display(e, f'extensions check update: {ext.name}') - shared.state.nextjob() - - return extension_table(), "" + return refresh_extensions_list_from_data(search_text, sort_column), "Update complete, please restart the server" def make_commit_link(commit_hash, remote, text=None): @@ -72,84 +104,26 @@ def make_commit_link(commit_hash, remote, text=None): return text -def extension_table(): - code = f""" - - - - - - - - - - - - """ - - for ext in extensions.extensions: - ext.read_info_from_repo() - - remote = f"""{html.escape(ext.remote or '')}""" - - if ext.can_update: - ext_status = f"""""" - else: - ext_status = ext.status - - style = "" - if shared.opts.disable_all_extensions == "extra" and not ext.is_builtin or shared.opts.disable_all_extensions == "all": - style = STYLE_PRIMARY - - version_link = ext.version - if ext.commit_hash and ext.remote: - version_link = make_commit_link(ext.commit_hash, ext.remote, ext.version) - - code += f""" - - - - - - {ext_status} - - """ - - code += """ - -
ExtensionTypeURLVersionUpdate
{html.escape(ext.name)}{"system" if ext.is_builtin else 'user'}{remote}{version_link}
- """ - - return code - - def normalize_git_url(url): if url is None: return "" - url = url.replace(".git", "") return url -def install_extension_from_url(dirname, url, branch_name=None): +def install_extension_from_url(dirname, url, branch_name, search_text, sort_column): check_access() - assert url, 'No URL specified' - if dirname is None or dirname == "": *parts, last_part = url.split('/') # pylint: disable=unused-variable last_part = normalize_git_url(last_part) dirname = last_part - target_dir = os.path.join(extensions.extensions_dir, dirname) shared.log.info(f'Installing extension: {url} into {target_dir}') assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}' - normalized_url = normalize_git_url(url) assert len([x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url]) == 0, 'Extension with this URL is already installed' - tmpdir = os.path.join(paths.data_path, "tmp", dirname) - try: shutil.rmtree(tmpdir, True) if not branch_name: @@ -170,226 +144,235 @@ def install_extension_from_url(dirname, url, branch_name=None): shutil.move(tmpdir, target_dir) else: raise err - from launch import run_extension_installer run_extension_installer(target_dir) extensions.list_extensions() - return [extension_table(), html.escape(f"Installed into {target_dir}")] + return [refresh_extensions_list_from_data(search_text, sort_column), html.escape(f"Extension {url} installed into {target_dir}")] finally: shutil.rmtree(tmpdir, True) -def install_extension_from_index(url, hide_tags, sort_column, filter_text): - ext_table, message = install_extension_from_url(None, url) - - code, _ = refresh_available_extensions_from_data(hide_tags, sort_column, filter_text) +def install_extension(extension_to_install, search_text, sort_column): + shared.log.info(f'Extension install: {extension_to_install}') + code, message = install_extension_from_url(None, extension_to_install, None, search_text, sort_column) + return code, message - return code, ext_table, message, '' +def uninstall_extension(extension_path, search_text, sort_column): + shared.log.info(f'Extension uninstall: {extension_path}') + ext = [extension for extension in extensions.extensions if extension.path == extension_path] + if len(ext) > 0 and os.path.isdir(extension_path): + try: + shutil.rmtree(extension_path, ignore_errors=False) + except Exception as e: + shared.log.warning(f'Extension uninstall failed: {extension_path} {e}') + extensions.extensions = [extension for extension in extensions.extensions if extension.path != extension_path] + update_extension_list() + code = refresh_extensions_list_from_data(search_text, sort_column) + # return code, ext_table, message + return code, f"Uninstalled {extension_path}" -def refresh_available_extensions(url, hide_tags, sort_column): - global available_extensions # pylint: disable=global-statement +def refresh_extensions_list(search_text, sort_column): + global extensions_list # pylint: disable=global-statement import urllib.request - with urllib.request.urlopen(url) as response: - text = response.read() - - available_extensions = json.loads(text) - - code, tags = refresh_available_extensions_from_data(hide_tags, sort_column) - - return url, code, gr.CheckboxGroup.update(choices=tags), '', '' - - -def refresh_available_extensions_for_tags(hide_tags, sort_column, filter_text): - code, _ = refresh_available_extensions_from_data(hide_tags, sort_column, filter_text) - - return code, '' - - -def search_extensions(filter_text, hide_tags, sort_column): - code, _ = refresh_available_extensions_from_data(hide_tags, sort_column, filter_text) - - return code, '' - - -sort_ordering = [ - # (reverse, order_by_function) - (True, lambda x: x.get('added', 'z')), - (False, lambda x: x.get('added', 'z')), - (False, lambda x: x.get('name', 'z')), - (True, lambda x: x.get('name', 'z')), - (False, lambda x: 'z'), -] - - -def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=""): - extlist = available_extensions["extensions"] - installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions} - - tags = available_extensions.get("tags", {}) - tags_to_hide = set(hide_tags) - hidden = 0 - - code = f""" - - + try: + with urllib.request.urlopen(extensions_index) as response: + text = response.read() + extensions_list = json.loads(text) + with open(os.path.join(paths.script_path, "html", "extensions.json"), "w", encoding="utf-8") as outfile: + json_object = json.dumps(extensions_list, indent=2) + outfile.write(json_object) + shared.log.debug(f'Updated extensions list: {len(extensions_list)} {extensions_index} {outfile}') + except Exception as e: + shared.log.warning(f'Updated extensions list failed: {extensions_index} {e}') + code = refresh_extensions_list_from_data(search_text, sort_column) + return code, f'Extensions list: {len(extensions.extensions)} registered | {len(extensions_list)} available' + + +def search_extensions(search_text, sort_column): + code = refresh_extensions_list_from_data(search_text, sort_column) + return code, f'Search complete: {search_text} {sort_column}' + + +def refresh_extensions_list_from_data(search_text, sort_column): + shared.log.debug(f'Extensions manager: refresh list search={search_text} sort={sort_column}') + code = """ +
+ + + + + + + + + + - + + + - - """ - - sort_reverse, sort_function = sort_ordering[sort_column if 0 <= sort_column < len(sort_ordering) else 0] - - for ext in sorted(extlist, key=sort_function, reverse=sort_reverse): - name = ext.get("name", "noname") - added = ext.get('added', 'unknown') - url = ext.get("url", None) + """ + sort_ordering = { + "default": (True, lambda x: x.get('sort_string', '')), + "updated": (True, lambda x: x.get('updated', '2000-01-01T00:00')), + "created": (False, lambda x: x.get('created', '2000-01-01T00:00')), + "name": (False, lambda x: x.get('name', '').lower()), + "enabled": (False, lambda x: x.get('sort_enabled', '').lower()), + "size": (True, lambda x: x.get('size', 0)), + "stars": (True, lambda x: x.get('stars', 0)), + "commits": (True, lambda x: x.get('commits', 0)), + "issues": (True, lambda x: x.get('issues', 0)), + } + for ext in extensions_list: + extension = [extension for extension in extensions.extensions if extension.git_name == ext['name'] or extension.name == ext['name']] + if len(extension) > 0: + extension[0].read_info_from_repo() + ext['installed'] = len(extension) > 0 + ext['commit_date'] = extension[0].commit_date if len(extension) > 0 else 1577836800 + ext['is_builtin'] = extension[0].is_builtin if len(extension) > 0 else False + ext['version'] = extension[0].version if len(extension) > 0 else '' + ext['enabled'] = extension[0].enabled if len(extension) > 0 else '' + ext['path'] = extension[0].path if len(extension) > 0 else '' + ext['sort_string'] = f"{'1' if ext['is_builtin'] else '0'}{'1' if ext['installed'] else '0'}{ext.get('updated', '2000-01-01T00:00')}" + ext['sort_enabled'] = f"{'1' if ext['enabled'] else '0'}{'1' if ext['is_builtin'] else '0'}{'1' if ext['installed'] else '0'}{ext.get('updated', '2000-01-01T00:00')}" + sort_reverse, sort_function = sort_ordering[sort_column] + + def dt(x: str): + val = ext.get(x, None) + if val is not None: + return datetime.fromisoformat(val[:-1]).strftime('%a %b%d %Y %H:%M') + else: + return "N/A" + + for ext in sorted(extensions_list, key=sort_function, reverse=sort_reverse): + name = ext.get("name", "unknown") + added = dt('added') + created = dt('created') + updated = dt('updated') + url = ext.get('url', None) + size = ext.get('size', 0) + stars = ext.get('stars', 0) + issues = ext.get('issues', 0) + commits = ext.get('commits', 0) description = ext.get("description", "") - extension_tags = ext.get("tags", []) - - if url is None: - continue - - existing = installed_extension_urls.get(normalize_git_url(url), None) - extension_tags = extension_tags + ["installed"] if existing else extension_tags - - if len([x for x in extension_tags if x in tags_to_hide]) > 0: - hidden += 1 + installed = ext.get("installed", False) + enabled = ext.get("enabled", False) + path = ext.get("path", "") + commit_date = ext.get('commit_date', 1577836800) or 1577836800 + update_available = installed & (datetime.utcfromtimestamp(commit_date + 60 * 60) < datetime.fromisoformat(ext.get('updated', '2000-01-01T00:00:00.000Z')[:-1])) + tags = ext.get("tags", []) + tags_string = ' '.join(tags) + tags = tags + ["installed"] if installed else tags + if len([x for x in tags if x in hide_tags]) > 0: continue - - if filter_text and filter_text.strip(): - if filter_text.lower() not in html.escape(name).lower() and filter_text.lower() not in html.escape(description).lower(): - hidden += 1 + if search_text and search_text.strip(): + if search_text.lower() not in html.escape(name).lower() and search_text.lower() not in html.escape(description).lower() and search_text.lower() not in html.escape(tags_string).lower(): continue - - install_code = f"""""" - - tags_text = ", ".join([f"{x}" for x in extension_tags]) + version_code = '' + type_code = '' + install_code = '' + enabled_code = '' + if installed: + type_code = f"""
{"SYSTEM" if ext['is_builtin'] else 'USER'}
""" + version_code = f"""
{ext['version']}
""" + enabled_code = f"""""" + if not ext['is_builtin']: + install_code = f"""""" + else: + install_code = f"""""" + tags_text = ", ".join([f"{x}" for x in tags]) code += f""" - - + {enabled_code} + + + + - - - """ - - for tag in [x for x in extension_tags if x not in tags]: - tags[tag] = tag - - code += """ - -
Enabled Extension DescriptionActionTypeCurrent version
{html.escape(name)}
{tags_text}
{html.escape(description)}

Added: {html.escape(added)}

{html.escape(name)}
{tags_text}
{html.escape(description)} +

Created {html.escape(created)} | Added {html.escape(added)} | Updated {html.escape(updated)}

+

Stars {html.escape(str(stars))} | Size {html.escape(str(size))} | Commits {html.escape(str(commits))} | Issues {html.escape(str(issues))}

+
{type_code}{version_code} {install_code}
- """ - - if hidden > 0: - code += f"

Extension hidden: {hidden}

" - - return code, list(tags) + """ + code += "" + return code def create_ui(): import modules.ui - with gr.Blocks(analytics_enabled=False) as ui: + extensions_disable_all = gr.Radio(label="Disable all extensions", choices=["none", "user", "all"], value=shared.opts.disable_all_extensions, elem_id="extensions_disable_all", visible=False) + extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False) + extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False).style(container=False) with gr.Tabs(elem_id="tabs_extensions"): - with gr.TabItem("Installed", id="installed"): - + with gr.TabItem("Manage Extensions", id="manage"): with gr.Row(elem_id="extensions_installed_top"): - apply = gr.Button(value="Apply & restart", variant="primary") - check = gr.Button(value="Check for updates") - extensions_disable_all = gr.Radio(label="Disable all extensions", choices=["none", "extra", "all"], value=shared.opts.disable_all_extensions, elem_id="extensions_disable_all") - extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False) - extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False).style(container=False) - - txt = "" - if shared.opts.disable_all_extensions != "none": - txt = """ - - "Disable all extensions" was set, change it to "none" to load all extensions again - - """ - info = gr.HTML(txt) - extensions_table = gr.HTML(lambda: extension_table()) - + extension_to_install = gr.Text(elem_id="extension_to_install", visible=False) + install_extension_button = gr.Button(elem_id="install_extension_button", visible=False) + uninstall_extension_button = gr.Button(elem_id="uninstall_extension_button", visible=False) + with gr.Column(scale=4): + search_text = gr.Text(label="Search") + info = gr.HTML('Note: After any operation such as install/uninstall or enable/disable, please restart the server') + with gr.Column(scale=1): + sort_column = gr.Dropdown(value="default", label="Sort by", choices=["default", "updated", "created", "name", "size", "stars", "commits", "issues"], multiselect=False) + with gr.Column(scale=1): + refresh_extensions_button = gr.Button(value="Refresh extension list", variant="primary") + check = gr.Button(value="Update installed extensions", variant="primary") + apply = gr.Button(value="Apply changes & restart server", variant="primary") + update_extension_list() + extensions_table = gr.HTML(refresh_extensions_list_from_data(search_text.value, sort_column.value)) + check.click( + fn=wrap_gradio_gpu_call(check_updates, extra_outputs=[gr.update()]), + _js="extensions_check", + inputs=[info, extensions_disabled_list, search_text, sort_column], + outputs=[extensions_table, info], + ) apply.click( fn=apply_and_restart, _js="extensions_apply", inputs=[extensions_disabled_list, extensions_update_list, extensions_disable_all], outputs=[], ) - - check.click( - fn=wrap_gradio_gpu_call(check_updates, extra_outputs=[gr.update()]), - _js="extensions_check", - inputs=[info, extensions_disabled_list], + refresh_extensions_button.click( + fn=modules.ui.wrap_gradio_call(refresh_extensions_list, extra_outputs=[gr.update(), gr.update()]), + inputs=[search_text, sort_column], outputs=[extensions_table, info], ) - - with gr.TabItem("Available", id="available"): - with gr.Row(): - refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary") - available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json", label="Extension index URL").style(container=False) - extension_to_install = gr.Text(elem_id="extension_to_install", visible=False) - install_extension_button = gr.Button(elem_id="install_extension_button", visible=False) - - with gr.Row(): - hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"]) - sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order", ], type="index") - - with gr.Row(): - search_extensions_text = gr.Text(label="Search").style(container=False) - - install_result = gr.HTML() - available_extensions_table = gr.HTML() - - refresh_available_extensions_button.click( - fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update()]), - inputs=[available_extensions_index, hide_tags, sort_column], - outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result, search_extensions_text], - ) - install_extension_button.click( - fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]), - inputs=[extension_to_install, hide_tags, sort_column, search_extensions_text], - outputs=[available_extensions_table, extensions_table, install_result], + fn=modules.ui.wrap_gradio_call(install_extension, extra_outputs=[gr.update(), gr.update(), gr.update()]), + inputs=[extension_to_install, search_text, sort_column], + outputs=[extensions_table, info], ) - - search_extensions_text.change( - fn=modules.ui.wrap_gradio_call(search_extensions, extra_outputs=[gr.update()]), - inputs=[search_extensions_text, hide_tags, sort_column], - outputs=[available_extensions_table, install_result], + uninstall_extension_button.click( + fn=modules.ui.wrap_gradio_call(uninstall_extension, extra_outputs=[gr.update(), gr.update(), gr.update()]), + inputs=[extension_to_install, search_text, sort_column], + outputs=[extensions_table, info], ) - - hide_tags.change( - fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]), - inputs=[hide_tags, sort_column, search_extensions_text], - outputs=[available_extensions_table, install_result] + search_text.change( + fn=modules.ui.wrap_gradio_call(search_extensions, extra_outputs=[gr.update(), gr.update()]), + inputs=[search_text, sort_column], + outputs=[extensions_table, info], ) - sort_column.change( - fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]), - inputs=[hide_tags, sort_column, search_extensions_text], - outputs=[available_extensions_table, install_result] + fn=modules.ui.wrap_gradio_call(search_extensions, extra_outputs=[gr.update(), gr.update()]), + inputs=[search_text, sort_column], + outputs=[extensions_table, info], ) - - with gr.TabItem("Install from URL", id="install_from_url"): + with gr.TabItem("Manual install", id="install_from_url"): install_url = gr.Text(label="URL for extension's git repository") install_branch = gr.Text(label="Specific branch name", placeholder="Leave empty for default main branch") install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto") install_button = gr.Button(value="Install", variant="primary") - install_result = gr.HTML(elem_id="extension_install_result") - + info = gr.HTML(elem_id="extension_info") install_button.click( fn=modules.ui.wrap_gradio_call(install_extension_from_url, extra_outputs=[gr.update()]), - inputs=[install_dirname, install_url, install_branch], - outputs=[extensions_table, install_result], + inputs=[install_dirname, install_url, install_branch, search_text, sort_column], + outputs=[extensions_table, info], ) - return ui diff --git a/scripts/custom_code.py b/scripts/custom_code.py index 4071d86d8..2dd036f2d 100644 --- a/scripts/custom_code.py +++ b/scripts/custom_code.py @@ -1,90 +1,90 @@ -import modules.scripts as scripts -import gradio as gr -import ast -import copy - -from modules.processing import Processed -from modules.shared import opts, cmd_opts, state - - -def convertExpr2Expression(expr): - expr.lineno = 0 - expr.col_offset = 0 - result = ast.Expression(expr.value, lineno=0, col_offset = 0) - - return result - - -def exec_with_return(code, module): - """ - like exec() but can return values - https://stackoverflow.com/a/52361938/5862977 - """ - code_ast = ast.parse(code) - - init_ast = copy.deepcopy(code_ast) - init_ast.body = code_ast.body[:-1] - - last_ast = copy.deepcopy(code_ast) - last_ast.body = code_ast.body[-1:] - - exec(compile(init_ast, "", "exec"), module.__dict__) - if type(last_ast.body[0]) == ast.Expr: - return eval(compile(convertExpr2Expression(last_ast.body[0]), "", "eval"), module.__dict__) - else: - exec(compile(last_ast, "", "exec"), module.__dict__) - - -class Script(scripts.Script): - - def title(self): - return "Custom code" - - def show(self, is_img2img): - return cmd_opts.allow_code - - def ui(self, is_img2img): - example = """from modules.processing import process_images - -p.width = 768 -p.height = 768 -p.batch_size = 2 -p.steps = 10 - -return process_images(p) -""" - - - code = gr.Code(value=example, language="python", label="Python code", elem_id=self.elem_id("code")) - indent_level = gr.Number(label='Indent level', value=2, precision=0, elem_id=self.elem_id("indent_level")) - - return [code, indent_level] - - def run(self, p, code, indent_level): - assert cmd_opts.allow_code, '--allow-code option must be enabled' - - display_result_data = [[], -1, ""] - - def display(imgs, s=display_result_data[1], i=display_result_data[2]): - display_result_data[0] = imgs - display_result_data[1] = s - display_result_data[2] = i - - from types import ModuleType - module = ModuleType("testmodule") - module.__dict__.update(globals()) - module.p = p - module.display = display - - indent = " " * indent_level - indented = code.replace('\n', '\n' + indent) - body = f"""def __webuitemp__(): -{indent}{indented} -__webuitemp__()""" - - result = exec_with_return(body, module) - - if isinstance(result, Processed): - return result - - return Processed(p, *display_result_data) +import copy +import ast +import gradio as gr +import modules.scripts as scripts + +from modules.processing import Processed +from modules.shared import opts, cmd_opts, state # pylint: disable=unused-import + + +def convertExpr2Expression(expr): + expr.lineno = 0 + expr.col_offset = 0 + result = ast.Expression(expr.value, lineno=0, col_offset = 0) + + return result + + +def exec_with_return(code, module): + """ + like exec() but can return values + https://stackoverflow.com/a/52361938/5862977 + """ + code_ast = ast.parse(code) + + init_ast = copy.deepcopy(code_ast) + init_ast.body = code_ast.body[:-1] + + last_ast = copy.deepcopy(code_ast) + last_ast.body = code_ast.body[-1:] + + exec(compile(init_ast, "", "exec"), module.__dict__) + if type(last_ast.body[0]) == ast.Expr: + return eval(compile(convertExpr2Expression(last_ast.body[0]), "", "eval"), module.__dict__) + else: + exec(compile(last_ast, "", "exec"), module.__dict__) + + +class Script(scripts.Script): + + def title(self): + return "Custom code" + + def show(self, is_img2img): + return cmd_opts.allow_code + + def ui(self, is_img2img): + example = """from modules.processing import process_images + +p.width = 768 +p.height = 768 +p.batch_size = 2 +p.steps = 10 + +return process_images(p) +""" + + + code = gr.Code(value=example, language="python", label="Python code", elem_id=self.elem_id("code")) + indent_level = gr.Number(label='Indent level', value=2, precision=0, elem_id=self.elem_id("indent_level")) + + return [code, indent_level] + + def run(self, p, code, indent_level): + assert cmd_opts.allow_code, '--allow-code option must be enabled' + + display_result_data = [[], -1, ""] + + def display(imgs, s=display_result_data[1], i=display_result_data[2]): + display_result_data[0] = imgs + display_result_data[1] = s + display_result_data[2] = i + + from types import ModuleType + module = ModuleType("testmodule") + module.__dict__.update(globals()) + module.p = p + module.display = display + + indent = " " * indent_level + indented = code.replace('\n', '\n' + indent) + body = f"""def __webuitemp__(): +{indent}{indented} +__webuitemp__()""" + + result = exec_with_return(body, module) + + if isinstance(result, Processed): + return result + + return Processed(p, *display_result_data) diff --git a/scripts/img2imgalt.py b/scripts/img2imgalt.py index 3066100a2..ffadedbd9 100644 --- a/scripts/img2imgalt.py +++ b/scripts/img2imgalt.py @@ -1,176 +1,176 @@ -from collections import namedtuple -import numpy as np -from tqdm import trange -import torch -import k_diffusion as K -import gradio as gr -import modules.scripts as scripts -from modules import processing, shared, sd_samplers, sd_samplers_common - - -def find_noise_for_image(p, cond, uncond, cfg_scale, steps): - x = p.init_latent - s_in = x.new_ones([x.shape[0]]) - if shared.sd_model.parameterization == "v": - dnw = K.external.CompVisVDenoiser(shared.sd_model) - skip = 1 - else: - dnw = K.external.CompVisDenoiser(shared.sd_model) - skip = 0 - sigmas = dnw.get_sigmas(steps).flip(0) - shared.state.sampling_steps = steps - - for i in trange(1, len(sigmas)): - shared.state.sampling_step += 1 - x_in = torch.cat([x] * 2) - sigma_in = torch.cat([sigmas[i] * s_in] * 2) - cond_in = torch.cat([uncond, cond]) - image_conditioning = torch.cat([p.image_conditioning] * 2) - cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]} - c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)[skip:]] - t = dnw.sigma_to_t(sigma_in) - eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in) - denoised_uncond, denoised_cond = (x_in + eps * c_out).chunk(2) - denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cfg_scale - d = (x - denoised) / sigmas[i] - dt = sigmas[i] - sigmas[i - 1] - x = x + d * dt - sd_samplers_common.store_latent(x) - # This shouldn't be necessary, but solved some VRAM issues - del x_in, sigma_in, cond_in, c_out, c_in, t, - del eps, denoised_uncond, denoised_cond, denoised, d, dt - - shared.state.nextjob() - return x / x.std() - - -Cached = namedtuple("Cached", ["noise", "cfg_scale", "steps", "latent", "original_prompt", "original_negative_prompt", "sigma_adjustment"]) - - -# Based on changes suggested by briansemrau in https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/736 -def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps): - x = p.init_latent - s_in = x.new_ones([x.shape[0]]) - if shared.sd_model.parameterization == "v": - dnw = K.external.CompVisVDenoiser(shared.sd_model) - skip = 1 - else: - dnw = K.external.CompVisDenoiser(shared.sd_model) - skip = 0 - sigmas = dnw.get_sigmas(steps).flip(0) - - shared.state.sampling_steps = steps - - for i in trange(1, len(sigmas)): - shared.state.sampling_step += 1 - x_in = torch.cat([x] * 2) - sigma_in = torch.cat([sigmas[i - 1] * s_in] * 2) - cond_in = torch.cat([uncond, cond]) - image_conditioning = torch.cat([p.image_conditioning] * 2) - cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]} - c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)[skip:]] - if i == 1: - t = dnw.sigma_to_t(torch.cat([sigmas[i] * s_in] * 2)) - else: - t = dnw.sigma_to_t(sigma_in) - eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in) - denoised_uncond, denoised_cond = (x_in + eps * c_out).chunk(2) - denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cfg_scale - if i == 1: - d = (x - denoised) / (2 * sigmas[i]) - else: - d = (x - denoised) / sigmas[i - 1] - dt = sigmas[i] - sigmas[i - 1] - x = x + d * dt - sd_samplers_common.store_latent(x) - # This shouldn't be necessary, but solved some VRAM issues - del x_in, sigma_in, cond_in, c_out, c_in, t, - del eps, denoised_uncond, denoised_cond, denoised, d, dt - - shared.state.nextjob() - return x / sigmas[-1] - - -class Script(scripts.Script): - def __init__(self): - self.cache = None - - def title(self): - return "Alternative" - - def show(self, is_img2img): - return is_img2img - - def ui(self, is_img2img): - info = gr.Markdown(''' - * `CFG Scale` should be 2 or lower. - ''') - override_sampler = gr.Checkbox(label="Override `Sampling method` to Euler?(this method is built for it)", value=True, elem_id=self.elem_id("override_sampler")) - override_prompt = gr.Checkbox(label="Override `prompt` to the same value as `original prompt`?(and `negative prompt`)", value=True, elem_id=self.elem_id("override_prompt")) - original_prompt = gr.Textbox(label="Original prompt", lines=1, elem_id=self.elem_id("original_prompt")) - original_negative_prompt = gr.Textbox(label="Original negative prompt", lines=1, elem_id=self.elem_id("original_negative_prompt")) - override_steps = gr.Checkbox(label="Override `Sampling Steps` to the same value as `Decode steps`?", value=True, elem_id=self.elem_id("override_steps")) - st = gr.Slider(label="Decode steps", minimum=1, maximum=150, step=1, value=50, elem_id=self.elem_id("st")) - override_strength = gr.Checkbox(label="Override `Denoising strength` to 1?", value=True, elem_id=self.elem_id("override_strength")) - cfg = gr.Slider(label="Decode CFG scale", minimum=0.0, maximum=15.0, step=0.1, value=1.0, elem_id=self.elem_id("cfg")) - randomness = gr.Slider(label="Randomness", minimum=0.0, maximum=1.0, step=0.01, value=0.0, elem_id=self.elem_id("randomness")) - sigma_adjustment = gr.Checkbox(label="Sigma adjustment for finding noise for image", value=False, elem_id=self.elem_id("sigma_adjustment")) - - return [ - info, - override_sampler, - override_prompt, original_prompt, original_negative_prompt, - override_steps, st, - override_strength, - cfg, randomness, sigma_adjustment, - ] - - def run(self, p, _, override_sampler, override_prompt, original_prompt, original_negative_prompt, override_steps, st, override_strength, cfg, randomness, sigma_adjustment): - # Override - if override_sampler: - p.sampler_name = "Euler" - if override_prompt: - p.prompt = original_prompt - p.negative_prompt = original_negative_prompt - if override_steps: - p.steps = st - if override_strength: - p.denoising_strength = 1.0 - - def sample_extra(conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): - lat = (p.init_latent.cpu().numpy() * 10).astype(int) - same_params = self.cache is not None and self.cache.cfg_scale == cfg and self.cache.steps == st \ - and self.cache.original_prompt == original_prompt \ - and self.cache.original_negative_prompt == original_negative_prompt \ - and self.cache.sigma_adjustment == sigma_adjustment - same_everything = same_params and self.cache.latent.shape == lat.shape and np.abs(self.cache.latent-lat).sum() < 100 - if same_everything: - rec_noise = self.cache.noise - else: - shared.state.job_count += 1 - cond = p.sd_model.get_learned_conditioning(p.batch_size * [original_prompt]) - uncond = p.sd_model.get_learned_conditioning(p.batch_size * [original_negative_prompt]) - if sigma_adjustment: - rec_noise = find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg, st) - else: - rec_noise = find_noise_for_image(p, cond, uncond, cfg, st) - self.cache = Cached(rec_noise, cfg, st, lat, original_prompt, original_negative_prompt, sigma_adjustment) - - rand_noise = processing.create_random_tensors(p.init_latent.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w, p=p) - combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5) - sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model) - sigmas = sampler.model_wrap.get_sigmas(p.steps) - noise_dt = combined_noise - (p.init_latent / sigmas[0]) - p.seed = p.seed + 1 - return sampler.sample_img2img(p, p.init_latent, noise_dt, conditioning, unconditional_conditioning, image_conditioning=p.image_conditioning) - - p.sample = sample_extra - p.extra_generation_params["Decode prompt"] = original_prompt - p.extra_generation_params["Decode negative prompt"] = original_negative_prompt - p.extra_generation_params["Decode CFG scale"] = cfg - p.extra_generation_params["Decode steps"] = st - p.extra_generation_params["Randomness"] = randomness - p.extra_generation_params["Sigma Adjustment"] = sigma_adjustment - processed = processing.process_images(p) - - return processed +from collections import namedtuple +import numpy as np +from tqdm import trange +import torch +import k_diffusion as K +import gradio as gr +import modules.scripts as scripts +from modules import processing, shared, sd_samplers, sd_samplers_common + + +def find_noise_for_image(p, cond, uncond, cfg_scale, steps): + x = p.init_latent + s_in = x.new_ones([x.shape[0]]) + if shared.sd_model.parameterization == "v": + dnw = K.external.CompVisVDenoiser(shared.sd_model) + skip = 1 + else: + dnw = K.external.CompVisDenoiser(shared.sd_model) + skip = 0 + sigmas = dnw.get_sigmas(steps).flip(0) + shared.state.sampling_steps = steps + + for i in trange(1, len(sigmas)): + shared.state.sampling_step += 1 + x_in = torch.cat([x] * 2) + sigma_in = torch.cat([sigmas[i] * s_in] * 2) + cond_in = torch.cat([uncond, cond]) + image_conditioning = torch.cat([p.image_conditioning] * 2) + cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]} + c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)[skip:]] + t = dnw.sigma_to_t(sigma_in) + eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in) + denoised_uncond, denoised_cond = (x_in + eps * c_out).chunk(2) + denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cfg_scale + d = (x - denoised) / sigmas[i] + dt = sigmas[i] - sigmas[i - 1] + x = x + d * dt + sd_samplers_common.store_latent(x) + # This shouldn't be necessary, but solved some VRAM issues + del x_in, sigma_in, cond_in, c_out, c_in, t, + del eps, denoised_uncond, denoised_cond, denoised, d, dt + + shared.state.nextjob() + return x / x.std() + + +Cached = namedtuple("Cached", ["noise", "cfg_scale", "steps", "latent", "original_prompt", "original_negative_prompt", "sigma_adjustment"]) + + +# Based on changes suggested by briansemrau in https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/736 +def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps): + x = p.init_latent + s_in = x.new_ones([x.shape[0]]) + if shared.sd_model.parameterization == "v": + dnw = K.external.CompVisVDenoiser(shared.sd_model) + skip = 1 + else: + dnw = K.external.CompVisDenoiser(shared.sd_model) + skip = 0 + sigmas = dnw.get_sigmas(steps).flip(0) + + shared.state.sampling_steps = steps + + for i in trange(1, len(sigmas)): + shared.state.sampling_step += 1 + x_in = torch.cat([x] * 2) + sigma_in = torch.cat([sigmas[i - 1] * s_in] * 2) + cond_in = torch.cat([uncond, cond]) + image_conditioning = torch.cat([p.image_conditioning] * 2) + cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]} + c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)[skip:]] + if i == 1: + t = dnw.sigma_to_t(torch.cat([sigmas[i] * s_in] * 2)) + else: + t = dnw.sigma_to_t(sigma_in) + eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in) + denoised_uncond, denoised_cond = (x_in + eps * c_out).chunk(2) + denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cfg_scale + if i == 1: + d = (x - denoised) / (2 * sigmas[i]) + else: + d = (x - denoised) / sigmas[i - 1] + dt = sigmas[i] - sigmas[i - 1] + x = x + d * dt + sd_samplers_common.store_latent(x) + # This shouldn't be necessary, but solved some VRAM issues + del x_in, sigma_in, cond_in, c_out, c_in, t, + del eps, denoised_uncond, denoised_cond, denoised, d, dt + + shared.state.nextjob() + return x / sigmas[-1] + + +class Script(scripts.Script): + def __init__(self): + self.cache = None + + def title(self): + return "Alternative" + + def show(self, is_img2img): + return is_img2img + + def ui(self, is_img2img): + info = gr.Markdown(''' + * `CFG Scale` should be 2 or lower. + ''') + override_sampler = gr.Checkbox(label="Override `Sampling method` to Euler?(this method is built for it)", value=True, elem_id=self.elem_id("override_sampler")) + override_prompt = gr.Checkbox(label="Override `prompt` to the same value as `original prompt`?(and `negative prompt`)", value=True, elem_id=self.elem_id("override_prompt")) + original_prompt = gr.Textbox(label="Original prompt", lines=1, elem_id=self.elem_id("original_prompt")) + original_negative_prompt = gr.Textbox(label="Original negative prompt", lines=1, elem_id=self.elem_id("original_negative_prompt")) + override_steps = gr.Checkbox(label="Override `Sampling Steps` to the same value as `Decode steps`?", value=True, elem_id=self.elem_id("override_steps")) + st = gr.Slider(label="Decode steps", minimum=1, maximum=150, step=1, value=50, elem_id=self.elem_id("st")) + override_strength = gr.Checkbox(label="Override `Denoising strength` to 1?", value=True, elem_id=self.elem_id("override_strength")) + cfg = gr.Slider(label="Decode CFG scale", minimum=0.0, maximum=15.0, step=0.1, value=1.0, elem_id=self.elem_id("cfg")) + randomness = gr.Slider(label="Randomness", minimum=0.0, maximum=1.0, step=0.01, value=0.0, elem_id=self.elem_id("randomness")) + sigma_adjustment = gr.Checkbox(label="Sigma adjustment for finding noise for image", value=False, elem_id=self.elem_id("sigma_adjustment")) + + return [ + info, + override_sampler, + override_prompt, original_prompt, original_negative_prompt, + override_steps, st, + override_strength, + cfg, randomness, sigma_adjustment, + ] + + def run(self, p, _, override_sampler, override_prompt, original_prompt, original_negative_prompt, override_steps, st, override_strength, cfg, randomness, sigma_adjustment): # pylint: disable=arguments-differ + # Override + if override_sampler: + p.sampler_name = "Euler" + if override_prompt: + p.prompt = original_prompt + p.negative_prompt = original_negative_prompt + if override_steps: + p.steps = st + if override_strength: + p.denoising_strength = 1.0 + + def sample_extra(conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): # pylint: disable=unused-argument + lat = (p.init_latent.cpu().numpy() * 10).astype(int) + same_params = self.cache is not None and self.cache.cfg_scale == cfg and self.cache.steps == st \ + and self.cache.original_prompt == original_prompt \ + and self.cache.original_negative_prompt == original_negative_prompt \ + and self.cache.sigma_adjustment == sigma_adjustment + same_everything = same_params and self.cache.latent.shape == lat.shape and np.abs(self.cache.latent-lat).sum() < 100 + if same_everything: + rec_noise = self.cache.noise + else: + shared.state.job_count += 1 + cond = p.sd_model.get_learned_conditioning(p.batch_size * [original_prompt]) + uncond = p.sd_model.get_learned_conditioning(p.batch_size * [original_negative_prompt]) + if sigma_adjustment: + rec_noise = find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg, st) + else: + rec_noise = find_noise_for_image(p, cond, uncond, cfg, st) + self.cache = Cached(rec_noise, cfg, st, lat, original_prompt, original_negative_prompt, sigma_adjustment) + + rand_noise = processing.create_random_tensors(p.init_latent.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w, p=p) + combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5) + sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model) + sigmas = sampler.model_wrap.get_sigmas(p.steps) + noise_dt = combined_noise - (p.init_latent / sigmas[0]) + p.seed = p.seed + 1 + return sampler.sample_img2img(p, p.init_latent, noise_dt, conditioning, unconditional_conditioning, image_conditioning=p.image_conditioning) + + p.sample = sample_extra + p.extra_generation_params["Decode prompt"] = original_prompt + p.extra_generation_params["Decode negative prompt"] = original_negative_prompt + p.extra_generation_params["Decode CFG scale"] = cfg + p.extra_generation_params["Decode steps"] = st + p.extra_generation_params["Randomness"] = randomness + p.extra_generation_params["Sigma Adjustment"] = sigma_adjustment + processed = processing.process_images(p) + + return processed diff --git a/scripts/loopback.py b/scripts/loopback.py index d3065fe6b..5ce5e8271 100644 --- a/scripts/loopback.py +++ b/scripts/loopback.py @@ -1,140 +1,140 @@ -import math - -import gradio as gr -import modules.scripts as scripts -from modules import deepbooru, images, processing, shared -from modules.processing import Processed -from modules.shared import opts, state - - -class Script(scripts.Script): - def title(self): - return "Loopback" - - def show(self, is_img2img): - return is_img2img - - def ui(self, is_img2img): - loops = gr.Slider(minimum=1, maximum=32, step=1, label='Loops', value=4, elem_id=self.elem_id("loops")) - final_denoising_strength = gr.Slider(minimum=0, maximum=1, step=0.01, label='Final denoising strength', value=0.5, elem_id=self.elem_id("final_denoising_strength")) - denoising_curve = gr.Dropdown(label="Denoising strength curve", choices=["Aggressive", "Linear", "Lazy"], value="Linear") - append_interrogation = gr.Dropdown(label="Append interrogated prompt at each iteration", choices=["None", "CLIP", "DeepBooru"], value="None") - - return [loops, final_denoising_strength, denoising_curve, append_interrogation] - - def run(self, p, loops, final_denoising_strength, denoising_curve, append_interrogation): - processing.fix_seed(p) - batch_count = p.n_iter - p.extra_generation_params = { - "Final denoising strength": final_denoising_strength, - "Denoising curve": denoising_curve - } - - p.batch_size = 1 - p.n_iter = 1 - - info = None - initial_seed = None - initial_info = None - initial_denoising_strength = p.denoising_strength - - grids = [] - all_images = [] - original_init_image = p.init_images - original_prompt = p.prompt - original_inpainting_fill = p.inpainting_fill - state.job_count = loops * batch_count - - initial_color_corrections = [processing.setup_color_correction(p.init_images[0])] - - def calculate_denoising_strength(loop): - strength = initial_denoising_strength - - if loops == 1: - return strength - - progress = loop / (loops - 1) - if denoising_curve == "Aggressive": - strength = math.sin((progress) * math.pi * 0.5) - elif denoising_curve == "Lazy": - strength = 1 - math.cos((progress) * math.pi * 0.5) - else: - strength = progress - - change = (final_denoising_strength - initial_denoising_strength) * strength - return initial_denoising_strength + change - - history = [] - - for n in range(batch_count): - # Reset to original init image at the start of each batch - p.init_images = original_init_image - - # Reset to original denoising strength - p.denoising_strength = initial_denoising_strength - - last_image = None - - for i in range(loops): - p.n_iter = 1 - p.batch_size = 1 - p.do_not_save_grid = True - - if opts.img2img_color_correction: - p.color_corrections = initial_color_corrections - - if append_interrogation != "None": - p.prompt = original_prompt + ", " if original_prompt != "" else "" - if append_interrogation == "CLIP": - p.prompt += shared.interrogator.interrogate(p.init_images[0]) - elif append_interrogation == "DeepBooru": - p.prompt += deepbooru.model.tag(p.init_images[0]) - - state.job = f"Iteration {i + 1}/{loops}, batch {n + 1}/{batch_count}" - - processed = processing.process_images(p) - - # Generation cancelled. - if state.interrupted: - break - - if initial_seed is None: - initial_seed = processed.seed - initial_info = processed.info - - p.seed = processed.seed + 1 - p.denoising_strength = calculate_denoising_strength(i + 1) - - if state.skipped: - break - - last_image = processed.images[0] - p.init_images = [last_image] - p.inpainting_fill = 1 # Set "masked content" to "original" for next loop. - - if batch_count == 1: - history.append(last_image) - all_images.append(last_image) - - if batch_count > 1 and not state.skipped and not state.interrupted: - history.append(last_image) - all_images.append(last_image) - - p.inpainting_fill = original_inpainting_fill - - if state.interrupted: - break - - if len(history) > 1: - grid = images.image_grid(history, rows=1) - if opts.grid_save: - images.save_image(grid, p.outpath_grids, "grid", initial_seed, p.prompt, opts.grid_format, info=info, short_filename=not opts.grid_extended_filename, grid=True, p=p) - - if opts.return_grid: - grids.append(grid) - - all_images = grids + all_images - - processed = Processed(p, all_images, initial_seed, initial_info) - - return processed +import math + +import gradio as gr +import modules.scripts as scripts +from modules import deepbooru, images, processing, shared +from modules.processing import Processed +from modules.shared import opts, state + + +class Script(scripts.Script): + def title(self): + return "Loopback" + + def show(self, is_img2img): + return is_img2img + + def ui(self, is_img2img): + loops = gr.Slider(minimum=1, maximum=32, step=1, label='Loops', value=4, elem_id=self.elem_id("loops")) + final_denoising_strength = gr.Slider(minimum=0, maximum=1, step=0.01, label='Final denoising strength', value=0.5, elem_id=self.elem_id("final_denoising_strength")) + denoising_curve = gr.Dropdown(label="Denoising strength curve", choices=["Aggressive", "Linear", "Lazy"], value="Linear") + append_interrogation = gr.Dropdown(label="Append interrogated prompt at each iteration", choices=["None", "CLIP", "DeepBooru"], value="None") + + return [loops, final_denoising_strength, denoising_curve, append_interrogation] + + def run(self, p, loops, final_denoising_strength, denoising_curve, append_interrogation): # pylint: disable=arguments-differ + processing.fix_seed(p) + batch_count = p.n_iter + p.extra_generation_params = { + "Final denoising strength": final_denoising_strength, + "Denoising curve": denoising_curve + } + + p.batch_size = 1 + p.n_iter = 1 + + info = None + initial_seed = None + initial_info = None + initial_denoising_strength = p.denoising_strength + + grids = [] + all_images = [] + original_init_image = p.init_images + original_prompt = p.prompt + original_inpainting_fill = p.inpainting_fill + state.job_count = loops * batch_count + + initial_color_corrections = [processing.setup_color_correction(p.init_images[0])] + + def calculate_denoising_strength(loop): + strength = initial_denoising_strength + + if loops == 1: + return strength + + progress = loop / (loops - 1) + if denoising_curve == "Aggressive": + strength = math.sin((progress) * math.pi * 0.5) + elif denoising_curve == "Lazy": + strength = 1 - math.cos((progress) * math.pi * 0.5) + else: + strength = progress + + change = (final_denoising_strength - initial_denoising_strength) * strength + return initial_denoising_strength + change + + history = [] + + for n in range(batch_count): + # Reset to original init image at the start of each batch + p.init_images = original_init_image + + # Reset to original denoising strength + p.denoising_strength = initial_denoising_strength + + last_image = None + + for i in range(loops): + p.n_iter = 1 + p.batch_size = 1 + p.do_not_save_grid = True + + if opts.img2img_color_correction: + p.color_corrections = initial_color_corrections + + if append_interrogation != "None": + p.prompt = original_prompt + ", " if original_prompt != "" else "" + if append_interrogation == "CLIP": + p.prompt += shared.interrogator.interrogate(p.init_images[0]) + elif append_interrogation == "DeepBooru": + p.prompt += deepbooru.model.tag(p.init_images[0]) + + state.job = f"Iteration {i + 1}/{loops}, batch {n + 1}/{batch_count}" + + processed = processing.process_images(p) + + # Generation cancelled. + if state.interrupted: + break + + if initial_seed is None: + initial_seed = processed.seed + initial_info = processed.info + + p.seed = processed.seed + 1 + p.denoising_strength = calculate_denoising_strength(i + 1) + + if state.skipped: + break + + last_image = processed.images[0] + p.init_images = [last_image] + p.inpainting_fill = 1 # Set "masked content" to "original" for next loop. + + if batch_count == 1: + history.append(last_image) + all_images.append(last_image) + + if batch_count > 1 and not state.skipped and not state.interrupted: + history.append(last_image) + all_images.append(last_image) + + p.inpainting_fill = original_inpainting_fill + + if state.interrupted: + break + + if len(history) > 1: + grid = images.image_grid(history, rows=1) + if opts.grid_save: + images.save_image(grid, p.outpath_grids, "grid", initial_seed, p.prompt, opts.grid_format, info=info, short_filename=not opts.grid_extended_filename, grid=True, p=p) + + if opts.return_grid: + grids.append(grid) + + all_images = grids + all_images + + processed = Processed(p, all_images, initial_seed, initial_info) + + return processed diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py index d375ff764..f546fd94e 100644 --- a/scripts/outpainting_mk_2.py +++ b/scripts/outpainting_mk_2.py @@ -1,283 +1,280 @@ -import math - -import numpy as np -import skimage - -import modules.scripts as scripts -import gradio as gr -from PIL import Image, ImageDraw - -from modules import images, processing, devices -from modules.processing import Processed, process_images -from modules.shared import opts, cmd_opts, state - - -# this function is taken from https://github.com/parlance-zz/g-diffuser-bot -def get_matched_noise(_np_src_image, np_mask_rgb, noise_q=1, color_variation=0.05): - # helper fft routines that keep ortho normalization and auto-shift before and after fft - def _fft2(data): - if data.ndim > 2: # has channels - out_fft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128) - for c in range(data.shape[2]): - c_data = data[:, :, c] - out_fft[:, :, c] = np.fft.fft2(np.fft.fftshift(c_data), norm="ortho") - out_fft[:, :, c] = np.fft.ifftshift(out_fft[:, :, c]) - else: # one channel - out_fft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128) - out_fft[:, :] = np.fft.fft2(np.fft.fftshift(data), norm="ortho") - out_fft[:, :] = np.fft.ifftshift(out_fft[:, :]) - - return out_fft - - def _ifft2(data): - if data.ndim > 2: # has channels - out_ifft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128) - for c in range(data.shape[2]): - c_data = data[:, :, c] - out_ifft[:, :, c] = np.fft.ifft2(np.fft.fftshift(c_data), norm="ortho") - out_ifft[:, :, c] = np.fft.ifftshift(out_ifft[:, :, c]) - else: # one channel - out_ifft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128) - out_ifft[:, :] = np.fft.ifft2(np.fft.fftshift(data), norm="ortho") - out_ifft[:, :] = np.fft.ifftshift(out_ifft[:, :]) - - return out_ifft - - def _get_gaussian_window(width, height, std=3.14, mode=0): - window_scale_x = float(width / min(width, height)) - window_scale_y = float(height / min(width, height)) - - window = np.zeros((width, height)) - x = (np.arange(width) / width * 2. - 1.) * window_scale_x - for y in range(height): - fy = (y / height * 2. - 1.) * window_scale_y - if mode == 0: - window[:, y] = np.exp(-(x ** 2 + fy ** 2) * std) - else: - window[:, y] = (1 / ((x ** 2 + 1.) * (fy ** 2 + 1.))) ** (std / 3.14) # hey wait a minute that's not gaussian - - return window - - def _get_masked_window_rgb(np_mask_grey, hardness=1.): - np_mask_rgb = np.zeros((np_mask_grey.shape[0], np_mask_grey.shape[1], 3)) - if hardness != 1.: - hardened = np_mask_grey[:] ** hardness - else: - hardened = np_mask_grey[:] - for c in range(3): - np_mask_rgb[:, :, c] = hardened[:] - return np_mask_rgb - - width = _np_src_image.shape[0] - height = _np_src_image.shape[1] - num_channels = _np_src_image.shape[2] - - np_src_image = _np_src_image[:] * (1. - np_mask_rgb) - np_mask_grey = (np.sum(np_mask_rgb, axis=2) / 3.) - img_mask = np_mask_grey > 1e-6 - ref_mask = np_mask_grey < 1e-3 - - windowed_image = _np_src_image * (1. - _get_masked_window_rgb(np_mask_grey)) - windowed_image /= np.max(windowed_image) - windowed_image += np.average(_np_src_image) * np_mask_rgb # / (1.-np.average(np_mask_rgb)) # rather than leave the masked area black, we get better results from fft by filling the average unmasked color - - src_fft = _fft2(windowed_image) # get feature statistics from masked src img - src_dist = np.absolute(src_fft) - src_phase = src_fft / src_dist - - # create a generator with a static seed to make outpainting deterministic / only follow global seed - rng = np.random.default_rng(0) - - noise_window = _get_gaussian_window(width, height, mode=1) # start with simple gaussian noise - noise_rgb = rng.random((width, height, num_channels)) - noise_grey = (np.sum(noise_rgb, axis=2) / 3.) - noise_rgb *= color_variation # the colorfulness of the starting noise is blended to greyscale with a parameter - for c in range(num_channels): - noise_rgb[:, :, c] += (1. - color_variation) * noise_grey - - noise_fft = _fft2(noise_rgb) - for c in range(num_channels): - noise_fft[:, :, c] *= noise_window - noise_rgb = np.real(_ifft2(noise_fft)) - shaped_noise_fft = _fft2(noise_rgb) - shaped_noise_fft[:, :, :] = np.absolute(shaped_noise_fft[:, :, :]) ** 2 * (src_dist ** noise_q) * src_phase # perform the actual shaping - - brightness_variation = 0. # color_variation # todo: temporarily tieing brightness variation to color variation for now - contrast_adjusted_np_src = _np_src_image[:] * (brightness_variation + 1.) - brightness_variation * 2. - - # scikit-image is used for histogram matching, very convenient! - shaped_noise = np.real(_ifft2(shaped_noise_fft)) - shaped_noise -= np.min(shaped_noise) - shaped_noise /= np.max(shaped_noise) - shaped_noise[img_mask, :] = skimage.exposure.match_histograms(shaped_noise[img_mask, :] ** 1., contrast_adjusted_np_src[ref_mask, :], channel_axis=1) - shaped_noise = _np_src_image[:] * (1. - np_mask_rgb) + shaped_noise * np_mask_rgb - - matched_noise = shaped_noise[:] - - return np.clip(matched_noise, 0., 1.) - - - -class Script(scripts.Script): - def title(self): - return "Outpainting" - - def show(self, is_img2img): - return is_img2img - - def ui(self, is_img2img): - if not is_img2img: - return None - - info = gr.HTML("

Recommended settings: Sampling Steps: 80-100, Sampler: Euler a, Denoising strength: 0.8

") - - pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) - mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8, elem_id=self.elem_id("mask_blur")) - direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction")) - noise_q = gr.Slider(label="Fall-off exponent (lower=higher detail)", minimum=0.0, maximum=4.0, step=0.01, value=1.0, elem_id=self.elem_id("noise_q")) - color_variation = gr.Slider(label="Color variation", minimum=0.0, maximum=1.0, step=0.01, value=0.05, elem_id=self.elem_id("color_variation")) - - return [info, pixels, mask_blur, direction, noise_q, color_variation] - - def run(self, p, _, pixels, mask_blur, direction, noise_q, color_variation): - initial_seed_and_info = [None, None] - - process_width = p.width - process_height = p.height - - p.mask_blur = mask_blur*4 - p.inpaint_full_res = False - p.inpainting_fill = 1 - p.do_not_save_samples = True - p.do_not_save_grid = True - - left = pixels if "left" in direction else 0 - right = pixels if "right" in direction else 0 - up = pixels if "up" in direction else 0 - down = pixels if "down" in direction else 0 - - init_img = p.init_images[0] - target_w = math.ceil((init_img.width + left + right) / 64) * 64 - target_h = math.ceil((init_img.height + up + down) / 64) * 64 - - if left > 0: - left = left * (target_w - init_img.width) // (left + right) - - if right > 0: - right = target_w - init_img.width - left - - if up > 0: - up = up * (target_h - init_img.height) // (up + down) - - if down > 0: - down = target_h - init_img.height - up - - def expand(init, count, expand_pixels, is_left=False, is_right=False, is_top=False, is_bottom=False): - is_horiz = is_left or is_right - is_vert = is_top or is_bottom - pixels_horiz = expand_pixels if is_horiz else 0 - pixels_vert = expand_pixels if is_vert else 0 - - images_to_process = [] - output_images = [] - for n in range(count): - res_w = init[n].width + pixels_horiz - res_h = init[n].height + pixels_vert - process_res_w = math.ceil(res_w / 64) * 64 - process_res_h = math.ceil(res_h / 64) * 64 - - img = Image.new("RGB", (process_res_w, process_res_h)) - img.paste(init[n], (pixels_horiz if is_left else 0, pixels_vert if is_top else 0)) - mask = Image.new("RGB", (process_res_w, process_res_h), "white") - draw = ImageDraw.Draw(mask) - draw.rectangle(( - expand_pixels + mask_blur if is_left else 0, - expand_pixels + mask_blur if is_top else 0, - mask.width - expand_pixels - mask_blur if is_right else res_w, - mask.height - expand_pixels - mask_blur if is_bottom else res_h, - ), fill="black") - - np_image = (np.asarray(img) / 255.0).astype(np.float64) - np_mask = (np.asarray(mask) / 255.0).astype(np.float64) - noised = get_matched_noise(np_image, np_mask, noise_q, color_variation) - output_images.append(Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB")) - - target_width = min(process_width, init[n].width + pixels_horiz) if is_horiz else img.width - target_height = min(process_height, init[n].height + pixels_vert) if is_vert else img.height - p.width = target_width if is_horiz else img.width - p.height = target_height if is_vert else img.height - - crop_region = ( - 0 if is_left else output_images[n].width - target_width, - 0 if is_top else output_images[n].height - target_height, - target_width if is_left else output_images[n].width, - target_height if is_top else output_images[n].height, - ) - mask = mask.crop(crop_region) - p.image_mask = mask - - image_to_process = output_images[n].crop(crop_region) - images_to_process.append(image_to_process) - - p.init_images = images_to_process - - latent_mask = Image.new("RGB", (p.width, p.height), "white") - draw = ImageDraw.Draw(latent_mask) - draw.rectangle(( - expand_pixels + mask_blur * 2 if is_left else 0, - expand_pixels + mask_blur * 2 if is_top else 0, - mask.width - expand_pixels - mask_blur * 2 if is_right else res_w, - mask.height - expand_pixels - mask_blur * 2 if is_bottom else res_h, - ), fill="black") - p.latent_mask = latent_mask - - proc = process_images(p) - - if initial_seed_and_info[0] is None: - initial_seed_and_info[0] = proc.seed - initial_seed_and_info[1] = proc.info - - for n in range(count): - output_images[n].paste(proc.images[n], (0 if is_left else output_images[n].width - proc.images[n].width, 0 if is_top else output_images[n].height - proc.images[n].height)) - output_images[n] = output_images[n].crop((0, 0, res_w, res_h)) - - return output_images - - batch_count = p.n_iter - batch_size = p.batch_size - p.n_iter = 1 - state.job_count = batch_count * ((1 if left > 0 else 0) + (1 if right > 0 else 0) + (1 if up > 0 else 0) + (1 if down > 0 else 0)) - all_processed_images = [] - - for i in range(batch_count): - imgs = [init_img] * batch_size - state.job = f"Batch {i + 1} out of {batch_count}" - - if left > 0: - imgs = expand(imgs, batch_size, left, is_left=True) - if right > 0: - imgs = expand(imgs, batch_size, right, is_right=True) - if up > 0: - imgs = expand(imgs, batch_size, up, is_top=True) - if down > 0: - imgs = expand(imgs, batch_size, down, is_bottom=True) - - all_processed_images += imgs - - all_images = all_processed_images - - combined_grid_image = images.image_grid(all_processed_images) - unwanted_grid_because_of_img_count = len(all_processed_images) < 2 and opts.grid_only_if_multiple - if opts.return_grid and not unwanted_grid_because_of_img_count: - all_images = [combined_grid_image] + all_processed_images - - res = Processed(p, all_images, initial_seed_and_info[0], initial_seed_and_info[1]) - - if opts.samples_save: - for img in all_processed_images: - images.save_image(img, p.outpath_samples, "", res.seed, p.prompt, opts.samples_format, info=res.info, p=p) - - if opts.grid_save and not unwanted_grid_because_of_img_count: - images.save_image(combined_grid_image, p.outpath_grids, "grid", res.seed, p.prompt, opts.samples_format, info=res.info, short_filename=not opts.grid_extended_filename, grid=True, p=p) - - return res +import math +import numpy as np +import skimage +import gradio as gr +from PIL import Image, ImageDraw +import modules.scripts as scripts +from modules import images +from modules.processing import Processed, process_images +from modules.shared import opts, state + + +# this function is taken from https://github.com/parlance-zz/g-diffuser-bot +def get_matched_noise(_np_src_image, np_mask_rgb, noise_q=1, color_variation=0.05): + # helper fft routines that keep ortho normalization and auto-shift before and after fft + def _fft2(data): + if data.ndim > 2: # has channels + out_fft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128) + for c in range(data.shape[2]): + c_data = data[:, :, c] + out_fft[:, :, c] = np.fft.fft2(np.fft.fftshift(c_data), norm="ortho") + out_fft[:, :, c] = np.fft.ifftshift(out_fft[:, :, c]) + else: # one channel + out_fft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128) + out_fft[:, :] = np.fft.fft2(np.fft.fftshift(data), norm="ortho") + out_fft[:, :] = np.fft.ifftshift(out_fft[:, :]) + + return out_fft + + def _ifft2(data): + if data.ndim > 2: # has channels + out_ifft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128) + for c in range(data.shape[2]): + c_data = data[:, :, c] + out_ifft[:, :, c] = np.fft.ifft2(np.fft.fftshift(c_data), norm="ortho") + out_ifft[:, :, c] = np.fft.ifftshift(out_ifft[:, :, c]) + else: # one channel + out_ifft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128) + out_ifft[:, :] = np.fft.ifft2(np.fft.fftshift(data), norm="ortho") + out_ifft[:, :] = np.fft.ifftshift(out_ifft[:, :]) + + return out_ifft + + def _get_gaussian_window(width, height, std=3.14, mode=0): + window_scale_x = float(width / min(width, height)) + window_scale_y = float(height / min(width, height)) + + window = np.zeros((width, height)) + x = (np.arange(width) / width * 2. - 1.) * window_scale_x + for y in range(height): + fy = (y / height * 2. - 1.) * window_scale_y + if mode == 0: + window[:, y] = np.exp(-(x ** 2 + fy ** 2) * std) + else: + window[:, y] = (1 / ((x ** 2 + 1.) * (fy ** 2 + 1.))) ** (std / 3.14) # hey wait a minute that's not gaussian + + return window + + def _get_masked_window_rgb(np_mask_grey, hardness=1.): + np_mask_rgb = np.zeros((np_mask_grey.shape[0], np_mask_grey.shape[1], 3)) + if hardness != 1.: + hardened = np_mask_grey[:] ** hardness + else: + hardened = np_mask_grey[:] + for c in range(3): + np_mask_rgb[:, :, c] = hardened[:] + return np_mask_rgb + + width = _np_src_image.shape[0] + height = _np_src_image.shape[1] + num_channels = _np_src_image.shape[2] + + np_src_image = _np_src_image[:] * (1. - np_mask_rgb) + np_mask_grey = np.sum(np_mask_rgb, axis=2) / 3. + img_mask = np_mask_grey > 1e-6 + ref_mask = np_mask_grey < 1e-3 + + windowed_image = _np_src_image * (1. - _get_masked_window_rgb(np_mask_grey)) + windowed_image /= np.max(windowed_image) + windowed_image += np.average(_np_src_image) * np_mask_rgb # / (1.-np.average(np_mask_rgb)) # rather than leave the masked area black, we get better results from fft by filling the average unmasked color + + src_fft = _fft2(windowed_image) # get feature statistics from masked src img + src_dist = np.absolute(src_fft) + src_phase = src_fft / src_dist + + # create a generator with a static seed to make outpainting deterministic / only follow global seed + rng = np.random.default_rng(0) + + noise_window = _get_gaussian_window(width, height, mode=1) # start with simple gaussian noise + noise_rgb = rng.random((width, height, num_channels)) + noise_grey = np.sum(noise_rgb, axis=2) / 3. + noise_rgb *= color_variation # the colorfulness of the starting noise is blended to greyscale with a parameter + for c in range(num_channels): + noise_rgb[:, :, c] += (1. - color_variation) * noise_grey + + noise_fft = _fft2(noise_rgb) + for c in range(num_channels): + noise_fft[:, :, c] *= noise_window + noise_rgb = np.real(_ifft2(noise_fft)) + shaped_noise_fft = _fft2(noise_rgb) + shaped_noise_fft[:, :, :] = np.absolute(shaped_noise_fft[:, :, :]) ** 2 * (src_dist ** noise_q) * src_phase # perform the actual shaping + + brightness_variation = 0. # color_variation # todo: temporarily tieing brightness variation to color variation for now + contrast_adjusted_np_src = _np_src_image[:] * (brightness_variation + 1.) - brightness_variation * 2. + + # scikit-image is used for histogram matching, very convenient! + shaped_noise = np.real(_ifft2(shaped_noise_fft)) + shaped_noise -= np.min(shaped_noise) + shaped_noise /= np.max(shaped_noise) + shaped_noise[img_mask, :] = skimage.exposure.match_histograms(shaped_noise[img_mask, :] ** 1., contrast_adjusted_np_src[ref_mask, :], channel_axis=1) + shaped_noise = _np_src_image[:] * (1. - np_mask_rgb) + shaped_noise * np_mask_rgb + + matched_noise = shaped_noise[:] + + return np.clip(matched_noise, 0., 1.) + + + +class Script(scripts.Script): + def title(self): + return "Outpainting" + + def show(self, is_img2img): + return is_img2img + + def ui(self, is_img2img): + if not is_img2img: + return None + + info = gr.HTML("

Recommended settings: Sampling Steps: 80-100, Sampler: Euler a, Denoising strength: 0.8

") + + pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) + mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8, elem_id=self.elem_id("mask_blur")) + direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction")) + noise_q = gr.Slider(label="Fall-off exponent (lower=higher detail)", minimum=0.0, maximum=4.0, step=0.01, value=1.0, elem_id=self.elem_id("noise_q")) + color_variation = gr.Slider(label="Color variation", minimum=0.0, maximum=1.0, step=0.01, value=0.05, elem_id=self.elem_id("color_variation")) + + return [info, pixels, mask_blur, direction, noise_q, color_variation] + + def run(self, p, _, pixels, mask_blur, direction, noise_q, color_variation): + initial_seed_and_info = [None, None] + + process_width = p.width + process_height = p.height + + p.mask_blur = mask_blur*4 + p.inpaint_full_res = False + p.inpainting_fill = 1 + p.do_not_save_samples = True + p.do_not_save_grid = True + + left = pixels if "left" in direction else 0 + right = pixels if "right" in direction else 0 + up = pixels if "up" in direction else 0 + down = pixels if "down" in direction else 0 + + init_img = p.init_images[0] + target_w = math.ceil((init_img.width + left + right) / 64) * 64 + target_h = math.ceil((init_img.height + up + down) / 64) * 64 + + if left > 0: + left = left * (target_w - init_img.width) // (left + right) + + if right > 0: + right = target_w - init_img.width - left + + if up > 0: + up = up * (target_h - init_img.height) // (up + down) + + if down > 0: + down = target_h - init_img.height - up + + def expand(init, count, expand_pixels, is_left=False, is_right=False, is_top=False, is_bottom=False): + is_horiz = is_left or is_right + is_vert = is_top or is_bottom + pixels_horiz = expand_pixels if is_horiz else 0 + pixels_vert = expand_pixels if is_vert else 0 + + images_to_process = [] + output_images = [] + for n in range(count): + res_w = init[n].width + pixels_horiz + res_h = init[n].height + pixels_vert + process_res_w = math.ceil(res_w / 64) * 64 + process_res_h = math.ceil(res_h / 64) * 64 + + img = Image.new("RGB", (process_res_w, process_res_h)) + img.paste(init[n], (pixels_horiz if is_left else 0, pixels_vert if is_top else 0)) + mask = Image.new("RGB", (process_res_w, process_res_h), "white") + draw = ImageDraw.Draw(mask) + draw.rectangle(( + expand_pixels + mask_blur if is_left else 0, + expand_pixels + mask_blur if is_top else 0, + mask.width - expand_pixels - mask_blur if is_right else res_w, + mask.height - expand_pixels - mask_blur if is_bottom else res_h, + ), fill="black") + + np_image = (np.asarray(img) / 255.0).astype(np.float64) + np_mask = (np.asarray(mask) / 255.0).astype(np.float64) + noised = get_matched_noise(np_image, np_mask, noise_q, color_variation) + output_images.append(Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB")) + + target_width = min(process_width, init[n].width + pixels_horiz) if is_horiz else img.width + target_height = min(process_height, init[n].height + pixels_vert) if is_vert else img.height + p.width = target_width if is_horiz else img.width + p.height = target_height if is_vert else img.height + + crop_region = ( + 0 if is_left else output_images[n].width - target_width, + 0 if is_top else output_images[n].height - target_height, + target_width if is_left else output_images[n].width, + target_height if is_top else output_images[n].height, + ) + mask = mask.crop(crop_region) + p.image_mask = mask + + image_to_process = output_images[n].crop(crop_region) + images_to_process.append(image_to_process) + + p.init_images = images_to_process + + latent_mask = Image.new("RGB", (p.width, p.height), "white") + draw = ImageDraw.Draw(latent_mask) + draw.rectangle(( + expand_pixels + mask_blur * 2 if is_left else 0, + expand_pixels + mask_blur * 2 if is_top else 0, + mask.width - expand_pixels - mask_blur * 2 if is_right else res_w, + mask.height - expand_pixels - mask_blur * 2 if is_bottom else res_h, + ), fill="black") + p.latent_mask = latent_mask + + proc = process_images(p) + + if initial_seed_and_info[0] is None: + initial_seed_and_info[0] = proc.seed + initial_seed_and_info[1] = proc.info + + for n in range(count): + output_images[n].paste(proc.images[n], (0 if is_left else output_images[n].width - proc.images[n].width, 0 if is_top else output_images[n].height - proc.images[n].height)) + output_images[n] = output_images[n].crop((0, 0, res_w, res_h)) + + return output_images + + batch_count = p.n_iter + batch_size = p.batch_size + p.n_iter = 1 + state.job_count = batch_count * ((1 if left > 0 else 0) + (1 if right > 0 else 0) + (1 if up > 0 else 0) + (1 if down > 0 else 0)) + all_processed_images = [] + + for i in range(batch_count): + imgs = [init_img] * batch_size + state.job = f"Batch {i + 1} out of {batch_count}" + + if left > 0: + imgs = expand(imgs, batch_size, left, is_left=True) + if right > 0: + imgs = expand(imgs, batch_size, right, is_right=True) + if up > 0: + imgs = expand(imgs, batch_size, up, is_top=True) + if down > 0: + imgs = expand(imgs, batch_size, down, is_bottom=True) + + all_processed_images += imgs + + all_images = all_processed_images + + combined_grid_image = images.image_grid(all_processed_images) + unwanted_grid_because_of_img_count = len(all_processed_images) < 2 and opts.grid_only_if_multiple + if opts.return_grid and not unwanted_grid_because_of_img_count: + all_images = [combined_grid_image] + all_processed_images + + res = Processed(p, all_images, initial_seed_and_info[0], initial_seed_and_info[1]) + + if opts.samples_save: + for img in all_processed_images: + images.save_image(img, p.outpath_samples, "", res.seed, p.prompt, opts.samples_format, info=res.info, p=p) + + if opts.grid_save and not unwanted_grid_because_of_img_count: + images.save_image(combined_grid_image, p.outpath_grids, "grid", res.seed, p.prompt, opts.samples_format, info=res.info, short_filename=not opts.grid_extended_filename, grid=True, p=p) + + return res diff --git a/scripts/poor_mans_outpainting.py b/scripts/poor_mans_outpainting.py index e80478b7c..c14fca6a6 100644 --- a/scripts/poor_mans_outpainting.py +++ b/scripts/poor_mans_outpainting.py @@ -1,146 +1,143 @@ -import math - -import modules.scripts as scripts -import gradio as gr -from PIL import Image, ImageDraw - -from modules import images, processing, devices -from modules.processing import Processed, process_images -from modules.shared import opts, cmd_opts, state - - -class Script(scripts.Script): - def title(self): - return "Outpainting alternative" - - def show(self, is_img2img): - return is_img2img - - def ui(self, is_img2img): - if not is_img2img: - return None - - pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) - mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id=self.elem_id("mask_blur")) - inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", elem_id=self.elem_id("inpainting_fill")) - direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction")) - - return [pixels, mask_blur, inpainting_fill, direction] - - def run(self, p, pixels, mask_blur, inpainting_fill, direction): - initial_seed = None - initial_info = None - - p.mask_blur = mask_blur * 2 - p.inpainting_fill = inpainting_fill - p.inpaint_full_res = False - - left = pixels if "left" in direction else 0 - right = pixels if "right" in direction else 0 - up = pixels if "up" in direction else 0 - down = pixels if "down" in direction else 0 - - init_img = p.init_images[0] - target_w = math.ceil((init_img.width + left + right) / 64) * 64 - target_h = math.ceil((init_img.height + up + down) / 64) * 64 - - if left > 0: - left = left * (target_w - init_img.width) // (left + right) - if right > 0: - right = target_w - init_img.width - left - - if up > 0: - up = up * (target_h - init_img.height) // (up + down) - - if down > 0: - down = target_h - init_img.height - up - - img = Image.new("RGB", (target_w, target_h)) - img.paste(init_img, (left, up)) - - mask = Image.new("L", (img.width, img.height), "white") - draw = ImageDraw.Draw(mask) - draw.rectangle(( - left + (mask_blur * 2 if left > 0 else 0), - up + (mask_blur * 2 if up > 0 else 0), - mask.width - right - (mask_blur * 2 if right > 0 else 0), - mask.height - down - (mask_blur * 2 if down > 0 else 0) - ), fill="black") - - latent_mask = Image.new("L", (img.width, img.height), "white") - latent_draw = ImageDraw.Draw(latent_mask) - latent_draw.rectangle(( - left + (mask_blur//2 if left > 0 else 0), - up + (mask_blur//2 if up > 0 else 0), - mask.width - right - (mask_blur//2 if right > 0 else 0), - mask.height - down - (mask_blur//2 if down > 0 else 0) - ), fill="black") - - devices.torch_gc() - - grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=pixels) - grid_mask = images.split_grid(mask, tile_w=p.width, tile_h=p.height, overlap=pixels) - grid_latent_mask = images.split_grid(latent_mask, tile_w=p.width, tile_h=p.height, overlap=pixels) - - p.n_iter = 1 - p.batch_size = 1 - p.do_not_save_grid = True - p.do_not_save_samples = True - - work = [] - work_mask = [] - work_latent_mask = [] - work_results = [] - - for (y, h, row), (_, _, row_mask), (_, _, row_latent_mask) in zip(grid.tiles, grid_mask.tiles, grid_latent_mask.tiles): - for tiledata, tiledata_mask, tiledata_latent_mask in zip(row, row_mask, row_latent_mask): - x, w = tiledata[0:2] - - if x >= left and x+w <= img.width - right and y >= up and y+h <= img.height - down: - continue - - work.append(tiledata[2]) - work_mask.append(tiledata_mask[2]) - work_latent_mask.append(tiledata_latent_mask[2]) - - batch_count = len(work) - print(f"Poor man's outpainting will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)}.") - - state.job_count = batch_count - - for i in range(batch_count): - p.init_images = [work[i]] - p.image_mask = work_mask[i] - p.latent_mask = work_latent_mask[i] - - state.job = f"Batch {i + 1} out of {batch_count}" - processed = process_images(p) - - if initial_seed is None: - initial_seed = processed.seed - initial_info = processed.info - - p.seed = processed.seed + 1 - work_results += processed.images - - - image_index = 0 - for y, h, row in grid.tiles: - for tiledata in row: - x, w = tiledata[0:2] - - if x >= left and x+w <= img.width - right and y >= up and y+h <= img.height - down: - continue - - tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height)) - image_index += 1 - - combined_image = images.combine_grid(grid) - - if opts.samples_save: - images.save_image(combined_image, p.outpath_samples, "", initial_seed, p.prompt, opts.samples_format, info=initial_info, p=p) - - processed = Processed(p, [combined_image], initial_seed, initial_info) - - return processed - +import math +import gradio as gr +from PIL import Image, ImageDraw +import modules.scripts as scripts +from modules import images, devices +from modules.processing import Processed, process_images +from modules.shared import opts, state + + +class Script(scripts.Script): + def title(self): + return "Outpainting alternative" + + def show(self, is_img2img): + return is_img2img + + def ui(self, is_img2img): + if not is_img2img: + return None + + pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) + mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id=self.elem_id("mask_blur")) + inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", elem_id=self.elem_id("inpainting_fill")) + direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction")) + + return [pixels, mask_blur, inpainting_fill, direction] + + def run(self, p, pixels, mask_blur, inpainting_fill, direction): + initial_seed = None + initial_info = None + + p.mask_blur = mask_blur * 2 + p.inpainting_fill = inpainting_fill + p.inpaint_full_res = False + + left = pixels if "left" in direction else 0 + right = pixels if "right" in direction else 0 + up = pixels if "up" in direction else 0 + down = pixels if "down" in direction else 0 + + init_img = p.init_images[0] + target_w = math.ceil((init_img.width + left + right) / 64) * 64 + target_h = math.ceil((init_img.height + up + down) / 64) * 64 + + if left > 0: + left = left * (target_w - init_img.width) // (left + right) + if right > 0: + right = target_w - init_img.width - left + + if up > 0: + up = up * (target_h - init_img.height) // (up + down) + + if down > 0: + down = target_h - init_img.height - up + + img = Image.new("RGB", (target_w, target_h)) + img.paste(init_img, (left, up)) + + mask = Image.new("L", (img.width, img.height), "white") + draw = ImageDraw.Draw(mask) + draw.rectangle(( + left + (mask_blur * 2 if left > 0 else 0), + up + (mask_blur * 2 if up > 0 else 0), + mask.width - right - (mask_blur * 2 if right > 0 else 0), + mask.height - down - (mask_blur * 2 if down > 0 else 0) + ), fill="black") + + latent_mask = Image.new("L", (img.width, img.height), "white") + latent_draw = ImageDraw.Draw(latent_mask) + latent_draw.rectangle(( + left + (mask_blur//2 if left > 0 else 0), + up + (mask_blur//2 if up > 0 else 0), + mask.width - right - (mask_blur//2 if right > 0 else 0), + mask.height - down - (mask_blur//2 if down > 0 else 0) + ), fill="black") + + devices.torch_gc() + + grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=pixels) + grid_mask = images.split_grid(mask, tile_w=p.width, tile_h=p.height, overlap=pixels) + grid_latent_mask = images.split_grid(latent_mask, tile_w=p.width, tile_h=p.height, overlap=pixels) + + p.n_iter = 1 + p.batch_size = 1 + p.do_not_save_grid = True + p.do_not_save_samples = True + + work = [] + work_mask = [] + work_latent_mask = [] + work_results = [] + + for (y, h, row), (_, _, row_mask), (_, _, row_latent_mask) in zip(grid.tiles, grid_mask.tiles, grid_latent_mask.tiles): + for tiledata, tiledata_mask, tiledata_latent_mask in zip(row, row_mask, row_latent_mask): + x, w = tiledata[0:2] + + if x >= left and x+w <= img.width - right and y >= up and y+h <= img.height - down: + continue + + work.append(tiledata[2]) + work_mask.append(tiledata_mask[2]) + work_latent_mask.append(tiledata_latent_mask[2]) + + batch_count = len(work) + print(f"Poor man's outpainting will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)}.") + + state.job_count = batch_count + + for i in range(batch_count): + p.init_images = [work[i]] + p.image_mask = work_mask[i] + p.latent_mask = work_latent_mask[i] + + state.job = f"Batch {i + 1} out of {batch_count}" + processed = process_images(p) + + if initial_seed is None: + initial_seed = processed.seed + initial_info = processed.info + + p.seed = processed.seed + 1 + work_results += processed.images + + + image_index = 0 + for y, h, row in grid.tiles: + for tiledata in row: + x, w = tiledata[0:2] + + if x >= left and x+w <= img.width - right and y >= up and y+h <= img.height - down: + continue + + tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height)) + image_index += 1 + + combined_image = images.combine_grid(grid) + + if opts.samples_save: + images.save_image(combined_image, p.outpath_samples, "", initial_seed, p.prompt, opts.samples_format, info=initial_info, p=p) + + processed = Processed(p, [combined_image], initial_seed, initial_info) + + return processed diff --git a/scripts/postprocessing_codeformer.py b/scripts/postprocessing_codeformer.py index 4bc2c3696..251443642 100644 --- a/scripts/postprocessing_codeformer.py +++ b/scripts/postprocessing_codeformer.py @@ -1,36 +1,34 @@ -from PIL import Image -import numpy as np - -from modules import scripts_postprocessing, codeformer_model -import gradio as gr - -from modules.ui_components import FormRow - - -class ScriptPostprocessingCodeFormer(scripts_postprocessing.ScriptPostprocessing): - name = "CodeFormer" - order = 3000 - - def ui(self): - with FormRow(): - codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="CodeFormer visibility", value=1.0, elem_id="extras_codeformer_visibility") - codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="CodeFormer weight (0 = max), 1 = min)", value=0.2, elem_id="extras_codeformer_weight") - - return { - "codeformer_visibility": codeformer_visibility, - "codeformer_weight": codeformer_weight, - } - - def process(self, pp: scripts_postprocessing.PostprocessedImage, codeformer_visibility, codeformer_weight): - if codeformer_visibility == 0: - return - - restored_img = codeformer_model.codeformer.restore(np.array(pp.image, dtype=np.uint8), w=codeformer_weight) - res = Image.fromarray(restored_img) - - if codeformer_visibility < 1.0: - res = Image.blend(pp.image, res, codeformer_visibility) - - pp.image = res - pp.info["CodeFormer visibility"] = round(codeformer_visibility, 3) - pp.info["CodeFormer weight"] = round(codeformer_weight, 3) +from PIL import Image +import numpy as np +import gradio as gr +from modules import scripts_postprocessing, codeformer_model +from modules.ui_components import FormRow + + +class ScriptPostprocessingCodeFormer(scripts_postprocessing.ScriptPostprocessing): + name = "CodeFormer" + order = 3000 + + def ui(self): + with FormRow(): + codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="CodeFormer visibility", value=1.0, elem_id="extras_codeformer_visibility") + codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="CodeFormer weight (0 = max), 1 = min)", value=0.2, elem_id="extras_codeformer_weight") + + return { + "codeformer_visibility": codeformer_visibility, + "codeformer_weight": codeformer_weight, + } + + def process(self, pp: scripts_postprocessing.PostprocessedImage, codeformer_visibility, codeformer_weight): + if codeformer_visibility == 0: + return + + restored_img = codeformer_model.codeformer.restore(np.array(pp.image, dtype=np.uint8), w=codeformer_weight) + res = Image.fromarray(restored_img) + + if codeformer_visibility < 1.0: + res = Image.blend(pp.image, res, codeformer_visibility) + + pp.image = res + pp.info["CodeFormer visibility"] = round(codeformer_visibility, 3) + pp.info["CodeFormer weight"] = round(codeformer_weight, 3) diff --git a/scripts/postprocessing_gfpgan.py b/scripts/postprocessing_gfpgan.py index d854f3f77..7cea285d7 100644 --- a/scripts/postprocessing_gfpgan.py +++ b/scripts/postprocessing_gfpgan.py @@ -1,33 +1,31 @@ -from PIL import Image -import numpy as np - -from modules import scripts_postprocessing, gfpgan_model -import gradio as gr - -from modules.ui_components import FormRow - - -class ScriptPostprocessingGfpGan(scripts_postprocessing.ScriptPostprocessing): - name = "GFPGAN" - order = 2000 - - def ui(self): - with FormRow(): - gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, elem_id="extras_gfpgan_visibility") - - return { - "gfpgan_visibility": gfpgan_visibility, - } - - def process(self, pp: scripts_postprocessing.PostprocessedImage, gfpgan_visibility): - if gfpgan_visibility == 0: - return - - restored_img = gfpgan_model.gfpgan_fix_faces(np.array(pp.image, dtype=np.uint8)) - res = Image.fromarray(restored_img) - - if gfpgan_visibility < 1.0: - res = Image.blend(pp.image, res, gfpgan_visibility) - - pp.image = res - pp.info["GFPGAN visibility"] = round(gfpgan_visibility, 3) +from PIL import Image +import numpy as np +import gradio as gr +from modules import scripts_postprocessing, gfpgan_model +from modules.ui_components import FormRow + + +class ScriptPostprocessingGfpGan(scripts_postprocessing.ScriptPostprocessing): + name = "GFPGAN" + order = 2000 + + def ui(self): + with FormRow(): + gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, elem_id="extras_gfpgan_visibility") + + return { + "gfpgan_visibility": gfpgan_visibility, + } + + def process(self, pp: scripts_postprocessing.PostprocessedImage, gfpgan_visibility): + if gfpgan_visibility == 0: + return + + restored_img = gfpgan_model.gfpgan_fix_faces(np.array(pp.image, dtype=np.uint8)) + res = Image.fromarray(restored_img) + + if gfpgan_visibility < 1.0: + res = Image.blend(pp.image, res, gfpgan_visibility) + + pp.image = res + pp.info["GFPGAN visibility"] = round(gfpgan_visibility, 3) diff --git a/scripts/postprocessing_upscale.py b/scripts/postprocessing_upscale.py index 43df74aff..bc04030af 100644 --- a/scripts/postprocessing_upscale.py +++ b/scripts/postprocessing_upscale.py @@ -1,134 +1,134 @@ -from PIL import Image -import numpy as np -import gradio as gr -from modules import scripts_postprocessing, shared -from modules.ui_components import FormRow, ToolButton -from modules.ui import switch_values_symbol - -upscale_cache = {} - - -class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing): - name = "Upscale" - order = 1000 - - def ui(self): - selected_tab = gr.State(value=0) # pylint: disable=abstract-class-instantiated - - with gr.Column(): - with FormRow(): - with gr.Tabs(elem_id="extras_resize_mode"): - with gr.TabItem('Scale by', elem_id="extras_scale_by_tab") as tab_scale_by: - upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize") - - with gr.TabItem('Scale to', elem_id="extras_scale_to_tab") as tab_scale_to: - with FormRow(): - with gr.Row(elem_id="upscaling_column_size", scale=4): - upscaling_resize_w = gr.Slider(minimum=64, maximum=4096, step=8, label="Width", value=512, elem_id="extras_upscaling_resize_w") - upscaling_resize_h = gr.Slider(minimum=64, maximum=4096, step=8, label="Height", value=512, elem_id="extras_upscaling_resize_h") - upscaling_res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="upscaling_res_switch_btn") - upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") - - with FormRow(): - extras_upscaler_1 = gr.Dropdown(label='Upscaler', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name) - - with FormRow(): - extras_upscaler_2 = gr.Dropdown(label='Secondary Upscaler', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name) - extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=0.0, elem_id="extras_upscaler_2_visibility") - - upscaling_res_switch_btn.click(lambda w, h: (h, w), inputs=[upscaling_resize_w, upscaling_resize_h], outputs=[upscaling_resize_w, upscaling_resize_h], show_progress=False) - tab_scale_by.select(fn=lambda: 0, inputs=[], outputs=[selected_tab]) - tab_scale_to.select(fn=lambda: 1, inputs=[], outputs=[selected_tab]) - - return { - "upscale_mode": selected_tab, - "upscale_by": upscaling_resize, - "upscale_to_width": upscaling_resize_w, - "upscale_to_height": upscaling_resize_h, - "upscale_crop": upscaling_crop, - "upscaler_1_name": extras_upscaler_1, - "upscaler_2_name": extras_upscaler_2, - "upscaler_2_visibility": extras_upscaler_2_visibility, - } - - def upscale(self, image, info, upscaler, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop): - if upscale_mode == 1: - upscale_by = max(upscale_to_width/image.width, upscale_to_height/image.height) - info["Postprocess upscale to"] = f"{upscale_to_width}x{upscale_to_height}" - else: - info["Postprocess upscale by"] = upscale_by - - cache_key = (hash(np.array(image.getdata()).tobytes()), upscaler.name, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop) - cached_image = upscale_cache.pop(cache_key, None) - - if cached_image is not None: - image = cached_image - else: - image = upscaler.scaler.upscale(image, upscale_by, upscaler.data_path) - - upscale_cache[cache_key] = image - if len(upscale_cache) > shared.opts.upscaling_max_images_in_cache: - upscale_cache.pop(next(iter(upscale_cache), None), None) - - if upscale_mode == 1 and upscale_crop: - cropped = Image.new("RGB", (upscale_to_width, upscale_to_height)) - cropped.paste(image, box=(upscale_to_width // 2 - image.width // 2, upscale_to_height // 2 - image.height // 2)) - image = cropped - info["Postprocess crop to"] = f"{image.width}x{image.height}" - - return image - - def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_mode=1, upscale_by=2.0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0): # pylint: disable=arguments-differ - if upscaler_1_name == "None": - upscaler_1_name = None - - upscaler1 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_1_name]), None) - assert upscaler1 or (upscaler_1_name is None), f'could not find upscaler named {upscaler_1_name}' - - if not upscaler1: - return - - if upscaler_2_name == "None": - upscaler_2_name = None - - upscaler2 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_2_name and x.name != "None"]), None) - assert upscaler2 or (upscaler_2_name is None), f'could not find upscaler named {upscaler_2_name}' - - upscaled_image = self.upscale(pp.image, pp.info, upscaler1, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop) - pp.info["Postprocess upscaler"] = upscaler1.name - - if upscaler2 and upscaler_2_visibility > 0: - second_upscale = self.upscale(pp.image, pp.info, upscaler2, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop) - upscaled_image = Image.blend(upscaled_image, second_upscale, upscaler_2_visibility) - - pp.info["Postprocess upscaler 2"] = upscaler2.name - - pp.image = upscaled_image - - def image_changed(self): - upscale_cache.clear() - - -class ScriptPostprocessingUpscaleSimple(ScriptPostprocessingUpscale): - name = "Simple Upscale" - order = 900 - - def ui(self): - with FormRow(): - upscaler_name = gr.Dropdown(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name) - upscale_by = gr.Slider(minimum=0.05, maximum=8.0, step=0.05, label="Upscale by", value=2) - - return { - "upscale_by": upscale_by, - "upscaler_name": upscaler_name, - } - - def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_by=2.0, upscaler_name=None): # pylint: disable=arguments-differ - if upscaler_name is None or upscaler_name == "None": - return - - upscaler1 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_name]), None) - assert upscaler1, f'could not find upscaler named {upscaler_name}' - - pp.image = self.upscale(pp.image, pp.info, upscaler1, 0, upscale_by, 0, 0, False) - pp.info["Postprocess upscaler"] = upscaler1.name +from PIL import Image +import numpy as np +import gradio as gr +from modules import scripts_postprocessing, shared +from modules.ui_components import FormRow, ToolButton +from modules.ui import switch_values_symbol + +upscale_cache = {} + + +class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing): + name = "Upscale" + order = 1000 + + def ui(self): + selected_tab = gr.State(value=0) # pylint: disable=abstract-class-instantiated + + with gr.Column(): + with FormRow(): + with gr.Tabs(elem_id="extras_resize_mode"): + with gr.TabItem('Scale by', elem_id="extras_scale_by_tab") as tab_scale_by: + upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize") + + with gr.TabItem('Scale to', elem_id="extras_scale_to_tab") as tab_scale_to: + with FormRow(): + with gr.Row(elem_id="upscaling_column_size", scale=4): + upscaling_resize_w = gr.Slider(minimum=64, maximum=4096, step=8, label="Width", value=512, elem_id="extras_upscaling_resize_w") + upscaling_resize_h = gr.Slider(minimum=64, maximum=4096, step=8, label="Height", value=512, elem_id="extras_upscaling_resize_h") + upscaling_res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="upscaling_res_switch_btn") + upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") + + with FormRow(): + extras_upscaler_1 = gr.Dropdown(label='Upscaler', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name) + + with FormRow(): + extras_upscaler_2 = gr.Dropdown(label='Secondary Upscaler', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name) + extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=0.0, elem_id="extras_upscaler_2_visibility") + + upscaling_res_switch_btn.click(lambda w, h: (h, w), inputs=[upscaling_resize_w, upscaling_resize_h], outputs=[upscaling_resize_w, upscaling_resize_h], show_progress=False) + tab_scale_by.select(fn=lambda: 0, inputs=[], outputs=[selected_tab]) + tab_scale_to.select(fn=lambda: 1, inputs=[], outputs=[selected_tab]) + + return { + "upscale_mode": selected_tab, + "upscale_by": upscaling_resize, + "upscale_to_width": upscaling_resize_w, + "upscale_to_height": upscaling_resize_h, + "upscale_crop": upscaling_crop, + "upscaler_1_name": extras_upscaler_1, + "upscaler_2_name": extras_upscaler_2, + "upscaler_2_visibility": extras_upscaler_2_visibility, + } + + def upscale(self, image, info, upscaler, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop): + if upscale_mode == 1: + upscale_by = max(upscale_to_width/image.width, upscale_to_height/image.height) + info["Postprocess upscale to"] = f"{upscale_to_width}x{upscale_to_height}" + else: + info["Postprocess upscale by"] = upscale_by + + cache_key = (hash(np.array(image.getdata()).tobytes()), upscaler.name, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop) + cached_image = upscale_cache.pop(cache_key, None) + + if cached_image is not None: + image = cached_image + else: + image = upscaler.scaler.upscale(image, upscale_by, upscaler.data_path) + + upscale_cache[cache_key] = image + if len(upscale_cache) > shared.opts.upscaling_max_images_in_cache: + upscale_cache.pop(next(iter(upscale_cache), None), None) + + if upscale_mode == 1 and upscale_crop: + cropped = Image.new("RGB", (upscale_to_width, upscale_to_height)) + cropped.paste(image, box=(upscale_to_width // 2 - image.width // 2, upscale_to_height // 2 - image.height // 2)) + image = cropped + info["Postprocess crop to"] = f"{image.width}x{image.height}" + + return image + + def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_mode=1, upscale_by=2.0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0): # pylint: disable=arguments-differ + if upscaler_1_name == "None": + upscaler_1_name = None + + upscaler1 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_1_name]), None) + assert upscaler1 or (upscaler_1_name is None), f'could not find upscaler named {upscaler_1_name}' + + if not upscaler1: + return + + if upscaler_2_name == "None": + upscaler_2_name = None + + upscaler2 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_2_name and x.name != "None"]), None) + assert upscaler2 or (upscaler_2_name is None), f'could not find upscaler named {upscaler_2_name}' + + upscaled_image = self.upscale(pp.image, pp.info, upscaler1, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop) + pp.info["Postprocess upscaler"] = upscaler1.name + + if upscaler2 and upscaler_2_visibility > 0: + second_upscale = self.upscale(pp.image, pp.info, upscaler2, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop) + upscaled_image = Image.blend(upscaled_image, second_upscale, upscaler_2_visibility) + + pp.info["Postprocess upscaler 2"] = upscaler2.name + + pp.image = upscaled_image + + def image_changed(self): + upscale_cache.clear() + + +class ScriptPostprocessingUpscaleSimple(ScriptPostprocessingUpscale): + name = "Simple Upscale" + order = 900 + + def ui(self): + with FormRow(): + upscaler_name = gr.Dropdown(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name) + upscale_by = gr.Slider(minimum=0.05, maximum=8.0, step=0.05, label="Upscale by", value=2) + + return { + "upscale_by": upscale_by, + "upscaler_name": upscaler_name, + } + + def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_by=2.0, upscaler_name=None): # pylint: disable=arguments-differ + if upscaler_name is None or upscaler_name == "None": + return + + upscaler1 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_name]), None) + assert upscaler1, f'could not find upscaler named {upscaler_name}' + + pp.image = self.upscale(pp.image, pp.info, upscaler1, 0, upscale_by, 0, 0, False) + pp.info["Postprocess upscaler"] = upscaler1.name diff --git a/scripts/prompt_matrix.py b/scripts/prompt_matrix.py index e9b115170..6772074ad 100644 --- a/scripts/prompt_matrix.py +++ b/scripts/prompt_matrix.py @@ -1,111 +1,106 @@ -import math -from collections import namedtuple -from copy import copy -import random - -import modules.scripts as scripts -import gradio as gr - -from modules import images -from modules.processing import process_images, Processed -from modules.shared import opts, cmd_opts, state -import modules.sd_samplers - - -def draw_xy_grid(xs, ys, x_label, y_label, cell): - res = [] - - ver_texts = [[images.GridAnnotation(y_label(y))] for y in ys] - hor_texts = [[images.GridAnnotation(x_label(x))] for x in xs] - - first_processed = None - - state.job_count = len(xs) * len(ys) - - for iy, y in enumerate(ys): - for ix, x in enumerate(xs): - state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}" - - processed = cell(x, y) - if first_processed is None: - first_processed = processed - - res.append(processed.images[0]) - - grid = images.image_grid(res, rows=len(ys)) - grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts) - - first_processed.images = [grid] - - return first_processed - - -class Script(scripts.Script): - def title(self): - return "Prompt matrix" - - def ui(self, is_img2img): - gr.HTML('
') - with gr.Row(): - with gr.Column(): - put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', value=False, elem_id=self.elem_id("put_at_start")) - different_seeds = gr.Checkbox(label='Use different seed for each picture', value=False, elem_id=self.elem_id("different_seeds")) - with gr.Column(): - prompt_type = gr.Radio(["positive", "negative"], label="Select prompt", elem_id=self.elem_id("prompt_type"), value="positive") - variations_delimiter = gr.Radio(["comma", "space"], label="Select joining char", elem_id=self.elem_id("variations_delimiter"), value="comma") - with gr.Column(): - margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size")) - - return [put_at_start, different_seeds, prompt_type, variations_delimiter, margin_size] - - def run(self, p, put_at_start, different_seeds, prompt_type, variations_delimiter, margin_size): - modules.processing.fix_seed(p) - # Raise error if promp type is not positive or negative - if prompt_type not in ["positive", "negative"]: - raise ValueError(f"Unknown prompt type {prompt_type}") - # Raise error if variations delimiter is not comma or space - if variations_delimiter not in ["comma", "space"]: - raise ValueError(f"Unknown variations delimiter {variations_delimiter}") - - prompt = p.prompt if prompt_type == "positive" else p.negative_prompt - original_prompt = prompt[0] if type(prompt) == list else prompt - positive_prompt = p.prompt[0] if type(p.prompt) == list else p.prompt - - delimiter = ", " if variations_delimiter == "comma" else " " - - all_prompts = [] - prompt_matrix_parts = original_prompt.split("|") - combination_count = 2 ** (len(prompt_matrix_parts) - 1) - for combination_num in range(combination_count): - selected_prompts = [text.strip().strip(',') for n, text in enumerate(prompt_matrix_parts[1:]) if combination_num & (1 << n)] - - if put_at_start: - selected_prompts = selected_prompts + [prompt_matrix_parts[0]] - else: - selected_prompts = [prompt_matrix_parts[0]] + selected_prompts - - all_prompts.append(delimiter.join(selected_prompts)) - - p.n_iter = math.ceil(len(all_prompts) / p.batch_size) - p.do_not_save_grid = True - - print(f"Prompt matrix will create {len(all_prompts)} images using a total of {p.n_iter} batches.") - - if prompt_type == "positive": - p.prompt = all_prompts - else: - p.negative_prompt = all_prompts - p.seed = [p.seed + (i if different_seeds else 0) for i in range(len(all_prompts))] - p.prompt_for_display = positive_prompt - processed = process_images(p) - - grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2)) - grid = images.draw_prompt_matrix(grid, processed.images[0].width, processed.images[0].height, prompt_matrix_parts, margin_size) - processed.images.insert(0, grid) - processed.index_of_first_image = 1 - processed.infotexts.insert(0, processed.infotexts[0]) - - if opts.grid_save: - images.save_image(processed.images[0], p.outpath_grids, "prompt_matrix", extension=opts.grid_format, prompt=original_prompt, seed=processed.seed, grid=True, p=p) - - return processed +import math +import gradio as gr +import modules.scripts as scripts +from modules import images +from modules.processing import process_images +from modules.shared import opts, state +import modules.sd_samplers + + +def draw_xy_grid(xs, ys, x_label, y_label, cell): + res = [] + + ver_texts = [[images.GridAnnotation(y_label(y))] for y in ys] + hor_texts = [[images.GridAnnotation(x_label(x))] for x in xs] + + first_processed = None + + state.job_count = len(xs) * len(ys) + + for iy, y in enumerate(ys): + for ix, x in enumerate(xs): + state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}" + + processed = cell(x, y) + if first_processed is None: + first_processed = processed + + res.append(processed.images[0]) + + grid = images.image_grid(res, rows=len(ys)) + grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts) + + first_processed.images = [grid] + + return first_processed + + +class Script(scripts.Script): + def title(self): + return "Prompt matrix" + + def ui(self, is_img2img): + gr.HTML('
') + with gr.Row(): + with gr.Column(): + put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', value=False, elem_id=self.elem_id("put_at_start")) + different_seeds = gr.Checkbox(label='Use different seed for each picture', value=False, elem_id=self.elem_id("different_seeds")) + with gr.Column(): + prompt_type = gr.Radio(["positive", "negative"], label="Select prompt", elem_id=self.elem_id("prompt_type"), value="positive") + variations_delimiter = gr.Radio(["comma", "space"], label="Select joining char", elem_id=self.elem_id("variations_delimiter"), value="comma") + with gr.Column(): + margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size")) + + return [put_at_start, different_seeds, prompt_type, variations_delimiter, margin_size] + + def run(self, p, put_at_start, different_seeds, prompt_type, variations_delimiter, margin_size): + modules.processing.fix_seed(p) + # Raise error if promp type is not positive or negative + if prompt_type not in ["positive", "negative"]: + raise ValueError(f"Unknown prompt type {prompt_type}") + # Raise error if variations delimiter is not comma or space + if variations_delimiter not in ["comma", "space"]: + raise ValueError(f"Unknown variations delimiter {variations_delimiter}") + + prompt = p.prompt if prompt_type == "positive" else p.negative_prompt + original_prompt = prompt[0] if type(prompt) == list else prompt + positive_prompt = p.prompt[0] if type(p.prompt) == list else p.prompt + + delimiter = ", " if variations_delimiter == "comma" else " " + + all_prompts = [] + prompt_matrix_parts = original_prompt.split("|") + combination_count = 2 ** (len(prompt_matrix_parts) - 1) + for combination_num in range(combination_count): + selected_prompts = [text.strip().strip(',') for n, text in enumerate(prompt_matrix_parts[1:]) if combination_num & (1 << n)] + + if put_at_start: + selected_prompts = selected_prompts + [prompt_matrix_parts[0]] + else: + selected_prompts = [prompt_matrix_parts[0]] + selected_prompts + + all_prompts.append(delimiter.join(selected_prompts)) + + p.n_iter = math.ceil(len(all_prompts) / p.batch_size) + p.do_not_save_grid = True + + print(f"Prompt matrix will create {len(all_prompts)} images using a total of {p.n_iter} batches.") + + if prompt_type == "positive": + p.prompt = all_prompts + else: + p.negative_prompt = all_prompts + p.seed = [p.seed + (i if different_seeds else 0) for i in range(len(all_prompts))] + p.prompt_for_display = positive_prompt + processed = process_images(p) + + grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2)) + grid = images.draw_prompt_matrix(grid, processed.images[0].width, processed.images[0].height, prompt_matrix_parts, margin_size) + processed.images.insert(0, grid) + processed.index_of_first_image = 1 + processed.infotexts.insert(0, processed.infotexts[0]) + + if opts.grid_save: + images.save_image(processed.images[0], p.outpath_grids, "prompt_matrix", extension=opts.grid_format, prompt=original_prompt, seed=processed.seed, grid=True, p=p) + + return processed diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index fe30d4b07..138eb775c 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -1,174 +1,168 @@ -import copy -import math -import os -import random -import sys -import modules.shared as shared -import shlex - -import modules.scripts as scripts -import gradio as gr - -from modules import sd_samplers, errors -from modules.processing import Processed, process_images -from PIL import Image -from modules.shared import opts, cmd_opts, state - - -def process_string_tag(tag): - return tag - - -def process_int_tag(tag): - return int(tag) - - -def process_float_tag(tag): - return float(tag) - - -def process_boolean_tag(tag): - return True if (tag == "true") else False - - -prompt_tags = { - "sd_model": None, - "outpath_samples": process_string_tag, - "outpath_grids": process_string_tag, - "prompt_for_display": process_string_tag, - "prompt": process_string_tag, - "negative_prompt": process_string_tag, - "styles": process_string_tag, - "seed": process_int_tag, - "subseed_strength": process_float_tag, - "subseed": process_int_tag, - "seed_resize_from_h": process_int_tag, - "seed_resize_from_w": process_int_tag, - "sampler_index": process_int_tag, - "sampler_name": process_string_tag, - "batch_size": process_int_tag, - "n_iter": process_int_tag, - "steps": process_int_tag, - "cfg_scale": process_float_tag, - "width": process_int_tag, - "height": process_int_tag, - "restore_faces": process_boolean_tag, - "tiling": process_boolean_tag, - "do_not_save_samples": process_boolean_tag, - "do_not_save_grid": process_boolean_tag -} - - -def cmdargs(line): - args = shlex.split(line) - pos = 0 - res = {} - - while pos < len(args): - arg = args[pos] - - assert arg.startswith("--"), f'must start with "--": {arg}' - assert pos+1 < len(args), f'missing argument for command line option {arg}' - - tag = arg[2:] - - if tag == "prompt" or tag == "negative_prompt": - pos += 1 - prompt = args[pos] - pos += 1 - while pos < len(args) and not args[pos].startswith("--"): - prompt += " " - prompt += args[pos] - pos += 1 - res[tag] = prompt - continue - - - func = prompt_tags.get(tag, None) - assert func, f'unknown commandline option: {arg}' - - val = args[pos+1] - if tag == "sampler_name": - val = sd_samplers.samplers_map.get(val.lower(), None) - - res[tag] = func(val) - - pos += 2 - - return res - - -def load_prompt_file(file): - if file is None: - lines = [] - else: - lines = [x.strip() for x in file.decode('utf8', errors='ignore').split("\n")] - - return None, "\n".join(lines), gr.update(lines=7) - - -class Script(scripts.Script): - def title(self): - return "Prompts from file" - - def ui(self, is_img2img): - checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False, elem_id=self.elem_id("checkbox_iterate")) - checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=self.elem_id("checkbox_iterate_batch")) - - prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1, elem_id=self.elem_id("prompt_txt")) - file = gr.File(label="Upload prompt inputs", type='binary', elem_id=self.elem_id("file")) - - file.change(fn=load_prompt_file, inputs=[file], outputs=[file, prompt_txt, prompt_txt]) - - # We start at one line. When the text changes, we jump to seven lines, or two lines if no \n. - # We don't shrink back to 1, because that causes the control to ignore [enter], and it may - # be unclear to the user that shift-enter is needed. - prompt_txt.change(lambda tb: gr.update(lines=7) if ("\n" in tb) else gr.update(lines=2), inputs=[prompt_txt], outputs=[prompt_txt]) - return [checkbox_iterate, checkbox_iterate_batch, prompt_txt] - - def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_txt: str): - lines = [x.strip() for x in prompt_txt.splitlines()] - lines = [x for x in lines if len(x) > 0] - - job_count = 0 - jobs = [] - - for line in lines: - if "--" in line: - try: - args = cmdargs(line) - except Exception as e: - errors.display(e, f'parsing prompts: {line}') - args = {"prompt": line} - else: - args = {"prompt": line} - - job_count += args.get("n_iter", p.n_iter) - - jobs.append(args) - - print(f"Will process {len(lines)} lines in {job_count} jobs.") - if (checkbox_iterate or checkbox_iterate_batch) and p.seed == -1: - p.seed = int(random.randrange(4294967294)) - - state.job_count = job_count - - images = [] - all_prompts = [] - infotexts = [] - for n, args in enumerate(jobs): - state.job = f"{state.job_no + 1} out of {state.job_count}" - - copy_p = copy.copy(p) - for k, v in args.items(): - setattr(copy_p, k, v) - - proc = process_images(copy_p) - images += proc.images - - if checkbox_iterate: - p.seed = p.seed + (p.batch_size * p.n_iter) - all_prompts += proc.all_prompts - infotexts += proc.infotexts - - return Processed(p, images, p.seed, "", all_prompts=all_prompts, infotexts=infotexts) +import copy +import random +import shlex +import gradio as gr +from PIL import Image +import modules.scripts as scripts +from modules import sd_samplers, errors +from modules.processing import Processed, process_images +from modules.shared import state + + +def process_string_tag(tag): + return tag + + +def process_int_tag(tag): + return int(tag) + + +def process_float_tag(tag): + return float(tag) + + +def process_boolean_tag(tag): + return True if (tag == "true") else False + + +prompt_tags = { + "sd_model": None, + "outpath_samples": process_string_tag, + "outpath_grids": process_string_tag, + "prompt_for_display": process_string_tag, + "prompt": process_string_tag, + "negative_prompt": process_string_tag, + "styles": process_string_tag, + "seed": process_int_tag, + "subseed_strength": process_float_tag, + "subseed": process_int_tag, + "seed_resize_from_h": process_int_tag, + "seed_resize_from_w": process_int_tag, + "sampler_index": process_int_tag, + "sampler_name": process_string_tag, + "batch_size": process_int_tag, + "n_iter": process_int_tag, + "steps": process_int_tag, + "cfg_scale": process_float_tag, + "width": process_int_tag, + "height": process_int_tag, + "restore_faces": process_boolean_tag, + "tiling": process_boolean_tag, + "do_not_save_samples": process_boolean_tag, + "do_not_save_grid": process_boolean_tag +} + + +def cmdargs(line): + args = shlex.split(line) + pos = 0 + res = {} + + while pos < len(args): + arg = args[pos] + + assert arg.startswith("--"), f'must start with "--": {arg}' + assert pos+1 < len(args), f'missing argument for command line option {arg}' + + tag = arg[2:] + + if tag == "prompt" or tag == "negative_prompt": + pos += 1 + prompt = args[pos] + pos += 1 + while pos < len(args) and not args[pos].startswith("--"): + prompt += " " + prompt += args[pos] + pos += 1 + res[tag] = prompt + continue + + + func = prompt_tags.get(tag, None) + assert func, f'unknown commandline option: {arg}' + + val = args[pos+1] + if tag == "sampler_name": + val = sd_samplers.samplers_map.get(val.lower(), None) + + res[tag] = func(val) + + pos += 2 + + return res + + +def load_prompt_file(file): + if file is None: + lines = [] + else: + lines = [x.strip() for x in file.decode('utf8', errors='ignore').split("\n")] + + return None, "\n".join(lines), gr.update(lines=7) + + +class Script(scripts.Script): + def title(self): + return "Prompts from file" + + def ui(self, is_img2img): + checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False, elem_id=self.elem_id("checkbox_iterate")) + checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=self.elem_id("checkbox_iterate_batch")) + + prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1, elem_id=self.elem_id("prompt_txt")) + file = gr.File(label="Upload prompt inputs", type='binary', elem_id=self.elem_id("file")) + + file.change(fn=load_prompt_file, inputs=[file], outputs=[file, prompt_txt, prompt_txt]) + + # We start at one line. When the text changes, we jump to seven lines, or two lines if no \n. + # We don't shrink back to 1, because that causes the control to ignore [enter], and it may + # be unclear to the user that shift-enter is needed. + prompt_txt.change(lambda tb: gr.update(lines=7) if ("\n" in tb) else gr.update(lines=2), inputs=[prompt_txt], outputs=[prompt_txt]) + return [checkbox_iterate, checkbox_iterate_batch, prompt_txt] + + def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_txt: str): + lines = [x.strip() for x in prompt_txt.splitlines()] + lines = [x for x in lines if len(x) > 0] + + job_count = 0 + jobs = [] + + for line in lines: + if "--" in line: + try: + args = cmdargs(line) + except Exception as e: + errors.display(e, f'parsing prompts: {line}') + args = {"prompt": line} + else: + args = {"prompt": line} + + job_count += args.get("n_iter", p.n_iter) + + jobs.append(args) + + print(f"Will process {len(lines)} lines in {job_count} jobs.") + if (checkbox_iterate or checkbox_iterate_batch) and p.seed == -1: + p.seed = int(random.randrange(4294967294)) + + state.job_count = job_count + + images = [] + all_prompts = [] + infotexts = [] + for n, args in enumerate(jobs): + state.job = f"{state.job_no + 1} out of {state.job_count}" + + copy_p = copy.copy(p) + for k, v in args.items(): + setattr(copy_p, k, v) + + proc = process_images(copy_p) + images += proc.images + + if checkbox_iterate: + p.seed = p.seed + (p.batch_size * p.n_iter) + all_prompts += proc.all_prompts + infotexts += proc.infotexts + + return Processed(p, images, p.seed, "", all_prompts=all_prompts, infotexts=infotexts) diff --git a/scripts/sd_upscale.py b/scripts/sd_upscale.py index 332d76d91..2ce97a6c4 100644 --- a/scripts/sd_upscale.py +++ b/scripts/sd_upscale.py @@ -1,101 +1,99 @@ -import math - -import modules.scripts as scripts -import gradio as gr -from PIL import Image - -from modules import processing, shared, sd_samplers, images, devices -from modules.processing import Processed -from modules.shared import opts, cmd_opts, state - - -class Script(scripts.Script): - def title(self): - return "SD upscale" - - def show(self, is_img2img): - return is_img2img - - def ui(self, is_img2img): - info = gr.HTML("

Will upscale the image by the selected scale factor; use width and height sliders to set tile size

") - overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, elem_id=self.elem_id("overlap")) - scale_factor = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label='Scale Factor', value=2.0, elem_id=self.elem_id("scale_factor")) - upscaler_index = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index", elem_id=self.elem_id("upscaler_index")) - - return [info, overlap, upscaler_index, scale_factor] - - def run(self, p, _, overlap, upscaler_index, scale_factor): - if isinstance(upscaler_index, str): - upscaler_index = [x.name.lower() for x in shared.sd_upscalers].index(upscaler_index.lower()) - processing.fix_seed(p) - upscaler = shared.sd_upscalers[upscaler_index] - - p.extra_generation_params["SD upscale overlap"] = overlap - p.extra_generation_params["SD upscale upscaler"] = upscaler.name - - initial_info = None - seed = p.seed - - init_img = p.init_images[0] - init_img = images.flatten(init_img, opts.img2img_background_color) - - if upscaler.name != "None": - img = upscaler.scaler.upscale(init_img, scale_factor, upscaler.data_path) - else: - img = init_img - - devices.torch_gc() - - grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=overlap) - - batch_size = p.batch_size - upscale_count = p.n_iter - p.n_iter = 1 - p.do_not_save_grid = True - p.do_not_save_samples = True - - work = [] - - for y, h, row in grid.tiles: - for tiledata in row: - work.append(tiledata[2]) - - batch_count = math.ceil(len(work) / batch_size) - state.job_count = batch_count * upscale_count - - print(f"SD upscaling will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)} per upscale in a total of {state.job_count} batches.") - - result_images = [] - for n in range(upscale_count): - start_seed = seed + n - p.seed = start_seed - - work_results = [] - for i in range(batch_count): - p.batch_size = batch_size - p.init_images = work[i * batch_size:(i + 1) * batch_size] - - state.job = f"Batch {i + 1 + n * batch_count} out of {state.job_count}" - processed = processing.process_images(p) - - if initial_info is None: - initial_info = processed.info - - p.seed = processed.seed + 1 - work_results += processed.images - - image_index = 0 - for y, h, row in grid.tiles: - for tiledata in row: - tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height)) - image_index += 1 - - combined_image = images.combine_grid(grid) - result_images.append(combined_image) - - if opts.samples_save: - images.save_image(combined_image, p.outpath_samples, "", start_seed, p.prompt, opts.samples_format, info=initial_info, p=p) - - processed = Processed(p, result_images, seed, initial_info) - - return processed +import math +import gradio as gr +from PIL import Image +import modules.scripts as scripts +from modules import processing, shared, images, devices +from modules.processing import Processed +from modules.shared import opts, state + + +class Script(scripts.Script): + def title(self): + return "SD upscale" + + def show(self, is_img2img): + return is_img2img + + def ui(self, is_img2img): + info = gr.HTML("

Will upscale the image by the selected scale factor; use width and height sliders to set tile size

") + overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, elem_id=self.elem_id("overlap")) + scale_factor = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label='Scale Factor', value=2.0, elem_id=self.elem_id("scale_factor")) + upscaler_index = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index", elem_id=self.elem_id("upscaler_index")) + + return [info, overlap, upscaler_index, scale_factor] + + def run(self, p, _, overlap, upscaler_index, scale_factor): + if isinstance(upscaler_index, str): + upscaler_index = [x.name.lower() for x in shared.sd_upscalers].index(upscaler_index.lower()) + processing.fix_seed(p) + upscaler = shared.sd_upscalers[upscaler_index] + + p.extra_generation_params["SD upscale overlap"] = overlap + p.extra_generation_params["SD upscale upscaler"] = upscaler.name + + initial_info = None + seed = p.seed + + init_img = p.init_images[0] + init_img = images.flatten(init_img, opts.img2img_background_color) + + if upscaler.name != "None": + img = upscaler.scaler.upscale(init_img, scale_factor, upscaler.data_path) + else: + img = init_img + + devices.torch_gc() + + grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=overlap) + + batch_size = p.batch_size + upscale_count = p.n_iter + p.n_iter = 1 + p.do_not_save_grid = True + p.do_not_save_samples = True + + work = [] + + for y, h, row in grid.tiles: + for tiledata in row: + work.append(tiledata[2]) + + batch_count = math.ceil(len(work) / batch_size) + state.job_count = batch_count * upscale_count + + print(f"SD upscaling will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)} per upscale in a total of {state.job_count} batches.") + + result_images = [] + for n in range(upscale_count): + start_seed = seed + n + p.seed = start_seed + + work_results = [] + for i in range(batch_count): + p.batch_size = batch_size + p.init_images = work[i * batch_size:(i + 1) * batch_size] + + state.job = f"Batch {i + 1 + n * batch_count} out of {state.job_count}" + processed = processing.process_images(p) + + if initial_info is None: + initial_info = processed.info + + p.seed = processed.seed + 1 + work_results += processed.images + + image_index = 0 + for y, h, row in grid.tiles: + for tiledata in row: + tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height)) + image_index += 1 + + combined_image = images.combine_grid(grid) + result_images.append(combined_image) + + if opts.samples_save: + images.save_image(combined_image, p.outpath_samples, "", start_seed, p.prompt, opts.samples_format, info=initial_info, p=p) + + processed = Processed(p, result_images, seed, initial_info) + + return processed diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 41f226848..58ab7507f 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -1,675 +1,675 @@ - # pylint: disable=unused-argument, attribute-defined-outside-init - -import re -import csv -import random -from collections import namedtuple -from copy import copy -from itertools import permutations, chain -from io import StringIO -from PIL import Image -import numpy as np -import gradio as gr -import modules.scripts as scripts -import modules.shared as shared -from modules import images, sd_samplers, processing, sd_models, sd_vae -from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img -from modules.ui_components import ToolButton - -fill_values_symbol = "\U0001f4d2" # 📒 -AxisInfo = namedtuple('AxisInfo', ['axis', 'values']) - - -def apply_field(field): - def fun(p, x, xs): - setattr(p, field, x) - return fun - - -def apply_prompt(p, x, xs): - if xs[0] not in p.prompt and xs[0] not in p.negative_prompt: - shared.log.warning(f"XYZ grid: prompt S/R did not find {xs[0]} in prompt or negative prompt.") - else: - p.prompt = p.prompt.replace(xs[0], x) - p.negative_prompt = p.negative_prompt.replace(xs[0], x) - - -def apply_order(p, x, xs): - token_order = [] - for token in x: - token_order.append((p.prompt.find(token), token)) - token_order.sort(key=lambda t: t[0]) - prompt_parts = [] - for _, token in token_order: - n = p.prompt.find(token) - prompt_parts.append(p.prompt[0:n]) - p.prompt = p.prompt[n + len(token):] - prompt_tmp = "" - for idx, part in enumerate(prompt_parts): - prompt_tmp += part - prompt_tmp += x[idx] - p.prompt = prompt_tmp + p.prompt - - -def apply_sampler(p, x, xs): - sampler_name = sd_samplers.samplers_map.get(x.lower(), None) - if sampler_name is None: - shared.log.warning(f"XYZ grid: unknown sampler: {x}") - else: - p.sampler_name = sampler_name - - -def confirm_samplers(p, xs): - for x in xs: - if x.lower() not in sd_samplers.samplers_map: - shared.log.warning(f"XYZ grid: unknown sampler: {x}") - - -def apply_checkpoint(p, x, xs): - if x == shared.opts.sd_model_checkpoint: - return - info = sd_models.get_closet_checkpoint_match(x) - if info is None: - shared.log.warning(f"XYZ grid: unknown checkpoint: {x}") - else: - sd_models.reload_model_weights(shared.sd_model, info) - - -def confirm_checkpoints(p, xs): - for x in xs: - if sd_models.get_closet_checkpoint_match(x) is None: - shared.log.warning(f"XYZ grid: Unknown checkpoint: {x}") - - -def apply_clip_skip(p, x, xs): - shared.opts.data["CLIP_stop_at_last_layers"] = x - - -def apply_upscale_latent_space(p, x, xs): - if x.lower().strip() != '0': - shared.opts.data["use_scale_latent_for_hires_fix"] = True - else: - shared.opts.data["use_scale_latent_for_hires_fix"] = False - - -def find_vae(name: str): - if name.lower() in ['auto', 'automatic']: - return sd_vae.unspecified - if name.lower() == 'none': - return None - else: - choices = [x for x in sorted(sd_vae.vae_dict, key=lambda x: len(x)) if name.lower().strip() in x.lower()] - if len(choices) == 0: - shared.log.warning(f"No VAE found for {name}; using automatic") - return sd_vae.unspecified - else: - return sd_vae.vae_dict[choices[0]] - - -def apply_vae(p, x, xs): - sd_vae.reload_vae_weights(shared.sd_model, vae_file=find_vae(x)) - - -def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _): - p.styles.extend(x.split(',')) - - -def apply_fallback(p, x, xs): - sampler_name = sd_samplers.samplers_map.get(x.lower(), None) - if sampler_name is None: - shared.log.warning(f"XYZ grid: unknown sampler: {x}") - else: - shared.opts.data["xyz_fallback_sampler"] = sampler_name - - -def apply_uni_pc_order(p, x, xs): - shared.opts.data["uni_pc_order"] = min(x, p.steps - 1) - - -def apply_face_restore(p, opt, x): - opt = opt.lower() - if opt == 'codeformer': - is_active = True - p.face_restoration_model = 'CodeFormer' - elif opt == 'gfpgan': - is_active = True - p.face_restoration_model = 'GFPGAN' - else: - is_active = opt in ('true', 'yes', 'y', '1') - p.restore_faces = is_active - - -def apply_token_merging_ratio_hr(p, x, xs): - shared.opts.data["token_merging_ratio_hr"] = x - - -def apply_token_merging_ratio(p, x, xs): - shared.opts.data["token_merging_ratio"] = x - - -def apply_token_merging_random(p, x, xs): - is_active = x.lower() in ('true', 'yes', 'y', '1') - shared.opts.data["token_merging_random"] = is_active - - -def format_value_add_label(p, opt, x): - if type(x) == float: - x = round(x, 8) - return f"{opt.label}: {x}" - - -def format_value(p, opt, x): - if type(x) == float: - x = round(x, 8) - return x - - -def format_value_join_list(p, opt, x): - return ", ".join(x) - - -def do_nothing(p, x, xs): - pass - - -def format_nothing(p, opt, x): - return "" - - -def str_permutations(x): - """dummy function for specifying it in AxisOption's type when you want to get a list of permutations""" - return x - - -class AxisOption: - def __init__(self, label, tipe, apply, fmt=format_value_add_label, confirm=None, cost=0.0, choices=None): - self.label = label - self.type = tipe - self.apply = apply - self.format_value = fmt - self.confirm = confirm - self.cost = cost - self.choices = choices - - -class AxisOptionImg2Img(AxisOption): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.is_img2img = True - -class AxisOptionTxt2Img(AxisOption): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.is_img2img = False - - -axis_options = [ - AxisOption("Nothing", str, do_nothing, fmt=format_nothing), - AxisOption("Seed", int, apply_field("seed")), - AxisOption("Var. seed", int, apply_field("subseed")), - AxisOption("Var. strength", float, apply_field("subseed_strength")), - AxisOption("Steps", int, apply_field("steps")), - AxisOptionTxt2Img("Hires steps", int, apply_field("hr_second_pass_steps")), - AxisOption("CFG Scale", float, apply_field("cfg_scale")), - AxisOptionImg2Img("Image CFG Scale", float, apply_field("image_cfg_scale")), - AxisOption("Prompt S/R", str, apply_prompt, fmt=format_value), - AxisOption("Prompt order", str_permutations, apply_order, fmt=format_value_join_list), - AxisOptionTxt2Img("Sampler", str, apply_sampler, fmt=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]), - AxisOptionImg2Img("Sampler", str, apply_sampler, fmt=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]), - AxisOption("Checkpoint name", str, apply_checkpoint, fmt=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: list(sd_models.checkpoints_list)), - AxisOption("Sigma Churn", float, apply_field("s_churn")), - AxisOption("Sigma min", float, apply_field("s_tmin")), - AxisOption("Sigma max", float, apply_field("s_tmax")), - AxisOption("Sigma noise", float, apply_field("s_noise")), - AxisOption("Eta", float, apply_field("eta")), - AxisOption("Clip skip", int, apply_clip_skip), - AxisOption("Denoising", float, apply_field("denoising_strength")), - AxisOptionTxt2Img("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]]), - AxisOptionTxt2Img("Fallback latent upscaler sampler", str, apply_fallback, fmt=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]), - AxisOptionImg2Img("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")), - AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: list(sd_vae.vae_dict)), - AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)), - AxisOption("UniPC Order", int, apply_uni_pc_order, cost=0.5), - AxisOption("Face restore", str, apply_face_restore, fmt=format_value), - AxisOption("ToMe ratio",float,apply_token_merging_ratio), - AxisOption("ToMe ratio for Hires fix",float,apply_token_merging_ratio_hr), - AxisOption("ToMe random pertubations",str,apply_token_merging_random, choices = lambda: ["Yes","No"]) -] - - -def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend, include_lone_images, include_sub_grids, first_axes_processed, second_axes_processed, margin_size): - hor_texts = [[images.GridAnnotation(x)] for x in x_labels] - ver_texts = [[images.GridAnnotation(y)] for y in y_labels] - title_texts = [[images.GridAnnotation(z)] for z in z_labels] - list_size = (len(xs) * len(ys) * len(zs)) - processed_result = None - shared.state.job_count = list_size * p.n_iter - - def process_cell(x, y, z, ix, iy, iz): - nonlocal processed_result - - def index(ix, iy, iz): - return ix + iy * len(xs) + iz * len(xs) * len(ys) - - shared.state.job = f"{index(ix, iy, iz) + 1} out of {list_size}" - processed: Processed = cell(x, y, z, ix, iy, iz) - if processed_result is None: - # Use our first processed result object as a template container to hold our full results - processed_result = copy(processed) - processed_result.images = [None] * list_size - processed_result.all_prompts = [None] * list_size - processed_result.all_seeds = [None] * list_size - processed_result.infotexts = [None] * list_size - processed_result.index_of_first_image = 1 - idx = index(ix, iy, iz) - if processed.images: - # Non-empty list indicates some degree of success. - processed_result.images[idx] = processed.images[0] - processed_result.all_prompts[idx] = processed.prompt - processed_result.all_seeds[idx] = processed.seed - processed_result.infotexts[idx] = processed.infotexts[0] - else: - cell_mode = "P" - cell_size = (processed_result.width, processed_result.height) - if processed_result.images[0] is not None: - cell_mode = processed_result.images[0].mode - #This corrects size in case of batches: - cell_size = processed_result.images[0].size - processed_result.images[idx] = Image.new(cell_mode, cell_size) - - if first_axes_processed == 'x': - for ix, x in enumerate(xs): - if second_axes_processed == 'y': - for iy, y in enumerate(ys): - for iz, z in enumerate(zs): - process_cell(x, y, z, ix, iy, iz) - else: - for iz, z in enumerate(zs): - for iy, y in enumerate(ys): - process_cell(x, y, z, ix, iy, iz) - elif first_axes_processed == 'y': - for iy, y in enumerate(ys): - if second_axes_processed == 'x': - for ix, x in enumerate(xs): - for iz, z in enumerate(zs): - process_cell(x, y, z, ix, iy, iz) - else: - for iz, z in enumerate(zs): - for ix, x in enumerate(xs): - process_cell(x, y, z, ix, iy, iz) - elif first_axes_processed == 'z': - for iz, z in enumerate(zs): - if second_axes_processed == 'x': - for ix, x in enumerate(xs): - for iy, y in enumerate(ys): - process_cell(x, y, z, ix, iy, iz) - else: - for iy, y in enumerate(ys): - for ix, x in enumerate(xs): - process_cell(x, y, z, ix, iy, iz) - - if not processed_result: - # Should never happen, I've only seen it on one of four open tabs and it needed to refresh. - shared.log.error("XYZ grid: Processing could not begin, you may need to refresh the tab or restart the service") - return Processed(p, []) - elif not any(processed_result.images): - shared.log.error("XYZ grid: Failed to return even a single processed image") - return Processed(p, []) - - z_count = len(zs) - # sub_grids = [None] * z_count - for i in range(z_count): - start_index = (i * len(xs) * len(ys)) + i - end_index = start_index + len(xs) * len(ys) - grid = images.image_grid(processed_result.images[start_index:end_index], rows=len(ys)) - if draw_legend: - grid = images.draw_grid_annotations(grid, processed_result.images[start_index].size[0], processed_result.images[start_index].size[1], hor_texts, ver_texts, margin_size) - processed_result.images.insert(i, grid) - processed_result.all_prompts.insert(i, processed_result.all_prompts[start_index]) - processed_result.all_seeds.insert(i, processed_result.all_seeds[start_index]) - processed_result.infotexts.insert(i, processed_result.infotexts[start_index]) - sub_grid_size = processed_result.images[0].size - z_grid = images.image_grid(processed_result.images[:z_count], rows=1) - if draw_legend: - z_grid = images.draw_grid_annotations(z_grid, sub_grid_size[0], sub_grid_size[1], title_texts, [[images.GridAnnotation()]]) - processed_result.images.insert(0, z_grid) - #processed_result.all_prompts.insert(0, processed_result.all_prompts[0]) - #processed_result.all_seeds.insert(0, processed_result.all_seeds[0]) - processed_result.infotexts.insert(0, processed_result.infotexts[0]) - return processed_result - - -class SharedSettingsStackHelper(object): - def __enter__(self): - #Save overridden settings so they can be restored later. - self.CLIP_stop_at_last_layers = shared.opts.CLIP_stop_at_last_layers - self.vae = shared.opts.sd_vae - self.uni_pc_order = shared.opts.uni_pc_order - self.token_merging_ratio_hr = shared.opts.token_merging_ratio_hr - self.token_merging_ratio = shared.opts.token_merging_ratio - self.token_merging_random = shared.opts.token_merging_random - self.sd_model_checkpoint = shared.opts.sd_model_checkpoint - self.sd_vae_checkpoint = shared.opts.sd_vae - - def __exit__(self, exc_type, exc_value, tb): - #Restore overriden settings after plot generation. - shared.opts.data["sd_vae"] = self.vae - shared.opts.data["uni_pc_order"] = self.uni_pc_order - shared.opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers - shared.opts.data["token_merging_ratio_hr"] = self.token_merging_ratio_hr - shared.opts.data["token_merging_ratio"] = self.token_merging_ratio - shared.opts.data["token_merging_random"] = self.token_merging_random - if self.sd_model_checkpoint != shared.opts.sd_model_checkpoint: - shared.opts.data["sd_model_checkpoint"] = self.sd_model_checkpoint - sd_models.reload_model_weights() - if self.sd_vae_checkpoint != shared.opts.sd_vae: - shared.opts.data["sd_vae"] = self.sd_vae_checkpoint - sd_vae.reload_vae_weights() - - -re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*") -re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*") -re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*") -re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*") - - -class Script(scripts.Script): - def title(self): - return "X/Y/Z plot" - - def ui(self, is_img2img): - self.current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img == is_img2img] - with gr.Row(): - with gr.Column(scale=19): - with gr.Row(): - x_type = gr.Dropdown(label="X type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type")) - x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values")) - x_values_dropdown = gr.Dropdown(label="X values",visible=False,multiselect=True,interactive=True) - fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_x_tool_button", visible=False) - - with gr.Row(): - y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type")) - y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values")) - y_values_dropdown = gr.Dropdown(label="Y values",visible=False,multiselect=True,interactive=True) - fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_y_tool_button", visible=False) - - with gr.Row(): - z_type = gr.Dropdown(label="Z type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("z_type")) - z_values = gr.Textbox(label="Z values", lines=1, elem_id=self.elem_id("z_values")) - z_values_dropdown = gr.Dropdown(label="Z values",visible=False,multiselect=True,interactive=True) - fill_z_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_z_tool_button", visible=False) - with gr.Row(variant="compact", elem_id="axis_options"): - draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend")) - no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds")) - include_lone_images = gr.Checkbox(label='Include Sub Images', value=False, elem_id=self.elem_id("include_lone_images")) - include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids")) - with gr.Row(variant="compact", elem_id="axis_options"): - margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size")) - with gr.Row(variant="compact", elem_id="swap_axes"): - swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button") - swap_yz_axes_button = gr.Button(value="Swap Y/Z axes", elem_id="yz_grid_swap_axes_button") - swap_xz_axes_button = gr.Button(value="Swap X/Z axes", elem_id="xz_grid_swap_axes_button") - - def swap_axes(axis1_type, axis1_values, axis1_values_dropdown, axis2_type, axis2_values, axis2_values_dropdown): - return self.current_axis_options[axis2_type].label, axis2_values, axis2_values_dropdown, self.current_axis_options[axis1_type].label, axis1_values, axis1_values_dropdown - - xy_swap_args = [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown] - swap_xy_axes_button.click(swap_axes, inputs=xy_swap_args, outputs=xy_swap_args) - yz_swap_args = [y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown] - swap_yz_axes_button.click(swap_axes, inputs=yz_swap_args, outputs=yz_swap_args) - xz_swap_args = [x_type, x_values, x_values_dropdown, z_type, z_values, z_values_dropdown] - swap_xz_axes_button.click(swap_axes, inputs=xz_swap_args, outputs=xz_swap_args) - - def fill(x_type): - axis = self.current_axis_options[x_type] - return axis.choices() if axis.choices else gr.update() - - fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values_dropdown]) - fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values_dropdown]) - fill_z_button.click(fn=fill, inputs=[z_type], outputs=[z_values_dropdown]) - - def select_axis(axis_type,axis_values_dropdown): - choices = self.current_axis_options[axis_type].choices - has_choices = choices is not None - current_values = axis_values_dropdown - if has_choices: - choices = choices() - if isinstance(current_values,str): - current_values = current_values.split(",") - current_values = list(filter(lambda x: x in choices, current_values)) - return gr.Button.update(visible=has_choices),gr.Textbox.update(visible=not has_choices),gr.update(choices=choices if has_choices else None,visible=has_choices,value=current_values) - - x_type.change(fn=select_axis, inputs=[x_type,x_values_dropdown], outputs=[fill_x_button,x_values,x_values_dropdown]) - y_type.change(fn=select_axis, inputs=[y_type,y_values_dropdown], outputs=[fill_y_button,y_values,y_values_dropdown]) - z_type.change(fn=select_axis, inputs=[z_type,z_values_dropdown], outputs=[fill_z_button,z_values,z_values_dropdown]) - - def get_dropdown_update_from_params(axis,params): - val_key = axis + " Values" - vals = params.get(val_key,"") - valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x] - return gr.update(value = valslist) - - self.infotext_fields = ( - (x_type, "X Type"), - (x_values, "X Values"), - (x_values_dropdown, lambda params:get_dropdown_update_from_params("X",params)), - (y_type, "Y Type"), - (y_values, "Y Values"), - (y_values_dropdown, lambda params:get_dropdown_update_from_params("Y",params)), - (z_type, "Z Type"), - (z_values, "Z Values"), - (z_values_dropdown, lambda params:get_dropdown_update_from_params("Z",params)), - ) - - return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size] - - def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size): # pylint: disable=arguments-differ - shared.log.debug(f'xyzgrid: {x_type}|{x_values}|{x_values_dropdown}|{y_type}|{y_values}|{y_values_dropdown}|{z_type}|{z_values}|{z_values_dropdown}|{draw_legend}|{include_lone_images}|{include_sub_grids}|{no_fixed_seeds}|{margin_size}') - if not no_fixed_seeds: - processing.fix_seed(p) - if not shared.opts.return_grid: - p.batch_size = 1 - def process_axis(opt, vals, vals_dropdown): - if opt.label == 'Nothing': - return [0] - if opt.choices is not None: - valslist = vals_dropdown - else: - valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x] - if opt.type == int: - valslist_ext = [] - for val in valslist: - m = re_range.fullmatch(val) - mc = re_range_count.fullmatch(val) - if m is not None: - start = int(m.group(1)) - end = int(m.group(2))+1 - step = int(m.group(3)) if m.group(3) is not None else 1 - valslist_ext += list(range(start, end, step)) - elif mc is not None: - start = int(mc.group(1)) - end = int(mc.group(2)) - num = int(mc.group(3)) if mc.group(3) is not None else 1 - valslist_ext += [int(x) for x in np.linspace(start=start, stop=end, num=num).tolist()] - else: - valslist_ext.append(val) - valslist = valslist_ext - elif opt.type == float: - valslist_ext = [] - for val in valslist: - m = re_range_float.fullmatch(val) - mc = re_range_count_float.fullmatch(val) - if m is not None: - start = float(m.group(1)) - end = float(m.group(2)) - step = float(m.group(3)) if m.group(3) is not None else 1 - valslist_ext += np.arange(start, end + step, step).tolist() - elif mc is not None: - start = float(mc.group(1)) - end = float(mc.group(2)) - num = int(mc.group(3)) if mc.group(3) is not None else 1 - valslist_ext += np.linspace(start=start, stop=end, num=num).tolist() - else: - valslist_ext.append(val) - valslist = valslist_ext - elif opt.type == str_permutations: - valslist = list(permutations(valslist)) - valslist = [opt.type(x) for x in valslist] - # Confirm options are valid before starting - if opt.confirm: - opt.confirm(p, valslist) - return valslist - - x_opt = self.current_axis_options[x_type] - if x_opt.choices is not None: - x_values = ",".join(x_values_dropdown) - xs = process_axis(x_opt, x_values, x_values_dropdown) - y_opt = self.current_axis_options[y_type] - if y_opt.choices is not None: - y_values = ",".join(y_values_dropdown) - ys = process_axis(y_opt, y_values, y_values_dropdown) - z_opt = self.current_axis_options[z_type] - if z_opt.choices is not None: - z_values = ",".join(z_values_dropdown) - zs = process_axis(z_opt, z_values, z_values_dropdown) - Image.MAX_IMAGE_PIXELS = None # disable check in Pillow and rely on check below to allow large custom image sizes - grid_mp = round(len(xs) * len(ys) * len(zs) * p.width * p.height / 1000000) - assert grid_mp < shared.opts.img_max_size_mp, f'Error: Resulting grid would be too large ({grid_mp} MPixels) (max configured size is {shared.opts.img_max_size_mp} MPixels)' - - def fix_axis_seeds(axis_opt, axis_list): - if axis_opt.label in ['Seed', 'Var. seed']: - return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list] - else: - return axis_list - - if not no_fixed_seeds: - xs = fix_axis_seeds(x_opt, xs) - ys = fix_axis_seeds(y_opt, ys) - zs = fix_axis_seeds(z_opt, zs) - - if x_opt.label == 'Steps': - total_steps = sum(xs) * len(ys) * len(zs) - elif y_opt.label == 'Steps': - total_steps = sum(ys) * len(xs) * len(zs) - elif z_opt.label == 'Steps': - total_steps = sum(zs) * len(xs) * len(ys) - else: - total_steps = p.steps * len(xs) * len(ys) * len(zs) - if isinstance(p, StableDiffusionProcessingTxt2Img) and p.enable_hr: - if x_opt.label == "Hires steps": - total_steps += sum(xs) * len(ys) * len(zs) - elif y_opt.label == "Hires steps": - total_steps += sum(ys) * len(xs) * len(zs) - elif z_opt.label == "Hires steps": - total_steps += sum(zs) * len(xs) * len(ys) - elif p.hr_second_pass_steps: - total_steps += p.hr_second_pass_steps * len(xs) * len(ys) * len(zs) - else: - total_steps *= 2 - total_steps *= p.n_iter - image_cell_count = p.n_iter * p.batch_size - shared.log.info(f"XYZ grid: images={len(xs)*len(ys)*len(zs)*image_cell_count} grid={len(zs)} {len(xs)}x{len(ys)} cells={len(zs)} steps={total_steps}") - shared.state.xyz_plot_x = AxisInfo(x_opt, xs) - shared.state.xyz_plot_y = AxisInfo(y_opt, ys) - shared.state.xyz_plot_z = AxisInfo(z_opt, zs) - # If one of the axes is very slow to change between (like SD model checkpoint), then make sure it is in the outer iteration of the nested `for` loop. - first_axes_processed = 'z' - second_axes_processed = 'y' - if x_opt.cost > y_opt.cost and x_opt.cost > z_opt.cost: - first_axes_processed = 'x' - if y_opt.cost > z_opt.cost: - second_axes_processed = 'y' - else: - second_axes_processed = 'z' - elif y_opt.cost > x_opt.cost and y_opt.cost > z_opt.cost: - first_axes_processed = 'y' - if x_opt.cost > z_opt.cost: - second_axes_processed = 'x' - else: - second_axes_processed = 'z' - elif z_opt.cost > x_opt.cost and z_opt.cost > y_opt.cost: - first_axes_processed = 'z' - if x_opt.cost > y_opt.cost: - second_axes_processed = 'x' - else: - second_axes_processed = 'y' - grid_infotext = [None] * (1 + len(zs)) - - def cell(x, y, z, ix, iy, iz): - if shared.state.interrupted: - return Processed(p, [], p.seed, "") - pc = copy(p) - pc.styles = pc.styles[:] - x_opt.apply(pc, x, xs) - y_opt.apply(pc, y, ys) - z_opt.apply(pc, z, zs) - res = process_images(pc) - # Sets subgrid infotexts - subgrid_index = 1 + iz - if grid_infotext[subgrid_index] is None and ix == 0 and iy == 0: - pc.extra_generation_params = copy(pc.extra_generation_params) - pc.extra_generation_params['Script'] = self.title() - if x_opt.label != 'Nothing': - pc.extra_generation_params["X Type"] = x_opt.label - pc.extra_generation_params["X Values"] = x_values - if x_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds: - pc.extra_generation_params["Fixed X Values"] = ", ".join([str(x) for x in xs]) - if y_opt.label != 'Nothing': - pc.extra_generation_params["Y Type"] = y_opt.label - pc.extra_generation_params["Y Values"] = y_values - if y_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds: - pc.extra_generation_params["Fixed Y Values"] = ", ".join([str(y) for y in ys]) - grid_infotext[subgrid_index] = processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds) - # Sets main grid infotext - if grid_infotext[0] is None and ix == 0 and iy == 0 and iz == 0: - pc.extra_generation_params = copy(pc.extra_generation_params) - - if z_opt.label != 'Nothing': - pc.extra_generation_params["Z Type"] = z_opt.label - pc.extra_generation_params["Z Values"] = z_values - if z_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds: - pc.extra_generation_params["Fixed Z Values"] = ", ".join([str(z) for z in zs]) - grid_infotext[0] = processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds) - return res - - with SharedSettingsStackHelper(): - processed = draw_xyz_grid( - p, - xs=xs, - ys=ys, - zs=zs, - x_labels=[x_opt.format_value(p, x_opt, x) for x in xs], - y_labels=[y_opt.format_value(p, y_opt, y) for y in ys], - z_labels=[z_opt.format_value(p, z_opt, z) for z in zs], - cell=cell, - draw_legend=draw_legend, - include_lone_images=include_lone_images, - include_sub_grids=include_sub_grids, - first_axes_processed=first_axes_processed, - second_axes_processed=second_axes_processed, - margin_size=margin_size - ) - - if not processed.images: - # It broke, no further handling needed. - return processed - z_count = len(zs) - # Set the grid infotexts to the real ones with extra_generation_params (1 main grid + z_count sub-grids) - processed.infotexts[:1+z_count] = grid_infotext[:1+z_count] - if not include_lone_images: - # Don't need sub-images anymore, drop from list: - processed.images = processed.images[:z_count+1] - if shared.opts.grid_save: - # Auto-save main and sub-grids: - grid_count = z_count + 1 if z_count > 1 else 1 - for g in range(grid_count): - adj_g = g-1 if g > 0 else g - images.save_image(processed.images[g], p.outpath_grids, "xyz_grid", info=processed.infotexts[g], extension=shared.opts.grid_format, prompt=processed.all_prompts[adj_g], seed=processed.all_seeds[adj_g], grid=True, p=processed) - if not include_sub_grids: - # Done with sub-grids, drop all related information: - for _sg in range(z_count): - del processed.images[1] - del processed.all_prompts[1] - del processed.all_seeds[1] - del processed.infotexts[1] - return processed + # pylint: disable=unused-argument, attribute-defined-outside-init + +import re +import csv +import random +from collections import namedtuple +from copy import copy +from itertools import permutations, chain +from io import StringIO +from PIL import Image +import numpy as np +import gradio as gr +import modules.scripts as scripts +import modules.shared as shared +from modules import images, sd_samplers, processing, sd_models, sd_vae +from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img +from modules.ui_components import ToolButton + +fill_values_symbol = "\U0001f4d2" # 📒 +AxisInfo = namedtuple('AxisInfo', ['axis', 'values']) + + +def apply_field(field): + def fun(p, x, xs): + setattr(p, field, x) + return fun + + +def apply_prompt(p, x, xs): + if xs[0] not in p.prompt and xs[0] not in p.negative_prompt: + shared.log.warning(f"XYZ grid: prompt S/R did not find {xs[0]} in prompt or negative prompt.") + else: + p.prompt = p.prompt.replace(xs[0], x) + p.negative_prompt = p.negative_prompt.replace(xs[0], x) + + +def apply_order(p, x, xs): + token_order = [] + for token in x: + token_order.append((p.prompt.find(token), token)) + token_order.sort(key=lambda t: t[0]) + prompt_parts = [] + for _, token in token_order: + n = p.prompt.find(token) + prompt_parts.append(p.prompt[0:n]) + p.prompt = p.prompt[n + len(token):] + prompt_tmp = "" + for idx, part in enumerate(prompt_parts): + prompt_tmp += part + prompt_tmp += x[idx] + p.prompt = prompt_tmp + p.prompt + + +def apply_sampler(p, x, xs): + sampler_name = sd_samplers.samplers_map.get(x.lower(), None) + if sampler_name is None: + shared.log.warning(f"XYZ grid: unknown sampler: {x}") + else: + p.sampler_name = sampler_name + + +def confirm_samplers(p, xs): + for x in xs: + if x.lower() not in sd_samplers.samplers_map: + shared.log.warning(f"XYZ grid: unknown sampler: {x}") + + +def apply_checkpoint(p, x, xs): + if x == shared.opts.sd_model_checkpoint: + return + info = sd_models.get_closet_checkpoint_match(x) + if info is None: + shared.log.warning(f"XYZ grid: unknown checkpoint: {x}") + else: + sd_models.reload_model_weights(shared.sd_model, info) + + +def confirm_checkpoints(p, xs): + for x in xs: + if sd_models.get_closet_checkpoint_match(x) is None: + shared.log.warning(f"XYZ grid: Unknown checkpoint: {x}") + + +def apply_clip_skip(p, x, xs): + shared.opts.data["CLIP_stop_at_last_layers"] = x + + +def apply_upscale_latent_space(p, x, xs): + if x.lower().strip() != '0': + shared.opts.data["use_scale_latent_for_hires_fix"] = True + else: + shared.opts.data["use_scale_latent_for_hires_fix"] = False + + +def find_vae(name: str): + if name.lower() in ['auto', 'automatic']: + return sd_vae.unspecified + if name.lower() == 'none': + return None + else: + choices = [x for x in sorted(sd_vae.vae_dict, key=lambda x: len(x)) if name.lower().strip() in x.lower()] + if len(choices) == 0: + shared.log.warning(f"No VAE found for {name}; using automatic") + return sd_vae.unspecified + else: + return sd_vae.vae_dict[choices[0]] + + +def apply_vae(p, x, xs): + sd_vae.reload_vae_weights(shared.sd_model, vae_file=find_vae(x)) + + +def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _): + p.styles.extend(x.split(',')) + + +def apply_fallback(p, x, xs): + sampler_name = sd_samplers.samplers_map.get(x.lower(), None) + if sampler_name is None: + shared.log.warning(f"XYZ grid: unknown sampler: {x}") + else: + shared.opts.data["xyz_fallback_sampler"] = sampler_name + + +def apply_uni_pc_order(p, x, xs): + shared.opts.data["uni_pc_order"] = min(x, p.steps - 1) + + +def apply_face_restore(p, opt, x): + opt = opt.lower() + if opt == 'codeformer': + is_active = True + p.face_restoration_model = 'CodeFormer' + elif opt == 'gfpgan': + is_active = True + p.face_restoration_model = 'GFPGAN' + else: + is_active = opt in ('true', 'yes', 'y', '1') + p.restore_faces = is_active + + +def apply_token_merging_ratio_hr(p, x, xs): + shared.opts.data["token_merging_ratio_hr"] = x + + +def apply_token_merging_ratio(p, x, xs): + shared.opts.data["token_merging_ratio"] = x + + +def apply_token_merging_random(p, x, xs): + is_active = x.lower() in ('true', 'yes', 'y', '1') + shared.opts.data["token_merging_random"] = is_active + + +def format_value_add_label(p, opt, x): + if type(x) == float: + x = round(x, 8) + return f"{opt.label}: {x}" + + +def format_value(p, opt, x): + if type(x) == float: + x = round(x, 8) + return x + + +def format_value_join_list(p, opt, x): + return ", ".join(x) + + +def do_nothing(p, x, xs): + pass + + +def format_nothing(p, opt, x): + return "" + + +def str_permutations(x): + """dummy function for specifying it in AxisOption's type when you want to get a list of permutations""" + return x + + +class AxisOption: + def __init__(self, label, tipe, apply, fmt=format_value_add_label, confirm=None, cost=0.0, choices=None): + self.label = label + self.type = tipe + self.apply = apply + self.format_value = fmt + self.confirm = confirm + self.cost = cost + self.choices = choices + + +class AxisOptionImg2Img(AxisOption): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_img2img = True + +class AxisOptionTxt2Img(AxisOption): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_img2img = False + + +axis_options = [ + AxisOption("Nothing", str, do_nothing, fmt=format_nothing), + AxisOption("Seed", int, apply_field("seed")), + AxisOption("Var. seed", int, apply_field("subseed")), + AxisOption("Var. strength", float, apply_field("subseed_strength")), + AxisOption("Steps", int, apply_field("steps")), + AxisOptionTxt2Img("Hires steps", int, apply_field("hr_second_pass_steps")), + AxisOption("CFG Scale", float, apply_field("cfg_scale")), + AxisOptionImg2Img("Image CFG Scale", float, apply_field("image_cfg_scale")), + AxisOption("Prompt S/R", str, apply_prompt, fmt=format_value), + AxisOption("Prompt order", str_permutations, apply_order, fmt=format_value_join_list), + AxisOptionTxt2Img("Sampler", str, apply_sampler, fmt=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]), + AxisOptionImg2Img("Sampler", str, apply_sampler, fmt=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]), + AxisOption("Checkpoint name", str, apply_checkpoint, fmt=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: list(sd_models.checkpoints_list)), + AxisOption("Sigma Churn", float, apply_field("s_churn")), + AxisOption("Sigma min", float, apply_field("s_tmin")), + AxisOption("Sigma max", float, apply_field("s_tmax")), + AxisOption("Sigma noise", float, apply_field("s_noise")), + AxisOption("Eta", float, apply_field("eta")), + AxisOption("Clip skip", int, apply_clip_skip), + AxisOption("Denoising", float, apply_field("denoising_strength")), + AxisOptionTxt2Img("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]]), + AxisOptionTxt2Img("Fallback latent upscaler sampler", str, apply_fallback, fmt=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]), + AxisOptionImg2Img("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")), + AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: list(sd_vae.vae_dict)), + AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)), + AxisOption("UniPC Order", int, apply_uni_pc_order, cost=0.5), + AxisOption("Face restore", str, apply_face_restore, fmt=format_value), + AxisOption("ToMe ratio",float,apply_token_merging_ratio), + AxisOption("ToMe ratio for Hires fix",float,apply_token_merging_ratio_hr), + AxisOption("ToMe random pertubations",str,apply_token_merging_random, choices = lambda: ["Yes","No"]) +] + + +def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend, include_lone_images, include_sub_grids, first_axes_processed, second_axes_processed, margin_size): + hor_texts = [[images.GridAnnotation(x)] for x in x_labels] + ver_texts = [[images.GridAnnotation(y)] for y in y_labels] + title_texts = [[images.GridAnnotation(z)] for z in z_labels] + list_size = (len(xs) * len(ys) * len(zs)) + processed_result = None + shared.state.job_count = list_size * p.n_iter + + def process_cell(x, y, z, ix, iy, iz): + nonlocal processed_result + + def index(ix, iy, iz): + return ix + iy * len(xs) + iz * len(xs) * len(ys) + + shared.state.job = f"{index(ix, iy, iz) + 1} out of {list_size}" + processed: Processed = cell(x, y, z, ix, iy, iz) + if processed_result is None: + # Use our first processed result object as a template container to hold our full results + processed_result = copy(processed) + processed_result.images = [None] * list_size + processed_result.all_prompts = [None] * list_size + processed_result.all_seeds = [None] * list_size + processed_result.infotexts = [None] * list_size + processed_result.index_of_first_image = 1 + idx = index(ix, iy, iz) + if processed.images: + # Non-empty list indicates some degree of success. + processed_result.images[idx] = processed.images[0] + processed_result.all_prompts[idx] = processed.prompt + processed_result.all_seeds[idx] = processed.seed + processed_result.infotexts[idx] = processed.infotexts[0] + else: + cell_mode = "P" + cell_size = (processed_result.width, processed_result.height) + if processed_result.images[0] is not None: + cell_mode = processed_result.images[0].mode + #This corrects size in case of batches: + cell_size = processed_result.images[0].size + processed_result.images[idx] = Image.new(cell_mode, cell_size) + + if first_axes_processed == 'x': + for ix, x in enumerate(xs): + if second_axes_processed == 'y': + for iy, y in enumerate(ys): + for iz, z in enumerate(zs): + process_cell(x, y, z, ix, iy, iz) + else: + for iz, z in enumerate(zs): + for iy, y in enumerate(ys): + process_cell(x, y, z, ix, iy, iz) + elif first_axes_processed == 'y': + for iy, y in enumerate(ys): + if second_axes_processed == 'x': + for ix, x in enumerate(xs): + for iz, z in enumerate(zs): + process_cell(x, y, z, ix, iy, iz) + else: + for iz, z in enumerate(zs): + for ix, x in enumerate(xs): + process_cell(x, y, z, ix, iy, iz) + elif first_axes_processed == 'z': + for iz, z in enumerate(zs): + if second_axes_processed == 'x': + for ix, x in enumerate(xs): + for iy, y in enumerate(ys): + process_cell(x, y, z, ix, iy, iz) + else: + for iy, y in enumerate(ys): + for ix, x in enumerate(xs): + process_cell(x, y, z, ix, iy, iz) + + if not processed_result: + # Should never happen, I've only seen it on one of four open tabs and it needed to refresh. + shared.log.error("XYZ grid: Processing could not begin, you may need to refresh the tab or restart the service") + return Processed(p, []) + elif not any(processed_result.images): + shared.log.error("XYZ grid: Failed to return even a single processed image") + return Processed(p, []) + + z_count = len(zs) + # sub_grids = [None] * z_count + for i in range(z_count): + start_index = (i * len(xs) * len(ys)) + i + end_index = start_index + len(xs) * len(ys) + grid = images.image_grid(processed_result.images[start_index:end_index], rows=len(ys)) + if draw_legend: + grid = images.draw_grid_annotations(grid, processed_result.images[start_index].size[0], processed_result.images[start_index].size[1], hor_texts, ver_texts, margin_size) + processed_result.images.insert(i, grid) + processed_result.all_prompts.insert(i, processed_result.all_prompts[start_index]) + processed_result.all_seeds.insert(i, processed_result.all_seeds[start_index]) + processed_result.infotexts.insert(i, processed_result.infotexts[start_index]) + sub_grid_size = processed_result.images[0].size + z_grid = images.image_grid(processed_result.images[:z_count], rows=1) + if draw_legend: + z_grid = images.draw_grid_annotations(z_grid, sub_grid_size[0], sub_grid_size[1], title_texts, [[images.GridAnnotation()]]) + processed_result.images.insert(0, z_grid) + #processed_result.all_prompts.insert(0, processed_result.all_prompts[0]) + #processed_result.all_seeds.insert(0, processed_result.all_seeds[0]) + processed_result.infotexts.insert(0, processed_result.infotexts[0]) + return processed_result + + +class SharedSettingsStackHelper(object): + def __enter__(self): + #Save overridden settings so they can be restored later. + self.CLIP_stop_at_last_layers = shared.opts.CLIP_stop_at_last_layers + self.vae = shared.opts.sd_vae + self.uni_pc_order = shared.opts.uni_pc_order + self.token_merging_ratio_hr = shared.opts.token_merging_ratio_hr + self.token_merging_ratio = shared.opts.token_merging_ratio + self.token_merging_random = shared.opts.token_merging_random + self.sd_model_checkpoint = shared.opts.sd_model_checkpoint + self.sd_vae_checkpoint = shared.opts.sd_vae + + def __exit__(self, exc_type, exc_value, tb): + #Restore overriden settings after plot generation. + shared.opts.data["sd_vae"] = self.vae + shared.opts.data["uni_pc_order"] = self.uni_pc_order + shared.opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers + shared.opts.data["token_merging_ratio_hr"] = self.token_merging_ratio_hr + shared.opts.data["token_merging_ratio"] = self.token_merging_ratio + shared.opts.data["token_merging_random"] = self.token_merging_random + if self.sd_model_checkpoint != shared.opts.sd_model_checkpoint: + shared.opts.data["sd_model_checkpoint"] = self.sd_model_checkpoint + sd_models.reload_model_weights() + if self.sd_vae_checkpoint != shared.opts.sd_vae: + shared.opts.data["sd_vae"] = self.sd_vae_checkpoint + sd_vae.reload_vae_weights() + + +re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*") +re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*") +re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*") +re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*") + + +class Script(scripts.Script): + def title(self): + return "X/Y/Z plot" + + def ui(self, is_img2img): + self.current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img == is_img2img] + with gr.Row(): + with gr.Column(scale=19): + with gr.Row(): + x_type = gr.Dropdown(label="X type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type")) + x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values")) + x_values_dropdown = gr.Dropdown(label="X values",visible=False,multiselect=True,interactive=True) + fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_x_tool_button", visible=False) + + with gr.Row(): + y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type")) + y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values")) + y_values_dropdown = gr.Dropdown(label="Y values",visible=False,multiselect=True,interactive=True) + fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_y_tool_button", visible=False) + + with gr.Row(): + z_type = gr.Dropdown(label="Z type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("z_type")) + z_values = gr.Textbox(label="Z values", lines=1, elem_id=self.elem_id("z_values")) + z_values_dropdown = gr.Dropdown(label="Z values",visible=False,multiselect=True,interactive=True) + fill_z_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_z_tool_button", visible=False) + with gr.Row(variant="compact", elem_id="axis_options"): + draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend")) + no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds")) + include_lone_images = gr.Checkbox(label='Include Sub Images', value=False, elem_id=self.elem_id("include_lone_images")) + include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids")) + with gr.Row(variant="compact", elem_id="axis_options"): + margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size")) + with gr.Row(variant="compact", elem_id="swap_axes"): + swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button") + swap_yz_axes_button = gr.Button(value="Swap Y/Z axes", elem_id="yz_grid_swap_axes_button") + swap_xz_axes_button = gr.Button(value="Swap X/Z axes", elem_id="xz_grid_swap_axes_button") + + def swap_axes(axis1_type, axis1_values, axis1_values_dropdown, axis2_type, axis2_values, axis2_values_dropdown): + return self.current_axis_options[axis2_type].label, axis2_values, axis2_values_dropdown, self.current_axis_options[axis1_type].label, axis1_values, axis1_values_dropdown + + xy_swap_args = [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown] + swap_xy_axes_button.click(swap_axes, inputs=xy_swap_args, outputs=xy_swap_args) + yz_swap_args = [y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown] + swap_yz_axes_button.click(swap_axes, inputs=yz_swap_args, outputs=yz_swap_args) + xz_swap_args = [x_type, x_values, x_values_dropdown, z_type, z_values, z_values_dropdown] + swap_xz_axes_button.click(swap_axes, inputs=xz_swap_args, outputs=xz_swap_args) + + def fill(x_type): + axis = self.current_axis_options[x_type] + return axis.choices() if axis.choices else gr.update() + + fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values_dropdown]) + fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values_dropdown]) + fill_z_button.click(fn=fill, inputs=[z_type], outputs=[z_values_dropdown]) + + def select_axis(axis_type,axis_values_dropdown): + choices = self.current_axis_options[axis_type].choices + has_choices = choices is not None + current_values = axis_values_dropdown + if has_choices: + choices = choices() + if isinstance(current_values,str): + current_values = current_values.split(",") + current_values = list(filter(lambda x: x in choices, current_values)) + return gr.Button.update(visible=has_choices),gr.Textbox.update(visible=not has_choices),gr.update(choices=choices if has_choices else None,visible=has_choices,value=current_values) + + x_type.change(fn=select_axis, inputs=[x_type,x_values_dropdown], outputs=[fill_x_button,x_values,x_values_dropdown]) + y_type.change(fn=select_axis, inputs=[y_type,y_values_dropdown], outputs=[fill_y_button,y_values,y_values_dropdown]) + z_type.change(fn=select_axis, inputs=[z_type,z_values_dropdown], outputs=[fill_z_button,z_values,z_values_dropdown]) + + def get_dropdown_update_from_params(axis,params): + val_key = axis + " Values" + vals = params.get(val_key,"") + valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x] + return gr.update(value = valslist) + + self.infotext_fields = ( + (x_type, "X Type"), + (x_values, "X Values"), + (x_values_dropdown, lambda params:get_dropdown_update_from_params("X",params)), + (y_type, "Y Type"), + (y_values, "Y Values"), + (y_values_dropdown, lambda params:get_dropdown_update_from_params("Y",params)), + (z_type, "Z Type"), + (z_values, "Z Values"), + (z_values_dropdown, lambda params:get_dropdown_update_from_params("Z",params)), + ) + + return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size] + + def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size): # pylint: disable=arguments-differ + shared.log.debug(f'xyzgrid: {x_type}|{x_values}|{x_values_dropdown}|{y_type}|{y_values}|{y_values_dropdown}|{z_type}|{z_values}|{z_values_dropdown}|{draw_legend}|{include_lone_images}|{include_sub_grids}|{no_fixed_seeds}|{margin_size}') + if not no_fixed_seeds: + processing.fix_seed(p) + if not shared.opts.return_grid: + p.batch_size = 1 + def process_axis(opt, vals, vals_dropdown): + if opt.label == 'Nothing': + return [0] + if opt.choices is not None: + valslist = vals_dropdown + else: + valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x] + if opt.type == int: + valslist_ext = [] + for val in valslist: + m = re_range.fullmatch(val) + mc = re_range_count.fullmatch(val) + if m is not None: + start = int(m.group(1)) + end = int(m.group(2))+1 + step = int(m.group(3)) if m.group(3) is not None else 1 + valslist_ext += list(range(start, end, step)) + elif mc is not None: + start = int(mc.group(1)) + end = int(mc.group(2)) + num = int(mc.group(3)) if mc.group(3) is not None else 1 + valslist_ext += [int(x) for x in np.linspace(start=start, stop=end, num=num).tolist()] + else: + valslist_ext.append(val) + valslist = valslist_ext + elif opt.type == float: + valslist_ext = [] + for val in valslist: + m = re_range_float.fullmatch(val) + mc = re_range_count_float.fullmatch(val) + if m is not None: + start = float(m.group(1)) + end = float(m.group(2)) + step = float(m.group(3)) if m.group(3) is not None else 1 + valslist_ext += np.arange(start, end + step, step).tolist() + elif mc is not None: + start = float(mc.group(1)) + end = float(mc.group(2)) + num = int(mc.group(3)) if mc.group(3) is not None else 1 + valslist_ext += np.linspace(start=start, stop=end, num=num).tolist() + else: + valslist_ext.append(val) + valslist = valslist_ext + elif opt.type == str_permutations: + valslist = list(permutations(valslist)) + valslist = [opt.type(x) for x in valslist] + # Confirm options are valid before starting + if opt.confirm: + opt.confirm(p, valslist) + return valslist + + x_opt = self.current_axis_options[x_type] + if x_opt.choices is not None: + x_values = ",".join(x_values_dropdown) + xs = process_axis(x_opt, x_values, x_values_dropdown) + y_opt = self.current_axis_options[y_type] + if y_opt.choices is not None: + y_values = ",".join(y_values_dropdown) + ys = process_axis(y_opt, y_values, y_values_dropdown) + z_opt = self.current_axis_options[z_type] + if z_opt.choices is not None: + z_values = ",".join(z_values_dropdown) + zs = process_axis(z_opt, z_values, z_values_dropdown) + Image.MAX_IMAGE_PIXELS = None # disable check in Pillow and rely on check below to allow large custom image sizes + grid_mp = round(len(xs) * len(ys) * len(zs) * p.width * p.height / 1000000) + assert grid_mp < shared.opts.img_max_size_mp, f'Error: Resulting grid would be too large ({grid_mp} MPixels) (max configured size is {shared.opts.img_max_size_mp} MPixels)' + + def fix_axis_seeds(axis_opt, axis_list): + if axis_opt.label in ['Seed', 'Var. seed']: + return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list] + else: + return axis_list + + if not no_fixed_seeds: + xs = fix_axis_seeds(x_opt, xs) + ys = fix_axis_seeds(y_opt, ys) + zs = fix_axis_seeds(z_opt, zs) + + if x_opt.label == 'Steps': + total_steps = sum(xs) * len(ys) * len(zs) + elif y_opt.label == 'Steps': + total_steps = sum(ys) * len(xs) * len(zs) + elif z_opt.label == 'Steps': + total_steps = sum(zs) * len(xs) * len(ys) + else: + total_steps = p.steps * len(xs) * len(ys) * len(zs) + if isinstance(p, StableDiffusionProcessingTxt2Img) and p.enable_hr: + if x_opt.label == "Hires steps": + total_steps += sum(xs) * len(ys) * len(zs) + elif y_opt.label == "Hires steps": + total_steps += sum(ys) * len(xs) * len(zs) + elif z_opt.label == "Hires steps": + total_steps += sum(zs) * len(xs) * len(ys) + elif p.hr_second_pass_steps: + total_steps += p.hr_second_pass_steps * len(xs) * len(ys) * len(zs) + else: + total_steps *= 2 + total_steps *= p.n_iter + image_cell_count = p.n_iter * p.batch_size + shared.log.info(f"XYZ grid: images={len(xs)*len(ys)*len(zs)*image_cell_count} grid={len(zs)} {len(xs)}x{len(ys)} cells={len(zs)} steps={total_steps}") + shared.state.xyz_plot_x = AxisInfo(x_opt, xs) + shared.state.xyz_plot_y = AxisInfo(y_opt, ys) + shared.state.xyz_plot_z = AxisInfo(z_opt, zs) + # If one of the axes is very slow to change between (like SD model checkpoint), then make sure it is in the outer iteration of the nested `for` loop. + first_axes_processed = 'z' + second_axes_processed = 'y' + if x_opt.cost > y_opt.cost and x_opt.cost > z_opt.cost: + first_axes_processed = 'x' + if y_opt.cost > z_opt.cost: + second_axes_processed = 'y' + else: + second_axes_processed = 'z' + elif y_opt.cost > x_opt.cost and y_opt.cost > z_opt.cost: + first_axes_processed = 'y' + if x_opt.cost > z_opt.cost: + second_axes_processed = 'x' + else: + second_axes_processed = 'z' + elif z_opt.cost > x_opt.cost and z_opt.cost > y_opt.cost: + first_axes_processed = 'z' + if x_opt.cost > y_opt.cost: + second_axes_processed = 'x' + else: + second_axes_processed = 'y' + grid_infotext = [None] * (1 + len(zs)) + + def cell(x, y, z, ix, iy, iz): + if shared.state.interrupted: + return Processed(p, [], p.seed, "") + pc = copy(p) + pc.styles = pc.styles[:] + x_opt.apply(pc, x, xs) + y_opt.apply(pc, y, ys) + z_opt.apply(pc, z, zs) + res = process_images(pc) + # Sets subgrid infotexts + subgrid_index = 1 + iz + if grid_infotext[subgrid_index] is None and ix == 0 and iy == 0: + pc.extra_generation_params = copy(pc.extra_generation_params) + pc.extra_generation_params['Script'] = self.title() + if x_opt.label != 'Nothing': + pc.extra_generation_params["X Type"] = x_opt.label + pc.extra_generation_params["X Values"] = x_values + if x_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds: + pc.extra_generation_params["Fixed X Values"] = ", ".join([str(x) for x in xs]) + if y_opt.label != 'Nothing': + pc.extra_generation_params["Y Type"] = y_opt.label + pc.extra_generation_params["Y Values"] = y_values + if y_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds: + pc.extra_generation_params["Fixed Y Values"] = ", ".join([str(y) for y in ys]) + grid_infotext[subgrid_index] = processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds) + # Sets main grid infotext + if grid_infotext[0] is None and ix == 0 and iy == 0 and iz == 0: + pc.extra_generation_params = copy(pc.extra_generation_params) + + if z_opt.label != 'Nothing': + pc.extra_generation_params["Z Type"] = z_opt.label + pc.extra_generation_params["Z Values"] = z_values + if z_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds: + pc.extra_generation_params["Fixed Z Values"] = ", ".join([str(z) for z in zs]) + grid_infotext[0] = processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds) + return res + + with SharedSettingsStackHelper(): + processed = draw_xyz_grid( + p, + xs=xs, + ys=ys, + zs=zs, + x_labels=[x_opt.format_value(p, x_opt, x) for x in xs], + y_labels=[y_opt.format_value(p, y_opt, y) for y in ys], + z_labels=[z_opt.format_value(p, z_opt, z) for z in zs], + cell=cell, + draw_legend=draw_legend, + include_lone_images=include_lone_images, + include_sub_grids=include_sub_grids, + first_axes_processed=first_axes_processed, + second_axes_processed=second_axes_processed, + margin_size=margin_size + ) + + if not processed.images: + # It broke, no further handling needed. + return processed + z_count = len(zs) + # Set the grid infotexts to the real ones with extra_generation_params (1 main grid + z_count sub-grids) + processed.infotexts[:1+z_count] = grid_infotext[:1+z_count] + if not include_lone_images: + # Don't need sub-images anymore, drop from list: + processed.images = processed.images[:z_count+1] + if shared.opts.grid_save: + # Auto-save main and sub-grids: + grid_count = z_count + 1 if z_count > 1 else 1 + for g in range(grid_count): + adj_g = g-1 if g > 0 else g + images.save_image(processed.images[g], p.outpath_grids, "xyz_grid", info=processed.infotexts[g], extension=shared.opts.grid_format, prompt=processed.all_prompts[adj_g], seed=processed.all_seeds[adj_g], grid=True, p=processed) + if not include_sub_grids: + # Done with sub-grids, drop all related information: + for _sg in range(z_count): + del processed.images[1] + del processed.all_prompts[1] + del processed.all_seeds[1] + del processed.infotexts[1] + return processed