diff --git a/src/ltdproxy/config.py b/src/ltdproxy/config.py index 693f787..619f275 100644 --- a/src/ltdproxy/config.py +++ b/src/ltdproxy/config.py @@ -4,6 +4,7 @@ import os from enum import Enum +from typing import Optional from pydantic import BaseSettings, Field, FilePath, HttpUrl, SecretStr @@ -69,6 +70,16 @@ class Configuration(BaseSettings): rewrites_config_path: FilePath = Field(env="LTDPROXY_REWRITES_CONFIG") + healthcheck_bucket_key: Optional[str] = Field( + None, + description=( + "A key in the bucket that the healthcheck endpoint will attempt " + "to stream. This is an actual bucket key, and is independent " + "of the s3_bucket_prefix configuration." + ), + env="LTDPROXY_S3_HEALTHCHECK_KEY", + ) + config = Configuration(_env_file=os.getenv("LTD_PROXY_ENV")) """Configuration for ltd-proxy.""" diff --git a/src/ltdproxy/handlers/healthcheck.py b/src/ltdproxy/handlers/healthcheck.py index bea2efc..0c60b46 100644 --- a/src/ltdproxy/handlers/healthcheck.py +++ b/src/ltdproxy/handlers/healthcheck.py @@ -1,11 +1,37 @@ """Handler for the Kubernetes health check.""" -from fastapi import APIRouter +import httpx +from fastapi import APIRouter, Depends +from safir.dependencies.http_client import http_client_dependency +from safir.dependencies.logger import logger_dependency from starlette.responses import PlainTextResponse +from structlog.stdlib import BoundLogger + +from ltdproxy.config import config +from ltdproxy.s3 import Bucket, bucket_dependency health_router = APIRouter() @health_router.get("/__healthz", name="healthz") -def healthy() -> PlainTextResponse: +async def healthy( + bucket: Bucket = Depends(bucket_dependency), + logger: BoundLogger = Depends(logger_dependency), + http_client: httpx.AsyncClient = Depends(http_client_dependency), +) -> PlainTextResponse: + if config.healthcheck_bucket_key: + # enter mode for testing S3 streaming + stream = await bucket.stream_object( + http_client, config.healthcheck_bucket_key + ) + if stream.status_code != httpx.codes.OK: + logger.error( + "Health check got bad S3 response code", + status_code=stream.status_code, + ) + return PlainTextResponse("ERROR", status_code=500) + async for _ in stream.aiter_bytes(): + pass + await stream.aclose() + return PlainTextResponse("OK", status_code=200)