From a0d55a5956e23f12638b6bb1b666169adc2ebb9e Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Wed, 13 Nov 2024 18:22:20 -0500 Subject: [PATCH] pulid with refine pass Signed-off-by: Vladimir Mandic --- modules/processing_vae.py | 6 +++ modules/prompt_parser_diffusers.py | 76 +++++++++++++++++------------- modules/pulid/pulid_sdxl.py | 4 ++ wiki | 2 +- 4 files changed, 53 insertions(+), 35 deletions(-) diff --git a/modules/processing_vae.py b/modules/processing_vae.py index 0473beeb1..3c0357c81 100644 --- a/modules/processing_vae.py +++ b/modules/processing_vae.py @@ -35,6 +35,8 @@ def create_latents(image, p, dtype=None, device=None): def full_vae_decode(latents, model): t0 = time.time() + if not hasattr(model, 'vae') and hasattr(model, 'pipe'): + model = model.pipe if model is None or not hasattr(model, 'vae'): shared.log.error('VAE not found in model') return [] @@ -148,6 +150,8 @@ def taesd_vae_encode(image): def vae_decode(latents, model, output_type='np', full_quality=True, width=None, height=None): t0 = time.time() model = model or shared.sd_model + if not hasattr(model, 'vae') and hasattr(model, 'pipe'): + model = model.pipe if latents is None or not torch.is_tensor(latents): # already decoded return latents prev_job = shared.state.job @@ -196,6 +200,8 @@ def vae_decode(latents, model, output_type='np', full_quality=True, width=None, def vae_encode(image, model, full_quality=True): # pylint: disable=unused-variable if shared.state.interrupted or shared.state.skipped: return [] + if not hasattr(model, 'vae') and hasattr(model, 'pipe'): + model = model.pipe if not hasattr(model, 'vae'): shared.log.error('VAE not found in model') return [] diff --git a/modules/prompt_parser_diffusers.py b/modules/prompt_parser_diffusers.py index 97696a5e1..234272907 100644 --- a/modules/prompt_parser_diffusers.py +++ b/modules/prompt_parser_diffusers.py @@ -19,20 +19,25 @@ embedder = None -def prompt_compatible(): +def prompt_compatible(pipe = None): + pipe = pipe or shared.sd_model if ( - 'StableDiffusion' not in shared.sd_model.__class__.__name__ and - 'DemoFusion' not in shared.sd_model.__class__.__name__ and - 'StableCascade' not in shared.sd_model.__class__.__name__ and - 'Flux' not in shared.sd_model.__class__.__name__ + 'StableDiffusion' not in pipe.__class__.__name__ and + 'DemoFusion' not in pipe.__class__.__name__ and + 'StableCascade' not in pipe.__class__.__name__ and + 'Flux' not in pipe.__class__.__name__ ): - shared.log.warning(f"Prompt parser not supported: {shared.sd_model.__class__.__name__}") + shared.log.warning(f"Prompt parser not supported: {pipe.__class__.__name__}") return False return True -def prepare_model(): - pipe = shared.sd_model +def prepare_model(pipe = None): + pipe = pipe or shared.sd_model + if not hasattr(pipe, "text_encoder") and hasattr(shared.sd_model, "pipe"): + pipe = pipe.pipe + if not hasattr(pipe, "text_encoder"): + return None if shared.opts.diffusers_offload_mode == "balanced": pipe = sd_models.apply_balanced_offload(pipe) elif hasattr(pipe, "maybe_free_model_hooks"): @@ -62,7 +67,10 @@ def __init__(self, prompts, negative_prompts, steps, clip_skip, p): earlyout = self.checkcache(p) if earlyout: return - pipe = prepare_model() + pipe = prepare_model(p.sd_model) + if pipe is None: + shared.log.error("Prompt encode: cannot find text encoder in model") + return # per prompt in batch for batchidx, (prompt, negative_prompt) in enumerate(zip(self.prompts, self.negative_prompts)): self.prepare_schedule(prompt, negative_prompt) @@ -168,8 +176,8 @@ def encode(self, pipe, positive_prompt, negative_prompt, batchidx): self.negative_pooleds[batchidx].append(negative_pooled) if debug_enabled: - get_tokens('positive', positive_prompt) - get_tokens('negative', negative_prompt) + get_tokens(pipe, 'positive', positive_prompt) + get_tokens(pipe, 'negative', negative_prompt) pipe = prepare_model() def __call__(self, key, step=0): @@ -288,25 +296,25 @@ def get_prompt_schedule(prompt, steps): return temp, len(schedule) > 1 -def get_tokens(msg, prompt): +def get_tokens(pipe, msg, prompt): global token_dict, token_type # pylint: disable=global-statement if not shared.native: return 0 - if shared.sd_loaded and hasattr(shared.sd_model, 'tokenizer') and shared.sd_model.tokenizer is not None: + if shared.sd_loaded and hasattr(pipe, 'tokenizer') and pipe.tokenizer is not None: if token_dict is None or token_type != shared.sd_model_type: token_type = shared.sd_model_type - fn = shared.sd_model.tokenizer.name_or_path + fn = pipe.tokenizer.name_or_path if fn.endswith('tokenizer'): - fn = os.path.join(shared.sd_model.tokenizer.name_or_path, 'vocab.json') + fn = os.path.join(pipe.tokenizer.name_or_path, 'vocab.json') else: - fn = os.path.join(shared.sd_model.tokenizer.name_or_path, 'tokenizer', 'vocab.json') + fn = os.path.join(pipe.tokenizer.name_or_path, 'tokenizer', 'vocab.json') token_dict = shared.readfile(fn, silent=True) - for k, v in shared.sd_model.tokenizer.added_tokens_decoder.items(): + for k, v in pipe.tokenizer.added_tokens_decoder.items(): token_dict[str(v)] = k shared.log.debug(f'Tokenizer: words={len(token_dict)} file="{fn}"') - has_bos_token = shared.sd_model.tokenizer.bos_token_id is not None - has_eos_token = shared.sd_model.tokenizer.eos_token_id is not None - ids = shared.sd_model.tokenizer(prompt) + has_bos_token = pipe.tokenizer.bos_token_id is not None + has_eos_token = pipe.tokenizer.eos_token_id is not None + ids = pipe.tokenizer(prompt) ids = getattr(ids, 'input_ids', []) tokens = [] for i in ids: @@ -337,10 +345,10 @@ def normalize_prompt(pairs: list): return pairs -def get_prompts_with_weights(prompt: str): +def get_prompts_with_weights(pipe, prompt: str): t0 = time.time() - manager = DiffusersTextualInversionManager(shared.sd_model, shared.sd_model.tokenizer or shared.sd_model.tokenizer_2) - prompt = manager.maybe_convert_prompt(prompt, shared.sd_model.tokenizer or shared.sd_model.tokenizer_2) + manager = DiffusersTextualInversionManager(pipe, pipe.tokenizer or pipe.tokenizer_2) + prompt = manager.maybe_convert_prompt(prompt, pipe.tokenizer or pipe.tokenizer_2) texts_and_weights = prompt_parser.parse_prompt_attention(prompt) if shared.opts.prompt_mean_norm: texts_and_weights = normalize_prompt(texts_and_weights) @@ -348,7 +356,7 @@ def get_prompts_with_weights(prompt: str): if debug_enabled: all_tokens = 0 for text in texts: - tokens = get_tokens('section', text) + tokens = get_tokens(pipe, 'section', text) all_tokens += tokens debug(f'Prompt tokenizer: parser={shared.opts.prompt_attention} tokens={all_tokens}') debug(f'Prompt: weights={texts_and_weights} time={(time.time() - t0):.3f}') @@ -412,7 +420,7 @@ def pad_to_same_length(pipe, embeds, empty_embedding_providers=None): return embeds -def split_prompts(prompt, SD3 = False): +def split_prompts(pipe, prompt, SD3 = False): if prompt.find("TE2:") != -1: prompt, prompt2 = prompt.split("TE2:") else: @@ -430,7 +438,7 @@ def split_prompts(prompt, SD3 = False): prompt3 = " " if prompt3.strip() == "" else prompt3.strip() if SD3 and prompt3 != " ": - ps, _ws = get_prompts_with_weights(prompt3) + ps, _ws = get_prompts_with_weights(pipe, prompt3) prompt3 = " ".join(ps) return prompt, prompt2, prompt3 @@ -438,15 +446,15 @@ def split_prompts(prompt, SD3 = False): def get_weighted_text_embeddings(pipe, prompt: str = "", neg_prompt: str = "", clip_skip: int = None): device = devices.device SD3 = hasattr(pipe, 'text_encoder_3') - prompt, prompt_2, prompt_3 = split_prompts(prompt, SD3) - neg_prompt, neg_prompt_2, neg_prompt_3 = split_prompts(neg_prompt, SD3) + prompt, prompt_2, prompt_3 = split_prompts(pipe, prompt, SD3) + neg_prompt, neg_prompt_2, neg_prompt_3 = split_prompts(pipe, neg_prompt, SD3) if prompt != prompt_2: - ps = [get_prompts_with_weights(p) for p in [prompt, prompt_2]] - ns = [get_prompts_with_weights(p) for p in [neg_prompt, neg_prompt_2]] + ps = [get_prompts_with_weights(pipe, p) for p in [prompt, prompt_2]] + ns = [get_prompts_with_weights(pipe, p) for p in [neg_prompt, neg_prompt_2]] else: - ps = 2 * [get_prompts_with_weights(prompt)] - ns = 2 * [get_prompts_with_weights(neg_prompt)] + ps = 2 * [get_prompts_with_weights(pipe, prompt)] + ns = 2 * [get_prompts_with_weights(pipe, neg_prompt)] positives, positive_weights = zip(*ps) negatives, negative_weights = zip(*ns) @@ -561,8 +569,8 @@ def get_weighted_text_embeddings(pipe, prompt: str = "", neg_prompt: str = "", c def get_xhinker_text_embeddings(pipe, prompt: str = "", neg_prompt: str = "", clip_skip: int = None): is_sd3 = hasattr(pipe, 'text_encoder_3') - prompt, prompt_2, _prompt_3 = split_prompts(prompt, is_sd3) - neg_prompt, neg_prompt_2, _neg_prompt_3 = split_prompts(neg_prompt, is_sd3) + prompt, prompt_2, _prompt_3 = split_prompts(pipe, prompt, is_sd3) + neg_prompt, neg_prompt_2, _neg_prompt_3 = split_prompts(pipe, neg_prompt, is_sd3) try: prompt = pipe.maybe_convert_prompt(prompt, pipe.tokenizer) neg_prompt = pipe.maybe_convert_prompt(neg_prompt, pipe.tokenizer) diff --git a/modules/pulid/pulid_sdxl.py b/modules/pulid/pulid_sdxl.py index 3053d759d..7ee9a138e 100644 --- a/modules/pulid/pulid_sdxl.py +++ b/modules/pulid/pulid_sdxl.py @@ -4,6 +4,7 @@ import numpy as np import torch import torch.nn as nn +from PIL import Image from diffusers import StableDiffusionXLPipeline from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput @@ -353,6 +354,9 @@ def __call__( debug(f'PulID call: width={width} height={height} cfg={guidance_scale} steps={num_inference_steps} seed={seed} strength={strength} id_scale={id_scale} output={output_type}') self.step = 0 # pylint: disable=attribute-defined-outside-init self.callback_on_step_end = callback_on_step_end # pylint: disable=attribute-defined-outside-init + if isinstance(image, list) and len(image) > 0 and isinstance(image[0], Image.Image): + if image[0].width != width or image[0].height != height: # override width/height if different + width, height = image[0].width, image[0].height size = (1, height, width) # sigmas sigmas = self.get_sigmas_karras(num_inference_steps).to(self.device) diff --git a/wiki b/wiki index 352fc655b..96f28bb7c 160000 --- a/wiki +++ b/wiki @@ -1 +1 @@ -Subproject commit 352fc655b0dc9edb22aac093186da087ba18b474 +Subproject commit 96f28bb7cec5a4e198a3244a88309f1957f75d03