Skip to content

Commit

Permalink
Set model_origin when creating chains (#972)
Browse files Browse the repository at this point in the history
* Wire up the new chains mutations to truss chains deploy.

* Add comment.

* Use model origin.

* Use model origin.

* Use model origin.

* Resolve pr feedback.
  • Loading branch information
squidarth authored Jun 17, 2024
1 parent 62e037c commit 540ec5f
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "truss"
version = "0.9.15rc20"
version = "0.9.15rc22"
description = "A seamless bridge from model development to model delivery"
license = "MIT"
readme = "README.md"
Expand Down
1 change: 1 addition & 0 deletions truss-chains/truss_chains/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def _deploy_to_baseten(
trusted=True,
publish=options.publish,
promote=options.promote,
origin=b10_types.ModelOrigin.CHAINS,
)
return cast(b10_service.BasetenService, service)

Expand Down
8 changes: 7 additions & 1 deletion truss/remote/baseten/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def create_model_from_truss(
client_version: str,
is_trusted: bool,
deployment_name: Optional[str] = None,
origin: Optional[b10_types.ModelOrigin] = None,
):
query_string = f"""
mutation {{
Expand All @@ -119,13 +120,15 @@ def create_model_from_truss(
client_version: "{client_version}",
is_trusted: {'true' if is_trusted else 'false'},
{f'version_name: "{deployment_name}"' if deployment_name else ""}
{f'model_origin: {origin.value}' if origin else ""}
) {{
id,
name,
version_id
}}
}}
"""

resp = self._post_graphql_query(query_string)
return resp["data"]["create_model_from_truss"]

Expand Down Expand Up @@ -158,6 +161,7 @@ def create_model_version_from_truss(
}}
}}
"""

resp = self._post_graphql_query(query_string)
return resp["data"]["create_model_version_from_truss"]

Expand All @@ -168,14 +172,16 @@ def create_development_model_from_truss(
config,
client_version,
is_trusted=False,
origin: Optional[b10_types.ModelOrigin] = None,
):
query_string = f"""
mutation {{
deploy_draft_truss(name: "{model_name}",
s3_key: "{s3_key}",
config: "{config}",
client_version: "{client_version}",
is_trusted: {'true' if is_trusted else 'false'}
is_trusted: {'true' if is_trusted else 'false'},
{f'model_origin: {origin.value}' if origin else ""}
) {{
id,
name,
Expand Down
3 changes: 3 additions & 0 deletions truss/remote/baseten/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def create_truss_service(
is_draft: Optional[bool] = False,
model_id: Optional[str] = None,
deployment_name: Optional[str] = None,
origin: Optional[b10_types.ModelOrigin] = None,
) -> Tuple[str, str]:
"""
Create a model in the Baseten remote.
Expand All @@ -229,6 +230,7 @@ def create_truss_service(
config,
f"truss=={truss.version()}",
is_trusted,
origin=origin,
)

return (model_version_json["id"], model_version_json["version_id"])
Expand All @@ -242,6 +244,7 @@ def create_truss_service(
client_version=f"truss=={truss.version()}",
is_trusted=is_trusted,
deployment_name=deployment_name,
origin=origin,
)
return (model_version_json["id"], model_version_json["version_id"])

Expand Down
10 changes: 8 additions & 2 deletions truss/remote/baseten/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import yaml
from requests import ReadTimeout
from truss.local.local_config_handler import LocalConfigHandler
from truss.remote.baseten import types as b10_types
from truss.remote.baseten.api import BasetenApi
from truss.remote.baseten.auth import AuthService
from truss.remote.baseten.core import (
Expand All @@ -28,7 +29,6 @@
)
from truss.remote.baseten.error import ApiError
from truss.remote.baseten.service import BasetenService
from truss.remote.baseten.types import ChainletData
from truss.remote.baseten.utils.transfer import base64_encoded_json_str
from truss.remote.truss_remote import TrussRemote
from truss.truss_config import ModelServer
Expand All @@ -48,8 +48,12 @@ def api(self) -> BasetenApi:
return self._api

def create_chain(
self, chain_name: str, chainlets: List[ChainletData], publish: bool = False
self,
chain_name: str,
chainlets: List[b10_types.ChainletData],
publish: bool = False,
) -> str:

chain_id = get_chain_id_by_name(self._api, chain_name)
return create_chain(
self._api,
Expand All @@ -68,6 +72,7 @@ def push( # type: ignore
promote: bool = False,
preserve_previous_prod_deployment: bool = False,
deployment_name: Optional[str] = None,
origin: Optional[b10_types.ModelOrigin] = None,
) -> BasetenService:
if model_name.isspace():
raise ValueError("Model name cannot be empty")
Expand Down Expand Up @@ -117,6 +122,7 @@ def push( # type: ignore
promote=promote,
preserve_previous_prod_deployment=preserve_previous_prod_deployment,
deployment_name=deployment_name,
origin=origin,
)

return BasetenService(
Expand Down
7 changes: 7 additions & 0 deletions truss/remote/baseten/types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from enum import Enum

import pydantic


class ChainletData(pydantic.BaseModel):
name: str
oracle_version_id: str
is_entrypoint: bool


class ModelOrigin(Enum):
BASETEN = "BASETEN"
CHAINS = "CHAINS"

0 comments on commit 540ec5f

Please sign in to comment.