Skip to content

Commit

Permalink
Merge pull request #2 from jmorganca/mxyng/fix
Browse files Browse the repository at this point in the history
fix endpoints
  • Loading branch information
mxyng authored Dec 23, 2023
2 parents 187bd29 + d5b1cc6 commit 349d9c3
Show file tree
Hide file tree
Showing 8 changed files with 49 additions and 41 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ for part in ollama.chat(model='llama2', messages=[message], stream=True):
```


### Using the Synchronous Client
## Using the Synchronous Client

```python
from ollama import Client
Expand All @@ -42,7 +42,7 @@ for part in Client().chat(model='llama2', messages=[message], stream=True):
print(part['message']['content'], end='', flush=True)
```

### Using the Asynchronous Client
## Using the Asynchronous Client

```python
import asyncio
Expand Down
6 changes: 2 additions & 4 deletions examples/simple-chat-stream/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@
},
]

for message in chat('mistral', messages=messages, stream=True):
if message := message.get('message'):
if message.get('role') == 'assistant':
print(message.get('content', ''), end='', flush=True)
for part in chat('mistral', messages=messages, stream=True):
print(part['message']['content'], end='', flush=True)

# end with a newline
print()
2 changes: 1 addition & 1 deletion examples/simple-chat/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
]

response = chat('mistral', messages=messages)
print(response['message'])
print(response['message']['content'])
5 changes: 5 additions & 0 deletions examples/simple-generate-stream/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ollama import generate


for part in generate('mistral', 'Why is the sky blue?', stream=True):
print(part['response'], end='', flush=True)
5 changes: 5 additions & 0 deletions examples/simple-generate/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ollama import generate


response = generate('mistral', 'Why is the sky blue?')
print(response['response'])
2 changes: 1 addition & 1 deletion examples/simple-multimodal/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@
raw.raise_for_status()

for response in generate('llava', 'explain this comic:', images=[raw.content], stream=True):
print(response.get('response'), end='', flush=True)
print(response['response'], end='', flush=True)

print()
30 changes: 15 additions & 15 deletions ollama/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@


class BaseClient:
def __init__(self, client, base_url='http://127.0.0.1:11434') -> None:
def __init__(self, client, base_url: str = 'http://127.0.0.1:11434') -> None:
self._client = client(base_url=base_url, follow_redirects=True, timeout=None)


class Client(BaseClient):
def __init__(self, base='http://localhost:11434') -> None:
super().__init__(httpx.Client, base)
def __init__(self, base_url: str = 'http://localhost:11434') -> None:
super().__init__(httpx.Client, base_url)

