diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml new file mode 100644 index 0000000..a3728ca --- /dev/null +++ b/.github/workflows/publish.yaml @@ -0,0 +1,24 @@ +name: publish + +on: + release: + types: + - created + +jobs: + publish: + runs-on: ubuntu-latest + environment: release + permissions: + id-token: write + steps: + - uses: actions/checkout@v4 + - run: pipx install poetry + - uses: actions/setup-python@v5 + with: + cache: poetry + - run: | + poetry version -- ${GIT_REF_NAME#v} + poetry build + - uses: pypa/gh-action-pypi-publish@release/v1 + - run: gh release upload $GIT_REF_NAME dist/* diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml new file mode 100644 index 0000000..cc3a2d5 --- /dev/null +++ b/.github/workflows/test.yaml @@ -0,0 +1,32 @@ +name: test + +on: + pull_request: + +jobs: + test: + strategy: + matrix: + python-version: ['3.8', '3.9', '3.10', '3.11', '3.12'] + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - run: pipx install poetry + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: poetry + - run: poetry install --with=dev + - run: poetry run ruff --output-format=github . + - run: poetry run pytest . --junitxml=junit/test-results-${{ matrix.python-version }}.xml --cov=ollama --cov-report=xml --cov-report=html + - name: check poetry.lock is up-to-date + run: poetry check --lock + - name: check requirements.txt is up-to-date + run: | + poetry export >requirements.txt + git diff --exit-code requirements.txt + - uses: actions/upload-artifact@v3 + with: + name: pytest-results-${{ matrix.python-version }} + path: junit/test-results-${{ matrix.python-version }}.xml + if: ${{ always() }} diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..68bc17f --- /dev/null +++ b/.gitignore @@ -0,0 +1,160 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ diff --git a/README.md b/README.md index e69de29..0183092 100644 --- a/README.md +++ b/README.md @@ -0,0 +1,70 @@ +# Ollama Python Library + +The Ollama Python library provides the easiest way to integrate your Python 3 project with [Ollama](https://github.com/jmorganca/ollama). + +## Getting Started + +Requires Python 3.8 or higher. + +```sh +pip install ollama +``` + +A global default client is provided for convenience and can be used in the same way as the synchronous client. + +```python +import ollama +response = ollama.chat(model='llama2', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}]) +``` + +```python +import ollama +message = {'role': 'user', 'content': 'Why is the sky blue?'} +for part in ollama.chat(model='llama2', messages=[message], stream=True): + print(part['message']['content'], end='', flush=True) +``` + + +### Using the Synchronous Client + +```python +from ollama import Client +message = {'role': 'user', 'content': 'Why is the sky blue?'} +response = Client().chat(model='llama2', messages=[message]) +``` + +Response streaming can be enabled by setting `stream=True`. This modifies the function to return a Python generator where each part is an object in the stream. + +```python +from ollama import Client +message = {'role': 'user', 'content': 'Why is the sky blue?'} +for part in Client().chat(model='llama2', messages=[message], stream=True): + print(part['message']['content'], end='', flush=True) +``` + +### Using the Asynchronous Client + +```python +import asyncio +from ollama import AsyncClient + +async def chat(): + message = {'role': 'user', 'content': 'Why is the sky blue?'} + response = await AsyncClient().chat(model='llama2', messages=[message]) + +asyncio.run(chat()) +``` + +Similar to the synchronous client, setting `stream=True` modifies the function to return a Python asynchronous generator. + +```python +import asyncio +from ollama import AsyncClient + +async def chat(): + message = {'role': 'user', 'content': 'Why is the sky blue?'} + async for part in await AsyncClient().chat(model='llama2', messages=[message], stream=True): + print(part['message']['content'], end='', flush=True) + +asyncio.run(chat()) +``` diff --git a/examples/simple-fill-in-middle/main.py b/examples/simple-fill-in-middle/main.py new file mode 100644 index 0000000..67d7a74 --- /dev/null +++ b/examples/simple-fill-in-middle/main.py @@ -0,0 +1,22 @@ +from ollama import generate + +prefix = '''def remove_non_ascii(s: str) -> str: + """ ''' + +suffix = """ + return result +""" + + +response = generate( + model='codellama:7b-code', + prompt=f'
{prefix}{suffix} ', + options={ + 'num_predict': 128, + 'temperature': 0, + 'top_p': 0.9, + 'stop': [' '], + }, +) + +print(response['response']) diff --git a/ollama/__init__.py b/ollama/__init__.py index 8e7dc22..a66f1d0 100644 --- a/ollama/__init__.py +++ b/ollama/__init__.py @@ -1,30 +1,56 @@ -from ollama.client import Client +from ollama._client import Client, AsyncClient, Message, Options + +__all__ = [ + 'Client', + 'AsyncClient', + 'Message', + 'Options', + 'generate', + 'chat', + 'pull', + 'push', + 'create', + 'delete', + 'list', + 'copy', + 'show', +] + _default_client = Client() + def generate(*args, **kwargs): return _default_client.generate(*args, **kwargs) + def chat(*args, **kwargs): return _default_client.chat(*args, **kwargs) + def pull(*args, **kwargs): return _default_client.pull(*args, **kwargs) + def push(*args, **kwargs): return _default_client.push(*args, **kwargs) + def create(*args, **kwargs): return _default_client.create(*args, **kwargs) + def delete(*args, **kwargs): return _default_client.delete(*args, **kwargs) + def list(*args, **kwargs): return _default_client.list(*args, **kwargs) + def copy(*args, **kwargs): return _default_client.copy(*args, **kwargs) + def show(*args, **kwargs): return _default_client.show(*args, **kwargs) diff --git a/ollama/_client.py b/ollama/_client.py new file mode 100644 index 0000000..d0fa30f --- /dev/null +++ b/ollama/_client.py @@ -0,0 +1,458 @@ +import io +import json +import httpx +from os import PathLike +from pathlib import Path +from hashlib import sha256 +from base64 import b64encode + +from typing import Any, AnyStr, Union, Optional, List, Mapping + +import sys + +if sys.version_info < (3, 9): + from typing import Iterator, AsyncIterator +else: + from collections.abc import Iterator, AsyncIterator + +from ollama._types import Message, Options + + +class BaseClient: + def __init__(self, client, base_url='http://127.0.0.1:11434') -> None: + self._client = client(base_url=base_url, follow_redirects=True, timeout=None) + + +class Client(BaseClient): + def __init__(self, base='http://localhost:11434') -> None: + super().__init__(httpx.Client, base) + + def _request(self, method: str, url: str, **kwargs) -> httpx.Response: + response = self._client.request(method, url, **kwargs) + response.raise_for_status() + return response + + def _request_json(self, method: str, url: str, **kwargs) -> Mapping[str, Any]: + return self._request(method, url, **kwargs).json() + + def _stream(self, method: str, url: str, **kwargs) -> Iterator[Mapping[str, Any]]: + with self._client.stream(method, url, **kwargs) as r: + for line in r.iter_lines(): + part = json.loads(line) + if e := part.get('error'): + raise Exception(e) + yield part + + def generate( + self, + model: str = '', + prompt: str = '', + system: str = '', + template: str = '', + context: Optional[List[int]] = None, + stream: bool = False, + raw: bool = False, + format: str = '', + images: Optional[List[AnyStr]] = None, + options: Optional[Options] = None, + ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]: + if not model: + raise Exception('must provide a model') + + fn = self._stream if stream else self._request_json + return fn( + 'POST', + '/api/generate', + json={ + 'model': model, + 'prompt': prompt, + 'system': system, + 'template': template, + 'context': context or [], + 'stream': stream, + 'raw': raw, + 'images': [_encode_image(image) for image in images or []], + 'format': format, + 'options': options or {}, + }, + ) + + def chat( + self, + model: str = '', + messages: Optional[List[Message]] = None, + stream: bool = False, + format: str = '', + options: Optional[Options] = None, + ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]: + if not model: + raise Exception('must provide a model') + + for message in messages or []: + if not isinstance(message, dict): + raise TypeError('messages must be a list of strings') + if not (role := message.get('role')) or role not in ['system', 'user', 'assistant']: + raise Exception('messages must contain a role and it must be one of "system", "user", or "assistant"') + if not message.get('content'): + raise Exception('messages must contain content') + if images := message.get('images'): + message['images'] = [_encode_image(image) for image in images] + + fn = self._stream if stream else self._request_json + return fn( + 'POST', + '/api/chat', + json={ + 'model': model, + 'messages': messages, + 'stream': stream, + 'format': format, + 'options': options or {}, + }, + ) + + def pull( + self, + model: str, + insecure: bool = False, + stream: bool = False, + ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]: + fn = self._stream if stream else self._request_json + return fn( + 'POST', + '/api/pull', + json={ + 'model': model, + 'insecure': insecure, + 'stream': stream, + }, + ) + + def push( + self, + model: str, + insecure: bool = False, + stream: bool = False, + ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]: + fn = self._stream if stream else self._request_json + return fn( + 'POST', + '/api/push', + json={ + 'model': model, + 'insecure': insecure, + 'stream': stream, + }, + ) + + def create( + self, + model: str, + path: Optional[Union[str, PathLike]] = None, + modelfile: Optional[str] = None, + stream: bool = False, + ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]: + if (realpath := _as_path(path)) and realpath.exists(): + modelfile = self._parse_modelfile(realpath.read_text(), base=realpath.parent) + elif modelfile: + modelfile = self._parse_modelfile(modelfile) + else: + raise Exception('must provide either path or modelfile') + + fn = self._stream if stream else self._request_json + return fn( + 'POST', + '/api/create', + json={ + 'model': model, + 'modelfile': modelfile, + 'stream': stream, + }, + ) + + def _parse_modelfile(self, modelfile: str, base: Optional[Path] = None) -> str: + base = Path.cwd() if base is None else base + + out = io.StringIO() + for line in io.StringIO(modelfile): + command, _, args = line.partition(' ') + if command.upper() in ['FROM', 'ADAPTER']: + path = Path(args).expanduser() + path = path if path.is_absolute() else base / path + if path.exists(): + args = f'@{self._create_blob(path)}' + + print(command, args, file=out) + return out.getvalue() + + def _create_blob(self, path: Union[str, Path]) -> str: + sha256sum = sha256() + with open(path, 'rb') as r: + while True: + chunk = r.read(32 * 1024) + if not chunk: + break + sha256sum.update(chunk) + + digest = f'sha256:{sha256sum.hexdigest()}' + + try: + self._request('HEAD', f'/api/blobs/{digest}') + except httpx.HTTPStatusError as e: + if e.response.status_code != 404: + raise + + with open(path, 'rb') as r: + self._request('PUT', f'/api/blobs/{digest}', content=r) + + return digest + + def delete(self, model: str) -> Mapping[str, Any]: + response = self._request('DELETE', '/api/delete', json={'model': model}) + return {'status': 'success' if response.status_code == 200 else 'error'} + + def list(self) -> Mapping[str, Any]: + return self._request_json('GET', '/api/tags').get('models', []) + + def copy(self, source: str, target: str) -> Mapping[str, Any]: + response = self._request('POST', '/api/copy', json={'source': source, 'destination': target}) + return {'status': 'success' if response.status_code == 200 else 'error'} + + def show(self, model: str) -> Mapping[str, Any]: + return self._request_json('GET', '/api/show', json={'model': model}) + + +class AsyncClient(BaseClient): + def __init__(self, base='http://localhost:11434') -> None: + super().__init__(httpx.AsyncClient, base) + + async def _request(self, method: str, url: str, **kwargs) -> httpx.Response: + response = await self._client.request(method, url, **kwargs) + response.raise_for_status() + return response + + async def _request_json(self, method: str, url: str, **kwargs) -> Mapping[str, Any]: + response = await self._request(method, url, **kwargs) + return response.json() + + async def _stream(self, method: str, url: str, **kwargs) -> AsyncIterator[Mapping[str, Any]]: + async def inner(): + async with self._client.stream(method, url, **kwargs) as r: + async for line in r.aiter_lines(): + part = json.loads(line) + if e := part.get('error'): + raise Exception(e) + yield part + + return inner() + + async def generate( + self, + model: str = '', + prompt: str = '', + system: str = '', + template: str = '', + context: Optional[List[int]] = None, + stream: bool = False, + raw: bool = False, + format: str = '', + images: Optional[List[AnyStr]] = None, + options: Optional[Options] = None, + ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: + if not model: + raise Exception('must provide a model') + + fn = self._stream if stream else self._request_json + return await fn( + 'POST', + '/api/generate', + json={ + 'model': model, + 'prompt': prompt, + 'system': system, + 'template': template, + 'context': context or [], + 'stream': stream, + 'raw': raw, + 'images': [_encode_image(image) for image in images or []], + 'format': format, + 'options': options or {}, + }, + ) + + async def chat( + self, + model: str = '', + messages: Optional[List[Message]] = None, + stream: bool = False, + format: str = '', + options: Optional[Options] = None, + ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: + if not model: + raise Exception('must provide a model') + + for message in messages or []: + if not isinstance(message, dict): + raise TypeError('messages must be a list of strings') + if not (role := message.get('role')) or role not in ['system', 'user', 'assistant']: + raise Exception('messages must contain a role and it must be one of "system", "user", or "assistant"') + if not message.get('content'): + raise Exception('messages must contain content') + if images := message.get('images'): + message['images'] = [_encode_image(image) for image in images] + + fn = self._stream if stream else self._request_json + return await fn( + 'POST', + '/api/chat', + json={ + 'model': model, + 'messages': messages, + 'stream': stream, + 'format': format, + 'options': options or {}, + }, + ) + + async def pull( + self, + model: str, + insecure: bool = False, + stream: bool = False, + ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: + fn = self._stream if stream else self._request_json + return await fn( + 'POST', + '/api/pull', + json={ + 'model': model, + 'insecure': insecure, + 'stream': stream, + }, + ) + + async def push( + self, + model: str, + insecure: bool = False, + stream: bool = False, + ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: + fn = self._stream if stream else self._request_json + return await fn( + 'POST', + '/api/push', + json={ + 'model': model, + 'insecure': insecure, + 'stream': stream, + }, + ) + + async def create( + self, + model: str, + path: Optional[Union[str, PathLike]] = None, + modelfile: Optional[str] = None, + stream: bool = False, + ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: + if (realpath := _as_path(path)) and realpath.exists(): + modelfile = await self._parse_modelfile(realpath.read_text(), base=realpath.parent) + elif modelfile: + modelfile = await self._parse_modelfile(modelfile) + else: + raise Exception('must provide either path or modelfile') + + fn = self._stream if stream else self._request_json + return await fn( + 'POST', + '/api/create', + json={ + 'model': model, + 'modelfile': modelfile, + 'stream': stream, + }, + ) + + async def _parse_modelfile(self, modelfile: str, base: Optional[Path] = None) -> str: + base = Path.cwd() if base is None else base + + out = io.StringIO() + for line in io.StringIO(modelfile): + command, _, args = line.partition(' ') + if command.upper() in ['FROM', 'ADAPTER']: + path = Path(args).expanduser() + path = path if path.is_absolute() else base / path + if path.exists(): + args = f'@{await self._create_blob(path)}' + + print(command, args, file=out) + return out.getvalue() + + async def _create_blob(self, path: Union[str, Path]) -> str: + sha256sum = sha256() + with open(path, 'rb') as r: + while True: + chunk = r.read(32 * 1024) + if not chunk: + break + sha256sum.update(chunk) + + digest = f'sha256:{sha256sum.hexdigest()}' + + try: + await self._request('HEAD', f'/api/blobs/{digest}') + except httpx.HTTPStatusError as e: + if e.response.status_code != 404: + raise + + async def upload_bytes(): + with open(path, 'rb') as r: + while True: + chunk = r.read(32 * 1024) + if not chunk: + break + yield chunk + + await self._request('PUT', f'/api/blobs/{digest}', content=upload_bytes()) + + return digest + + async def delete(self, model: str) -> Mapping[str, Any]: + response = await self._request('DELETE', '/api/delete', json={'model': model}) + return {'status': 'success' if response.status_code == 200 else 'error'} + + async def list(self) -> Mapping[str, Any]: + response = await self._request_json('GET', '/api/tags') + return response.get('models', []) + + async def copy(self, source: str, target: str) -> Mapping[str, Any]: + response = await self._request('POST', '/api/copy', json={'source': source, 'destination': target}) + return {'status': 'success' if response.status_code == 200 else 'error'} + + async def show(self, model: str) -> Mapping[str, Any]: + return await self._request_json('GET', '/api/show', json={'model': model}) + + +def _encode_image(image) -> str: + if p := _as_path(image): + b64 = b64encode(p.read_bytes()) + elif b := _as_bytesio(image): + b64 = b64encode(b.read()) + else: + raise Exception('images must be a list of bytes, path-like objects, or file-like objects') + + return b64.decode('utf-8') + + +def _as_path(s: Optional[Union[str, PathLike]]) -> Union[Path, None]: + if isinstance(s, str) or isinstance(s, Path): + return Path(s) + return None + + +def _as_bytesio(s: Any) -> Union[io.BytesIO, None]: + if isinstance(s, io.BytesIO): + return s + elif isinstance(s, bytes): + return io.BytesIO(s) + return None diff --git a/ollama/_types.py b/ollama/_types.py new file mode 100644 index 0000000..d263269 --- /dev/null +++ b/ollama/_types.py @@ -0,0 +1,53 @@ +from typing import Any, TypedDict, List + +import sys + +if sys.version_info < (3, 11): + from typing_extensions import NotRequired +else: + from typing import NotRequired + + +class Message(TypedDict): + role: str + content: str + images: NotRequired[List[Any]] + + +class Options(TypedDict, total=False): + # load time options + numa: bool + num_ctx: int + num_batch: int + num_gqa: int + num_gpu: int + main_gpu: int + low_vram: bool + f16_kv: bool + logits_all: bool + vocab_only: bool + use_mmap: bool + use_mlock: bool + embedding_only: bool + rope_frequency_base: float + rope_frequency_scale: float + num_thread: int + + # runtime options + num_keep: int + seed: int + num_predict: int + top_k: int + top_p: float + tfs_z: float + typical_p: float + repeat_last_n: int + temperature: float + repeat_penalty: float + presence_penalty: float + frequency_penalty: float + mirostat: int + mirostat_tau: float + mirostat_eta: float + penalize_newline: bool + stop: List[str] diff --git a/ollama/client.py b/ollama/client.py deleted file mode 100644 index 578757f..0000000 --- a/ollama/client.py +++ /dev/null @@ -1,182 +0,0 @@ -import io -import json -import httpx -from pathlib import Path -from hashlib import sha256 -from base64 import b64encode - - -class BaseClient: - - def __init__(self, client, base_url='http://127.0.0.1:11434'): - self._client = client(base_url=base_url, follow_redirects=True, timeout=None) - - -class Client(BaseClient): - - def __init__(self, base='http://localhost:11434'): - super().__init__(httpx.Client, base) - - def _request(self, method, url, **kwargs): - response = self._client.request(method, url, **kwargs) - response.raise_for_status() - return response - - def _request_json(self, method, url, **kwargs): - return self._request(method, url, **kwargs).json() - - def stream(self, method, url, **kwargs): - with self._client.stream(method, url, **kwargs) as r: - for line in r.iter_lines(): - part = json.loads(line) - if e := part.get('error'): - raise Exception(e) - yield part - - def generate(self, model, prompt='', system='', template='', context=None, stream=False, raw=False, format='', images=None, options=None): - fn = self.stream if stream else self._request_json - return fn('POST', '/api/generate', json={ - 'model': model, - 'prompt': prompt, - 'system': system, - 'template': template, - 'context': context or [], - 'stream': stream, - 'raw': raw, - 'images': [_encode_image(image) for image in images or []], - 'format': format, - 'options': options or {}, - }) - - def chat(self, model, messages=None, stream=False, format='', options=None): - for message in messages or []: - if not isinstance(message, dict): - raise TypeError('messages must be a list of strings') - if not (role := message.get('role')) or role not in ['system', 'user', 'assistant']: - raise Exception('messages must contain a role and it must be one of "system", "user", or "assistant"') - if not message.get('content'): - raise Exception('messages must contain content') - if images := message.get('images'): - message['images'] = [_encode_image(image) for image in images] - - fn = self.stream if stream else self._request_json - return fn('POST', '/api/chat', json={ - 'model': model, - 'messages': messages, - 'stream': stream, - 'format': format, - 'options': options or {}, - }) - - def pull(self, model, insecure=False, stream=False): - fn = self.stream if stream else self._request_json - return fn('POST', '/api/pull', json={ - 'model': model, - 'insecure': insecure, - 'stream': stream, - }) - - def push(self, model, insecure=False, stream=False): - fn = self.stream if stream else self._request_json - return fn('POST', '/api/push', json={ - 'model': model, - 'insecure': insecure, - 'stream': stream, - }) - - def create(self, model, path=None, modelfile=None, stream=False): - if (path := _as_path(path)) and path.exists(): - modelfile = _parse_modelfile(path.read_text(), self.create_blob, base=path.parent) - elif modelfile: - modelfile = _parse_modelfile(modelfile, self.create_blob) - else: - raise Exception('must provide either path or modelfile') - - fn = self.stream if stream else self._request_json - return fn('POST', '/api/create', json={ - 'model': model, - 'modelfile': modelfile, - 'stream': stream, - }) - - - def create_blob(self, path): - sha256sum = sha256() - with open(path, 'rb') as r: - while True: - chunk = r.read(32*1024) - if not chunk: - break - sha256sum.update(chunk) - - digest = f'sha256:{sha256sum.hexdigest()}' - - try: - self._request('HEAD', f'/api/blobs/{digest}') - except httpx.HTTPError: - with open(path, 'rb') as r: - self._request('PUT', f'/api/blobs/{digest}', content=r) - - return digest - - def delete(self, model): - response = self._request_json('DELETE', '/api/delete', json={'model': model}) - return {'status': 'success' if response.status_code == 200 else 'error'} - - def list(self): - return self._request_json('GET', '/api/tags').get('models', []) - - def copy(self, source, target): - response = self._request_json('POST', '/api/copy', json={'source': source, 'destination': target}) - return {'status': 'success' if response.status_code == 200 else 'error'} - - def show(self, model): - return self._request_json('GET', '/api/show', json={'model': model}).json() - - -def _encode_image(image): - ''' - _encode_images takes a list of images and returns a generator of base64 encoded images. - if the image is a bytes object, it is assumed to be the raw bytes of an image. - if the image is a string, it is assumed to be a path to a file. - if the image is a Path object, it is assumed to be a path to a file. - if the image is a file-like object, it is assumed to be a container to the raw bytes of an image. - ''' - - if p := _as_path(image): - b64 = b64encode(p.read_bytes()) - elif b := _as_bytesio(image): - b64 = b64encode(b.read()) - else: - raise Exception('images must be a list of bytes, path-like objects, or file-like objects') - - return b64.decode('utf-8') - - -def _parse_modelfile(modelfile, cb, base=None): - base = Path.cwd() if base is None else base - - out = io.StringIO() - for line in io.StringIO(modelfile): - command, _, args = line.partition(' ') - if command.upper() in ['FROM', 'ADAPTER']: - path = Path(args).expanduser() - path = path if path.is_absolute() else base / path - if path.exists(): - args = f'@{cb(path)}' - - print(command, args, file=out) - return out.getvalue() - - -def _as_path(s): - if isinstance(s, str) or isinstance(s, Path): - return Path(s) - return None - -def _as_bytesio(s): - if isinstance(s, io.BytesIO): - return s - elif isinstance(s, bytes): - return io.BytesIO(s) - return None diff --git a/ollama/client_test.py b/ollama/client_test.py deleted file mode 100644 index cb563dc..0000000 --- a/ollama/client_test.py +++ /dev/null @@ -1,292 +0,0 @@ -import pytest -import os -import io -import types -import tempfile -from pathlib import Path -from ollama.client import Client -from pytest_httpserver import HTTPServer, URIPattern -from werkzeug.wrappers import Response -from PIL import Image - - -class PrefixPattern(URIPattern): - def __init__(self, prefix: str): - self.prefix = prefix - - def match(self, uri): - return uri.startswith(self.prefix) - - -def test_client_chat(httpserver: HTTPServer): - httpserver.expect_ordered_request('/api/chat', method='POST', json={ - 'model': 'dummy', - 'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}], - 'stream': False, - 'format': '', - 'options': {}, - }).respond_with_json({}) - - client = Client(httpserver.url_for('/')) - response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}]) - assert isinstance(response, dict) - - -def test_client_chat_stream(httpserver: HTTPServer): - httpserver.expect_ordered_request('/api/chat', method='POST', json={ - 'model': 'dummy', - 'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}], - 'stream': True, - 'format': '', - 'options': {}, - }).respond_with_json({}) - - client = Client(httpserver.url_for('/')) - response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], stream=True) - assert isinstance(response, types.GeneratorType) - - -def test_client_chat_images(httpserver: HTTPServer): - httpserver.expect_ordered_request('/api/chat', method='POST', json={ - 'model': 'dummy', - 'messages': [ - { - 'role': 'user', - 'content': 'Why is the sky blue?', - 'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'], - }, - ], - 'stream': False, - 'format': '', - 'options': {}, - }).respond_with_json({}) - - client = Client(httpserver.url_for('/')) - - with io.BytesIO() as b: - Image.new('RGB', (1, 1)).save(b, 'PNG') - response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?', 'images': [b.getvalue()]}]) - assert isinstance(response, dict) - - -def test_client_generate(httpserver: HTTPServer): - httpserver.expect_ordered_request('/api/generate', method='POST', json={ - 'model': 'dummy', - 'prompt': 'Why is the sky blue?', - 'system': '', - 'template': '', - 'context': [], - 'stream': False, - 'raw': False, - 'images': [], - 'format': '', - 'options': {}, - }).respond_with_json({}) - - client = Client(httpserver.url_for('/')) - response = client.generate('dummy', 'Why is the sky blue?') - assert isinstance(response, dict) - - -def test_client_generate_stream(httpserver: HTTPServer): - httpserver.expect_ordered_request('/api/generate', method='POST', json={ - 'model': 'dummy', - 'prompt': 'Why is the sky blue?', - 'system': '', - 'template': '', - 'context': [], - 'stream': True, - 'raw': False, - 'images': [], - 'format': '', - 'options': {}, - }).respond_with_json({}) - - client = Client(httpserver.url_for('/')) - response = client.generate('dummy', 'Why is the sky blue?', stream=True) - assert isinstance(response, types.GeneratorType) - - -def test_client_generate_images(httpserver: HTTPServer): - httpserver.expect_ordered_request('/api/generate', method='POST', json={ - 'model': 'dummy', - 'prompt': 'Why is the sky blue?', - 'system': '', - 'template': '', - 'context': [], - 'stream': False, - 'raw': False, - 'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'], - 'format': '', - 'options': {}, - }).respond_with_json({}) - - client = Client(httpserver.url_for('/')) - - with tempfile.NamedTemporaryFile() as temp: - Image.new('RGB', (1, 1)).save(temp, 'PNG') - response = client.generate('dummy', 'Why is the sky blue?', images=[temp.name]) - assert isinstance(response, dict) - - -def test_client_pull(httpserver: HTTPServer): - httpserver.expect_ordered_request('/api/pull', method='POST', json={ - 'model': 'dummy', - 'insecure': False, - 'stream': False, - }).respond_with_json({}) - - client = Client(httpserver.url_for('/')) - response = client.pull('dummy') - assert isinstance(response, dict) - - -def test_client_pull_stream(httpserver: HTTPServer): - httpserver.expect_ordered_request('/api/pull', method='POST', json={ - 'model': 'dummy', - 'insecure': False, - 'stream': True, - }).respond_with_json({}) - - client = Client(httpserver.url_for('/')) - response = client.pull('dummy', stream=True) - assert isinstance(response, types.GeneratorType) - - -def test_client_push(httpserver: HTTPServer): - httpserver.expect_ordered_request('/api/push', method='POST', json={ - 'model': 'dummy', - 'insecure': False, - 'stream': False, - }).respond_with_json({}) - - client = Client(httpserver.url_for('/')) - response = client.push('dummy') - assert isinstance(response, dict) - - -def test_client_push_stream(httpserver: HTTPServer): - httpserver.expect_ordered_request('/api/push', method='POST', json={ - 'model': 'dummy', - 'insecure': False, - 'stream': True, - }).respond_with_json({}) - - client = Client(httpserver.url_for('/')) - response = client.push('dummy', stream=True) - assert isinstance(response, types.GeneratorType) - - -def test_client_create_path(httpserver: HTTPServer): - httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200)) - httpserver.expect_ordered_request('/api/create', method='POST', json={ - 'model': 'dummy', - 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n', - 'stream': False, - }).respond_with_json({}) - - client = Client(httpserver.url_for('/')) - - with tempfile.NamedTemporaryFile() as modelfile: - with tempfile.NamedTemporaryFile() as blob: - modelfile.write(f'FROM {blob.name}'.encode('utf-8')) - modelfile.flush() - - response = client.create('dummy', path=modelfile.name) - assert isinstance(response, dict) - - -def test_client_create_path_relative(httpserver: HTTPServer): - httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200)) - httpserver.expect_ordered_request('/api/create', method='POST', json={ - 'model': 'dummy', - 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n', - 'stream': False, - }).respond_with_json({}) - - client = Client(httpserver.url_for('/')) - - with tempfile.NamedTemporaryFile() as modelfile: - with tempfile.NamedTemporaryFile(dir=Path(modelfile.name).parent) as blob: - modelfile.write(f'FROM {Path(blob.name).name}'.encode('utf-8')) - modelfile.flush() - - response = client.create('dummy', path=modelfile.name) - assert isinstance(response, dict) - - -@pytest.fixture -def userhomedir(): - with tempfile.TemporaryDirectory() as temp: - home = os.getenv('HOME', '') - os.environ['HOME'] = temp - yield Path(temp) - os.environ['HOME'] = home - - -def test_client_create_path_user_home(httpserver: HTTPServer, userhomedir): - httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200)) - httpserver.expect_ordered_request('/api/create', method='POST', json={ - 'model': 'dummy', - 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n', - 'stream': False, - }).respond_with_json({}) - - client = Client(httpserver.url_for('/')) - - with tempfile.NamedTemporaryFile() as modelfile: - with tempfile.NamedTemporaryFile(dir=userhomedir) as blob: - modelfile.write(f'FROM ~/{Path(blob.name).name}'.encode('utf-8')) - modelfile.flush() - - response = client.create('dummy', path=modelfile.name) - assert isinstance(response, dict) - - -def test_client_create_modelfile(httpserver: HTTPServer): - httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200)) - httpserver.expect_ordered_request('/api/create', method='POST', json={ - 'model': 'dummy', - 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n', - 'stream': False, - }).respond_with_json({}) - - client = Client(httpserver.url_for('/')) - - with tempfile.NamedTemporaryFile() as blob: - response = client.create('dummy', modelfile=f'FROM {blob.name}') - assert isinstance(response, dict) - - -def test_client_create_from_library(httpserver: HTTPServer): - httpserver.expect_ordered_request('/api/create', method='POST', json={ - 'model': 'dummy', - 'modelfile': 'FROM llama2\n', - 'stream': False, - }).respond_with_json({}) - - client = Client(httpserver.url_for('/')) - - response = client.create('dummy', modelfile='FROM llama2') - assert isinstance(response, dict) - - -def test_client_create_blob(httpserver: HTTPServer): - httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=404)) - httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='PUT').respond_with_response(Response(status=201)) - - client = Client(httpserver.url_for('/')) - - with tempfile.NamedTemporaryFile() as blob: - response = client.create_blob(blob.name) - assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855' - - -def test_client_create_blob_exists(httpserver: HTTPServer): - httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200)) - - client = Client(httpserver.url_for('/')) - - with tempfile.NamedTemporaryFile() as blob: - response = client.create_blob(blob.name) - assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855' diff --git a/poetry.lock b/poetry.lock index a61b83c..3db7b00 100644 --- a/poetry.lock +++ b/poetry.lock @@ -387,6 +387,24 @@ tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} [package.extras] testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-asyncio" +version = "0.23.2" +description = "Pytest support for asyncio" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-asyncio-0.23.2.tar.gz", hash = "sha256:c16052382554c7b22d48782ab3438d5b10f8cf7a4bdcae7f0f67f097d95beecc"}, + {file = "pytest_asyncio-0.23.2-py3-none-any.whl", hash = "sha256:ea9021364e32d58f0be43b91c6233fb8d2224ccef2398d6837559e587682808f"}, +] + +[package.dependencies] +pytest = ">=7.0.0" + +[package.extras] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] +testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] + [[package]] name = "pytest-cov" version = "4.1.0" @@ -498,4 +516,4 @@ watchdog = ["watchdog (>=2.3)"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "b9f64e1a5795a417d2dbff7286360f8d3f8f10fdfa9580411940d144c2561e92" +content-hash = "9416a897c95d3c80cf1bfd3cc61cd19f0143c9bd0bc7c219fcb31ee27c497c9d" diff --git a/pyproject.toml b/pyproject.toml index 5f3858a..76b5e15 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,9 +1,12 @@ [tool.poetry] name = "ollama" -version = "0.1.0" +version = "0.0.0" description = "The official Python client for Ollama." authors = ["Ollama "] +license = "MIT" readme = "README.md" +homepage = "https://ollama.ai" +repository = "https://github.com/jmorganca/ollama-python" [tool.poetry.dependencies] python = "^3.8" @@ -11,12 +14,18 @@ httpx = "^0.25.2" [tool.poetry.group.dev.dependencies] pytest = "^7.4.3" +pytest-asyncio = "^0.23.2" pytest-cov = "^4.1.0" pytest-httpserver = "^1.0.8" pillow = "^10.1.0" ruff = "^0.1.8" +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" + [tool.ruff] +line-length = 999 indent-width = 2 [tool.ruff.format] @@ -26,7 +35,3 @@ indent-style = "space" [tool.ruff.lint] select = ["E", "F", "B"] ignore = ["E501"] - -[build-system] -requires = ["poetry-core"] -build-backend = "poetry.core.masonry.api" diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 0000000..fe151dc --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,779 @@ +import os +import io +import json +import types +import pytest +import tempfile +from pathlib import Path +from pytest_httpserver import HTTPServer, URIPattern +from werkzeug.wrappers import Request, Response +from PIL import Image + +from ollama._client import Client, AsyncClient + + +class PrefixPattern(URIPattern): + def __init__(self, prefix: str): + self.prefix = prefix + + def match(self, uri): + return uri.startswith(self.prefix) + + +def test_client_chat(httpserver: HTTPServer): + httpserver.expect_ordered_request( + '/api/chat', + method='POST', + json={ + 'model': 'dummy', + 'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}], + 'stream': False, + 'format': '', + 'options': {}, + }, + ).respond_with_json( + { + 'model': 'dummy', + 'message': { + 'role': 'assistant', + 'content': "I don't know.", + }, + } + ) + + client = Client(httpserver.url_for('/')) + response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}]) + assert response['model'] == 'dummy' + assert response['message']['role'] == 'assistant' + assert response['message']['content'] == "I don't know." + + +def test_client_chat_stream(httpserver: HTTPServer): + def stream_handler(_: Request): + def generate(): + for message in ['I ', "don't ", 'know.']: + yield ( + json.dumps( + { + 'model': 'dummy', + 'message': { + 'role': 'assistant', + 'content': message, + }, + } + ) + + '\n' + ) + + return Response(generate()) + + httpserver.expect_ordered_request( + '/api/chat', + method='POST', + json={ + 'model': 'dummy', + 'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}], + 'stream': True, + 'format': '', + 'options': {}, + }, + ).respond_with_handler(stream_handler) + + client = Client(httpserver.url_for('/')) + response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], stream=True) + for part in response: + assert part['message']['role'] in 'assistant' + assert part['message']['content'] in ['I ', "don't ", 'know.'] + + +def test_client_chat_images(httpserver: HTTPServer): + httpserver.expect_ordered_request( + '/api/chat', + method='POST', + json={ + 'model': 'dummy', + 'messages': [ + { + 'role': 'user', + 'content': 'Why is the sky blue?', + 'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'], + }, + ], + 'stream': False, + 'format': '', + 'options': {}, + }, + ).respond_with_json( + { + 'model': 'dummy', + 'message': { + 'role': 'assistant', + 'content': "I don't know.", + }, + } + ) + + client = Client(httpserver.url_for('/')) + + with io.BytesIO() as b: + Image.new('RGB', (1, 1)).save(b, 'PNG') + response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?', 'images': [b.getvalue()]}]) + assert response['model'] == 'dummy' + assert response['message']['role'] == 'assistant' + assert response['message']['content'] == "I don't know." + + +def test_client_generate(httpserver: HTTPServer): + httpserver.expect_ordered_request( + '/api/generate', + method='POST', + json={ + 'model': 'dummy', + 'prompt': 'Why is the sky blue?', + 'system': '', + 'template': '', + 'context': [], + 'stream': False, + 'raw': False, + 'images': [], + 'format': '', + 'options': {}, + }, + ).respond_with_json( + { + 'model': 'dummy', + 'response': 'Because it is.', + } + ) + + client = Client(httpserver.url_for('/')) + response = client.generate('dummy', 'Why is the sky blue?') + assert response['model'] == 'dummy' + assert response['response'] == 'Because it is.' + + +def test_client_generate_stream(httpserver: HTTPServer): + def stream_handler(_: Request): + def generate(): + for message in ['Because ', 'it ', 'is.']: + yield ( + json.dumps( + { + 'model': 'dummy', + 'response': message, + } + ) + + '\n' + ) + + return Response(generate()) + + httpserver.expect_ordered_request( + '/api/generate', + method='POST', + json={ + 'model': 'dummy', + 'prompt': 'Why is the sky blue?', + 'system': '', + 'template': '', + 'context': [], + 'stream': True, + 'raw': False, + 'images': [], + 'format': '', + 'options': {}, + }, + ).respond_with_handler(stream_handler) + + client = Client(httpserver.url_for('/')) + response = client.generate('dummy', 'Why is the sky blue?', stream=True) + for part in response: + assert part['model'] == 'dummy' + assert part['response'] in ['Because ', 'it ', 'is.'] + + +def test_client_generate_images(httpserver: HTTPServer): + httpserver.expect_ordered_request( + '/api/generate', + method='POST', + json={ + 'model': 'dummy', + 'prompt': 'Why is the sky blue?', + 'system': '', + 'template': '', + 'context': [], + 'stream': False, + 'raw': False, + 'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'], + 'format': '', + 'options': {}, + }, + ).respond_with_json( + { + 'model': 'dummy', + 'response': 'Because it is.', + } + ) + + client = Client(httpserver.url_for('/')) + + with tempfile.NamedTemporaryFile() as temp: + Image.new('RGB', (1, 1)).save(temp, 'PNG') + response = client.generate('dummy', 'Why is the sky blue?', images=[temp.name]) + assert response['model'] == 'dummy' + assert response['response'] == 'Because it is.' + + +def test_client_pull(httpserver: HTTPServer): + httpserver.expect_ordered_request( + '/api/pull', + method='POST', + json={ + 'model': 'dummy', + 'insecure': False, + 'stream': False, + }, + ).respond_with_json( + { + 'status': 'success', + } + ) + + client = Client(httpserver.url_for('/')) + response = client.pull('dummy') + assert response['status'] == 'success' + + +def test_client_pull_stream(httpserver: HTTPServer): + def stream_handler(_: Request): + def generate(): + yield json.dumps({'status': 'pulling manifest'}) + '\n' + yield json.dumps({'status': 'verifying sha256 digest'}) + '\n' + yield json.dumps({'status': 'writing manifest'}) + '\n' + yield json.dumps({'status': 'removing any unused layers'}) + '\n' + yield json.dumps({'status': 'success'}) + '\n' + + return Response(generate()) + + httpserver.expect_ordered_request( + '/api/pull', + method='POST', + json={ + 'model': 'dummy', + 'insecure': False, + 'stream': True, + }, + ).respond_with_json({}) + + client = Client(httpserver.url_for('/')) + response = client.pull('dummy', stream=True) + assert isinstance(response, types.GeneratorType) + + +def test_client_push(httpserver: HTTPServer): + httpserver.expect_ordered_request( + '/api/push', + method='POST', + json={ + 'model': 'dummy', + 'insecure': False, + 'stream': False, + }, + ).respond_with_json({}) + + client = Client(httpserver.url_for('/')) + response = client.push('dummy') + assert isinstance(response, dict) + + +def test_client_push_stream(httpserver: HTTPServer): + httpserver.expect_ordered_request( + '/api/push', + method='POST', + json={ + 'model': 'dummy', + 'insecure': False, + 'stream': True, + }, + ).respond_with_json({}) + + client = Client(httpserver.url_for('/')) + response = client.push('dummy', stream=True) + assert isinstance(response, types.GeneratorType) + + +def test_client_create_path(httpserver: HTTPServer): + httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200)) + httpserver.expect_ordered_request( + '/api/create', + method='POST', + json={ + 'model': 'dummy', + 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n', + 'stream': False, + }, + ).respond_with_json({}) + + client = Client(httpserver.url_for('/')) + + with tempfile.NamedTemporaryFile() as modelfile: + with tempfile.NamedTemporaryFile() as blob: + modelfile.write(f'FROM {blob.name}'.encode('utf-8')) + modelfile.flush() + + response = client.create('dummy', path=modelfile.name) + assert isinstance(response, dict) + + +def test_client_create_path_relative(httpserver: HTTPServer): + httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200)) + httpserver.expect_ordered_request( + '/api/create', + method='POST', + json={ + 'model': 'dummy', + 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n', + 'stream': False, + }, + ).respond_with_json({}) + + client = Client(httpserver.url_for('/')) + + with tempfile.NamedTemporaryFile() as modelfile: + with tempfile.NamedTemporaryFile(dir=Path(modelfile.name).parent) as blob: + modelfile.write(f'FROM {Path(blob.name).name}'.encode('utf-8')) + modelfile.flush() + + response = client.create('dummy', path=modelfile.name) + assert isinstance(response, dict) + + +@pytest.fixture +def userhomedir(): + with tempfile.TemporaryDirectory() as temp: + home = os.getenv('HOME', '') + os.environ['HOME'] = temp + yield Path(temp) + os.environ['HOME'] = home + + +def test_client_create_path_user_home(httpserver: HTTPServer, userhomedir): + httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200)) + httpserver.expect_ordered_request( + '/api/create', + method='POST', + json={ + 'model': 'dummy', + 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n', + 'stream': False, + }, + ).respond_with_json({}) + + client = Client(httpserver.url_for('/')) + + with tempfile.NamedTemporaryFile() as modelfile: + with tempfile.NamedTemporaryFile(dir=userhomedir) as blob: + modelfile.write(f'FROM ~/{Path(blob.name).name}'.encode('utf-8')) + modelfile.flush() + + response = client.create('dummy', path=modelfile.name) + assert isinstance(response, dict) + + +def test_client_create_modelfile(httpserver: HTTPServer): + httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200)) + httpserver.expect_ordered_request( + '/api/create', + method='POST', + json={ + 'model': 'dummy', + 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n', + 'stream': False, + }, + ).respond_with_json({}) + + client = Client(httpserver.url_for('/')) + + with tempfile.NamedTemporaryFile() as blob: + response = client.create('dummy', modelfile=f'FROM {blob.name}') + assert isinstance(response, dict) + + +def test_client_create_from_library(httpserver: HTTPServer): + httpserver.expect_ordered_request( + '/api/create', + method='POST', + json={ + 'model': 'dummy', + 'modelfile': 'FROM llama2\n', + 'stream': False, + }, + ).respond_with_json({}) + + client = Client(httpserver.url_for('/')) + + response = client.create('dummy', modelfile='FROM llama2') + assert isinstance(response, dict) + + +def test_client_create_blob(httpserver: HTTPServer): + httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=404)) + httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='PUT').respond_with_response(Response(status=201)) + + client = Client(httpserver.url_for('/')) + + with tempfile.NamedTemporaryFile() as blob: + response = client._create_blob(blob.name) + assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855' + + +def test_client_create_blob_exists(httpserver: HTTPServer): + httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200)) + + client = Client(httpserver.url_for('/')) + + with tempfile.NamedTemporaryFile() as blob: + response = client._create_blob(blob.name) + assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855' + + +@pytest.mark.asyncio +async def test_async_client_chat(httpserver: HTTPServer): + httpserver.expect_ordered_request( + '/api/chat', + method='POST', + json={ + 'model': 'dummy', + 'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}], + 'stream': False, + 'format': '', + 'options': {}, + }, + ).respond_with_json({}) + + client = AsyncClient(httpserver.url_for('/')) + response = await client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}]) + assert isinstance(response, dict) + + +@pytest.mark.asyncio +async def test_async_client_chat_stream(httpserver: HTTPServer): + httpserver.expect_ordered_request( + '/api/chat', + method='POST', + json={ + 'model': 'dummy', + 'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}], + 'stream': True, + 'format': '', + 'options': {}, + }, + ).respond_with_json({}) + + client = AsyncClient(httpserver.url_for('/')) + response = await client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], stream=True) + assert isinstance(response, types.AsyncGeneratorType) + + +@pytest.mark.asyncio +async def test_async_client_chat_images(httpserver: HTTPServer): + httpserver.expect_ordered_request( + '/api/chat', + method='POST', + json={ + 'model': 'dummy', + 'messages': [ + { + 'role': 'user', + 'content': 'Why is the sky blue?', + 'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'], + }, + ], + 'stream': False, + 'format': '', + 'options': {}, + }, + ).respond_with_json({}) + + client = AsyncClient(httpserver.url_for('/')) + + with io.BytesIO() as b: + Image.new('RGB', (1, 1)).save(b, 'PNG') + response = await client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?', 'images': [b.getvalue()]}]) + assert isinstance(response, dict) + + +@pytest.mark.asyncio +async def test_async_client_generate(httpserver: HTTPServer): + httpserver.expect_ordered_request( + '/api/generate', + method='POST', + json={ + 'model': 'dummy', + 'prompt': 'Why is the sky blue?', + 'system': '', + 'template': '', + 'context': [], + 'stream': False, + 'raw': False, + 'images': [], + 'format': '', + 'options': {}, + }, + ).respond_with_json({}) + + client = AsyncClient(httpserver.url_for('/')) + response = await client.generate('dummy', 'Why is the sky blue?') + assert isinstance(response, dict) + + +@pytest.mark.asyncio +async def test_async_client_generate_stream(httpserver: HTTPServer): + httpserver.expect_ordered_request( + '/api/generate', + method='POST', + json={ + 'model': 'dummy', + 'prompt': 'Why is the sky blue?', + 'system': '', + 'template': '', + 'context': [], + 'stream': True, + 'raw': False, + 'images': [], + 'format': '', + 'options': {}, + }, + ).respond_with_json({}) + + client = AsyncClient(httpserver.url_for('/')) + response = await client.generate('dummy', 'Why is the sky blue?', stream=True) + assert isinstance(response, types.AsyncGeneratorType) + + +@pytest.mark.asyncio +async def test_async_client_generate_images(httpserver: HTTPServer): + httpserver.expect_ordered_request( + '/api/generate', + method='POST', + json={ + 'model': 'dummy', + 'prompt': 'Why is the sky blue?', + 'system': '', + 'template': '', + 'context': [], + 'stream': False, + 'raw': False, + 'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'], + 'format': '', + 'options': {}, + }, + ).respond_with_json({}) + + client = AsyncClient(httpserver.url_for('/')) + + with tempfile.NamedTemporaryFile() as temp: + Image.new('RGB', (1, 1)).save(temp, 'PNG') + response = await client.generate('dummy', 'Why is the sky blue?', images=[temp.name]) + assert isinstance(response, dict) + + +@pytest.mark.asyncio +async def test_async_client_pull(httpserver: HTTPServer): + httpserver.expect_ordered_request( + '/api/pull', + method='POST', + json={ + 'model': 'dummy', + 'insecure': False, + 'stream': False, + }, + ).respond_with_json({}) + + client = AsyncClient(httpserver.url_for('/')) + response = await client.pull('dummy') + assert isinstance(response, dict) + + +@pytest.mark.asyncio +async def test_async_client_pull_stream(httpserver: HTTPServer): + httpserver.expect_ordered_request( + '/api/pull', + method='POST', + json={ + 'model': 'dummy', + 'insecure': False, + 'stream': True, + }, + ).respond_with_json({}) + + client = AsyncClient(httpserver.url_for('/')) + response = await client.pull('dummy', stream=True) + assert isinstance(response, types.AsyncGeneratorType) + + +@pytest.mark.asyncio +async def test_async_client_push(httpserver: HTTPServer): + httpserver.expect_ordered_request( + '/api/push', + method='POST', + json={ + 'model': 'dummy', + 'insecure': False, + 'stream': False, + }, + ).respond_with_json({}) + + client = AsyncClient(httpserver.url_for('/')) + response = await client.push('dummy') + assert isinstance(response, dict) + + +@pytest.mark.asyncio +async def test_async_client_push_stream(httpserver: HTTPServer): + httpserver.expect_ordered_request( + '/api/push', + method='POST', + json={ + 'model': 'dummy', + 'insecure': False, + 'stream': True, + }, + ).respond_with_json({}) + + client = AsyncClient(httpserver.url_for('/')) + response = await client.push('dummy', stream=True) + assert isinstance(response, types.AsyncGeneratorType) + + +@pytest.mark.asyncio +async def test_async_client_create_path(httpserver: HTTPServer): + httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200)) + httpserver.expect_ordered_request( + '/api/create', + method='POST', + json={ + 'model': 'dummy', + 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n', + 'stream': False, + }, + ).respond_with_json({}) + + client = AsyncClient(httpserver.url_for('/')) + + with tempfile.NamedTemporaryFile() as modelfile: + with tempfile.NamedTemporaryFile() as blob: + modelfile.write(f'FROM {blob.name}'.encode('utf-8')) + modelfile.flush() + + response = await client.create('dummy', path=modelfile.name) + assert isinstance(response, dict) + + +@pytest.mark.asyncio +async def test_async_client_create_path_relative(httpserver: HTTPServer): + httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200)) + httpserver.expect_ordered_request( + '/api/create', + method='POST', + json={ + 'model': 'dummy', + 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n', + 'stream': False, + }, + ).respond_with_json({}) + + client = AsyncClient(httpserver.url_for('/')) + + with tempfile.NamedTemporaryFile() as modelfile: + with tempfile.NamedTemporaryFile(dir=Path(modelfile.name).parent) as blob: + modelfile.write(f'FROM {Path(blob.name).name}'.encode('utf-8')) + modelfile.flush() + + response = await client.create('dummy', path=modelfile.name) + assert isinstance(response, dict) + + +@pytest.mark.asyncio +async def test_async_client_create_path_user_home(httpserver: HTTPServer, userhomedir): + httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200)) + httpserver.expect_ordered_request( + '/api/create', + method='POST', + json={ + 'model': 'dummy', + 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n', + 'stream': False, + }, + ).respond_with_json({}) + + client = AsyncClient(httpserver.url_for('/')) + + with tempfile.NamedTemporaryFile() as modelfile: + with tempfile.NamedTemporaryFile(dir=userhomedir) as blob: + modelfile.write(f'FROM ~/{Path(blob.name).name}'.encode('utf-8')) + modelfile.flush() + + response = await client.create('dummy', path=modelfile.name) + assert isinstance(response, dict) + + +@pytest.mark.asyncio +async def test_async_client_create_modelfile(httpserver: HTTPServer): + httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200)) + httpserver.expect_ordered_request( + '/api/create', + method='POST', + json={ + 'model': 'dummy', + 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n', + 'stream': False, + }, + ).respond_with_json({}) + + client = AsyncClient(httpserver.url_for('/')) + + with tempfile.NamedTemporaryFile() as blob: + response = await client.create('dummy', modelfile=f'FROM {blob.name}') + assert isinstance(response, dict) + + +@pytest.mark.asyncio +async def test_async_client_create_from_library(httpserver: HTTPServer): + httpserver.expect_ordered_request( + '/api/create', + method='POST', + json={ + 'model': 'dummy', + 'modelfile': 'FROM llama2\n', + 'stream': False, + }, + ).respond_with_json({}) + + client = AsyncClient(httpserver.url_for('/')) + + response = await client.create('dummy', modelfile='FROM llama2') + assert isinstance(response, dict) + + +@pytest.mark.asyncio +async def test_async_client_create_blob(httpserver: HTTPServer): + httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=404)) + httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='PUT').respond_with_response(Response(status=201)) + + client = AsyncClient(httpserver.url_for('/')) + + with tempfile.NamedTemporaryFile() as blob: + response = await client._create_blob(blob.name) + assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855' + + +@pytest.mark.asyncio +async def test_async_client_create_blob_exists(httpserver: HTTPServer): + httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200)) + + client = AsyncClient(httpserver.url_for('/')) + + with tempfile.NamedTemporaryFile() as blob: + response = await client._create_blob(blob.name) + assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'