Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refreshSession, deleteSession #38

Merged
merged 12 commits into from
Jan 2, 2025
10 changes: 8 additions & 2 deletions migration_scripts/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

from millipds import static_config

with apsw.Connection(static_config.MAIN_DB_PATH) as con:

def migrate(con):
version_now, *_ = con.execute("SELECT db_version FROM config").fetchone()

assert version_now == 1
Expand Down Expand Up @@ -36,4 +37,9 @@

con.execute("UPDATE config SET db_version=2")

print("v1 -> v2 Migration successful")

if __name__ == "__main__":
with apsw.Connection(static_config.MAIN_DB_PATH) as con:
migrate(con)

print("v1 -> v2 Migration successful")
34 changes: 34 additions & 0 deletions migration_scripts/v3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# TODO: some smarter way of handling migrations

import apsw
import apsw.bestpractice

apsw.bestpractice.apply(apsw.bestpractice.recommended)

from millipds import static_config


def migrate(con: apsw.Connection):
version_now, *_ = con.execute("SELECT db_version FROM config").fetchone()

assert version_now == 2

con.execute(
"""
CREATE TABLE revoked_token(
did TEXT NOT NULL,
jti TEXT NOT NULL,
expires_at INTEGER NOT NULL,
PRIMARY KEY (did, jti)
) STRICT, WITHOUT ROWID
"""
)

con.execute("UPDATE config SET db_version=3")


if __name__ == "__main__":
with apsw.Connection(static_config.MAIN_DB_PATH) as con:
migrate(con)

print("v2 -> v3 Migration successful")
12 changes: 7 additions & 5 deletions src/millipds/appview_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,16 @@ async def service_proxy(request: web.Request, service: Optional[str] = None):
)
if did_doc is None:
return web.HTTPInternalServerError(
f"unable to resolve service {service!r}"
text=f"unable to resolve service {service!r}"
)
for service in did_doc.get("service", []):
if service.get("id") == fragment:
service_route = service["serviceEndpoint"]
for service_info in did_doc.get("service", []):
if service_info.get("id") == fragment:
service_route = service_info["serviceEndpoint"]
break
else:
return web.HTTPBadRequest(f"unable to resolve service {service!r}")
return web.HTTPBadRequest(
text=f"unable to resolve service {service!r}"
)
else: # fall thru to assuming bsky appview
service_did = db.config["bsky_appview_did"]
service_route = db.config["bsky_appview_pfx"]
Expand Down
75 changes: 50 additions & 25 deletions src/millipds/auth_bearer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,44 @@
routes = web.RouteTableDef()


def verify_symmetric_token(
request: web.Request, token: str, expected_scope: str
) -> dict:
db = get_db(request)
try:
payload: dict = jwt.decode(
jwt=token,
key=db.config["jwt_access_secret"],
algorithms=["HS256"],
audience=db.config["pds_did"],
options={
"require": ["exp", "iat", "scope", "jti", "sub"],
"verify_exp": True,
"verify_iat": True,
"strict_aud": True, # may be unnecessary
},
)
except jwt.exceptions.PyJWTError:
raise web.HTTPUnauthorized(text="invalid jwt")

revoked = db.con.execute(
"SELECT COUNT(*) FROM revoked_token WHERE did=? AND jti=?",
(payload["sub"], payload["jti"]),
).fetchone()[0]

if revoked:
raise web.HTTPUnauthorized(text="revoked token")

# if we reached this far, the payload must've been signed by us
if payload.get("scope") != expected_scope:
raise web.HTTPUnauthorized(text="invalid jwt scope")

if not payload.get("sub", "").startswith("did:"):
raise web.HTTPUnauthorized(text="invalid jwt: invalid subject")

return payload


