Skip to content

Commit

Permalink
feat: support image understanding from html
Browse files Browse the repository at this point in the history
  • Loading branch information
IcyKallen committed Jan 17, 2025
1 parent 6f81db0 commit 0cda45a
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 49 deletions.
Binary file modified source/lambda/job/dep/dist/llm_bot_dep-0.1.0-py3-none-any.whl
Binary file not shown.
155 changes: 123 additions & 32 deletions source/lambda/job/dep/llm_bot_dep/figure_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,19 @@
import io
import json
import logging
import mimetypes
import os
import re
import tempfile
import urllib.request
from datetime import datetime
from pathlib import Path
from urllib.parse import urlparse

import boto3
import requests
from botocore.exceptions import ClientError
from PIL import Image

CHART_UNDERSTAND_PROMPT = """
您是文档阅读专家。您的任务是将图片中的图表转换成Markdown格式。以下是说明:
Expand Down Expand Up @@ -43,16 +50,38 @@
FIGURE_CLASSIFICATION_PROMPT_PATH = os.path.join(os.path.dirname(__file__), "prompt/figure_classification.txt")
MERMAID_TEMPLATE_PATH = os.path.join(os.path.dirname(__file__), "prompt/mermaid_template.txt")

# Add minimum size threshold constants
MIN_WIDTH = 50 # minimum width in pixels
MIN_HEIGHT = 50 # minimum height in pixels

logger = logging.getLogger(__name__)
s3_client = boto3.client("s3")


class figureUnderstand:
"""A class to understand and process figures using LLM.
This class provides methods to analyze images using Claude 3 Sonnet model,
classify them, and generate appropriate descriptions or representations.
"""

def __init__(self):
"""Initialize the figureUnderstand class with Bedrock runtime client."""
self.bedrock_runtime = boto3.client(service_name="bedrock-runtime")
self.mermaid_prompt = json.load(open(MERMAID_PROMPT_PATH, "r"))

def invoke_llm(self, img, prompt, prefix="<output>", stop="</output>"):
"""Invoke the LLM model with an image and prompt.
Args:
img: Either a base64 encoded string or PIL Image object
prompt (str): The prompt to send to the model
prefix (str): Starting tag for the output
stop (str): Ending tag for the output
Returns:
str: The model's response with prefix and stop tags
"""
# If img is already a base64 string, use it directly
if isinstance(img, str):
base64_encoded = img
Expand Down Expand Up @@ -174,63 +203,125 @@ def upload_image_to_s3(image_path: str, bucket: str, file_name: str, splitting_t
return object_key


def process_markdown_images_with_llm(content: str, bucket_name: str, file_name: str) -> str:
"""Process images in markdown content and upload them to S3.
def download_image_from_url(img_url: str) -> str:
"""Download image from URL and save to temporary file.
Args:
content (str): The markdown content containing images to process
bucket_name (str): The S3 bucket where images will be uploaded
file_name (str): The file name for organizing uploads
Returns:
str: Path to temporary file containing the image
"""
response = requests.get(img_url, timeout=10)
response.raise_for_status()

content_type = response.headers.get("Content-Type", "")
ext = mimetypes.guess_extension(content_type) or ".jpg"
if ext == ".jpe":
ext = ".jpg"

temp_file = tempfile.NamedTemporaryFile(suffix=ext, delete=False)
temp_file.write(response.content)
temp_file.close()
return temp_file.name


def process_single_image(
img_path: str, context: str, image_tag: str, bucket_name: str, file_name: str, idx: int
) -> str:
"""Process a single image and return its understanding text.
Args:
img_path (str): Path to the image file
context (str): Surrounding text context for the image
image_tag (str): Tag to identify the image in the context
bucket_name (str): S3 bucket name for uploading
file_name (str): Base file name for S3 path
idx (int): Index number of the image
Returns:
str: Processed markdown with updated image references
str: The processed understanding text for the image, or None if image is too small
Raises:
Various exceptions during image processing and upload
"""
with Image.open(img_path) as img:
width, height = img.size
if width < MIN_WIDTH or height < MIN_HEIGHT:
logger.warning(f"Image {idx} is too small ({width}x{height}). Skipping processing.")
return None

image_base64 = encode_image_to_base64(img_path)
figure_llm = figureUnderstand()

# Get image understanding
understanding = figure_llm.figure_understand(image_base64, context, image_tag, s3_link=f"{idx:05d}.jpg")

