Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean Chains Stack Traces and consolidate logging config. Fixes BT-13465 BT-13378 #1353

Merged
merged 3 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions truss-chains/tests/itest_chain/itest_chain.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import logging
import math

from user_package import shared_chainlet
from user_package.nested_package import io_types

import truss_chains as chains

logger = logging.getLogger(__name__)

IMAGE_BASETEN = chains.DockerImage(
base_image=chains.BasetenImage.PY310,
pip_requirements_file=chains.make_abs_path_here("requirements.txt"),
Expand Down Expand Up @@ -103,6 +106,7 @@ def __init__(
text_to_num: TextToNum = chains.depends(TextToNum),
context=chains.depends_context(),
) -> None:
logging.info("User log root during load.")
self._context = context
self._data_generator = data_generator
self._data_splitter = splitter
Expand All @@ -117,6 +121,8 @@ async def run_remote(
),
simple_default_arg: list[str] = ["a", "b"],
) -> tuple[int, str, int, shared_chainlet.SplitTextOutput, list[str]]:
logging.info("User log root.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you intend to leave these in?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeas, that's in the integration test chain, I like to have this around when debugging/changing anything.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good

logger.info("User log module.")
data = self._data_generator.run_remote(length)
text_parts, number, items = await self._data_splitter.run_remote(
io_types.SplitTextInput(
Expand Down
34 changes: 10 additions & 24 deletions truss-chains/truss_chains/deployment/deployment_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,11 +221,10 @@ def push(
class DockerChainletService(b10_service.TrussService):
"""This service is for Chainlets (not for Chains)."""

def __init__(self, port: int, is_draft: bool, **kwargs):
def __init__(self, port: int, **kwargs):
remote_url = f"http://localhost:{port}"
self._port = port

super().__init__(remote_url, is_draft, **kwargs)
super().__init__(remote_url, is_draft=False, **kwargs)

def authenticate(self) -> Dict[str, str]:
return {}
Expand All @@ -246,10 +245,6 @@ def is_ready(self) -> bool:
def logs_url(self) -> str:
raise NotImplementedError()

@property
def port(self) -> int:
return self._port

@property
def predict_url(self) -> str:
return f"{self._service_url}/v1/models/model:predict"
Expand All @@ -272,6 +267,7 @@ def _push_service_docker(
wait_for_server_ready=True,
network="host",
container_name_prefix=chainlet_display_name,
disable_json_logging=True,
)


Expand Down Expand Up @@ -309,12 +305,13 @@ def _create_docker_chain(
entrypoint_artifact: b10_types.ChainletArtifact,
dependency_artifacts: list[b10_types.ChainletArtifact],
) -> DockerChainService:
chainlet_artifacts = [entrypoint_artifact, *dependency_artifacts]
chainlet_artifacts = [*dependency_artifacts, entrypoint_artifact]
chainlet_to_predict_url: Dict[str, Dict[str, str]] = {}
chainlet_to_service: Dict[str, DockerChainletService] = {}
for chainlet_artifact in chainlet_artifacts:
port = utils.get_free_port()
service = DockerChainletService(is_draft=True, port=port)
service = DockerChainletService(port)

docker_internal_url = service.predict_url.replace(
"localhost", "host.docker.internal"
)
Expand All @@ -323,27 +320,16 @@ def _create_docker_chain(
}
chainlet_to_service[chainlet_artifact.name] = service

local_config_handler.LocalConfigHandler.set_dynamic_config(
definitions.DYNAMIC_CHAINLET_CONFIG_KEY, json.dumps(chainlet_to_predict_url)
)
local_config_handler.LocalConfigHandler.set_dynamic_config(
definitions.DYNAMIC_CHAINLET_CONFIG_KEY, json.dumps(chainlet_to_predict_url)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, this doesn't have to have all the predict URLs of dependencies populated before mounting and running? I noticed you put entrypoint_artifact at the end of the chainlet_artifacts list but what about the topology of other Chainlets?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chains are always DAGs by construction and _get_ordered_dependencies gives you a topological sorting.

That's also how it worked before we had dynamic config map, because then all URLs were directly baked into the generated truss config.


# TODO(Tyron): We run the Docker containers in a
# separate for-loop to make sure that the dynamic
# config is populated (the same one gets mounted
# on all the containers). We should look into
# consolidating the logic into a single for-loop.
# One approach might be to use separate config
# paths for each container under the `/tmp` dir.
for chainlet_artifact in chainlet_artifacts:
truss_dir = chainlet_artifact.truss_dir
logging.info(
f"Building Chainlet `{chainlet_artifact.display_name}` docker image."
)
_push_service_docker(
truss_dir,
chainlet_artifact.display_name,
docker_options,
chainlet_to_service[chainlet_artifact.name].port,
truss_dir, chainlet_artifact.display_name, docker_options, port
)
logging.info(
f"Pushed Chainlet `{chainlet_artifact.display_name}` as docker container."
Expand Down
6 changes: 4 additions & 2 deletions truss/templates/control/control/application.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import logging
import logging.config
import re
from pathlib import Path
from typing import Dict
Expand All @@ -13,7 +14,7 @@
from helpers.inference_server_process_controller import InferenceServerProcessController
from helpers.inference_server_starter import async_inference_server_startup_flow
from helpers.truss_patch.model_container_patch_applier import ModelContainerPatchApplier
from shared.logging import setup_logging
from shared import log_config
from starlette.datastructures import State


Expand All @@ -35,10 +36,11 @@ async def handle_model_load_failed(_, error):

def create_app(base_config: Dict):
app_state = State()
setup_logging()
app_logger = logging.getLogger(__name__)
app_state.logger = app_logger

logging.config.dictConfig(log_config.make_log_config("INFO"))

for k, v in base_config.items():
setattr(app_state, k, v)

Expand Down
25 changes: 24 additions & 1 deletion truss/templates/server/common/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,30 @@ def intercept_exceptions(
yield
# Note that logger.error logs the stacktrace, such that the user can
# debug this error from the logs.
except fastapi.HTTPException:
except fastapi.HTTPException as e:
# TODO: we try to avoid any dependency of the truss server on chains, but for
# the purpose of getting readable chained-stack traces in the server logs,
# we have to add a special-case here.
if "user_stack_trace" in e.detail:
try:
from truss_chains import definitions

chains_error = definitions.RemoteErrorDetail.model_validate(e.detail)
# The formatted error contains a (potentially chained) stack trace
# with all framework code removed, see
# truss_chains/remote_chainlet/utils.py::response_raise_errors.
logger.error(f"Chainlet raised Exception:\n{chains_error.format()}")
except: # If we cannot import chains or parse the error.
logger.error(
"Model raised HTTPException",
exc_info=filter_traceback(model_file_name),
)
raise
# If error was extracted successfully, the customized stack trace is
# already printed above, so we raise with a clear traceback.
e.__traceback__ = None
raise e from None

logger.error(
"Model raised HTTPException", exc_info=filter_traceback(model_file_name)
)
Expand Down
2 changes: 0 additions & 2 deletions truss/templates/server/main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import os

from shared.logging import setup_logging
from truss_server import TrussServer

CONFIG_FILE = "config.yaml"

if __name__ == "__main__":
setup_logging()
http_port = int(os.environ.get("INFERENCE_SERVER_PORT", "8080"))
TrussServer(http_port, CONFIG_FILE).start()
4 changes: 2 additions & 2 deletions truss/templates/server/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,9 @@ def __init__(self, config: Dict, tracer: sdk_trace.Tracer):
# We need a logger that has all our server JSON logging setup applied in its
# handlers and where this also hold in the loading thread. Creating a new
# instance does not carry over the setup into the thread and using unspecified
# `getLogger` may return non-compliant loggers if depdencies override the root
# `getLogger` may return non-compliant loggers if dependencies override the root
# logger (c.g. https://github.com/numpy/numpy/issues/24213). We chose to get
# the uvicorn logger that is setup in `truss_server`.
# the uvicorn logger that is set up in `truss_server`.
self._logger = logging.getLogger("uvicorn")
self.name = MODEL_BASENAME
self._load_lock = Lock()
Expand Down
61 changes: 7 additions & 54 deletions truss/templates/server/truss_server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import json
import logging
import logging.config
import os
import signal
import sys
Expand All @@ -21,8 +22,7 @@
from opentelemetry import trace
from opentelemetry.sdk import trace as sdk_trace
from pydantic import BaseModel
from shared import serialization
from shared.logging import setup_logging
from shared import log_config, serialization
from shared.secrets_resolver import SecretsResolver
from starlette.requests import ClientDisconnect
from starlette.responses import Response
Expand All @@ -37,7 +37,6 @@
# [IMPORTANT] A lot of things depend on this currently, change with extreme care.
TIMEOUT_GRACEFUL_SHUTDOWN = 120
INFERENCE_SERVER_FAILED_FILE = Path("~/inference_server_crashed.txt").expanduser()
DATE_FORMAT = "%Y-%m-%d %H:%M:%S"


async def parse_body(request: Request) -> bytes:
Expand Down Expand Up @@ -260,12 +259,10 @@ class TrussServer:

_server: Optional[uvicorn.Server]

def __init__(
self,
http_port: int,
config_or_path: Union[str, Path, Dict],
setup_json_logger: bool = True,
):
def __init__(self, http_port: int, config_or_path: Union[str, Path, Dict]):
# This is run before uvicorn is up. Need explicit logging config here.
logging.config.dictConfig(log_config.make_log_config("INFO"))

if isinstance(config_or_path, (str, Path)):
with open(config_or_path, encoding="utf-8") as config_file:
config = yaml.safe_load(config_file)
Expand All @@ -274,7 +271,6 @@ def __init__(

secrets = SecretsResolver.get_secrets(config)
tracer = tracing.get_truss_tracer(secrets, config)
self._setup_json_logger = setup_json_logger
self._http_port = http_port
self._config = config
self._model = ModelWrapper(self._config, tracer)
Expand All @@ -291,8 +287,6 @@ def on_startup(self):
we want to setup our logging and model.
"""
self.cleanup()
if self._setup_json_logger:
setup_logging()
self._model.start_load_thread()
asyncio.create_task(self._shutdown_if_load_fails())
self._model.setup_polling_for_environment_updates()
Expand Down Expand Up @@ -366,9 +360,6 @@ def start(self):
if self._config["runtime"].get("enable_debug_logs", False)
else "INFO"
)
# Warning: `ModelWrapper` depends on correctly setup `uvicorn` logger,
# if you change/remove that logger, make sure `ModelWrapper` has a suitable
# alternative logger that is also correctly setup in the load thread.
cfg = uvicorn.Config(
self.create_application(),
# We hard-code the http parser as h11 (the default) in case the user has
Expand All @@ -379,45 +370,7 @@ def start(self):
port=self._http_port,
workers=1,
timeout_graceful_shutdown=TIMEOUT_GRACEFUL_SHUTDOWN,
log_config={
"version": 1,
"formatters": {
"default": {
"()": "uvicorn.logging.DefaultFormatter",
"datefmt": DATE_FORMAT,
"fmt": "%(asctime)s.%(msecs)03d %(name)s %(levelprefix)s %(message)s",
"use_colors": None,
},
"access": {
"()": "uvicorn.logging.AccessFormatter",
"datefmt": DATE_FORMAT,
"fmt": "%(asctime)s.%(msecs)03d %(name)s %(levelprefix)s %(client_addr)s %(process)s - "
'"%(request_line)s" %(status_code)s',
# noqa: E501
},
},
"handlers": {
"default": {
"formatter": "default",
"class": "logging.StreamHandler",
"stream": "ext://sys.stderr",
},
"access": {
"formatter": "access",
"class": "logging.StreamHandler",
"stream": "ext://sys.stdout",
},
},
"loggers": {
"uvicorn": {"handlers": ["default"], "level": log_level},
"uvicorn.error": {"level": "INFO"},
"uvicorn.access": {
"handlers": ["access"],
"level": "INFO",
"propagate": False,
},
},
},
log_config=log_config.make_log_config(log_level),
)
cfg.setup_event_loop() # Call this so uvloop gets used
server = uvicorn.Server(config=cfg)
Expand Down
3 changes: 0 additions & 3 deletions truss/templates/shared/README.md

This file was deleted.

Loading
Loading