Skip to content

Commit

Permalink
Merge pull request #4 from jmorganca/mxyng/rm-default-async
Browse files Browse the repository at this point in the history
Mxyng/rm default async
  • Loading branch information
mxyng authored Jan 10, 2024
2 parents e2be701 + c67ef1a commit 0f9211d
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 64 deletions.
35 changes: 12 additions & 23 deletions ollama/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ollama._client import Client, AsyncClient, Message, Options
from ollama._client import Client, AsyncClient
from ollama._types import Message, Options

__all__ = [
'Client',
Expand All @@ -16,26 +17,14 @@
'show',
]

_default_client = Client()
_client = Client()

generate = _default_client.generate
chat = _default_client.chat
pull = _default_client.pull
push = _default_client.push
create = _default_client.create
delete = _default_client.delete
list = _default_client.list
copy = _default_client.copy
show = _default_client.show

_async_default_client = AsyncClient()

async_generate = _async_default_client.generate
async_chat = _async_default_client.chat
async_pull = _async_default_client.pull
async_push = _async_default_client.push
async_create = _async_default_client.create
async_delete = _async_default_client.delete
async_list = _async_default_client.list
async_copy = _async_default_client.copy
async_show = _async_default_client.show
generate = _client.generate
chat = _client.chat
pull = _client.pull
push = _client.push
create = _client.create
delete = _client.delete
list = _client.list
copy = _client.copy
show = _client.show
98 changes: 57 additions & 41 deletions ollama/_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import io
import json
import httpx
Expand All @@ -19,29 +20,35 @@


class BaseClient:
def __init__(self, client, base_url: str = 'http://127.0.0.1:11434') -> None:
def __init__(self, client, base_url: Optional[str] = None) -> None:
base_url = base_url or os.getenv('OLLAMA_HOST', 'http://127.0.0.1:11434')
self._client = client(base_url=base_url, follow_redirects=True, timeout=None)


class Client(BaseClient):
def __init__(self, base_url: str = 'http://localhost:11434') -> None:
def __init__(self, base_url: Optional[str] = None) -> None:
super().__init__(httpx.Client, base_url)

def _request(self, method: str, url: str, **kwargs) -> httpx.Response:
response = self._client.request(method, url, **kwargs)
response.raise_for_status()
return response

def _request_json(self, method: str, url: str, **kwargs) -> Mapping[str, Any]:
return self._request(method, url, **kwargs).json()

def _stream(self, method: str, url: str, **kwargs) -> Iterator[Mapping[str, Any]]:
with self._client.stream(method, url, **kwargs) as r:
for line in r.iter_lines():
part = json.loads(line)
if e := part.get('error'):
partial = json.loads(line)
if e := partial.get('error'):
raise Exception(e)
yield part
yield partial

