Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
nnarayen committed Feb 4, 2025
1 parent 21a88eb commit 86c2ec0
Show file tree
Hide file tree
Showing 8 changed files with 164 additions and 143 deletions.
13 changes: 0 additions & 13 deletions truss-chains/tests/openai/config.yaml

This file was deleted.

18 changes: 0 additions & 18 deletions truss-chains/tests/openai/model/model.py

This file was deleted.

29 changes: 0 additions & 29 deletions truss-chains/tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,32 +328,3 @@ def test_custom_health_checks_chain():
assert response.status_code == 503
container_logs = get_container_logs_from_prefix(entrypoint.name)
assert container_logs.count("Health check failed.") == 2


@pytest.mark.integration
def test_custom_openai_endpoints():
with ensure_kill_all():
model_root = TEST_ROOT / "openai"
truss_handle = load(model_root)

assert truss_handle.spec.config.model_name == "OpenAIModel"

port = utils.get_free_port()
truss_handle.docker_run(local_port=port, detach=True, network="host")

base_url = f"http://localhost:{port}"
response = requests.post(
f"{base_url}/v1/models/model:predict", json={"increment": 1}
)
assert response.status_code == 200
assert response.json() == 1

response = requests.post(f"{base_url}/v1/completions", json={"increment": 2})
assert response.status_code == 200
assert response.json() == 2

# Written model intentionally does not support chat completions
response = requests.post(
f"{base_url}/v1/chat/completions", json={"increment": 3}
)
assert response.status_code == 404
159 changes: 90 additions & 69 deletions truss/templates/server/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,11 @@
import time
import weakref
from contextlib import asynccontextmanager
from enum import Enum
from functools import cached_property
from multiprocessing import Lock
from pathlib import Path
from threading import Thread
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Union
from typing import Any, Callable, Dict, Optional, Tuple, Union

import opentelemetry.sdk.trace as sdk_trace
import pydantic
Expand All @@ -29,7 +28,7 @@
from common.retry import retry
from common.schema import TrussSchema
from opentelemetry import trace
from shared import dynamic_config_resolver, serialization
from shared import dynamic_config_resolver, serialization, util
from shared.lazy_data_resolver import LazyDataResolver
from shared.secrets_resolver import SecretsResolver

Expand All @@ -50,14 +49,14 @@
POLL_FOR_ENVIRONMENT_UPDATES_TIMEOUT_SECS = 30


class ModelMethod(Enum):
PREPROCESS = "preprocess"
PREDICT = "predict"
POSTPROCESS = "postprocess"
IS_HEALTHY = "is_healthy"
SETUP_ENVIRONMENT = "setup_environment"
COMPLETIONS = "completions"
CHAT_COMPLETIONS = "chat_completions"
class ModelMethodName(util.LowerStrEnum):
CHAT_COMPLETIONS = enum.auto()
COMPLETIONS = enum.auto()
IS_HEALTHY = enum.auto()
POSTPROCESS = enum.auto()
PREDICT = enum.auto()
PREPROCESS = enum.auto()
SETUP_ENVIRONMENT = enum.auto()


InputType = Union[serialization.JSONType, serialization.MsgPackType, pydantic.BaseModel]
Expand Down Expand Up @@ -221,63 +220,78 @@ def skip_input_parsing(self) -> bool:
@classmethod
def _gen_truss_schema(
cls,
model: Any,
model_cls: Any,
predict: MethodDescriptor,
preprocess: Optional[MethodDescriptor],
postprocess: Optional[MethodDescriptor],
) -> TrussSchema:
if preprocess:
parameters = inspect.signature(model.preprocess).parameters
parameters = inspect.signature(model_cls.preprocess).parameters
else:
parameters = inspect.signature(model.predict).parameters
parameters = inspect.signature(model_cls.predict).parameters

if postprocess:
return_annotation = inspect.signature(model.postprocess).return_annotation
return_annotation = inspect.signature(
model_cls.postprocess
).return_annotation
else:
return_annotation = inspect.signature(model.predict).return_annotation
return_annotation = inspect.signature(model_cls.predict).return_annotation

return TrussSchema.from_signature(parameters, return_annotation)

@classmethod
def _safe_extract_descriptor(
cls, model: Any, method: str
cls, model_cls: Any, method_name: ModelMethodName
) -> Union[MethodDescriptor, None]:
if hasattr(model, method):
return MethodDescriptor.from_method(getattr(model, method), method)
if hasattr(model_cls, method_name):
return MethodDescriptor.from_method(
getattr(model_cls, method_name), method_name
)
return None

