Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support image understanding from html #512

Merged
merged 2 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified source/lambda/job/dep/dist/llm_bot_dep-0.1.0-py3-none-any.whl
Binary file not shown.
155 changes: 123 additions & 32 deletions source/lambda/job/dep/llm_bot_dep/figure_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,19 @@
import io
import json
import logging
import mimetypes
import os
import re
import tempfile
import urllib.request
from datetime import datetime
from pathlib import Path
from urllib.parse import urlparse

import boto3
import requests
from botocore.exceptions import ClientError
from PIL import Image

CHART_UNDERSTAND_PROMPT = """
您是文档阅读专家。您的任务是将图片中的图表转换成Markdown格式。以下是说明:
Expand Down Expand Up @@ -43,16 +50,38 @@
FIGURE_CLASSIFICATION_PROMPT_PATH = os.path.join(os.path.dirname(__file__), "prompt/figure_classification.txt")
MERMAID_TEMPLATE_PATH = os.path.join(os.path.dirname(__file__), "prompt/mermaid_template.txt")

# Add minimum size threshold constants
MIN_WIDTH = 50 # minimum width in pixels
MIN_HEIGHT = 50 # minimum height in pixels

logger = logging.getLogger(__name__)
s3_client = boto3.client("s3")


class figureUnderstand:
"""A class to understand and process figures using LLM.

This class provides methods to analyze images using Claude 3 Sonnet model,
classify them, and generate appropriate descriptions or representations.
"""

def __init__(self):
"""Initialize the figureUnderstand class with Bedrock runtime client."""
self.bedrock_runtime = boto3.client(service_name="bedrock-runtime")
self.mermaid_prompt = json.load(open(MERMAID_PROMPT_PATH, "r"))

def invoke_llm(self, img, prompt, prefix="<output>", stop="</output>"):
"""Invoke the LLM model with an image and prompt.

Args:
img: Either a base64 encoded string or PIL Image object
prompt (str): The prompt to send to the model
prefix (str): Starting tag for the output
stop (str): Ending tag for the output

Returns:
str: The model's response with prefix and stop tags
"""
# If img is already a base64 string, use it directly
if isinstance(img, str):
base64_encoded = img
Expand Down Expand Up @@ -174,63 +203,125 @@ def upload_image_to_s3(image_path: str, bucket: str, file_name: str, splitting_t
return object_key


def process_markdown_images_with_llm(content: str, bucket_name: str, file_name: str) -> str:
"""Process images in markdown content and upload them to S3.
def download_image_from_url(img_url: str) -> str:
"""Download image from URL and save to temporary file.

Returns:
str: Path to temporary file containing the image
"""
response = requests.get(img_url, timeout=10)
response.raise_for_status()

content_type = response.headers.get("Content-Type", "")
ext = mimetypes.guess_extension(content_type) or ".jpg"
if ext == ".jpe":
ext = ".jpg"

temp_file = tempfile.NamedTemporaryFile(suffix=ext, delete=False)
temp_file.write(response.content)
temp_file.close()
return temp_file.name


def process_single_image(
img_path: str, context: str, image_tag: str, bucket_name: str, file_name: str, idx: int
) -> str:
"""Process a single image and return its understanding text.

Args:
content (str): The markdown content containing images to process
bucket_name (str): The S3 bucket where images will be uploaded
file_name (str): The file name for organizing uploads
img_path (str): Path to the image file
context (str): Surrounding text context for the image
image_tag (str): Tag to identify the image in the context
bucket_name (str): S3 bucket name for uploading
file_name (str): Base file name for S3 path
idx (int): Index number of the image

Returns:
str: Processed markdown with updated image references
str: The processed understanding text for the image, or None if image is too small

Raises:
Various exceptions during image processing and upload
"""
with Image.open(img_path) as img:
width, height = img.size
if width < MIN_WIDTH or height < MIN_HEIGHT:
logger.warning(f"Image {idx} is too small ({width}x{height}). Skipping processing.")
return None

image_base64 = encode_image_to_base64(img_path)
figure_llm = figureUnderstand()

# Get image understanding
understanding = figure_llm.figure_understand(image_base64, context, image_tag, s3_link=f"{idx:05d}.jpg")

# Update S3 link
updated_s3_link = upload_image_to_s3(img_path, bucket_name, file_name, "image", idx)
understanding = understanding.replace(f"<link>{idx:05d}.jpg</link>", f"<link>{updated_s3_link}</link>")

