diff --git a/changes/3680.feature.md b/changes/3680.feature.md new file mode 100644 index 0000000000..f7b633f290 --- /dev/null +++ b/changes/3680.feature.md @@ -0,0 +1 @@ +Make security policy configurable \ No newline at end of file diff --git a/configs/webserver/halfstack.conf b/configs/webserver/halfstack.conf index 3dc7c3a97c..435e5517ae 100644 --- a/configs/webserver/halfstack.conf +++ b/configs/webserver/halfstack.conf @@ -28,6 +28,10 @@ max_cuda_shares_per_container = 16 max_shm_per_container = 2 max_file_upload_size = 4294967296 +[security] +request_policies = ["reject_metadata_local_link_policy", "reject_access_for_unsafe_file_policy"] +response_policies = ["add_self_content_security_policy", "set_content_type_nosniff_policy"] + [environments] [plugin] diff --git a/src/ai/backend/web/config.py b/src/ai/backend/web/config.py index 47e35b053c..80821dc83e 100644 --- a/src/ai/backend/web/config.py +++ b/src/ai/backend/web/config.py @@ -26,6 +26,11 @@ }, } +_default_security_config: Mapping[str, list[str]] = { + "request_policies": [], + "response_policies": [], +} + config_iv = t.Dict({ t.Key("service"): t.Dict({ t.Key("ip", default="0.0.0.0"): tx.IPAddress, @@ -67,6 +72,10 @@ t.Key("enable_model_store", default=True): t.ToBool(), t.Key("enable_extend_login_session", default=False): t.ToBool(), }).allow_extra("*"), + t.Key("security", default=_default_security_config): t.Dict({ + t.Key("request_policies", default=[]): t.List(t.String), + t.Key("response_policies", default=[]): t.List(t.String), + }).allow_extra("*"), t.Key("resources"): t.Dict({ t.Key("open_port_to_public", default=False): t.ToBool, t.Key("allow_non_auth_tcp", default=False): t.ToBool, diff --git a/src/ai/backend/web/security.py b/src/ai/backend/web/security.py index 33a39edb18..00eddce14a 100644 --- a/src/ai/backend/web/security.py +++ b/src/ai/backend/web/security.py @@ -3,6 +3,10 @@ from aiohttp import web from aiohttp.typedefs import Handler +type RequestPolicy = Callable[[web.Request], None] + +type ResponsePolicy = Callable[[web.StreamResponse], web.StreamResponse] + @web.middleware async def security_policy_middleware(request: web.Request, handler: Handler) -> web.StreamResponse: @@ -13,17 +17,40 @@ async def security_policy_middleware(request: web.Request, handler: Handler) -> class SecurityPolicy: - _request_policies: Iterable[Callable[[web.Request], None]] - _response_policies: Iterable[Callable[[web.StreamResponse], web.StreamResponse]] + _request_policies: Iterable[RequestPolicy] + _response_policies: Iterable[ResponsePolicy] def __init__( self, - request_policies: Iterable[Callable[[web.Request], None]], - response_policies: Iterable[Callable[[web.StreamResponse], web.StreamResponse]], + request_policies: Iterable[RequestPolicy], + response_policies: Iterable[ResponsePolicy], ) -> None: self._request_policies = request_policies self._response_policies = response_policies + @classmethod + def from_config( + cls, request_policy_config: list[str], response_policy_config: list[str] + ) -> Self: + request_policy_map = { + "reject_metadata_local_link_policy": reject_metadata_local_link_policy, + "reject_access_for_unsafe_file_policy": reject_access_for_unsafe_file_policy, + } + response_policy_map = { + "add_self_content_security_policy": add_self_content_security_policy, + "set_content_type_nosniff_policy": set_content_type_nosniff_policy, + } + try: + request_policies = [ + request_policy_map[policy_name] for policy_name in request_policy_config + ] + response_policies = [ + response_policy_map[policy_name] for policy_name in response_policy_config + ] + except KeyError as e: + raise ValueError(f"Unknown security policy name: {e}") + return cls(request_policies, response_policies) + @classmethod def default_policy(cls) -> Self: request_policies = [reject_metadata_local_link_policy, reject_access_for_unsafe_file_policy] diff --git a/src/ai/backend/web/server.py b/src/ai/backend/web/server.py index b8f545c904..42f581ffcc 100644 --- a/src/ai/backend/web/server.py +++ b/src/ai/backend/web/server.py @@ -609,7 +609,11 @@ async def server_main( middlewares=[decrypt_payload, track_active_handlers, security_policy_middleware] ) app["config"] = config - app["security_policy"] = SecurityPolicy.default_policy() + request_policy_config = config["security"]["request_policies"] + response_policy_config = config["security"]["response_policies"] + app["security_policy"] = SecurityPolicy.from_config( + request_policy_config, response_policy_config + ) j2env = jinja2.Environment( extensions=[ "ai.backend.web.template.TOMLField",