@classmethod
def from_model(cls, model) -> "ModelDescriptor":
preprocess = cls._safe_extract_descriptor(model, ModelMethod.PREPROCESS.value)
predict = cls._safe_extract_descriptor(model, ModelMethod.PREDICT.value)
def from_model(cls, model_cls) -> "ModelDescriptor":
preprocess = cls._safe_extract_descriptor(model_cls, ModelMethodName.PREPROCESS)
predict = cls._safe_extract_descriptor(model_cls, ModelMethodName.PREDICT)
if predict is None:
raise errors.ModelDefinitionError(
f"Truss model must have a `{ModelMethod.PREDICT.value}` method."
f"Truss model must have a `{ModelMethodName.PREDICT}` method."
)
elif preprocess and predict.arg_config == ArgConfig.REQUEST_ONLY:
raise errors.ModelDefinitionError(
f"When using `{ModelMethod.PREPROCESS.value}`, the {ModelMethod.PREDICT.value} method "
f"cannot only have the request argument (because the result of `{ModelMethod.PREPROCESS.value}` "
f"When using `{ModelMethodName.PREPROCESS}`, the {ModelMethodName.PREDICT} method "
f"cannot only have the request argument (because the result of `{ModelMethodName.PREPROCESS}` "
"would be discarded)."
)

postprocess = cls._safe_extract_descriptor(model, ModelMethod.POSTPROCESS.value)
postprocess = cls._safe_extract_descriptor(
model_cls, ModelMethodName.POSTPROCESS
)
if postprocess and postprocess.arg_config == ArgConfig.REQUEST_ONLY:
raise errors.ModelDefinitionError(
f"The `{ModelMethod.POSTPROCESS.value}` method cannot only have the request "
f"argument (because the result of `{ModelMethod.PREDICT.value}` would be discarded)."
f"The `{ModelMethodName.POSTPROCESS}` method cannot only have the request "
f"argument (because the result of `{ModelMethodName.PREDICT}` would be discarded)."
)
setup = cls._safe_extract_descriptor(model, ModelMethod.SETUP_ENVIRONMENT.value)
completions = cls._safe_extract_descriptor(model, ModelMethod.COMPLETIONS.value)
chats = cls._safe_extract_descriptor(model, ModelMethod.CHAT_COMPLETIONS.value)
is_healthy = cls._safe_extract_descriptor(model, ModelMethod.IS_HEALTHY.value)
setup = cls._safe_extract_descriptor(
model_cls, ModelMethodName.SETUP_ENVIRONMENT
)
completions = cls._safe_extract_descriptor(
model_cls, ModelMethodName.COMPLETIONS
)
chats = cls._safe_extract_descriptor(
model_cls, ModelMethodName.CHAT_COMPLETIONS
)
is_healthy = cls._safe_extract_descriptor(model_cls, ModelMethodName.IS_HEALTHY)
if is_healthy and is_healthy.arg_config != ArgConfig.NONE:
raise errors.ModelDefinitionError(
f"`{ModelMethod.IS_HEALTHY.value}` must have only one argument: `self`."
f"`{ModelMethodName.IS_HEALTHY}` must have only one argument: `self`."
)

