Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set model_origin when creating chains #972

Merged
merged 8 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "truss"
version = "0.9.15rc20"
version = "0.9.15rc22"
description = "A seamless bridge from model development to model delivery"
license = "MIT"
readme = "README.md"
Expand Down
1 change: 1 addition & 0 deletions truss-chains/truss_chains/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def _deploy_to_baseten(
trusted=True,
publish=options.publish,
promote=options.promote,
origin=b10_types.ModelOrigin.CHAINS,
)
return cast(b10_service.BasetenService, service)

Expand Down
8 changes: 7 additions & 1 deletion truss/remote/baseten/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def create_model_from_truss(
client_version: str,
is_trusted: bool,
deployment_name: Optional[str] = None,
origin: Optional[b10_types.ModelOrigin] = None,
):
query_string = f"""
mutation {{
Expand All @@ -119,13 +120,15 @@ def create_model_from_truss(
client_version: "{client_version}",
is_trusted: {'true' if is_trusted else 'false'},
{f'version_name: "{deployment_name}"' if deployment_name else ""}
{f'model_origin: {origin.value}' if origin else ""}
) {{
id,
name,
version_id
}}
}}
"""

resp = self._post_graphql_query(query_string)
return resp["data"]["create_model_from_truss"]

Expand Down Expand Up @@ -158,6 +161,7 @@ def create_model_version_from_truss(
}}
}}
"""

resp = self._post_graphql_query(query_string)
return resp["data"]["create_model_version_from_truss"]

Expand All @@ -168,14 +172,16 @@ def create_development_model_from_truss(
config,
client_version,
is_trusted=False,
origin: Optional[b10_types.ModelOrigin] = None,
):
query_string = f"""
mutation {{
deploy_draft_truss(name: "{model_name}",
s3_key: "{s3_key}",
config: "{config}",
client_version: "{client_version}",
is_trusted: {'true' if is_trusted else 'false'}
is_trusted: {'true' if is_trusted else 'false'},
{f'model_origin: {origin.value}' if origin else ""}
) {{
id,
name,
Expand Down
3 changes: 3 additions & 0 deletions truss/remote/baseten/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def create_truss_service(
is_draft: Optional[bool] = False,
model_id: Optional[str] = None,
deployment_name: Optional[str] = None,
origin: Optional[b10_types.ModelOrigin] = None,
) -> Tuple[str, str]:
"""
Create a model in the Baseten remote.
Expand All @@ -229,6 +230,7 @@ def create_truss_service(
config,
f"truss=={truss.version()}",
is_trusted,
origin=origin,
)

return (model_version_json["id"], model_version_json["version_id"])
Expand All @@ -242,6 +244,7 @@ def create_truss_service(
client_version=f"truss=={truss.version()}",
is_trusted=is_trusted,
deployment_name=deployment_name,
origin=origin,
)
return (model_version_json["id"], model_version_json["version_id"])

Expand Down
10 changes: 8 additions & 2 deletions truss/remote/baseten/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import yaml
from requests import ReadTimeout
from truss.local.local_config_handler import LocalConfigHandler
from truss.remote.baseten import types as b10_types
from truss.remote.baseten.api import BasetenApi
from truss.remote.baseten.auth import AuthService
from truss.remote.baseten.core import (
Expand All @@ -28,7 +29,6 @@
)
from truss.remote.baseten.error import ApiError
from truss.remote.baseten.service import BasetenService
from truss.remote.baseten.types import ChainletData
from truss.remote.baseten.utils.transfer import base64_encoded_json_str
from truss.remote.truss_remote import TrussRemote
from truss.truss_config import ModelServer
Expand All @@ -48,8 +48,12 @@ def api(self) -> BasetenApi:
return self._api

def create_chain(
self, chain_name: str, chainlets: List[ChainletData], publish: bool = False
self,
chain_name: str,
chainlets: List[b10_types.ChainletData],
publish: bool = False,
) -> str:

chain_id = get_chain_id_by_name(self._api, chain_name)
return create_chain(
self._api,
Expand All @@ -68,6 +72,7 @@ def push( # type: ignore
promote: bool = False,
preserve_previous_prod_deployment: bool = False,
deployment_name: Optional[str] = None,
origin: Optional[b10_types.ModelOrigin] = None,
) -> BasetenService:
if model_name.isspace():
raise ValueError("Model name cannot be empty")
Expand Down Expand Up @@ -117,6 +122,7 @@ def push( # type: ignore
promote=promote,
preserve_previous_prod_deployment=preserve_previous_prod_deployment,
deployment_name=deployment_name,
origin=origin,
)

return BasetenService(
Expand Down
7 changes: 7 additions & 0 deletions truss/remote/baseten/types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from enum import Enum

import pydantic


class ChainletData(pydantic.BaseModel):
name: str
oracle_version_id: str
is_entrypoint: bool


class ModelOrigin(Enum):
BASETEN = "BASETEN"
CHAINS = "CHAINS"
Loading