diff --git a/.gitignore b/.gitignore index 3382bf80d..1c4049a7b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,14 +1,15 @@ # defaults __pycache__ -/params.txt /cache.json /config.json -/ui-config.json +/params.txt /setup.log /styles.csv +/ui-config.json /user.css /webui-user.bat /webui-user.sh +/html/extensions.json /javascript/themes.json node_modules pnpm-lock.yaml diff --git a/extensions-builtin/stable-diffusion-webui-images-browser b/extensions-builtin/stable-diffusion-webui-images-browser index c751a9eee..5f9e5ad49 160000 --- a/extensions-builtin/stable-diffusion-webui-images-browser +++ b/extensions-builtin/stable-diffusion-webui-images-browser @@ -1 +1 @@ -Subproject commit c751a9eeedef738fd0a731db1dd65690690b6161 +Subproject commit 5f9e5ad49035b6237f6c20f528d9313bb41e6f75 diff --git a/installer.py b/installer.py index be164ca73..84b1fe746 100644 --- a/installer.py +++ b/installer.py @@ -318,8 +318,9 @@ def run_extension_installer(folder): # get list of all enabled extensions def list_extensions(folder): - if opts.get('disable_all_extensions', 'none') != 'none': - log.debug('Disabled extensions: all') + disabled_extensions = opts.get('disable_all_extensions', 'none') + if disabled_extensions != 'none': + log.debug(f'Disabled extensions: {disabled_extensions}') return [] disabled_extensions = set(opts.get('disabled_extensions', [])) if len(disabled_extensions) > 0: @@ -350,7 +351,7 @@ def install_extensions(): if not args.skip_extensions: run_extension_installer(os.path.join(folder, ext)) log.info(f'Extensions enabled: {extensions_enabled}') - if (len(extensions_duplicates) > 0): + if len(extensions_duplicates) > 0: log.warning(f'Extensions duplicates: {extensions_duplicates}') diff --git a/javascript/aspectRatioOverlay.js b/javascript/aspectRatioOverlay.js index faec76d54..055d5bd7f 100644 --- a/javascript/aspectRatioOverlay.js +++ b/javascript/aspectRatioOverlay.js @@ -1,107 +1,106 @@ -let currentWidth = null; -let currentHeight = null; -let arFrameTimeout = setTimeout(() => {}, 0); - -function dimensionChange(e, is_width, is_height) { - if (is_width) { - currentWidth = e.target.value * 1.0; - } - if (is_height) { - currentHeight = e.target.value * 1.0; - } - - const inImg2img = gradioApp().querySelector('#tab_img2img').style.display == 'block'; - - if (!inImg2img) { - return; - } - - let targetElement = null; - - const tabIndex = get_tab_index('mode_img2img'); - if (tabIndex == 0) { // img2img - targetElement = gradioApp().querySelector('#img2img_image div[data-testid=image] img'); - } else if (tabIndex == 1) { // Sketch - targetElement = gradioApp().querySelector('#img2img_sketch div[data-testid=image] img'); - } else if (tabIndex == 2) { // Inpaint - targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img'); - } else if (tabIndex == 3) { // Inpaint sketch - targetElement = gradioApp().querySelector('#inpaint_sketch div[data-testid=image] img'); - } - - if (targetElement) { - let arPreviewRect = gradioApp().querySelector('#imageARPreview'); - if (!arPreviewRect) { - arPreviewRect = document.createElement('div'); - arPreviewRect.id = 'imageARPreview'; - gradioApp().appendChild(arPreviewRect); - } - - const viewportOffset = targetElement.getBoundingClientRect(); - - viewportscale = Math.min(targetElement.clientWidth / targetElement.naturalWidth, targetElement.clientHeight / targetElement.naturalHeight); - - scaledx = targetElement.naturalWidth * viewportscale; - scaledy = targetElement.naturalHeight * viewportscale; - - cleintRectTop = (viewportOffset.top + window.scrollY); - cleintRectLeft = (viewportOffset.left + window.scrollX); - cleintRectCentreY = cleintRectTop + (targetElement.clientHeight / 2); - cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth / 2); - - viewRectTop = cleintRectCentreY - (scaledy / 2); - viewRectLeft = cleintRectCentreX - (scaledx / 2); - arRectWidth = scaledx; - arRectHeight = scaledy; - - arscale = Math.min(arRectWidth / currentWidth, arRectHeight / currentHeight); - arscaledx = currentWidth * arscale; - arscaledy = currentHeight * arscale; - - arRectTop = cleintRectCentreY - (arscaledy / 2); - arRectLeft = cleintRectCentreX - (arscaledx / 2); - arRectWidth = arscaledx; - arRectHeight = arscaledy; - - arPreviewRect.style.top = `${arRectTop}px`; - arPreviewRect.style.left = `${arRectLeft}px`; - arPreviewRect.style.width = `${arRectWidth}px`; - arPreviewRect.style.height = `${arRectHeight}px`; - - clearTimeout(arFrameTimeout); - arFrameTimeout = setTimeout(() => { - arPreviewRect.style.display = 'none'; - }, 2000); - - arPreviewRect.style.display = 'block'; - } -} - -onUiUpdate(() => { - const arPreviewRect = gradioApp().querySelector('#imageARPreview'); - if (arPreviewRect) { - arPreviewRect.style.display = 'none'; - } - const tabImg2img = gradioApp().querySelector('#tab_img2img'); - if (tabImg2img) { - const inImg2img = tabImg2img.style.display == 'block'; - if (inImg2img) { - const inputs = gradioApp().querySelectorAll('input'); - inputs.forEach((e) => { - const is_width = e.parentElement.id == 'img2img_width'; - const is_height = e.parentElement.id == 'img2img_height'; - - if ((is_width || is_height) && !e.classList.contains('scrollwatch')) { - e.addEventListener('input', (e) => { dimensionChange(e, is_width, is_height); }); - e.classList.add('scrollwatch'); - } - if (is_width) { - currentWidth = e.value * 1.0; - } - if (is_height) { - currentHeight = e.value * 1.0; - } - }); - } - } -}); +let currentWidth = null; +let currentHeight = null; +let arFrameTimeout = setTimeout(() => {}, 0); + +function dimensionChange(e, is_width, is_height) { + if (is_width) { + currentWidth = e.target.value * 1.0; + } + if (is_height) { + currentHeight = e.target.value * 1.0; + } + + const inImg2img = gradioApp().querySelector('#tab_img2img').style.display === 'block'; + + if (!inImg2img) { + return; + } + + let targetElement = null; + + const tabIndex = get_tab_index('mode_img2img'); + if (tabIndex === 0) { // img2img + targetElement = gradioApp().querySelector('#img2img_image div[data-testid=image] img'); + } else if (tabIndex === 1) { // Sketch + targetElement = gradioApp().querySelector('#img2img_sketch div[data-testid=image] img'); + } else if (tabIndex === 2) { // Inpaint + targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img'); + } else if (tabIndex === 3) { // Inpaint sketch + targetElement = gradioApp().querySelector('#inpaint_sketch div[data-testid=image] img'); + } + + if (targetElement) { + let arPreviewRect = gradioApp().querySelector('#imageARPreview'); + if (!arPreviewRect) { + arPreviewRect = document.createElement('div'); + arPreviewRect.id = 'imageARPreview'; + gradioApp().appendChild(arPreviewRect); + } + + const viewportOffset = targetElement.getBoundingClientRect(); + + viewportscale = Math.min(targetElement.clientWidth / targetElement.naturalWidth, targetElement.clientHeight / targetElement.naturalHeight); + + scaledx = targetElement.naturalWidth * viewportscale; + scaledy = targetElement.naturalHeight * viewportscale; + + cleintRectTop = (viewportOffset.top + window.scrollY); + cleintRectLeft = (viewportOffset.left + window.scrollX); + cleintRectCentreY = cleintRectTop + (targetElement.clientHeight / 2); + cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth / 2); + + viewRectTop = cleintRectCentreY - (scaledy / 2); + viewRectLeft = cleintRectCentreX - (scaledx / 2); + arRectWidth = scaledx; + arRectHeight = scaledy; + + arscale = Math.min(arRectWidth / currentWidth, arRectHeight / currentHeight); + arscaledx = currentWidth * arscale; + arscaledy = currentHeight * arscale; + + arRectTop = cleintRectCentreY - (arscaledy / 2); + arRectLeft = cleintRectCentreX - (arscaledx / 2); + arRectWidth = arscaledx; + arRectHeight = arscaledy; + + arPreviewRect.style.top = `${arRectTop}px`; + arPreviewRect.style.left = `${arRectLeft}px`; + arPreviewRect.style.width = `${arRectWidth}px`; + arPreviewRect.style.height = `${arRectHeight}px`; + + clearTimeout(arFrameTimeout); + arFrameTimeout = setTimeout(() => { + arPreviewRect.style.display = 'none'; + }, 2000); + arPreviewRect.style.display = 'block'; + } +} + +onUiUpdate(() => { + const arPreviewRect = gradioApp().querySelector('#imageARPreview'); + if (arPreviewRect) { + arPreviewRect.style.display = 'none'; + } + const tabImg2img = gradioApp().querySelector('#tab_img2img'); + if (tabImg2img) { + const inImg2img = tabImg2img.style.display === 'block'; + if (inImg2img) { + const inputs = gradioApp().querySelectorAll('input'); + inputs.forEach((e) => { + const is_width = e.parentElement.id === 'img2img_width'; + const is_height = e.parentElement.id === 'img2img_height'; + + if ((is_width || is_height) && !e.classList.contains('scrollwatch')) { + e.addEventListener('input', (e) => { dimensionChange(e, is_width, is_height); }); + e.classList.add('scrollwatch'); + } + if (is_width) { + currentWidth = e.value * 1.0; + } + if (is_height) { + currentHeight = e.value * 1.0; + } + }); + } + } +}); diff --git a/javascript/contextMenus.js b/javascript/contextMenus.js index 39d880165..2296ba297 100644 --- a/javascript/contextMenus.js +++ b/javascript/contextMenus.js @@ -1,176 +1,176 @@ -contextMenuInit = function () { - let eventListenerApplied = false; - const menuSpecs = new Map(); - - const uid = function () { - return Date.now().toString(36) + Math.random().toString(36).substr(2); - }; - - function showContextMenu(event, element, menuEntries) { - const posx = event.clientX + document.body.scrollLeft + document.documentElement.scrollLeft; - const posy = event.clientY + document.body.scrollTop + document.documentElement.scrollTop; - - const oldMenu = gradioApp().querySelector('#context-menu'); - if (oldMenu) { - oldMenu.remove(); - } - - const tabButton = uiCurrentTab; - const baseStyle = window.getComputedStyle(tabButton); - - const contextMenu = document.createElement('nav'); - contextMenu.id = 'context-menu'; - contextMenu.style.background = baseStyle.background; - contextMenu.style.color = baseStyle.color; - contextMenu.style.fontFamily = baseStyle.fontFamily; - contextMenu.style.top = `${posy}px`; - contextMenu.style.left = `${posx}px`; - - const contextMenuList = document.createElement('ul'); - contextMenuList.className = 'context-menu-items'; - contextMenu.append(contextMenuList); - - menuEntries.forEach((entry) => { - const contextMenuEntry = document.createElement('a'); - contextMenuEntry.innerHTML = entry.name; - contextMenuEntry.addEventListener('click', (e) => { - entry.func(); - }); - contextMenuList.append(contextMenuEntry); - }); - - gradioApp().appendChild(contextMenu); - - const menuWidth = contextMenu.offsetWidth + 4; - const menuHeight = contextMenu.offsetHeight + 4; - - const windowWidth = window.innerWidth; - const windowHeight = window.innerHeight; - - if ((windowWidth - posx) < menuWidth) { - contextMenu.style.left = `${windowWidth - menuWidth}px`; - } - - if ((windowHeight - posy) < menuHeight) { - contextMenu.style.top = `${windowHeight - menuHeight}px`; - } - } - - function appendContextMenuOption(targetElementSelector, entryName, entryFunction) { - currentItems = menuSpecs.get(targetElementSelector); - - if (!currentItems) { - currentItems = []; - menuSpecs.set(targetElementSelector, currentItems); - } - const newItem = { - id: `${targetElementSelector}_${uid()}`, - name: entryName, - func: entryFunction, - isNew: true, - }; - - currentItems.push(newItem); - return newItem.id; - } - - function removeContextMenuOption(uid) { - menuSpecs.forEach((v, k) => { - let index = -1; - v.forEach((e, ei) => { if (e.id == uid) { index = ei; } }); - if (index >= 0) { - v.splice(index, 1); - } - }); - } - - function addContextMenuEventListener() { - if (eventListenerApplied) { - return; - } - gradioApp().addEventListener('click', (e) => { - const source = e.composedPath()[0]; - if (source.id && source.id.indexOf('check_progress') > -1) { - return; - } - - const oldMenu = gradioApp().querySelector('#context-menu'); - if (oldMenu) { - oldMenu.remove(); - } - }); - gradioApp().addEventListener('contextmenu', (e) => { - const oldMenu = gradioApp().querySelector('#context-menu'); - if (oldMenu) { - oldMenu.remove(); - } - menuSpecs.forEach((v, k) => { - if (e.composedPath()[0].matches(k)) { - showContextMenu(e, e.composedPath()[0], v); - e.preventDefault(); - } - }); - }); - eventListenerApplied = true; - } - - return [appendContextMenuOption, removeContextMenuOption, addContextMenuEventListener]; -}; - -initResponse = contextMenuInit(); -appendContextMenuOption = initResponse[0]; -removeContextMenuOption = initResponse[1]; -addContextMenuEventListener = initResponse[2]; - -(function () { - // Start example Context Menu Items - const generateOnRepeat = function (genbuttonid, interruptbuttonid) { - const genbutton = gradioApp().querySelector(genbuttonid); - const busy = document.getElementById('progressbar')?.style.display == 'block'; - if (!busy) { - genbutton.click(); - } - clearInterval(window.generateOnRepeatInterval); - window.generateOnRepeatInterval = setInterval( - () => { - const busy = document.getElementById('progressbar')?.style.display == 'block'; - if (!busy) { - genbutton.click(); - } - }, - 500, - ); - }; - - appendContextMenuOption('#txt2img_generate', 'Generate forever', () => { - generateOnRepeat('#txt2img_generate', '#txt2img_interrupt'); - }); - appendContextMenuOption('#img2img_generate', 'Generate forever', () => { - generateOnRepeat('#img2img_generate', '#img2img_interrupt'); - }); - - const cancelGenerateForever = function () { - clearInterval(window.generateOnRepeatInterval); - }; - - appendContextMenuOption('#txt2img_interrupt', 'Cancel generate forever', cancelGenerateForever); - appendContextMenuOption('#txt2img_generate', 'Cancel generate forever', cancelGenerateForever); - appendContextMenuOption('#img2img_interrupt', 'Cancel generate forever', cancelGenerateForever); - appendContextMenuOption('#img2img_generate', 'Cancel generate forever', cancelGenerateForever); - - appendContextMenuOption( - '#roll', - 'Roll three', - () => { - const rollbutton = get_uiCurrentTabContent().querySelector('#roll'); - setTimeout(() => { rollbutton.click(); }, 100); - setTimeout(() => { rollbutton.click(); }, 200); - setTimeout(() => { rollbutton.click(); }, 300); - }, - ); -}()); -// End example Context Menu Items - -onUiUpdate(() => { - addContextMenuEventListener(); -}); +contextMenuInit = function () { + let eventListenerApplied = false; + const menuSpecs = new Map(); + + const uid = function () { + return Date.now().toString(36) + Math.random().toString(36).substr(2); + }; + + function showContextMenu(event, element, menuEntries) { + const posx = event.clientX + document.body.scrollLeft + document.documentElement.scrollLeft; + const posy = event.clientY + document.body.scrollTop + document.documentElement.scrollTop; + + const oldMenu = gradioApp().querySelector('#context-menu'); + if (oldMenu) { + oldMenu.remove(); + } + + const tabButton = uiCurrentTab; + const baseStyle = window.getComputedStyle(tabButton); + + const contextMenu = document.createElement('nav'); + contextMenu.id = 'context-menu'; + contextMenu.style.background = baseStyle.background; + contextMenu.style.color = baseStyle.color; + contextMenu.style.fontFamily = baseStyle.fontFamily; + contextMenu.style.top = `${posy}px`; + contextMenu.style.left = `${posx}px`; + + const contextMenuList = document.createElement('ul'); + contextMenuList.className = 'context-menu-items'; + contextMenu.append(contextMenuList); + + menuEntries.forEach((entry) => { + const contextMenuEntry = document.createElement('a'); + contextMenuEntry.innerHTML = entry.name; + contextMenuEntry.addEventListener('click', (e) => { + entry.func(); + }); + contextMenuList.append(contextMenuEntry); + }); + + gradioApp().appendChild(contextMenu); + + const menuWidth = contextMenu.offsetWidth + 4; + const menuHeight = contextMenu.offsetHeight + 4; + + const windowWidth = window.innerWidth; + const windowHeight = window.innerHeight; + + if ((windowWidth - posx) < menuWidth) { + contextMenu.style.left = `${windowWidth - menuWidth}px`; + } + + if ((windowHeight - posy) < menuHeight) { + contextMenu.style.top = `${windowHeight - menuHeight}px`; + } + } + + function appendContextMenuOption(targetElementSelector, entryName, entryFunction) { + currentItems = menuSpecs.get(targetElementSelector); + + if (!currentItems) { + currentItems = []; + menuSpecs.set(targetElementSelector, currentItems); + } + const newItem = { + id: `${targetElementSelector}_${uid()}`, + name: entryName, + func: entryFunction, + isNew: true, + }; + + currentItems.push(newItem); + return newItem.id; + } + + function removeContextMenuOption(uid) { + menuSpecs.forEach((v, k) => { + let index = -1; + v.forEach((e, ei) => { if (e.id === uid) { index = ei; } }); + if (index >= 0) { + v.splice(index, 1); + } + }); + } + + function addContextMenuEventListener() { + if (eventListenerApplied) { + return; + } + gradioApp().addEventListener('click', (e) => { + const source = e.composedPath()[0]; + if (source.id && source.id.indexOf('check_progress') > -1) { + return; + } + + const oldMenu = gradioApp().querySelector('#context-menu'); + if (oldMenu) { + oldMenu.remove(); + } + }); + gradioApp().addEventListener('contextmenu', (e) => { + const oldMenu = gradioApp().querySelector('#context-menu'); + if (oldMenu) { + oldMenu.remove(); + } + menuSpecs.forEach((v, k) => { + if (e.composedPath()[0].matches(k)) { + showContextMenu(e, e.composedPath()[0], v); + e.preventDefault(); + } + }); + }); + eventListenerApplied = true; + } + + return [appendContextMenuOption, removeContextMenuOption, addContextMenuEventListener]; +}; + +initResponse = contextMenuInit(); +appendContextMenuOption = initResponse[0]; +removeContextMenuOption = initResponse[1]; +addContextMenuEventListener = initResponse[2]; + +(function () { + // Start example Context Menu Items + const generateOnRepeat = function (genbuttonid, interruptbuttonid) { + const genbutton = gradioApp().querySelector(genbuttonid); + const busy = document.getElementById('progressbar')?.style.display === 'block'; + if (!busy) { + genbutton.click(); + } + clearInterval(window.generateOnRepeatInterval); + window.generateOnRepeatInterval = setInterval( + () => { + const busy = document.getElementById('progressbar')?.style.display === 'block'; + if (!busy) { + genbutton.click(); + } + }, + 500, + ); + }; + + appendContextMenuOption('#txt2img_generate', 'Generate forever', () => { + generateOnRepeat('#txt2img_generate', '#txt2img_interrupt'); + }); + appendContextMenuOption('#img2img_generate', 'Generate forever', () => { + generateOnRepeat('#img2img_generate', '#img2img_interrupt'); + }); + + const cancelGenerateForever = function () { + clearInterval(window.generateOnRepeatInterval); + }; + + appendContextMenuOption('#txt2img_interrupt', 'Cancel generate forever', cancelGenerateForever); + appendContextMenuOption('#txt2img_generate', 'Cancel generate forever', cancelGenerateForever); + appendContextMenuOption('#img2img_interrupt', 'Cancel generate forever', cancelGenerateForever); + appendContextMenuOption('#img2img_generate', 'Cancel generate forever', cancelGenerateForever); + + appendContextMenuOption( + '#roll', + 'Roll three', + () => { + const rollbutton = get_uiCurrentTabContent().querySelector('#roll'); + setTimeout(() => { rollbutton.click(); }, 100); + setTimeout(() => { rollbutton.click(); }, 200); + setTimeout(() => { rollbutton.click(); }, 300); + }, + ); +}()); +// End example Context Menu Items + +onUiUpdate(() => { + addContextMenuEventListener(); +}); diff --git a/javascript/dragdrop.js b/javascript/dragdrop.js index 27fa4341d..55e229e89 100644 --- a/javascript/dragdrop.js +++ b/javascript/dragdrop.js @@ -47,7 +47,7 @@ function dropReplaceImage(imgWrap, files) { window.document.addEventListener('dragover', (e) => { const target = e.composedPath()[0]; const imgWrap = target.closest('[data-testid="image"]'); - if (!imgWrap && target.placeholder && target.placeholder.indexOf('Prompt') == -1) return; + if (!imgWrap && target.placeholder && target.placeholder.indexOf('Prompt') === -1) return; e.stopPropagation(); e.preventDefault(); e.dataTransfer.dropEffect = 'copy'; @@ -56,7 +56,7 @@ window.document.addEventListener('dragover', (e) => { window.document.addEventListener('drop', (e) => { const target = e.composedPath()[0]; if (!target.placeholder) return; - if (target.placeholder.indexOf('Prompt') == -1) return; + if (target.placeholder.indexOf('Prompt') === -1) return; const imgWrap = target.closest('[data-testid="image"]'); if (!imgWrap) return; e.stopPropagation(); diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js index 87ec52617..594baae67 100644 --- a/javascript/edit-attention.js +++ b/javascript/edit-attention.js @@ -1,119 +1,119 @@ -function keyupEditAttention(event) { - const target = event.originalTarget || event.composedPath()[0]; - if (!target.matches("[id*='_toprow'] [id*='_prompt'] textarea")) return; - if (!(event.metaKey || event.ctrlKey)) return; - - const isPlus = event.key == 'ArrowUp'; - const isMinus = event.key == 'ArrowDown'; - if (!isPlus && !isMinus) return; - - let { selectionStart } = target; - let { selectionEnd } = target; - let text = target.value; - - function selectCurrentParenthesisBlock(OPEN, CLOSE) { - if (selectionStart !== selectionEnd) return false; - - // Find opening parenthesis around current cursor - const before = text.substring(0, selectionStart); - let beforeParen = before.lastIndexOf(OPEN); - if (beforeParen == -1) return false; - let beforeParenClose = before.lastIndexOf(CLOSE); - while (beforeParenClose !== -1 && beforeParenClose > beforeParen) { - beforeParen = before.lastIndexOf(OPEN, beforeParen - 1); - beforeParenClose = before.lastIndexOf(CLOSE, beforeParenClose - 1); - } - - // Find closing parenthesis around current cursor - const after = text.substring(selectionStart); - let afterParen = after.indexOf(CLOSE); - if (afterParen == -1) return false; - let afterParenOpen = after.indexOf(OPEN); - while (afterParenOpen !== -1 && afterParen > afterParenOpen) { - afterParen = after.indexOf(CLOSE, afterParen + 1); - afterParenOpen = after.indexOf(OPEN, afterParenOpen + 1); - } - if (beforeParen === -1 || afterParen === -1) return false; - - // Set the selection to the text between the parenthesis - const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen); - const lastColon = parenContent.lastIndexOf(':'); - selectionStart = beforeParen + 1; - selectionEnd = selectionStart + lastColon; - target.setSelectionRange(selectionStart, selectionEnd); - return true; - } - - function selectCurrentWord() { - if (selectionStart !== selectionEnd) return false; - const delimiters = `${opts.keyedit_delimiters} \r\n\t`; - - // seek backward until to find beggining - while (!delimiters.includes(text[selectionStart - 1]) && selectionStart > 0) { - selectionStart--; - } - - // seek forward to find end - while (!delimiters.includes(text[selectionEnd]) && selectionEnd < text.length) { - selectionEnd++; - } - - target.setSelectionRange(selectionStart, selectionEnd); - return true; - } - - // If the user hasn't selected anything, let's select their current parenthesis block or word - if (!selectCurrentParenthesisBlock('<', '>') && !selectCurrentParenthesisBlock('(', ')')) { - selectCurrentWord(); - } - - event.preventDefault(); - - closeCharacter = ')'; - delta = opts.keyedit_precision_attention; - - if (selectionStart > 0 && text[selectionStart - 1] == '<') { - closeCharacter = '>'; - delta = opts.keyedit_precision_extra; - } else if (selectionStart == 0 || text[selectionStart - 1] != '(') { - // do not include spaces at the end - while (selectionEnd > selectionStart && text[selectionEnd - 1] == ' ') { - selectionEnd -= 1; - } - if (selectionStart == selectionEnd) { - return; - } - - text = `${text.slice(0, selectionStart)}(${text.slice(selectionStart, selectionEnd)}:1.0)${text.slice(selectionEnd)}`; - - selectionStart += 1; - selectionEnd += 1; - } - - end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1; - weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end)); - if (isNaN(weight)) return; - - weight += isPlus ? delta : -delta; - weight = parseFloat(weight.toPrecision(12)); - if (String(weight).length == 1) weight += '.0'; - - if (closeCharacter == ')' && weight == 1) { - text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + 5); - selectionStart--; - selectionEnd--; - } else { - text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1); - } - - target.focus(); - target.value = text; - target.selectionStart = selectionStart; - target.selectionEnd = selectionEnd; - - updateInput(target); -} - -addEventListener('keydown', (event) => { - keyupEditAttention(event); -}); +function keyupEditAttention(event) { + const target = event.originalTarget || event.composedPath()[0]; + if (!target.matches("[id*='_toprow'] [id*='_prompt'] textarea")) return; + if (!(event.metaKey || event.ctrlKey)) return; + + const isPlus = event.key === 'ArrowUp'; + const isMinus = event.key === 'ArrowDown'; + if (!isPlus && !isMinus) return; + + let { selectionStart } = target; + let { selectionEnd } = target; + let text = target.value; + + function selectCurrentParenthesisBlock(OPEN, CLOSE) { + if (selectionStart !== selectionEnd) return false; + + // Find opening parenthesis around current cursor + const before = text.substring(0, selectionStart); + let beforeParen = before.lastIndexOf(OPEN); + if (beforeParen === -1) return false; + let beforeParenClose = before.lastIndexOf(CLOSE); + while (beforeParenClose !== -1 && beforeParenClose > beforeParen) { + beforeParen = before.lastIndexOf(OPEN, beforeParen - 1); + beforeParenClose = before.lastIndexOf(CLOSE, beforeParenClose - 1); + } + + // Find closing parenthesis around current cursor + const after = text.substring(selectionStart); + let afterParen = after.indexOf(CLOSE); + if (afterParen === -1) return false; + let afterParenOpen = after.indexOf(OPEN); + while (afterParenOpen !== -1 && afterParen > afterParenOpen) { + afterParen = after.indexOf(CLOSE, afterParen + 1); + afterParenOpen = after.indexOf(OPEN, afterParenOpen + 1); + } + if (beforeParen === -1 || afterParen === -1) return false; + + // Set the selection to the text between the parenthesis + const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen); + const lastColon = parenContent.lastIndexOf(':'); + selectionStart = beforeParen + 1; + selectionEnd = selectionStart + lastColon; + target.setSelectionRange(selectionStart, selectionEnd); + return true; + } + + function selectCurrentWord() { + if (selectionStart !== selectionEnd) return false; + const delimiters = `${opts.keyedit_delimiters} \r\n\t`; + + // seek backward until to find beggining + while (!delimiters.includes(text[selectionStart - 1]) && selectionStart > 0) { + selectionStart--; + } + + // seek forward to find end + while (!delimiters.includes(text[selectionEnd]) && selectionEnd < text.length) { + selectionEnd++; + } + + target.setSelectionRange(selectionStart, selectionEnd); + return true; + } + + // If the user hasn't selected anything, let's select their current parenthesis block or word + if (!selectCurrentParenthesisBlock('<', '>') && !selectCurrentParenthesisBlock('(', ')')) { + selectCurrentWord(); + } + + event.preventDefault(); + + closeCharacter = ')'; + delta = opts.keyedit_precision_attention; + + if (selectionStart > 0 && text[selectionStart - 1] === '<') { + closeCharacter = '>'; + delta = opts.keyedit_precision_extra; + } else if (selectionStart === 0 || text[selectionStart - 1] != '(') { + // do not include spaces at the end + while (selectionEnd > selectionStart && text[selectionEnd - 1] === ' ') { + selectionEnd -= 1; + } + if (selectionStart === selectionEnd) { + return; + } + + text = `${text.slice(0, selectionStart)}(${text.slice(selectionStart, selectionEnd)}:1.0)${text.slice(selectionEnd)}`; + + selectionStart += 1; + selectionEnd += 1; + } + + end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1; + weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end)); + if (isNaN(weight)) return; + + weight += isPlus ? delta : -delta; + weight = parseFloat(weight.toPrecision(12)); + if (String(weight).length === 1) weight += '.0'; + + if (closeCharacter === ')' && weight === 1) { + text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + 5); + selectionStart--; + selectionEnd--; + } else { + text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1); + } + + target.focus(); + target.value = text; + target.selectionStart = selectionStart; + target.selectionEnd = selectionEnd; + + updateInput(target); +} + +addEventListener('keydown', (event) => { + keyupEditAttention(event); +}); diff --git a/javascript/extensions.js b/javascript/extensions.js index 2bcaa50fc..5f8049d95 100644 --- a/javascript/extensions.js +++ b/javascript/extensions.js @@ -1,32 +1,41 @@ -function extensions_apply(_a, _b, disable_all){ - var disable = [] - var update = [] - gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){ - if(x.name.startsWith("enable_") && ! x.checked) disable.push(x.name.substr(7)) - if(x.name.startsWith("update_") && x.checked) update.push(x.name.substr(7)) - }) - restart_reload() - return [JSON.stringify(disable), JSON.stringify(update), disable_all] -} - -function extensions_check(_, _){ - var disable = [] - gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){ - if(x.name.startsWith("enable_") && ! x.checked) disable.push(x.name.substr(7)) - }) - gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){ - x.innerHTML = "Loading..." - }) - var id = randomId() - requestProgress(id, gradioApp().getElementById('extensions_installed_top'), null, null, null, false) - return [id, JSON.stringify(disable)] -} - -function install_extension_from_index(button, url){ - button.disabled = "disabled" - button.value = "Installing..." - textarea = gradioApp().querySelector('#extension_to_install textarea') - textarea.value = url - updateInput(textarea) - gradioApp().querySelector('#install_extension_button').click() -} +function extensions_apply(_a, _b, disable_all) { + const disable = []; + const update = []; + gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach((x) => { + if (x.name.startsWith('enable_') && !x.checked) disable.push(x.name.substr(7)); + if (x.name.startsWith('update_') && x.checked) update.push(x.name.substr(7)); + }); + restart_reload(); + return [JSON.stringify(disable), JSON.stringify(update), disable_all]; +} + +function extensions_check(_a, _b) { + const disable = []; + gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach((x) => { + if (x.name.startsWith('enable_') && !x.checked) disable.push(x.name.substr(7)); + }); + gradioApp().querySelectorAll('#extensions .extension_status').forEach((x) => { + x.innerHTML = 'Loading...'; + }); + const id = randomId(); + // requestProgress(id, gradioApp().getElementById('extensions_installed_top'), null, null, null, false); + return [id, JSON.stringify(disable)]; +} + +function install_extension(button, url) { + button.disabled = 'disabled'; + button.value = 'Installing...'; + const textarea = gradioApp().querySelector('#extension_to_install textarea'); + textarea.value = url; + updateInput(textarea); + gradioApp().querySelector('#install_extension_button').click(); +} + +function uninstall_extension(button, url) { + button.disabled = 'disabled'; + button.value = 'Uninstalling...'; + const textarea = gradioApp().querySelector('#extension_to_install textarea'); + textarea.value = url; + updateInput(textarea); + gradioApp().querySelector('#uninstall_extension_button').click(); +} diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 9b2a610f7..5b9254cd5 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -1,170 +1,170 @@ -function setupExtraNetworksForTab(tabname) { - gradioApp().querySelector(`#${tabname}_extra_tabs`).classList.add('extra-networks'); - const tabs = gradioApp().querySelector(`#${tabname}_extra_tabs > div`); - const search = gradioApp().querySelector(`#${tabname}_extra_search textarea`); - const refresh = gradioApp().getElementById(`${tabname}_extra_refresh`); - const descriptInput = gradioApp().getElementById(`${tabname}_description_input`); - const close = gradioApp().getElementById(`${tabname}_extra_close`); - search.classList.add('search'); - tabs.appendChild(search); - tabs.appendChild(refresh); - tabs.appendChild(close); - tabs.appendChild(descriptInput); - search.addEventListener('input', (evt) => { - searchTerm = search.value.toLowerCase(); - gradioApp().querySelectorAll(`#${tabname}_extra_tabs div.card`).forEach((elem) => { - text = `${elem.querySelector('.name').textContent.toLowerCase()} ${elem.querySelector('.search_term').textContent.toLowerCase()}`; - elem.style.display = text.indexOf(searchTerm) == -1 ? 'none' : ''; - }); - }); -} - -const activePromptTextarea = {}; - -function setupExtraNetworks() { - setupExtraNetworksForTab('txt2img'); - setupExtraNetworksForTab('img2img'); - function registerPrompt(tabname, id) { - const textarea = gradioApp().querySelector(`#${id} > label > textarea`); - if (!activePromptTextarea[tabname]) activePromptTextarea[tabname] = textarea; - textarea.addEventListener('focus', () => { - activePromptTextarea[tabname] = textarea; - }); - } - registerPrompt('txt2img', 'txt2img_prompt'); - registerPrompt('txt2img', 'txt2img_neg_prompt'); - registerPrompt('img2img', 'img2img_prompt'); - registerPrompt('img2img', 'img2img_neg_prompt'); -} - -onUiLoaded(setupExtraNetworks); -const re_extranet = /<([^:]+:[^:]+):[\d\.]+>/; -const re_extranet_g = /\s+<([^:]+:[^:]+):[\d\.]+>/g; - -function tryToRemoveExtraNetworkFromPrompt(textarea, text) { - let m = text.match(re_extranet); - if (!m) return false; - const partToSearch = m[1]; - let replaced = false; - const newTextareaText = textarea.value.replaceAll(re_extranet_g, (found, index) => { - m = found.match(re_extranet); - if (m[1] == partToSearch) { - replaced = true; - return ''; - } - return found; - }); - if (replaced) { - textarea.value = newTextareaText; - return true; - } - return false; -} - -function cardClicked(tabname, textToAdd, allowNegativePrompt) { - const textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector(`#${tabname}_prompt > label > textarea`); - if (!tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)) textarea.value = textarea.value + opts.extra_networks_add_text_separator + textToAdd; - updateInput(textarea); -} - -function saveCardPreview(event, tabname, filename) { - const textarea = gradioApp().querySelector(`#${tabname}_preview_filename > label > textarea`); - const button = gradioApp().getElementById(`${tabname}_save_preview`); - textarea.value = filename; - updateInput(textarea); - button.click(); - event.stopPropagation(); - event.preventDefault(); -} - -function saveCardDescription(event, tabname, filename, descript) { - const textarea = gradioApp().querySelector(`#${tabname}_description_filename > label > textarea`); - const button = gradioApp().getElementById(`${tabname}_save_description`); - const description = gradioApp().getElementById(`${tabname}_description_input`); - textarea.value = filename; - description.value = descript; - updateInput(textarea); - button.click(); - event.stopPropagation(); - event.preventDefault(); -} - -function readCardDescription(event, tabname, filename, descript, extraPage, cardName) { - const textarea = gradioApp().querySelector(`#${tabname}_description_filename > label > textarea`); - const description_textarea = gradioApp().querySelector(`#${tabname}_description_input > label > textarea`); - const button = gradioApp().getElementById(`${tabname}_read_description`); - textarea.value = filename; - description_textarea.value = descript; - updateInput(textarea); - updateInput(description_textarea); - button.click(); - event.stopPropagation(); - event.preventDefault(); -} - -function extraNetworksSearchButton(tabs_id, event) { - searchTextarea = gradioApp().querySelector(`#${tabs_id} > div > textarea`); - button = event.target; - text = button.classList.contains('search-all') ? '' : button.textContent.trim(); - searchTextarea.value = text; - updateInput(searchTextarea); -} - -let globalPopup = null; -let globalPopupInner = null; -function popup(contents) { - if (!globalPopup) { - globalPopup = document.createElement('div'); - globalPopup.onclick = function () { globalPopup.style.display = 'none'; }; - globalPopup.classList.add('global-popup'); - const close = document.createElement('div'); - close.classList.add('global-popup-close'); - close.onclick = function () { globalPopup.style.display = 'none'; }; - close.title = 'Close'; - globalPopup.appendChild(close); - globalPopupInner = document.createElement('div'); - globalPopupInner.onclick = function (event) { event.stopPropagation(); return false; }; - globalPopupInner.classList.add('global-popup-inner'); - globalPopup.appendChild(globalPopupInner); - gradioApp().appendChild(globalPopup); - } - globalPopupInner.innerHTML = ''; - globalPopupInner.appendChild(contents); - globalPopup.style.display = 'flex'; -} - -function readCardMetadata(event, extraPage, cardName) { - requestGet('./sd_extra_networks/metadata', { page: extraPage, item: cardName }, (data) => { - if (data && data.metadata) { - elem = document.createElement('pre'); - elem.classList.add('popup-metadata'); - elem.textContent = data.metadata; - popup(elem); - } - }, () => {}); - event.stopPropagation(); - event.preventDefault(); -} - -function requestGet(url, data, handler, errorHandler) { - const xhr = new XMLHttpRequest(); - const args = Object.keys(data).map((k) => `${encodeURIComponent(k)}=${encodeURIComponent(data[k])}`).join('&'); - xhr.open('GET', `${url}?${args}`, true); - xhr.onreadystatechange = function () { - if (xhr.readyState === 4) { - if (xhr.status === 200) { - try { - const js = JSON.parse(xhr.responseText); - handler(js); - } catch (error) { - console.error(error); - errorHandler(); - } - } else { - errorHandler(); - } - } - }; - const js = JSON.stringify(data); - xhr.send(js); -} +function setupExtraNetworksForTab(tabname) { + gradioApp().querySelector(`#${tabname}_extra_tabs`).classList.add('extra-networks'); + const tabs = gradioApp().querySelector(`#${tabname}_extra_tabs > div`); + const search = gradioApp().querySelector(`#${tabname}_extra_search textarea`); + const refresh = gradioApp().getElementById(`${tabname}_extra_refresh`); + const descriptInput = gradioApp().getElementById(`${tabname}_description_input`); + const close = gradioApp().getElementById(`${tabname}_extra_close`); + search.classList.add('search'); + tabs.appendChild(search); + tabs.appendChild(refresh); + tabs.appendChild(close); + tabs.appendChild(descriptInput); + search.addEventListener('input', (evt) => { + searchTerm = search.value.toLowerCase(); + gradioApp().querySelectorAll(`#${tabname}_extra_tabs div.card`).forEach((elem) => { + text = `${elem.querySelector('.name').textContent.toLowerCase()} ${elem.querySelector('.search_term').textContent.toLowerCase()}`; + elem.style.display = text.indexOf(searchTerm) == -1 ? 'none' : ''; + }); + }); +} + +const activePromptTextarea = {}; + +function setupExtraNetworks() { + setupExtraNetworksForTab('txt2img'); + setupExtraNetworksForTab('img2img'); + function registerPrompt(tabname, id) { + const textarea = gradioApp().querySelector(`#${id} > label > textarea`); + if (!activePromptTextarea[tabname]) activePromptTextarea[tabname] = textarea; + textarea.addEventListener('focus', () => { + activePromptTextarea[tabname] = textarea; + }); + } + registerPrompt('txt2img', 'txt2img_prompt'); + registerPrompt('txt2img', 'txt2img_neg_prompt'); + registerPrompt('img2img', 'img2img_prompt'); + registerPrompt('img2img', 'img2img_neg_prompt'); +} + +onUiLoaded(setupExtraNetworks); +const re_extranet = /<([^:]+:[^:]+):[\d\.]+>/; +const re_extranet_g = /\s+<([^:]+:[^:]+):[\d\.]+>/g; + +function tryToRemoveExtraNetworkFromPrompt(textarea, text) { + let m = text.match(re_extranet); + if (!m) return false; + const partToSearch = m[1]; + let replaced = false; + const newTextareaText = textarea.value.replaceAll(re_extranet_g, (found, index) => { + m = found.match(re_extranet); + if (m[1] == partToSearch) { + replaced = true; + return ''; + } + return found; + }); + if (replaced) { + textarea.value = newTextareaText; + return true; + } + return false; +} + +function cardClicked(tabname, textToAdd, allowNegativePrompt) { + const textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector(`#${tabname}_prompt > label > textarea`); + if (!tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)) textarea.value = textarea.value + opts.extra_networks_add_text_separator + textToAdd; + updateInput(textarea); +} + +function saveCardPreview(event, tabname, filename) { + const textarea = gradioApp().querySelector(`#${tabname}_preview_filename > label > textarea`); + const button = gradioApp().getElementById(`${tabname}_save_preview`); + textarea.value = filename; + updateInput(textarea); + button.click(); + event.stopPropagation(); + event.preventDefault(); +} + +function saveCardDescription(event, tabname, filename, descript) { + const textarea = gradioApp().querySelector(`#${tabname}_description_filename > label > textarea`); + const button = gradioApp().getElementById(`${tabname}_save_description`); + const description = gradioApp().getElementById(`${tabname}_description_input`); + textarea.value = filename; + description.value = descript; + updateInput(textarea); + button.click(); + event.stopPropagation(); + event.preventDefault(); +} + +function readCardDescription(event, tabname, filename, descript, extraPage, cardName) { + const textarea = gradioApp().querySelector(`#${tabname}_description_filename > label > textarea`); + const description_textarea = gradioApp().querySelector(`#${tabname}_description_input > label > textarea`); + const button = gradioApp().getElementById(`${tabname}_read_description`); + textarea.value = filename; + description_textarea.value = descript; + updateInput(textarea); + updateInput(description_textarea); + button.click(); + event.stopPropagation(); + event.preventDefault(); +} + +function extraNetworksSearchButton(tabs_id, event) { + searchTextarea = gradioApp().querySelector(`#${tabs_id} > div > textarea`); + button = event.target; + text = button.classList.contains('search-all') ? '' : button.textContent.trim(); + searchTextarea.value = text; + updateInput(searchTextarea); +} + +let globalPopup = null; +let globalPopupInner = null; +function popup(contents) { + if (!globalPopup) { + globalPopup = document.createElement('div'); + globalPopup.onclick = function () { globalPopup.style.display = 'none'; }; + globalPopup.classList.add('global-popup'); + const close = document.createElement('div'); + close.classList.add('global-popup-close'); + close.onclick = function () { globalPopup.style.display = 'none'; }; + close.title = 'Close'; + globalPopup.appendChild(close); + globalPopupInner = document.createElement('div'); + globalPopupInner.onclick = function (event) { event.stopPropagation(); return false; }; + globalPopupInner.classList.add('global-popup-inner'); + globalPopup.appendChild(globalPopupInner); + gradioApp().appendChild(globalPopup); + } + globalPopupInner.innerHTML = ''; + globalPopupInner.appendChild(contents); + globalPopup.style.display = 'flex'; +} + +function readCardMetadata(event, extraPage, cardName) { + requestGet('./sd_extra_networks/metadata', { page: extraPage, item: cardName }, (data) => { + if (data && data.metadata) { + elem = document.createElement('pre'); + elem.classList.add('popup-metadata'); + elem.textContent = data.metadata; + popup(elem); + } + }, () => {}); + event.stopPropagation(); + event.preventDefault(); +} + +function requestGet(url, data, handler, errorHandler) { + const xhr = new XMLHttpRequest(); + const args = Object.keys(data).map((k) => `${encodeURIComponent(k)}=${encodeURIComponent(data[k])}`).join('&'); + xhr.open('GET', `${url}?${args}`, true); + xhr.onreadystatechange = function () { + if (xhr.readyState === 4) { + if (xhr.status === 200) { + try { + const js = JSON.parse(xhr.responseText); + handler(js); + } catch (error) { + console.error(error); + errorHandler(); + } + } else { + errorHandler(); + } + } + }; + const js = JSON.stringify(data); + xhr.send(js); +} diff --git a/javascript/style.css b/javascript/style.css index 0f77632cb..04902104c 100644 --- a/javascript/style.css +++ b/javascript/style.css @@ -505,11 +505,25 @@ div#extras_scale_to_tab div.form{ font-size: 95%; } -#available_extensions .info{ +#extensions .name{ + font-size: 1.1rem +} + +#extensions .type{ + opacity: 0.5; + font-size: 90%; + text-align: center; +} + +#extensions .version{ + opacity: 0.7; +} + +#extensions .info{ margin: 0; } -#available_extensions .date_added{ +#extensions .date{ opacity: 0.85; font-size: 90%; } diff --git a/javascript/textualInversion.js b/javascript/textualInversion.js index 8a50ec601..db73b03d9 100644 --- a/javascript/textualInversion.js +++ b/javascript/textualInversion.js @@ -1,13 +1,10 @@ - - - -function start_training_textual_inversion(){ - gradioApp().querySelector('#ti_error').innerHTML='' - var id = randomId() - const onProgress = (progress) => gradioApp().getElementById('ti_progress').innerHTML = progress.textinfo; - // requestProgress(id_task, progressbarContainer, gallery, atEnd = null, onProgress = null, once = false) { - requestProgress(id, gradioApp().getElementById('ti_output'), gradioApp().getElementById('ti_gallery'), null, onProgress, false) - var res = args_to_array(arguments) - res[0] = id - return res -} +function start_training_textual_inversion() { + gradioApp().querySelector('#ti_error').innerHTML='' + var id = randomId() + const onProgress = (progress) => gradioApp().getElementById('ti_progress').innerHTML = progress.textinfo; + // requestProgress(id_task, progressbarContainer, gallery, atEnd = null, onProgress = null, once = false) { + requestProgress(id, gradioApp().getElementById('ti_output'), gradioApp().getElementById('ti_gallery'), null, onProgress, false) + var res = args_to_array(arguments) + res[0] = id + return res +} diff --git a/modules/call_queue.py b/modules/call_queue.py index 2ea136a19..3af8eca08 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -14,9 +14,7 @@ def wrap_queued_call(func): def f(*args, **kwargs): with queue_lock: res = func(*args, **kwargs) - return res - return f @@ -38,9 +36,9 @@ def f(*args, **kwargs): progress.finish_task(id_task) shared.state.end() return res - return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True) + def wrap_gradio_call(func, extra_outputs=None, add_stats=False): def f(*args, extra_outputs_array=extra_outputs, **kwargs): run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats @@ -73,21 +71,17 @@ def f(*args, extra_outputs_array=extra_outputs, **kwargs): if extra_outputs_array is None: extra_outputs_array = [None, ''] res = extra_outputs_array + [f"
GPU active {active_peak} MB reserved {reserved_peak} MB | System peak {sys_peak} MB total {sys_total} MB
" else: vram_html = '' - res[-1] += f"Time taken: {elapsed_text}
{vram_html}{self.commit_hash[:8]}
{datetime.fromtimestamp(self.commit_date).strftime('%a %b%d %Y %H:%M')}
" except Exception as ex: shared.log.error(f"Failed reading extension data from Git repository: {self.name}: {ex}") self.remote = None def list_files(self, subdir, extension): from modules import scripts - dirpath = os.path.join(self.path, subdir) if not os.path.isdir(dirpath): return [] - res = [] for filename in sorted(os.listdir(dirpath)): priority = '50' @@ -80,9 +80,7 @@ def list_files(self, subdir, extension): with open(os.path.join(dirpath, "..", ".priority"), "r", encoding="utf-8") as f: priority = str(f.read().strip()) res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename), priority)) - res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)] - return res def check_updates(self): @@ -92,7 +90,6 @@ def check_updates(self): self.can_update = True self.status = "new commits" return - try: origin = repo.rev_parse('origin') if repo.head.commit != origin: @@ -103,7 +100,6 @@ def check_updates(self): self.can_update = False self.status = "unknown (remote error)" return - self.can_update = False self.status = "latest" @@ -119,19 +115,15 @@ def fetch_and_reset_hard(self, commit='origin'): def list_extensions(): extensions.clear() - if not os.path.isdir(extensions_dir): return - - if shared.opts.disable_all_extensions == "all" or shared.opts.disable_all_extensions == "extra": - shared.log.warning("Option set: Disable all extensions") - + if shared.opts.disable_all_extensions == "all" or shared.opts.disable_all_extensions == "user": + shared.log.warning(f"Option set: Disable extensions: {shared.opts.disable_all_extensions}") extension_paths = [] extension_names = [] for dirname in [extensions_builtin_dir, extensions_dir]: if not os.path.isdir(dirname): return - for extension_dirname in sorted(os.listdir(dirname)): path = os.path.join(dirname, extension_dirname) if not os.path.isdir(path): @@ -141,7 +133,6 @@ def list_extensions(): continue extension_names.append(extension_dirname) extension_paths.append((extension_dirname, path, dirname == extensions_builtin_dir)) - for dirname, path, is_builtin in extension_paths: extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin) extensions.append(extension) diff --git a/modules/shared.py b/modules/shared.py index 509f314ea..f82ace1c1 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -463,7 +463,7 @@ def refresh_themes(): options_templates.update(options_section((None, "Hidden options"), { "disabled_extensions": OptionInfo([], "Disable these extensions"), - "disable_all_extensions": OptionInfo("none", "Disable all extensions (preserves the list of disabled extensions)", gr.Radio, {"choices": ["none", "extra", "all"]}), + "disable_all_extensions": OptionInfo("none", "Disable all extensions (preserves the list of disabled extensions)", gr.Radio, {"choices": ["none", "user", "all"]}), "sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"), })) diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 69de248b9..585fd34aa 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -1,16 +1,51 @@ import json import os.path -import time import shutil import errno import html +from datetime import datetime import git import gradio as gr from modules import extensions, shared, paths, errors from modules.call_queue import wrap_gradio_gpu_call -available_extensions = {"extensions": []} -STYLE_PRIMARY = ' style="color: var(--primary-400)"' + +extensions_index = "https://vladmandic.github.io/sd-data/pages/extensions.json" +hide_tags = ["localization"] +extensions_list = [] + + +def update_extension_list(): + global extensions_list # pylint: disable=global-statement + try: + with open(os.path.join(paths.script_path, "html", "extensions.json"), "r", encoding="utf-8") as f: + extensions_list = json.loads(f.read()) + shared.log.debug(f'Extensions list loaded: {os.path.join(paths.script_path, "html", "extensions.json")}') + except: + shared.log.debug(f'Extensions list failed to load: {os.path.join(paths.script_path, "html", "extensions.json")}') + found = [] + for ext in extensions_list: + installed = [extension for extension in extensions.extensions if extension.git_name == ext['name'] or extension.name == ext['name']] + if len(installed) > 0: + found.append(installed[0]) + not_matched = [extension for extension in extensions.extensions if extension not in found] + for ext in not_matched: + ext.read_info_from_repo() + entry = { + "name": ext.name or "", + "description": ext.description or "", + "url": ext.remote or "", + "tags": [], + "stars": 0, + "issues": 0, + "commits": 0, + "size": 0, + "long": ext.git_name or ext.name or "", + "added": ext.ctime, + "created": ext.ctime, + "updated": ext.mtime, + } + extensions_list.append(entry) def check_access(): @@ -19,6 +54,7 @@ def check_access(): def apply_and_restart(disable_list, update_list, disable_all): check_access() + shared.log.debug(f'Extensions apply: disable={disable_list} update={update_list}') disabled = json.loads(disable_list) assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}" update = json.loads(update_list) @@ -37,18 +73,16 @@ def apply_and_restart(disable_list, update_list, disable_all): shared.restart_server(restart=True) -def check_updates(_id_task, disable_list): +def check_updates(_id_task, disable_list, search_text, sort_column): check_access() - disabled = json.loads(disable_list) assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}" - exts = [ext for ext in extensions.extensions if ext.remote is not None and ext.name not in disabled] + shared.log.info(f'Extensions update check: update={len(exts)} disabled={len(disable_list)}') shared.state.job_count = len(exts) - for ext in exts: + shared.log.debug(f'Extensions update: {ext.name}') shared.state.textinfo = ext.name - try: ext.check_updates() except FileNotFoundError as e: @@ -56,10 +90,8 @@ def check_updates(_id_task, disable_list): raise except Exception: errors.display(e, f'extensions check update: {ext.name}') - shared.state.nextjob() - - return extension_table(), "" + return refresh_extensions_list_from_data(search_text, sort_column), "Update complete, please restart the server" def make_commit_link(commit_hash, remote, text=None): @@ -72,84 +104,26 @@ def make_commit_link(commit_hash, remote, text=None): return text -def extension_table(): - code = f""" -Extension | -Type | -URL | -Version | -Update | -
---|---|---|---|---|
- | {"system" if ext.is_builtin else 'user'} | -{remote} | -{version_link} | -{ext_status} | -
Enabled | Extension | Description | -Action | +Type | +Current version | +||
---|---|---|---|---|---|---|---|
{html.escape(name)} {tags_text} |
- {html.escape(description)} Added: {html.escape(added)} |
+ {enabled_code} | +{html.escape(name)} {tags_text} |
+ {html.escape(description)}
+ Created {html.escape(created)} | Added {html.escape(added)} | Updated {html.escape(updated)} +Stars {html.escape(str(stars))} | Size {html.escape(str(size))} | Commits {html.escape(str(commits))} | Issues {html.escape(str(issues))} + |
+ {type_code} | +{version_code} | {install_code} | -
Extension hidden: {hidden}
" - - return code, list(tags) + """ + code += "Recommended settings: Sampling Steps: 80-100, Sampler: Euler a, Denoising strength: 0.8
") - - pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) - mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8, elem_id=self.elem_id("mask_blur")) - direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction")) - noise_q = gr.Slider(label="Fall-off exponent (lower=higher detail)", minimum=0.0, maximum=4.0, step=0.01, value=1.0, elem_id=self.elem_id("noise_q")) - color_variation = gr.Slider(label="Color variation", minimum=0.0, maximum=1.0, step=0.01, value=0.05, elem_id=self.elem_id("color_variation")) - - return [info, pixels, mask_blur, direction, noise_q, color_variation] - - def run(self, p, _, pixels, mask_blur, direction, noise_q, color_variation): - initial_seed_and_info = [None, None] - - process_width = p.width - process_height = p.height - - p.mask_blur = mask_blur*4 - p.inpaint_full_res = False - p.inpainting_fill = 1 - p.do_not_save_samples = True - p.do_not_save_grid = True - - left = pixels if "left" in direction else 0 - right = pixels if "right" in direction else 0 - up = pixels if "up" in direction else 0 - down = pixels if "down" in direction else 0 - - init_img = p.init_images[0] - target_w = math.ceil((init_img.width + left + right) / 64) * 64 - target_h = math.ceil((init_img.height + up + down) / 64) * 64 - - if left > 0: - left = left * (target_w - init_img.width) // (left + right) - - if right > 0: - right = target_w - init_img.width - left - - if up > 0: - up = up * (target_h - init_img.height) // (up + down) - - if down > 0: - down = target_h - init_img.height - up - - def expand(init, count, expand_pixels, is_left=False, is_right=False, is_top=False, is_bottom=False): - is_horiz = is_left or is_right - is_vert = is_top or is_bottom - pixels_horiz = expand_pixels if is_horiz else 0 - pixels_vert = expand_pixels if is_vert else 0 - - images_to_process = [] - output_images = [] - for n in range(count): - res_w = init[n].width + pixels_horiz - res_h = init[n].height + pixels_vert - process_res_w = math.ceil(res_w / 64) * 64 - process_res_h = math.ceil(res_h / 64) * 64 - - img = Image.new("RGB", (process_res_w, process_res_h)) - img.paste(init[n], (pixels_horiz if is_left else 0, pixels_vert if is_top else 0)) - mask = Image.new("RGB", (process_res_w, process_res_h), "white") - draw = ImageDraw.Draw(mask) - draw.rectangle(( - expand_pixels + mask_blur if is_left else 0, - expand_pixels + mask_blur if is_top else 0, - mask.width - expand_pixels - mask_blur if is_right else res_w, - mask.height - expand_pixels - mask_blur if is_bottom else res_h, - ), fill="black") - - np_image = (np.asarray(img) / 255.0).astype(np.float64) - np_mask = (np.asarray(mask) / 255.0).astype(np.float64) - noised = get_matched_noise(np_image, np_mask, noise_q, color_variation) - output_images.append(Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB")) - - target_width = min(process_width, init[n].width + pixels_horiz) if is_horiz else img.width - target_height = min(process_height, init[n].height + pixels_vert) if is_vert else img.height - p.width = target_width if is_horiz else img.width - p.height = target_height if is_vert else img.height - - crop_region = ( - 0 if is_left else output_images[n].width - target_width, - 0 if is_top else output_images[n].height - target_height, - target_width if is_left else output_images[n].width, - target_height if is_top else output_images[n].height, - ) - mask = mask.crop(crop_region) - p.image_mask = mask - - image_to_process = output_images[n].crop(crop_region) - images_to_process.append(image_to_process) - - p.init_images = images_to_process - - latent_mask = Image.new("RGB", (p.width, p.height), "white") - draw = ImageDraw.Draw(latent_mask) - draw.rectangle(( - expand_pixels + mask_blur * 2 if is_left else 0, - expand_pixels + mask_blur * 2 if is_top else 0, - mask.width - expand_pixels - mask_blur * 2 if is_right else res_w, - mask.height - expand_pixels - mask_blur * 2 if is_bottom else res_h, - ), fill="black") - p.latent_mask = latent_mask - - proc = process_images(p) - - if initial_seed_and_info[0] is None: - initial_seed_and_info[0] = proc.seed - initial_seed_and_info[1] = proc.info - - for n in range(count): - output_images[n].paste(proc.images[n], (0 if is_left else output_images[n].width - proc.images[n].width, 0 if is_top else output_images[n].height - proc.images[n].height)) - output_images[n] = output_images[n].crop((0, 0, res_w, res_h)) - - return output_images - - batch_count = p.n_iter - batch_size = p.batch_size - p.n_iter = 1 - state.job_count = batch_count * ((1 if left > 0 else 0) + (1 if right > 0 else 0) + (1 if up > 0 else 0) + (1 if down > 0 else 0)) - all_processed_images = [] - - for i in range(batch_count): - imgs = [init_img] * batch_size - state.job = f"Batch {i + 1} out of {batch_count}" - - if left > 0: - imgs = expand(imgs, batch_size, left, is_left=True) - if right > 0: - imgs = expand(imgs, batch_size, right, is_right=True) - if up > 0: - imgs = expand(imgs, batch_size, up, is_top=True) - if down > 0: - imgs = expand(imgs, batch_size, down, is_bottom=True) - - all_processed_images += imgs - - all_images = all_processed_images - - combined_grid_image = images.image_grid(all_processed_images) - unwanted_grid_because_of_img_count = len(all_processed_images) < 2 and opts.grid_only_if_multiple - if opts.return_grid and not unwanted_grid_because_of_img_count: - all_images = [combined_grid_image] + all_processed_images - - res = Processed(p, all_images, initial_seed_and_info[0], initial_seed_and_info[1]) - - if opts.samples_save: - for img in all_processed_images: - images.save_image(img, p.outpath_samples, "", res.seed, p.prompt, opts.samples_format, info=res.info, p=p) - - if opts.grid_save and not unwanted_grid_because_of_img_count: - images.save_image(combined_grid_image, p.outpath_grids, "grid", res.seed, p.prompt, opts.samples_format, info=res.info, short_filename=not opts.grid_extended_filename, grid=True, p=p) - - return res +import math +import numpy as np +import skimage +import gradio as gr +from PIL import Image, ImageDraw +import modules.scripts as scripts +from modules import images +from modules.processing import Processed, process_images +from modules.shared import opts, state + + +# this function is taken from https://github.com/parlance-zz/g-diffuser-bot +def get_matched_noise(_np_src_image, np_mask_rgb, noise_q=1, color_variation=0.05): + # helper fft routines that keep ortho normalization and auto-shift before and after fft + def _fft2(data): + if data.ndim > 2: # has channels + out_fft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128) + for c in range(data.shape[2]): + c_data = data[:, :, c] + out_fft[:, :, c] = np.fft.fft2(np.fft.fftshift(c_data), norm="ortho") + out_fft[:, :, c] = np.fft.ifftshift(out_fft[:, :, c]) + else: # one channel + out_fft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128) + out_fft[:, :] = np.fft.fft2(np.fft.fftshift(data), norm="ortho") + out_fft[:, :] = np.fft.ifftshift(out_fft[:, :]) + + return out_fft + + def _ifft2(data): + if data.ndim > 2: # has channels + out_ifft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128) + for c in range(data.shape[2]): + c_data = data[:, :, c] + out_ifft[:, :, c] = np.fft.ifft2(np.fft.fftshift(c_data), norm="ortho") + out_ifft[:, :, c] = np.fft.ifftshift(out_ifft[:, :, c]) + else: # one channel + out_ifft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128) + out_ifft[:, :] = np.fft.ifft2(np.fft.fftshift(data), norm="ortho") + out_ifft[:, :] = np.fft.ifftshift(out_ifft[:, :]) + + return out_ifft + + def _get_gaussian_window(width, height, std=3.14, mode=0): + window_scale_x = float(width / min(width, height)) + window_scale_y = float(height / min(width, height)) + + window = np.zeros((width, height)) + x = (np.arange(width) / width * 2. - 1.) * window_scale_x + for y in range(height): + fy = (y / height * 2. - 1.) * window_scale_y + if mode == 0: + window[:, y] = np.exp(-(x ** 2 + fy ** 2) * std) + else: + window[:, y] = (1 / ((x ** 2 + 1.) * (fy ** 2 + 1.))) ** (std / 3.14) # hey wait a minute that's not gaussian + + return window + + def _get_masked_window_rgb(np_mask_grey, hardness=1.): + np_mask_rgb = np.zeros((np_mask_grey.shape[0], np_mask_grey.shape[1], 3)) + if hardness != 1.: + hardened = np_mask_grey[:] ** hardness + else: + hardened = np_mask_grey[:] + for c in range(3): + np_mask_rgb[:, :, c] = hardened[:] + return np_mask_rgb + + width = _np_src_image.shape[0] + height = _np_src_image.shape[1] + num_channels = _np_src_image.shape[2] + + np_src_image = _np_src_image[:] * (1. - np_mask_rgb) + np_mask_grey = np.sum(np_mask_rgb, axis=2) / 3. + img_mask = np_mask_grey > 1e-6 + ref_mask = np_mask_grey < 1e-3 + + windowed_image = _np_src_image * (1. - _get_masked_window_rgb(np_mask_grey)) + windowed_image /= np.max(windowed_image) + windowed_image += np.average(_np_src_image) * np_mask_rgb # / (1.-np.average(np_mask_rgb)) # rather than leave the masked area black, we get better results from fft by filling the average unmasked color + + src_fft = _fft2(windowed_image) # get feature statistics from masked src img + src_dist = np.absolute(src_fft) + src_phase = src_fft / src_dist + + # create a generator with a static seed to make outpainting deterministic / only follow global seed + rng = np.random.default_rng(0) + + noise_window = _get_gaussian_window(width, height, mode=1) # start with simple gaussian noise + noise_rgb = rng.random((width, height, num_channels)) + noise_grey = np.sum(noise_rgb, axis=2) / 3. + noise_rgb *= color_variation # the colorfulness of the starting noise is blended to greyscale with a parameter + for c in range(num_channels): + noise_rgb[:, :, c] += (1. - color_variation) * noise_grey + + noise_fft = _fft2(noise_rgb) + for c in range(num_channels): + noise_fft[:, :, c] *= noise_window + noise_rgb = np.real(_ifft2(noise_fft)) + shaped_noise_fft = _fft2(noise_rgb) + shaped_noise_fft[:, :, :] = np.absolute(shaped_noise_fft[:, :, :]) ** 2 * (src_dist ** noise_q) * src_phase # perform the actual shaping + + brightness_variation = 0. # color_variation # todo: temporarily tieing brightness variation to color variation for now + contrast_adjusted_np_src = _np_src_image[:] * (brightness_variation + 1.) - brightness_variation * 2. + + # scikit-image is used for histogram matching, very convenient! + shaped_noise = np.real(_ifft2(shaped_noise_fft)) + shaped_noise -= np.min(shaped_noise) + shaped_noise /= np.max(shaped_noise) + shaped_noise[img_mask, :] = skimage.exposure.match_histograms(shaped_noise[img_mask, :] ** 1., contrast_adjusted_np_src[ref_mask, :], channel_axis=1) + shaped_noise = _np_src_image[:] * (1. - np_mask_rgb) + shaped_noise * np_mask_rgb + + matched_noise = shaped_noise[:] + + return np.clip(matched_noise, 0., 1.) + + + +class Script(scripts.Script): + def title(self): + return "Outpainting" + + def show(self, is_img2img): + return is_img2img + + def ui(self, is_img2img): + if not is_img2img: + return None + + info = gr.HTML("Recommended settings: Sampling Steps: 80-100, Sampler: Euler a, Denoising strength: 0.8
") + + pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) + mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8, elem_id=self.elem_id("mask_blur")) + direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction")) + noise_q = gr.Slider(label="Fall-off exponent (lower=higher detail)", minimum=0.0, maximum=4.0, step=0.01, value=1.0, elem_id=self.elem_id("noise_q")) + color_variation = gr.Slider(label="Color variation", minimum=0.0, maximum=1.0, step=0.01, value=0.05, elem_id=self.elem_id("color_variation")) + + return [info, pixels, mask_blur, direction, noise_q, color_variation] + + def run(self, p, _, pixels, mask_blur, direction, noise_q, color_variation): + initial_seed_and_info = [None, None] + + process_width = p.width + process_height = p.height + + p.mask_blur = mask_blur*4 + p.inpaint_full_res = False + p.inpainting_fill = 1 + p.do_not_save_samples = True + p.do_not_save_grid = True + + left = pixels if "left" in direction else 0 + right = pixels if "right" in direction else 0 + up = pixels if "up" in direction else 0 + down = pixels if "down" in direction else 0 + + init_img = p.init_images[0] + target_w = math.ceil((init_img.width + left + right) / 64) * 64 + target_h = math.ceil((init_img.height + up + down) / 64) * 64 + + if left > 0: + left = left * (target_w - init_img.width) // (left + right) + + if right > 0: + right = target_w - init_img.width - left + + if up > 0: + up = up * (target_h - init_img.height) // (up + down) + + if down > 0: + down = target_h - init_img.height - up + + def expand(init, count, expand_pixels, is_left=False, is_right=False, is_top=False, is_bottom=False): + is_horiz = is_left or is_right + is_vert = is_top or is_bottom + pixels_horiz = expand_pixels if is_horiz else 0 + pixels_vert = expand_pixels if is_vert else 0 + + images_to_process = [] + output_images = [] + for n in range(count): + res_w = init[n].width + pixels_horiz + res_h = init[n].height + pixels_vert + process_res_w = math.ceil(res_w / 64) * 64 + process_res_h = math.ceil(res_h / 64) * 64 + + img = Image.new("RGB", (process_res_w, process_res_h)) + img.paste(init[n], (pixels_horiz if is_left else 0, pixels_vert if is_top else 0)) + mask = Image.new("RGB", (process_res_w, process_res_h), "white") + draw = ImageDraw.Draw(mask) + draw.rectangle(( + expand_pixels + mask_blur if is_left else 0, + expand_pixels + mask_blur if is_top else 0, + mask.width - expand_pixels - mask_blur if is_right else res_w, + mask.height - expand_pixels - mask_blur if is_bottom else res_h, + ), fill="black") + + np_image = (np.asarray(img) / 255.0).astype(np.float64) + np_mask = (np.asarray(mask) / 255.0).astype(np.float64) + noised = get_matched_noise(np_image, np_mask, noise_q, color_variation) + output_images.append(Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB")) + + target_width = min(process_width, init[n].width + pixels_horiz) if is_horiz else img.width + target_height = min(process_height, init[n].height + pixels_vert) if is_vert else img.height + p.width = target_width if is_horiz else img.width + p.height = target_height if is_vert else img.height + + crop_region = ( + 0 if is_left else output_images[n].width - target_width, + 0 if is_top else output_images[n].height - target_height, + target_width if is_left else output_images[n].width, + target_height if is_top else output_images[n].height, + ) + mask = mask.crop(crop_region) + p.image_mask = mask + + image_to_process = output_images[n].crop(crop_region) + images_to_process.append(image_to_process) + + p.init_images = images_to_process + + latent_mask = Image.new("RGB", (p.width, p.height), "white") + draw = ImageDraw.Draw(latent_mask) + draw.rectangle(( + expand_pixels + mask_blur * 2 if is_left else 0, + expand_pixels + mask_blur * 2 if is_top else 0, + mask.width - expand_pixels - mask_blur * 2 if is_right else res_w, + mask.height - expand_pixels - mask_blur * 2 if is_bottom else res_h, + ), fill="black") + p.latent_mask = latent_mask + + proc = process_images(p) + + if initial_seed_and_info[0] is None: + initial_seed_and_info[0] = proc.seed + initial_seed_and_info[1] = proc.info + + for n in range(count): + output_images[n].paste(proc.images[n], (0 if is_left else output_images[n].width - proc.images[n].width, 0 if is_top else output_images[n].height - proc.images[n].height)) + output_images[n] = output_images[n].crop((0, 0, res_w, res_h)) + + return output_images + + batch_count = p.n_iter + batch_size = p.batch_size + p.n_iter = 1 + state.job_count = batch_count * ((1 if left > 0 else 0) + (1 if right > 0 else 0) + (1 if up > 0 else 0) + (1 if down > 0 else 0)) + all_processed_images = [] + + for i in range(batch_count): + imgs = [init_img] * batch_size + state.job = f"Batch {i + 1} out of {batch_count}" + + if left > 0: + imgs = expand(imgs, batch_size, left, is_left=True) + if right > 0: + imgs = expand(imgs, batch_size, right, is_right=True) + if up > 0: + imgs = expand(imgs, batch_size, up, is_top=True) + if down > 0: + imgs = expand(imgs, batch_size, down, is_bottom=True) + + all_processed_images += imgs + + all_images = all_processed_images + + combined_grid_image = images.image_grid(all_processed_images) + unwanted_grid_because_of_img_count = len(all_processed_images) < 2 and opts.grid_only_if_multiple + if opts.return_grid and not unwanted_grid_because_of_img_count: + all_images = [combined_grid_image] + all_processed_images + + res = Processed(p, all_images, initial_seed_and_info[0], initial_seed_and_info[1]) + + if opts.samples_save: + for img in all_processed_images: + images.save_image(img, p.outpath_samples, "", res.seed, p.prompt, opts.samples_format, info=res.info, p=p) + + if opts.grid_save and not unwanted_grid_because_of_img_count: + images.save_image(combined_grid_image, p.outpath_grids, "grid", res.seed, p.prompt, opts.samples_format, info=res.info, short_filename=not opts.grid_extended_filename, grid=True, p=p) + + return res diff --git a/scripts/poor_mans_outpainting.py b/scripts/poor_mans_outpainting.py index e80478b7c..c14fca6a6 100644 --- a/scripts/poor_mans_outpainting.py +++ b/scripts/poor_mans_outpainting.py @@ -1,146 +1,143 @@ -import math - -import modules.scripts as scripts -import gradio as gr -from PIL import Image, ImageDraw - -from modules import images, processing, devices -from modules.processing import Processed, process_images -from modules.shared import opts, cmd_opts, state - - -class Script(scripts.Script): - def title(self): - return "Outpainting alternative" - - def show(self, is_img2img): - return is_img2img - - def ui(self, is_img2img): - if not is_img2img: - return None - - pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) - mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id=self.elem_id("mask_blur")) - inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", elem_id=self.elem_id("inpainting_fill")) - direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction")) - - return [pixels, mask_blur, inpainting_fill, direction] - - def run(self, p, pixels, mask_blur, inpainting_fill, direction): - initial_seed = None - initial_info = None - - p.mask_blur = mask_blur * 2 - p.inpainting_fill = inpainting_fill - p.inpaint_full_res = False - - left = pixels if "left" in direction else 0 - right = pixels if "right" in direction else 0 - up = pixels if "up" in direction else 0 - down = pixels if "down" in direction else 0 - - init_img = p.init_images[0] - target_w = math.ceil((init_img.width + left + right) / 64) * 64 - target_h = math.ceil((init_img.height + up + down) / 64) * 64 - - if left > 0: - left = left * (target_w - init_img.width) // (left + right) - if right > 0: - right = target_w - init_img.width - left - - if up > 0: - up = up * (target_h - init_img.height) // (up + down) - - if down > 0: - down = target_h - init_img.height - up - - img = Image.new("RGB", (target_w, target_h)) - img.paste(init_img, (left, up)) - - mask = Image.new("L", (img.width, img.height), "white") - draw = ImageDraw.Draw(mask) - draw.rectangle(( - left + (mask_blur * 2 if left > 0 else 0), - up + (mask_blur * 2 if up > 0 else 0), - mask.width - right - (mask_blur * 2 if right > 0 else 0), - mask.height - down - (mask_blur * 2 if down > 0 else 0) - ), fill="black") - - latent_mask = Image.new("L", (img.width, img.height), "white") - latent_draw = ImageDraw.Draw(latent_mask) - latent_draw.rectangle(( - left + (mask_blur//2 if left > 0 else 0), - up + (mask_blur//2 if up > 0 else 0), - mask.width - right - (mask_blur//2 if right > 0 else 0), - mask.height - down - (mask_blur//2 if down > 0 else 0) - ), fill="black") - - devices.torch_gc() - - grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=pixels) - grid_mask = images.split_grid(mask, tile_w=p.width, tile_h=p.height, overlap=pixels) - grid_latent_mask = images.split_grid(latent_mask, tile_w=p.width, tile_h=p.height, overlap=pixels) - - p.n_iter = 1 - p.batch_size = 1 - p.do_not_save_grid = True - p.do_not_save_samples = True - - work = [] - work_mask = [] - work_latent_mask = [] - work_results = [] - - for (y, h, row), (_, _, row_mask), (_, _, row_latent_mask) in zip(grid.tiles, grid_mask.tiles, grid_latent_mask.tiles): - for tiledata, tiledata_mask, tiledata_latent_mask in zip(row, row_mask, row_latent_mask): - x, w = tiledata[0:2] - - if x >= left and x+w <= img.width - right and y >= up and y+h <= img.height - down: - continue - - work.append(tiledata[2]) - work_mask.append(tiledata_mask[2]) - work_latent_mask.append(tiledata_latent_mask[2]) - - batch_count = len(work) - print(f"Poor man's outpainting will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)}.") - - state.job_count = batch_count - - for i in range(batch_count): - p.init_images = [work[i]] - p.image_mask = work_mask[i] - p.latent_mask = work_latent_mask[i] - - state.job = f"Batch {i + 1} out of {batch_count}" - processed = process_images(p) - - if initial_seed is None: - initial_seed = processed.seed - initial_info = processed.info - - p.seed = processed.seed + 1 - work_results += processed.images - - - image_index = 0 - for y, h, row in grid.tiles: - for tiledata in row: - x, w = tiledata[0:2] - - if x >= left and x+w <= img.width - right and y >= up and y+h <= img.height - down: - continue - - tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height)) - image_index += 1 - - combined_image = images.combine_grid(grid) - - if opts.samples_save: - images.save_image(combined_image, p.outpath_samples, "", initial_seed, p.prompt, opts.samples_format, info=initial_info, p=p) - - processed = Processed(p, [combined_image], initial_seed, initial_info) - - return processed - +import math +import gradio as gr +from PIL import Image, ImageDraw +import modules.scripts as scripts +from modules import images, devices +from modules.processing import Processed, process_images +from modules.shared import opts, state + + +class Script(scripts.Script): + def title(self): + return "Outpainting alternative" + + def show(self, is_img2img): + return is_img2img + + def ui(self, is_img2img): + if not is_img2img: + return None + + pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) + mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id=self.elem_id("mask_blur")) + inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", elem_id=self.elem_id("inpainting_fill")) + direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction")) + + return [pixels, mask_blur, inpainting_fill, direction] + + def run(self, p, pixels, mask_blur, inpainting_fill, direction): + initial_seed = None + initial_info = None + + p.mask_blur = mask_blur * 2 + p.inpainting_fill = inpainting_fill + p.inpaint_full_res = False + + left = pixels if "left" in direction else 0 + right = pixels if "right" in direction else 0 + up = pixels if "up" in direction else 0 + down = pixels if "down" in direction else 0 + + init_img = p.init_images[0] + target_w = math.ceil((init_img.width + left + right) / 64) * 64 + target_h = math.ceil((init_img.height + up + down) / 64) * 64 + + if left > 0: + left = left * (target_w - init_img.width) // (left + right) + if right > 0: + right = target_w - init_img.width - left + + if up > 0: + up = up * (target_h - init_img.height) // (up + down) + + if down > 0: + down = target_h - init_img.height - up + + img = Image.new("RGB", (target_w, target_h)) + img.paste(init_img, (left, up)) + + mask = Image.new("L", (img.width, img.height), "white") + draw = ImageDraw.Draw(mask) + draw.rectangle(( + left + (mask_blur * 2 if left > 0 else 0), + up + (mask_blur * 2 if up > 0 else 0), + mask.width - right - (mask_blur * 2 if right > 0 else 0), + mask.height - down - (mask_blur * 2 if down > 0 else 0) + ), fill="black") + + latent_mask = Image.new("L", (img.width, img.height), "white") + latent_draw = ImageDraw.Draw(latent_mask) + latent_draw.rectangle(( + left + (mask_blur//2 if left > 0 else 0), + up + (mask_blur//2 if up > 0 else 0), + mask.width - right - (mask_blur//2 if right > 0 else 0), + mask.height - down - (mask_blur//2 if down > 0 else 0) + ), fill="black") + + devices.torch_gc() + + grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=pixels) + grid_mask = images.split_grid(mask, tile_w=p.width, tile_h=p.height, overlap=pixels) + grid_latent_mask = images.split_grid(latent_mask, tile_w=p.width, tile_h=p.height, overlap=pixels) + + p.n_iter = 1 + p.batch_size = 1 + p.do_not_save_grid = True + p.do_not_save_samples = True + + work = [] + work_mask = [] + work_latent_mask = [] + work_results = [] + + for (y, h, row), (_, _, row_mask), (_, _, row_latent_mask) in zip(grid.tiles, grid_mask.tiles, grid_latent_mask.tiles): + for tiledata, tiledata_mask, tiledata_latent_mask in zip(row, row_mask, row_latent_mask): + x, w = tiledata[0:2] + + if x >= left and x+w <= img.width - right and y >= up and y+h <= img.height - down: + continue + + work.append(tiledata[2]) + work_mask.append(tiledata_mask[2]) + work_latent_mask.append(tiledata_latent_mask[2]) + + batch_count = len(work) + print(f"Poor man's outpainting will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)}.") + + state.job_count = batch_count + + for i in range(batch_count): + p.init_images = [work[i]] + p.image_mask = work_mask[i] + p.latent_mask = work_latent_mask[i] + + state.job = f"Batch {i + 1} out of {batch_count}" + processed = process_images(p) + + if initial_seed is None: + initial_seed = processed.seed + initial_info = processed.info + + p.seed = processed.seed + 1 + work_results += processed.images + + + image_index = 0 + for y, h, row in grid.tiles: + for tiledata in row: + x, w = tiledata[0:2] + + if x >= left and x+w <= img.width - right and y >= up and y+h <= img.height - down: + continue + + tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height)) + image_index += 1 + + combined_image = images.combine_grid(grid) + + if opts.samples_save: + images.save_image(combined_image, p.outpath_samples, "", initial_seed, p.prompt, opts.samples_format, info=initial_info, p=p) + + processed = Processed(p, [combined_image], initial_seed, initial_info) + + return processed diff --git a/scripts/postprocessing_codeformer.py b/scripts/postprocessing_codeformer.py index 4bc2c3696..251443642 100644 --- a/scripts/postprocessing_codeformer.py +++ b/scripts/postprocessing_codeformer.py @@ -1,36 +1,34 @@ -from PIL import Image -import numpy as np - -from modules import scripts_postprocessing, codeformer_model -import gradio as gr - -from modules.ui_components import FormRow - - -class ScriptPostprocessingCodeFormer(scripts_postprocessing.ScriptPostprocessing): - name = "CodeFormer" - order = 3000 - - def ui(self): - with FormRow(): - codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="CodeFormer visibility", value=1.0, elem_id="extras_codeformer_visibility") - codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="CodeFormer weight (0 = max), 1 = min)", value=0.2, elem_id="extras_codeformer_weight") - - return { - "codeformer_visibility": codeformer_visibility, - "codeformer_weight": codeformer_weight, - } - - def process(self, pp: scripts_postprocessing.PostprocessedImage, codeformer_visibility, codeformer_weight): - if codeformer_visibility == 0: - return - - restored_img = codeformer_model.codeformer.restore(np.array(pp.image, dtype=np.uint8), w=codeformer_weight) - res = Image.fromarray(restored_img) - - if codeformer_visibility < 1.0: - res = Image.blend(pp.image, res, codeformer_visibility) - - pp.image = res - pp.info["CodeFormer visibility"] = round(codeformer_visibility, 3) - pp.info["CodeFormer weight"] = round(codeformer_weight, 3) +from PIL import Image +import numpy as np +import gradio as gr +from modules import scripts_postprocessing, codeformer_model +from modules.ui_components import FormRow + + +class ScriptPostprocessingCodeFormer(scripts_postprocessing.ScriptPostprocessing): + name = "CodeFormer" + order = 3000 + + def ui(self): + with FormRow(): + codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="CodeFormer visibility", value=1.0, elem_id="extras_codeformer_visibility") + codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="CodeFormer weight (0 = max), 1 = min)", value=0.2, elem_id="extras_codeformer_weight") + + return { + "codeformer_visibility": codeformer_visibility, + "codeformer_weight": codeformer_weight, + } + + def process(self, pp: scripts_postprocessing.PostprocessedImage, codeformer_visibility, codeformer_weight): + if codeformer_visibility == 0: + return + + restored_img = codeformer_model.codeformer.restore(np.array(pp.image, dtype=np.uint8), w=codeformer_weight) + res = Image.fromarray(restored_img) + + if codeformer_visibility < 1.0: + res = Image.blend(pp.image, res, codeformer_visibility) + + pp.image = res + pp.info["CodeFormer visibility"] = round(codeformer_visibility, 3) + pp.info["CodeFormer weight"] = round(codeformer_weight, 3) diff --git a/scripts/postprocessing_gfpgan.py b/scripts/postprocessing_gfpgan.py index d854f3f77..7cea285d7 100644 --- a/scripts/postprocessing_gfpgan.py +++ b/scripts/postprocessing_gfpgan.py @@ -1,33 +1,31 @@ -from PIL import Image -import numpy as np - -from modules import scripts_postprocessing, gfpgan_model -import gradio as gr - -from modules.ui_components import FormRow - - -class ScriptPostprocessingGfpGan(scripts_postprocessing.ScriptPostprocessing): - name = "GFPGAN" - order = 2000 - - def ui(self): - with FormRow(): - gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, elem_id="extras_gfpgan_visibility") - - return { - "gfpgan_visibility": gfpgan_visibility, - } - - def process(self, pp: scripts_postprocessing.PostprocessedImage, gfpgan_visibility): - if gfpgan_visibility == 0: - return - - restored_img = gfpgan_model.gfpgan_fix_faces(np.array(pp.image, dtype=np.uint8)) - res = Image.fromarray(restored_img) - - if gfpgan_visibility < 1.0: - res = Image.blend(pp.image, res, gfpgan_visibility) - - pp.image = res - pp.info["GFPGAN visibility"] = round(gfpgan_visibility, 3) +from PIL import Image +import numpy as np +import gradio as gr +from modules import scripts_postprocessing, gfpgan_model +from modules.ui_components import FormRow + + +class ScriptPostprocessingGfpGan(scripts_postprocessing.ScriptPostprocessing): + name = "GFPGAN" + order = 2000 + + def ui(self): + with FormRow(): + gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, elem_id="extras_gfpgan_visibility") + + return { + "gfpgan_visibility": gfpgan_visibility, + } + + def process(self, pp: scripts_postprocessing.PostprocessedImage, gfpgan_visibility): + if gfpgan_visibility == 0: + return + + restored_img = gfpgan_model.gfpgan_fix_faces(np.array(pp.image, dtype=np.uint8)) + res = Image.fromarray(restored_img) + + if gfpgan_visibility < 1.0: + res = Image.blend(pp.image, res, gfpgan_visibility) + + pp.image = res + pp.info["GFPGAN visibility"] = round(gfpgan_visibility, 3) diff --git a/scripts/postprocessing_upscale.py b/scripts/postprocessing_upscale.py index 43df74aff..bc04030af 100644 --- a/scripts/postprocessing_upscale.py +++ b/scripts/postprocessing_upscale.py @@ -1,134 +1,134 @@ -from PIL import Image -import numpy as np -import gradio as gr -from modules import scripts_postprocessing, shared -from modules.ui_components import FormRow, ToolButton -from modules.ui import switch_values_symbol - -upscale_cache = {} - - -class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing): - name = "Upscale" - order = 1000 - - def ui(self): - selected_tab = gr.State(value=0) # pylint: disable=abstract-class-instantiated - - with gr.Column(): - with FormRow(): - with gr.Tabs(elem_id="extras_resize_mode"): - with gr.TabItem('Scale by', elem_id="extras_scale_by_tab") as tab_scale_by: - upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize") - - with gr.TabItem('Scale to', elem_id="extras_scale_to_tab") as tab_scale_to: - with FormRow(): - with gr.Row(elem_id="upscaling_column_size", scale=4): - upscaling_resize_w = gr.Slider(minimum=64, maximum=4096, step=8, label="Width", value=512, elem_id="extras_upscaling_resize_w") - upscaling_resize_h = gr.Slider(minimum=64, maximum=4096, step=8, label="Height", value=512, elem_id="extras_upscaling_resize_h") - upscaling_res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="upscaling_res_switch_btn") - upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") - - with FormRow(): - extras_upscaler_1 = gr.Dropdown(label='Upscaler', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name) - - with FormRow(): - extras_upscaler_2 = gr.Dropdown(label='Secondary Upscaler', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name) - extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=0.0, elem_id="extras_upscaler_2_visibility") - - upscaling_res_switch_btn.click(lambda w, h: (h, w), inputs=[upscaling_resize_w, upscaling_resize_h], outputs=[upscaling_resize_w, upscaling_resize_h], show_progress=False) - tab_scale_by.select(fn=lambda: 0, inputs=[], outputs=[selected_tab]) - tab_scale_to.select(fn=lambda: 1, inputs=[], outputs=[selected_tab]) - - return { - "upscale_mode": selected_tab, - "upscale_by": upscaling_resize, - "upscale_to_width": upscaling_resize_w, - "upscale_to_height": upscaling_resize_h, - "upscale_crop": upscaling_crop, - "upscaler_1_name": extras_upscaler_1, - "upscaler_2_name": extras_upscaler_2, - "upscaler_2_visibility": extras_upscaler_2_visibility, - } - - def upscale(self, image, info, upscaler, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop): - if upscale_mode == 1: - upscale_by = max(upscale_to_width/image.width, upscale_to_height/image.height) - info["Postprocess upscale to"] = f"{upscale_to_width}x{upscale_to_height}" - else: - info["Postprocess upscale by"] = upscale_by - - cache_key = (hash(np.array(image.getdata()).tobytes()), upscaler.name, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop) - cached_image = upscale_cache.pop(cache_key, None) - - if cached_image is not None: - image = cached_image - else: - image = upscaler.scaler.upscale(image, upscale_by, upscaler.data_path) - - upscale_cache[cache_key] = image - if len(upscale_cache) > shared.opts.upscaling_max_images_in_cache: - upscale_cache.pop(next(iter(upscale_cache), None), None) - - if upscale_mode == 1 and upscale_crop: - cropped = Image.new("RGB", (upscale_to_width, upscale_to_height)) - cropped.paste(image, box=(upscale_to_width // 2 - image.width // 2, upscale_to_height // 2 - image.height // 2)) - image = cropped - info["Postprocess crop to"] = f"{image.width}x{image.height}" - - return image - - def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_mode=1, upscale_by=2.0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0): # pylint: disable=arguments-differ - if upscaler_1_name == "None": - upscaler_1_name = None - - upscaler1 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_1_name]), None) - assert upscaler1 or (upscaler_1_name is None), f'could not find upscaler named {upscaler_1_name}' - - if not upscaler1: - return - - if upscaler_2_name == "None": - upscaler_2_name = None - - upscaler2 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_2_name and x.name != "None"]), None) - assert upscaler2 or (upscaler_2_name is None), f'could not find upscaler named {upscaler_2_name}' - - upscaled_image = self.upscale(pp.image, pp.info, upscaler1, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop) - pp.info["Postprocess upscaler"] = upscaler1.name - - if upscaler2 and upscaler_2_visibility > 0: - second_upscale = self.upscale(pp.image, pp.info, upscaler2, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop) - upscaled_image = Image.blend(upscaled_image, second_upscale, upscaler_2_visibility) - - pp.info["Postprocess upscaler 2"] = upscaler2.name - - pp.image = upscaled_image - - def image_changed(self): - upscale_cache.clear() - - -class ScriptPostprocessingUpscaleSimple(ScriptPostprocessingUpscale): - name = "Simple Upscale" - order = 900 - - def ui(self): - with FormRow(): - upscaler_name = gr.Dropdown(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name) - upscale_by = gr.Slider(minimum=0.05, maximum=8.0, step=0.05, label="Upscale by", value=2) - - return { - "upscale_by": upscale_by, - "upscaler_name": upscaler_name, - } - - def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_by=2.0, upscaler_name=None): # pylint: disable=arguments-differ - if upscaler_name is None or upscaler_name == "None": - return - - upscaler1 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_name]), None) - assert upscaler1, f'could not find upscaler named {upscaler_name}' - - pp.image = self.upscale(pp.image, pp.info, upscaler1, 0, upscale_by, 0, 0, False) - pp.info["Postprocess upscaler"] = upscaler1.name +from PIL import Image +import numpy as np +import gradio as gr +from modules import scripts_postprocessing, shared +from modules.ui_components import FormRow, ToolButton +from modules.ui import switch_values_symbol + +upscale_cache = {} + + +class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing): + name = "Upscale" + order = 1000 + + def ui(self): + selected_tab = gr.State(value=0) # pylint: disable=abstract-class-instantiated + + with gr.Column(): + with FormRow(): + with gr.Tabs(elem_id="extras_resize_mode"): + with gr.TabItem('Scale by', elem_id="extras_scale_by_tab") as tab_scale_by: + upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize") + + with gr.TabItem('Scale to', elem_id="extras_scale_to_tab") as tab_scale_to: + with FormRow(): + with gr.Row(elem_id="upscaling_column_size", scale=4): + upscaling_resize_w = gr.Slider(minimum=64, maximum=4096, step=8, label="Width", value=512, elem_id="extras_upscaling_resize_w") + upscaling_resize_h = gr.Slider(minimum=64, maximum=4096, step=8, label="Height", value=512, elem_id="extras_upscaling_resize_h") + upscaling_res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="upscaling_res_switch_btn") + upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") + + with FormRow(): + extras_upscaler_1 = gr.Dropdown(label='Upscaler', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name) + + with FormRow(): + extras_upscaler_2 = gr.Dropdown(label='Secondary Upscaler', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name) + extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=0.0, elem_id="extras_upscaler_2_visibility") + + upscaling_res_switch_btn.click(lambda w, h: (h, w), inputs=[upscaling_resize_w, upscaling_resize_h], outputs=[upscaling_resize_w, upscaling_resize_h], show_progress=False) + tab_scale_by.select(fn=lambda: 0, inputs=[], outputs=[selected_tab]) + tab_scale_to.select(fn=lambda: 1, inputs=[], outputs=[selected_tab]) + + return { + "upscale_mode": selected_tab, + "upscale_by": upscaling_resize, + "upscale_to_width": upscaling_resize_w, + "upscale_to_height": upscaling_resize_h, + "upscale_crop": upscaling_crop, + "upscaler_1_name": extras_upscaler_1, + "upscaler_2_name": extras_upscaler_2, + "upscaler_2_visibility": extras_upscaler_2_visibility, + } + + def upscale(self, image, info, upscaler, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop): + if upscale_mode == 1: + upscale_by = max(upscale_to_width/image.width, upscale_to_height/image.height) + info["Postprocess upscale to"] = f"{upscale_to_width}x{upscale_to_height}" + else: + info["Postprocess upscale by"] = upscale_by + + cache_key = (hash(np.array(image.getdata()).tobytes()), upscaler.name, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop) + cached_image = upscale_cache.pop(cache_key, None) + + if cached_image is not None: + image = cached_image + else: + image = upscaler.scaler.upscale(image, upscale_by, upscaler.data_path) + + upscale_cache[cache_key] = image + if len(upscale_cache) > shared.opts.upscaling_max_images_in_cache: + upscale_cache.pop(next(iter(upscale_cache), None), None) + + if upscale_mode == 1 and upscale_crop: + cropped = Image.new("RGB", (upscale_to_width, upscale_to_height)) + cropped.paste(image, box=(upscale_to_width // 2 - image.width // 2, upscale_to_height // 2 - image.height // 2)) + image = cropped + info["Postprocess crop to"] = f"{image.width}x{image.height}" + + return image + + def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_mode=1, upscale_by=2.0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0): # pylint: disable=arguments-differ + if upscaler_1_name == "None": + upscaler_1_name = None + + upscaler1 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_1_name]), None) + assert upscaler1 or (upscaler_1_name is None), f'could not find upscaler named {upscaler_1_name}' + + if not upscaler1: + return + + if upscaler_2_name == "None": + upscaler_2_name = None + + upscaler2 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_2_name and x.name != "None"]), None) + assert upscaler2 or (upscaler_2_name is None), f'could not find upscaler named {upscaler_2_name}' + + upscaled_image = self.upscale(pp.image, pp.info, upscaler1, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop) + pp.info["Postprocess upscaler"] = upscaler1.name + + if upscaler2 and upscaler_2_visibility > 0: + second_upscale = self.upscale(pp.image, pp.info, upscaler2, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop) + upscaled_image = Image.blend(upscaled_image, second_upscale, upscaler_2_visibility) + + pp.info["Postprocess upscaler 2"] = upscaler2.name + + pp.image = upscaled_image + + def image_changed(self): + upscale_cache.clear() + + +class ScriptPostprocessingUpscaleSimple(ScriptPostprocessingUpscale): + name = "Simple Upscale" + order = 900 + + def ui(self): + with FormRow(): + upscaler_name = gr.Dropdown(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name) + upscale_by = gr.Slider(minimum=0.05, maximum=8.0, step=0.05, label="Upscale by", value=2) + + return { + "upscale_by": upscale_by, + "upscaler_name": upscaler_name, + } + + def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_by=2.0, upscaler_name=None): # pylint: disable=arguments-differ + if upscaler_name is None or upscaler_name == "None": + return + + upscaler1 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_name]), None) + assert upscaler1, f'could not find upscaler named {upscaler_name}' + + pp.image = self.upscale(pp.image, pp.info, upscaler1, 0, upscale_by, 0, 0, False) + pp.info["Postprocess upscaler"] = upscaler1.name diff --git a/scripts/prompt_matrix.py b/scripts/prompt_matrix.py index e9b115170..6772074ad 100644 --- a/scripts/prompt_matrix.py +++ b/scripts/prompt_matrix.py @@ -1,111 +1,106 @@ -import math -from collections import namedtuple -from copy import copy -import random - -import modules.scripts as scripts -import gradio as gr - -from modules import images -from modules.processing import process_images, Processed -from modules.shared import opts, cmd_opts, state -import modules.sd_samplers - - -def draw_xy_grid(xs, ys, x_label, y_label, cell): - res = [] - - ver_texts = [[images.GridAnnotation(y_label(y))] for y in ys] - hor_texts = [[images.GridAnnotation(x_label(x))] for x in xs] - - first_processed = None - - state.job_count = len(xs) * len(ys) - - for iy, y in enumerate(ys): - for ix, x in enumerate(xs): - state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}" - - processed = cell(x, y) - if first_processed is None: - first_processed = processed - - res.append(processed.images[0]) - - grid = images.image_grid(res, rows=len(ys)) - grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts) - - first_processed.images = [grid] - - return first_processed - - -class Script(scripts.Script): - def title(self): - return "Prompt matrix" - - def ui(self, is_img2img): - gr.HTML('Will upscale the image by the selected scale factor; use width and height sliders to set tile size
") - overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, elem_id=self.elem_id("overlap")) - scale_factor = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label='Scale Factor', value=2.0, elem_id=self.elem_id("scale_factor")) - upscaler_index = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index", elem_id=self.elem_id("upscaler_index")) - - return [info, overlap, upscaler_index, scale_factor] - - def run(self, p, _, overlap, upscaler_index, scale_factor): - if isinstance(upscaler_index, str): - upscaler_index = [x.name.lower() for x in shared.sd_upscalers].index(upscaler_index.lower()) - processing.fix_seed(p) - upscaler = shared.sd_upscalers[upscaler_index] - - p.extra_generation_params["SD upscale overlap"] = overlap - p.extra_generation_params["SD upscale upscaler"] = upscaler.name - - initial_info = None - seed = p.seed - - init_img = p.init_images[0] - init_img = images.flatten(init_img, opts.img2img_background_color) - - if upscaler.name != "None": - img = upscaler.scaler.upscale(init_img, scale_factor, upscaler.data_path) - else: - img = init_img - - devices.torch_gc() - - grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=overlap) - - batch_size = p.batch_size - upscale_count = p.n_iter - p.n_iter = 1 - p.do_not_save_grid = True - p.do_not_save_samples = True - - work = [] - - for y, h, row in grid.tiles: - for tiledata in row: - work.append(tiledata[2]) - - batch_count = math.ceil(len(work) / batch_size) - state.job_count = batch_count * upscale_count - - print(f"SD upscaling will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)} per upscale in a total of {state.job_count} batches.") - - result_images = [] - for n in range(upscale_count): - start_seed = seed + n - p.seed = start_seed - - work_results = [] - for i in range(batch_count): - p.batch_size = batch_size - p.init_images = work[i * batch_size:(i + 1) * batch_size] - - state.job = f"Batch {i + 1 + n * batch_count} out of {state.job_count}" - processed = processing.process_images(p) - - if initial_info is None: - initial_info = processed.info - - p.seed = processed.seed + 1 - work_results += processed.images - - image_index = 0 - for y, h, row in grid.tiles: - for tiledata in row: - tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height)) - image_index += 1 - - combined_image = images.combine_grid(grid) - result_images.append(combined_image) - - if opts.samples_save: - images.save_image(combined_image, p.outpath_samples, "", start_seed, p.prompt, opts.samples_format, info=initial_info, p=p) - - processed = Processed(p, result_images, seed, initial_info) - - return processed +import math +import gradio as gr +from PIL import Image +import modules.scripts as scripts +from modules import processing, shared, images, devices +from modules.processing import Processed +from modules.shared import opts, state + + +class Script(scripts.Script): + def title(self): + return "SD upscale" + + def show(self, is_img2img): + return is_img2img + + def ui(self, is_img2img): + info = gr.HTML("Will upscale the image by the selected scale factor; use width and height sliders to set tile size
") + overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, elem_id=self.elem_id("overlap")) + scale_factor = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label='Scale Factor', value=2.0, elem_id=self.elem_id("scale_factor")) + upscaler_index = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index", elem_id=self.elem_id("upscaler_index")) + + return [info, overlap, upscaler_index, scale_factor] + + def run(self, p, _, overlap, upscaler_index, scale_factor): + if isinstance(upscaler_index, str): + upscaler_index = [x.name.lower() for x in shared.sd_upscalers].index(upscaler_index.lower()) + processing.fix_seed(p) + upscaler = shared.sd_upscalers[upscaler_index] + + p.extra_generation_params["SD upscale overlap"] = overlap + p.extra_generation_params["SD upscale upscaler"] = upscaler.name + + initial_info = None + seed = p.seed + + init_img = p.init_images[0] + init_img = images.flatten(init_img, opts.img2img_background_color) + + if upscaler.name != "None": + img = upscaler.scaler.upscale(init_img, scale_factor, upscaler.data_path) + else: + img = init_img + + devices.torch_gc() + + grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=overlap) + + batch_size = p.batch_size + upscale_count = p.n_iter + p.n_iter = 1 + p.do_not_save_grid = True + p.do_not_save_samples = True + + work = [] + + for y, h, row in grid.tiles: + for tiledata in row: + work.append(tiledata[2]) + + batch_count = math.ceil(len(work) / batch_size) + state.job_count = batch_count * upscale_count + + print(f"SD upscaling will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)} per upscale in a total of {state.job_count} batches.") + + result_images = [] + for n in range(upscale_count): + start_seed = seed + n + p.seed = start_seed + + work_results = [] + for i in range(batch_count): + p.batch_size = batch_size + p.init_images = work[i * batch_size:(i + 1) * batch_size] + + state.job = f"Batch {i + 1 + n * batch_count} out of {state.job_count}" + processed = processing.process_images(p) + + if initial_info is None: + initial_info = processed.info + + p.seed = processed.seed + 1 + work_results += processed.images + + image_index = 0 + for y, h, row in grid.tiles: + for tiledata in row: + tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height)) + image_index += 1 + + combined_image = images.combine_grid(grid) + result_images.append(combined_image) + + if opts.samples_save: + images.save_image(combined_image, p.outpath_samples, "", start_seed, p.prompt, opts.samples_format, info=initial_info, p=p) + + processed = Processed(p, result_images, seed, initial_info) + + return processed diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 41f226848..58ab7507f 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -1,675 +1,675 @@ - # pylint: disable=unused-argument, attribute-defined-outside-init - -import re -import csv -import random -from collections import namedtuple -from copy import copy -from itertools import permutations, chain -from io import StringIO -from PIL import Image -import numpy as np -import gradio as gr -import modules.scripts as scripts -import modules.shared as shared -from modules import images, sd_samplers, processing, sd_models, sd_vae -from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img -from modules.ui_components import ToolButton - -fill_values_symbol = "\U0001f4d2" # 📒 -AxisInfo = namedtuple('AxisInfo', ['axis', 'values']) - - -def apply_field(field): - def fun(p, x, xs): - setattr(p, field, x) - return fun - - -def apply_prompt(p, x, xs): - if xs[0] not in p.prompt and xs[0] not in p.negative_prompt: - shared.log.warning(f"XYZ grid: prompt S/R did not find {xs[0]} in prompt or negative prompt.") - else: - p.prompt = p.prompt.replace(xs[0], x) - p.negative_prompt = p.negative_prompt.replace(xs[0], x) - - -def apply_order(p, x, xs): - token_order = [] - for token in x: - token_order.append((p.prompt.find(token), token)) - token_order.sort(key=lambda t: t[0]) - prompt_parts = [] - for _, token in token_order: - n = p.prompt.find(token) - prompt_parts.append(p.prompt[0:n]) - p.prompt = p.prompt[n + len(token):] - prompt_tmp = "" - for idx, part in enumerate(prompt_parts): - prompt_tmp += part - prompt_tmp += x[idx] - p.prompt = prompt_tmp + p.prompt - - -def apply_sampler(p, x, xs): - sampler_name = sd_samplers.samplers_map.get(x.lower(), None) - if sampler_name is None: - shared.log.warning(f"XYZ grid: unknown sampler: {x}") - else: - p.sampler_name = sampler_name - - -def confirm_samplers(p, xs): - for x in xs: - if x.lower() not in sd_samplers.samplers_map: - shared.log.warning(f"XYZ grid: unknown sampler: {x}") - - -def apply_checkpoint(p, x, xs): - if x == shared.opts.sd_model_checkpoint: - return - info = sd_models.get_closet_checkpoint_match(x) - if info is None: - shared.log.warning(f"XYZ grid: unknown checkpoint: {x}") - else: - sd_models.reload_model_weights(shared.sd_model, info) - - -def confirm_checkpoints(p, xs): - for x in xs: - if sd_models.get_closet_checkpoint_match(x) is None: - shared.log.warning(f"XYZ grid: Unknown checkpoint: {x}") - - -def apply_clip_skip(p, x, xs): - shared.opts.data["CLIP_stop_at_last_layers"] = x - - -def apply_upscale_latent_space(p, x, xs): - if x.lower().strip() != '0': - shared.opts.data["use_scale_latent_for_hires_fix"] = True - else: - shared.opts.data["use_scale_latent_for_hires_fix"] = False - - -def find_vae(name: str): - if name.lower() in ['auto', 'automatic']: - return sd_vae.unspecified - if name.lower() == 'none': - return None - else: - choices = [x for x in sorted(sd_vae.vae_dict, key=lambda x: len(x)) if name.lower().strip() in x.lower()] - if len(choices) == 0: - shared.log.warning(f"No VAE found for {name}; using automatic") - return sd_vae.unspecified - else: - return sd_vae.vae_dict[choices[0]] - - -def apply_vae(p, x, xs): - sd_vae.reload_vae_weights(shared.sd_model, vae_file=find_vae(x)) - - -def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _): - p.styles.extend(x.split(',')) - - -def apply_fallback(p, x, xs): - sampler_name = sd_samplers.samplers_map.get(x.lower(), None) - if sampler_name is None: - shared.log.warning(f"XYZ grid: unknown sampler: {x}") - else: - shared.opts.data["xyz_fallback_sampler"] = sampler_name - - -def apply_uni_pc_order(p, x, xs): - shared.opts.data["uni_pc_order"] = min(x, p.steps - 1) - - -def apply_face_restore(p, opt, x): - opt = opt.lower() - if opt == 'codeformer': - is_active = True - p.face_restoration_model = 'CodeFormer' - elif opt == 'gfpgan': - is_active = True - p.face_restoration_model = 'GFPGAN' - else: - is_active = opt in ('true', 'yes', 'y', '1') - p.restore_faces = is_active - - -def apply_token_merging_ratio_hr(p, x, xs): - shared.opts.data["token_merging_ratio_hr"] = x - - -def apply_token_merging_ratio(p, x, xs): - shared.opts.data["token_merging_ratio"] = x - - -def apply_token_merging_random(p, x, xs): - is_active = x.lower() in ('true', 'yes', 'y', '1') - shared.opts.data["token_merging_random"] = is_active - - -def format_value_add_label(p, opt, x): - if type(x) == float: - x = round(x, 8) - return f"{opt.label}: {x}" - - -def format_value(p, opt, x): - if type(x) == float: - x = round(x, 8) - return x - - -def format_value_join_list(p, opt, x): - return ", ".join(x) - - -def do_nothing(p, x, xs): - pass - - -def format_nothing(p, opt, x): - return "" - - -def str_permutations(x): - """dummy function for specifying it in AxisOption's type when you want to get a list of permutations""" - return x - - -class AxisOption: - def __init__(self, label, tipe, apply, fmt=format_value_add_label, confirm=None, cost=0.0, choices=None): - self.label = label - self.type = tipe - self.apply = apply - self.format_value = fmt - self.confirm = confirm - self.cost = cost - self.choices = choices - - -class AxisOptionImg2Img(AxisOption): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.is_img2img = True - -class AxisOptionTxt2Img(AxisOption): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.is_img2img = False - - -axis_options = [ - AxisOption("Nothing", str, do_nothing, fmt=format_nothing), - AxisOption("Seed", int, apply_field("seed")), - AxisOption("Var. seed", int, apply_field("subseed")), - AxisOption("Var. strength", float, apply_field("subseed_strength")), - AxisOption("Steps", int, apply_field("steps")), - AxisOptionTxt2Img("Hires steps", int, apply_field("hr_second_pass_steps")), - AxisOption("CFG Scale", float, apply_field("cfg_scale")), - AxisOptionImg2Img("Image CFG Scale", float, apply_field("image_cfg_scale")), - AxisOption("Prompt S/R", str, apply_prompt, fmt=format_value), - AxisOption("Prompt order", str_permutations, apply_order, fmt=format_value_join_list), - AxisOptionTxt2Img("Sampler", str, apply_sampler, fmt=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]), - AxisOptionImg2Img("Sampler", str, apply_sampler, fmt=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]), - AxisOption("Checkpoint name", str, apply_checkpoint, fmt=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: list(sd_models.checkpoints_list)), - AxisOption("Sigma Churn", float, apply_field("s_churn")), - AxisOption("Sigma min", float, apply_field("s_tmin")), - AxisOption("Sigma max", float, apply_field("s_tmax")), - AxisOption("Sigma noise", float, apply_field("s_noise")), - AxisOption("Eta", float, apply_field("eta")), - AxisOption("Clip skip", int, apply_clip_skip), - AxisOption("Denoising", float, apply_field("denoising_strength")), - AxisOptionTxt2Img("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]]), - AxisOptionTxt2Img("Fallback latent upscaler sampler", str, apply_fallback, fmt=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]), - AxisOptionImg2Img("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")), - AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: list(sd_vae.vae_dict)), - AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)), - AxisOption("UniPC Order", int, apply_uni_pc_order, cost=0.5), - AxisOption("Face restore", str, apply_face_restore, fmt=format_value), - AxisOption("ToMe ratio",float,apply_token_merging_ratio), - AxisOption("ToMe ratio for Hires fix",float,apply_token_merging_ratio_hr), - AxisOption("ToMe random pertubations",str,apply_token_merging_random, choices = lambda: ["Yes","No"]) -] - - -def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend, include_lone_images, include_sub_grids, first_axes_processed, second_axes_processed, margin_size): - hor_texts = [[images.GridAnnotation(x)] for x in x_labels] - ver_texts = [[images.GridAnnotation(y)] for y in y_labels] - title_texts = [[images.GridAnnotation(z)] for z in z_labels] - list_size = (len(xs) * len(ys) * len(zs)) - processed_result = None - shared.state.job_count = list_size * p.n_iter - - def process_cell(x, y, z, ix, iy, iz): - nonlocal processed_result - - def index(ix, iy, iz): - return ix + iy * len(xs) + iz * len(xs) * len(ys) - - shared.state.job = f"{index(ix, iy, iz) + 1} out of {list_size}" - processed: Processed = cell(x, y, z, ix, iy, iz) - if processed_result is None: - # Use our first processed result object as a template container to hold our full results - processed_result = copy(processed) - processed_result.images = [None] * list_size - processed_result.all_prompts = [None] * list_size - processed_result.all_seeds = [None] * list_size - processed_result.infotexts = [None] * list_size - processed_result.index_of_first_image = 1 - idx = index(ix, iy, iz) - if processed.images: - # Non-empty list indicates some degree of success. - processed_result.images[idx] = processed.images[0] - processed_result.all_prompts[idx] = processed.prompt - processed_result.all_seeds[idx] = processed.seed - processed_result.infotexts[idx] = processed.infotexts[0] - else: - cell_mode = "P" - cell_size = (processed_result.width, processed_result.height) - if processed_result.images[0] is not None: - cell_mode = processed_result.images[0].mode - #This corrects size in case of batches: - cell_size = processed_result.images[0].size - processed_result.images[idx] = Image.new(cell_mode, cell_size) - - if first_axes_processed == 'x': - for ix, x in enumerate(xs): - if second_axes_processed == 'y': - for iy, y in enumerate(ys): - for iz, z in enumerate(zs): - process_cell(x, y, z, ix, iy, iz) - else: - for iz, z in enumerate(zs): - for iy, y in enumerate(ys): - process_cell(x, y, z, ix, iy, iz) - elif first_axes_processed == 'y': - for iy, y in enumerate(ys): - if second_axes_processed == 'x': - for ix, x in enumerate(xs): - for iz, z in enumerate(zs): - process_cell(x, y, z, ix, iy, iz) - else: - for iz, z in enumerate(zs): - for ix, x in enumerate(xs): - process_cell(x, y, z, ix, iy, iz) - elif first_axes_processed == 'z': - for iz, z in enumerate(zs): - if second_axes_processed == 'x': - for ix, x in enumerate(xs): - for iy, y in enumerate(ys): - process_cell(x, y, z, ix, iy, iz) - else: - for iy, y in enumerate(ys): - for ix, x in enumerate(xs): - process_cell(x, y, z, ix, iy, iz) - - if not processed_result: - # Should never happen, I've only seen it on one of four open tabs and it needed to refresh. - shared.log.error("XYZ grid: Processing could not begin, you may need to refresh the tab or restart the service") - return Processed(p, []) - elif not any(processed_result.images): - shared.log.error("XYZ grid: Failed to return even a single processed image") - return Processed(p, []) - - z_count = len(zs) - # sub_grids = [None] * z_count - for i in range(z_count): - start_index = (i * len(xs) * len(ys)) + i - end_index = start_index + len(xs) * len(ys) - grid = images.image_grid(processed_result.images[start_index:end_index], rows=len(ys)) - if draw_legend: - grid = images.draw_grid_annotations(grid, processed_result.images[start_index].size[0], processed_result.images[start_index].size[1], hor_texts, ver_texts, margin_size) - processed_result.images.insert(i, grid) - processed_result.all_prompts.insert(i, processed_result.all_prompts[start_index]) - processed_result.all_seeds.insert(i, processed_result.all_seeds[start_index]) - processed_result.infotexts.insert(i, processed_result.infotexts[start_index]) - sub_grid_size = processed_result.images[0].size - z_grid = images.image_grid(processed_result.images[:z_count], rows=1) - if draw_legend: - z_grid = images.draw_grid_annotations(z_grid, sub_grid_size[0], sub_grid_size[1], title_texts, [[images.GridAnnotation()]]) - processed_result.images.insert(0, z_grid) - #processed_result.all_prompts.insert(0, processed_result.all_prompts[0]) - #processed_result.all_seeds.insert(0, processed_result.all_seeds[0]) - processed_result.infotexts.insert(0, processed_result.infotexts[0]) - return processed_result - - -class SharedSettingsStackHelper(object): - def __enter__(self): - #Save overridden settings so they can be restored later. - self.CLIP_stop_at_last_layers = shared.opts.CLIP_stop_at_last_layers - self.vae = shared.opts.sd_vae - self.uni_pc_order = shared.opts.uni_pc_order - self.token_merging_ratio_hr = shared.opts.token_merging_ratio_hr - self.token_merging_ratio = shared.opts.token_merging_ratio - self.token_merging_random = shared.opts.token_merging_random - self.sd_model_checkpoint = shared.opts.sd_model_checkpoint - self.sd_vae_checkpoint = shared.opts.sd_vae - - def __exit__(self, exc_type, exc_value, tb): - #Restore overriden settings after plot generation. - shared.opts.data["sd_vae"] = self.vae - shared.opts.data["uni_pc_order"] = self.uni_pc_order - shared.opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers - shared.opts.data["token_merging_ratio_hr"] = self.token_merging_ratio_hr - shared.opts.data["token_merging_ratio"] = self.token_merging_ratio - shared.opts.data["token_merging_random"] = self.token_merging_random - if self.sd_model_checkpoint != shared.opts.sd_model_checkpoint: - shared.opts.data["sd_model_checkpoint"] = self.sd_model_checkpoint - sd_models.reload_model_weights() - if self.sd_vae_checkpoint != shared.opts.sd_vae: - shared.opts.data["sd_vae"] = self.sd_vae_checkpoint - sd_vae.reload_vae_weights() - - -re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*") -re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*") -re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*") -re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*") - - -class Script(scripts.Script): - def title(self): - return "X/Y/Z plot" - - def ui(self, is_img2img): - self.current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img == is_img2img] - with gr.Row(): - with gr.Column(scale=19): - with gr.Row(): - x_type = gr.Dropdown(label="X type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type")) - x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values")) - x_values_dropdown = gr.Dropdown(label="X values",visible=False,multiselect=True,interactive=True) - fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_x_tool_button", visible=False) - - with gr.Row(): - y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type")) - y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values")) - y_values_dropdown = gr.Dropdown(label="Y values",visible=False,multiselect=True,interactive=True) - fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_y_tool_button", visible=False) - - with gr.Row(): - z_type = gr.Dropdown(label="Z type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("z_type")) - z_values = gr.Textbox(label="Z values", lines=1, elem_id=self.elem_id("z_values")) - z_values_dropdown = gr.Dropdown(label="Z values",visible=False,multiselect=True,interactive=True) - fill_z_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_z_tool_button", visible=False) - with gr.Row(variant="compact", elem_id="axis_options"): - draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend")) - no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds")) - include_lone_images = gr.Checkbox(label='Include Sub Images', value=False, elem_id=self.elem_id("include_lone_images")) - include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids")) - with gr.Row(variant="compact", elem_id="axis_options"): - margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size")) - with gr.Row(variant="compact", elem_id="swap_axes"): - swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button") - swap_yz_axes_button = gr.Button(value="Swap Y/Z axes", elem_id="yz_grid_swap_axes_button") - swap_xz_axes_button = gr.Button(value="Swap X/Z axes", elem_id="xz_grid_swap_axes_button") - - def swap_axes(axis1_type, axis1_values, axis1_values_dropdown, axis2_type, axis2_values, axis2_values_dropdown): - return self.current_axis_options[axis2_type].label, axis2_values, axis2_values_dropdown, self.current_axis_options[axis1_type].label, axis1_values, axis1_values_dropdown - - xy_swap_args = [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown] - swap_xy_axes_button.click(swap_axes, inputs=xy_swap_args, outputs=xy_swap_args) - yz_swap_args = [y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown] - swap_yz_axes_button.click(swap_axes, inputs=yz_swap_args, outputs=yz_swap_args) - xz_swap_args = [x_type, x_values, x_values_dropdown, z_type, z_values, z_values_dropdown] - swap_xz_axes_button.click(swap_axes, inputs=xz_swap_args, outputs=xz_swap_args) - - def fill(x_type): - axis = self.current_axis_options[x_type] - return axis.choices() if axis.choices else gr.update() - - fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values_dropdown]) - fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values_dropdown]) - fill_z_button.click(fn=fill, inputs=[z_type], outputs=[z_values_dropdown]) - - def select_axis(axis_type,axis_values_dropdown): - choices = self.current_axis_options[axis_type].choices - has_choices = choices is not None - current_values = axis_values_dropdown - if has_choices: - choices = choices() - if isinstance(current_values,str): - current_values = current_values.split(",") - current_values = list(filter(lambda x: x in choices, current_values)) - return gr.Button.update(visible=has_choices),gr.Textbox.update(visible=not has_choices),gr.update(choices=choices if has_choices else None,visible=has_choices,value=current_values) - - x_type.change(fn=select_axis, inputs=[x_type,x_values_dropdown], outputs=[fill_x_button,x_values,x_values_dropdown]) - y_type.change(fn=select_axis, inputs=[y_type,y_values_dropdown], outputs=[fill_y_button,y_values,y_values_dropdown]) - z_type.change(fn=select_axis, inputs=[z_type,z_values_dropdown], outputs=[fill_z_button,z_values,z_values_dropdown]) - - def get_dropdown_update_from_params(axis,params): - val_key = axis + " Values" - vals = params.get(val_key,"") - valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x] - return gr.update(value = valslist) - - self.infotext_fields = ( - (x_type, "X Type"), - (x_values, "X Values"), - (x_values_dropdown, lambda params:get_dropdown_update_from_params("X",params)), - (y_type, "Y Type"), - (y_values, "Y Values"), - (y_values_dropdown, lambda params:get_dropdown_update_from_params("Y",params)), - (z_type, "Z Type"), - (z_values, "Z Values"), - (z_values_dropdown, lambda params:get_dropdown_update_from_params("Z",params)), - ) - - return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size] - - def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size): # pylint: disable=arguments-differ - shared.log.debug(f'xyzgrid: {x_type}|{x_values}|{x_values_dropdown}|{y_type}|{y_values}|{y_values_dropdown}|{z_type}|{z_values}|{z_values_dropdown}|{draw_legend}|{include_lone_images}|{include_sub_grids}|{no_fixed_seeds}|{margin_size}') - if not no_fixed_seeds: - processing.fix_seed(p) - if not shared.opts.return_grid: - p.batch_size = 1 - def process_axis(opt, vals, vals_dropdown): - if opt.label == 'Nothing': - return [0] - if opt.choices is not None: - valslist = vals_dropdown - else: - valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x] - if opt.type == int: - valslist_ext = [] - for val in valslist: - m = re_range.fullmatch(val) - mc = re_range_count.fullmatch(val) - if m is not None: - start = int(m.group(1)) - end = int(m.group(2))+1 - step = int(m.group(3)) if m.group(3) is not None else 1 - valslist_ext += list(range(start, end, step)) - elif mc is not None: - start = int(mc.group(1)) - end = int(mc.group(2)) - num = int(mc.group(3)) if mc.group(3) is not None else 1 - valslist_ext += [int(x) for x in np.linspace(start=start, stop=end, num=num).tolist()] - else: - valslist_ext.append(val) - valslist = valslist_ext - elif opt.type == float: - valslist_ext = [] - for val in valslist: - m = re_range_float.fullmatch(val) - mc = re_range_count_float.fullmatch(val) - if m is not None: - start = float(m.group(1)) - end = float(m.group(2)) - step = float(m.group(3)) if m.group(3) is not None else 1 - valslist_ext += np.arange(start, end + step, step).tolist() - elif mc is not None: - start = float(mc.group(1)) - end = float(mc.group(2)) - num = int(mc.group(3)) if mc.group(3) is not None else 1 - valslist_ext += np.linspace(start=start, stop=end, num=num).tolist() - else: - valslist_ext.append(val) - valslist = valslist_ext - elif opt.type == str_permutations: - valslist = list(permutations(valslist)) - valslist = [opt.type(x) for x in valslist] - # Confirm options are valid before starting - if opt.confirm: - opt.confirm(p, valslist) - return valslist - - x_opt = self.current_axis_options[x_type] - if x_opt.choices is not None: - x_values = ",".join(x_values_dropdown) - xs = process_axis(x_opt, x_values, x_values_dropdown) - y_opt = self.current_axis_options[y_type] - if y_opt.choices is not None: - y_values = ",".join(y_values_dropdown) - ys = process_axis(y_opt, y_values, y_values_dropdown) - z_opt = self.current_axis_options[z_type] - if z_opt.choices is not None: - z_values = ",".join(z_values_dropdown) - zs = process_axis(z_opt, z_values, z_values_dropdown) - Image.MAX_IMAGE_PIXELS = None # disable check in Pillow and rely on check below to allow large custom image sizes - grid_mp = round(len(xs) * len(ys) * len(zs) * p.width * p.height / 1000000) - assert grid_mp < shared.opts.img_max_size_mp, f'Error: Resulting grid would be too large ({grid_mp} MPixels) (max configured size is {shared.opts.img_max_size_mp} MPixels)' - - def fix_axis_seeds(axis_opt, axis_list): - if axis_opt.label in ['Seed', 'Var. seed']: - return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list] - else: - return axis_list - - if not no_fixed_seeds: - xs = fix_axis_seeds(x_opt, xs) - ys = fix_axis_seeds(y_opt, ys) - zs = fix_axis_seeds(z_opt, zs) - - if x_opt.label == 'Steps': - total_steps = sum(xs) * len(ys) * len(zs) - elif y_opt.label == 'Steps': - total_steps = sum(ys) * len(xs) * len(zs) - elif z_opt.label == 'Steps': - total_steps = sum(zs) * len(xs) * len(ys) - else: - total_steps = p.steps * len(xs) * len(ys) * len(zs) - if isinstance(p, StableDiffusionProcessingTxt2Img) and p.enable_hr: - if x_opt.label == "Hires steps": - total_steps += sum(xs) * len(ys) * len(zs) - elif y_opt.label == "Hires steps": - total_steps += sum(ys) * len(xs) * len(zs) - elif z_opt.label == "Hires steps": - total_steps += sum(zs) * len(xs) * len(ys) - elif p.hr_second_pass_steps: - total_steps += p.hr_second_pass_steps * len(xs) * len(ys) * len(zs) - else: - total_steps *= 2 - total_steps *= p.n_iter - image_cell_count = p.n_iter * p.batch_size - shared.log.info(f"XYZ grid: images={len(xs)*len(ys)*len(zs)*image_cell_count} grid={len(zs)} {len(xs)}x{len(ys)} cells={len(zs)} steps={total_steps}") - shared.state.xyz_plot_x = AxisInfo(x_opt, xs) - shared.state.xyz_plot_y = AxisInfo(y_opt, ys) - shared.state.xyz_plot_z = AxisInfo(z_opt, zs) - # If one of the axes is very slow to change between (like SD model checkpoint), then make sure it is in the outer iteration of the nested `for` loop. - first_axes_processed = 'z' - second_axes_processed = 'y' - if x_opt.cost > y_opt.cost and x_opt.cost > z_opt.cost: - first_axes_processed = 'x' - if y_opt.cost > z_opt.cost: - second_axes_processed = 'y' - else: - second_axes_processed = 'z' - elif y_opt.cost > x_opt.cost and y_opt.cost > z_opt.cost: - first_axes_processed = 'y' - if x_opt.cost > z_opt.cost: - second_axes_processed = 'x' - else: - second_axes_processed = 'z' - elif z_opt.cost > x_opt.cost and z_opt.cost > y_opt.cost: - first_axes_processed = 'z' - if x_opt.cost > y_opt.cost: - second_axes_processed = 'x' - else: - second_axes_processed = 'y' - grid_infotext = [None] * (1 + len(zs)) - - def cell(x, y, z, ix, iy, iz): - if shared.state.interrupted: - return Processed(p, [], p.seed, "") - pc = copy(p) - pc.styles = pc.styles[:] - x_opt.apply(pc, x, xs) - y_opt.apply(pc, y, ys) - z_opt.apply(pc, z, zs) - res = process_images(pc) - # Sets subgrid infotexts - subgrid_index = 1 + iz - if grid_infotext[subgrid_index] is None and ix == 0 and iy == 0: - pc.extra_generation_params = copy(pc.extra_generation_params) - pc.extra_generation_params['Script'] = self.title() - if x_opt.label != 'Nothing': - pc.extra_generation_params["X Type"] = x_opt.label - pc.extra_generation_params["X Values"] = x_values - if x_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds: - pc.extra_generation_params["Fixed X Values"] = ", ".join([str(x) for x in xs]) - if y_opt.label != 'Nothing': - pc.extra_generation_params["Y Type"] = y_opt.label - pc.extra_generation_params["Y Values"] = y_values - if y_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds: - pc.extra_generation_params["Fixed Y Values"] = ", ".join([str(y) for y in ys]) - grid_infotext[subgrid_index] = processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds) - # Sets main grid infotext - if grid_infotext[0] is None and ix == 0 and iy == 0 and iz == 0: - pc.extra_generation_params = copy(pc.extra_generation_params) - - if z_opt.label != 'Nothing': - pc.extra_generation_params["Z Type"] = z_opt.label - pc.extra_generation_params["Z Values"] = z_values - if z_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds: - pc.extra_generation_params["Fixed Z Values"] = ", ".join([str(z) for z in zs]) - grid_infotext[0] = processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds) - return res - - with SharedSettingsStackHelper(): - processed = draw_xyz_grid( - p, - xs=xs, - ys=ys, - zs=zs, - x_labels=[x_opt.format_value(p, x_opt, x) for x in xs], - y_labels=[y_opt.format_value(p, y_opt, y) for y in ys], - z_labels=[z_opt.format_value(p, z_opt, z) for z in zs], - cell=cell, - draw_legend=draw_legend, - include_lone_images=include_lone_images, - include_sub_grids=include_sub_grids, - first_axes_processed=first_axes_processed, - second_axes_processed=second_axes_processed, - margin_size=margin_size - ) - - if not processed.images: - # It broke, no further handling needed. - return processed - z_count = len(zs) - # Set the grid infotexts to the real ones with extra_generation_params (1 main grid + z_count sub-grids) - processed.infotexts[:1+z_count] = grid_infotext[:1+z_count] - if not include_lone_images: - # Don't need sub-images anymore, drop from list: - processed.images = processed.images[:z_count+1] - if shared.opts.grid_save: - # Auto-save main and sub-grids: - grid_count = z_count + 1 if z_count > 1 else 1 - for g in range(grid_count): - adj_g = g-1 if g > 0 else g - images.save_image(processed.images[g], p.outpath_grids, "xyz_grid", info=processed.infotexts[g], extension=shared.opts.grid_format, prompt=processed.all_prompts[adj_g], seed=processed.all_seeds[adj_g], grid=True, p=processed) - if not include_sub_grids: - # Done with sub-grids, drop all related information: - for _sg in range(z_count): - del processed.images[1] - del processed.all_prompts[1] - del processed.all_seeds[1] - del processed.infotexts[1] - return processed + # pylint: disable=unused-argument, attribute-defined-outside-init + +import re +import csv +import random +from collections import namedtuple +from copy import copy +from itertools import permutations, chain +from io import StringIO +from PIL import Image +import numpy as np +import gradio as gr +import modules.scripts as scripts +import modules.shared as shared +from modules import images, sd_samplers, processing, sd_models, sd_vae +from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img +from modules.ui_components import ToolButton + +fill_values_symbol = "\U0001f4d2" # 📒 +AxisInfo = namedtuple('AxisInfo', ['axis', 'values']) + + +def apply_field(field): + def fun(p, x, xs): + setattr(p, field, x) + return fun + + +def apply_prompt(p, x, xs): + if xs[0] not in p.prompt and xs[0] not in p.negative_prompt: + shared.log.warning(f"XYZ grid: prompt S/R did not find {xs[0]} in prompt or negative prompt.") + else: + p.prompt = p.prompt.replace(xs[0], x) + p.negative_prompt = p.negative_prompt.replace(xs[0], x) + + +def apply_order(p, x, xs): + token_order = [] + for token in x: + token_order.append((p.prompt.find(token), token)) + token_order.sort(key=lambda t: t[0]) + prompt_parts = [] + for _, token in token_order: + n = p.prompt.find(token) + prompt_parts.append(p.prompt[0:n]) + p.prompt = p.prompt[n + len(token):] + prompt_tmp = "" + for idx, part in enumerate(prompt_parts): + prompt_tmp += part + prompt_tmp += x[idx] + p.prompt = prompt_tmp + p.prompt + + +def apply_sampler(p, x, xs): + sampler_name = sd_samplers.samplers_map.get(x.lower(), None) + if sampler_name is None: + shared.log.warning(f"XYZ grid: unknown sampler: {x}") + else: + p.sampler_name = sampler_name + + +def confirm_samplers(p, xs): + for x in xs: + if x.lower() not in sd_samplers.samplers_map: + shared.log.warning(f"XYZ grid: unknown sampler: {x}") + + +def apply_checkpoint(p, x, xs): + if x == shared.opts.sd_model_checkpoint: + return + info = sd_models.get_closet_checkpoint_match(x) + if info is None: + shared.log.warning(f"XYZ grid: unknown checkpoint: {x}") + else: + sd_models.reload_model_weights(shared.sd_model, info) + + +def confirm_checkpoints(p, xs): + for x in xs: + if sd_models.get_closet_checkpoint_match(x) is None: + shared.log.warning(f"XYZ grid: Unknown checkpoint: {x}") + + +def apply_clip_skip(p, x, xs): + shared.opts.data["CLIP_stop_at_last_layers"] = x + + +def apply_upscale_latent_space(p, x, xs): + if x.lower().strip() != '0': + shared.opts.data["use_scale_latent_for_hires_fix"] = True + else: + shared.opts.data["use_scale_latent_for_hires_fix"] = False + + +def find_vae(name: str): + if name.lower() in ['auto', 'automatic']: + return sd_vae.unspecified + if name.lower() == 'none': + return None + else: + choices = [x for x in sorted(sd_vae.vae_dict, key=lambda x: len(x)) if name.lower().strip() in x.lower()] + if len(choices) == 0: + shared.log.warning(f"No VAE found for {name}; using automatic") + return sd_vae.unspecified + else: + return sd_vae.vae_dict[choices[0]] + + +def apply_vae(p, x, xs): + sd_vae.reload_vae_weights(shared.sd_model, vae_file=find_vae(x)) + + +def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _): + p.styles.extend(x.split(',')) + + +def apply_fallback(p, x, xs): + sampler_name = sd_samplers.samplers_map.get(x.lower(), None) + if sampler_name is None: + shared.log.warning(f"XYZ grid: unknown sampler: {x}") + else: + shared.opts.data["xyz_fallback_sampler"] = sampler_name + + +def apply_uni_pc_order(p, x, xs): + shared.opts.data["uni_pc_order"] = min(x, p.steps - 1) + + +def apply_face_restore(p, opt, x): + opt = opt.lower() + if opt == 'codeformer': + is_active = True + p.face_restoration_model = 'CodeFormer' + elif opt == 'gfpgan': + is_active = True + p.face_restoration_model = 'GFPGAN' + else: + is_active = opt in ('true', 'yes', 'y', '1') + p.restore_faces = is_active + + +def apply_token_merging_ratio_hr(p, x, xs): + shared.opts.data["token_merging_ratio_hr"] = x + + +def apply_token_merging_ratio(p, x, xs): + shared.opts.data["token_merging_ratio"] = x + + +def apply_token_merging_random(p, x, xs): + is_active = x.lower() in ('true', 'yes', 'y', '1') + shared.opts.data["token_merging_random"] = is_active + + +def format_value_add_label(p, opt, x): + if type(x) == float: + x = round(x, 8) + return f"{opt.label}: {x}" + + +def format_value(p, opt, x): + if type(x) == float: + x = round(x, 8) + return x + + +def format_value_join_list(p, opt, x): + return ", ".join(x) + + +def do_nothing(p, x, xs): + pass + + +def format_nothing(p, opt, x): + return "" + + +def str_permutations(x): + """dummy function for specifying it in AxisOption's type when you want to get a list of permutations""" + return x + + +class AxisOption: + def __init__(self, label, tipe, apply, fmt=format_value_add_label, confirm=None, cost=0.0, choices=None): + self.label = label + self.type = tipe + self.apply = apply + self.format_value = fmt + self.confirm = confirm + self.cost = cost + self.choices = choices + + +class AxisOptionImg2Img(AxisOption): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_img2img = True + +class AxisOptionTxt2Img(AxisOption): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_img2img = False + + +axis_options = [ + AxisOption("Nothing", str, do_nothing, fmt=format_nothing), + AxisOption("Seed", int, apply_field("seed")), + AxisOption("Var. seed", int, apply_field("subseed")), + AxisOption("Var. strength", float, apply_field("subseed_strength")), + AxisOption("Steps", int, apply_field("steps")), + AxisOptionTxt2Img("Hires steps", int, apply_field("hr_second_pass_steps")), + AxisOption("CFG Scale", float, apply_field("cfg_scale")), + AxisOptionImg2Img("Image CFG Scale", float, apply_field("image_cfg_scale")), + AxisOption("Prompt S/R", str, apply_prompt, fmt=format_value), + AxisOption("Prompt order", str_permutations, apply_order, fmt=format_value_join_list), + AxisOptionTxt2Img("Sampler", str, apply_sampler, fmt=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]), + AxisOptionImg2Img("Sampler", str, apply_sampler, fmt=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]), + AxisOption("Checkpoint name", str, apply_checkpoint, fmt=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: list(sd_models.checkpoints_list)), + AxisOption("Sigma Churn", float, apply_field("s_churn")), + AxisOption("Sigma min", float, apply_field("s_tmin")), + AxisOption("Sigma max", float, apply_field("s_tmax")), + AxisOption("Sigma noise", float, apply_field("s_noise")), + AxisOption("Eta", float, apply_field("eta")), + AxisOption("Clip skip", int, apply_clip_skip), + AxisOption("Denoising", float, apply_field("denoising_strength")), + AxisOptionTxt2Img("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]]), + AxisOptionTxt2Img("Fallback latent upscaler sampler", str, apply_fallback, fmt=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]), + AxisOptionImg2Img("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")), + AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: list(sd_vae.vae_dict)), + AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)), + AxisOption("UniPC Order", int, apply_uni_pc_order, cost=0.5), + AxisOption("Face restore", str, apply_face_restore, fmt=format_value), + AxisOption("ToMe ratio",float,apply_token_merging_ratio), + AxisOption("ToMe ratio for Hires fix",float,apply_token_merging_ratio_hr), + AxisOption("ToMe random pertubations",str,apply_token_merging_random, choices = lambda: ["Yes","No"]) +] + + +def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend, include_lone_images, include_sub_grids, first_axes_processed, second_axes_processed, margin_size): + hor_texts = [[images.GridAnnotation(x)] for x in x_labels] + ver_texts = [[images.GridAnnotation(y)] for y in y_labels] + title_texts = [[images.GridAnnotation(z)] for z in z_labels] + list_size = (len(xs) * len(ys) * len(zs)) + processed_result = None + shared.state.job_count = list_size * p.n_iter + + def process_cell(x, y, z, ix, iy, iz): + nonlocal processed_result + + def index(ix, iy, iz): + return ix + iy * len(xs) + iz * len(xs) * len(ys) + + shared.state.job = f"{index(ix, iy, iz) + 1} out of {list_size}" + processed: Processed = cell(x, y, z, ix, iy, iz) + if processed_result is None: + # Use our first processed result object as a template container to hold our full results + processed_result = copy(processed) + processed_result.images = [None] * list_size + processed_result.all_prompts = [None] * list_size + processed_result.all_seeds = [None] * list_size + processed_result.infotexts = [None] * list_size + processed_result.index_of_first_image = 1 + idx = index(ix, iy, iz) + if processed.images: + # Non-empty list indicates some degree of success. + processed_result.images[idx] = processed.images[0] + processed_result.all_prompts[idx] = processed.prompt + processed_result.all_seeds[idx] = processed.seed + processed_result.infotexts[idx] = processed.infotexts[0] + else: + cell_mode = "P" + cell_size = (processed_result.width, processed_result.height) + if processed_result.images[0] is not None: + cell_mode = processed_result.images[0].mode + #This corrects size in case of batches: + cell_size = processed_result.images[0].size + processed_result.images[idx] = Image.new(cell_mode, cell_size) + + if first_axes_processed == 'x': + for ix, x in enumerate(xs): + if second_axes_processed == 'y': + for iy, y in enumerate(ys): + for iz, z in enumerate(zs): + process_cell(x, y, z, ix, iy, iz) + else: + for iz, z in enumerate(zs): + for iy, y in enumerate(ys): + process_cell(x, y, z, ix, iy, iz) + elif first_axes_processed == 'y': + for iy, y in enumerate(ys): + if second_axes_processed == 'x': + for ix, x in enumerate(xs): + for iz, z in enumerate(zs): + process_cell(x, y, z, ix, iy, iz) + else: + for iz, z in enumerate(zs): + for ix, x in enumerate(xs): + process_cell(x, y, z, ix, iy, iz) + elif first_axes_processed == 'z': + for iz, z in enumerate(zs): + if second_axes_processed == 'x': + for ix, x in enumerate(xs): + for iy, y in enumerate(ys): + process_cell(x, y, z, ix, iy, iz) + else: + for iy, y in enumerate(ys): + for ix, x in enumerate(xs): + process_cell(x, y, z, ix, iy, iz) + + if not processed_result: + # Should never happen, I've only seen it on one of four open tabs and it needed to refresh. + shared.log.error("XYZ grid: Processing could not begin, you may need to refresh the tab or restart the service") + return Processed(p, []) + elif not any(processed_result.images): + shared.log.error("XYZ grid: Failed to return even a single processed image") + return Processed(p, []) + + z_count = len(zs) + # sub_grids = [None] * z_count + for i in range(z_count): + start_index = (i * len(xs) * len(ys)) + i + end_index = start_index + len(xs) * len(ys) + grid = images.image_grid(processed_result.images[start_index:end_index], rows=len(ys)) + if draw_legend: + grid = images.draw_grid_annotations(grid, processed_result.images[start_index].size[0], processed_result.images[start_index].size[1], hor_texts, ver_texts, margin_size) + processed_result.images.insert(i, grid) + processed_result.all_prompts.insert(i, processed_result.all_prompts[start_index]) + processed_result.all_seeds.insert(i, processed_result.all_seeds[start_index]) + processed_result.infotexts.insert(i, processed_result.infotexts[start_index]) + sub_grid_size = processed_result.images[0].size + z_grid = images.image_grid(processed_result.images[:z_count], rows=1) + if draw_legend: + z_grid = images.draw_grid_annotations(z_grid, sub_grid_size[0], sub_grid_size[1], title_texts, [[images.GridAnnotation()]]) + processed_result.images.insert(0, z_grid) + #processed_result.all_prompts.insert(0, processed_result.all_prompts[0]) + #processed_result.all_seeds.insert(0, processed_result.all_seeds[0]) + processed_result.infotexts.insert(0, processed_result.infotexts[0]) + return processed_result + + +class SharedSettingsStackHelper(object): + def __enter__(self): + #Save overridden settings so they can be restored later. + self.CLIP_stop_at_last_layers = shared.opts.CLIP_stop_at_last_layers + self.vae = shared.opts.sd_vae + self.uni_pc_order = shared.opts.uni_pc_order + self.token_merging_ratio_hr = shared.opts.token_merging_ratio_hr + self.token_merging_ratio = shared.opts.token_merging_ratio + self.token_merging_random = shared.opts.token_merging_random + self.sd_model_checkpoint = shared.opts.sd_model_checkpoint + self.sd_vae_checkpoint = shared.opts.sd_vae + + def __exit__(self, exc_type, exc_value, tb): + #Restore overriden settings after plot generation. + shared.opts.data["sd_vae"] = self.vae + shared.opts.data["uni_pc_order"] = self.uni_pc_order + shared.opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers + shared.opts.data["token_merging_ratio_hr"] = self.token_merging_ratio_hr + shared.opts.data["token_merging_ratio"] = self.token_merging_ratio + shared.opts.data["token_merging_random"] = self.token_merging_random + if self.sd_model_checkpoint != shared.opts.sd_model_checkpoint: + shared.opts.data["sd_model_checkpoint"] = self.sd_model_checkpoint + sd_models.reload_model_weights() + if self.sd_vae_checkpoint != shared.opts.sd_vae: + shared.opts.data["sd_vae"] = self.sd_vae_checkpoint + sd_vae.reload_vae_weights() + + +re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*") +re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*") +re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*") +re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*") + + +class Script(scripts.Script): + def title(self): + return "X/Y/Z plot" + + def ui(self, is_img2img): + self.current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img == is_img2img] + with gr.Row(): + with gr.Column(scale=19): + with gr.Row(): + x_type = gr.Dropdown(label="X type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type")) + x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values")) + x_values_dropdown = gr.Dropdown(label="X values",visible=False,multiselect=True,interactive=True) + fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_x_tool_button", visible=False) + + with gr.Row(): + y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type")) + y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values")) + y_values_dropdown = gr.Dropdown(label="Y values",visible=False,multiselect=True,interactive=True) + fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_y_tool_button", visible=False) + + with gr.Row(): + z_type = gr.Dropdown(label="Z type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("z_type")) + z_values = gr.Textbox(label="Z values", lines=1, elem_id=self.elem_id("z_values")) + z_values_dropdown = gr.Dropdown(label="Z values",visible=False,multiselect=True,interactive=True) + fill_z_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_z_tool_button", visible=False) + with gr.Row(variant="compact", elem_id="axis_options"): + draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend")) + no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds")) + include_lone_images = gr.Checkbox(label='Include Sub Images', value=False, elem_id=self.elem_id("include_lone_images")) + include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids")) + with gr.Row(variant="compact", elem_id="axis_options"): + margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size")) + with gr.Row(variant="compact", elem_id="swap_axes"): + swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button") + swap_yz_axes_button = gr.Button(value="Swap Y/Z axes", elem_id="yz_grid_swap_axes_button") + swap_xz_axes_button = gr.Button(value="Swap X/Z axes", elem_id="xz_grid_swap_axes_button") + + def swap_axes(axis1_type, axis1_values, axis1_values_dropdown, axis2_type, axis2_values, axis2_values_dropdown): + return self.current_axis_options[axis2_type].label, axis2_values, axis2_values_dropdown, self.current_axis_options[axis1_type].label, axis1_values, axis1_values_dropdown + + xy_swap_args = [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown] + swap_xy_axes_button.click(swap_axes, inputs=xy_swap_args, outputs=xy_swap_args) + yz_swap_args = [y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown] + swap_yz_axes_button.click(swap_axes, inputs=yz_swap_args, outputs=yz_swap_args) + xz_swap_args = [x_type, x_values, x_values_dropdown, z_type, z_values, z_values_dropdown] + swap_xz_axes_button.click(swap_axes, inputs=xz_swap_args, outputs=xz_swap_args) + + def fill(x_type): + axis = self.current_axis_options[x_type] + return axis.choices() if axis.choices else gr.update() + + fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values_dropdown]) + fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values_dropdown]) + fill_z_button.click(fn=fill, inputs=[z_type], outputs=[z_values_dropdown]) + + def select_axis(axis_type,axis_values_dropdown): + choices = self.current_axis_options[axis_type].choices + has_choices = choices is not None + current_values = axis_values_dropdown + if has_choices: + choices = choices() + if isinstance(current_values,str): + current_values = current_values.split(",") + current_values = list(filter(lambda x: x in choices, current_values)) + return gr.Button.update(visible=has_choices),gr.Textbox.update(visible=not has_choices),gr.update(choices=choices if has_choices else None,visible=has_choices,value=current_values) + + x_type.change(fn=select_axis, inputs=[x_type,x_values_dropdown], outputs=[fill_x_button,x_values,x_values_dropdown]) + y_type.change(fn=select_axis, inputs=[y_type,y_values_dropdown], outputs=[fill_y_button,y_values,y_values_dropdown]) + z_type.change(fn=select_axis, inputs=[z_type,z_values_dropdown], outputs=[fill_z_button,z_values,z_values_dropdown]) + + def get_dropdown_update_from_params(axis,params): + val_key = axis + " Values" + vals = params.get(val_key,"") + valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x] + return gr.update(value = valslist) + + self.infotext_fields = ( + (x_type, "X Type"), + (x_values, "X Values"), + (x_values_dropdown, lambda params:get_dropdown_update_from_params("X",params)), + (y_type, "Y Type"), + (y_values, "Y Values"), + (y_values_dropdown, lambda params:get_dropdown_update_from_params("Y",params)), + (z_type, "Z Type"), + (z_values, "Z Values"), + (z_values_dropdown, lambda params:get_dropdown_update_from_params("Z",params)), + ) + + return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size] + + def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size): # pylint: disable=arguments-differ + shared.log.debug(f'xyzgrid: {x_type}|{x_values}|{x_values_dropdown}|{y_type}|{y_values}|{y_values_dropdown}|{z_type}|{z_values}|{z_values_dropdown}|{draw_legend}|{include_lone_images}|{include_sub_grids}|{no_fixed_seeds}|{margin_size}') + if not no_fixed_seeds: + processing.fix_seed(p) + if not shared.opts.return_grid: + p.batch_size = 1 + def process_axis(opt, vals, vals_dropdown): + if opt.label == 'Nothing': + return [0] + if opt.choices is not None: + valslist = vals_dropdown + else: + valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x] + if opt.type == int: + valslist_ext = [] + for val in valslist: + m = re_range.fullmatch(val) + mc = re_range_count.fullmatch(val) + if m is not None: + start = int(m.group(1)) + end = int(m.group(2))+1 + step = int(m.group(3)) if m.group(3) is not None else 1 + valslist_ext += list(range(start, end, step)) + elif mc is not None: + start = int(mc.group(1)) + end = int(mc.group(2)) + num = int(mc.group(3)) if mc.group(3) is not None else 1 + valslist_ext += [int(x) for x in np.linspace(start=start, stop=end, num=num).tolist()] + else: + valslist_ext.append(val) + valslist = valslist_ext + elif opt.type == float: + valslist_ext = [] + for val in valslist: + m = re_range_float.fullmatch(val) + mc = re_range_count_float.fullmatch(val) + if m is not None: + start = float(m.group(1)) + end = float(m.group(2)) + step = float(m.group(3)) if m.group(3) is not None else 1 + valslist_ext += np.arange(start, end + step, step).tolist() + elif mc is not None: + start = float(mc.group(1)) + end = float(mc.group(2)) + num = int(mc.group(3)) if mc.group(3) is not None else 1 + valslist_ext += np.linspace(start=start, stop=end, num=num).tolist() + else: + valslist_ext.append(val) + valslist = valslist_ext + elif opt.type == str_permutations: + valslist = list(permutations(valslist)) + valslist = [opt.type(x) for x in valslist] + # Confirm options are valid before starting + if opt.confirm: + opt.confirm(p, valslist) + return valslist + + x_opt = self.current_axis_options[x_type] + if x_opt.choices is not None: + x_values = ",".join(x_values_dropdown) + xs = process_axis(x_opt, x_values, x_values_dropdown) + y_opt = self.current_axis_options[y_type] + if y_opt.choices is not None: + y_values = ",".join(y_values_dropdown) + ys = process_axis(y_opt, y_values, y_values_dropdown) + z_opt = self.current_axis_options[z_type] + if z_opt.choices is not None: + z_values = ",".join(z_values_dropdown) + zs = process_axis(z_opt, z_values, z_values_dropdown) + Image.MAX_IMAGE_PIXELS = None # disable check in Pillow and rely on check below to allow large custom image sizes + grid_mp = round(len(xs) * len(ys) * len(zs) * p.width * p.height / 1000000) + assert grid_mp < shared.opts.img_max_size_mp, f'Error: Resulting grid would be too large ({grid_mp} MPixels) (max configured size is {shared.opts.img_max_size_mp} MPixels)' + + def fix_axis_seeds(axis_opt, axis_list): + if axis_opt.label in ['Seed', 'Var. seed']: + return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list] + else: + return axis_list + + if not no_fixed_seeds: + xs = fix_axis_seeds(x_opt, xs) + ys = fix_axis_seeds(y_opt, ys) + zs = fix_axis_seeds(z_opt, zs) + + if x_opt.label == 'Steps': + total_steps = sum(xs) * len(ys) * len(zs) + elif y_opt.label == 'Steps': + total_steps = sum(ys) * len(xs) * len(zs) + elif z_opt.label == 'Steps': + total_steps = sum(zs) * len(xs) * len(ys) + else: + total_steps = p.steps * len(xs) * len(ys) * len(zs) + if isinstance(p, StableDiffusionProcessingTxt2Img) and p.enable_hr: + if x_opt.label == "Hires steps": + total_steps += sum(xs) * len(ys) * len(zs) + elif y_opt.label == "Hires steps": + total_steps += sum(ys) * len(xs) * len(zs) + elif z_opt.label == "Hires steps": + total_steps += sum(zs) * len(xs) * len(ys) + elif p.hr_second_pass_steps: + total_steps += p.hr_second_pass_steps * len(xs) * len(ys) * len(zs) + else: + total_steps *= 2 + total_steps *= p.n_iter + image_cell_count = p.n_iter * p.batch_size + shared.log.info(f"XYZ grid: images={len(xs)*len(ys)*len(zs)*image_cell_count} grid={len(zs)} {len(xs)}x{len(ys)} cells={len(zs)} steps={total_steps}") + shared.state.xyz_plot_x = AxisInfo(x_opt, xs) + shared.state.xyz_plot_y = AxisInfo(y_opt, ys) + shared.state.xyz_plot_z = AxisInfo(z_opt, zs) + # If one of the axes is very slow to change between (like SD model checkpoint), then make sure it is in the outer iteration of the nested `for` loop. + first_axes_processed = 'z' + second_axes_processed = 'y' + if x_opt.cost > y_opt.cost and x_opt.cost > z_opt.cost: + first_axes_processed = 'x' + if y_opt.cost > z_opt.cost: + second_axes_processed = 'y' + else: + second_axes_processed = 'z' + elif y_opt.cost > x_opt.cost and y_opt.cost > z_opt.cost: + first_axes_processed = 'y' + if x_opt.cost > z_opt.cost: + second_axes_processed = 'x' + else: + second_axes_processed = 'z' + elif z_opt.cost > x_opt.cost and z_opt.cost > y_opt.cost: + first_axes_processed = 'z' + if x_opt.cost > y_opt.cost: + second_axes_processed = 'x' + else: + second_axes_processed = 'y' + grid_infotext = [None] * (1 + len(zs)) + + def cell(x, y, z, ix, iy, iz): + if shared.state.interrupted: + return Processed(p, [], p.seed, "") + pc = copy(p) + pc.styles = pc.styles[:] + x_opt.apply(pc, x, xs) + y_opt.apply(pc, y, ys) + z_opt.apply(pc, z, zs) + res = process_images(pc) + # Sets subgrid infotexts + subgrid_index = 1 + iz + if grid_infotext[subgrid_index] is None and ix == 0 and iy == 0: + pc.extra_generation_params = copy(pc.extra_generation_params) + pc.extra_generation_params['Script'] = self.title() + if x_opt.label != 'Nothing': + pc.extra_generation_params["X Type"] = x_opt.label + pc.extra_generation_params["X Values"] = x_values + if x_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds: + pc.extra_generation_params["Fixed X Values"] = ", ".join([str(x) for x in xs]) + if y_opt.label != 'Nothing': + pc.extra_generation_params["Y Type"] = y_opt.label + pc.extra_generation_params["Y Values"] = y_values + if y_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds: + pc.extra_generation_params["Fixed Y Values"] = ", ".join([str(y) for y in ys]) + grid_infotext[subgrid_index] = processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds) + # Sets main grid infotext + if grid_infotext[0] is None and ix == 0 and iy == 0 and iz == 0: + pc.extra_generation_params = copy(pc.extra_generation_params) + + if z_opt.label != 'Nothing': + pc.extra_generation_params["Z Type"] = z_opt.label + pc.extra_generation_params["Z Values"] = z_values + if z_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds: + pc.extra_generation_params["Fixed Z Values"] = ", ".join([str(z) for z in zs]) + grid_infotext[0] = processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds) + return res + + with SharedSettingsStackHelper(): + processed = draw_xyz_grid( + p, + xs=xs, + ys=ys, + zs=zs, + x_labels=[x_opt.format_value(p, x_opt, x) for x in xs], + y_labels=[y_opt.format_value(p, y_opt, y) for y in ys], + z_labels=[z_opt.format_value(p, z_opt, z) for z in zs], + cell=cell, + draw_legend=draw_legend, + include_lone_images=include_lone_images, + include_sub_grids=include_sub_grids, + first_axes_processed=first_axes_processed, + second_axes_processed=second_axes_processed, + margin_size=margin_size + ) + + if not processed.images: + # It broke, no further handling needed. + return processed + z_count = len(zs) + # Set the grid infotexts to the real ones with extra_generation_params (1 main grid + z_count sub-grids) + processed.infotexts[:1+z_count] = grid_infotext[:1+z_count] + if not include_lone_images: + # Don't need sub-images anymore, drop from list: + processed.images = processed.images[:z_count+1] + if shared.opts.grid_save: + # Auto-save main and sub-grids: + grid_count = z_count + 1 if z_count > 1 else 1 + for g in range(grid_count): + adj_g = g-1 if g > 0 else g + images.save_image(processed.images[g], p.outpath_grids, "xyz_grid", info=processed.infotexts[g], extension=shared.opts.grid_format, prompt=processed.all_prompts[adj_g], seed=processed.all_seeds[adj_g], grid=True, p=processed) + if not include_sub_grids: + # Done with sub-grids, drop all related information: + for _sg in range(z_count): + del processed.images[1] + del processed.all_prompts[1] + del processed.all_seeds[1] + del processed.infotexts[1] + return processed