return understanding


def process_markdown_images_with_llm(content: str, bucket_name: str, file_name: str) -> str:
"""Process all images in markdown content and upload them to S3.

This function:
1. Finds all markdown image references in the content
2. Downloads images if they are URLs
3. Processes each image with LLM
4. Uploads images to S3
5. Replaces image references with processed understanding

Args:
content (str): The markdown content containing images
bucket_name (str): S3 bucket name for uploading
file_name (str): Base file name for S3 path

Returns:
str: Updated content with processed image understandings
"""
# Regular expression to find markdown image syntax: ![alt text](image_path)
image_pattern = r"!\[([^\]]*)\]\(([^)]+)\)"

# Keep track of where we last ended to maintain the full text
last_end = 0
result = ""

for idx, match in enumerate(re.finditer(image_pattern, content), 1):
# Generate unique identifier for this image
image_tag = f"[IMAGE_{idx:05d}]"

# Get the full image match and its position
start, end = match.span()
img_path = match.group(2) # Get the image path from the markdown syntax
img_path = match.group(2)
image_tag = f"[IMAGE_{idx:05d}]"

# Add the text before the image
result += content[last_end:start]

# Get context (200 characters before and after)
context_start = max(0, start - 200)
context_end = min(len(content), end + 200)
context = f"{content[context_start:start]}\n<image>\n{image_tag}\n</image>\n{content[end:context_end]}"

try:
# Convert image to base64
image_base64 = encode_image_to_base64(img_path)

# Get image understanding
understanding = figure_llm.figure_understand(image_base64, context, image_tag, s3_link=f"{idx:05d}.jpg")

updated_s3_link = upload_image_to_s3(img_path, bucket_name, file_name, "image", idx)
understanding = understanding.replace(f"<link>{idx:05d}.jpg</link>", f"<link>{updated_s3_link}</link>")

# Add the understanding text
result += f"\n\n{understanding}\n\n"
# Handle URL images
if img_path.startswith(("http://", "https://")):
try:
img_path = download_image_from_url(img_path)
except Exception as e:
logger.error(f"Error downloading image from URL {img_path}: {e}")
result += match.group(1)
last_end = end
continue

# Get context
context_start = max(0, start - 200)
context_end = min(len(content), end + 200)
context = f"{content[context_start:start]}\n<image>\n{image_tag}\n</image>\n{content[end:context_end]}"

# Process the image
understanding = process_single_image(img_path, context, image_tag, bucket_name, file_name, idx)

if understanding:
result += f"\n\n{understanding}\n\n"
else:
result += match.group(1)

except Exception as e:
logger.error(f"Error processing image {idx}: {e}")
# If there's an error, keep the original markdown image syntax
result += match.group(0)
result += match.group(1)

last_end = end

# Add any remaining text after the last image
result += content[last_end:]

return result
16 changes: 0 additions & 16 deletions source/lambda/job/dep/llm_bot_dep/loaders/docx.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
import os
import sys
import uuid
from datetime import datetime
from pathlib import Path
Expand All @@ -14,9 +13,6 @@
from llm_bot_dep.splitter_utils import MarkdownHeaderTextSplitter
from PIL import Image

# sys.path.append("/home/ubuntu/icyxu/code/solutions/Intelli-Agent/source/lambda/job/dep")


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -118,15 +114,3 @@ def process_doc(s3, **kwargs):
doc_list = splitter.split_text(doc)

return doc_list


# if __name__ == "__main__":
# import boto3

# s3 = boto3.client("s3")
# kwargs = {
# "res_bucket": "ai-customer-service-sharedconstructaicustomerservi-wywyift3c084",
# "bucket": "ai-customer-service-sharedconstructaicustomerservi-wywyift3c084",
# "key": "workshop/CATS.docx",
# }
# process_doc(s3, **kwargs)
1 change: 0 additions & 1 deletion source/lambda/job/dep/llm_bot_dep/loaders/html.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def load(self, file_content: str, bucket_name: str, file_name: str):
html_content = self.clean_html(file_content)
file_content = markdownify.markdownify(html_content, heading_style="ATX")
file_content = process_markdown_images_with_llm(file_content, bucket_name, file_name)
print(file_content)
doc = Document(
page_content=file_content,
metadata={"file_type": "html", "file_path": self.aws_path},
Expand Down
Loading