diff --git a/ollama/__init__.py b/ollama/__init__.py index 423e12c..2d0e94f 100644 --- a/ollama/__init__.py +++ b/ollama/__init__.py @@ -1,4 +1,5 @@ -from ollama._client import Client, AsyncClient, Message, Options +from ollama._client import Client, AsyncClient +from ollama._types import Message, Options __all__ = [ 'Client', @@ -16,26 +17,14 @@ 'show', ] -_default_client = Client() +_client = Client() -generate = _default_client.generate -chat = _default_client.chat -pull = _default_client.pull -push = _default_client.push -create = _default_client.create -delete = _default_client.delete -list = _default_client.list -copy = _default_client.copy -show = _default_client.show - -_async_default_client = AsyncClient() - -async_generate = _async_default_client.generate -async_chat = _async_default_client.chat -async_pull = _async_default_client.pull -async_push = _async_default_client.push -async_create = _async_default_client.create -async_delete = _async_default_client.delete -async_list = _async_default_client.list -async_copy = _async_default_client.copy -async_show = _async_default_client.show +generate = _client.generate +chat = _client.chat +pull = _client.pull +push = _client.push +create = _client.create +delete = _client.delete +list = _client.list +copy = _client.copy +show = _client.show diff --git a/ollama/_client.py b/ollama/_client.py index 997e3e2..19ec847 100644 --- a/ollama/_client.py +++ b/ollama/_client.py @@ -1,3 +1,4 @@ +import os import io import json import httpx @@ -19,12 +20,13 @@ class BaseClient: - def __init__(self, client, base_url: str = 'http://127.0.0.1:11434') -> None: + def __init__(self, client, base_url: Optional[str] = None) -> None: + base_url = base_url or os.getenv('OLLAMA_HOST', 'http://127.0.0.1:11434') self._client = client(base_url=base_url, follow_redirects=True, timeout=None) class Client(BaseClient): - def __init__(self, base_url: str = 'http://localhost:11434') -> None: + def __init__(self, base_url: Optional[str] = None) -> None: super().__init__(httpx.Client, base_url) def _request(self, method: str, url: str, **kwargs) -> httpx.Response: @@ -32,16 +34,21 @@ def _request(self, method: str, url: str, **kwargs) -> httpx.Response: response.raise_for_status() return response - def _request_json(self, method: str, url: str, **kwargs) -> Mapping[str, Any]: - return self._request(method, url, **kwargs).json() - def _stream(self, method: str, url: str, **kwargs) -> Iterator[Mapping[str, Any]]: with self._client.stream(method, url, **kwargs) as r: for line in r.iter_lines(): - part = json.loads(line) - if e := part.get('error'): + partial = json.loads(line) + if e := partial.get('error'): raise Exception(e) - yield part + yield partial + + def _request_stream( + self, + *args, + stream: bool = False, + **kwargs, + ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]: + return self._stream(*args, **kwargs) if stream else self._request(*args, **kwargs).json() def generate( self, @@ -59,8 +66,7 @@ def generate( if not model: raise Exception('must provide a model') - fn = self._stream if stream else self._request_json - return fn( + return self._request_stream( 'POST', '/api/generate', json={ @@ -75,6 +81,7 @@ def generate( 'format': format, 'options': options or {}, }, + stream=stream, ) def chat( @@ -98,8 +105,7 @@ def chat( if images := message.get('images'): message['images'] = [_encode_image(image) for image in images] - fn = self._stream if stream else self._request_json - return fn( + return self._request_stream( 'POST', '/api/chat', json={ @@ -109,6 +115,7 @@ def chat( 'format': format, 'options': options or {}, }, + stream=stream, ) def pull( @@ -117,8 +124,7 @@ def pull( insecure: bool = False, stream: bool = False, ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]: - fn = self._stream if stream else self._request_json - return fn( + return self._request_stream( 'POST', '/api/pull', json={ @@ -126,6 +132,7 @@ def pull( 'insecure': insecure, 'stream': stream, }, + stream=stream, ) def push( @@ -134,8 +141,7 @@ def push( insecure: bool = False, stream: bool = False, ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]: - fn = self._stream if stream else self._request_json - return fn( + return self._request_stream( 'POST', '/api/push', json={ @@ -143,6 +149,7 @@ def push( 'insecure': insecure, 'stream': stream, }, + stream=stream, ) def create( @@ -159,8 +166,7 @@ def create( else: raise Exception('must provide either path or modelfile') - fn = self._stream if stream else self._request_json - return fn( + return self._request_stream( 'POST', '/api/create', json={ @@ -168,6 +174,7 @@ def create( 'modelfile': modelfile, 'stream': stream, }, + stream=stream, ) def _parse_modelfile(self, modelfile: str, base: Optional[Path] = None) -> str: @@ -212,18 +219,18 @@ def delete(self, model: str) -> Mapping[str, Any]: return {'status': 'success' if response.status_code == 200 else 'error'} def list(self) -> Mapping[str, Any]: - return self._request_json('GET', '/api/tags').get('models', []) + return self._request('GET', '/api/tags').json().get('models', []) def copy(self, source: str, target: str) -> Mapping[str, Any]: response = self._request('POST', '/api/copy', json={'source': source, 'destination': target}) return {'status': 'success' if response.status_code == 200 else 'error'} def show(self, model: str) -> Mapping[str, Any]: - return self._request_json('GET', '/api/show', json={'name': model}) + return self._request('GET', '/api/show', json={'name': model}).json() class AsyncClient(BaseClient): - def __init__(self, base_url: str = 'http://localhost:11434') -> None: + def __init__(self, base_url: Optional[str] = None) -> None: super().__init__(httpx.AsyncClient, base_url) async def _request(self, method: str, url: str, **kwargs) -> httpx.Response: @@ -231,21 +238,29 @@ async def _request(self, method: str, url: str, **kwargs) -> httpx.Response: response.raise_for_status() return response - async def _request_json(self, method: str, url: str, **kwargs) -> Mapping[str, Any]: - response = await self._request(method, url, **kwargs) - return response.json() - async def _stream(self, method: str, url: str, **kwargs) -> AsyncIterator[Mapping[str, Any]]: async def inner(): async with self._client.stream(method, url, **kwargs) as r: async for line in r.aiter_lines(): - part = json.loads(line) - if e := part.get('error'): + partial = json.loads(line) + if e := partial.get('error'): raise Exception(e) - yield part + yield partial return inner() + async def _request_stream( + self, + *args, + stream: bool = False, + **kwargs, + ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: + if stream: + return await self._stream(*args, **kwargs) + + response = await self._request(*args, **kwargs) + return response.json() + async def generate( self, model: str = '', @@ -262,8 +277,7 @@ async def generate( if not model: raise Exception('must provide a model') - fn = self._stream if stream else self._request_json - return await fn( + return await self._request_stream( 'POST', '/api/generate', json={ @@ -278,6 +292,7 @@ async def generate( 'format': format, 'options': options or {}, }, + stream=stream, ) async def chat( @@ -301,8 +316,7 @@ async def chat( if images := message.get('images'): message['images'] = [_encode_image(image) for image in images] - fn = self._stream if stream else self._request_json - return await fn( + return await self._request_stream( 'POST', '/api/chat', json={ @@ -312,6 +326,7 @@ async def chat( 'format': format, 'options': options or {}, }, + stream=stream, ) async def pull( @@ -320,8 +335,7 @@ async def pull( insecure: bool = False, stream: bool = False, ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: - fn = self._stream if stream else self._request_json - return await fn( + return await self._request_stream( 'POST', '/api/pull', json={ @@ -329,6 +343,7 @@ async def pull( 'insecure': insecure, 'stream': stream, }, + stream=stream, ) async def push( @@ -337,8 +352,7 @@ async def push( insecure: bool = False, stream: bool = False, ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: - fn = self._stream if stream else self._request_json - return await fn( + return await self._request_stream( 'POST', '/api/push', json={ @@ -346,6 +360,7 @@ async def push( 'insecure': insecure, 'stream': stream, }, + stream=stream, ) async def create( @@ -362,8 +377,7 @@ async def create( else: raise Exception('must provide either path or modelfile') - fn = self._stream if stream else self._request_json - return await fn( + return await self._request_stream( 'POST', '/api/create', json={ @@ -371,6 +385,7 @@ async def create( 'modelfile': modelfile, 'stream': stream, }, + stream=stream, ) async def _parse_modelfile(self, modelfile: str, base: Optional[Path] = None) -> str: @@ -422,15 +437,16 @@ async def delete(self, model: str) -> Mapping[str, Any]: return {'status': 'success' if response.status_code == 200 else 'error'} async def list(self) -> Mapping[str, Any]: - response = await self._request_json('GET', '/api/tags') - return response.get('models', []) + response = await self._request('GET', '/api/tags') + return response.json().get('models', []) async def copy(self, source: str, target: str) -> Mapping[str, Any]: response = await self._request('POST', '/api/copy', json={'source': source, 'destination': target}) return {'status': 'success' if response.status_code == 200 else 'error'} async def show(self, model: str) -> Mapping[str, Any]: - return await self._request_json('GET', '/api/show', json={'name': model}) + response = await self._request('GET', '/api/show', json={'name': model}) + return response.json() def _encode_image(image) -> str: