Skip to content

Commit

Permalink
fix interrogate api
Browse files Browse the repository at this point in the history
  • Loading branch information
vladmandic committed Feb 10, 2024
1 parent 8ec457c commit 19e5062
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 13 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
- xyz grid handle ip adapter name and scale
- lazy loading of image may prevent metadata from being loaded on time
- allow startup without valid models folder
- fix interrogate api endpoint
- handle extensions that install conflicting versions of packages
`onnxruntime`, `opencv2-python`

Expand Down
22 changes: 14 additions & 8 deletions modules/api/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,22 +77,28 @@ def post_interrogate(req: models.ReqInterrogate):
image = helpers.decode_base64_to_image(req.image)
image = image.convert('RGB')
if req.model == "clip":
caption = shared.interrogator.interrogate(image)
return models.ResInterrogate(caption)
elif req.model == "deepdanbooru":
from mobules import deepbooru
try:
caption = shared.interrogator.interrogate(image)
except Exception as e:
caption = str(e)
return models.ResInterrogate(caption=caption)
elif req.model == "deepdanbooru" or req.model == 'deepbooru':
from modules import deepbooru
caption = deepbooru.model.tag(image)
return models.ResInterrogate(caption)
return models.ResInterrogate(caption=caption)
else:
from modules.ui_interrogate import interrogate_image, analyze_image, get_models
if req.model not in get_models():
raise HTTPException(status_code=404, detail="Model not found")
caption = interrogate_image(image, model=req.model, mode=req.mode)
try:
caption = interrogate_image(image, model=req.model, mode=req.mode)
except Exception as e:
caption = str(e)
if not req.analyze:
return models.ResInterrogate(caption)
return models.ResInterrogate(caption=caption)
else:
medium, artist, movement, trending, flavor = analyze_image(image, model=req.model)
return models.ResInterrogate(caption, medium, artist, movement, trending, flavor)
return models.ResInterrogate(caption=caption, medium=medium, artist=artist, movement=movement, trending=trending, flavor=flavor)

def post_unload_checkpoint():
from modules import sd_models
Expand Down
5 changes: 3 additions & 2 deletions modules/interrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ def checkpoint_wrapper(self):

def load_blip_model(self):
self.create_fake_fairscale()
import models.blip # pylint: disable=no-name-in-module
from repositories.blip import models
from repositories.blip.models import blip
import modules.modelloader as modelloader
model_path = os.path.join(paths.models_path, "BLIP")
download_name='model_base_caption_capfilt_large.pth',
Expand All @@ -90,7 +91,7 @@ def load_blip_model(self):
ext_filter=[".pth"],
download_name=download_name,
)
blip_model = models.blip.blip_decoder(pretrained=files[0], image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json")) # pylint: disable=c-extension-no-member
blip_model = blip.blip_decoder(pretrained=files[0], image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json")) # pylint: disable=c-extension-no-member
blip_model.eval()

return blip_model
Expand Down
1 change: 0 additions & 1 deletion modules/sd_hijack_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,6 @@ def tokenize(self, texts):

def encode_with_transformers(self, tokens):
clip_skip = int(opts.data['clip_skip']) or 1
print('HERE', type(clip_skip), clip_skip)
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-clip_skip)
if clip_skip > 1:
z = outputs.hidden_states[-clip_skip]
Expand Down
4 changes: 2 additions & 2 deletions repositories/blip/models/blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import warnings
warnings.filterwarnings("ignore")

from models.vit import VisionTransformer, interpolate_pos_embed
from models.med import BertConfig, BertModel, BertLMHeadModel
from repositories.blip.models.vit import VisionTransformer, interpolate_pos_embed
from repositories.blip.models.med import BertConfig, BertModel, BertLMHeadModel
from transformers import BertTokenizer

import torch
Expand Down

0 comments on commit 19e5062

Please sign in to comment.