Skip to content

Commit

Permalink
rename base_domain and remove it from cli
Browse files Browse the repository at this point in the history
  • Loading branch information
Anupreet Walia committed Jun 17, 2024
1 parent 9bcfb2b commit befca55
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 30 deletions.
11 changes: 2 additions & 9 deletions truss/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,10 @@ def run(target_directory: str, build_dir: Path, tag, port, attach) -> None:
help="Name of the remote in .trussrc to patch changes to",
)
@click.option(
"--remote_base_domain",
"--remote_inference_base_domain",
type=str,
required=False,
help="Name of the remote_base_domain in .trussrc to patch changes to",
help="Name of the remote_inference_base_domain in .trussrc to patch changes to",
)
@error_handling
def watch(
Expand Down Expand Up @@ -470,12 +470,6 @@ def _extract_request_data(data: Optional[str], file: Optional[Path]):
required=False,
help="Name of the remote in .trussrc to push to",
)
@click.option(
"--remote_base_domain",
type=str,
required=False,
help="Name of the remote_base_domain in .trussrc for invoking onferences from",
)
@click.option(
"-d",
"--data",
Expand Down Expand Up @@ -522,7 +516,6 @@ def _extract_request_data(data: Optional[str], file: Optional[Path]):
def predict(
target_directory: str,
remote: str,
remote_base_domain: str,
data: Optional[str],
file: Optional[Path],
published: Optional[bool],
Expand Down
8 changes: 4 additions & 4 deletions truss/remote/baseten/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ def __init__(
self,
remote_url: str,
api_key: str,
remote_base_domain: Optional[str] = None,
remote_inference_base_domain: Optional[str] = None,
**kwargs,
):
super().__init__(remote_url, **kwargs)
self._auth_service = AuthService(api_key=api_key)
self._api = BasetenApi(remote_url, self._auth_service)
self.remote_base_domain = remote_base_domain
self.remote_inference_base_domain = remote_inference_base_domain

@property
def api(self) -> BasetenApi:
Expand Down Expand Up @@ -118,7 +118,7 @@ def push( # type: ignore
api_key=self._auth_service.authenticate().value,
service_url=f"{self._remote_url}/model_versions/{model_version_id}",
truss_handle=truss_handle,
remote_base_domain=self.remote_base_domain,
remote_inference_base_domain=self.remote_inference_base_domain,
api=self._api,
)

Expand Down Expand Up @@ -204,7 +204,7 @@ def get_service(self, **kwargs) -> BasetenService:
is_draft=not published,
api_key=self._auth_service.authenticate().value,
service_url=f"{self._remote_url}{service_url_path}",
remote_base_domain=self.remote_base_domain,
remote_inference_base_domain=self.remote_inference_base_domain,
api=self._api,
)

Expand Down
9 changes: 5 additions & 4 deletions truss/remote/baseten/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,14 @@ def __init__(
api_key: str,
service_url: str,
api: BasetenApi,
remote_base_domain: Optional[str] = None,
remote_inference_base_domain: Optional[str] = None,
truss_handle: Optional[TrussHandle] = None,
):
super().__init__(is_draft=is_draft, service_url=service_url)
self._model_id = model_id
self._model_version_id = model_version_id
self._auth_service = AuthService(api_key=api_key)
self.remote_base_domain = remote_base_domain
self.remote_inference_base_domain = remote_inference_base_domain
self._api = api
self._truss_handle = truss_handle

Expand Down Expand Up @@ -131,9 +131,10 @@ def predict_url(self) -> str:
"""
# E.g. `https://api.baseten.co` -> `https://model-{model_id}.api.baseten.co`
url = _add_model_domain(
self._api.rest_api_url, f"model-{self.model_id}", self.remote_base_domain
self._api.rest_api_url,
f"model-{self.model_id}",
self.remote_inference_base_domain,
)
print(url)
if self.is_draft:
# "https://model-{model_id}.api.baseten.co/development".
url = f"{url}/development/predict"
Expand Down
4 changes: 2 additions & 2 deletions truss/remote/remote_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def inquire_remote_config() -> RemoteConfig:
validate=NonEmptyValidator(),
).execute()

remote_base_domain = inquirer.text(
remote_inference_base_domain = inquirer.text(
"🌐 What is the base domain for your deployed models?",
qmark="",
).execute()
Expand All @@ -39,7 +39,7 @@ def inquire_remote_config() -> RemoteConfig:
"remote_provider": "baseten",
"api_key": api_key,
"remote_url": remote_url,
"remote_base_domain": remote_base_domain,
"remote_inference_base_domain": remote_inference_base_domain,
},
)

Expand Down
34 changes: 25 additions & 9 deletions truss/tests/remote/baseten/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@

_TEST_REMOTE_URL = "http://test_remote.com"
_TEST_REMOTE_GRAPHQL_PATH = "http://test_remote.com/graphql/"
_TEST_REMOTE_BASE_DOMAIN = "test_remote_base.co"
_TEST_REMOTE_INFERENCE_BASE_DOMAIN = "test_remote_base.co"


def test_get_service_by_version_id():
remote = BasetenRemote(_TEST_REMOTE_URL, "api_key", _TEST_REMOTE_BASE_DOMAIN)
remote = BasetenRemote(
_TEST_REMOTE_URL, "api_key", _TEST_REMOTE_INFERENCE_BASE_DOMAIN
)

version = {
"id": "version_id",
Expand Down Expand Up @@ -45,7 +47,9 @@ def test_get_service_by_version_id_no_version():


def test_get_service_by_model_name():
remote = BasetenRemote(_TEST_REMOTE_URL, "api_key", _TEST_REMOTE_BASE_DOMAIN)
remote = BasetenRemote(
_TEST_REMOTE_URL, "api_key", _TEST_REMOTE_INFERENCE_BASE_DOMAIN
)

versions = [
{"id": "1", "is_draft": False, "is_primary": False},
Expand Down Expand Up @@ -121,7 +125,9 @@ def test_get_service_by_model_name_no_dev_version():


def test_get_service_by_model_name_no_prod_version():
remote = BasetenRemote(_TEST_REMOTE_URL, "api_key", _TEST_REMOTE_BASE_DOMAIN)
remote = BasetenRemote(
_TEST_REMOTE_URL, "api_key", _TEST_REMOTE_INFERENCE_BASE_DOMAIN
)

versions = [
{"id": "1", "is_draft": True, "is_primary": False},
Expand Down Expand Up @@ -156,7 +162,9 @@ def test_get_service_by_model_name_no_prod_version():


def test_get_service_by_model_id():
remote = BasetenRemote(_TEST_REMOTE_URL, "api_key", _TEST_REMOTE_BASE_DOMAIN)
remote = BasetenRemote(
_TEST_REMOTE_URL, "api_key", _TEST_REMOTE_INFERENCE_BASE_DOMAIN
)

model_response = {
"data": {
Expand All @@ -180,7 +188,9 @@ def test_get_service_by_model_id():


def test_get_service_by_model_id_no_model():
remote = BasetenRemote(_TEST_REMOTE_URL, "api_key", _TEST_REMOTE_BASE_DOMAIN)
remote = BasetenRemote(
_TEST_REMOTE_URL, "api_key", _TEST_REMOTE_INFERENCE_BASE_DOMAIN
)
model_response = {"errors": [{"message": "error"}]}
with requests_mock.Mocker() as m:
m.post(
Expand All @@ -194,7 +204,9 @@ def test_get_service_by_model_id_no_model():
def test_push_raised_value_error_when_deployment_name_and_not_publish(
custom_model_truss_dir_with_pre_and_post,
):
remote = BasetenRemote(_TEST_REMOTE_URL, "api_key", _TEST_REMOTE_BASE_DOMAIN)
remote = BasetenRemote(
_TEST_REMOTE_URL, "api_key", _TEST_REMOTE_INFERENCE_BASE_DOMAIN
)
model_response = {
"data": {
"model": {
Expand All @@ -221,7 +233,9 @@ def test_push_raised_value_error_when_deployment_name_and_not_publish(
def test_push_raised_value_error_when_deployment_name_is_not_valid(
custom_model_truss_dir_with_pre_and_post,
):
remote = BasetenRemote(_TEST_REMOTE_URL, "api_key", _TEST_REMOTE_BASE_DOMAIN)
remote = BasetenRemote(
_TEST_REMOTE_URL, "api_key", _TEST_REMOTE_INFERENCE_BASE_DOMAIN
)
model_response = {
"data": {
"model": {
Expand All @@ -248,7 +262,9 @@ def test_push_raised_value_error_when_deployment_name_is_not_valid(
def test_push_raised_value_error_when_keep_previous_prod_settings_and_not_promote(
custom_model_truss_dir_with_pre_and_post,
):
remote = BasetenRemote(_TEST_REMOTE_URL, "api_key", _TEST_REMOTE_BASE_DOMAIN)
remote = BasetenRemote(
_TEST_REMOTE_URL, "api_key", _TEST_REMOTE_INFERENCE_BASE_DOMAIN
)
model_response = {
"data": {
"model": {
Expand Down
4 changes: 2 additions & 2 deletions truss/tests/remote/test_remote_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
remote_provider=test_remote
"""

SAMPLE_TRUSSRC_WITH_REMOTE_BASE_DOMAIN = """
SAMPLE_TRUSSRC_WITH_REMOTE_INFERENCE_BASE_DOMAIN = """
[test]
api_key=test_key
remote_url=http://test.com
Expand Down Expand Up @@ -147,7 +147,7 @@ def test_load_remote_config_no_params(mock_exists, mock_open):
@mock.patch(
"builtins.open",
new_callable=mock.mock_open,
read_data=SAMPLE_TRUSSRC_WITH_REMOTE_BASE_DOMAIN,
read_data=SAMPLE_TRUSSRC_WITH_REMOTE_INFERENCE_BASE_DOMAIN,
)
@mock.patch("pathlib.Path.exists", return_value=True)
def test_load_remote_config_with_remote_base_domain(mock_exists, mock_open):
Expand Down

0 comments on commit befca55

Please sign in to comment.