Skip to content

Commit

Permalink
joy caption 添加在线模式
Browse files Browse the repository at this point in the history
  • Loading branch information
aidenli committed Sep 6, 2024
1 parent 68f5fed commit c822464
Show file tree
Hide file tree
Showing 8 changed files with 212 additions and 43 deletions.
2 changes: 1 addition & 1 deletion __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def install_package(full_name, package_name):


def check_and_install_packages():
packages = ["pygtrans"]
packages = ["pygtrans", "fake_useragent"]
for package in packages:
package_name = re.match(r"^([^\s=<>!]+)", package.strip())
if package_name:
Expand Down
67 changes: 56 additions & 11 deletions nodes/JoyCaption/JoyCaption.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
from ..config import LoadConfig, print_log
import os
from huggingface_hub import snapshot_download
import gc
import torch
from PIL import Image
import numpy as np
from pathlib import Path
from .online import joy_caption_online
from ..utils import create_nonceid
import time

config_data = LoadConfig()

Expand Down Expand Up @@ -97,6 +99,7 @@ def INPUT_TYPES(self):
{"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01},
),
"clear_cache": ("BOOLEAN", {"default": False}),
"newbie": ("BOOLEAN", {"default": False}),
}
}

Expand All @@ -108,15 +111,8 @@ def INPUT_TYPES(self):
llama_model = None
clip_model = None

