From de61864358e5ea649f3bdc3584af4ef5eeab3e31 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 21 Dec 2023 16:48:15 -0800 Subject: [PATCH 1/4] fix api calls --- ollama/_client.py | 20 ++++++++++---------- tests/test_client.py | 36 ++++++++++++++++++------------------ 2 files changed, 28 insertions(+), 28 deletions(-) diff --git a/ollama/_client.py b/ollama/_client.py index d0fa30f..26c4202 100644 --- a/ollama/_client.py +++ b/ollama/_client.py @@ -122,7 +122,7 @@ def pull( 'POST', '/api/pull', json={ - 'model': model, + 'name': model, 'insecure': insecure, 'stream': stream, }, @@ -139,7 +139,7 @@ def push( 'POST', '/api/push', json={ - 'model': model, + 'name': model, 'insecure': insecure, 'stream': stream, }, @@ -164,7 +164,7 @@ def create( 'POST', '/api/create', json={ - 'model': model, + 'name': model, 'modelfile': modelfile, 'stream': stream, }, @@ -208,7 +208,7 @@ def _create_blob(self, path: Union[str, Path]) -> str: return digest def delete(self, model: str) -> Mapping[str, Any]: - response = self._request('DELETE', '/api/delete', json={'model': model}) + response = self._request('DELETE', '/api/delete', json={'name': model}) return {'status': 'success' if response.status_code == 200 else 'error'} def list(self) -> Mapping[str, Any]: @@ -219,7 +219,7 @@ def copy(self, source: str, target: str) -> Mapping[str, Any]: return {'status': 'success' if response.status_code == 200 else 'error'} def show(self, model: str) -> Mapping[str, Any]: - return self._request_json('GET', '/api/show', json={'model': model}) + return self._request_json('GET', '/api/show', json={'name': model}) class AsyncClient(BaseClient): @@ -325,7 +325,7 @@ async def pull( 'POST', '/api/pull', json={ - 'model': model, + 'name': model, 'insecure': insecure, 'stream': stream, }, @@ -342,7 +342,7 @@ async def push( 'POST', '/api/push', json={ - 'model': model, + 'name': model, 'insecure': insecure, 'stream': stream, }, @@ -367,7 +367,7 @@ async def create( 'POST', '/api/create', json={ - 'model': model, + 'name': model, 'modelfile': modelfile, 'stream': stream, }, @@ -418,7 +418,7 @@ async def upload_bytes(): return digest async def delete(self, model: str) -> Mapping[str, Any]: - response = await self._request('DELETE', '/api/delete', json={'model': model}) + response = await self._request('DELETE', '/api/delete', json={'name': model}) return {'status': 'success' if response.status_code == 200 else 'error'} async def list(self) -> Mapping[str, Any]: @@ -430,7 +430,7 @@ async def copy(self, source: str, target: str) -> Mapping[str, Any]: return {'status': 'success' if response.status_code == 200 else 'error'} async def show(self, model: str) -> Mapping[str, Any]: - return await self._request_json('GET', '/api/show', json={'model': model}) + return await self._request_json('GET', '/api/show', json={'name': model}) def _encode_image(image) -> str: diff --git a/tests/test_client.py b/tests/test_client.py index fe151dc..3b8ce8b 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -229,7 +229,7 @@ def test_client_pull(httpserver: HTTPServer): '/api/pull', method='POST', json={ - 'model': 'dummy', + 'name': 'dummy', 'insecure': False, 'stream': False, }, @@ -259,7 +259,7 @@ def generate(): '/api/pull', method='POST', json={ - 'model': 'dummy', + 'name': 'dummy', 'insecure': False, 'stream': True, }, @@ -275,7 +275,7 @@ def test_client_push(httpserver: HTTPServer): '/api/push', method='POST', json={ - 'model': 'dummy', + 'name': 'dummy', 'insecure': False, 'stream': False, }, @@ -291,7 +291,7 @@ def test_client_push_stream(httpserver: HTTPServer): '/api/push', method='POST', json={ - 'model': 'dummy', + 'name': 'dummy', 'insecure': False, 'stream': True, }, @@ -308,7 +308,7 @@ def test_client_create_path(httpserver: HTTPServer): '/api/create', method='POST', json={ - 'model': 'dummy', + 'name': 'dummy', 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n', 'stream': False, }, @@ -331,7 +331,7 @@ def test_client_create_path_relative(httpserver: HTTPServer): '/api/create', method='POST', json={ - 'model': 'dummy', + 'name': 'dummy', 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n', 'stream': False, }, @@ -363,7 +363,7 @@ def test_client_create_path_user_home(httpserver: HTTPServer, userhomedir): '/api/create', method='POST', json={ - 'model': 'dummy', + 'name': 'dummy', 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n', 'stream': False, }, @@ -386,7 +386,7 @@ def test_client_create_modelfile(httpserver: HTTPServer): '/api/create', method='POST', json={ - 'model': 'dummy', + 'name': 'dummy', 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n', 'stream': False, }, @@ -404,7 +404,7 @@ def test_client_create_from_library(httpserver: HTTPServer): '/api/create', method='POST', json={ - 'model': 'dummy', + 'name': 'dummy', 'modelfile': 'FROM llama2\n', 'stream': False, }, @@ -584,7 +584,7 @@ async def test_async_client_pull(httpserver: HTTPServer): '/api/pull', method='POST', json={ - 'model': 'dummy', + 'name': 'dummy', 'insecure': False, 'stream': False, }, @@ -601,7 +601,7 @@ async def test_async_client_pull_stream(httpserver: HTTPServer): '/api/pull', method='POST', json={ - 'model': 'dummy', + 'name': 'dummy', 'insecure': False, 'stream': True, }, @@ -618,7 +618,7 @@ async def test_async_client_push(httpserver: HTTPServer): '/api/push', method='POST', json={ - 'model': 'dummy', + 'name': 'dummy', 'insecure': False, 'stream': False, }, @@ -635,7 +635,7 @@ async def test_async_client_push_stream(httpserver: HTTPServer): '/api/push', method='POST', json={ - 'model': 'dummy', + 'name': 'dummy', 'insecure': False, 'stream': True, }, @@ -653,7 +653,7 @@ async def test_async_client_create_path(httpserver: HTTPServer): '/api/create', method='POST', json={ - 'model': 'dummy', + 'name': 'dummy', 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n', 'stream': False, }, @@ -677,7 +677,7 @@ async def test_async_client_create_path_relative(httpserver: HTTPServer): '/api/create', method='POST', json={ - 'model': 'dummy', + 'name': 'dummy', 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n', 'stream': False, }, @@ -701,7 +701,7 @@ async def test_async_client_create_path_user_home(httpserver: HTTPServer, userho '/api/create', method='POST', json={ - 'model': 'dummy', + 'name': 'dummy', 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n', 'stream': False, }, @@ -725,7 +725,7 @@ async def test_async_client_create_modelfile(httpserver: HTTPServer): '/api/create', method='POST', json={ - 'model': 'dummy', + 'name': 'dummy', 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n', 'stream': False, }, @@ -744,7 +744,7 @@ async def test_async_client_create_from_library(httpserver: HTTPServer): '/api/create', method='POST', json={ - 'model': 'dummy', + 'name': 'dummy', 'modelfile': 'FROM llama2\n', 'stream': False, }, From 2236de230cd784a3fe971f88739d67fcf62f3cf5 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 21 Dec 2023 16:38:52 -0800 Subject: [PATCH 2/4] s/base/base_url/ --- ollama/_client.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ollama/_client.py b/ollama/_client.py index 26c4202..997e3e2 100644 --- a/ollama/_client.py +++ b/ollama/_client.py @@ -19,13 +19,13 @@ class BaseClient: - def __init__(self, client, base_url='http://127.0.0.1:11434') -> None: + def __init__(self, client, base_url: str = 'http://127.0.0.1:11434') -> None: self._client = client(base_url=base_url, follow_redirects=True, timeout=None) class Client(BaseClient): - def __init__(self, base='http://localhost:11434') -> None: - super().__init__(httpx.Client, base) + def __init__(self, base_url: str = 'http://localhost:11434') -> None: + super().__init__(httpx.Client, base_url) def _request(self, method: str, url: str, **kwargs) -> httpx.Response: response = self._client.request(method, url, **kwargs) @@ -223,8 +223,8 @@ def show(self, model: str) -> Mapping[str, Any]: class AsyncClient(BaseClient): - def __init__(self, base='http://localhost:11434') -> None: - super().__init__(httpx.AsyncClient, base) + def __init__(self, base_url: str = 'http://localhost:11434') -> None: + super().__init__(httpx.AsyncClient, base_url) async def _request(self, method: str, url: str, **kwargs) -> httpx.Response: response = await self._client.request(method, url, **kwargs) From e8a66b8de168e9331485ecdfe685158ceafd37fb Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 21 Dec 2023 16:37:07 -0800 Subject: [PATCH 3/4] examples --- examples/simple-chat-stream/main.py | 6 ++---- examples/simple-chat/main.py | 2 +- examples/simple-generate-stream/main.py | 5 +++++ examples/simple-generate/main.py | 5 +++++ examples/simple-multimodal/main.py | 2 +- 5 files changed, 14 insertions(+), 6 deletions(-) create mode 100644 examples/simple-generate-stream/main.py create mode 100644 examples/simple-generate/main.py diff --git a/examples/simple-chat-stream/main.py b/examples/simple-chat-stream/main.py index 33d1740..2a57346 100644 --- a/examples/simple-chat-stream/main.py +++ b/examples/simple-chat-stream/main.py @@ -8,10 +8,8 @@ }, ] -for message in chat('mistral', messages=messages, stream=True): - if message := message.get('message'): - if message.get('role') == 'assistant': - print(message.get('content', ''), end='', flush=True) +for part in chat('mistral', messages=messages, stream=True): + print(part['message']['content'], end='', flush=True) # end with a newline print() diff --git a/examples/simple-chat/main.py b/examples/simple-chat/main.py index 5019307..90c5f90 100644 --- a/examples/simple-chat/main.py +++ b/examples/simple-chat/main.py @@ -9,4 +9,4 @@ ] response = chat('mistral', messages=messages) -print(response['message']) +print(response['message']['content']) diff --git a/examples/simple-generate-stream/main.py b/examples/simple-generate-stream/main.py new file mode 100644 index 0000000..a24b410 --- /dev/null +++ b/examples/simple-generate-stream/main.py @@ -0,0 +1,5 @@ +from ollama import generate + + +for part in generate('mistral', 'Why is the sky blue?', stream=True): + print(part['response'], end='', flush=True) diff --git a/examples/simple-generate/main.py b/examples/simple-generate/main.py new file mode 100644 index 0000000..e39e295 --- /dev/null +++ b/examples/simple-generate/main.py @@ -0,0 +1,5 @@ +from ollama import generate + + +response = generate('mistral', 'Why is the sky blue?') +print(response['response']) diff --git a/examples/simple-multimodal/main.py b/examples/simple-multimodal/main.py index 97eba59..44b3716 100644 --- a/examples/simple-multimodal/main.py +++ b/examples/simple-multimodal/main.py @@ -24,6 +24,6 @@ raw.raise_for_status() for response in generate('llava', 'explain this comic:', images=[raw.content], stream=True): - print(response.get('response'), end='', flush=True) + print(response['response'], end='', flush=True) print() From d5b1cc60fa538a4bf0ec7fe80f2557a35b86a064 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Fri, 22 Dec 2023 11:50:21 -0800 Subject: [PATCH 4/4] update readme --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 0183092..ec7801c 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ for part in ollama.chat(model='llama2', messages=[message], stream=True): ``` -### Using the Synchronous Client +## Using the Synchronous Client ```python from ollama import Client @@ -42,7 +42,7 @@ for part in Client().chat(model='llama2', messages=[message], stream=True): print(part['message']['content'], end='', flush=True) ``` -### Using the Asynchronous Client +## Using the Asynchronous Client ```python import asyncio