Skip to content

Commit

Permalink
优化图片爬取
Browse files Browse the repository at this point in the history
  • Loading branch information
aidenli committed Dec 25, 2024
1 parent 1171a3b commit 69d1a33
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 49 deletions.
87 changes: 39 additions & 48 deletions nodes/civitai_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,13 @@
from urllib.parse import quote
import random
from .utils import get_system_proxy
from lxml import html
from .utils import print_log, save_image_bytes_for_preview
import urllib.request
import folder_paths
import os
import node_helpers
import torch
import numpy as np
from PIL import Image, ImageOps, ImageSequence, ImageFile
from PIL import Image, ImageOps, ImageSequence


class CivitaiPromptNode:
Expand Down Expand Up @@ -124,7 +122,7 @@ def filter_meta(item):

self.__next_cursor = json_rsp["result"]["data"]["json"]["nextCursor"]
return [
item["id"]
(item["id"], item["url"])
for item in json_rsp["result"]["data"]["json"]["items"]
if filter_meta(item)
]
Expand Down Expand Up @@ -159,57 +157,50 @@ def get_image_detail(self, image_id) -> tuple[str, str, int]:
print_log(f"获取图片详情失败[{image_id}]:{e}")
return ("", "", 1)

def get_image(self, image_id):
def get_image(self, image_id, image_url):
try:
# get hash
response = self.__session.get(
"https://raw.githubusercontent.com/aidenli/c_pre/refs/heads/main/hash.txt", timeout=10
)
hash = response.text
print(hash)
req_str = quote(json.dumps({"json":{"id":image_id}})).replace("%20", "")
print(f"https://{self.__host}/api/trpc/image.get?input={req_str}")
response = requests.get(
f"https://{self.__host}/api/trpc/image.get?input={req_str}", timeout=10
)
print(response.text)
json_rsp = json.loads(response.text)
image_hash_id = json_rsp["result"]["data"]["json"]["url"]
img_url = f"https://image.{self.__host}/{hash}/{image_hash_id}/width=450/{image_id}.jpeg"
response = self.__session.get("https://raw.githubusercontent.com/aidenli/c_pre/refs/heads/main/hash.txt")
image_hash = response.text.strip()
img_url = f"https://image.{self.__host}/{image_hash}/{image_url}/width=450/{image_id}.jpeg"
print(img_url)
print_log(f"下载图片: {img_url}")
img_response = self.__session.get(img_url, timeout=20)
if img_response.status_code == 200:
return img_response.content
else:
print(img_response.status_code)
return None
except Exception as e:
print_log(f"下载图片失败[{image_id}]:{e}")
return None

def get_output_image(self, image_id):
return None
image_path = os.path.join(self.output_dir, f"{image_id}.jpeg")
img = node_helpers.pillow(Image.open, image_path)
try:
image_path = os.path.join(self.output_dir, f"{image_id}.jpeg")
img = node_helpers.pillow(Image.open, image_path)

output_images = []
excluded_formats = ["MPO"]
for i in ImageSequence.Iterator(img):
i = node_helpers.pillow(ImageOps.exif_transpose, i)
output_images = []
excluded_formats = ["MPO"]
for i in ImageSequence.Iterator(img):
i = node_helpers.pillow(ImageOps.exif_transpose, i)

if i.mode == "I":
i = i.point(lambda i: i * (1 / 255))
image = i.convert("RGB")
if i.mode == "I":
i = i.point(lambda i: i * (1 / 255))
image = i.convert("RGB")

image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,]
output_images.append(image)
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,]
output_images.append(image)

if len(output_images) > 1 and img.format not in excluded_formats:
output_image = torch.cat(output_images, dim=0)
else:
output_image = output_images[0]
if len(output_images) > 1 and img.format not in excluded_formats:
output_image = torch.cat(output_images, dim=0)
else:
output_image = output_images[0]

return output_image
return output_image
except Exception as e:
print_log(f"获取图片失败[{image_id}]:{e}")
return None

def choise_image(
self, fixed_prompt, preview_image, mirror_sites
Expand Down Expand Up @@ -247,7 +238,7 @@ def choise_image(
while cur < 5:
cur += 1
idx = random.randint(0, len(self.__image_list) - 1)
image_id = self.__image_list.pop(idx)
(image_id, image_url) = self.__image_list.pop(idx)
print_log(f"获取图片[{image_id}]的提示词")
positive, negative, errcode = self.get_image_detail(image_id)
if errcode > 0:
Expand All @@ -258,15 +249,15 @@ def choise_image(
self.__cache_positive = positive
self.__cache_negative = negative
if preview_image:
# print_log(f"获取图片[{image_id}]的内容")
# image_content = self.get_image(image_id)
print_log(f"获取图片[{image_id}]的内容")
image_content = self.get_image(image_id, image_url)

# if image_content is not None:
# # 保存图片到output目录
# with open(
# os.path.join(self.output_dir, f"{image_id}.jpeg"), "wb"
# ) as f:
# f.write(image_content)
if image_content is not None:
# 保存图片到output目录
with open(
os.path.join(self.output_dir, f"{image_id}.jpeg"), "wb"
) as f:
f.write(image_content)

# 保存提示词
with open(
Expand All @@ -276,7 +267,7 @@ def choise_image(
if len(negative) > 0:
f.write(f"\n\n---------------------\nnegative:\n{negative}")

previews = [] # [save_image_bytes_for_preview(image_content)]
previews = [save_image_bytes_for_preview(image_content)]
self.__cache_previews = previews
return {
"result": (positive, negative, self.get_output_image(image_id)),
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "comfyui_nyjy"
description = "A comfyui node that provides translation and image reverse push functions(JoyTag & JoyCaption)."
version = "1.8.4"
version = "1.8.5"
license = {file = "LICENSE"}
dependencies = ["torch", "torchvision>=0.15.2", "einops>=0.7.0", "safetensors>=0.4.1", "pillow>=9.4.0", "huggingface_hub>=0.23.5", "accelerate", "transformers>=4.43.3", "sentencepiece", "bitsandbytes>=0.43.3", "pygtrans", "fake_useragent", "lxml", "openai", "gradio_client"]

Expand Down

0 comments on commit 69d1a33

Please sign in to comment.