Skip to content

Commit

Permalink
Make TrussServer single-worker, remove faulty middleware. (#1172)
Browse files Browse the repository at this point in the history
* Remove N workers

* Fix tests

* Cleanup

* Add cancellaiton integration test

* Reivew comments

* Update

* Fix tests
  • Loading branch information
marius-baseten authored Nov 5, 2024
1 parent d9e1c01 commit d28a29a
Show file tree
Hide file tree
Showing 9 changed files with 239 additions and 330 deletions.
1 change: 1 addition & 0 deletions truss/templates/server/common/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Callable


# TODO: replace with tenacity.
def retry(
fn: Callable,
count: int,
Expand Down
64 changes: 0 additions & 64 deletions truss/templates/server/common/termination_handler_middleware.py

This file was deleted.

10 changes: 3 additions & 7 deletions truss/templates/server/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,11 +328,10 @@ def start_load_thread(self):
thread = Thread(target=self.load)
thread.start()

def load(self) -> bool:
def load(self):
if self.ready:
return True

# if we are already loading, block on aquiring the lock;
return
# if we are already loading, block on acquiring the lock;
# this worker will return 503 while the worker with the lock is loading
with self._load_lock:
self._status = ModelWrapper.Status.LOADING
Expand All @@ -344,13 +343,10 @@ def load(self) -> bool:
self._logger.info(
f"Completed model.load() execution in {_elapsed_ms(start_time)} ms"
)
return True
except Exception:
self._logger.exception("Exception while loading model")
self._status = ModelWrapper.Status.FAILED

return False

def _load_impl(self):
data_dir = Path("data")
data_dir.mkdir(exist_ok=True)
Expand Down
126 changes: 36 additions & 90 deletions truss/templates/server/truss_server.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,28 @@
import asyncio
import json
import logging
import multiprocessing
import os
import signal
import socket
import sys
import time
from http import HTTPStatus
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Dict, Optional, Union

import pydantic
import uvicorn
import yaml
from common import errors, tracing
from common.schema import TrussSchema
from common.termination_handler_middleware import TerminationHandlerMiddleware
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.responses import ORJSONResponse, StreamingResponse
from fastapi.routing import APIRoute as FastAPIRoute
from model_wrapper import ModelWrapper
from opentelemetry import propagate as otel_propagate
from opentelemetry import trace
from opentelemetry.sdk import trace as sdk_trace
from shared import serialization, util
from shared import serialization
from shared.logging import setup_logging
from shared.secrets_resolver import SecretsResolver
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import ClientDisconnect
from starlette.responses import Response

Expand All @@ -36,14 +31,8 @@
else:
from typing_extensions import AsyncGenerator, Generator

# [IMPORTANT] A lot of things depend on this currently.
# Please consider the following when increasing this:
# 1. Self-termination on model load fail.
# 2. Graceful termination.
DEFAULT_NUM_WORKERS = 1
DEFAULT_NUM_SERVER_PROCESSES = 1
WORKER_TERMINATION_TIMEOUT_SECS = 120.0
WORKER_TERMINATION_CHECK_INTERVAL_SECS = 0.5
# [IMPORTANT] A lot of things depend on this currently, change with extreme care.
TIMEOUT_GRACEFUL_SHUTDOWN = 120
INFERENCE_SERVER_FAILED_FILE = Path("~/inference_server_crashed.txt").expanduser()
DATE_FORMAT = "%Y-%m-%d %H:%M:%S"

Expand All @@ -60,22 +49,6 @@ async def parse_body(request: Request) -> bytes:
raise HTTPException(status_code=499, detail=error_message) from exc


class UvicornCustomServer(multiprocessing.Process):
def __init__(
self, config: uvicorn.Config, sockets: Optional[List[socket.socket]] = None
):
super().__init__()
self.sockets = sockets
self.config = config

def stop(self):
self.terminate()

def run(self):
server = uvicorn.Server(config=self.config)
asyncio.run(server.serve(sockets=self.sockets))


class BasetenEndpoints:
"""The implementation of the model server endpoints.
Expand Down Expand Up @@ -170,6 +143,11 @@ async def predict(
"""
This method calls the user-provided predict method
"""
if await request.is_disconnected():
msg = "Client disconnected. Skipping `predict`."
logging.info(msg)
raise ClientDisconnect(msg)

model: ModelWrapper = self._safe_lookup_model(model_name)

self.check_healthy(model)
Expand Down Expand Up @@ -248,6 +226,8 @@ class TrussServer:
main loop.
"""

_server: Optional[uvicorn.Server]

def __init__(
self,
http_port: int,
Expand All @@ -263,10 +243,11 @@ def __init__(
secrets = SecretsResolver.get_secrets(config)
tracer = tracing.get_truss_tracer(secrets, config)
self._setup_json_logger = setup_json_logger
self.http_port = http_port
self._http_port = http_port
self._config = config
self._model = ModelWrapper(self._config, tracer)
self._endpoints = BasetenEndpoints(self._model, tracer)
self._server = None

def cleanup(self):
if INFERENCE_SERVER_FAILED_FILE.exists():
Expand All @@ -281,8 +262,18 @@ def on_startup(self):
if self._setup_json_logger:
setup_logging()
self._model.start_load_thread()
asyncio.create_task(self._shutdown_if_load_fails())
self._model.setup_polling_for_environment_updates()

async def _shutdown_if_load_fails(self):
while not self._model.ready:
await asyncio.sleep(0.5)
if self._model.load_failed:
assert self._server is not None
logging.info("Trying shut down.")
self._server.should_exit = True
return

def create_application(self):
app = FastAPI(
title="Baseten Inference Server",
Expand Down Expand Up @@ -331,29 +322,24 @@ def create_application(self):
# This here is a fallback to add our custom headers in all other cases.
app.add_exception_handler(Exception, errors.exception_handler)

def exit_self():
# Note that this kills the current process, the worker process, not
# the main truss_server process.
util.kill_child_processes(os.getpid())
sys.exit()

termination_handler_middleware = TerminationHandlerMiddleware(
on_stop=lambda: None,
on_term=exit_self,
)
app.add_middleware(BaseHTTPMiddleware, dispatch=termination_handler_middleware)
return app

def start(self):
log_level = (
"DEBUG"
if self._config["runtime"].get("enable_debug_logs", False)
else "INFO"
)
cfg = uvicorn.Config(
self.create_application(),
# We hard-code the http parser as h11 (the default) in case the user has
# httptools installed, which does not work with our requests & version
# of uvicorn.
http="h11",
host="0.0.0.0",
port=self.http_port,
workers=DEFAULT_NUM_WORKERS,
port=self._http_port,
workers=1,
timeout_graceful_shutdown=TIMEOUT_GRACEFUL_SHUTDOWN,
log_config={
"version": 1,
"formatters": {
Expand Down Expand Up @@ -384,7 +370,7 @@ def start(self):
},
},
"loggers": {
"uvicorn": {"handlers": ["default"], "level": "INFO"},
"uvicorn": {"handlers": ["default"], "level": log_level},
"uvicorn.error": {"level": "INFO"},
"uvicorn.access": {
"handlers": ["access"],
Expand All @@ -394,47 +380,7 @@ def start(self):
},
},
)

# Call this so uvloop gets used
cfg.setup_event_loop()

async def serve() -> None:
serversocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
serversocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
serversocket.bind((cfg.host, cfg.port))
serversocket.listen(5)

num_server_procs = self._config.get("runtime", {}).get(
"num_workers", DEFAULT_NUM_SERVER_PROCESSES
)
logging.info(f"starting {num_server_procs} uvicorn server processes")
servers: List[UvicornCustomServer] = []
for _ in range(num_server_procs):
server = UvicornCustomServer(config=cfg, sockets=[serversocket])
server.start()
servers.append(server)

def stop_servers():
# Send stop signal, then wait for all to exit
for server in servers:
# Sends term signal to the process, which should be handled
# by the termination handler.
server.stop()

termination_check_attempts = int(
WORKER_TERMINATION_TIMEOUT_SECS
/ WORKER_TERMINATION_CHECK_INTERVAL_SECS
)
for _ in range(termination_check_attempts):
time.sleep(WORKER_TERMINATION_CHECK_INTERVAL_SECS)
if util.all_processes_dead(servers):
return

for sig in [signal.SIGINT, signal.SIGTERM, signal.SIGQUIT]:
signal.signal(sig, lambda sig, frame: stop_servers())

async def servers_task():
servers = [serve()]
await asyncio.gather(*servers)

asyncio.run(servers_task())
cfg.setup_event_loop() # Call this so uvloop gets used
server = uvicorn.Server(config=cfg)
self._server = server
asyncio.run(server.serve())
Loading

0 comments on commit d28a29a

Please sign in to comment.