Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into tianshuc/docker-serve…
Browse files Browse the repository at this point in the history
…r-stop-copying-truss
  • Loading branch information
Tianshu Cheng authored and Tianshu Cheng committed Oct 28, 2024
2 parents 412805f + 9889e04 commit edb93d1
Show file tree
Hide file tree
Showing 26 changed files with 848 additions and 299 deletions.
2 changes: 0 additions & 2 deletions .gitattributes

This file was deleted.

1 change: 0 additions & 1 deletion .python-version

This file was deleted.

495 changes: 275 additions & 220 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "truss"
version = "0.9.44"
version = "0.9.45rc009"
description = "A seamless bridge from model development to model delivery"
license = "MIT"
readme = "README.md"
Expand All @@ -27,6 +27,7 @@ packages = [
"Baseten" = "https://baseten.co"

[tool.poetry.dependencies]
aiofiles = "^24.1.0"
blake3 = "^0.3.3"
boto3 = "^1.34.85"
fastapi = ">=0.109.1"
Expand Down Expand Up @@ -96,6 +97,7 @@ pytest = "7.2.0"
pytest-cov = "^3.0.0"
types-PyYAML = "^6.0.12.12"
types-setuptools = "^69.0.0.0"
types-aiofiles = "^24.1.0.20240626"

[tool.poetry.scripts]
truss = 'truss.cli:truss_cli'
Expand Down
12 changes: 9 additions & 3 deletions truss-chains/truss_chains/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import pydantic
from truss import truss_config
from truss.constants import PRODUCTION_ENVIRONMENT_NAME
from truss.remote import baseten as baseten_remote
from truss.remote import remote_cli, remote_factory

Expand Down Expand Up @@ -609,22 +610,27 @@ class PushOptions(SafeModelNonSerializable):
class PushOptionsBaseten(PushOptions):
remote_provider: baseten_remote.BasetenRemote
publish: bool
promote: bool
environment: Optional[str]

@classmethod
def create(
cls,
chain_name: str,
publish: bool,
promote: bool,
promote: Optional[bool],
only_generate_trusses: bool,
user_env: Mapping[str, str],
remote: Optional[str] = None,
environment: Optional[str] = None,
) -> "PushOptionsBaseten":
if not remote:
remote = remote_cli.inquire_remote_name(
remote_factory.RemoteFactory.get_available_config_names()
)
if promote and not environment:
environment = PRODUCTION_ENVIRONMENT_NAME
if environment:
publish = True
remote_provider = cast(
baseten_remote.BasetenRemote,
remote_factory.RemoteFactory.create(remote=remote),
Expand All @@ -633,9 +639,9 @@ def create(
remote_provider=remote_provider,
chain_name=chain_name,
publish=publish,
promote=promote,
only_generate_trusses=only_generate_trusses,
user_env=user_env,
environment=environment,
)


Expand Down
20 changes: 20 additions & 0 deletions truss-chains/truss_chains/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pprint
import sys
import types
import warnings
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -268,6 +269,25 @@ def _validate_and_describe_endpoint(
is_async = False
is_generator = inspect.isgeneratorfunction(endpoint_method)

if not is_async:
warnings.warn(
"`run_remote` must be an async (coroutine) function in future releases. "
"Replace `def run_remote(...` with `async def run_remote(...`. "
"Local testing and execution can be done with "
"`asyncio.run(my_chainlet.run_remote(...))`.\n"
"Note on concurrency: previously sync functions were run in threads by the "
"Truss server.\bn"
"For some frameworks this was **unsafe** (e.g. in torch the CUDA context "
"is not thread-safe).\n"
"Additionally, python threads hold the GIL and therefore might not give "
"actual throughput gains.\n"
"To achieve safe and performant concurrency, use framework-specific async "
"APIs (e.g. AsyncLLMEngine for vLLM) or generic async batching like such "
"as https://github.com/hussein-awala/async-batcher.",
DeprecationWarning,
stacklevel=1,
)

return definitions.EndpointAPIDescriptor(
input_args=input_args,
output_types=output_types,
Expand Down
3 changes: 3 additions & 0 deletions truss-chains/truss_chains/public_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def push(
user_env: Optional[Mapping[str, str]] = None,
only_generate_trusses: bool = False,
remote: Optional[str] = None,
environment: Optional[str] = None,
) -> chains_remote.BasetenChainService:
"""
Deploys a chain remotely (with all dependent chainlets).
Expand All @@ -144,6 +145,7 @@ def push(
``/tmp/.chains_generated``.
remote: name of a remote config in `.trussrc`. If not provided, it will be
inquired.
environment: The name of an environment to promote deployment into.
Returns:
A chain service handle to the deployed chain.
Expand All @@ -156,6 +158,7 @@ def push(
user_env=user_env or {},
only_generate_trusses=only_generate_trusses,
remote=remote,
environment=environment,
)
service = chains_remote.push(entrypoint, options)
assert isinstance(service, chains_remote.BasetenChainService) # Per options above.
Expand Down
8 changes: 2 additions & 6 deletions truss-chains/truss_chains/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,15 @@ def _push_to_baseten(
model_name = truss_handle.spec.config.model_name
assert model_name is not None
assert bool(_MODEL_NAME_RE.match(model_name))
if options.promote and not options.publish:
logging.info("`promote=True` overrides `publish` to `True`.")
logging.info(
f"Pushing chainlet `{model_name}` as a truss model on Baseten "
f"(publish={options.publish}, promote={options.promote})."
f"Pushing chainlet `{model_name}` as a truss model on Baseten (publish={options.publish})"
)
# Models must be trusted to use the API KEY secret.
service = options.remote_provider.push(
truss_handle,
model_name=model_name,
trusted=True,
publish=options.publish,
promote=options.promote,
origin=b10_types.ModelOrigin.CHAINS,
)
return cast(b10_service.BasetenService, service)
Expand Down Expand Up @@ -327,7 +323,7 @@ def _create_baseten_chain(
chain_name=baseten_options.chain_name,
chainlets=chainlet_data,
publish=baseten_options.publish,
promote=baseten_options.promote,
environment=baseten_options.environment,
)
return BasetenChainService(
baseten_options.chain_name,
Expand Down
2 changes: 1 addition & 1 deletion truss-chains/truss_chains/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def override_chainlet_to_service_metadata(
chainlet_to_service: Mapping[str, definitions.ServiceDescriptor],
):
# Override predict_urls in chainlet_to_service ServiceDescriptors if dynamic_chainlet_config exists
dynamic_chainlet_config_str = dynamic_config_resolver.get_dynamic_config_value(
dynamic_chainlet_config_str = dynamic_config_resolver.get_dynamic_config_value_sync(
definitions.DYNAMIC_CHAINLET_CONFIG_KEY
)
if dynamic_chainlet_config_str:
Expand Down
32 changes: 29 additions & 3 deletions truss/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,13 @@ def chains():
"""Subcommands for truss chains"""


def _make_chains_curl_snippet(run_remote_url: str) -> str:
def _make_chains_curl_snippet(run_remote_url: str, environment: Optional[str]) -> str:
if environment:
idx = run_remote_url.find("deployment")
if idx != -1:
run_remote_url = (
run_remote_url[:idx] + f"environments/{environment}/run_remote"
)
return (
f"curl -X POST '{run_remote_url}' \\\n"
' -H "Authorization: Api-Key $BASETEN_API_KEY" \\\n'
Expand Down Expand Up @@ -505,6 +511,15 @@ def _create_chains_table(service) -> Tuple[rich.table.Table, List[str]]:
default=False,
help="Replace production chainlets with newly deployed chainlets.",
)
@click.option(
"--environment",
type=str,
required=False,
help=(
"Deploy the chain as a published deployment to the specified environment."
"If specified, --publish is implied and the supplied value of --promote will be ignored."
),
)
@click.option(
"--wait/--no-wait",
type=bool,
Expand Down Expand Up @@ -557,6 +572,7 @@ def push_chain(
dryrun: bool,
user_env: Optional[str],
remote: Optional[str],
environment: Optional[str],
) -> None:
"""
Deploys a chain remotely.
Expand Down Expand Up @@ -597,6 +613,10 @@ def push_chain(
else:
user_env_parsed = {}

if promote and environment:
promote_warning = "`promote` flag and `environment` flag were both specified. Ignoring the value of `promote`"
console.print(promote_warning, style="yellow")

with framework.import_target(source, entrypoint) as entrypoint_cls:
chain_name = name or entrypoint_cls.__name__
options = chains_def.PushOptionsBaseten.create(
Expand All @@ -606,6 +626,7 @@ def push_chain(
only_generate_trusses=dryrun,
user_env=user_env_parsed,
remote=remote,
environment=environment,
)
service = chains_remote.push(entrypoint_cls, options)

Expand All @@ -614,7 +635,9 @@ def push_chain(
return

assert isinstance(service, chains_remote.BasetenChainService)
curl_snippet = _make_chains_curl_snippet(service.run_remote_url)
curl_snippet = _make_chains_curl_snippet(
service.run_remote_url, options.environment
)

table, statuses = _create_chains_table(service)
status_check_wait_sec = 2
Expand Down Expand Up @@ -647,7 +670,10 @@ def push_chain(
for log in intercepted_logs:
console.print(f"\t{log}")
if success:
console.print("Deployment succeeded.", style="bold green")
deploy_success_text = "Deployment succeeded."
if environment:
deploy_success_text = f"Your chain has been deployed into the {options.environment} environment."
console.print(deploy_success_text, style="bold green")
console.print(f"You can run the chain with:\n{curl_snippet}")
if watch: # Note that this command will print a startup message.
chains_remote.watch(
Expand Down
12 changes: 12 additions & 0 deletions truss/local/local_config_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,18 @@ def bptr_data_resolution_dir_path():
bptr_data_dir.mkdir(exist_ok=True, parents=True)
return bptr_data_dir

@staticmethod
def dynamic_config_path():
dynamic_config_dir = LocalConfigHandler.TRUSS_CONFIG_DIR / "b10_dynamic_config"
dynamic_config_dir.mkdir(exist_ok=True, parents=True)
return dynamic_config_dir

@staticmethod
def set_dynamic_config(key: str, value: str):
key_path = LocalConfigHandler.dynamic_config_path() / key
with key_path.open("w") as key_file:
key_file.write(value)

@staticmethod
def _signatures_dir_path():
return LocalConfigHandler.TRUSS_CONFIG_DIR / "signatures"
Expand Down
21 changes: 12 additions & 9 deletions truss/remote/baseten/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,22 +235,25 @@ def deploy_draft_chain(
return resp["data"]["deploy_draft_chain"]

def deploy_chain_deployment(
self, chain_id: str, chainlet_data: List[b10_types.ChainletData], promote: bool
self,
chain_id: str,
chainlet_data: List[b10_types.ChainletData],
environment: Optional[str] = None,
):
chainlet_data_strings = [
_chainlet_data_to_graphql_mutation(chainlet) for chainlet in chainlet_data
]
chainlets_string = ", ".join(chainlet_data_strings)
query_string = f"""
mutation {{
deploy_chain_deployment(
chain_id: "{chain_id}",
chainlets: [{chainlets_string}],
promote_after_deploy: {'true' if promote else 'false'},
) {{
chain_id
chain_deployment_id
}}
deploy_chain_deployment(
chain_id: "{chain_id}",
chainlets: [{chainlets_string}],
{f'environment_name: "{environment}"' if environment else ""}
) {{
chain_id
chain_deployment_id
}}
}}
"""
resp = self._post_graphql_query(query_string)
Expand Down
21 changes: 18 additions & 3 deletions truss/remote/baseten/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def create_chain(
chain_name: str,
chainlets: List[b10_types.ChainletData],
is_draft: bool,
promote: bool,
environment: Optional[str],
) -> ChainDeploymentHandle:
if is_draft:
response = api.deploy_draft_chain(chain_name, chainlets)
Expand All @@ -93,8 +93,20 @@ def create_chain(
# if there is no chain already, the first deployment will
# already be production, and only published deployments can
# be promoted.
response = api.deploy_chain_deployment(chain_id, chainlets, promote)
try:
response = api.deploy_chain_deployment(chain_id, chainlets, environment)
except ApiError as e:
if (
e.graphql_error_code
== BasetenApi.GraphQLErrorCodes.RESOURCE_NOT_FOUND.value
):
raise ValueError(
f'Environment "{environment}" does not exist. You can create environments in the Chains UI.'
) from e
raise e
else:
if environment and environment != PRODUCTION_ENVIRONMENT_NAME:
raise ValueError(NO_ENVIRONMENTS_EXIST_ERROR_MESSAGING)
response = api.deploy_chain(chain_name, chainlets)

return ChainDeploymentHandle(
Expand Down Expand Up @@ -299,7 +311,10 @@ def create_truss_service(
environment=environment,
)
except ApiError as e:
if "Environment matching query does not exist" in e.message:
if (
e.graphql_error_code
== BasetenApi.GraphQLErrorCodes.RESOURCE_NOT_FOUND.value
):
raise ValueError(
f'Environment "{environment}" does not exist. You can create environments in the Baseten UI.'
) from e
Expand Down
8 changes: 4 additions & 4 deletions truss/remote/baseten/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ def create_chain(
chain_name: str,
chainlets: List[custom_types.ChainletData],
publish: bool = False,
promote: bool = False,
environment: Optional[str] = None,
) -> ChainDeploymentHandle:
if promote:
# If we are promoting a model after deploy, it must be published.
if environment:
# If we are promoting a model to an environment after deploy, it must be published.
# Draft models cannot be promoted.
publish = True
# Returns tuple of (chain_id, chain_deployment_id)
Expand All @@ -81,7 +81,7 @@ def create_chain(
chain_name=chain_name,
chainlets=chainlets,
is_draft=not publish,
promote=promote,
environment=environment,
)

def get_chainlets(
Expand Down
Loading

0 comments on commit edb93d1

Please sign in to comment.