truss_schema = cls._gen_truss_schema(
model=model, predict=predict, preprocess=preprocess, postprocess=postprocess
model_cls=model_cls,
predict=predict,
preprocess=preprocess,
postprocess=postprocess,
)
return cls(
preprocess=preprocess,
Expand All @@ -302,7 +316,7 @@ class ModelWrapper:
_poll_for_environment_updates_task: Optional[asyncio.Task]
_environment: Optional[dict]

class Status(Enum):
class Status(enum.Enum):
NOT_READY = 0
LOADING = 1
READY = 2
Expand Down Expand Up @@ -571,9 +585,15 @@ async def preprocess(
self, inputs: InputType, request: starlette.requests.Request
) -> Any:
descriptor = self.model_descriptor.preprocess
assert descriptor, "`preprocess` must only be called if model has it."
assert descriptor, (
f"`{ModelMethodName.PREPROCESS}` must only be called if model has it."
)
return await self._execute_async_model_fn(
descriptor, inputs, request, self._model.preprocess
descriptor,
inputs,
request,
self._model.preprocess,
supports_generators=False,
)

async def predict(
Expand All @@ -595,9 +615,15 @@ async def postprocess(
# and postprocess is skipped.
# The result type can be the same as for predict.
descriptor = self.model_descriptor.postprocess
assert descriptor, "`postprocess` must only be called if model has it."
assert descriptor, (
f"`{ModelMethodName.POSTPROCESS}` must only be called if model has it."
)
return await self._execute_async_model_fn(
descriptor, result, request, self._model.postprocess
descriptor,
result,
request,
self._model.postprocess,
supports_generators=False,
)

async def _write_response_to_queue(
Expand All @@ -623,7 +649,7 @@ async def _stream_with_background_task(
generator: Union[Generator[bytes, None, None], AsyncGenerator[bytes, None]],
span: trace.Span,
trace_ctx: trace.Context,
release_and_end: Callable[[], None] = lambda: None,
release_and_end: Callable[[], None],
) -> AsyncGenerator[bytes, None]:
# The streaming read timeout is the amount of time in between streamed chunk
# before a timeout is triggered.
Expand Down Expand Up @@ -672,10 +698,12 @@ async def _execute_async_model_fn(
inputs: Union[InputType, Any],
request: starlette.requests.Request,
model_fn: Any,
supports_generators: bool = True,
) -> OutputType:
args = ArgConfig.prepare_args(descriptor, inputs, request)
with errors.intercept_exceptions(self._logger, self._model_file_name):
if descriptor.is_generator:
if supports_generators and descriptor.is_generator:
# Even for async generators, don't await here.
return model_fn(*args)
if descriptor.is_async:
return await model_fn(*args)
Expand All @@ -685,23 +713,28 @@ async def _trace_and_process_model_fn(
self,
inputs: InputType,
request: starlette.requests.Request,
method: ModelMethod,
exec_fn: Callable[
[InputType, starlette.requests.Request], Awaitable[OutputType]
],
method_name: ModelMethodName,
descriptor: MethodDescriptor,
model_fn: Any,
) -> OutputType:
fn_span = self._tracer.start_span(f"call-{method.value}")
fn_span = self._tracer.start_span(f"call-{method_name}")
with tracing.section_as_event(
fn_span, method.value
fn_span, method_name
), tracing.detach_context() as detached_ctx:
result = await exec_fn(inputs, request)
result = await self._execute_async_model_fn(
descriptor, inputs, request, model_fn
)

if inspect.isgenerator(result) or inspect.isasyncgen(result):
if request.headers.get("accept") == "application/json":
return await _gather_generator(result)
else:
return await self._stream_with_background_task(
result, fn_span, detached_ctx
result,
fn_span,
detached_ctx,
# No semaphores needed for non-predict model functions.
release_and_end=lambda: None,
)

return result
Expand All @@ -711,43 +744,31 @@ async def completions(
) -> OutputType:
descriptor = self.model_descriptor.completions
assert descriptor, (
f"`{ModelMethod.COMPLETIONS.value}` must only be called if model has it."
f"`{ModelMethodName.COMPLETIONS}` must only be called if model has it."
)

async def exec_fn(
inputs: InputType, request: starlette.requests.Request
) -> OutputType:
return await self._execute_async_model_fn(
descriptor, inputs, request, self._model.completions
)

return await self._trace_and_process_model_fn(
inputs=inputs,
request=request,
method=ModelMethod.COMPLETIONS,
exec_fn=exec_fn,
method_name=ModelMethodName.COMPLETIONS,
descriptor=descriptor,
model_fn=self._model.completions,
)

async def chat_completions(
self, inputs: InputType, request: starlette.requests.Request
) -> OutputType:
descriptor = self.model_descriptor.chat_completions
assert descriptor, (
f"`{ModelMethod.CHAT_COMPLETIONS.value}` must only be called if model has it."
f"`{ModelMethodName.CHAT_COMPLETIONS}` must only be called if model has it."
)

async def exec_fn(
inputs: InputType, request: starlette.requests.Request
) -> OutputType:
return await self._execute_async_model_fn(
descriptor, inputs, request, self._model.chat_completions
)

return await self._trace_and_process_model_fn(
inputs=inputs,
request=request,
method=ModelMethod.CHAT_COMPLETIONS,
exec_fn=exec_fn,
method_name=ModelMethodName.CHAT_COMPLETIONS,
descriptor=descriptor,
model_fn=self._model.chat_completions,
)

async def __call__(
Expand Down
Loading

0 comments on commit 86c2ec0

Please sign in to comment.