Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

removing nested-asyncio from Model Registry client #802

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions clients/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,20 @@ Or you can set the `is_secure` flag to `False` to connect **without** TLS (not r
registry = ModelRegistry("http://server-address", 8080, author="Ada Lovelace", is_secure=False) # insecure port set to 8080
```

ModelRegistry client is using asynch execution of HTTP API calls getting results synchronously. To do this, the
client's implementation is leveraging [AsyncTaskRunnerThread](src/model_registry/_async_task_runner_thread.py),
based on this [gist](https://gist.github.com/blink1073/969aeba85f32c285235750626f2eadd8), that works for both
standard [asyncio](https://docs.python.org/3/library/asyncio.html) and [uviloop](https://github.com/MagicStack/uvloop).
If you would like to overwrite it, you can create `ModelRegistry` using the following code:

```py
registry = ModelRegistry("http://server-address", 8080, author="Ada Lovelace", is_secure=False, async_task_runner=MyAsyncTaskRunner)
```

Where MyAsyncTaskRunner is an implementation, that should extend
[AsyncTaskRunnerBase](src/model_registry/_async_task_runner_base.py) implementing both `get_instance`
and `run` method.

### Registering models

To register your first model, you can use the `register_model` method:
Expand Down
2 changes: 2 additions & 0 deletions clients/python/src/model_registry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

__version__ = "0.2.14"

from ._async_task_runner_base import AsyncTaskRunnerBase
from ._client import ModelRegistry

__all__ = [
"ModelRegistry",
"AsyncTaskRunnerBase",
]
19 changes: 19 additions & 0 deletions clients/python/src/model_registry/_async_task_runner_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from collections.abc import Coroutine
from typing import Any

NOT_IMPLEMENTED = "Must be implemented by subclass"


class AsyncTaskRunnerBase:
"""A base task runner that runs an asyncio event loop on a background thread.

A user can add his own representation of this class
"""
@staticmethod
def get_instance():
"""Get an AsyncTaskRunner (singleton)."""
raise ValueError(NOT_IMPLEMENTED)

def run(self, coro: Coroutine) -> Any:
"""Synchronously run a coroutine on a background thread."""
raise ValueError(NOT_IMPLEMENTED)
92 changes: 92 additions & 0 deletions clients/python/src/model_registry/_async_task_runner_thread.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""Copyright (c) 2022 Steven Silvester.

All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.

2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.

3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
import asyncio
import atexit
from collections.abc import Coroutine
from threading import Lock, Thread
from typing import Any, Optional

from ._async_task_runner_base import AsyncTaskRunnerBase

SINGLETON = "This class is a singleton!"


class AsyncTaskRunnerThread(AsyncTaskRunnerBase):
"""A singleton task runner that runs an asyncio event loop on a background thread."""

__instance = None

@staticmethod
def get_instance():
"""Get an AsyncTaskRunner (singleton)."""
if AsyncTaskRunnerThread.__instance is None:
AsyncTaskRunnerThread()
assert AsyncTaskRunnerThread.__instance is not None
return AsyncTaskRunnerThread.__instance

def __init__(self):
"""Initialize."""
# make sure it is a singleton
if AsyncTaskRunnerThread.__instance is not None:
raise Exception(SINGLETON)
AsyncTaskRunnerThread.__instance = self
# initialize variables
self.__io_loop: Optional[asyncio.AbstractEventLoop] = None
self.__runner_thread: Optional[Thread] = None
self.__lock = Lock()
# register exit handler
atexit.register(self._close)

def _close(self):
"""Clean up. Stop the loop if running."""
if self.__io_loop:
self.__io_loop.stop()

def _runner(self) -> None:
"""Function to run in a thread."""
loop = self.__io_loop
assert loop is not None
try:
loop.run_forever()
finally:
loop.close()

def run(self, coro: Coroutine) -> Any:
"""Synchronously run a coroutine on a background thread."""
with self.__lock:
if self.__io_loop is None:
# If the asyncio loop does not exist
self.__io_loop = asyncio.new_event_loop()
self.__runner_thread = Thread(target=self._runner, daemon=True)
self.__runner_thread.start()
# run coroutine thread safe inside a thread. This return concurrent future
fut = asyncio.run_coroutine_threadsafe(coro, self.__io_loop)
# get concurrent future result
return fut.result()
53 changes: 23 additions & 30 deletions clients/python/src/model_registry/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
import os
from collections.abc import Mapping
from pathlib import Path
from typing import Any, TypeVar, Union, get_args
from typing import TypeVar, Union, get_args
from warnings import warn

from ._async_task_runner_base import AsyncTaskRunnerBase
from ._async_task_runner_thread import AsyncTaskRunnerThread
from .core import ModelRegistryAPIClient
from .exceptions import StoreError
from .types import (
Expand Down Expand Up @@ -63,6 +65,7 @@ def __init__(
custom_ca: str | None = None,
custom_ca_envvar: str | None = None,
log_level: int = logging.WARNING,
async_task_runner: type[AsyncTaskRunnerBase] = AsyncTaskRunnerThread
):
"""Constructor.

Expand All @@ -74,17 +77,16 @@ def __init__(
author: Name of the author.
is_secure: Whether to use a secure connection. Defaults to True.
user_token: The PEM-encoded user token as a string.
user_token_envvar: Environment variable to read the user token from if it's not passed as an arg. Defaults to KF_PIPELINES_SA_TOKEN_PATH.
user_token_envvar: Environment variable to read the user token from if it's not passed as an arg.
Defaults to KF_PIPELINES_SA_TOKEN_PATH.
custom_ca: Path to the PEM-encoded root certificates as a string.
custom_ca_envvar: Environment variable to read the custom CA from if it's not passed as an arg.
log_level: Log level. Defaults to logging.WARNING.
async_task_runner: implementation of async task runner. Default - AsyncTaskRunnerThread
"""
logger.setLevel(log_level)

import nest_asyncio

logger.debug("Setting up reentrant async event loop")
nest_asyncio.apply()
self.runner = async_task_runner.get_instance()

# TODO: get remaining args from env
self._author = author
Expand Down Expand Up @@ -127,16 +129,6 @@ def __init__(
)
self.get_registered_models().page_size(1)._next_page()

def async_runner(self, coro: Any) -> Any:
import asyncio

try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop.run_until_complete(coro)

async def _register_model(self, name: str, **kwargs) -> RegisteredModel:
if rm := await self._api.get_registered_model_by_params(name):
return rm
Expand Down Expand Up @@ -210,8 +202,8 @@ def register_model(
Returns:
Registered model.
"""
rm = self.async_runner(self._register_model(name, owner=owner or self._author))
mv = self.async_runner(
rm = self.runner.run(self._register_model(name, owner=owner or self._author))
mv = self.runner.run(
self._register_new_version(
rm,
version,
Expand All @@ -220,7 +212,7 @@ def register_model(
custom_properties=metadata or {},
)
)
self.async_runner(
self.runner.run(
self._register_model_artifact(
mv,
name,
Expand All @@ -244,10 +236,10 @@ def update(self, model: TModel) -> TModel:
msg = f"Model must be one of {get_args(ModelTypes)}"
raise StoreError(msg)
if isinstance(model, RegisteredModel):
return self.async_runner(self._api.upsert_registered_model(model))
return self.runner.run(self._api.upsert_registered_model(model))
if isinstance(model, ModelVersion):
return self.async_runner(self._api.upsert_model_version(model, None))
return self.async_runner(self._api.upsert_model_artifact(model))
return self.runner.run(self._api.upsert_model_version(model, None))
return self.runner.run(self._api.upsert_model_artifact(model))

def register_hf_model(
self,
Expand Down Expand Up @@ -289,8 +281,8 @@ def register_hf_model(
from huggingface_hub import HfApi, hf_hub_url, utils
except ImportError as e:
msg = """package `huggingface-hub` is not installed.
To import models from Hugging Face Hub, start by installing the `huggingface-hub` package, either directly or as an
extra (available as `model-registry[hf]`), e.g.:
To import models from Hugging Face Hub, start by installing the `huggingface-hub` package,
either directly or as an extra (available as `model-registry[hf]`), e.g.:
```sh
!pip install --pre model-registry[hf]
```
Expand Down Expand Up @@ -363,7 +355,7 @@ def get_registered_model(self, name: str) -> RegisteredModel | None:
Returns:
Registered model.
"""
return self.async_runner(self._api.get_registered_model_by_params(name))
return self.runner.run(self._api.get_registered_model_by_params(name))

def get_model_version(self, name: str, version: str) -> ModelVersion | None:
"""Get a model version.
Expand All @@ -382,7 +374,7 @@ def get_model_version(self, name: str, version: str) -> ModelVersion | None:
msg = f"Model {name} does not exist"
raise StoreError(msg)
assert rm.id
return self.async_runner(self._api.get_model_version_by_params(rm.id, version))
return self.runner.run(self._api.get_model_version_by_params(rm.id, version))

def get_model_artifact(self, name: str, version: str) -> ModelArtifact | None:
"""Get a model artifact.
Expand All @@ -401,7 +393,7 @@ def get_model_artifact(self, name: str, version: str) -> ModelArtifact | None:
msg = f"Version {version} does not exist"
raise StoreError(msg)
assert mv.id
return self.async_runner(self._api.get_model_artifact_by_params(name, mv.id))
return self.runner.run(self._api.get_model_artifact_by_params(name, mv.id))

def get_registered_models(self) -> Pager[RegisteredModel]:
"""Get a pager for registered models.
Expand All @@ -411,7 +403,7 @@ def get_registered_models(self) -> Pager[RegisteredModel]:
"""

def rm_list(options: ListOptions) -> list[RegisteredModel]:
return self.async_runner(self._api.get_registered_models(options))
return self.runner.run(self._api.get_registered_models(options))

return Pager[RegisteredModel](rm_list)

Expand All @@ -432,8 +424,9 @@ def get_model_versions(self, name: str) -> Pager[ModelVersion]:
raise StoreError(msg)

def rm_versions(options: ListOptions) -> list[ModelVersion]:
# type checkers can't restrict the type inside a nested function: https://mypy.readthedocs.io/en/stable/common_issues.html#narrowing-and-inner-functions
# type checkers can't restrict the type inside a nested function:
# https://mypy.readthedocs.io/en/stable/common_issues.html#narrowing-and-inner-functions
assert rm.id
return self.async_runner(self._api.get_model_versions(rm.id, options))
return self.runner.run(self._api.get_model_versions(rm.id, options))

return Pager[ModelVersion](rm_versions)