Skip to content

Commit

Permalink
Merge pull request #31 from DavidBuchanan314/did-resolution
Browse files Browse the repository at this point in the history
DID resolution
  • Loading branch information
DavidBuchanan314 authored Dec 24, 2024
2 parents b837ab6 + 865c056 commit 648d065
Show file tree
Hide file tree
Showing 12 changed files with 367 additions and 35 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ run-name: ${{ github.actor }} is running tests
on:
push:
branches:
- '*'
- main
pull_request:
branches:
- main
Expand Down
39 changes: 39 additions & 0 deletions migration_scripts/v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# TODO: some smarter way of handling migrations

import apsw
import apsw.bestpractice

apsw.bestpractice.apply(apsw.bestpractice.recommended)

from millipds import static_config

with apsw.Connection(static_config.MAIN_DB_PATH) as con:
version_now, *_ = con.execute("SELECT db_version FROM config").fetchone()

assert version_now == 1

con.execute(
"""
CREATE TABLE did_cache(
did TEXT PRIMARY KEY NOT NULL,
doc TEXT,
created_at INTEGER NOT NULL,
expires_at INTEGER NOT NULL
)
"""
)

con.execute(
"""
CREATE TABLE handle_cache(
handle TEXT PRIMARY KEY NOT NULL,
did TEXT,
created_at INTEGER NOT NULL,
expires_at INTEGER NOT NULL
)
"""
)

con.execute("UPDATE config SET db_version=2")

