Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] support image embeds #13955

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 97 additions & 2 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,29 @@ class AudioURL(TypedDict, total=False):
"""


class ImageEmbeds(TypedDict, total=False):
image_embeds: Required[str]
# image_sizes: Optional[List[int]] TODO(@chaunceyjiang)
image_grid_thw: Optional[str]
"""
Image embeds to be used in the chat completion API.
"""


class ChatCompletionContentPartAudioParam(TypedDict, total=False):
audio_url: Required[AudioURL]

type: Required[Literal["audio_url"]]
"""The type of the content part."""


class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
image_embeds: Required[Union[str, ImageEmbeds]]

type: Required[Literal["image_embeds"]]
"""The type of the content part."""


class VideoURL(TypedDict, total=False):
url: Required[str]
"""
Expand Down Expand Up @@ -108,6 +124,7 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
ChatCompletionContentPartInputAudioParam,
ChatCompletionContentPartVideoParam, ChatCompletionContentPartRefusalParam,
CustomChatCompletionContentSimpleImageParam,
ChatCompletionContentPartImageEmbedsParam,
CustomChatCompletionContentSimpleAudioParam,
CustomChatCompletionContentSimpleVideoParam, str]

Expand Down Expand Up @@ -508,6 +525,10 @@ def mm_placeholder_counts(self) -> Dict[str, int]:
def parse_image(self, image_url: str) -> None:
raise NotImplementedError

@abstractmethod
def parse_image_embeds(self, image_embeds: Union[str, ImageEmbeds]) -> None:
raise NotImplementedError

@abstractmethod
def parse_audio(self, audio_url: str) -> None:
raise NotImplementedError
Expand Down Expand Up @@ -538,6 +559,37 @@ def parse_image(self, image_url: str) -> None:
placeholder = self._tracker.add("image", image)
self._add_placeholder(placeholder)

def _parse_image_embeds_params(self,
image_embeds: ImageEmbeds
) -> Dict[str, Any]:
if image_embeds.get("image_grid_thw", ""):
embedding_url = f"data:image/embeds;base64,{
image_embeds["image_grid_thw"]}"
image_grid_thw = self._connector.\
fetch_image_embedding(embedding_url)
image_embeds["image_grid_thw"] = image_grid_thw

return cast(Dict[str, Any], image_embeds)

def parse_image_embeds(self, image_embeds: Union[str, ImageEmbeds]) -> None:
if isinstance(image_embeds, dict):
image_data = image_embeds.get("image_embeds", "")
embedding_url = f"data:image/embeds;base64,{image_data}"
embedding = self._connector.fetch_image_embedding(embedding_url)

embeds = cast(Dict[str, Any], image_embeds)
embeds["image_embeds"] = embedding # decoded image data
embeds |= self._parse_image_embeds_params(image_embeds)

placeholder = self._tracker.add("image", embeds)
Copy link
Contributor Author

@chaunceyjiang chaunceyjiang Feb 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DarkLight1337
I'm a bit confused about this part.

After embeds is added to _items_by_modality, it will be processed into

multi_modal_data = {
    "image": [{ 
        "image_embeds": image_embeds,
        # image_grid_thw is needed to calculate positional encoding.
        "image_grid_thw": torch.load(...),  # torch.Tensor of shape (1, 3),
    }] #### <<<<<<- This is a list.
}

https://docs.vllm.ai/en/latest/serving/multimodal_inputs.html#embedding-inputs

multi_modal_data = {
    "image": {
        "image_embeds": image_embeds,
        # image_grid_thw is needed to calculate positional encoding.
        "image_grid_thw": torch.load(...),  # torch.Tensor of shape (1, 3),
    } #### <<<<<- This is a dict.
}

I believe I should convert the image_embeds passed by the user into the format mentioned above to pass to the VLLM engine.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The data in the tracker is split by multimodal items. You should perform an extra step when combining the inputs together to convert from list of dicts to dict of lists.


if isinstance(image_embeds, str):
embedding_url = f"data:image/embeds;base64,{image_embeds}"
embedding = self._connector.fetch_image_embedding(embedding_url)
placeholder = self._tracker.add("image", image_embeds)

self._add_placeholder(placeholder)

def parse_audio(self, audio_url: str) -> None:
audio = self._connector.fetch_audio(audio_url)

Expand Down Expand Up @@ -574,6 +626,39 @@ def parse_image(self, image_url: str) -> None:
placeholder = self._tracker.add("image", image_coro)
self._add_placeholder(placeholder)

def _parse_image_embeds_params(self,
image_embeds: ImageEmbeds,
) -> Dict[str, Any]:
if image_embeds.get("image_grid_thw", ""):
embedding_url = f"data:image/embeds;base64,{
image_embeds["image_grid_thw"]}"
image_grid_thw = self._connector.\
fetch_image_embedding(embedding_url)
image_embeds["image_grid_thw"] = image_grid_thw
return cast(Dict[str, Any], image_embeds)

