From bf56bed81619929ddaec96ae15e6ac51a5a4df9c Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 9 Jan 2024 13:04:07 -0800 Subject: [PATCH 1/2] remove default async client --- ollama/__init__.py | 35 ++++++++++++----------------------- ollama/_client.py | 20 +++++++++++--------- 2 files changed, 23 insertions(+), 32 deletions(-) diff --git a/ollama/__init__.py b/ollama/__init__.py index 423e12c..2d0e94f 100644 --- a/ollama/__init__.py +++ b/ollama/__init__.py @@ -1,4 +1,5 @@ -from ollama._client import Client, AsyncClient, Message, Options +from ollama._client import Client, AsyncClient +from ollama._types import Message, Options __all__ = [ 'Client', @@ -16,26 +17,14 @@ 'show', ] -_default_client = Client() +_client = Client() -generate = _default_client.generate -chat = _default_client.chat -pull = _default_client.pull -push = _default_client.push -create = _default_client.create -delete = _default_client.delete -list = _default_client.list -copy = _default_client.copy -show = _default_client.show - -_async_default_client = AsyncClient() - -async_generate = _async_default_client.generate -async_chat = _async_default_client.chat -async_pull = _async_default_client.pull -async_push = _async_default_client.push -async_create = _async_default_client.create -async_delete = _async_default_client.delete -async_list = _async_default_client.list -async_copy = _async_default_client.copy -async_show = _async_default_client.show +generate = _client.generate +chat = _client.chat +pull = _client.pull +push = _client.push +create = _client.create +delete = _client.delete +list = _client.list +copy = _client.copy +show = _client.show diff --git a/ollama/_client.py b/ollama/_client.py index 997e3e2..9110212 100644 --- a/ollama/_client.py +++ b/ollama/_client.py @@ -1,3 +1,4 @@ +import os import io import json import httpx @@ -19,12 +20,13 @@ class BaseClient: - def __init__(self, client, base_url: str = 'http://127.0.0.1:11434') -> None: + def __init__(self, client, base_url: Optional[str] = None) -> None: + base_url = base_url or os.getenv('OLLAMA_HOST', 'http://127.0.0.1:11434') self._client = client(base_url=base_url, follow_redirects=True, timeout=None) class Client(BaseClient): - def __init__(self, base_url: str = 'http://localhost:11434') -> None: + def __init__(self, base_url: Optional[str] = None) -> None: super().__init__(httpx.Client, base_url) def _request(self, method: str, url: str, **kwargs) -> httpx.Response: @@ -38,10 +40,10 @@ def _request_json(self, method: str, url: str, **kwargs) -> Mapping[str, Any]: def _stream(self, method: str, url: str, **kwargs) -> Iterator[Mapping[str, Any]]: with self._client.stream(method, url, **kwargs) as r: for line in r.iter_lines(): - part = json.loads(line) - if e := part.get('error'): + partial = json.loads(line) + if e := partial.get('error'): raise Exception(e) - yield part + yield partial def generate( self, @@ -223,7 +225,7 @@ def show(self, model: str) -> Mapping[str, Any]: class AsyncClient(BaseClient): - def __init__(self, base_url: str = 'http://localhost:11434') -> None: + def __init__(self, base_url: Optional[str] = None) -> None: super().__init__(httpx.AsyncClient, base_url) async def _request(self, method: str, url: str, **kwargs) -> httpx.Response: @@ -239,10 +241,10 @@ async def _stream(self, method: str, url: str, **kwargs) -> AsyncIterator[Mappin async def inner(): async with self._client.stream(method, url, **kwargs) as r: async for line in r.aiter_lines(): - part = json.loads(line) - if e := part.get('error'): + partial = json.loads(line) + if e := partial.get('error'): raise Exception(e) - yield part + yield partial return inner() From c67ef1ae340e49ef4fbef7e3c5a1cce80a951da2 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 9 Jan 2024 15:24:30 -0800 Subject: [PATCH 2/2] fix: type hints --- ollama/_client.py | 78 ++++++++++++++++++++++++++++------------------- 1 file changed, 46 insertions(+), 32 deletions(-) diff --git a/ollama/_client.py b/ollama/_client.py index 9110212..19ec847 100644 --- a/ollama/_client.py +++ b/ollama/_client.py @@ -34,9 +34,6 @@ def _request(self, method: str, url: str, **kwargs) -> httpx.Response: response.raise_for_status() return response - def _request_json(self, method: str, url: str, **kwargs) -> Mapping[str, Any]: - return self._request(method, url, **kwargs).json() - def _stream(self, method: str, url: str, **kwargs) -> Iterator[Mapping[str, Any]]: with self._client.stream(method, url, **kwargs) as r: for line in r.iter_lines(): @@ -45,6 +42,14 @@ def _stream(self, method: str, url: str, **kwargs) -> Iterator[Mapping[str, Any] raise Exception(e) yield partial + def _request_stream( + self, + *args, + stream: bool = False, + **kwargs, + ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]: + return self._stream(*args, **kwargs) if stream else self._request(*args, **kwargs).json() + def generate( self, model: str = '', @@ -61,8 +66,7 @@ def generate( if not model: raise Exception('must provide a model') - fn = self._stream if stream else self._request_json - return fn( + return self._request_stream( 'POST', '/api/generate', json={ @@ -77,6 +81,7 @@ def generate( 'format': format, 'options': options or {}, }, + stream=stream, ) def chat( @@ -100,8 +105,7 @@ def chat( if images := message.get('images'): message['images'] = [_encode_image(image) for image in images] - fn = self._stream if stream else self._request_json - return fn( + return self._request_stream( 'POST', '/api/chat', json={ @@ -111,6 +115,7 @@ def chat( 'format': format, 'options': options or {}, }, + stream=stream, ) def pull( @@ -119,8 +124,7 @@ def pull( insecure: bool = False, stream: bool = False, ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]: - fn = self._stream if stream else self._request_json - return fn( + return self._request_stream( 'POST', '/api/pull', json={ @@ -128,6 +132,7 @@ def pull( 'insecure': insecure, 'stream': stream, }, + stream=stream, ) def push( @@ -136,8 +141,7 @@ def push( insecure: bool = False, stream: bool = False, ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]: - fn = self._stream if stream else self._request_json - return fn( + return self._request_stream( 'POST', '/api/push', json={ @@ -145,6 +149,7 @@ def push( 'insecure': insecure, 'stream': stream, }, + stream=stream, ) def create( @@ -161,8 +166,7 @@ def create( else: raise Exception('must provide either path or modelfile') - fn = self._stream if stream else self._request_json - return fn( + return self._request_stream( 'POST', '/api/create', json={ @@ -170,6 +174,7 @@ def create( 'modelfile': modelfile, 'stream': stream, }, + stream=stream, ) def _parse_modelfile(self, modelfile: str, base: Optional[Path] = None) -> str: @@ -214,14 +219,14 @@ def delete(self, model: str) -> Mapping[str, Any]: return {'status': 'success' if response.status_code == 200 else 'error'} def list(self) -> Mapping[str, Any]: - return self._request_json('GET', '/api/tags').get('models', []) + return self._request('GET', '/api/tags').json().get('models', []) def copy(self, source: str, target: str) -> Mapping[str, Any]: response = self._request('POST', '/api/copy', json={'source': source, 'destination': target}) return {'status': 'success' if response.status_code == 200 else 'error'} def show(self, model: str) -> Mapping[str, Any]: - return self._request_json('GET', '/api/show', json={'name': model}) + return self._request('GET', '/api/show', json={'name': model}).json() class AsyncClient(BaseClient): @@ -233,10 +238,6 @@ async def _request(self, method: str, url: str, **kwargs) -> httpx.Response: response.raise_for_status() return response - async def _request_json(self, method: str, url: str, **kwargs) -> Mapping[str, Any]: - response = await self._request(method, url, **kwargs) - return response.json() - async def _stream(self, method: str, url: str, **kwargs) -> AsyncIterator[Mapping[str, Any]]: async def inner(): async with self._client.stream(method, url, **kwargs) as r: @@ -248,6 +249,18 @@ async def inner(): return inner() + async def _request_stream( + self, + *args, + stream: bool = False, + **kwargs, + ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: + if stream: + return await self._stream(*args, **kwargs) + + response = await self._request(*args, **kwargs) + return response.json() + async def generate( self, model: str = '', @@ -264,8 +277,7 @@ async def generate( if not model: raise Exception('must provide a model') - fn = self._stream if stream else self._request_json - return await fn( + return await self._request_stream( 'POST', '/api/generate', json={ @@ -280,6 +292,7 @@ async def generate( 'format': format, 'options': options or {}, }, + stream=stream, ) async def chat( @@ -303,8 +316,7 @@ async def chat( if images := message.get('images'): message['images'] = [_encode_image(image) for image in images] - fn = self._stream if stream else self._request_json - return await fn( + return await self._request_stream( 'POST', '/api/chat', json={ @@ -314,6 +326,7 @@ async def chat( 'format': format, 'options': options or {}, }, + stream=stream, ) async def pull( @@ -322,8 +335,7 @@ async def pull( insecure: bool = False, stream: bool = False, ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: - fn = self._stream if stream else self._request_json - return await fn( + return await self._request_stream( 'POST', '/api/pull', json={ @@ -331,6 +343,7 @@ async def pull( 'insecure': insecure, 'stream': stream, }, + stream=stream, ) async def push( @@ -339,8 +352,7 @@ async def push( insecure: bool = False, stream: bool = False, ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: - fn = self._stream if stream else self._request_json - return await fn( + return await self._request_stream( 'POST', '/api/push', json={ @@ -348,6 +360,7 @@ async def push( 'insecure': insecure, 'stream': stream, }, + stream=stream, ) async def create( @@ -364,8 +377,7 @@ async def create( else: raise Exception('must provide either path or modelfile') - fn = self._stream if stream else self._request_json - return await fn( + return await self._request_stream( 'POST', '/api/create', json={ @@ -373,6 +385,7 @@ async def create( 'modelfile': modelfile, 'stream': stream, }, + stream=stream, ) async def _parse_modelfile(self, modelfile: str, base: Optional[Path] = None) -> str: @@ -424,15 +437,16 @@ async def delete(self, model: str) -> Mapping[str, Any]: return {'status': 'success' if response.status_code == 200 else 'error'} async def list(self) -> Mapping[str, Any]: - response = await self._request_json('GET', '/api/tags') - return response.get('models', []) + response = await self._request('GET', '/api/tags') + return response.json().get('models', []) async def copy(self, source: str, target: str) -> Mapping[str, Any]: response = await self._request('POST', '/api/copy', json={'source': source, 'destination': target}) return {'status': 'success' if response.status_code == 200 else 'error'} async def show(self, model: str) -> Mapping[str, Any]: - return await self._request_json('GET', '/api/show', json={'name': model}) + response = await self._request('GET', '/api/show', json={'name': model}) + return response.json() def _encode_image(image) -> str: