Skip to content

Commit

Permalink
Add support for passing in Image type to generate
Browse files Browse the repository at this point in the history
  • Loading branch information
ParthSareen committed Jan 6, 2025
1 parent ee349ec commit 0ab8c02
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 6 deletions.
12 changes: 6 additions & 6 deletions ollama/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def generate(
stream: Literal[False] = False,
raw: bool = False,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
images: Optional[Sequence[Union[str, bytes]]] = None,
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> GenerateResponse: ...
Expand All @@ -207,7 +207,7 @@ def generate(
stream: Literal[True] = True,
raw: bool = False,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
images: Optional[Sequence[Union[str, bytes]]] = None,
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Iterator[GenerateResponse]: ...
Expand All @@ -224,7 +224,7 @@ def generate(
stream: bool = False,
raw: Optional[bool] = None,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
images: Optional[Sequence[Union[str, bytes]]] = None,
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Union[GenerateResponse, Iterator[GenerateResponse]]:
Expand Down Expand Up @@ -691,7 +691,7 @@ async def generate(
stream: Literal[False] = False,
raw: bool = False,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
images: Optional[Sequence[Union[str, bytes]]] = None,
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> GenerateResponse: ...
Expand All @@ -709,7 +709,7 @@ async def generate(
stream: Literal[True] = True,
raw: bool = False,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
images: Optional[Sequence[Union[str, bytes]]] = None,
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> AsyncIterator[GenerateResponse]: ...
Expand All @@ -726,7 +726,7 @@ async def generate(
stream: bool = False,
raw: Optional[bool] = None,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
images: Optional[Sequence[Union[str, bytes]]] = None,
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Union[GenerateResponse, AsyncIterator[GenerateResponse]]:
Expand Down
41 changes: 41 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import json
from pydantic import ValidationError, BaseModel
from ollama._types import Image
import pytest
import tempfile
from pathlib import Path
Expand Down Expand Up @@ -284,6 +285,46 @@ def test_client_generate(httpserver: HTTPServer):
assert response['response'] == 'Because it is.'


def test_client_generate_with_image_type(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/generate',
method='POST',
json={
'model': 'dummy',
'prompt': 'What is in this image?',
'stream': False,
'images': [PNG_BASE64],
},
).respond_with_json(
{
'model': 'dummy',
'response': 'A blue sky.',
}
)

client = Client(httpserver.url_for('/'))
response = client.generate('dummy', 'What is in this image?', images=[Image(value=PNG_BASE64)])
assert response['model'] == 'dummy'
assert response['response'] == 'A blue sky.'


def test_client_generate_with_invalid_image(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/generate',
method='POST',
json={
'model': 'dummy',
'prompt': 'What is in this image?',
'stream': False,
'images': ['invalid_base64'],
},
).respond_with_json({'error': 'Invalid image data'}, status=400)

client = Client(httpserver.url_for('/'))
with pytest.raises(ValueError):
client.generate('dummy', 'What is in this image?', images=[Image(value='invalid_base64')])


def test_client_generate_stream(httpserver: HTTPServer):
def stream_handler(_: Request):
def generate():
Expand Down

0 comments on commit 0ab8c02

Please sign in to comment.