Skip to content

Commit

Permalink
[Frontend] support image embeds
Browse files Browse the repository at this point in the history
Signed-off-by: chaunceyjiang <[email protected]>
  • Loading branch information
chaunceyjiang committed Feb 28, 2025
1 parent a7f3731 commit aba83db
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 3 deletions.
99 changes: 97 additions & 2 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,29 @@ class AudioURL(TypedDict, total=False):
"""


class ImageEmbeds(TypedDict, total=False):
image_embeds: Required[str]
# image_sizes: Optional[List[int]] TODO(@chaunceyjiang)
image_grid_thw: Optional[str]
"""
Image embeds to be used in the chat completion API.
"""


class ChatCompletionContentPartAudioParam(TypedDict, total=False):
audio_url: Required[AudioURL]

type: Required[Literal["audio_url"]]
"""The type of the content part."""


class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
image_embeds: Required[Union[str, ImageEmbeds]]

type: Required[Literal["image_embeds"]]
"""The type of the content part."""


class VideoURL(TypedDict, total=False):
url: Required[str]
"""
Expand Down Expand Up @@ -108,6 +124,7 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
ChatCompletionContentPartInputAudioParam,
ChatCompletionContentPartVideoParam, ChatCompletionContentPartRefusalParam,
CustomChatCompletionContentSimpleImageParam,
ChatCompletionContentPartImageEmbedsParam,
CustomChatCompletionContentSimpleAudioParam,
CustomChatCompletionContentSimpleVideoParam, str]

Expand Down Expand Up @@ -508,6 +525,10 @@ def mm_placeholder_counts(self) -> Dict[str, int]:
def parse_image(self, image_url: str) -> None:
raise NotImplementedError

@abstractmethod
def parse_image_embeds(self, image_embeds: Union[str, ImageEmbeds]) -> None:
raise NotImplementedError

@abstractmethod
def parse_audio(self, audio_url: str) -> None:
raise NotImplementedError
Expand Down Expand Up @@ -538,6 +559,37 @@ def parse_image(self, image_url: str) -> None:
placeholder = self._tracker.add("image", image)
self._add_placeholder(placeholder)

def _parse_image_embeds_params(self,
image_embeds: ImageEmbeds
) -> Dict[str, Any]:
if image_embeds.get("image_grid_thw", ""):
embedding_url = f"data:image/embeds;base64,{
image_embeds["image_grid_thw"]}"
image_grid_thw = self._connector.\
fetch_image_embedding(embedding_url)
image_embeds["image_grid_thw"] = image_grid_thw

return cast(Dict[str, Any], image_embeds)

def parse_image_embeds(self, image_embeds: Union[str, ImageEmbeds]) -> None:
if isinstance(image_embeds, dict):
image_data = image_embeds.get("image_embeds", "")
embedding_url = f"data:image/embeds;base64,{image_data}"
embedding = self._connector.fetch_image_embedding(embedding_url)

embeds = cast(Dict[str, Any], image_embeds)
embeds["image_embeds"] = embedding # decoded image data
embeds |= self._parse_image_embeds_params(image_embeds)

placeholder = self._tracker.add("image", embeds)

if isinstance(image_embeds, str):
embedding_url = f"data:image/embeds;base64,{image_embeds}"
embedding = self._connector.fetch_image_embedding(embedding_url)
placeholder = self._tracker.add("image", image_embeds)

self._add_placeholder(placeholder)

def parse_audio(self, audio_url: str) -> None:
audio = self._connector.fetch_audio(audio_url)

Expand Down Expand Up @@ -574,6 +626,39 @@ def parse_image(self, image_url: str) -> None:
placeholder = self._tracker.add("image", image_coro)
self._add_placeholder(placeholder)

def _parse_image_embeds_params(self,
image_embeds: ImageEmbeds,
) -> Dict[str, Any]:
if image_embeds.get("image_grid_thw", ""):
embedding_url = f"data:image/embeds;base64,{
image_embeds["image_grid_thw"]}"
image_grid_thw = self._connector.\
fetch_image_embedding(embedding_url)
image_embeds["image_grid_thw"] = image_grid_thw
return cast(Dict[str, Any], image_embeds)

def parse_image_embeds(self, image_embeds: Union[str, ImageEmbeds]) -> None:
future: asyncio.Future[Union[str, Dict[str, Any]]] = asyncio.Future()