print("v1 -> v2 Migration successful")
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ dependencies = [
"pyjwt[crypto]",
"cryptography",
"aiohttp",
"aiodns", # goes faster, apparently
"aiohttp-middlewares", # cors
"docopt",
"apsw",
Expand Down
6 changes: 4 additions & 2 deletions src/millipds/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@
from getpass import getpass

from docopt import docopt
import aiohttp
from .ssrf import get_ssrf_safe_client


import cbrrr

Expand Down Expand Up @@ -234,7 +235,8 @@ def main():
elif args["run"]:

async def run_service_with_client():
async with aiohttp.ClientSession() as client:
# TODO: option to use regular unsafe client for local dev testing
async with get_ssrf_safe_client() as client:
await service.run(
db=db,
client=client,
Expand Down
8 changes: 8 additions & 0 deletions src/millipds/app_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from aiohttp import web

from . import database
from .did import DIDResolver

MILLIPDS_DB = web.AppKey("MILLIPDS_DB", database.Database)
MILLIPDS_AIOHTTP_CLIENT = web.AppKey(
Expand All @@ -16,6 +17,7 @@
MILLIPDS_FIREHOSE_QUEUES_LOCK = web.AppKey(
"MILLIPDS_FIREHOSE_QUEUES_LOCK", asyncio.Lock
)
MILLIPDS_DID_RESOLVER = web.AppKey("MILLIPDS_DID_RESOLVER", DIDResolver)


# these helpers are useful for conciseness and type hinting
Expand All @@ -35,13 +37,19 @@ def get_firehose_queues_lock(req: web.Request):
return req.app[MILLIPDS_FIREHOSE_QUEUES_LOCK]


def get_did_resolver(req: web.Request):
return req.app[MILLIPDS_DID_RESOLVER]


__all__ = [
"MILLIPDS_DB",
"MILLIPDS_AIOHTTP_CLIENT",
"MILLIPDS_FIREHOSE_QUEUES",
"MILLIPDS_FIREHOSE_QUEUES_LOCK",
"MILLIPDS_DID_RESOLVER",
"get_db",
"get_client",
"get_firehose_queues",
"get_firehose_queues_lock",
"get_did_resolver",
]
27 changes: 15 additions & 12 deletions src/millipds/appview_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,6 @@
logger = logging.getLogger(__name__)


# TODO: this should be done via actual DID resolution, not hardcoded!
SERVICE_ROUTES = {
"did:web:api.bsky.chat#bsky_chat": "https://api.bsky.chat",
"did:web:discover.bsky.app#bsky_fg": "https://discover.bsky.app",
"did:plc:ar7c4by46qjdydhdevvrndac#atproto_labeler": "https://mod.bsky.app",
}


@authenticated
async def service_proxy(request: web.Request, service: Optional[str] = None):
"""
Expand All @@ -30,11 +22,22 @@ async def service_proxy(request: web.Request, service: Optional[str] = None):
logger.info(f"proxying lxm {lxm}")
db = get_db(request)
if service:
service_did = service.partition("#")[0]
service_route = SERVICE_ROUTES.get(service)
if service_route is None:
service_did, _, fragment = service.partition("#")
fragment = "#" + fragment
did_doc = await get_did_resolver(request).resolve_with_db_cache(
db, service_did
)
if did_doc is None:
return web.HTTPInternalServerError(
f"unable to resolve service {service!r}"
)
for service in did_doc.get("service", []):
if service.get("id") == fragment:
service_route = service["serviceEndpoint"]
break
else:
return web.HTTPBadRequest(f"unable to resolve service {service!r}")
else:
else: # fall thru to assuming bsky appview
service_did = db.config["bsky_appview_did"]
service_route = db.config["bsky_appview_pfx"]

Expand Down
39 changes: 33 additions & 6 deletions src/millipds/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,21 @@ class Database:
def __init__(self, path: str = static_config.MAIN_DB_PATH) -> None:
logger.info(f"opening database at {path}")
self.path = path
util.mkdirs_for_file(path)
if "/" in path:
util.mkdirs_for_file(path)
self.con = self.new_con()
self.pw_hasher = argon2.PasswordHasher()

try:
config_exists = self.con.execute(
"SELECT count(*) FROM sqlite_master WHERE type='table' AND name='config'"
).fetchone()[0]

if config_exists:
if self.config["db_version"] != static_config.MILLIPDS_DB_VERSION:
raise Exception(
"unrecognised db version (TODO: db migrations?!)"
)

except apsw.SQLError as e: # no such table, so lets create it
if "no such table" not in str(e):
raise
else:
with self.con:
self._init_tables()

Expand Down Expand Up @@ -216,6 +218,31 @@ def _init_tables(self):
"""
)

# we cache failures too, represented as a null doc (with shorter TTL)
# timestamps are unix timestamp ints, in seconds
self.con.execute(
"""
CREATE TABLE did_cache(
did TEXT PRIMARY KEY NOT NULL,
doc TEXT,
created_at INTEGER NOT NULL,
expires_at INTEGER NOT NULL
)
"""
)

# likewise, a null did represents a failed resolution
self.con.execute(
"""
CREATE TABLE handle_cache(
handle TEXT PRIMARY KEY NOT NULL,
did TEXT,
created_at INTEGER NOT NULL,
expires_at INTEGER NOT NULL
)
"""
)

def update_config(
self,
pds_pfx: Optional[str] = None,
Expand Down
180 changes: 180 additions & 0 deletions src/millipds/did.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import aiohttp
import asyncio
from typing import Dict, Callable, Any, Awaitable, Optional
import re
import json
import time
import logging

from .database import Database
from . import util
from . import static_config

logger = logging.getLogger(__name__)

DIDDoc = Dict[str, Any]

"""
Security considerations for DID resolution:
- SSRF - not handled here!!! - caller must pass in an "SSRF safe" ClientSession
- Overly long DID strings (handled here via a hard limit (2KiB))
- Overly long DID document responses (handled here via a hard limit (64KiB))
- Servers that are slow to respond (handled via timeouts configured in the ClientSession)
- Non-canonically-encoded DIDs (handled here via strict regex - for now we don't support percent-encoding at all)
"""


class DIDResolver:
DID_LENGTH_LIMIT = 2048
DIDDOC_LENGTH_LIMIT = 0x10000

def __init__(
self,
session: aiohttp.ClientSession,
plc_directory_host: str = static_config.PLC_DIRECTORY_HOST,
) -> None:
self.session: aiohttp.ClientSession = session
self.plc_directory_host: str = plc_directory_host
self.did_methods: Dict[str, Callable[[str], Awaitable[DIDDoc]]] = {
"web": self.resolve_did_web,
"plc": self.resolve_did_plc,
}

self._concurrent_query_locks = util.PartitionedLock()

# keep stats for logging
self.hits = 0
self.misses = 0

# note: the uncached methods raise exceptions on failure, but this one returns None
async def resolve_with_db_cache(
self, db: Database, did: str
) -> Optional[DIDDoc]:
"""
If we fired off two concurrent queries for the same DID, the second would
be a waste of resources. By using a per-DID locking scheme, we ensure that
any subsequent queries wait for the first one to complete - by which time
the cache will be primed and the second query can return the cached result.
TODO: maybe consider an in-memory cache, too? Probably not worth it.
"""
async with self._concurrent_query_locks.get_lock(did):
# try the db first
now = int(time.time())
row = db.con.execute(
"SELECT doc FROM did_cache WHERE did=? AND ?<expires_at",
(did, now),
).fetchone()

# cache hit
if row is not None:
self.hits += 1
doc = row[0]
return None if doc is None else json.loads(doc)

# cache miss
self.misses += 1
logger.info(
f"DID cache miss for {did}. Total hits: {self.hits}, Total misses: {self.misses}"
)
try:
doc = await self.resolve_uncached(did)
logger.info(f"Successfully resolved {did}")
except Exception as e:
logger.exception(f"Error resolving {did}: {e}")
doc = None

# update "now" because resolution might've taken a while
now = int(time.time())
expires_at = now + (
static_config.DID_CACHE_ERROR_TTL
if doc is None
else static_config.DID_CACHE_TTL
)

# update the cache (note: we cache failures too, but with a shorter TTL)
# TODO: if current doc is None, only replace if the existing entry is also None
db.con.execute(
"INSERT OR REPLACE INTO did_cache (did, doc, created_at, expires_at) VALUES (?, ?, ?, ?)",
(
did,
None if doc is None else util.compact_json(doc),
now,
expires_at,
),
)

return doc

async def resolve_uncached(self, did: str) -> DIDDoc:
if len(did) > self.DID_LENGTH_LIMIT:
raise ValueError("DID too long for atproto")
scheme, method, *_ = did.split(":")
if scheme != "did":
raise ValueError("not a valid DID")
resolver = self.did_methods.get(method)
if resolver is None:
raise ValueError(f"Unsupported DID method: {method}")
return await resolver(did)

# 64k ought to be enough for anyone!
async def _get_json_with_limit(self, url: str, limit: int) -> DIDDoc:
async with self.session.get(url) as r:
r.raise_for_status()
try:
await r.content.readexactly(limit)
raise ValueError("DID document too large")
except asyncio.IncompleteReadError as e:
# this is actually the happy path
return json.loads(e.partial)

async def resolve_did_web(self, did: str) -> DIDDoc:
# TODO: support port numbers on localhost?
if not re.match(r"^did:web:[a-z0-9\.\-]+$", did):
raise ValueError("Invalid did:web")
host = did.rpartition(":")[2]

return await self._get_json_with_limit(
f"https://{host}/.well-known/did.json", self.DIDDOC_LENGTH_LIMIT
)

async def resolve_did_plc(self, did: str) -> DIDDoc:
if not re.match(r"^did:plc:[a-z2-7]+$", did): # base32-sortable
raise ValueError("Invalid did:plc")

return await self._get_json_with_limit(
f"{self.plc_directory_host}/{did}", self.DIDDOC_LENGTH_LIMIT
)


async def main() -> None:
# TODO: move these tests into a proper pytest file

async with aiohttp.ClientSession() as session:
TEST_DIDWEB = "did:web:retr0.id" # TODO: don't rely on external infra
resolver = DIDResolver(session)
print(await resolver.resolve_uncached(TEST_DIDWEB))
print(
await resolver.resolve_uncached("did:plc:vwzwgnygau7ed7b7wt5ux7y2")
)

db = Database(":memory:")
a = resolver.resolve_with_db_cache(db, TEST_DIDWEB)
b = resolver.resolve_with_db_cache(db, TEST_DIDWEB)
res_a, res_b = await asyncio.gather(a, b)
assert res_a == res_b

# if not for _concurrent_query_locks, we'd have 2 misses and 0 hits
# (because the second query would start before the first one finishes
# and primes the cache)
assert resolver.hits == 1
assert resolver.misses == 1

# check that the WeakValueDictionary is doing its thing (i.e. no memory leaks)
assert list(resolver._concurrent_query_locks._locks.keys()) == []


if __name__ == "__main__":
asyncio.run(main())
Loading

0 comments on commit 648d065

Please sign in to comment.