diff --git a/truss/cli/cli.py b/truss/cli/cli.py index b418620e0..cb628beaf 100644 --- a/truss/cli/cli.py +++ b/truss/cli/cli.py @@ -227,10 +227,10 @@ def run(target_directory: str, build_dir: Path, tag, port, attach) -> None: help="Name of the remote in .trussrc to patch changes to", ) @click.option( - "--remote_base_domain", + "--remote_inference_base_domain", type=str, required=False, - help="Name of the remote_base_domain in .trussrc to patch changes to", + help="Name of the remote_inference_base_domain in .trussrc to patch changes to", ) @error_handling def watch( @@ -470,12 +470,6 @@ def _extract_request_data(data: Optional[str], file: Optional[Path]): required=False, help="Name of the remote in .trussrc to push to", ) -@click.option( - "--remote_base_domain", - type=str, - required=False, - help="Name of the remote_base_domain in .trussrc for invoking onferences from", -) @click.option( "-d", "--data", @@ -522,7 +516,6 @@ def _extract_request_data(data: Optional[str], file: Optional[Path]): def predict( target_directory: str, remote: str, - remote_base_domain: str, data: Optional[str], file: Optional[Path], published: Optional[bool], diff --git a/truss/remote/baseten/remote.py b/truss/remote/baseten/remote.py index 52446c0f9..078647225 100644 --- a/truss/remote/baseten/remote.py +++ b/truss/remote/baseten/remote.py @@ -39,13 +39,13 @@ def __init__( self, remote_url: str, api_key: str, - remote_base_domain: Optional[str] = None, + remote_inference_base_domain: Optional[str] = None, **kwargs, ): super().__init__(remote_url, **kwargs) self._auth_service = AuthService(api_key=api_key) self._api = BasetenApi(remote_url, self._auth_service) - self.remote_base_domain = remote_base_domain + self.remote_inference_base_domain = remote_inference_base_domain @property def api(self) -> BasetenApi: @@ -118,7 +118,7 @@ def push( # type: ignore api_key=self._auth_service.authenticate().value, service_url=f"{self._remote_url}/model_versions/{model_version_id}", truss_handle=truss_handle, - remote_base_domain=self.remote_base_domain, + remote_inference_base_domain=self.remote_inference_base_domain, api=self._api, ) @@ -204,7 +204,7 @@ def get_service(self, **kwargs) -> BasetenService: is_draft=not published, api_key=self._auth_service.authenticate().value, service_url=f"{self._remote_url}{service_url_path}", - remote_base_domain=self.remote_base_domain, + remote_inference_base_domain=self.remote_inference_base_domain, api=self._api, ) diff --git a/truss/remote/baseten/service.py b/truss/remote/baseten/service.py index 182f6c8d7..5ef08130d 100644 --- a/truss/remote/baseten/service.py +++ b/truss/remote/baseten/service.py @@ -52,14 +52,14 @@ def __init__( api_key: str, service_url: str, api: BasetenApi, - remote_base_domain: Optional[str] = None, + remote_inference_base_domain: Optional[str] = None, truss_handle: Optional[TrussHandle] = None, ): super().__init__(is_draft=is_draft, service_url=service_url) self._model_id = model_id self._model_version_id = model_version_id self._auth_service = AuthService(api_key=api_key) - self.remote_base_domain = remote_base_domain + self.remote_inference_base_domain = remote_inference_base_domain self._api = api self._truss_handle = truss_handle @@ -131,9 +131,10 @@ def predict_url(self) -> str: """ # E.g. `https://api.baseten.co` -> `https://model-{model_id}.api.baseten.co` url = _add_model_domain( - self._api.rest_api_url, f"model-{self.model_id}", self.remote_base_domain + self._api.rest_api_url, + f"model-{self.model_id}", + self.remote_inference_base_domain, ) - print(url) if self.is_draft: # "https://model-{model_id}.api.baseten.co/development". url = f"{url}/development/predict" diff --git a/truss/remote/remote_cli.py b/truss/remote/remote_cli.py index 9f931f60e..1a9248a09 100644 --- a/truss/remote/remote_cli.py +++ b/truss/remote/remote_cli.py @@ -28,7 +28,7 @@ def inquire_remote_config() -> RemoteConfig: validate=NonEmptyValidator(), ).execute() - remote_base_domain = inquirer.text( + remote_inference_base_domain = inquirer.text( "🌐 What is the base domain for your deployed models?", qmark="", ).execute() @@ -39,7 +39,7 @@ def inquire_remote_config() -> RemoteConfig: "remote_provider": "baseten", "api_key": api_key, "remote_url": remote_url, - "remote_base_domain": remote_base_domain, + "remote_inference_base_domain": remote_inference_base_domain, }, ) diff --git a/truss/tests/remote/baseten/test_remote.py b/truss/tests/remote/baseten/test_remote.py index 216ea5599..781d8d1b6 100644 --- a/truss/tests/remote/baseten/test_remote.py +++ b/truss/tests/remote/baseten/test_remote.py @@ -7,11 +7,13 @@ _TEST_REMOTE_URL = "http://test_remote.com" _TEST_REMOTE_GRAPHQL_PATH = "http://test_remote.com/graphql/" -_TEST_REMOTE_BASE_DOMAIN = "test_remote_base.co" +_TEST_REMOTE_INFERENCE_BASE_DOMAIN = "test_remote_base.co" def test_get_service_by_version_id(): - remote = BasetenRemote(_TEST_REMOTE_URL, "api_key", _TEST_REMOTE_BASE_DOMAIN) + remote = BasetenRemote( + _TEST_REMOTE_URL, "api_key", _TEST_REMOTE_INFERENCE_BASE_DOMAIN + ) version = { "id": "version_id", @@ -45,7 +47,9 @@ def test_get_service_by_version_id_no_version(): def test_get_service_by_model_name(): - remote = BasetenRemote(_TEST_REMOTE_URL, "api_key", _TEST_REMOTE_BASE_DOMAIN) + remote = BasetenRemote( + _TEST_REMOTE_URL, "api_key", _TEST_REMOTE_INFERENCE_BASE_DOMAIN + ) versions = [ {"id": "1", "is_draft": False, "is_primary": False}, @@ -121,7 +125,9 @@ def test_get_service_by_model_name_no_dev_version(): def test_get_service_by_model_name_no_prod_version(): - remote = BasetenRemote(_TEST_REMOTE_URL, "api_key", _TEST_REMOTE_BASE_DOMAIN) + remote = BasetenRemote( + _TEST_REMOTE_URL, "api_key", _TEST_REMOTE_INFERENCE_BASE_DOMAIN + ) versions = [ {"id": "1", "is_draft": True, "is_primary": False}, @@ -156,7 +162,9 @@ def test_get_service_by_model_name_no_prod_version(): def test_get_service_by_model_id(): - remote = BasetenRemote(_TEST_REMOTE_URL, "api_key", _TEST_REMOTE_BASE_DOMAIN) + remote = BasetenRemote( + _TEST_REMOTE_URL, "api_key", _TEST_REMOTE_INFERENCE_BASE_DOMAIN + ) model_response = { "data": { @@ -180,7 +188,9 @@ def test_get_service_by_model_id(): def test_get_service_by_model_id_no_model(): - remote = BasetenRemote(_TEST_REMOTE_URL, "api_key", _TEST_REMOTE_BASE_DOMAIN) + remote = BasetenRemote( + _TEST_REMOTE_URL, "api_key", _TEST_REMOTE_INFERENCE_BASE_DOMAIN + ) model_response = {"errors": [{"message": "error"}]} with requests_mock.Mocker() as m: m.post( @@ -194,7 +204,9 @@ def test_get_service_by_model_id_no_model(): def test_push_raised_value_error_when_deployment_name_and_not_publish( custom_model_truss_dir_with_pre_and_post, ): - remote = BasetenRemote(_TEST_REMOTE_URL, "api_key", _TEST_REMOTE_BASE_DOMAIN) + remote = BasetenRemote( + _TEST_REMOTE_URL, "api_key", _TEST_REMOTE_INFERENCE_BASE_DOMAIN + ) model_response = { "data": { "model": { @@ -221,7 +233,9 @@ def test_push_raised_value_error_when_deployment_name_and_not_publish( def test_push_raised_value_error_when_deployment_name_is_not_valid( custom_model_truss_dir_with_pre_and_post, ): - remote = BasetenRemote(_TEST_REMOTE_URL, "api_key", _TEST_REMOTE_BASE_DOMAIN) + remote = BasetenRemote( + _TEST_REMOTE_URL, "api_key", _TEST_REMOTE_INFERENCE_BASE_DOMAIN + ) model_response = { "data": { "model": { @@ -248,7 +262,9 @@ def test_push_raised_value_error_when_deployment_name_is_not_valid( def test_push_raised_value_error_when_keep_previous_prod_settings_and_not_promote( custom_model_truss_dir_with_pre_and_post, ): - remote = BasetenRemote(_TEST_REMOTE_URL, "api_key", _TEST_REMOTE_BASE_DOMAIN) + remote = BasetenRemote( + _TEST_REMOTE_URL, "api_key", _TEST_REMOTE_INFERENCE_BASE_DOMAIN + ) model_response = { "data": { "model": { diff --git a/truss/tests/remote/test_remote_factory.py b/truss/tests/remote/test_remote_factory.py index 74b3ddb99..55245dd24 100644 --- a/truss/tests/remote/test_remote_factory.py +++ b/truss/tests/remote/test_remote_factory.py @@ -24,7 +24,7 @@ remote_provider=test_remote """ -SAMPLE_TRUSSRC_WITH_REMOTE_BASE_DOMAIN = """ +SAMPLE_TRUSSRC_WITH_REMOTE_INFERENCE_BASE_DOMAIN = """ [test] api_key=test_key remote_url=http://test.com @@ -147,7 +147,7 @@ def test_load_remote_config_no_params(mock_exists, mock_open): @mock.patch( "builtins.open", new_callable=mock.mock_open, - read_data=SAMPLE_TRUSSRC_WITH_REMOTE_BASE_DOMAIN, + read_data=SAMPLE_TRUSSRC_WITH_REMOTE_INFERENCE_BASE_DOMAIN, ) @mock.patch("pathlib.Path.exists", return_value=True) def test_load_remote_config_with_remote_base_domain(mock_exists, mock_open):