if isinstance(image_embeds, dict):
image_data = image_embeds.get("image_embeds", "")
embedding_url = f"data:image/embeds;base64,{image_data}"
embedding = self._connector.\
fetch_image_embedding(embedding_url)
embeds = cast(Dict[str, Any], image_embeds)
embeds["image_embeds"] = embedding # decoded image data
embeds |= self._parse_image_embeds_params(image_embeds)
future.set_result(embeds)

if isinstance(image_embeds, str):
embedding_url = f"data:image/embeds;base64,{image_embeds}"
embedding = self._connector.\
fetch_image_embedding(embedding_url)
future.set_result(embedding)

placeholder = self._tracker.add("image", future)
self._add_placeholder(placeholder)

def parse_audio(self, audio_url: str) -> None:
audio_coro = self._connector.fetch_audio_async(audio_url)

Expand Down Expand Up @@ -679,12 +764,13 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],
# No need to validate using Pydantic again
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
_ImageParser = partial(cast, ChatCompletionContentPartImageParam)
_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
_AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
_VideoParser = partial(cast, ChatCompletionContentPartVideoParam)

_ContentPart: TypeAlias = Union[str, Dict[str, str], InputAudio]
_ContentPart: TypeAlias = Union[str, Dict[str, str], InputAudio, ImageEmbeds]

# Define a mapping from part types to their corresponding parsing functions.
MM_PARSER_MAP: Dict[
Expand All @@ -695,6 +781,8 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],
lambda part: _TextParser(part).get("text", ""),
"image_url":
lambda part: _ImageParser(part).get("image_url", {}).get("url", ""),
"image_embeds":
lambda part: _ImageEmbedsParser(part).get("image_embeds", {}),
"audio_url":
lambda part: _AudioParser(part).get("audio_url", {}).get("url", ""),
"input_audio":
Expand Down Expand Up @@ -764,6 +852,7 @@ def _parse_chat_message_content_mm_part(


VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url",
"image_embeds",
"audio_url", "input_audio", "video_url")


Expand Down Expand Up @@ -838,7 +927,13 @@ def _parse_chat_message_content_part(
str_content = cast(str, content)
mm_parser.parse_image(str_content)
return {'type': 'image'} if wrap_dicts else None

if part_type == "image_embeds":
if isinstance(content, dict):
content = cast(ImageEmbeds, content)
if isinstance(content, str):
content = cast(str, content)
mm_parser.parse_image_embeds(content)
return {'type': 'image'} if wrap_dicts else None
if part_type == "audio_url":
str_content = cast(str, content)
mm_parser.parse_audio(str_content)
Expand Down
19 changes: 19 additions & 0 deletions vllm/multimodal/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,22 @@ def encode_base64(
data = buffer.getvalue()

return base64.b64encode(data).decode('utf-8')


class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]):

def __init__(self) -> None:
super().__init__()

def load_bytes(self, data: bytes) -> torch.Tensor:
buffer = BytesIO(data)
return torch.load(buffer)

def load_base64(self, media_type: str, data: str) -> torch.Tensor:
return self.load_bytes(base64.b64decode(data))

def load_file(self, filepath: Path) -> torch.Tensor:
return torch.load(filepath)

def encode_base64(self, media: torch.Tensor) -> str:
return base64.b64encode(media.numpy()).decode('utf-8')
22 changes: 21 additions & 1 deletion vllm/multimodal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy as np
import numpy.typing as npt
import torch
from PIL import Image

import vllm.envs as envs
Expand All @@ -16,7 +17,7 @@

from .audio import AudioMediaIO
from .base import MediaIO
from .image import ImageMediaIO
from .image import ImageEmbeddingMediaIO, ImageMediaIO
from .inputs import PlaceholderRange
from .video import VideoMediaIO

Expand Down Expand Up @@ -245,6 +246,25 @@ async def fetch_video_async(
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
)

def fetch_image_embedding(
self,
image_embedding_url: str,
) -> torch.Tensor:
"""
Load image embedding from a URL.
"""
image_embedding_io = ImageEmbeddingMediaIO()
url_spec = urlparse(image_embedding_url)

if url_spec.scheme == "data":
return self._load_data_url(url_spec, image_embedding_io)

if url_spec.scheme == "file":
return self._load_file_url(url_spec, image_embedding_io)

msg = "The URL must be either a data or file URL."
raise ValueError(msg)


global_media_connector = MediaConnector()
"""The global :class:`MediaConnector` instance used by vLLM."""
Expand Down

0 comments on commit aba83db

Please sign in to comment.