# Update S3 link
updated_s3_link = upload_image_to_s3(img_path, bucket_name, file_name, "image", idx)
understanding = understanding.replace(f"<link>{idx:05d}.jpg</link>", f"<link>{updated_s3_link}</link>")

return understanding


def process_markdown_images_with_llm(content: str, bucket_name: str, file_name: str) -> str:
"""Process all images in markdown content and upload them to S3.
This function:
1. Finds all markdown image references in the content
2. Downloads images if they are URLs
3. Processes each image with LLM
4. Uploads images to S3
5. Replaces image references with processed understanding
Args:
content (str): The markdown content containing images
bucket_name (str): S3 bucket name for uploading
file_name (str): Base file name for S3 path
Returns:
str: Updated content with processed image understandings
"""
# Regular expression to find markdown image syntax: ![alt text](image_path)
image_pattern = r"!\[([^\]]*)\]\(([^)]+)\)"

# Keep track of where we last ended to maintain the full text
last_end = 0
result = ""

for idx, match in enumerate(re.finditer(image_pattern, content), 1):
# Generate unique identifier for this image
image_tag = f"[IMAGE_{idx:05d}]"

# Get the full image match and its position
start, end = match.span()
img_path = match.group(2) # Get the image path from the markdown syntax
img_path = match.group(2)
image_tag = f"[IMAGE_{idx:05d}]"

# Add the text before the image
result += content[last_end:start]

# Get context (200 characters before and after)
context_start = max(0, start - 200)
context_end = min(len(content), end + 200)
context = f"{content[context_start:start]}\n<image>\n{image_tag}\n</image>\n{content[end:context_end]}"

try:
# Convert image to base64
image_base64 = encode_image_to_base64(img_path)

# Get image understanding
understanding = figure_llm.figure_understand(image_base64, context, image_tag, s3_link=f"{idx:05d}.jpg")

updated_s3_link = upload_image_to_s3(img_path, bucket_name, file_name, "image", idx)
understanding = understanding.replace(f"<link>{idx:05d}.jpg</link>", f"<link>{updated_s3_link}</link>")

# Add the understanding text
result += f"\n\n{understanding}\n\n"
# Handle URL images
if img_path.startswith(("http://", "https://")):
try:
img_path = download_image_from_url(img_path)
except Exception as e:
logger.error(f"Error downloading image from URL {img_path}: {e}")
result += match.group(0)
last_end = end
continue

# Get context
context_start = max(0, start - 200)
context_end = min(len(content), end + 200)
context = f"{content[context_start:start]}\n<image>\n{image_tag}\n</image>\n{content[end:context_end]}"

# Process the image
understanding = process_single_image(img_path, context, image_tag, bucket_name, file_name, idx)

if understanding:
result += f"\n\n{understanding}\n\n"
else:
result += match.group(0)

except Exception as e:
logger.error(f"Error processing image {idx}: {e}")
# If there's an error, keep the original markdown image syntax
result += match.group(0)

last_end = end

# Add any remaining text after the last image
result += content[last_end:]

return result
16 changes: 0 additions & 16 deletions source/lambda/job/dep/llm_bot_dep/loaders/docx.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
import os
import sys
import uuid
from datetime import datetime
from pathlib import Path
Expand All @@ -14,9 +13,6 @@
from llm_bot_dep.splitter_utils import MarkdownHeaderTextSplitter
from PIL import Image

# sys.path.append("/home/ubuntu/icyxu/code/solutions/Intelli-Agent/source/lambda/job/dep")


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -118,15 +114,3 @@ def process_doc(s3, **kwargs):
doc_list = splitter.split_text(doc)

return doc_list


# if __name__ == "__main__":
# import boto3

# s3 = boto3.client("s3")
# kwargs = {
# "res_bucket": "ai-customer-service-sharedconstructaicustomerservi-wywyift3c084",
# "bucket": "ai-customer-service-sharedconstructaicustomerservi-wywyift3c084",
# "key": "workshop/CATS.docx",
# }
# process_doc(s3, **kwargs)
1 change: 0 additions & 1 deletion source/lambda/job/dep/llm_bot_dep/loaders/html.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def load(self, file_content: str, bucket_name: str, file_name: str):
html_content = self.clean_html(file_content)
file_content = markdownify.markdownify(html_content, heading_style="ATX")
file_content = process_markdown_images_with_llm(file_content, bucket_name, file_name)
print(file_content)
doc = Document(
page_content=file_content,
metadata={"file_type": "html", "file_path": self.aws_path},
Expand Down

0 comments on commit 0cda45a

Please sign in to comment.