Skip to content

Commit

Permalink
type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
mxyng committed Dec 21, 2023
1 parent 6f55659 commit dabcca6
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 43 deletions.
4 changes: 3 additions & 1 deletion ollama/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from ._client import Client, AsyncClient
from ollama._client import Client, AsyncClient, Message, Options

__all__ = [
'Client',
'AsyncClient',
'Message',
'Options',
'generate',
'chat',
'pull',
Expand Down
166 changes: 124 additions & 42 deletions ollama/_client.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,62 @@
import io
import json
import httpx
from os import PathLike
from pathlib import Path
from hashlib import sha256
from base64 import b64encode

from typing import Any, AnyStr, Union, Optional, List, Mapping

import sys
if sys.version_info < (3, 9):
from typing import Iterator, AsyncIterator
else:
from collections.abc import Iterator, AsyncIterator

from ollama._types import Message, Options


class BaseClient:

def __init__(self, client, base_url='http://127.0.0.1:11434'):
def __init__(self, client, base_url='http://127.0.0.1:11434') -> None:
self._client = client(base_url=base_url, follow_redirects=True, timeout=None)


class Client(BaseClient):

def __init__(self, base='http://localhost:11434'):
def __init__(self, base='http://localhost:11434') -> None:
super().__init__(httpx.Client, base)

def _request(self, method, url, **kwargs):
def _request(self, method: str, url: str, **kwargs) -> httpx.Response:
response = self._client.request(method, url, **kwargs)
response.raise_for_status()
return response

def _request_json(self, method, url, **kwargs):
def _request_json(self, method: str, url: str, **kwargs) -> Mapping[str, Any]:
return self._request(method, url, **kwargs).json()

def _stream(self, method, url, **kwargs):
def _stream(self, method: str, url: str, **kwargs) -> Iterator[Mapping[str, Any]]:
with self._client.stream(method, url, **kwargs) as r:
for line in r.iter_lines():
part = json.loads(line)
if e := part.get('error'):
raise Exception(e)
yield part

def generate(self, model='', prompt='', system='', template='', context=None, stream=False, raw=False, format='', images=None, options=None):
def generate(
self,
model: str = '',
prompt: str = '',
system: str = '',
template: str = '',
context: Optional[List[int]] = None,
stream: bool = False,
raw: bool = False,
format: str = '',
images: Optional[List[AnyStr]] = None,
options: Optional[Options] = None,
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
if not model:
raise Exception('must provide a model')

Expand All @@ -51,7 +74,14 @@ def generate(self, model='', prompt='', system='', template='', context=None, st
'options': options or {},
})

def chat(self, model='', messages=None, stream=False, format='', options=None):
def chat(
self,
model: str = '',
messages: Optional[List[Message]] = None,
stream: bool = False,
format: str = '',
options: Optional[Options] = None,
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
if not model:
raise Exception('must provide a model')

Expand All @@ -74,25 +104,41 @@ def chat(self, model='', messages=None, stream=False, format='', options=None):
'options': options or {},
})

