From 13987cf7b8bd00d40432580d4a1fd7a4f7513b7a Mon Sep 17 00:00:00 2001 From: shiimizu Date: Sun, 1 Sep 2024 00:39:10 -0700 Subject: [PATCH] Removed hooks in favor of a `PhotoMakerLoraLoaderPlus` node. Resolves #37 --- README.md | 6 +- photomaker.py | 140 ++++++++++++++++++--------- utils.py | 258 ++------------------------------------------------ 3 files changed, 108 insertions(+), 296 deletions(-) diff --git a/README.md b/README.md index f2bd6be..63e557d 100644 --- a/README.md +++ b/README.md @@ -20,18 +20,20 @@ PhotoMaker implementation that follows the ComfyUI way of doing things. The code git clone https://github.com/shiimizu/ComfyUI-PhotoMaker-Plus.git ``` 4. Download the model(s) from Hugging Face ([V1](https://huggingface.co/TencentARC/PhotoMaker), [V2](https://huggingface.co/TencentARC/PhotoMaker-V2)) and place it in a `photomaker` folder in your `models` folder such as `ComfyUI/models/photomaker`. -5. Load the LoRA within the model using the `LoraLoaderModelOnly` node. +5. Check out the [example workflows](https://github.com/shiimizu/ComfyUI-PhotoMaker-Plus/tree/main/examples). ## Features of this `Plus` version * Better face resemblance by using `CLIPImageProcessor` like in the original code. -* Automatic PhotoMaker LoRA detection & loading via the LoraLoader nodes. * Customizable trigger word * Allows multiple trigger words in the prompt * Extra nodes such as `PhotoMakerStyles` and `PrepImagesForClipVisionFromPath` ## Important news +**2024-09-01** +* A `PhotoMakerLoraLoaderPlus` node was added. Use that to load the LoRA. + **2024-07-26** * Support for PhotoMaker V2. This uses InsightFace, so make sure to use the new `PhotoMakerLoaderPlus` and `PhotoMakerInsightFaceLoader` nodes. diff --git a/photomaker.py b/photomaker.py index 467eb99..0de697d 100644 --- a/photomaker.py +++ b/photomaker.py @@ -1,52 +1,102 @@ +import torch +import hashlib +import os +import logging +import numpy as np import comfy.clip_vision import comfy.clip_model import comfy.model_management import comfy.utils +import comfy.sd +import folder_paths +import torchvision.transforms.v2 as T from comfy.sd import CLIP +from typing import Union +from collections import Counter +from torch import Tensor from transformers import CLIPImageProcessor from transformers.image_utils import PILImageResampling -from collections import Counter -import folder_paths -import torch -import os +from .insightface_package import analyze_faces, insightface_loader from .model import PhotoMakerIDEncoder from .model_v2 import PhotoMakerIDEncoder_CLIPInsightfaceExtendtoken from .utils import LoadImageCustom, load_image, prepImage, crop_image_pil, tokenize_with_trigger_word -from folder_paths import folder_names_and_paths, models_dir, supported_pt_extensions, add_model_folder_path -from torch import Tensor -import hashlib -from typing import Union -from .insightface_package import analyze_faces, insightface_loader -import numpy as np -import torchvision.transforms.v2 as T - -INSIGHTFACE_DIR = os.path.join(models_dir, "insightface") - -folder_names_and_paths["photomaker"] = ([os.path.join(models_dir, "photomaker")], supported_pt_extensions) -add_model_folder_path("loras", folder_names_and_paths["photomaker"][0][0]) +from .style_template import styles class PhotoMakerLoaderPlus: + def __init__(self): + self.loaded_lora = None + self.loaded_clipvision = None + @classmethod def INPUT_TYPES(s): - return {"required": { "photomaker_model_name": (folder_paths.get_filename_list("photomaker"), ), - }} - RETURN_TYPES = ("PHOTOMAKER",) + return {"required": { + "photomaker_model_name": (folder_paths.get_filename_list("photomaker"), ), + }, + } + RETURN_TYPES = ("PHOTOMAKER", ) FUNCTION = "load_photomaker_model" CATEGORY = "PhotoMaker" def load_photomaker_model(self, photomaker_model_name): - photomaker_model_path = folder_paths.get_full_path("photomaker", photomaker_model_name) - if 'v1' in photomaker_model_name: - photomaker_model = PhotoMakerIDEncoder() - else: + self.load_data(None, None, photomaker_model_name, 0, 0)[0] + if 'qformer_perceiver.token_norm.weight' in self.loaded_clipvision[1].keys(): photomaker_model = PhotoMakerIDEncoder_CLIPInsightfaceExtendtoken() - data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True) - if "id_encoder" in data: - data = data["id_encoder"] - photomaker_model.load_state_dict(data) + else: + photomaker_model = PhotoMakerIDEncoder() + photomaker_model.load_state_dict(self.loaded_clipvision[1]) + photomaker_model.loader = self + photomaker_model.filename = photomaker_model_name return (photomaker_model,) + def load_data(self, model, clip, name, strength_model, strength_clip): + model_lora, clip_lora = model, clip + + path = folder_paths.get_full_path("photomaker", name) + lora = None + if self.loaded_lora is not None: + if self.loaded_lora[0] == path: + lora = self.loaded_lora[1] + else: + temp = self.loaded_lora + self.loaded_lora = None + del temp + temp = self.loaded_clipvision + self.loaded_clipvision = None + del temp + + if lora is None: + data = comfy.utils.load_torch_file(path, safe_load=True) + clipvision = data.get("id_encoder", None) + lora = data.get("lora_weights", None) + self.loaded_lora = (path, lora) + self.loaded_clipvision = (path, clipvision) + + if model is not None and (strength_model > 0 or strength_clip > 0): + model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip) + return (model_lora, clip_lora) + +class PhotoMakerLoraLoaderPlus: + def __init__(self): + self.loaded_lora = None + self.loaded_clipvision = None + + @classmethod + def INPUT_TYPES(s): + return {"required": { + "model": ("MODEL",), + "photomaker": ("PHOTOMAKER",), + "lora_strength": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}), + }, + } + RETURN_TYPES = ("MODEL", ) + FUNCTION = "load_photomaker_lora" + + CATEGORY = "PhotoMaker" + + def load_photomaker_lora(self, model, photomaker, lora_strength): + return (photomaker.loader.load_data(model, None, photomaker.filename, lora_strength, 0)[0],) + class PhotoMakerInsightFaceLoader: @classmethod def INPUT_TYPES(s): @@ -86,26 +136,33 @@ def INPUT_TYPES(s): def apply_photomaker(self, clip: CLIP, photomaker: Union[PhotoMakerIDEncoder, PhotoMakerIDEncoder_CLIPInsightfaceExtendtoken], image: Tensor, trigger_word: str, text: str, insightface_opt=None): if (num_images := len(image)) == 0: raise ValueError("No image provided or found.") + trigger_word=trigger_word.strip() tokens = clip.tokenize(text) class_tokens_mask = {} out_tokens = {} + num_tokens = getattr(photomaker, 'num_tokens', 2) + num_tokens = 1 for key, val in tokens.items(): clip_tokenizer = getattr(clip.tokenizer, f'clip_{key}', clip.tokenizer) - img_token = clip_tokenizer.tokenizer(trigger_word.strip(), truncation=False, add_special_tokens=False)["input_ids"][0] # only get the first token + img_token = clip_tokenizer.tokenizer(trigger_word, truncation=False, add_special_tokens=False)["input_ids"][0] # only get the first token _tokens = torch.tensor([[tpy[0] for tpy in tpy_] for tpy_ in val ] , dtype=torch.int32) _weights = torch.tensor([[tpy[1] for tpy in tpy_] for tpy_ in val] , dtype=torch.float32) start_token = clip_tokenizer.start_token end_token = clip_tokenizer.end_token pad_token = clip_tokenizer.pad_token - tokens_mask = tokenize_with_trigger_word(_tokens, _weights, num_images,img_token,start_token, end_token, pad_token, return_mask=True)[0] - tokens_new, weights_new, num_trigger_tokens_processed = tokenize_with_trigger_word(_tokens, _weights, num_images,img_token,start_token, end_token, pad_token) + tokens_mask = tokenize_with_trigger_word(_tokens, _weights, num_images, num_tokens, img_token,start_token, end_token, pad_token, return_mask=True)[0] + tokens_new, weights_new, num_trigger_tokens_processed = tokenize_with_trigger_word(_tokens, _weights, num_images, num_tokens, img_token,start_token, end_token, pad_token) token_weight_pairs = [[(tt,ww) for tt,ww in zip(x.tolist(), y.tolist())] for x,y in zip(tokens_new, weights_new)] mask = (tokens_mask == -1).tolist() class_tokens_mask[key] = mask out_tokens[key] = token_weight_pairs cond, pooled = clip.encode_from_tokens(out_tokens, return_pooled=True) + if num_trigger_tokens_processed == 0 or not trigger_word: + logging.warning("\033[33mWarning:\033[0m No trigger token found.") + return ([[cond, {"pooled_output": pooled}]],) + prompt_embeds = cond device_orig = prompt_embeds.device first_key = next(iter(tokens.keys())) @@ -129,8 +186,9 @@ def apply_photomaker(self, clip: CLIP, photomaker: Union[PhotoMakerIDEncoder, Ph pixel_values = comfy.clip_vision.clip_preprocess(image.to(photomaker.load_device)).float() if photomaker.__class__.__name__ == 'PhotoMakerIDEncoder': - cond = photomaker(id_pixel_values=pixel_values.unsqueeze(0), prompt_embeds=cond.to(photomaker.load_device), - class_tokens_mask=torch.tensor(class_tokens_mask, dtype=torch.bool, device=photomaker.load_device)) + cond = photomaker(id_pixel_values=pixel_values.unsqueeze(0), + prompt_embeds=cond.to(photomaker.load_device), + class_tokens_mask=torch.tensor(class_tokens_mask, dtype=torch.bool, device=photomaker.load_device).unsqueeze(0)) else: if insightface_opt is None: raise ValueError(f"InsightFace is required for PhotoMaker V2") @@ -171,9 +229,6 @@ def tensor_to_pil_np(_img): return ([[cond, {"pooled_output": pooled}]],) - -from .style_template import styles - class PhotoMakerStyles: @classmethod def INPUT_TYPES(s): @@ -262,8 +317,8 @@ def prep_images_for_clip_vision_from_path(self, path:str, interpolation:str, cro clip_preprocess = CLIPImageProcessor(resample=resample, do_normalize=False, do_resize=do_resize) id_pixel_values = clip_preprocess(input_id_images, return_tensors="pt").pixel_values.movedim(1,-1) except TypeError as err: - print('[PhotoMaker]:', err) - print('[PhotoMaker]: You may need to update transformers.') + logging.warning('[PhotoMaker]:', err) + logging.warning('[PhotoMaker]: You may need to update transformers.') input_id_images = [self.image_loader.load_image(image_path)[0] for image_path in image_path_list] do_resize = not all(img.shape[-3:-3+2] == size for img in input_id_images) if do_resize: @@ -272,26 +327,21 @@ def prep_images_for_clip_vision_from_path(self, path:str, interpolation:str, cro id_pixel_values = torch.cat(input_id_images) return (id_pixel_values,) -# supported = False -# try: -# from comfy_extras.nodes_photomaker import PhotoMakerLoader as _PhotoMakerLoader -# supported = True -# except Exception: ... NODE_CLASS_MAPPINGS = { -# **({} if supported else {"PhotoMakerLoader": PhotoMakerLoaderPlus}), "PhotoMakerLoaderPlus": PhotoMakerLoaderPlus, "PhotoMakerEncodePlus": PhotoMakerEncodePlus, "PhotoMakerStyles": PhotoMakerStyles, + "PhotoMakerLoraLoaderPlus": PhotoMakerLoraLoaderPlus, "PrepImagesForClipVisionFromPath": PrepImagesForClipVisionFromPath, "PhotoMakerInsightFaceLoader": PhotoMakerInsightFaceLoader, } NODE_DISPLAY_NAME_MAPPINGS = { -# **({} if supported else {"PhotoMakerLoader": "Load PhotoMaker"}), "PhotoMakerLoaderPlus": "PhotoMaker Loader Plus", "PhotoMakerEncodePlus": "PhotoMaker Encode Plus", "PhotoMakerStyles": "Apply PhotoMaker Style", + "PhotoMakerLoraLoaderPlus": "PhotoMaker LoRA Loader Plus", "PrepImagesForClipVisionFromPath": "Prepare Images For CLIP Vision From Path", "PhotoMakerInsightFaceLoader": "PhotoMaker InsightFace Loader", -} +} \ No newline at end of file diff --git a/utils.py b/utils.py index 5fee7eb..e13c278 100644 --- a/utils.py +++ b/utils.py @@ -1,171 +1,13 @@ import os -import sys import PIL import PIL.Image import PIL.ImageOps -import inspect -import importlib -import types -import functools -from textwrap import dedent, indent -from copy import copy import torch -from typing import List, Union -from collections import namedtuple -from .model import PhotoMakerIDEncoder -import comfy.sd1_clip -from comfy.sd1_clip import escape_important, token_weights, unescape_important +from typing import Union import torch.nn.functional as F import torchvision.transforms as TT -Hook = namedtuple('Hook', ['fn', 'module_name', 'target', 'orig_key', 'module_name_nt', 'module_name_unix']) - -def hook_clip_model_CLIPVisionModelProjection(): - return create_hook(PhotoMakerIDEncoder, 'comfy.clip_model', 'CLIPVisionModelProjection') - -def hook_tokenize_with_weights(): - import comfy.sd1_clip - if not hasattr(comfy.sd1_clip.SDTokenizer, 'tokenize_with_weights_original'): - comfy.sd1_clip.SDTokenizer.tokenize_with_weights_original = comfy.sd1_clip.SDTokenizer.tokenize_with_weights - comfy.sd1_clip.SDTokenizer.tokenize_with_weights = tokenize_with_weights - return create_hook(tokenize_with_weights, 'comfy.sd1_clip', 'SDTokenizer.tokenize_with_weights') - -def hook_load_torch_file(): - import comfy.utils - if not hasattr(comfy.utils, 'load_torch_file_original'): - comfy.utils.load_torch_file_original = comfy.utils.load_torch_file - replace_str=""" - if sd.get('id_encoder', None) and (lora_weights:=sd.get('lora_weights', None)) and len(sd) == 2: - def find_outer_instance(target:str, target_type): - import inspect - frame = inspect.currentframe() - i = 0 - while frame and i < 5: - if (found:=frame.f_locals.get(target, None)) is not None: - if isinstance(found, target_type): - return found - frame = frame.f_back - i += 1 - return None - if find_outer_instance('lora_name', str) is not None: - sd = lora_weights - return sd""" - source = inspect.getsource(comfy.utils.load_torch_file_original) - modified_source = source.replace("return sd", replace_str) - fn = write_to_file_and_return_fn(comfy.utils.load_torch_file_original, modified_source, 'w') - return create_hook(fn, 'comfy.utils') - -def create_hook(fn, module_name, target = None, orig_key = None): - if target is None: target = fn.__name__ - if orig_key is None: orig_key = f'{target}_original' - module_name_nt = '\\'.join(module_name.split('.')) - module_name_unix = '/'.join(module_name.split('.')) - return Hook(fn, module_name, target, orig_key, module_name_nt, module_name_unix) - -def hook_all(restore=False, hooks = None): - if hooks is None: - hooks: List[Hook] = [ - hook_clip_model_CLIPVisionModelProjection(), - ] - for m in list(sys.modules.keys()): - for hook in hooks: - if hook is None: - continue - if hook.module_name == m or (os.name != 'nt' and m.endswith(hook.module_name_unix)) or (os.name == 'nt' and m.endswith(hook.module_name_nt)): - if hasattr(sys.modules[m], hook.target): - if not hasattr(sys.modules[m], hook.orig_key): - if (orig_fn:=getattr(sys.modules[m], hook.target, None)) is not None: - setattr(sys.modules[m], hook.orig_key, orig_fn) - if restore: - setattr(sys.modules[m], hook.target, getattr(sys.modules[m], hook.orig_key, None)) - else: - setattr(sys.modules[m], hook.target, hook.fn) - -def tokenize_with_weights(self: comfy.sd1_clip.SDTokenizer, text:str, return_word_ids=False, tokens=None, return_tokens=False): - ''' - Takes a prompt and converts it to a list of (token, weight, word id) elements. - Tokens can both be integer tokens and pre computed CLIP tensors. - Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens. - Returned list has the dimensions NxM where M is the input size of CLIP - ''' - - if tokens is None: - tokens = [] - if not tokens: - text = escape_important(text) - parsed_weights = token_weights(text, 1.0) - - #tokenize words - tokens = [] - for weighted_segment, weight in parsed_weights: - to_tokenize = unescape_important(weighted_segment).replace("\n", " ").split(' ') - to_tokenize = [x for x in to_tokenize if x != ""] - for word in to_tokenize: - #if we find an embedding, deal with the embedding - if word.startswith(self.embedding_identifier) and self.embedding_directory is not None: - embedding_name = word[len(self.embedding_identifier):].strip('\n') - embed, leftover = self._try_get_embedding(embedding_name) - if embed is None: - print(f"warning, embedding:{embedding_name} does not exist, ignoring") - else: - if len(embed.shape) == 1: - tokens.append([(embed, weight)]) - else: - tokens.append([(embed[x], weight) for x in range(embed.shape[0])]) - #if we accidentally have leftover text, continue parsing using leftover, else move on to next word - if leftover != "": - word = leftover - else: - continue - #parse word - tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][self.tokens_start:-1]]) - if return_tokens: return tokens - - #reshape token array to CLIP input size - batched_tokens = [] - batch = [] - if self.start_token is not None: - batch.append((self.start_token, 1.0, 0)) - batched_tokens.append(batch) - for i, t_group in enumerate(tokens): - #determine if we're going to try and keep the tokens in a single batch - is_large = len(t_group) >= self.max_word_length - - while len(t_group) > 0: - if len(t_group) + len(batch) > self.max_length - 1: - remaining_length = self.max_length - len(batch) - 1 - #break word in two and add end token - if is_large: - batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]]) - batch.append((self.end_token, 1.0, 0)) - t_group = t_group[remaining_length:] - #add end token and pad - else: - batch.append((self.end_token, 1.0, 0)) - if self.pad_to_max_length: - batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length)) - #start new batch - batch = [] - if self.start_token is not None: - batch.append((self.start_token, 1.0, 0)) - batched_tokens.append(batch) - else: - batch.extend([(t,w,i+1) for t,w in t_group]) - t_group = [] - - #fill last batch - batch.append((self.end_token, 1.0, 0)) - if self.pad_to_max_length: - batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch))) - if self.min_length is not None and len(batch) < self.min_length: - batch.extend([(self.pad_token, 1.0, 0)] * (self.min_length - len(batch))) - - if not return_word_ids: - batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens] - - return batched_tokens - -def tokenize_with_trigger_word(tokens, weights, num_images, img_token, start_token=49406, end_token=49407, pad_token=0, max_len=77, return_mask=False): +def tokenize_with_trigger_word(tokens, weights, num_images, num_tokens, img_token, start_token=49406, end_token=49407, pad_token=0, max_len=77, return_mask=False): """ Filters out the image token(s). Repeats the preceding token if any. @@ -178,26 +20,26 @@ def tokenize_with_trigger_word(tokens, weights, num_images, img_token, start_tok split = torch.tensor_split(clean_tokens, img_token_indices + 1, dim=-1) split_mask = torch.tensor_split(clean_tokens_mask, img_token_indices + 1, dim=-1) - l0 = [] - lw0 = [] + tt = [] + ww = [] for chunk, chunk_mask in zip(split, split_mask): img_token_exists = chunk == img_token img_token_not_exists = ~img_token_exists - pad_amount = img_token_exists.nonzero().view(-1).shape[0] * num_images + pad_amount = img_token_exists.nonzero().view(-1).shape[0] * num_images * num_tokens chunk_clean, chunk_mask_clean = chunk[img_token_not_exists], chunk_mask[img_token_not_exists] if pad_amount > 0 and len(chunk_clean) > 0: count += 1 - l0.append(torch.nn.functional.pad(chunk_clean[:-1], (0, pad_amount), 'constant', chunk_clean[-1] if not return_mask else -1)) - lw0.append(torch.nn.functional.pad(chunk_mask_clean[:-1], (0, pad_amount), 'constant', chunk_mask_clean[-1] if not return_mask else -1)) + tt.append(torch.nn.functional.pad(chunk_clean[:-1], (0, pad_amount), 'constant', chunk_clean[-1] if not return_mask else -1)) + ww.append(torch.nn.functional.pad(chunk_mask_clean[:-1], (0, pad_amount), 'constant', chunk_mask_clean[-1] if not return_mask else -1)) if count == 0: return (tokens, weights, count) # rebatch and pad out = [] - outw=[] + outw = [] one = torch.tensor([1.0]) - for tc, tcw in zip(torch.cat(l0).split(max_len - 2), torch.cat(lw0).split(max_len - 2)): + for tc, tcw in zip(torch.cat(tt).split(max_len - 2), torch.cat(ww).split(max_len - 2)): out.append(torch.cat([torch.tensor([start_token]), tc, torch.tensor([end_token])])) outw.append(torch.cat([one, tcw, one])) @@ -376,85 +218,3 @@ def prepImage(image, interpolation="LANCZOS", crop_position="center", size=(224, output = output.permute([0,2,3,1]) return output - -def inject_code(original_func, data, mode='a'): - # Get the source code of the original function - original_source = inspect.getsource(original_func) - - # Split the source code into lines - lines = original_source.split("\n") - - for item in data: - # Find the line number of the target line - target_line_number = None - for i, line in enumerate(lines): - if item['target_line'] not in line: continue - target_line_number = i + 1 - if item.get("mode","insert") == "replace": - lines[i] = lines[i].replace(item['target_line'], item['code_to_insert']) - break - - # Find the indentation of the line where the new code will be inserted - indentation = '' - for char in line: - if char == ' ': - indentation += char - else: - break - - # Indent the new code to match the original - code_to_insert = item['code_to_insert'] - if item.get("dedent",True): - code_to_insert = dedent(item['code_to_insert']) - code_to_insert = indent(code_to_insert, indentation) - - break - - # Insert the code to be injected after the target line - if item.get("mode","insert") == "insert" and target_line_number is not None: - lines.insert(target_line_number, code_to_insert) - - # Recreate the modified source code - modified_source = "\n".join(lines) - modified_source = dedent(modified_source.strip("\n")) - return write_to_file_and_return_fn(original_func, modified_source, mode) - -def write_to_file_and_return_fn(original_func, source:str, mode='a'): - # Write the modified source code to a temporary file so the - # source code and stack traces can still be viewed when debugging. - custom_name = ".patches.py" - current_dir = os.path.dirname(os.path.abspath(__file__)) - temp_file_path = os.path.join(current_dir, custom_name) - with open(temp_file_path, mode) as temp_file: - temp_file.write(source) - temp_file.write("\n") - temp_file.flush() - - MODULE_PATH = temp_file.name - MODULE_NAME = __name__.split('.')[0].replace('-','_') + "_patch_modules" - spec = importlib.util.spec_from_file_location(MODULE_NAME, MODULE_PATH) - module = importlib.util.module_from_spec(spec) - sys.modules[spec.name] = module - spec.loader.exec_module(module) - - # Retrieve the modified function from the module - modified_function = getattr(module, original_func.__name__) - - # Adapted from https://stackoverflow.com/a/49077211 - def copy_func(f, globals=None, module=None, code=None, update_wrapper=True): - if globals is None: globals = f.__globals__ - if code is None: code = f.__code__ - g = types.FunctionType(code, globals, name=f.__name__, - argdefs=f.__defaults__, closure=f.__closure__) - if update_wrapper: g = functools.update_wrapper(g, f) - if module is not None: g.__module__ = module - g.__kwdefaults__ = copy(f.__kwdefaults__) - return g - - return copy_func(original_func, code=modified_function.__code__, update_wrapper=False) - - -hook_all(hooks=[ - # hook_tokenize_with_weights(), - hook_load_torch_file(), -]) \ No newline at end of file