Skip to content

Commit

Permalink
implement generic service proxying, with hardcoded DID resolution
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidBuchanan314 committed Dec 13, 2024
1 parent de236c2 commit 8619b6a
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 71 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dependencies = [
"pyjwt[crypto]",
"cryptography",
"aiohttp",
"aiohttp_cors",
"aiohttp-middlewares", # cors
"docopt",
"apsw",
"argon2-cffi",
Expand Down
31 changes: 25 additions & 6 deletions src/millipds/appview_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,47 @@

logger = logging.getLogger(__name__)


# TODO: this should be done via actual DID resolution, not hardcoded!
SERVICE_ROUTES = {
"did:web:api.bsky.chat#bsky_chat": "https://api.bsky.chat"
}

@authenticated
async def static_appview_proxy(request: web.Request):
async def service_proxy(request: web.Request, service: Optional[str]=None):
"""
If `service` is None, default to bsky appview (per details in db config)
"""
lxm = request.path.rpartition("/")[2].partition("?")[0]
# TODO: verify valid lexicon method?
logger.info(f"proxying lxm {lxm}")
db = get_db(request)
if service:
service_did = service.partition("#")[0]
service_route = SERVICE_ROUTES.get(service)
if service_route is None:
return web.HTTPBadRequest(f"unable to resolve service {service!r}")
else:
service_did = db.config["bsky_appview_did"]
service_route = db.config["bsky_appview_pfx"]

signing_key = db.signing_key_pem_by_did(request["authed_did"])
authn = {
"Authorization": "Bearer "
+ jwt.encode(
{
"iss": request["authed_did"],
"aud": db.config["bsky_appview_did"],
"aud": service_did,
"lxm": lxm,
"exp": int(time.time()) + 60 * 60 * 24, # 24h
"exp": int(time.time()) + 5 * 60, # 5 mins
},
signing_key,
algorithm=crypto.jwt_signature_alg_for_pem(signing_key),
)
} # TODO: cache this!
appview_pfx = db.config["bsky_appview_pfx"]
if request.method == "GET":
async with get_client(request).get(
appview_pfx + request.path, params=request.query, headers=authn
service_route + request.path, params=request.query, headers=authn
) as r:
body_bytes = await r.read() # TODO: streaming?
return web.Response(
Expand All @@ -42,11 +59,13 @@ async def static_appview_proxy(request: web.Request):
elif request.method == "POST":
request_body = await request.read() # TODO: streaming?
async with get_client(request).post(
appview_pfx + request.path, data=request_body, headers=(authn|{"Content-Type": request.content_type})
service_route + request.path, data=request_body, headers=(authn|{"Content-Type": request.content_type})
) as r:
body_bytes = await r.read() # TODO: streaming?
return web.Response(
body=body_bytes, content_type=r.content_type, status=r.status
) # XXX: allowlist safe content types!
elif request.method == "PUT":
raise NotImplementedError("TODO: PUT")
else:
raise NotImplementedError("TODO")
4 changes: 2 additions & 2 deletions src/millipds/atproto_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import apsw

from . import repo_ops
from .appview_proxy import static_appview_proxy
from .appview_proxy import service_proxy
from .auth_bearer import authenticated
from .app_util import *

Expand Down Expand Up @@ -156,7 +156,7 @@ async def repo_get_record(request: web.Request):
(did_or_handle, did_or_handle, collection, rkey)
).fetchone()
if row is None:
return await static_appview_proxy(request) # forward to appview
return await service_proxy(request) # forward to appview
#return web.HTTPNotFound(text="record not found")
cid_out, value = row
cid_out = cbrrr.CID(cid_out)
Expand Down
4 changes: 2 additions & 2 deletions src/millipds/auth_bearer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

routes = web.RouteTableDef()
def authenticated(handler):
def authentication_handler(request: web.Request):
def authentication_handler(request: web.Request, *args, **kwargs):
# extract the auth token
auth = request.headers.get("Authorization")
if auth is None:
Expand Down Expand Up @@ -43,6 +43,6 @@ def authentication_handler(request: web.Request):
if not subject.startswith("did:"):
raise web.HTTPUnauthorized(text="invalid jwt: invalid subject")
request["authed_did"] = subject
return handler(request)
return handler(request, *args, **kwargs)

return authentication_handler
105 changes: 45 additions & 60 deletions src/millipds/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import importlib.metadata
import logging
import asyncio
import aiohttp_cors
import time
import os
import io
Expand All @@ -11,6 +10,7 @@

import apsw
import aiohttp
from aiohttp_middlewares import cors_middleware
from aiohttp import web
import jwt

Expand All @@ -22,7 +22,7 @@
from . import atproto_sync
from . import atproto_repo
from . import util
from .appview_proxy import static_appview_proxy
from .appview_proxy import service_proxy
from .auth_bearer import authenticated
from .app_util import *

Expand All @@ -33,11 +33,10 @@

@web.middleware
async def atproto_service_proxy_middleware(request: web.Request, handler):
# TODO: if service proxying header is present, do service proxying!
# https://atproto.com/specs/xrpc#service-proxying
# (this implies having a DID resolver!!!) (probably with a cache!)
# if request.headers.get("atproto-proxy"):
# pass
atproto_proxy = request.headers.get("atproto-proxy")
if atproto_proxy:
return await service_proxy(request, atproto_proxy)

# else, normal response
res: web.Response = await handler(request)
Expand All @@ -51,13 +50,6 @@ async def atproto_service_proxy_middleware(request: web.Request, handler):
return res


# inject permissive CORS headers unconditionally
# async def prepare_cors_headers(request, response: web.Response):
# response.headers["Access-Control-Allow-Origin"] = "*"
# response.headers["Access-Control-Allow-Headers"] = "atproto-accept-labelers,authorization" # TODO: tighten?
# response.headers["Access-Control-Allow-Methods"] = "GET,HEAD,PUT,PATCH,POST,DELETE"


@routes.get("/")
async def hello(request: web.Request):
version = importlib.metadata.version("millipds")
Expand Down Expand Up @@ -300,7 +292,16 @@ async def server_get_session(request: web.Request):


def construct_app(routes, db: database.Database) -> web.Application:
app = web.Application(middlewares=[atproto_service_proxy_middleware])
cors = cors_middleware( # TODO: review and reduce scope - and maybe just /xrpc/*?
allow_all=True,
expose_headers=["*"],
allow_headers=["*"],
allow_methods=["*"],
allow_credentials=True,
max_age=2_000_000_000
)

app = web.Application(middlewares=[cors, atproto_service_proxy_middleware])
app["MILLIPDS_DB"] = db
app["MILLIPDS_AIOHTTP_CLIENT"] = (
aiohttp.ClientSession()
Expand All @@ -318,55 +319,39 @@ def construct_app(routes, db: database.Database) -> web.Application:
# fmt off
# web.get ("/xrpc/app.bsky.actor.getPreferences", static_appview_proxy),
# web.post("/xrpc/app.bsky.actor.putPreferences", static_appview_proxy),
web.get("/xrpc/app.bsky.actor.getProfile", static_appview_proxy),
web.get("/xrpc/app.bsky.actor.getProfiles", static_appview_proxy),
web.get("/xrpc/app.bsky.actor.getSuggestions", static_appview_proxy),
web.get("/xrpc/app.bsky.actor.searchActorsTypeahead", static_appview_proxy),
web.get("/xrpc/app.bsky.labeler.getServices", static_appview_proxy),
web.get("/xrpc/app.bsky.notification.listNotifications", static_appview_proxy),
web.get("/xrpc/app.bsky.notification.getUnreadCount", static_appview_proxy),
web.post("/xrpc/app.bsky.notification.updateSeen", static_appview_proxy),
web.get("/xrpc/app.bsky.graph.getList", static_appview_proxy),
web.get("/xrpc/app.bsky.graph.getLists", static_appview_proxy),
web.get("/xrpc/app.bsky.graph.getFollows", static_appview_proxy),
web.get("/xrpc/app.bsky.graph.getFollowers", static_appview_proxy),
web.get("/xrpc/app.bsky.graph.getStarterPack", static_appview_proxy),
web.get("/xrpc/app.bsky.graph.getSuggestedFollowsByActor", static_appview_proxy),
web.get("/xrpc/app.bsky.graph.getActorStarterPacks", static_appview_proxy),
web.post("/xrpc/app.bsky.graph.muteActor", static_appview_proxy),
web.post("/xrpc/app.bsky.graph.unmuteActor", static_appview_proxy),
web.get("/xrpc/app.bsky.feed.getTimeline", static_appview_proxy),
web.get("/xrpc/app.bsky.feed.getAuthorFeed", static_appview_proxy),
web.get("/xrpc/app.bsky.feed.getActorFeeds", static_appview_proxy),
web.get("/xrpc/app.bsky.feed.getFeed", static_appview_proxy),
web.get("/xrpc/app.bsky.feed.getListFeed", static_appview_proxy),
web.get("/xrpc/app.bsky.feed.getFeedGenerator", static_appview_proxy),
web.get("/xrpc/app.bsky.feed.getFeedGenerators", static_appview_proxy),
web.get("/xrpc/app.bsky.feed.getPostThread", static_appview_proxy),
web.get("/xrpc/app.bsky.feed.getPosts", static_appview_proxy),
web.get("/xrpc/app.bsky.feed.getLikes", static_appview_proxy),
web.get("/xrpc/app.bsky.feed.getActorLikes", static_appview_proxy),
web.get("/xrpc/app.bsky.unspecced.getPopularFeedGenerators", static_appview_proxy),
web.get("/xrpc/chat.bsky.convo.listConvos", static_appview_proxy)
web.get("/xrpc/app.bsky.actor.getProfile", service_proxy),
web.get("/xrpc/app.bsky.actor.getProfiles", service_proxy),
web.get("/xrpc/app.bsky.actor.getSuggestions", service_proxy),
web.get("/xrpc/app.bsky.actor.searchActorsTypeahead", service_proxy),
web.get("/xrpc/app.bsky.labeler.getServices", service_proxy),
web.get("/xrpc/app.bsky.notification.listNotifications", service_proxy),
web.get("/xrpc/app.bsky.notification.getUnreadCount", service_proxy),
web.post("/xrpc/app.bsky.notification.updateSeen", service_proxy),
web.get("/xrpc/app.bsky.graph.getList", service_proxy),
web.get("/xrpc/app.bsky.graph.getLists", service_proxy),
web.get("/xrpc/app.bsky.graph.getFollows", service_proxy),
web.get("/xrpc/app.bsky.graph.getFollowers", service_proxy),
web.get("/xrpc/app.bsky.graph.getStarterPack", service_proxy),
web.get("/xrpc/app.bsky.graph.getSuggestedFollowsByActor", service_proxy),
web.get("/xrpc/app.bsky.graph.getActorStarterPacks", service_proxy),
web.post("/xrpc/app.bsky.graph.muteActor", service_proxy),
web.post("/xrpc/app.bsky.graph.unmuteActor", service_proxy),
web.get("/xrpc/app.bsky.feed.getTimeline", service_proxy),
web.get("/xrpc/app.bsky.feed.getAuthorFeed", service_proxy),
web.get("/xrpc/app.bsky.feed.getActorFeeds", service_proxy),
web.get("/xrpc/app.bsky.feed.getFeed", service_proxy),
web.get("/xrpc/app.bsky.feed.getListFeed", service_proxy),
web.get("/xrpc/app.bsky.feed.getFeedGenerator", service_proxy),
web.get("/xrpc/app.bsky.feed.getFeedGenerators", service_proxy),
web.get("/xrpc/app.bsky.feed.getPostThread", service_proxy),
web.get("/xrpc/app.bsky.feed.getPosts", service_proxy),
web.get("/xrpc/app.bsky.feed.getLikes", service_proxy),
web.get("/xrpc/app.bsky.feed.getActorLikes", service_proxy),
web.get("/xrpc/app.bsky.unspecced.getPopularFeedGenerators", service_proxy),
#web.get("/xrpc/chat.bsky.convo.listConvos", static_appview_proxy)
# fmt on
]
)
# app.on_response_prepare.append(prepare_cors_headers)

cors = aiohttp_cors.setup(
app,
defaults={
"*": aiohttp_cors.ResourceOptions(
allow_credentials=True, # TODO: restrict?
expose_headers="*", # TODO: restrict?
allow_headers="*", # TODO: restrict?
max_age=2_000_000_000, # forever (not really, browsers cap this because they're cowards https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age#delta-seconds )
)
},
)

for route in app.router.routes():
cors.add(route)

return app

Expand Down

0 comments on commit 8619b6a

Please sign in to comment.