def pull(self, model, insecure=False, stream=False):
def pull(
self,
model: str,
insecure: bool = False,
stream: bool = False,
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
fn = self._stream if stream else self._request_json
return fn('POST', '/api/pull', json={
'model': model,
'insecure': insecure,
'stream': stream,
})

def push(self, model, insecure=False, stream=False):
def push(
self,
model: str,
insecure: bool = False,
stream: bool = False,
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
fn = self._stream if stream else self._request_json
return fn('POST', '/api/push', json={
'model': model,
'insecure': insecure,
'stream': stream,
})

def create(self, model, path=None, modelfile=None, stream=False):
if (path := _as_path(path)) and path.exists():
modelfile = self._parse_modelfile(path.read_text(), base=path.parent)
def create(
self,
model: str,
path: Optional[Union[str, PathLike]] = None,
modelfile: Optional[str] = None,
stream: bool = False,
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
if (realpath := _as_path(path)) and realpath.exists():
modelfile = self._parse_modelfile(realpath.read_text(), base=realpath.parent)
elif modelfile:
modelfile = self._parse_modelfile(modelfile)
else:
Expand All @@ -105,7 +151,7 @@ def create(self, model, path=None, modelfile=None, stream=False):
'stream': stream,
})

def _parse_modelfile(self, modelfile, base=None):
def _parse_modelfile(self, modelfile: str, base: Optional[Path] = None) -> str:
base = Path.cwd() if base is None else base

out = io.StringIO()
Expand All @@ -120,7 +166,7 @@ def _parse_modelfile(self, modelfile, base=None):
print(command, args, file=out)
return out.getvalue()

def _create_blob(self, path):
def _create_blob(self, path: Union[str, Path]) -> str:
sha256sum = sha256()
with open(path, 'rb') as r:
while True:
Expand All @@ -142,36 +188,36 @@ def _create_blob(self, path):

return digest

def delete(self, model):
response = self._request_json('DELETE', '/api/delete', json={'model': model})
def delete(self, model: str) -> Mapping[str, Any]:
response = self._request('DELETE', '/api/delete', json={'model': model})
return {'status': 'success' if response.status_code == 200 else 'error'}

def list(self):
def list(self) -> Mapping[str, Any]:
return self._request_json('GET', '/api/tags').get('models', [])

def copy(self, source, target):
response = self._request_json('POST', '/api/copy', json={'source': source, 'destination': target})
def copy(self, source: str, target: str) -> Mapping[str, Any]:
response = self._request('POST', '/api/copy', json={'source': source, 'destination': target})
return {'status': 'success' if response.status_code == 200 else 'error'}

def show(self, model):
def show(self, model: str) -> Mapping[str, Any]:
return self._request_json('GET', '/api/show', json={'model': model})


class AsyncClient(BaseClient):

def __init__(self, base='http://localhost:11434'):
def __init__(self, base='http://localhost:11434') -> None:
super().__init__(httpx.AsyncClient, base)

async def _request(self, method, url, **kwargs):
async def _request(self, method: str, url: str, **kwargs) -> httpx.Response:
response = await self._client.request(method, url, **kwargs)
response.raise_for_status()
return response

async def _request_json(self, method, url, **kwargs):
async def _request_json(self, method: str, url: str, **kwargs) -> Mapping[str, Any]:
response = await self._request(method, url, **kwargs)
return response.json()

async def _stream(self, method, url, **kwargs):
async def _stream(self, method: str, url: str, **kwargs) -> AsyncIterator[Mapping[str, Any]]:
async def inner():
async with self._client.stream(method, url, **kwargs) as r:
async for line in r.aiter_lines():
Expand All @@ -181,7 +227,19 @@ async def inner():
yield part
return inner()

async def generate(self, model='', prompt='', system='', template='', context=None, stream=False, raw=False, format='', images=None, options=None):
async def generate(
self,
model: str = '',
prompt: str = '',
system: str = '',
template: str = '',
context: Optional[List[int]] = None,
stream: bool = False,
raw: bool = False,
format: str = '',
images: Optional[List[AnyStr]] = None,
options: Optional[Options] = None,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
if not model:
raise Exception('must provide a model')

Expand All @@ -199,7 +257,14 @@ async def generate(self, model='', prompt='', system='', template='', context=No
'options': options or {},
})

async def chat(self, model='', messages=None, stream=False, format='', options=None):
async def chat(
self,
model: str = '',
messages: Optional[List[Message]] = None,
stream: bool = False,
format: str = '',
options: Optional[Options] = None,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
if not model:
raise Exception('must provide a model')

Expand All @@ -222,25 +287,41 @@ async def chat(self, model='', messages=None, stream=False, format='', options=N
'options': options or {},
})

async def pull(self, model, insecure=False, stream=False):
async def pull(
self,
model: str,
insecure: bool = False,
stream: bool = False,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
fn = self._stream if stream else self._request_json
return await fn('POST', '/api/pull', json={
'model': model,
'insecure': insecure,
'stream': stream,
})

async def push(self, model, insecure=False, stream=False):
async def push(
self,
model: str,
insecure: bool = False,
stream: bool = False,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
fn = self._stream if stream else self._request_json
return await fn('POST', '/api/push', json={
'model': model,
'insecure': insecure,
'stream': stream,
})

async def create(self, model, path=None, modelfile=None, stream=False):
if (path := _as_path(path)) and path.exists():
modelfile = await self._parse_modelfile(path.read_text(), base=path.parent)
async def create(
self,
model: str,
path: Optional[Union[str, PathLike]] = None,
modelfile: Optional[str] = None,
stream: bool = False,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
if (realpath := _as_path(path)) and realpath.exists():
modelfile = await self._parse_modelfile(realpath.read_text(), base=realpath.parent)
elif modelfile:
modelfile = await self._parse_modelfile(modelfile)
else:
Expand All @@ -253,7 +334,7 @@ async def create(self, model, path=None, modelfile=None, stream=False):
'stream': stream,
})

async def _parse_modelfile(self, modelfile, base=None):
async def _parse_modelfile(self, modelfile: str, base: Optional[Path] = None) -> str:
base = Path.cwd() if base is None else base

out = io.StringIO()
Expand All @@ -268,7 +349,7 @@ async def _parse_modelfile(self, modelfile, base=None):
print(command, args, file=out)
return out.getvalue()

async def _create_blob(self, path):
async def _create_blob(self, path: Union[str, Path]) -> str:
sha256sum = sha256()
with open(path, 'rb') as r:
while True:
Expand Down Expand Up @@ -297,23 +378,23 @@ async def upload_bytes():

return digest

async def delete(self, model):
response = await self._request_json('DELETE', '/api/delete', json={'model': model})
async def delete(self, model: str) -> Mapping[str, Any]:
response = await self._request('DELETE', '/api/delete', json={'model': model})
return {'status': 'success' if response.status_code == 200 else 'error'}

async def list(self):
async def list(self) -> Mapping[str, Any]:
response = await self._request_json('GET', '/api/tags')
return response.get('models', [])

async def copy(self, source, target):
response = await self._request_json('POST', '/api/copy', json={'source': source, 'destination': target})
async def copy(self, source: str, target: str) -> Mapping[str, Any]:
response = await self._request('POST', '/api/copy', json={'source': source, 'destination': target})
return {'status': 'success' if response.status_code == 200 else 'error'}

async def show(self, model):
async def show(self, model: str) -> Mapping[str, Any]:
return await self._request_json('GET', '/api/show', json={'model': model})


def _encode_image(image):
def _encode_image(image) -> str:
if p := _as_path(image):
b64 = b64encode(p.read_bytes())
elif b := _as_bytesio(image):
Expand All @@ -324,12 +405,13 @@ def _encode_image(image):
return b64.decode('utf-8')


def _as_path(s):
def _as_path(s: Optional[Union[str, PathLike]]) -> Union[Path, None]:
if isinstance(s, str) or isinstance(s, Path):
return Path(s)
return None

def _as_bytesio(s):

def _as_bytesio(s: Any) -> Union[io.BytesIO, None]:
if isinstance(s, io.BytesIO):
return s
elif isinstance(s, bytes):
Expand Down
Loading

0 comments on commit dabcca6

Please sign in to comment.