def authenticated(handler):
"""
There are three types of auth:
Expand Down Expand Up @@ -39,30 +77,9 @@ async def authentication_handler(request: web.Request, *args, **kwargs):
)
# logger.info(unverified)
if unverified["header"]["alg"] == "HS256": # symmetric secret
try:
payload: dict = jwt.decode(
jwt=token,
key=db.config["jwt_access_secret"],
algorithms=["HS256"],
audience=db.config["pds_did"],
options={
"require": ["exp", "iat", "scope"], # consider iat?
"verify_exp": True,
"verify_iat": True,
"strict_aud": True, # may be unnecessary
},
)
except jwt.exceptions.PyJWTError:
raise web.HTTPUnauthorized(text="invalid jwt")

# if we reached this far, the payload must've been signed by us
if payload.get("scope") != "com.atproto.access":
raise web.HTTPUnauthorized(text="invalid jwt scope")

subject: str = payload.get("sub", "")
if not subject.startswith("did:"):
raise web.HTTPUnauthorized(text="invalid jwt: invalid subject")
request["authed_did"] = subject
request["authed_did"] = verify_symmetric_token(
request, token, "com.atproto.access"
)["sub"]
else: # asymmetric service auth (scoped to a specific lxm)
did: str = unverified["payload"]["iss"]
if not did.startswith("did:"):
Expand All @@ -81,7 +98,7 @@ async def authentication_handler(request: web.Request, *args, **kwargs):
algorithms=[alg],
audience=db.config["pds_did"],
options={
"require": ["exp", "iat", "lxm"],
"require": ["exp", "iat", "lxm", "jti", "iss"],
"verify_exp": True,
"verify_iat": True,
"strict_aud": True, # may be unnecessary
Expand All @@ -90,6 +107,14 @@ async def authentication_handler(request: web.Request, *args, **kwargs):
except jwt.exceptions.PyJWTError:
raise web.HTTPUnauthorized(text="invalid jwt")

revoked = db.con.execute(
"SELECT COUNT(*) FROM revoked_token WHERE did=? AND jti=?",
(payload["iss"], payload["jti"]),
).fetchone()[0]

if revoked:
raise web.HTTPUnauthorized(text="revoked token")

request_lxm = request.path.rpartition("/")[2].partition("?")[0]
if request_lxm != payload.get("lxm"):
raise web.HTTPUnauthorized(text="invalid jwt: bad lxm")
Expand Down
13 changes: 13 additions & 0 deletions src/millipds/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,19 @@ def _init_tables(self):
"""
)

# this is only for the tokens *we* issue, dpop jti will be tracked separately
# there's no point remembering that an expired token was revoked, and we'll garbage-collect these periodically
self.con.execute(
"""
CREATE TABLE revoked_token(
did TEXT NOT NULL,
jti TEXT NOT NULL,
expires_at INTEGER NOT NULL,
PRIMARY KEY (did, jti)
) STRICT, WITHOUT ROWID
"""
)

def update_config(
self,
pds_pfx: Optional[str] = None,
Expand Down
130 changes: 89 additions & 41 deletions src/millipds/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from . import crypto
from . import util
from .appview_proxy import service_proxy
from .auth_bearer import authenticated
from .auth_bearer import authenticated, verify_symmetric_token
from .app_util import *
from .did import DIDResolver

Expand Down Expand Up @@ -203,6 +203,53 @@ async def server_describe_server(request: web.Request):
)


def session_info(request: web.Request) -> dict:
return {
"handle": get_db(request).handle_by_did(request["authed_did"]),
"did": request["authed_did"],
"email": "[email protected]", # this and below are just here for testing lol
"emailConfirmed": True,
# "didDoc": {}, # iiuc this is only used for entryway usecase?
}


def generate_session_tokens(request: web.Request) -> dict:
db = get_db(request)
unix_seconds_now = int(time.time())
# use the same jti for both tokens, so revoking one revokes both
jti = str(uuid.uuid4())
access_jwt = jwt.encode(
{
"scope": "com.atproto.access",
"aud": db.config["pds_did"],
"sub": request["authed_did"],
"iat": unix_seconds_now,
"exp": unix_seconds_now + static_config.ACCESS_EXP,
"jti": jti,
},
db.config["jwt_access_secret"],
"HS256",
)

refresh_jwt = jwt.encode(
{
"scope": "com.atproto.refresh",
"aud": db.config["pds_did"],
"sub": request["authed_did"],
"iat": unix_seconds_now,
"exp": unix_seconds_now + static_config.REFRESH_EXP,
"jti": jti,
},
db.config["jwt_access_secret"],
"HS256",
)

return {
"accessJwt": access_jwt,
"refreshJwt": refresh_jwt,
}


# TODO: ratelimit this!!!
@routes.post("/xrpc/com.atproto.server.createSession")
async def server_create_session(request: web.Request):
Expand All @@ -228,44 +275,53 @@ async def server_create_session(request: web.Request):
except ValueError:
raise web.HTTPUnauthorized(text="incorrect identifier or password")