def _request_stream(
self,
*args,
stream: bool = False,
**kwargs,
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
return self._stream(*args, **kwargs) if stream else self._request(*args, **kwargs).json()

def generate(
self,
Expand All @@ -59,8 +66,7 @@ def generate(
if not model:
raise Exception('must provide a model')

fn = self._stream if stream else self._request_json
return fn(
return self._request_stream(
'POST',
'/api/generate',
json={
Expand All @@ -75,6 +81,7 @@ def generate(
'format': format,
'options': options or {},
},
stream=stream,
)

def chat(
Expand All @@ -98,8 +105,7 @@ def chat(
if images := message.get('images'):
message['images'] = [_encode_image(image) for image in images]

fn = self._stream if stream else self._request_json
return fn(
return self._request_stream(
'POST',
'/api/chat',
json={
Expand All @@ -109,6 +115,7 @@ def chat(
'format': format,
'options': options or {},
},
stream=stream,
)

def pull(
Expand All @@ -117,15 +124,15 @@ def pull(
insecure: bool = False,
stream: bool = False,
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
fn = self._stream if stream else self._request_json
return fn(
return self._request_stream(
'POST',
'/api/pull',
json={
'name': model,
'insecure': insecure,
'stream': stream,
},
stream=stream,
)

def push(
Expand All @@ -134,15 +141,15 @@ def push(
insecure: bool = False,
stream: bool = False,
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
fn = self._stream if stream else self._request_json
return fn(
return self._request_stream(
'POST',
'/api/push',
json={
'name': model,
'insecure': insecure,
'stream': stream,
},
stream=stream,
)

def create(
Expand All @@ -159,15 +166,15 @@ def create(
else:
raise Exception('must provide either path or modelfile')

fn = self._stream if stream else self._request_json
return fn(
return self._request_stream(
'POST',
'/api/create',
json={
'name': model,
'modelfile': modelfile,
'stream': stream,
},
stream=stream,
)

def _parse_modelfile(self, modelfile: str, base: Optional[Path] = None) -> str:
Expand Down Expand Up @@ -212,40 +219,48 @@ def delete(self, model: str) -> Mapping[str, Any]:
return {'status': 'success' if response.status_code == 200 else 'error'}

def list(self) -> Mapping[str, Any]:
return self._request_json('GET', '/api/tags').get('models', [])
return self._request('GET', '/api/tags').json().get('models', [])

def copy(self, source: str, target: str) -> Mapping[str, Any]:
response = self._request('POST', '/api/copy', json={'source': source, 'destination': target})
return {'status': 'success' if response.status_code == 200 else 'error'}

def show(self, model: str) -> Mapping[str, Any]:
return self._request_json('GET', '/api/show', json={'name': model})
return self._request('GET', '/api/show', json={'name': model}).json()


class AsyncClient(BaseClient):
def __init__(self, base_url: str = 'http://localhost:11434') -> None:
def __init__(self, base_url: Optional[str] = None) -> None:
super().__init__(httpx.AsyncClient, base_url)

async def _request(self, method: str, url: str, **kwargs) -> httpx.Response:
response = await self._client.request(method, url, **kwargs)
response.raise_for_status()
return response

async def _request_json(self, method: str, url: str, **kwargs) -> Mapping[str, Any]:
response = await self._request(method, url, **kwargs)
return response.json()

async def _stream(self, method: str, url: str, **kwargs) -> AsyncIterator[Mapping[str, Any]]:
async def inner():
async with self._client.stream(method, url, **kwargs) as r:
async for line in r.aiter_lines():
part = json.loads(line)
if e := part.get('error'):
partial = json.loads(line)
if e := partial.get('error'):
raise Exception(e)
yield part
yield partial

return inner()

async def _request_stream(
self,
*args,
stream: bool = False,
**kwargs,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
if stream:
return await self._stream(*args, **kwargs)

response = await self._request(*args, **kwargs)
return response.json()

async def generate(
self,
model: str = '',
Expand All @@ -262,8 +277,7 @@ async def generate(
if not model:
raise Exception('must provide a model')

fn = self._stream if stream else self._request_json
return await fn(
return await self._request_stream(
'POST',
'/api/generate',
json={
Expand All @@ -278,6 +292,7 @@ async def generate(
'format': format,
'options': options or {},
},
stream=stream,
)

async def chat(
Expand All @@ -301,8 +316,7 @@ async def chat(
if images := message.get('images'):
message['images'] = [_encode_image(image) for image in images]

fn = self._stream if stream else self._request_json
return await fn(
return await self._request_stream(
'POST',
'/api/chat',
json={
Expand All @@ -312,6 +326,7 @@ async def chat(
'format': format,
'options': options or {},
},
stream=stream,
)

async def pull(
Expand All @@ -320,15 +335,15 @@ async def pull(
insecure: bool = False,
stream: bool = False,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
fn = self._stream if stream else self._request_json
return await fn(
return await self._request_stream(
'POST',
'/api/pull',
json={
'name': model,
'insecure': insecure,
'stream': stream,
},
stream=stream,
)

async def push(
Expand All @@ -337,15 +352,15 @@ async def push(
insecure: bool = False,
stream: bool = False,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
fn = self._stream if stream else self._request_json
return await fn(
return await self._request_stream(
'POST',
'/api/push',
json={
'name': model,
'insecure': insecure,
'stream': stream,
},
stream=stream,
)

async def create(
Expand All @@ -362,15 +377,15 @@ async def create(
else:
raise Exception('must provide either path or modelfile')

fn = self._stream if stream else self._request_json
return await fn(
return await self._request_stream(
'POST',
'/api/create',
json={
'name': model,
'modelfile': modelfile,
'stream': stream,
},
stream=stream,
)

async def _parse_modelfile(self, modelfile: str, base: Optional[Path] = None) -> str:
Expand Down Expand Up @@ -422,15 +437,16 @@ async def delete(self, model: str) -> Mapping[str, Any]:
return {'status': 'success' if response.status_code == 200 else 'error'}

async def list(self) -> Mapping[str, Any]:
response = await self._request_json('GET', '/api/tags')
return response.get('models', [])
response = await self._request('GET', '/api/tags')
return response.json().get('models', [])

async def copy(self, source: str, target: str) -> Mapping[str, Any]:
response = await self._request('POST', '/api/copy', json={'source': source, 'destination': target})
return {'status': 'success' if response.status_code == 200 else 'error'}

async def show(self, model: str) -> Mapping[str, Any]:
return await self._request_json('GET', '/api/show', json={'name': model})
response = await self._request('GET', '/api/show', json={'name': model})
return response.json()


def _encode_image(image) -> str:
Expand Down

0 comments on commit 0f9211d

Please sign in to comment.