def _request(self, method: str, url: str, **kwargs) -> httpx.Response:
response = self._client.request(method, url, **kwargs)
Expand Down Expand Up @@ -122,7 +122,7 @@ def pull(
'POST',
'/api/pull',
json={
'model': model,
'name': model,
'insecure': insecure,
'stream': stream,
},
Expand All @@ -139,7 +139,7 @@ def push(
'POST',
'/api/push',
json={
'model': model,
'name': model,
'insecure': insecure,
'stream': stream,
},
Expand All @@ -164,7 +164,7 @@ def create(
'POST',
'/api/create',
json={
'model': model,
'name': model,
'modelfile': modelfile,
'stream': stream,
},
Expand Down Expand Up @@ -208,7 +208,7 @@ def _create_blob(self, path: Union[str, Path]) -> str:
return digest

def delete(self, model: str) -> Mapping[str, Any]:
response = self._request('DELETE', '/api/delete', json={'model': model})
response = self._request('DELETE', '/api/delete', json={'name': model})
return {'status': 'success' if response.status_code == 200 else 'error'}

def list(self) -> Mapping[str, Any]:
Expand All @@ -219,12 +219,12 @@ def copy(self, source: str, target: str) -> Mapping[str, Any]:
return {'status': 'success' if response.status_code == 200 else 'error'}

def show(self, model: str) -> Mapping[str, Any]:
return self._request_json('GET', '/api/show', json={'model': model})
return self._request_json('GET', '/api/show', json={'name': model})


class AsyncClient(BaseClient):
def __init__(self, base='http://localhost:11434') -> None:
super().__init__(httpx.AsyncClient, base)
def __init__(self, base_url: str = 'http://localhost:11434') -> None:
super().__init__(httpx.AsyncClient, base_url)

async def _request(self, method: str, url: str, **kwargs) -> httpx.Response:
response = await self._client.request(method, url, **kwargs)
Expand Down Expand Up @@ -325,7 +325,7 @@ async def pull(
'POST',
'/api/pull',
json={
'model': model,
'name': model,
'insecure': insecure,
'stream': stream,
},
Expand All @@ -342,7 +342,7 @@ async def push(
'POST',
'/api/push',
json={
'model': model,
'name': model,
'insecure': insecure,
'stream': stream,
},
Expand All @@ -367,7 +367,7 @@ async def create(
'POST',
'/api/create',
json={
'model': model,
'name': model,
'modelfile': modelfile,
'stream': stream,
},
Expand Down Expand Up @@ -418,7 +418,7 @@ async def upload_bytes():
return digest

async def delete(self, model: str) -> Mapping[str, Any]:
response = await self._request('DELETE', '/api/delete', json={'model': model})
response = await self._request('DELETE', '/api/delete', json={'name': model})
return {'status': 'success' if response.status_code == 200 else 'error'}

async def list(self) -> Mapping[str, Any]:
Expand All @@ -430,7 +430,7 @@ async def copy(self, source: str, target: str) -> Mapping[str, Any]:
return {'status': 'success' if response.status_code == 200 else 'error'}

async def show(self, model: str) -> Mapping[str, Any]:
return await self._request_json('GET', '/api/show', json={'model': model})
return await self._request_json('GET', '/api/show', json={'name': model})


def _encode_image(image) -> str:
Expand Down
36 changes: 18 additions & 18 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def test_client_pull(httpserver: HTTPServer):
'/api/pull',
method='POST',
json={
'model': 'dummy',
'name': 'dummy',
'insecure': False,
'stream': False,
},
Expand Down Expand Up @@ -259,7 +259,7 @@ def generate():
'/api/pull',
method='POST',
json={
'model': 'dummy',
'name': 'dummy',
'insecure': False,
'stream': True,
},
Expand All @@ -275,7 +275,7 @@ def test_client_push(httpserver: HTTPServer):
'/api/push',
method='POST',
json={
'model': 'dummy',
'name': 'dummy',
'insecure': False,
'stream': False,
},
Expand All @@ -291,7 +291,7 @@ def test_client_push_stream(httpserver: HTTPServer):
'/api/push',
method='POST',
json={
'model': 'dummy',
'name': 'dummy',
'insecure': False,
'stream': True,
},
Expand All @@ -308,7 +308,7 @@ def test_client_create_path(httpserver: HTTPServer):
'/api/create',
method='POST',
json={
'model': 'dummy',
'name': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
},
Expand All @@ -331,7 +331,7 @@ def test_client_create_path_relative(httpserver: HTTPServer):
'/api/create',
method='POST',
json={
'model': 'dummy',
'name': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
},
Expand Down Expand Up @@ -363,7 +363,7 @@ def test_client_create_path_user_home(httpserver: HTTPServer, userhomedir):
'/api/create',
method='POST',
json={
'model': 'dummy',
'name': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
},
Expand All @@ -386,7 +386,7 @@ def test_client_create_modelfile(httpserver: HTTPServer):
'/api/create',
method='POST',
json={
'model': 'dummy',
'name': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
},
Expand All @@ -404,7 +404,7 @@ def test_client_create_from_library(httpserver: HTTPServer):
'/api/create',
method='POST',
json={
'model': 'dummy',
'name': 'dummy',
'modelfile': 'FROM llama2\n',
'stream': False,
},
Expand Down Expand Up @@ -584,7 +584,7 @@ async def test_async_client_pull(httpserver: HTTPServer):
'/api/pull',
method='POST',
json={
'model': 'dummy',
'name': 'dummy',
'insecure': False,
'stream': False,
},
Expand All @@ -601,7 +601,7 @@ async def test_async_client_pull_stream(httpserver: HTTPServer):
'/api/pull',
method='POST',
json={
'model': 'dummy',
'name': 'dummy',
'insecure': False,
'stream': True,
},
Expand All @@ -618,7 +618,7 @@ async def test_async_client_push(httpserver: HTTPServer):
'/api/push',
method='POST',
json={
'model': 'dummy',
'name': 'dummy',
'insecure': False,
'stream': False,
},
Expand All @@ -635,7 +635,7 @@ async def test_async_client_push_stream(httpserver: HTTPServer):
'/api/push',
method='POST',
json={
'model': 'dummy',
'name': 'dummy',
'insecure': False,
'stream': True,
},
Expand All @@ -653,7 +653,7 @@ async def test_async_client_create_path(httpserver: HTTPServer):
'/api/create',
method='POST',
json={
'model': 'dummy',
'name': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
},
Expand All @@ -677,7 +677,7 @@ async def test_async_client_create_path_relative(httpserver: HTTPServer):
'/api/create',
method='POST',
json={
'model': 'dummy',
'name': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
},
Expand All @@ -701,7 +701,7 @@ async def test_async_client_create_path_user_home(httpserver: HTTPServer, userho
'/api/create',
method='POST',
json={
'model': 'dummy',
'name': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
},
Expand All @@ -725,7 +725,7 @@ async def test_async_client_create_modelfile(httpserver: HTTPServer):
'/api/create',
method='POST',
json={
'model': 'dummy',
'name': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
},
Expand All @@ -744,7 +744,7 @@ async def test_async_client_create_from_library(httpserver: HTTPServer):
'/api/create',
method='POST',
json={
'model': 'dummy',
'name': 'dummy',
'modelfile': 'FROM llama2\n',
'stream': False,
},
Expand Down

0 comments on commit 349d9c3

Please sign in to comment.