Skip to content

Commit

Permalink
added base_domain as an optional param
Browse files Browse the repository at this point in the history
  • Loading branch information
Anupreet Walia committed Jun 13, 2024
1 parent d9f3426 commit 9bcfb2b
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 15 deletions.
3 changes: 3 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ asdf plugin add python
asdf plugin add poetry
asdf install

# Create a new shell session within the context of the virtual environment
poetry shell

# Install poetry dependencies
poetry install

Expand Down
13 changes: 13 additions & 0 deletions truss/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,12 @@ def run(target_directory: str, build_dir: Path, tag, port, attach) -> None:
required=False,
help="Name of the remote in .trussrc to patch changes to",
)
@click.option(
"--remote_base_domain",
type=str,
required=False,
help="Name of the remote_base_domain in .trussrc to patch changes to",
)
@error_handling
def watch(
target_directory: str,
Expand Down Expand Up @@ -464,6 +470,12 @@ def _extract_request_data(data: Optional[str], file: Optional[Path]):
required=False,
help="Name of the remote in .trussrc to push to",
)
@click.option(
"--remote_base_domain",
type=str,
required=False,
help="Name of the remote_base_domain in .trussrc for invoking onferences from",
)
@click.option(
"-d",
"--data",
Expand Down Expand Up @@ -510,6 +522,7 @@ def _extract_request_data(data: Optional[str], file: Optional[Path]):
def predict(
target_directory: str,
remote: str,
remote_base_domain: str,
data: Optional[str],
file: Optional[Path],
published: Optional[bool],
Expand Down
11 changes: 10 additions & 1 deletion truss/remote/baseten/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,17 @@


class BasetenRemote(TrussRemote):
def __init__(self, remote_url: str, api_key: str, **kwargs):
def __init__(
self,
remote_url: str,
api_key: str,
remote_base_domain: Optional[str] = None,
**kwargs,
):
super().__init__(remote_url, **kwargs)
self._auth_service = AuthService(api_key=api_key)
self._api = BasetenApi(remote_url, self._auth_service)
self.remote_base_domain = remote_base_domain

@property
def api(self) -> BasetenApi:
Expand Down Expand Up @@ -111,6 +118,7 @@ def push( # type: ignore
api_key=self._auth_service.authenticate().value,
service_url=f"{self._remote_url}/model_versions/{model_version_id}",
truss_handle=truss_handle,
remote_base_domain=self.remote_base_domain,
api=self._api,
)

Expand Down Expand Up @@ -196,6 +204,7 @@ def get_service(self, **kwargs) -> BasetenService:
is_draft=not published,
api_key=self._auth_service.authenticate().value,
service_url=f"{self._remote_url}{service_url_path}",
remote_base_domain=self.remote_base_domain,
api=self._api,
)

Expand Down
36 changes: 31 additions & 5 deletions truss/remote/baseten/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,32 @@
DEFAULT_STREAM_ENCODING = "utf-8"


def _add_model_subdomain(rest_api_url: str, model_subdomain: str) -> str:
"""E.g. `https://api.baseten.co` -> `https://{model_subdomain}.api.baseten.co`"""
def _add_model_domain(
rest_api_url: str,
model_subdomain: str,
model_base_domain: Optional[str] = None,
) -> str:
"""E.g. `https://api.baseten.co` -> `https://{model_subdomain}.api.baseten.{model_base_domain}`"""
parsed_url = urllib.parse.urlparse(rest_api_url)
new_netloc = f"{model_subdomain}.{parsed_url.netloc}"

# Extract the existing subdomain parts and root domain parts
netloc_parts = parsed_url.netloc.split(".")
subdomain_parts = netloc_parts[:-1]
subdomain = ".".join(subdomain_parts) if subdomain_parts else ""

new_netloc = ""

if model_base_domain:
# Replace the root domain with the new base domain
new_netloc = f"{subdomain + '.' if subdomain else ''}{model_base_domain}"
else:
# Preserve the original root domain
new_netloc = parsed_url.netloc

if model_subdomain:
# Add the new subdomain
new_netloc = f"{model_subdomain}.{new_netloc}"

model_url = parsed_url._replace(netloc=new_netloc)
return str(urllib.parse.urlunparse(model_url))

Expand All @@ -30,12 +52,14 @@ def __init__(
api_key: str,
service_url: str,
api: BasetenApi,
remote_base_domain: Optional[str] = None,
truss_handle: Optional[TrussHandle] = None,
):
super().__init__(is_draft=is_draft, service_url=service_url)
self._model_id = model_id
self._model_version_id = model_version_id
self._auth_service = AuthService(api_key=api_key)
self.remote_base_domain = remote_base_domain
self._api = api
self._truss_handle = truss_handle

Expand Down Expand Up @@ -64,7 +88,6 @@ def predict(
response = self._send_request(
self.predict_url, "POST", data=model_request_body, stream=True
)

if response.headers.get("transfer-encoding") == "chunked":
# Case of streaming response, the backend does not set an encoding, so
# manually decode to the contents to utf-8 here.
Expand Down Expand Up @@ -107,7 +130,10 @@ def predict_url(self) -> str:
Get the URL for the prediction endpoint.
"""
# E.g. `https://api.baseten.co` -> `https://model-{model_id}.api.baseten.co`
url = _add_model_subdomain(self._api.rest_api_url, f"model-{self.model_id}")
url = _add_model_domain(
self._api.rest_api_url, f"model-{self.model_id}", self.remote_base_domain
)
print(url)
if self.is_draft:
# "https://model-{model_id}.api.baseten.co/development".
url = f"{url}/development/predict"
Expand Down
6 changes: 6 additions & 0 deletions truss/remote/remote_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,18 @@ def inquire_remote_config() -> RemoteConfig:
validate=NonEmptyValidator(),
).execute()

remote_base_domain = inquirer.text(
"🌐 What is the base domain for your deployed models?",
qmark="",
).execute()

return RemoteConfig(
name="baseten",
configs={
"remote_provider": "baseten",
"api_key": api_key,
"remote_url": remote_url,
"remote_base_domain": remote_base_domain,
},
)

Expand Down
17 changes: 9 additions & 8 deletions truss/tests/remote/baseten/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@

_TEST_REMOTE_URL = "http://test_remote.com"
_TEST_REMOTE_GRAPHQL_PATH = "http://test_remote.com/graphql/"
_TEST_REMOTE_BASE_DOMAIN = "test_remote_base.co"


def test_get_service_by_version_id():
remote = BasetenRemote(_TEST_REMOTE_URL, "api_key")
remote = BasetenRemote(_TEST_REMOTE_URL, "api_key", _TEST_REMOTE_BASE_DOMAIN)

version = {
"id": "version_id",
Expand Down Expand Up @@ -44,7 +45,7 @@ def test_get_service_by_version_id_no_version():


def test_get_service_by_model_name():
remote = BasetenRemote(_TEST_REMOTE_URL, "api_key")
remote = BasetenRemote(_TEST_REMOTE_URL, "api_key", _TEST_REMOTE_BASE_DOMAIN)

versions = [
{"id": "1", "is_draft": False, "is_primary": False},
Expand Down Expand Up @@ -120,7 +121,7 @@ def test_get_service_by_model_name_no_dev_version():


def test_get_service_by_model_name_no_prod_version():
remote = BasetenRemote(_TEST_REMOTE_URL, "api_key")
remote = BasetenRemote(_TEST_REMOTE_URL, "api_key", _TEST_REMOTE_BASE_DOMAIN)

versions = [
{"id": "1", "is_draft": True, "is_primary": False},
Expand Down Expand Up @@ -155,7 +156,7 @@ def test_get_service_by_model_name_no_prod_version():


def test_get_service_by_model_id():
remote = BasetenRemote(_TEST_REMOTE_URL, "api_key")
remote = BasetenRemote(_TEST_REMOTE_URL, "api_key", _TEST_REMOTE_BASE_DOMAIN)

model_response = {
"data": {
Expand All @@ -179,7 +180,7 @@ def test_get_service_by_model_id():


def test_get_service_by_model_id_no_model():
remote = BasetenRemote(_TEST_REMOTE_URL, "api_key")
remote = BasetenRemote(_TEST_REMOTE_URL, "api_key", _TEST_REMOTE_BASE_DOMAIN)
model_response = {"errors": [{"message": "error"}]}
with requests_mock.Mocker() as m:
m.post(
Expand All @@ -193,7 +194,7 @@ def test_get_service_by_model_id_no_model():
def test_push_raised_value_error_when_deployment_name_and_not_publish(
custom_model_truss_dir_with_pre_and_post,
):
remote = BasetenRemote(_TEST_REMOTE_URL, "api_key")
remote = BasetenRemote(_TEST_REMOTE_URL, "api_key", _TEST_REMOTE_BASE_DOMAIN)
model_response = {
"data": {
"model": {
Expand All @@ -220,7 +221,7 @@ def test_push_raised_value_error_when_deployment_name_and_not_publish(
def test_push_raised_value_error_when_deployment_name_is_not_valid(
custom_model_truss_dir_with_pre_and_post,
):
remote = BasetenRemote(_TEST_REMOTE_URL, "api_key")
remote = BasetenRemote(_TEST_REMOTE_URL, "api_key", _TEST_REMOTE_BASE_DOMAIN)
model_response = {
"data": {
"model": {
Expand All @@ -247,7 +248,7 @@ def test_push_raised_value_error_when_deployment_name_is_not_valid(
def test_push_raised_value_error_when_keep_previous_prod_settings_and_not_promote(
custom_model_truss_dir_with_pre_and_post,
):
remote = BasetenRemote(_TEST_REMOTE_URL, "api_key")
remote = BasetenRemote(_TEST_REMOTE_URL, "api_key", _TEST_REMOTE_BASE_DOMAIN)
model_response = {
"data": {
"model": {
Expand Down
28 changes: 27 additions & 1 deletion truss/tests/remote/test_remote_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,19 @@
remote_provider=test_remote
"""

SAMPLE_TRUSSRC_WITH_REMOTE_BASE_DOMAIN = """
[test]
api_key=test_key
remote_url=http://test.com
remote_base_domain=testbasedomain.co
"""


class TrussTestRemote(TrussRemote):
def __init__(self, api_key, remote_url):
def __init__(self, api_key, remote_url, remote_base_domain=None):
self.api_key = api_key
self.remote_url = remote_url
self.remote_base_domain = remote_base_domain

def authenticate(self):
return {"Authorization": self.api_key}
Expand Down Expand Up @@ -131,3 +139,21 @@ def test_load_remote_config_no_params(mock_exists, mock_open):
service = RemoteFactory.load_remote_config("test")
with pytest.raises(ValueError):
RemoteFactory.validate_remote_config(service.configs, "test")


@mock.patch.dict(
RemoteFactory.REGISTRY, {"test_remote_base_domain": TrussTestRemote}, clear=True
)
@mock.patch(
"builtins.open",
new_callable=mock.mock_open,
read_data=SAMPLE_TRUSSRC_WITH_REMOTE_BASE_DOMAIN,
)
@mock.patch("pathlib.Path.exists", return_value=True)
def test_load_remote_config_with_remote_base_domain(mock_exists, mock_open):
service = RemoteFactory.load_remote_config("test")
assert service.configs == {
"api_key": "test_key",
"remote_url": "http://test.com",
"remote_base_domain": "testbasedomain.co",
}

0 comments on commit 9bcfb2b

Please sign in to comment.