def run(
self,
model,
image,
prompt,
max_new_tokens,
top_k,
temperature,
clear_cache,
def run_local(
self, model, image, prompt, max_new_tokens, top_k, temperature, clear_cache
):
if self.llama_model is None:
# load LLM
Expand Down Expand Up @@ -239,11 +235,60 @@ def run(
generate_ids = generate_ids[:, :-1]

caption = self.llama_model.tokenizer.batch_decode(
generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
generate_ids,
skip_special_tokens=False,
clean_up_tokenization_spaces=False,
)[0]
print_log(f"Exec Finished")
if clear_cache is True:
self.clip_model = None
self.llama_model = None

return (caption.strip(),)

def run_online(self, images):
tmp_folder = os.path.join(config_data["base_path"], "tmp")
if not os.path.exists(tmp_folder):
os.mkdir(tmp_folder)

for batch_number, image in enumerate(images):
# 只处理一张
i = 255.0 * image.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
file_name = f"{time.time()}_{create_nonceid(10)}.png"
file_path = os.path.join(tmp_folder, file_name)
img.save(file_path)
break

jco = joy_caption_online()
result = jco.analyze(file_path).strip()
try:
os.remove(file_path)
except Exception as e:
print_log(f"删除临时文件失败:{e}")

return (result,)

def run(
self,
model,
image,
prompt,
max_new_tokens,
top_k,
temperature,
clear_cache,
newbie,
):
if newbie == True:
return self.run_online(image)
else:
return self.run_local(
model,
image,
prompt,
max_new_tokens,
top_k,
temperature,
clear_cache,
)
119 changes: 119 additions & 0 deletions nodes/JoyCaption/online.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import requests
import json
from pathlib import Path
from fake_useragent import UserAgent
from ..config import print_log
from ..utils import create_nonceid, get_system_proxy


class joy_caption_online:
def __init__(self, host="fancyfeast-joy-caption-pre-alpha.hf.space") -> None:
# 获取系统代理
proxy = get_system_proxy()
proxies = {
"http": f"{proxy}",
"https": f"{proxy}",
}

ua = UserAgent(platforms="pc")
# 定义头部信息
headers = {
"origin": f"https://{host}",
"referer": f"https://{host}/?__theme=light",
"user-agent": ua.random,
}

self.__session = requests.Session()
self.__session.trust_env = False
self.__session.proxies = proxies
self.__session.headers = headers
self.__host = host
pass

def __upload_image(self, img_path, file_info):
file_info.name
nonceid = create_nonceid(11)

try:
response = self.__session.post(
f"https://{self.__host}/upload?upload_id={nonceid}",
files={"files": (file_info.name, open(file_info, "rb"))},
)

arr_img = json.loads(response.text)
return arr_img[0]
except requests.exceptions.RequestException as e:
print_log(f"请求发生异常:{e}")
except Exception as e:
print_log(f"业务处理异常:{e}")
return ""

def __add_queue(self, img_url, file_info):
nonceid = create_nonceid(10)
try:
response = self.__session.post(
f"https://{self.__host}/queue/join?__theme=light",
json={
"data": [
{
"path": img_url,
"url": f"https://{{self.__host}}/file={img_url}",
"orig_name": file_info.name,
"size": file_info.stat().st_size,
"mime_type": "image/jpeg",
"meta": {"_type": "gradio.FileData"},
}
],
"event_data": None,
"fn_index": 0,
"trigger_id": 5,
"session_hash": nonceid,
},
)
return nonceid
except requests.exceptions.RequestException as e:
print_log(f"请求发生异常:{e}")
except Exception as e:
print_log(f"业务处理异常:{e}")

return ""

def __get_result(self, nonceid):
try:
response = self.__session.get(
f"https://{self.__host}/queue/data?session_hash={nonceid}",
stream=True,
)

for line in response.iter_lines():
if line:
decoded_line = line.decode("utf-8")
if decoded_line.startswith("data:") == False:
continue

data = decoded_line[5:]
json_content = json.loads(data)

if json_content["msg"] != "process_completed":
continue

return json_content["output"]["data"][0]
except requests.exceptions.RequestException as e:
print_log(f"请求发生异常:{e}")
except Exception as e:
print_log(f"业务处理异常:{e}")

return ""

def analyze(self, img_path):
file_info = Path(img_path)
print_log("上传图片进行分析")
img_url = self.__upload_image(img_path, file_info)
if img_url == "":
return ""
print_log("提交服务器处理中...")
nonceid = self.__add_queue(img_url, file_info)
if nonceid == "":
return ""
print_log("图片分析中...")
return self.__get_result(nonceid)
10 changes: 6 additions & 4 deletions nodes/Translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
from hashlib import md5
from .config import LoadConfig
from .utils import get_system_proxy
from pygtrans import Translate

# 语言列表
Expand Down Expand Up @@ -68,10 +69,6 @@

class TranslateNode:
def __init__(self):
config_data = LoadConfig()
self.appid = config_data["Baidu"]["AppId"]
self.appkey = config_data["Baidu"]["Secret"]
self.proxy = config_data["Google"]["proxy"] if "Google" in config_data else ""
pass

@classmethod
Expand Down Expand Up @@ -154,6 +151,11 @@ def trans_by_google(self, from_lang, to_lang, text):
return text.translatedText

def run(self, from_lang, to_lang, text, platform, clip=None):
config_data = LoadConfig()
self.appid = config_data["Baidu"]["AppId"]
self.appkey = config_data["Baidu"]["Secret"]
self.proxy = get_system_proxy()

if platform == "Google":
translate_str = self.trans_by_google(from_lang, to_lang, text)
else:
Expand Down
24 changes: 0 additions & 24 deletions nodes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,6 @@
import json
from inspect import currentframe, stack, getmodule
import time
import winreg


# 获取系统代理地址
def get_system_proxy():
try:
internet_settings = winreg.OpenKey(
winreg.HKEY_CURRENT_USER,
r"Software\Microsoft\Windows\CurrentVersion\Internet Settings",
)
proxy_server, _ = winreg.QueryValueEx(internet_settings, "ProxyServer")
proxy_enable, _ = winreg.QueryValueEx(internet_settings, "ProxyEnable")
if proxy_enable:
return proxy_server
else:
return None
except FileNotFoundError:
return None


config_template = {
Expand All @@ -28,7 +10,6 @@ def get_system_proxy():
"model_download": "https://hf-mirror.com/fancyfeast/joytag/tree/main",
"hf_project": "fancyfeast/joytag",
},
"Google": {"proxy": "http://localhost:10809"},
}

current_path = os.path.abspath(os.path.dirname(__file__))
Expand Down Expand Up @@ -61,11 +42,6 @@ def LoadConfig():
# 合并最新的配置项(当config_template有变动的时候)
config_data = merge_config(config_data, config_template)

# 获取系统代理地址,并修改配置文件
proxy = get_system_proxy()
if proxy:
config_data["Google"]["proxy"] = proxy

with open(config_path, "w") as f:
f.write(json.dumps(config_data, indent=4))

Expand Down
26 changes: 26 additions & 0 deletions nodes/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import string
import secrets
import winreg


def create_nonceid(length=10):
alphabet = string.ascii_letters + string.digits
nonceid = "".join(secrets.choice(alphabet) for i in range(length))
return nonceid


# 获取系统代理地址
def get_system_proxy():
try:
internet_settings = winreg.OpenKey(
winreg.HKEY_CURRENT_USER,
r"Software\Microsoft\Windows\CurrentVersion\Internet Settings",
)
proxy_server, _ = winreg.QueryValueEx(internet_settings, "ProxyServer")
proxy_enable, _ = winreg.QueryValueEx(internet_settings, "ProxyEnable")
if proxy_enable:
return proxy_server
else:
return None
except FileNotFoundError:
return None
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
[project]
name = "comfyui_nyjy"
description = "A comfyui node that provides translation and image reverse push functions(JoyTag & JoyCaption)."
version = "1.1.2"
version = "1.2.0"
license = {file = "LICENSE"}
dependencies = ["torch", "torchvision>=0.15.2", "einops>=0.7.0", "safetensors>=0.4.1", "pillow>=9.4.0", "huggingface_hub>=0.23.5", "accelerate", "transformers>=4.43.3", "sentencepiece", "bitsandbytes==0.43.3", "pygtrans"]
dependencies = ["torch", "torchvision>=0.15.2", "einops>=0.7.0", "safetensors>=0.4.1", "pillow>=9.4.0", "huggingface_hub>=0.23.5", "accelerate", "transformers>=4.43.3", "sentencepiece", "bitsandbytes==0.43.3", "pygtrans", "fake_useragent"]

[project.urls]
Repository = "https://github.com/aidenli/ComfyUI_NYJY"
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ accelerate
transformers>=4.43.3
sentencepiece
bitsandbytes==0.43.3
pygtrans
pygtrans
fake_useragent

0 comments on commit c822464

Please sign in to comment.