diff --git a/src/millipds/auth_bearer.py b/src/millipds/auth_bearer.py index 63440ca..53f63a5 100644 --- a/src/millipds/auth_bearer.py +++ b/src/millipds/auth_bearer.py @@ -4,6 +4,7 @@ from aiohttp import web from .app_util import * +from . import crypto logger = logging.getLogger(__name__) @@ -11,6 +12,13 @@ def authenticated(handler): + """ + There are three types of auth: + - bearer token signed by symmetric secret (generated by us during the password login flow) + - "service" bearer token signed by (asymmetric) repo signing key, scoped to a specific lxm + - whatever I do for oauth (TODO) + """ + async def authentication_handler(request: web.Request, *args, **kwargs): # extract the auth token auth = request.headers.get("Authorization") @@ -25,30 +33,69 @@ async def authentication_handler(request: web.Request, *args, **kwargs): # validate it TODO: this needs rigorous testing, I'm not 100% sure I'm # verifying all the things that need verifying db = get_db(request) - try: - payload: dict = jwt.decode( - jwt=token, - key=db.config["jwt_access_secret"], - algorithms=["HS256"], - audience=db.config["pds_did"], - options={ - "require": ["exp", "iat", "scope"], # consider iat? - "verify_exp": True, - "verify_iat": True, - "strict_aud": True, # may be unnecessary - }, - ) - except jwt.exceptions.PyJWTError: - raise web.HTTPUnauthorized(text="invalid jwt") - # if we reached this far, the payload must've been signed by us - if payload.get("scope") != "com.atproto.access": - raise web.HTTPUnauthorized(text="invalid jwt scope") + unverified = jwt.api_jwt.decode_complete( + token, options={"verify_signature": False} + ) + # logger.info(unverified) + if unverified["header"]["alg"] == "HS256": # symmetric secret + try: + payload: dict = jwt.decode( + jwt=token, + key=db.config["jwt_access_secret"], + algorithms=["HS256"], + audience=db.config["pds_did"], + options={ + "require": ["exp", "iat", "scope"], # consider iat? + "verify_exp": True, + "verify_iat": True, + "strict_aud": True, # may be unnecessary + }, + ) + except jwt.exceptions.PyJWTError: + raise web.HTTPUnauthorized(text="invalid jwt") + + # if we reached this far, the payload must've been signed by us + if payload.get("scope") != "com.atproto.access": + raise web.HTTPUnauthorized(text="invalid jwt scope") + + subject: str = payload.get("sub", "") + if not subject.startswith("did:"): + raise web.HTTPUnauthorized(text="invalid jwt: invalid subject") + request["authed_did"] = subject + else: # asymmetric service auth (scoped to a specific lxm) + did: str = unverified["payload"]["iss"] + if not did.startswith("did:"): + raise web.HTTPUnauthorized(text="invalid jwt: invalid issuer") + signing_key_pem = db.signing_key_pem_by_did(did) + if signing_key_pem is None: + raise web.HTTPUnauthorized(text="invalid jwt: unknown issuer") + alg = crypto.jwt_signature_alg_for_pem(signing_key_pem) + verifying_key = crypto.privkey_from_pem( + signing_key_pem + ).public_key() + try: + payload = jwt.decode( + jwt=token, + key=verifying_key, + algorithms=[alg], + audience=db.config["pds_did"], + options={ + "require": ["exp", "lxm"], # consider iat? + "verify_exp": True, + "strict_aud": True, # may be unnecessary + }, + ) + except jwt.exceptions.PyJWTError: + raise web.HTTPUnauthorized(text="invalid jwt") + + request_lxm = request.path.rpartition("/")[2].partition("?")[0] + if request_lxm != payload.get("lxm"): + raise web.HTTPUnauthorized(text="invalid jwt: bad lxm") + + # everything checks out + request["authed_did"] = did - subject: str = payload.get("sub", "") - if not subject.startswith("did:"): - raise web.HTTPUnauthorized(text="invalid jwt: invalid subject") - request["authed_did"] = subject return await handler(request, *args, **kwargs) return authentication_handler diff --git a/src/millipds/service.py b/src/millipds/service.py index 5a38c00..ccd89f7 100644 --- a/src/millipds/service.py +++ b/src/millipds/service.py @@ -272,10 +272,26 @@ async def server_create_session(request: web.Request): async def server_get_service_auth(request: web.Request): aud = request.query.get("aud") lxm = request.query.get("lxm") - # note, we ignore exp for now + + # default to 60s into the future + now = int(time.time()) + exp = int(request.query.get("exp", now + 60)) + + # lxm is not required by the lexicon but I'm requiring it anyway if not (aud and lxm): raise web.HTTPBadRequest(text="missing aud or lxm") - # TODO: validation of aud and lxm? + if lxm == "com.atproto.server.getServiceAuth": + raise web.HTTPBadRequest(text="can't generate auth tokens recursively!") + + max_exp = now + 60 * 30 # 30 mins + if exp > max_exp: + logger.info( + f"requested exp too far into the future, truncating to {max_exp}" + ) + exp = max_exp + + # TODO: strict validation of aud and lxm? + db = get_db(request) signing_key = db.signing_key_pem_by_did(request["authed_did"]) return web.json_response( @@ -285,7 +301,7 @@ async def server_get_service_auth(request: web.Request): "iss": request["authed_did"], "aud": aud, "lxm": lxm, - "exp": int(time.time()) + 60, # 60s + "exp": exp, }, signing_key, algorithm=crypto.jwt_signature_alg_for_pem(signing_key), diff --git a/tests/integration_test.py b/tests/integration_test.py index 7c8cf37..3657184 100644 --- a/tests/integration_test.py +++ b/tests/integration_test.py @@ -373,3 +373,24 @@ async def test_sync_getRecord_existent(s, populated_pds_host): assert proof_car # nonempty # TODO: make sure the proof is valid, and contains the record assert b"test record" in proof_car + + +async def test_seviceauth(s, test_pds, auth_headers): + async with s.get( + test_pds.endpoint + "/xrpc/com.atproto.server.getServiceAuth", + headers=auth_headers, + params={ + "aud": test_pds.db.config["pds_did"], + "lxm": "com.atproto.server.getSession", + }, + ) as r: + assert r.status == 200 + token = (await r.json())["token"] + + # test if the token works + async with s.get( + test_pds.endpoint + "/xrpc/com.atproto.server.getSession", + headers={"Authorization": "Bearer " + token}, + ) as r: + assert r.status == 200 + await r.json()