-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #38 from DavidBuchanan314/token-revocation
refreshSession, deleteSession
- Loading branch information
Showing
8 changed files
with
304 additions
and
74 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# TODO: some smarter way of handling migrations | ||
|
||
import apsw | ||
import apsw.bestpractice | ||
|
||
apsw.bestpractice.apply(apsw.bestpractice.recommended) | ||
|
||
from millipds import static_config | ||
|
||
|
||
def migrate(con: apsw.Connection): | ||
version_now, *_ = con.execute("SELECT db_version FROM config").fetchone() | ||
|
||
assert version_now == 2 | ||
|
||
con.execute( | ||
""" | ||
CREATE TABLE revoked_token( | ||
did TEXT NOT NULL, | ||
jti TEXT NOT NULL, | ||
expires_at INTEGER NOT NULL, | ||
PRIMARY KEY (did, jti) | ||
) STRICT, WITHOUT ROWID | ||
""" | ||
) | ||
|
||
con.execute("UPDATE config SET db_version=3") | ||
|
||
|
||
if __name__ == "__main__": | ||
with apsw.Connection(static_config.MAIN_DB_PATH) as con: | ||
migrate(con) | ||
|
||
print("v2 -> v3 Migration successful") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,7 +25,7 @@ | |
from . import crypto | ||
from . import util | ||
from .appview_proxy import service_proxy | ||
from .auth_bearer import authenticated | ||
from .auth_bearer import authenticated, verify_symmetric_token | ||
from .app_util import * | ||
from .did import DIDResolver | ||
|
||
|
@@ -203,6 +203,53 @@ async def server_describe_server(request: web.Request): | |
) | ||
|
||
|
||
def session_info(request: web.Request) -> dict: | ||
return { | ||
"handle": get_db(request).handle_by_did(request["authed_did"]), | ||
"did": request["authed_did"], | ||
"email": "[email protected]", # this and below are just here for testing lol | ||
"emailConfirmed": True, | ||
# "didDoc": {}, # iiuc this is only used for entryway usecase? | ||
} | ||
|
||
|
||
def generate_session_tokens(request: web.Request) -> dict: | ||
db = get_db(request) | ||
unix_seconds_now = int(time.time()) | ||
# use the same jti for both tokens, so revoking one revokes both | ||
jti = str(uuid.uuid4()) | ||
access_jwt = jwt.encode( | ||
{ | ||
"scope": "com.atproto.access", | ||
"aud": db.config["pds_did"], | ||
"sub": request["authed_did"], | ||
"iat": unix_seconds_now, | ||
"exp": unix_seconds_now + static_config.ACCESS_EXP, | ||
"jti": jti, | ||
}, | ||
db.config["jwt_access_secret"], | ||
"HS256", | ||
) | ||
|
||
refresh_jwt = jwt.encode( | ||
{ | ||
"scope": "com.atproto.refresh", | ||
"aud": db.config["pds_did"], | ||
"sub": request["authed_did"], | ||
"iat": unix_seconds_now, | ||
"exp": unix_seconds_now + static_config.REFRESH_EXP, | ||
"jti": jti, | ||
}, | ||
db.config["jwt_access_secret"], | ||
"HS256", | ||
) | ||
|
||
return { | ||
"accessJwt": access_jwt, | ||
"refreshJwt": refresh_jwt, | ||
} | ||
|
||
|
||
# TODO: ratelimit this!!! | ||
@routes.post("/xrpc/com.atproto.server.createSession") | ||
async def server_create_session(request: web.Request): | ||
|
@@ -228,44 +275,53 @@ async def server_create_session(request: web.Request): | |
except ValueError: | ||
raise web.HTTPUnauthorized(text="incorrect identifier or password") | ||
|
||
# prepare access tokens | ||
unix_seconds_now = int(time.time()) | ||
access_jwt = jwt.encode( | ||
{ | ||
"scope": "com.atproto.access", | ||
"aud": db.config["pds_did"], | ||
"sub": did, | ||
"iat": unix_seconds_now, | ||
"exp": unix_seconds_now + 60 * 60 * 24, # 24h | ||
"jti": str(uuid.uuid4()), | ||
}, | ||
db.config["jwt_access_secret"], | ||
"HS256", | ||
# both generate_session_tokens and session_info need this | ||
request["authed_did"] = did | ||
|
||
return web.json_response( | ||
session_info(request) | generate_session_tokens(request) | ||
) | ||
|
||
refresh_jwt = jwt.encode( | ||
{ | ||
"scope": "com.atproto.refresh", | ||
"aud": db.config["pds_did"], | ||
"sub": did, | ||
"iat": unix_seconds_now, | ||
"exp": unix_seconds_now + 60 * 60 * 24 * 90, # 90 days! | ||
"jti": str(uuid.uuid4()), | ||
}, | ||
db.config["jwt_access_secret"], | ||
"HS256", | ||
|
||
@routes.post("/xrpc/com.atproto.server.refreshSession") | ||
async def server_refresh_session(request: web.Request): | ||
auth = request.headers.get("Authorization", "") | ||
if not auth.startswith("Bearer "): | ||
raise web.HTTPUnauthorized(text="invalid auth type") | ||
token = auth.removeprefix("Bearer ") | ||
token_payload = verify_symmetric_token( | ||
request, token, "com.atproto.refresh" | ||
) | ||
request["authed_did"] = token_payload["sub"] | ||
|
||
get_db(request).con.execute( | ||
"INSERT INTO revoked_token (did, jti, expires_at) VALUES (?, ?, ?)", | ||
(token_payload["sub"], token_payload["jti"], token_payload["exp"]), | ||
) | ||
return web.json_response( | ||
{ | ||
"did": did, | ||
"handle": handle, | ||
"accessJwt": access_jwt, | ||
"refreshJwt": refresh_jwt, | ||
} | ||
session_info(request) | generate_session_tokens(request) | ||
) | ||
|
||
|
||
# NOTE: deleteSession requires refresh token as auth, not access token | ||
@routes.post("/xrpc/com.atproto.server.deleteSession") | ||
async def server_delete_session(request: web.Request): | ||
auth = request.headers.get("Authorization", "") | ||
if not auth.startswith("Bearer "): | ||
raise web.HTTPUnauthorized(text="invalid auth type") | ||
token = auth.removeprefix("Bearer ") | ||
token_payload = verify_symmetric_token( | ||
request, token, "com.atproto.refresh" | ||
) | ||
|
||
get_db(request).con.execute( | ||
"INSERT INTO revoked_token (did, jti, expires_at) VALUES (?, ?, ?)", | ||
(token_payload["sub"], token_payload["jti"], token_payload["exp"]), | ||
) | ||
|
||
return web.Response() | ||
|
||
|
||
@routes.get("/xrpc/com.atproto.server.getServiceAuth") | ||
@authenticated | ||
async def server_get_service_auth(request: web.Request): | ||
|
@@ -302,7 +358,7 @@ async def server_get_service_auth(request: web.Request): | |
"lxm": lxm, | ||
"exp": exp, | ||
"iat": now, | ||
"jti": str(uuid.uuid4()) | ||
"jti": str(uuid.uuid4()), | ||
}, | ||
signing_key, | ||
algorithm=crypto.jwt_signature_alg_for_pem(signing_key), | ||
|
@@ -381,15 +437,7 @@ async def identity_update_handle(request: web.Request): | |
@routes.get("/xrpc/com.atproto.server.getSession") | ||
@authenticated | ||
async def server_get_session(request: web.Request): | ||
return web.json_response( | ||
{ | ||
"handle": get_db(request).handle_by_did(request["authed_did"]), | ||
"did": request["authed_did"], | ||
"email": "[email protected]", # this and below are just here for testing lol | ||
"emailConfirmed": True, | ||
# "didDoc": {}, # iiuc this is only used for entryway usecase? | ||
} | ||
) | ||
return web.json_response(session_info(request)) | ||
|
||
|
||
def construct_app( | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.