Skip to content

Commit

Permalink
add reference models
Browse files Browse the repository at this point in the history
  • Loading branch information
vladmandic committed Nov 5, 2023
1 parent 2919774 commit d878970
Show file tree
Hide file tree
Showing 13 changed files with 150 additions and 35 deletions.
4 changes: 2 additions & 2 deletions .eslintrc.json
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@
"switch_to_extras": "readonly",
"get_tab_index": "readonly",
"create_submit_args": "readonly",
"restart_reload": "readonly",
"restartReload": "readonly",
"updateInput": "readonly",
"toggleCompact": "readonly",
// settings.js
"register_drag_drop": "readonly",
"registerDragDrop": "readonly",
//extraNetworks.js
"requestGet": "readonly",
"getENActiveTab": "readonly",
Expand Down
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ cache
.idea/
/localizations

# unexcluded so folders get created
# force included
!/models/VAE-approx
!/models/VAE-approx/model.pt
!/models/Reference
!/models/Reference/**/*
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ Also, [Wiki](https://github.com/vladmandic/automatic/wiki) has been updated with
Some highlights: [OpenVINO](https://github.com/vladmandic/automatic/wiki/OpenVINO), [IntelArc](https://github.com/vladmandic/automatic/wiki/Intel-ARC), [DirectML](https://github.com/vladmandic/automatic/wiki/DirectML), [ONNX/Olive>](https://github.com/vladmandic/automatic/wiki/ONNX-Runtime)

- **Diffusers**
- since now **SD.Next** supports **12** different model types, we've added reference model for each type in
*Extra networks -> Reference* for easier select & auto-download
Models can still be downloaded manually, this is just a convenience feature & a showcase for supported models
- new model type: [Segmind SSD-1B](https://huggingface.co/segmind/SSD-1B)
its a *distilled* model, this time 50% smaller and faster version of SD-XL!
(and quality does not suffer, its just more optimized)
Expand Down
40 changes: 40 additions & 0 deletions html/reference.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
{
"RunwayML SD 1.5": {
"path": "runwayml/stable-diffusion-v1-5"
},
"StabilityAI SD 2.1": {
"path": "stabilityai/stable-diffusion-2-1-base"
},
"StabilityAI SD-XL 1.0 Base": {
"path": "stabilityai/stable-diffusion-xl-base-1.0"
},
"Segmind SSD-1B": {
"path": "segmind/SSD-1B"
},
"Segmind Tiny": {
"path": "segmind/tiny-sd"
},
"LCM Dreamshaper 7": {
"path": "SimianLuo/LCM_Dreamshaper_v7"
},
"Warp Wuerstchen": {
"path": "warp-ai/wuerstchen"
},
"Kandinsky 2.1": {
"path": "kandinsky-community/kandinsky-2-1"
},
"Kandinsky 2.2": {
"path": "kandinsky-community/kandinsky-2-2-decoder"
},
"DeepFloyd IF Medium": {
"path": "DeepFloyd/IF-I-M-v1.0"
},
"Tsinghua UniDiffuser": {
"path": "thu-ml/unidiffuser-v1",
"desc": "UniDiffuser is a unified diffusion framework to fit all distributions relevant to a set of multi-modal data in one transformer. UniDiffuser is able to perform image, text, text-to-image, image-to-text, and image-text pair generation by setting proper timesteps without additional overhead.\nSpecifically, UniDiffuser employs a variation of transformer, called U-ViT, which parameterizes the joint noise prediction network. Other components perform as encoders and decoders of different modalities, including a pretrained image autoencoder from Stable Diffusion, a pretrained image ViT-B/32 CLIP encoder, a pretrained text ViT-L CLIP encoder, and a GPT-2 text decoder finetuned by ourselves.",
"preview": "unidiffuser-v1.jpg"
},
"Sudo-AI Zero123": {
"path": "sudo-ai/zero123plus-v1.1"
}
}
2 changes: 0 additions & 2 deletions installer.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,8 +508,6 @@ def check_torch():
import xformers
if torch.__version__ != '2.0.1+cu118' and xformers.__version__ in ['0.0.22', '0.0.21', '0.0.20']:
log.warning(f'Likely incompatible torch with: xformers=={xformers.__version__} installed: torch=={torch.__version__} required: torch==2.1.0+cu118 - build xformers manually or downgrade torch')
if 'cu118' not in torch.__version__:
log.warning(f'Likely incompatible Cuda with: xformers=={xformers.__version__} installed: torch=={torch.__version__} required: torch==2.1.0+cu118 - build xformers manually or downgrade torch')
elif not args.experimental and not args.use_xformers:
uninstall('xformers')
except Exception as e:
Expand Down
2 changes: 1 addition & 1 deletion javascript/extensions.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ function extensions_apply(extensions_disabled_list, extensions_update_list, disa
if (x.name.startsWith('enable_') && !x.checked) disable.push(x.name.substring(7));
if (x.name.startsWith('update_') && x.checked) update.push(x.name.substring(7));
});
restart_reload();
restartReload();
log('Extensions apply:', { disable, update });
return [JSON.stringify(disable), JSON.stringify(update), disable_all];
}
Expand Down
2 changes: 1 addition & 1 deletion javascript/settings.js
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ onAfterUiUpdate(async () => {
const jsdata = textarea.value;
updateOpts(jsdata);
executeCallbacks(optionsChangedCallbacks);
register_drag_drop();
registerDragDrop();

Object.defineProperty(textarea, 'value', {
set(newValue) {
Expand Down
28 changes: 17 additions & 11 deletions javascript/ui.js
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ function recalculate_prompts_inpaint(...args) {
return Array.from(arguments);
}

function register_drag_drop() {
function registerDragDrop() {
const qs = gradioApp().getElementById('quicksettings');
if (!qs) return;
qs.addEventListener('dragover', (evt) => {
Expand Down Expand Up @@ -297,33 +297,33 @@ function getTranslation(...args) {
return null;
}

function monitor_server_status() {
function monitorServerStatus() {
document.open();
document.write(`
<html>
<head><title>SD.Next</title></head>
<body style="background: #222222; font-size: 1rem; font-family:monospace; margin-top:20%; color:lightgray; text-align:center">
<h1>Waiting for server...</h1>
<script>
function monitor_server_status() {
function monitorServerStatus() {
fetch('/sdapi/v1/progress')
.then((res) => { !res?.ok ? setTimeout(monitor_server_status, 1000) : location.reload(); })
.catch((e) => setTimeout(monitor_server_status, 1000))
.then((res) => { !res?.ok ? setTimeout(monitorServerStatus, 1000) : location.reload(); })
.catch((e) => setTimeout(monitorServerStatus, 1000))
}
window.onload = () => monitor_server_status();
window.onload = () => monitorServerStatus();
</script>
</body>
</html>
`);
document.close();
}

function restart_reload() {
function restartReload() {
document.body.style = 'background: #222222; font-size: 1rem; font-family:monospace; margin-top:20%; color:lightgray; text-align:center';
document.body.innerHTML = '<h1>Server shutdown in progress...</h1>';
fetch('/sdapi/v1/progress')
.then((res) => setTimeout(restart_reload, 1000))
.catch((e) => setTimeout(monitor_server_status, 500));
.then((res) => setTimeout(restartReload, 1000))
.catch((e) => setTimeout(monitorServerStatus, 500));
return [];
}

Expand Down Expand Up @@ -351,6 +351,12 @@ function selectVAE(name) {
log(`Change VAE: ${desiredVAEName}`);
}

function selectReference(name) {
console.log('HERE', name);
desiredCheckpointName = name;
gradioApp().getElementById('change_reference').click();
}

function currentImg2imgSourceResolution(_a, _b, scaleBy) {
const img = gradioApp().querySelector('#mode_img2img > div[style="display: block;"] img');
return img ? [img.naturalWidth, img.naturalHeight, scaleBy] : [0, 0, scaleBy];
Expand All @@ -361,7 +367,7 @@ function updateImg2imgResizeToTextAfterChangingImage() {
return [];
}

function create_theme_element() {
function createThemeElement() {
const el = document.createElement('img');
el.id = 'theme-preview';
el.className = 'theme-preview';
Expand Down Expand Up @@ -393,7 +399,7 @@ function previewTheme() {
if (theme) {
window.open(theme.subdomain, '_blank');
} else {
const el = document.getElementById('theme-preview') || create_theme_element();
const el = document.getElementById('theme-preview') || createThemeElement();
el.style.display = el.style.display === 'block' ? 'none' : 'block';
name = name.replace('/', '-');
el.src = `/file=html/${name}.jpg`;
Expand Down
Binary file added models/Reference/unidiffuser-v1.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
30 changes: 25 additions & 5 deletions modules/modelloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,21 +213,24 @@ def download_diffusers_model(hub_id: str, cache_dir: str = None, download_config
pipeline_dir = None

ok = True
err = None
try:
pipeline_dir = DiffusionPipeline.download(hub_id, **download_config)
except Exception as e:
err = e
ok = False
shared.log.warning(f"Diffusers download error: {hub_id} {e}")
if not ok:
# shared.log.warning(f"Diffusers download error: {hub_id} {e}")
if not ok and 'Repository Not Found' not in str(err):
try:
download_config.pop('load_connected_pipeline')
download_config.pop('variant')
pipeline_dir = hf.snapshot_download(hub_id, **download_config)
except Exception as e:
shared.log.warning(f"Diffusers hub download error: {hub_id} {e}")
except Exception:
# shared.log.warning(f"Diffusers download error: {hub_id} {e}")
pass

if pipeline_dir is None:
shared.log.error(f"Diffusers no pipeline folder: {hub_id}")
shared.log.error(f"Diffusers download error: {hub_id} {err}")
return None
try:
# TODO diffusers is this real error?
Expand Down Expand Up @@ -314,6 +317,23 @@ def find_diffuser(name: str):
return None


def load_reference(name: str):
found = [r for r in diffuser_repos if name == r['name'] or name == r['friendly'] or name == r['path']]
if len(found) > 0: # already downloaded
shared.log.debug(f'Reference model: {found[0]}')
return True
shared.log.debug(f'Reference download: {name}')
model_dir = download_diffusers_model(name, shared.opts.diffusers_dir)
if model_dir is None:
shared.log.debug(f'Reference download failed: {name}')
return False
else:
shared.log.debug(f'Reference download complete: {name}')
from modules import sd_models
sd_models.list_models()
return True


modelloader_directories = {}
cache_last = 0
cache_time = 1
Expand Down
21 changes: 18 additions & 3 deletions modules/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -1157,9 +1157,9 @@ def reload_sd_weights():
inputs=components,
outputs=[text_settings, result],
)
defaults_submit.click(fn=lambda: modules.shared.restore_defaults(restart=True), _js="restart_reload")
restart_submit.click(fn=lambda: modules.shared.restart_server(restart=True), _js="restart_reload")
shutdown_submit.click(fn=lambda: modules.shared.restart_server(restart=False), _js="restart_reload")
defaults_submit.click(fn=lambda: modules.shared.restore_defaults(restart=True), _js="restartReload")
restart_submit.click(fn=lambda: modules.shared.restart_server(restart=True), _js="restartReload")
shutdown_submit.click(fn=lambda: modules.shared.restart_server(restart=False), _js="restartReload")

for _i, k, _item in quicksettings_list:
component = component_dict[k]
Expand Down Expand Up @@ -1194,6 +1194,21 @@ def reload_sd_weights():
outputs=[component_dict['sd_vae'], text_settings],
)

def reference_submit(model):
from modules import modelloader
loaded = modelloader.load_reference(model)
if loaded:
return model if loaded else opts.sd_model_checkpoint
print('HERE', model, loaded)
return loaded

button_set_reference = gr.Button('Change reference', elem_id='change_reference', visible=False)
button_set_reference.click(
fn=reference_submit,
_js="function(v){ return desiredCheckpointName; }",
inputs=[component_dict['sd_model_checkpoint']],
outputs=[component_dict['sd_model_checkpoint']],
)
component_keys = [k for k in opts.data_labels.keys() if k in component_dict]

def get_settings_values():
Expand Down
14 changes: 9 additions & 5 deletions modules/ui_extra_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
import gradio as gr
from PIL import Image
from starlette.responses import FileResponse, JSONResponse
from modules import shared, scripts, modelloader
from modules import paths, shared, scripts, modelloader
from modules.ui_components import ToolButton
import modules.ui_symbols as symbols


allowed_dirs = []
dir_cache = {} # key=path, value=(mtime, listdir(path))
refresh_time = 0
Expand Down Expand Up @@ -238,8 +239,11 @@ def create_page(self, tabname, skip = False):
allowed_folders = [os.path.abspath(x) for x in self.allowed_directories_for_previews()]
for parentdir, dirs in {d: modelloader.directory_directories(d) for d in allowed_folders}.items():
for tgt in dirs.keys():
if shared.opts.diffusers_dir in tgt:
subdirs[os.path.basename(shared.opts.diffusers_dir)] = 1
if shared.backend == shared.Backend.DIFFUSERS:
if os.path.join(paths.models_path, 'Reference') in tgt:
subdirs['Reference'] = 1
if shared.opts.diffusers_dir in tgt:
subdirs[os.path.basename(shared.opts.diffusers_dir)] = 1
if 'models--' in tgt:
continue
subdir = tgt[len(parentdir):].replace("\\", "/")
Expand Down Expand Up @@ -737,7 +741,7 @@ def ui_scan_click(title):
return ui_refresh_click(title)

def ui_save_click():
from modules import paths, generation_parameters_copypaste
from modules import generation_parameters_copypaste
filename = os.path.join(paths.data_path, "params.txt")
if os.path.exists(filename):
with open(filename, "r", encoding="utf8") as file:
Expand All @@ -749,7 +753,7 @@ def ui_save_click():
return res

def ui_quicksave_click(name):
from modules import paths, generation_parameters_copypaste
from modules import generation_parameters_copypaste
fn = os.path.join(paths.data_path, "params.txt")
if os.path.exists(fn):
with open(fn, "r", encoding="utf8") as file:
Expand Down
35 changes: 31 additions & 4 deletions modules/ui_extra_networks_checkpoints.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import html
import json
import os
from modules import shared, ui_extra_networks, sd_models
from modules import shared, ui_extra_networks, sd_models, paths

reference_dir = os.path.join(paths.models_path, 'Reference')

class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
def __init__(self):
Expand All @@ -11,11 +12,35 @@ def __init__(self):
def refresh(self):
shared.refresh_checkpoints()

def list_reference(self):
if shared.backend != shared.Backend.DIFFUSERS:
return []
reference_models = shared.readfile(os.path.join('html', 'reference.json'))
for k, v in reference_models.items():
name = os.path.join(reference_dir, k)
yield {
"type": 'Model',
"name": name,
"title": name,
"filename": v['path'],
"search_term": self.search_terms_from_path(name),
"preview": self.find_preview(os.path.join(reference_dir, os.path.basename(v['path']))),
"local_preview": f"{os.path.splitext(name)[0]}.{shared.opts.samples_format}",
"onclick": '"' + html.escape(f"""return selectReference({json.dumps(v['path'])})""") + '"',
"hash": None,
"mtime": 0,
"size": 0,
"info": {},
"metadata": {},
"description": v.get('desc', ''),
}

def list_items(self):
checkpoint: sd_models.CheckpointInfo
checkpoints = sd_models.checkpoints_list.copy()
for name, checkpoint in checkpoints.items():
try:
exists = os.path.exists(checkpoint.filename)
record = {
"type": 'Model',
"name": checkpoint.name,
Expand All @@ -27,14 +52,16 @@ def list_items(self):
"local_preview": f"{os.path.splitext(checkpoint.filename)[0]}.{shared.opts.samples_format}",
"metadata": checkpoint.metadata,
"onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"',
"mtime": os.path.getmtime(checkpoint.filename),
"size": os.path.getsize(checkpoint.filename),
"mtime": os.path.getmtime(checkpoint.filename) if exists else 0,
"size": os.path.getsize(checkpoint.filename) if exists else 0,
}
record["info"] = self.find_info(checkpoint.filename)
record["description"] = self.find_description(checkpoint.filename, record["info"])
yield record
except Exception as e:
shared.log.debug(f"Extra networks error: type=model file={name} {e}")
for record in self.list_reference():
yield record

def allowed_directories_for_previews(self):
return [v for v in [shared.opts.ckpt_dir, shared.opts.diffusers_dir, sd_models.model_path] if v is not None]
return [v for v in [shared.opts.ckpt_dir, shared.opts.diffusers_dir, reference_dir, sd_models.model_path] if v is not None]

0 comments on commit d878970

Please sign in to comment.