diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index c50c631dafccc..dd8c3fb5ed5bc 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -48,6 +48,15 @@ class AudioURL(TypedDict, total=False): """ +class ImageEmbeds(TypedDict, total=False): + image_embeds: Required[str] + # image_sizes: Optional[List[int]] TODO(@chaunceyjiang) + # image_grid_thw: Optional[List[int]] TODO(@chaunceyjiang) + """ + Image embeds to be used in the chat completion API. + """ + + class ChatCompletionContentPartAudioParam(TypedDict, total=False): audio_url: Required[AudioURL] @@ -55,6 +64,13 @@ class ChatCompletionContentPartAudioParam(TypedDict, total=False): """The type of the content part.""" +class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False): + image_embeds: Required[Union[str, ImageEmbeds]] + + type: Required[Literal["image_embeds"]] + """The type of the content part.""" + + class VideoURL(TypedDict, total=False): url: Required[str] """ @@ -108,6 +124,7 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False): ChatCompletionContentPartInputAudioParam, ChatCompletionContentPartVideoParam, ChatCompletionContentPartRefusalParam, CustomChatCompletionContentSimpleImageParam, + ChatCompletionContentPartImageEmbedsParam, CustomChatCompletionContentSimpleAudioParam, CustomChatCompletionContentSimpleVideoParam, str] @@ -508,6 +525,10 @@ def mm_placeholder_counts(self) -> Dict[str, int]: def parse_image(self, image_url: str) -> None: raise NotImplementedError + @abstractmethod + def parse_image_embeds(self, image_embeds: Union[str, ImageEmbeds]) -> None: + raise NotImplementedError + @abstractmethod def parse_audio(self, audio_url: str) -> None: raise NotImplementedError @@ -538,6 +559,18 @@ def parse_image(self, image_url: str) -> None: placeholder = self._tracker.add("image", image) self._add_placeholder(placeholder) + def parse_image_embeds(self, image_embeds: Union[str, ImageEmbeds]) -> None: + if isinstance(image_embeds, dict): + embedding = self._connector.fetch_image_embedding( + image_embeds.get("image_embeds", "")) + placeholder = self._tracker.add("image", + image_embeds | + {"image_embeds": embedding}) + if isinstance(image_embeds, str): + embedding = self._connector.fetch_image_embedding(image_embeds) + placeholder = self._tracker.add("image", image_embeds) + self._add_placeholder(placeholder) + def parse_audio(self, audio_url: str) -> None: audio = self._connector.fetch_audio(audio_url) @@ -574,6 +607,20 @@ def parse_image(self, image_url: str) -> None: placeholder = self._tracker.add("image", image_coro) self._add_placeholder(placeholder) + def parse_image_embeds(self, image_embeds: Union[str, ImageEmbeds]) -> None: + future: asyncio.Future[Union[str, ImageEmbeds]] = asyncio.Future() + if isinstance(image_embeds, dict): + embedding = self._connector.fetch_image_embedding( + image_embeds.get("image_embeds", "")) + image_embeds = image_embeds | {"image_embeds": embedding} + future.set_result(image_embeds) + placeholder = self._tracker.add("image", future) + if isinstance(image_embeds, str): + embedding = self._connector.fetch_image_embedding(image_embeds) + future.set_result(embedding) + placeholder = self._tracker.add("image", future) + self._add_placeholder(placeholder) + def parse_audio(self, audio_url: str) -> None: audio_coro = self._connector.fetch_audio_async(audio_url) @@ -679,12 +726,13 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int], # No need to validate using Pydantic again _TextParser = partial(cast, ChatCompletionContentPartTextParam) _ImageParser = partial(cast, ChatCompletionContentPartImageParam) +_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam) _AudioParser = partial(cast, ChatCompletionContentPartAudioParam) _InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam) _RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam) _VideoParser = partial(cast, ChatCompletionContentPartVideoParam) -_ContentPart: TypeAlias = Union[str, Dict[str, str], InputAudio] +_ContentPart: TypeAlias = Union[str, Dict[str, str], InputAudio, ImageEmbeds] # Define a mapping from part types to their corresponding parsing functions. MM_PARSER_MAP: Dict[ @@ -695,6 +743,8 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int], lambda part: _TextParser(part).get("text", ""), "image_url": lambda part: _ImageParser(part).get("image_url", {}).get("url", ""), + "image_embeds": + lambda part: _ImageEmbedsParser(part).get("image_embeds", {}), "audio_url": lambda part: _AudioParser(part).get("audio_url", {}).get("url", ""), "input_audio": @@ -764,6 +814,7 @@ def _parse_chat_message_content_mm_part( VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url", + "image_embeds", "audio_url", "input_audio", "video_url") @@ -838,7 +889,13 @@ def _parse_chat_message_content_part( str_content = cast(str, content) mm_parser.parse_image(str_content) return {'type': 'image'} if wrap_dicts else None - + if part_type == "image_embeds": + if isinstance(content, dict): + content = cast(ImageEmbeds, content) + if isinstance(content, str): + content = cast(str, content) + mm_parser.parse_image_embeds(content) + return {'type': 'image'} if wrap_dicts else None if part_type == "audio_url": str_content = cast(str, content) mm_parser.parse_audio(str_content) diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 98ece8f806f1d..9850a24601dd1 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -134,3 +134,18 @@ def encode_base64( data = buffer.getvalue() return base64.b64encode(data).decode('utf-8') + + +class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]): + + def load_bytes(self, data: bytes) -> torch.Tensor: + return torch.from_numpy(torch.tensor(data)) + + def load_base64(self, media_type: str, data: str) -> torch.Tensor: + return self.load_bytes(base64.b64decode(data)) + + def load_file(self, filepath: Path) -> torch.Tensor: + return torch.load(filepath) + + def encode_base64(self, media: torch.Tensor) -> str: + return base64.b64encode(media.numpy()).decode('utf-8') \ No newline at end of file diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 6e6c10b34a25f..32c26ea5513fe 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -7,6 +7,7 @@ import numpy as np import numpy.typing as npt +import torch from PIL import Image import vllm.envs as envs @@ -16,7 +17,7 @@ from .audio import AudioMediaIO from .base import MediaIO -from .image import ImageMediaIO +from .image import ImageEmbeddingMediaIO, ImageMediaIO from .inputs import PlaceholderRange from .video import VideoMediaIO @@ -245,6 +246,25 @@ async def fetch_video_async( fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT, ) + def fetch_image_embedding( + self, + image_embedding_url: str, + ) -> torch.Tensor: + """ + Load image embedding from a URL. + """ + image_embedding_io = ImageEmbeddingMediaIO() + url_spec = urlparse(image_embedding_url) + + if url_spec.scheme == "data": + return self._load_data_url(url_spec, image_embedding_io) + + if url_spec.scheme == "file": + return self._load_file_url(url_spec, image_embedding_io) + + msg = "The URL must be either a data or file URL." + raise ValueError(msg) + global_media_connector = MediaConnector() """The global :class:`MediaConnector` instance used by vLLM."""