Skip to content

Commit

Permalink
feat: Make security policy configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
HyeockJinKim committed Feb 13, 2025
1 parent a12e32f commit 0699e9c
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 5 deletions.
4 changes: 4 additions & 0 deletions configs/webserver/halfstack.conf
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ max_cuda_shares_per_container = 16
max_shm_per_container = 2
max_file_upload_size = 4294967296

[security]
request_policies = ["reject_metadata_local_link_policy", "reject_access_for_unsafe_file_policy"]
response_policies = ["add_self_content_security_policy", "set_content_type_nosniff_policy"]

[environments]

[plugin]
Expand Down
9 changes: 9 additions & 0 deletions src/ai/backend/web/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@
},
}

_default_security_config: Mapping[str, list[str]] = {
"request_policies": [],
"response_policies": [],
}

config_iv = t.Dict({
t.Key("service"): t.Dict({
t.Key("ip", default="0.0.0.0"): tx.IPAddress,
Expand Down Expand Up @@ -67,6 +72,10 @@
t.Key("enable_model_store", default=True): t.ToBool(),
t.Key("enable_extend_login_session", default=False): t.ToBool(),
}).allow_extra("*"),
t.Key("security", default=_default_security_config): t.Dict({
t.Key("request_policies", default=[]): t.List(t.String),
t.Key("response_policies", default=[]): t.List(t.String),
}).allow_extra("*"),
t.Key("resources"): t.Dict({
t.Key("open_port_to_public", default=False): t.ToBool,
t.Key("allow_non_auth_tcp", default=False): t.ToBool,
Expand Down
35 changes: 31 additions & 4 deletions src/ai/backend/web/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
from aiohttp import web
from aiohttp.typedefs import Handler

type RequestPolicy = Callable[[web.Request], None]

type ResponsePolicy = Callable[[web.StreamResponse], web.StreamResponse]


@web.middleware
async def security_policy_middleware(request: web.Request, handler: Handler) -> web.StreamResponse:
Expand All @@ -13,17 +17,40 @@ async def security_policy_middleware(request: web.Request, handler: Handler) ->


class SecurityPolicy:
_request_policies: Iterable[Callable[[web.Request], None]]
_response_policies: Iterable[Callable[[web.StreamResponse], web.StreamResponse]]
_request_policies: Iterable[RequestPolicy]
_response_policies: Iterable[ResponsePolicy]

def __init__(
self,
request_policies: Iterable[Callable[[web.Request], None]],
response_policies: Iterable[Callable[[web.StreamResponse], web.StreamResponse]],
request_policies: Iterable[RequestPolicy],
response_policies: Iterable[ResponsePolicy],
) -> None:
self._request_policies = request_policies
self._response_policies = response_policies

@classmethod
def from_config(
cls, request_policy_config: list[str], response_policy_config: list[str]
) -> Self:
request_policy_map = {
"reject_metadata_local_link_policy": reject_metadata_local_link_policy,
"reject_access_for_unsafe_file_policy": reject_access_for_unsafe_file_policy,
}
response_policy_map = {
"add_self_content_security_policy": add_self_content_security_policy,
"set_content_type_nosniff_policy": set_content_type_nosniff_policy,
}
try:
request_policies = [
request_policy_map[policy_name] for policy_name in request_policy_config
]
response_policies = [
response_policy_map[policy_name] for policy_name in response_policy_config
]
except KeyError as e:
raise ValueError(f"Unknown security policy name: {e}")
return cls(request_policies, response_policies)

@classmethod
def default_policy(cls) -> Self:
request_policies = [reject_metadata_local_link_policy, reject_access_for_unsafe_file_policy]
Expand Down
6 changes: 5 additions & 1 deletion src/ai/backend/web/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,11 @@ async def server_main(
middlewares=[decrypt_payload, track_active_handlers, security_policy_middleware]
)
app["config"] = config
app["security_policy"] = SecurityPolicy.default_policy()
request_policy_config = config["security"]["request_policies"]
response_policy_config = config["security"]["response_policies"]
app["security_policy"] = SecurityPolicy.from_config(
request_policy_config, response_policy_config
)
j2env = jinja2.Environment(
extensions=[
"ai.backend.web.template.TOMLField",
Expand Down

0 comments on commit 0699e9c

Please sign in to comment.