From 0ab8c026890ad91676e886ccf3ff3148e4e1fb71 Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Mon, 6 Jan 2025 10:50:20 -0800 Subject: [PATCH] Add support for passing in Image type to generate --- ollama/_client.py | 12 ++++++------ tests/test_client.py | 41 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 6 deletions(-) diff --git a/ollama/_client.py b/ollama/_client.py index 4b62765..56ebb77 100644 --- a/ollama/_client.py +++ b/ollama/_client.py @@ -189,7 +189,7 @@ def generate( stream: Literal[False] = False, raw: bool = False, format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, - images: Optional[Sequence[Union[str, bytes]]] = None, + images: Optional[Sequence[Union[str, bytes, Image]]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, ) -> GenerateResponse: ... @@ -207,7 +207,7 @@ def generate( stream: Literal[True] = True, raw: bool = False, format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, - images: Optional[Sequence[Union[str, bytes]]] = None, + images: Optional[Sequence[Union[str, bytes, Image]]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, ) -> Iterator[GenerateResponse]: ... @@ -224,7 +224,7 @@ def generate( stream: bool = False, raw: Optional[bool] = None, format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, - images: Optional[Sequence[Union[str, bytes]]] = None, + images: Optional[Sequence[Union[str, bytes, Image]]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, ) -> Union[GenerateResponse, Iterator[GenerateResponse]]: @@ -691,7 +691,7 @@ async def generate( stream: Literal[False] = False, raw: bool = False, format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, - images: Optional[Sequence[Union[str, bytes]]] = None, + images: Optional[Sequence[Union[str, bytes, Image]]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, ) -> GenerateResponse: ... @@ -709,7 +709,7 @@ async def generate( stream: Literal[True] = True, raw: bool = False, format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, - images: Optional[Sequence[Union[str, bytes]]] = None, + images: Optional[Sequence[Union[str, bytes, Image]]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, ) -> AsyncIterator[GenerateResponse]: ... @@ -726,7 +726,7 @@ async def generate( stream: bool = False, raw: Optional[bool] = None, format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, - images: Optional[Sequence[Union[str, bytes]]] = None, + images: Optional[Sequence[Union[str, bytes, Image]]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, ) -> Union[GenerateResponse, AsyncIterator[GenerateResponse]]: diff --git a/tests/test_client.py b/tests/test_client.py index d837a1a..8564222 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -2,6 +2,7 @@ import os import json from pydantic import ValidationError, BaseModel +from ollama._types import Image import pytest import tempfile from pathlib import Path @@ -284,6 +285,46 @@ def test_client_generate(httpserver: HTTPServer): assert response['response'] == 'Because it is.' +def test_client_generate_with_image_type(httpserver: HTTPServer): + httpserver.expect_ordered_request( + '/api/generate', + method='POST', + json={ + 'model': 'dummy', + 'prompt': 'What is in this image?', + 'stream': False, + 'images': [PNG_BASE64], + }, + ).respond_with_json( + { + 'model': 'dummy', + 'response': 'A blue sky.', + } + ) + + client = Client(httpserver.url_for('/')) + response = client.generate('dummy', 'What is in this image?', images=[Image(value=PNG_BASE64)]) + assert response['model'] == 'dummy' + assert response['response'] == 'A blue sky.' + + +def test_client_generate_with_invalid_image(httpserver: HTTPServer): + httpserver.expect_ordered_request( + '/api/generate', + method='POST', + json={ + 'model': 'dummy', + 'prompt': 'What is in this image?', + 'stream': False, + 'images': ['invalid_base64'], + }, + ).respond_with_json({'error': 'Invalid image data'}, status=400) + + client = Client(httpserver.url_for('/')) + with pytest.raises(ValueError): + client.generate('dummy', 'What is in this image?', images=[Image(value='invalid_base64')]) + + def test_client_generate_stream(httpserver: HTTPServer): def stream_handler(_: Request): def generate():