diff --git a/pyproject.toml b/pyproject.toml index 018c6fee5..49235e874 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "truss" -version = "0.9.15rc20" +version = "0.9.15rc22" description = "A seamless bridge from model development to model delivery" license = "MIT" readme = "README.md" diff --git a/truss-chains/truss_chains/deploy.py b/truss-chains/truss_chains/deploy.py index 977732ace..1a92c84b6 100644 --- a/truss-chains/truss_chains/deploy.py +++ b/truss-chains/truss_chains/deploy.py @@ -46,6 +46,7 @@ def _deploy_to_baseten( trusted=True, publish=options.publish, promote=options.promote, + origin=b10_types.ModelOrigin.CHAINS, ) return cast(b10_service.BasetenService, service) diff --git a/truss/remote/baseten/api.py b/truss/remote/baseten/api.py index 9130df78e..6c990ff96 100644 --- a/truss/remote/baseten/api.py +++ b/truss/remote/baseten/api.py @@ -108,6 +108,7 @@ def create_model_from_truss( client_version: str, is_trusted: bool, deployment_name: Optional[str] = None, + origin: Optional[b10_types.ModelOrigin] = None, ): query_string = f""" mutation {{ @@ -119,6 +120,7 @@ def create_model_from_truss( client_version: "{client_version}", is_trusted: {'true' if is_trusted else 'false'}, {f'version_name: "{deployment_name}"' if deployment_name else ""} + {f'model_origin: {origin.value}' if origin else ""} ) {{ id, name, @@ -126,6 +128,7 @@ def create_model_from_truss( }} }} """ + resp = self._post_graphql_query(query_string) return resp["data"]["create_model_from_truss"] @@ -158,6 +161,7 @@ def create_model_version_from_truss( }} }} """ + resp = self._post_graphql_query(query_string) return resp["data"]["create_model_version_from_truss"] @@ -168,6 +172,7 @@ def create_development_model_from_truss( config, client_version, is_trusted=False, + origin: Optional[b10_types.ModelOrigin] = None, ): query_string = f""" mutation {{ @@ -175,7 +180,8 @@ def create_development_model_from_truss( s3_key: "{s3_key}", config: "{config}", client_version: "{client_version}", - is_trusted: {'true' if is_trusted else 'false'} + is_trusted: {'true' if is_trusted else 'false'}, + {f'model_origin: {origin.value}' if origin else ""} ) {{ id, name, diff --git a/truss/remote/baseten/core.py b/truss/remote/baseten/core.py index 1c5594035..832db79ba 100644 --- a/truss/remote/baseten/core.py +++ b/truss/remote/baseten/core.py @@ -204,6 +204,7 @@ def create_truss_service( is_draft: Optional[bool] = False, model_id: Optional[str] = None, deployment_name: Optional[str] = None, + origin: Optional[b10_types.ModelOrigin] = None, ) -> Tuple[str, str]: """ Create a model in the Baseten remote. @@ -229,6 +230,7 @@ def create_truss_service( config, f"truss=={truss.version()}", is_trusted, + origin=origin, ) return (model_version_json["id"], model_version_json["version_id"]) @@ -242,6 +244,7 @@ def create_truss_service( client_version=f"truss=={truss.version()}", is_trusted=is_trusted, deployment_name=deployment_name, + origin=origin, ) return (model_version_json["id"], model_version_json["version_id"]) diff --git a/truss/remote/baseten/remote.py b/truss/remote/baseten/remote.py index 25ba28a4f..fce7b297b 100644 --- a/truss/remote/baseten/remote.py +++ b/truss/remote/baseten/remote.py @@ -8,6 +8,7 @@ import yaml from requests import ReadTimeout from truss.local.local_config_handler import LocalConfigHandler +from truss.remote.baseten import types as b10_types from truss.remote.baseten.api import BasetenApi from truss.remote.baseten.auth import AuthService from truss.remote.baseten.core import ( @@ -28,7 +29,6 @@ ) from truss.remote.baseten.error import ApiError from truss.remote.baseten.service import BasetenService -from truss.remote.baseten.types import ChainletData from truss.remote.baseten.utils.transfer import base64_encoded_json_str from truss.remote.truss_remote import TrussRemote from truss.truss_config import ModelServer @@ -48,8 +48,12 @@ def api(self) -> BasetenApi: return self._api def create_chain( - self, chain_name: str, chainlets: List[ChainletData], publish: bool = False + self, + chain_name: str, + chainlets: List[b10_types.ChainletData], + publish: bool = False, ) -> str: + chain_id = get_chain_id_by_name(self._api, chain_name) return create_chain( self._api, @@ -68,6 +72,7 @@ def push( # type: ignore promote: bool = False, preserve_previous_prod_deployment: bool = False, deployment_name: Optional[str] = None, + origin: Optional[b10_types.ModelOrigin] = None, ) -> BasetenService: if model_name.isspace(): raise ValueError("Model name cannot be empty") @@ -117,6 +122,7 @@ def push( # type: ignore promote=promote, preserve_previous_prod_deployment=preserve_previous_prod_deployment, deployment_name=deployment_name, + origin=origin, ) return BasetenService( diff --git a/truss/remote/baseten/types.py b/truss/remote/baseten/types.py index 1331ab61f..b8e196e3a 100644 --- a/truss/remote/baseten/types.py +++ b/truss/remote/baseten/types.py @@ -1,3 +1,5 @@ +from enum import Enum + import pydantic @@ -5,3 +7,8 @@ class ChainletData(pydantic.BaseModel): name: str oracle_version_id: str is_entrypoint: bool + + +class ModelOrigin(Enum): + BASETEN = "BASETEN" + CHAINS = "CHAINS"