# prepare access tokens
unix_seconds_now = int(time.time())
access_jwt = jwt.encode(
{
"scope": "com.atproto.access",
"aud": db.config["pds_did"],
"sub": did,
"iat": unix_seconds_now,
"exp": unix_seconds_now + 60 * 60 * 24, # 24h
"jti": str(uuid.uuid4()),
},
db.config["jwt_access_secret"],
"HS256",
# both generate_session_tokens and session_info need this
request["authed_did"] = did

return web.json_response(
session_info(request) | generate_session_tokens(request)
)

refresh_jwt = jwt.encode(
{
"scope": "com.atproto.refresh",
"aud": db.config["pds_did"],
"sub": did,
"iat": unix_seconds_now,
"exp": unix_seconds_now + 60 * 60 * 24 * 90, # 90 days!
"jti": str(uuid.uuid4()),
},
db.config["jwt_access_secret"],
"HS256",

@routes.post("/xrpc/com.atproto.server.refreshSession")
async def server_refresh_session(request: web.Request):
auth = request.headers.get("Authorization", "")
if not auth.startswith("Bearer "):
raise web.HTTPUnauthorized(text="invalid auth type")
token = auth.removeprefix("Bearer ")
token_payload = verify_symmetric_token(
request, token, "com.atproto.refresh"
)
request["authed_did"] = token_payload["sub"]

get_db(request).con.execute(
"INSERT INTO revoked_token (did, jti, expires_at) VALUES (?, ?, ?)",
(token_payload["sub"], token_payload["jti"], token_payload["exp"]),
)
return web.json_response(
{
"did": did,
"handle": handle,
"accessJwt": access_jwt,
"refreshJwt": refresh_jwt,
}
session_info(request) | generate_session_tokens(request)
)


# NOTE: deleteSession requires refresh token as auth, not access token
@routes.post("/xrpc/com.atproto.server.deleteSession")
async def server_delete_session(request: web.Request):
auth = request.headers.get("Authorization", "")
if not auth.startswith("Bearer "):
raise web.HTTPUnauthorized(text="invalid auth type")
token = auth.removeprefix("Bearer ")
token_payload = verify_symmetric_token(
request, token, "com.atproto.refresh"
)

get_db(request).con.execute(
"INSERT INTO revoked_token (did, jti, expires_at) VALUES (?, ?, ?)",
(token_payload["sub"], token_payload["jti"], token_payload["exp"]),
)

return web.Response()


@routes.get("/xrpc/com.atproto.server.getServiceAuth")
@authenticated
async def server_get_service_auth(request: web.Request):
Expand Down Expand Up @@ -302,7 +358,7 @@ async def server_get_service_auth(request: web.Request):
"lxm": lxm,
"exp": exp,
"iat": now,
"jti": str(uuid.uuid4())
"jti": str(uuid.uuid4()),
},
signing_key,
algorithm=crypto.jwt_signature_alg_for_pem(signing_key),
Expand Down Expand Up @@ -381,15 +437,7 @@ async def identity_update_handle(request: web.Request):
@routes.get("/xrpc/com.atproto.server.getSession")
@authenticated
async def server_get_session(request: web.Request):
return web.json_response(
{
"handle": get_db(request).handle_by_did(request["authed_did"]),
"did": request["authed_did"],
"email": "[email protected]", # this and below are just here for testing lol
"emailConfirmed": True,
# "didDoc": {}, # iiuc this is only used for entryway usecase?
}
)
return web.json_response(session_info(request))


def construct_app(
Expand Down
5 changes: 4 additions & 1 deletion src/millipds/static_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
GROUPNAME = "millipds-sock"

# this gets bumped if we make breaking changes to the db schema
MILLIPDS_DB_VERSION = 2
MILLIPDS_DB_VERSION = 3

ATPROTO_REPO_VERSION_3 = 3 # might get bumped if the atproto spec changes
CAR_VERSION_1 = 1
Expand All @@ -29,3 +29,6 @@
DID_CACHE_ERROR_TTL = 60 * 5 # 5 mins

PLC_DIRECTORY_HOST = "https://plc.directory"

ACCESS_EXP = 60 * 60 * 2 # 2 h
REFRESH_EXP = 60 * 60 * 24 * 90 # 90 days
Loading
Loading