def parse_image_embeds(self, image_embeds: Union[str, ImageEmbeds]) -> None:
future: asyncio.Future[Union[str, Dict[str, Any]]] = asyncio.Future()

if isinstance(image_embeds, dict):
image_data = image_embeds.get("image_embeds", "")
embedding_url = f"data:image/embeds;base64,{image_data}"
embedding = self._connector.\
fetch_image_embedding(embedding_url)
embeds = cast(Dict[str, Any], image_embeds)
embeds["image_embeds"] = embedding # decoded image data
embeds |= self._parse_image_embeds_params(image_embeds)
future.set_result(embeds)

if isinstance(image_embeds, str):
embedding_url = f"data:image/embeds;base64,{image_embeds}"
embedding = self._connector.\
fetch_image_embedding(embedding_url)
future.set_result(embedding)

placeholder = self._tracker.add("image", future)
self._add_placeholder(placeholder)

def parse_audio(self, audio_url: str) -> None:
audio_coro = self._connector.fetch_audio_async(audio_url)

Expand Down Expand Up @@ -679,12 +764,13 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],
# No need to validate using Pydantic again
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
_ImageParser = partial(cast, ChatCompletionContentPartImageParam)
_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
_AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
_VideoParser = partial(cast, ChatCompletionContentPartVideoParam)

_ContentPart: TypeAlias = Union[str, Dict[str, str], InputAudio]
_ContentPart: TypeAlias = Union[str, Dict[str, str], InputAudio, ImageEmbeds]

# Define a mapping from part types to their corresponding parsing functions.
MM_PARSER_MAP: Dict[
Expand All @@ -695,6 +781,8 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],
lambda part: _TextParser(part).get("text", ""),
"image_url":
lambda part: _ImageParser(part).get("image_url", {}).get("url", ""),
"image_embeds":
lambda part: _ImageEmbedsParser(part).get("image_embeds", {}),
"audio_url":
lambda part: _AudioParser(part).get("audio_url", {}).get("url", ""),
"input_audio":
Expand Down Expand Up @@ -764,6 +852,7 @@ def _parse_chat_message_content_mm_part(


VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url",
"image_embeds",
"audio_url", "input_audio", "video_url")


Expand Down Expand Up @@ -838,7 +927,13 @@ def _parse_chat_message_content_part(
str_content = cast(str, content)
mm_parser.parse_image(str_content)
return {'type': 'image'} if wrap_dicts else None

if part_type == "image_embeds":
if isinstance(content, dict):
content = cast(ImageEmbeds, content)
if isinstance(content, str):
content = cast(str, content)
mm_parser.parse_image_embeds(content)
return {'type': 'image'} if wrap_dicts else None
if part_type == "audio_url":
str_content = cast(str, content)
mm_parser.parse_audio(str_content)
Expand Down
19 changes: 19 additions & 0 deletions vllm/multimodal/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,22 @@ def encode_base64(
data = buffer.getvalue()

return base64.b64encode(data).decode('utf-8')


class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]):

def __init__(self) -> None:
super().__init__()

def load_bytes(self, data: bytes) -> torch.Tensor:
buffer = BytesIO(data)
return torch.load(buffer)

def load_base64(self, media_type: str, data: str) -> torch.Tensor:
return self.load_bytes(base64.b64decode(data))

def load_file(self, filepath: Path) -> torch.Tensor:
return torch.load(filepath)

def encode_base64(self, media: torch.Tensor) -> str:
return base64.b64encode(media.numpy()).decode('utf-8')
22 changes: 21 additions & 1 deletion vllm/multimodal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy as np
import numpy.typing as npt
import torch
from PIL import Image

import vllm.envs as envs
Expand All @@ -16,7 +17,7 @@

from .audio import AudioMediaIO
from .base import MediaIO
from .image import ImageMediaIO
from .image import ImageEmbeddingMediaIO, ImageMediaIO
from .inputs import PlaceholderRange
from .video import VideoMediaIO

Expand Down Expand Up @@ -245,6 +246,25 @@ async def fetch_video_async(
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
)

def fetch_image_embedding(
self,
image_embedding_url: str,
) -> torch.Tensor:
"""
Load image embedding from a URL.
"""
image_embedding_io = ImageEmbeddingMediaIO()
url_spec = urlparse(image_embedding_url)

if url_spec.scheme == "data":
return self._load_data_url(url_spec, image_embedding_io)

if url_spec.scheme == "file":
return self._load_file_url(url_spec, image_embedding_io)

msg = "The URL must be either a data or file URL."
raise ValueError(msg)


global_media_connector = MediaConnector()
"""The global :class:`MediaConnector` instance used by vLLM."""
Expand Down
Loading