diff --git a/CHANGELOG.md b/CHANGELOG.md index 5d41c5426..07bea87ad 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,7 @@ - xyz grid handle ip adapter name and scale - lazy loading of image may prevent metadata from being loaded on time - allow startup without valid models folder + - fix interrogate api endpoint - handle extensions that install conflicting versions of packages `onnxruntime`, `opencv2-python` diff --git a/modules/api/endpoints.py b/modules/api/endpoints.py index 23d5e48ef..f28c3bd8b 100644 --- a/modules/api/endpoints.py +++ b/modules/api/endpoints.py @@ -77,22 +77,28 @@ def post_interrogate(req: models.ReqInterrogate): image = helpers.decode_base64_to_image(req.image) image = image.convert('RGB') if req.model == "clip": - caption = shared.interrogator.interrogate(image) - return models.ResInterrogate(caption) - elif req.model == "deepdanbooru": - from mobules import deepbooru + try: + caption = shared.interrogator.interrogate(image) + except Exception as e: + caption = str(e) + return models.ResInterrogate(caption=caption) + elif req.model == "deepdanbooru" or req.model == 'deepbooru': + from modules import deepbooru caption = deepbooru.model.tag(image) - return models.ResInterrogate(caption) + return models.ResInterrogate(caption=caption) else: from modules.ui_interrogate import interrogate_image, analyze_image, get_models if req.model not in get_models(): raise HTTPException(status_code=404, detail="Model not found") - caption = interrogate_image(image, model=req.model, mode=req.mode) + try: + caption = interrogate_image(image, model=req.model, mode=req.mode) + except Exception as e: + caption = str(e) if not req.analyze: - return models.ResInterrogate(caption) + return models.ResInterrogate(caption=caption) else: medium, artist, movement, trending, flavor = analyze_image(image, model=req.model) - return models.ResInterrogate(caption, medium, artist, movement, trending, flavor) + return models.ResInterrogate(caption=caption, medium=medium, artist=artist, movement=movement, trending=trending, flavor=flavor) def post_unload_checkpoint(): from modules import sd_models diff --git a/modules/interrogate.py b/modules/interrogate.py index 16bdbebe1..3932489fd 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -79,7 +79,8 @@ def checkpoint_wrapper(self): def load_blip_model(self): self.create_fake_fairscale() - import models.blip # pylint: disable=no-name-in-module + from repositories.blip import models + from repositories.blip.models import blip import modules.modelloader as modelloader model_path = os.path.join(paths.models_path, "BLIP") download_name='model_base_caption_capfilt_large.pth', @@ -90,7 +91,7 @@ def load_blip_model(self): ext_filter=[".pth"], download_name=download_name, ) - blip_model = models.blip.blip_decoder(pretrained=files[0], image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json")) # pylint: disable=c-extension-no-member + blip_model = blip.blip_decoder(pretrained=files[0], image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json")) # pylint: disable=c-extension-no-member blip_model.eval() return blip_model diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index 6dc39de21..a8926a452 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -241,7 +241,6 @@ def tokenize(self, texts): def encode_with_transformers(self, tokens): clip_skip = int(opts.data['clip_skip']) or 1 - print('HERE', type(clip_skip), clip_skip) outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-clip_skip) if clip_skip > 1: z = outputs.hidden_states[-clip_skip] diff --git a/repositories/blip/models/blip.py b/repositories/blip/models/blip.py index 38678f65e..32cdee3dc 100644 --- a/repositories/blip/models/blip.py +++ b/repositories/blip/models/blip.py @@ -8,8 +8,8 @@ import warnings warnings.filterwarnings("ignore") -from models.vit import VisionTransformer, interpolate_pos_embed -from models.med import BertConfig, BertModel, BertLMHeadModel +from repositories.blip.models.vit import VisionTransformer, interpolate_pos_embed +from repositories.blip.models.med import BertConfig, BertModel, BertLMHeadModel from transformers import BertTokenizer import torch