Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

atmst refactoring #25

Merged
merged 10 commits into from
Dec 21, 2024
7 changes: 5 additions & 2 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
"-p",
"test_*.py"
],
"python.testing.pytestEnabled": false,
"python.testing.unittestEnabled": true
"python.testing.pytestEnabled": true,
"python.testing.unittestEnabled": false,
"python.testing.pytestArgs": [
"."
]
}
11 changes: 8 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ classifiers = [
]
dependencies = [
"cbrrr >= 1.0.0, < 2",
"atmst >= 0.0.4",
"atmst >= 0.0.6",
"pyjwt[crypto]",
"cryptography",
"aiohttp",
Expand All @@ -42,6 +42,7 @@ test = [
"pytest",
"pytest-aio",
"pytest-aiohttp",
"pytest-depends",
]

[project.urls]
Expand All @@ -51,7 +52,11 @@ Issues = "https://github.com/DavidBuchanan314/millipds/issues"
[project.scripts]
millipds = "millipds.__main__:main"

[tool.ruff.format]
indent-style = "tab"
[tool.ruff]
line-length = 80
format.indent-style = "tab"

[tool.setuptools_scm]

[tool.pytest.ini_options]
asyncio_default_fixture_loop_scope = "session"
67 changes: 44 additions & 23 deletions src/millipds/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
from getpass import getpass

from docopt import docopt
import aiohttp

import cbrrr

Expand All @@ -85,7 +86,8 @@ def main():
"""

args = docopt(
__doc__, version=f"millipds version {importlib.metadata.version('millipds')}"
__doc__,
version=f"millipds version {importlib.metadata.version('millipds')}",
)

if args["init"]:
Expand All @@ -96,14 +98,16 @@ def main():
" or manually delete the db and try again."
)
return
if args["--dev"]: # like prod but http://
if args["--dev"]: # like prod but http://
db.update_config(
pds_pfx=f'http://{args["<hostname>"]}',
pds_did=f'did:web:{urllib.parse.quote(args["<hostname>"])}',
bsky_appview_pfx="https://api.bsky.app",
bsky_appview_did="did:web:api.bsky.app",
)
elif args["--sandbox"]: # now-defunct, need to figure out how to point at local infra
elif args[
"--sandbox"
]: # now-defunct, need to figure out how to point at local infra
db.update_config(
pds_pfx=f'https://{args["<hostname>"]}',
pds_did=f'did:web:{urllib.parse.quote(args["<hostname>"])}',
Expand All @@ -121,11 +125,15 @@ def main():
db.print_config()
return
elif args["util"]:
if args["keygen"]: # TODO: deprecate in favour of openssl?
if args["keygen"]: # TODO: deprecate in favour of openssl?
if args["--k256"]:
privkey = crypto.keygen_k256() # openssl ecparam -name secp256k1 -genkey -noout
else: # default
privkey = crypto.keygen_p256() # openssl ecparam -name prime256v1 -genkey -noout
privkey = (
crypto.keygen_k256()
) # openssl ecparam -name secp256k1 -genkey -noout
else: # default
privkey = (
crypto.keygen_p256()
) # openssl ecparam -name prime256v1 -genkey -noout
print(crypto.privkey_to_pem(privkey), end="")
elif args["print_pubkey"]:
with open(args["<pem>"]) as pem:
Expand All @@ -142,20 +150,27 @@ def main():
raise ValueError("invalid did:key")
genesis = {
"type": "plc_operation",
"rotationKeys": [ crypto.encode_pubkey_as_did_key(rotation_key.public_key()) ],
"verificationMethods": { "atproto": args["--repo_pubkey"] },
"alsoKnownAs": [ "at://" + args["--handle"] ],
"rotationKeys": [
crypto.encode_pubkey_as_did_key(rotation_key.public_key())
],
"verificationMethods": {"atproto": args["--repo_pubkey"]},
"alsoKnownAs": ["at://" + args["--handle"]],
"services": {
"atproto_pds": {
"type": "AtprotoPersonalDataServer",
"endpoint": args["--pds_host"]
"endpoint": args["--pds_host"],
}
},
"prev": None,
}
genesis["sig"] = crypto.plc_sign(rotation_key, genesis)
genesis_digest = hashlib.sha256(cbrrr.encode_dag_cbor(genesis)).digest()
plc = "did:plc:" + base64.b32encode(genesis_digest)[:24].lower().decode()
genesis_digest = hashlib.sha256(
cbrrr.encode_dag_cbor(genesis)
).digest()
plc = (
"did:plc:"
+ base64.b32encode(genesis_digest)[:24].lower().decode()
)
with open(args["--genesis_json"], "w") as out:
json.dump(genesis, out, indent=4)
print(plc)
Expand All @@ -167,8 +182,10 @@ def main():
if args["--prev_op"]:
with open(args["--prev_op"]) as op_json:
prev_op = json.load(op_json)
op["prev"] = cbrrr.CID.cidv1_dag_cbor_sha256_32_from(cbrrr.encode_dag_cbor(prev_op)).encode()
del op["sig"] # remove any existing sig
op["prev"] = cbrrr.CID.cidv1_dag_cbor_sha256_32_from(
cbrrr.encode_dag_cbor(prev_op)
).encode()
del op["sig"] # remove any existing sig
op["sig"] = crypto.plc_sign(rotation_key, op)
print(json.dumps(op, indent=4))
else:
Expand Down Expand Up @@ -215,14 +232,18 @@ def main():
else:
print("invalid account subcommand")
elif args["run"]:
asyncio.run(
service.run(
db=db,
sock_path=args["--sock_path"],
host=args["--listen_host"],
port=int(args["--listen_port"]),
)
)

async def run_service_with_client():
async with aiohttp.ClientSession() as client:
await service.run(
db=db,
client=client,
sock_path=args["--sock_path"],
host=args["--listen_host"],
port=int(args["--listen_port"]),
)

asyncio.run(run_service_with_client())
else:
print("CLI arg parse error?!")

Expand Down
42 changes: 34 additions & 8 deletions src/millipds/app_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,42 @@

from . import database

MILLIPDS_DB = web.AppKey("MILLIPDS_DB", database.Database)
MILLIPDS_AIOHTTP_CLIENT = web.AppKey(
"MILLIPDS_AIOHTTP_CLIENT", aiohttp.ClientSession
)
MILLIPDS_FIREHOSE_QUEUES = web.AppKey(
"MILLIPDS_FIREHOSE_QUEUES", Set[asyncio.Queue[Optional[Tuple[int, bytes]]]]
)
MILLIPDS_FIREHOSE_QUEUES_LOCK = web.AppKey(
"MILLIPDS_FIREHOSE_QUEUES_LOCK", asyncio.Lock
)


# these helpers are useful for conciseness and type hinting
def get_db(req: web.Request) -> database.Database:
return req.app["MILLIPDS_DB"]
def get_db(req: web.Request):
return req.app[MILLIPDS_DB]


def get_client(req: web.Request):
return req.app[MILLIPDS_AIOHTTP_CLIENT]


def get_firehose_queues(req: web.Request):
return req.app[MILLIPDS_FIREHOSE_QUEUES]


def get_client(req: web.Request) -> aiohttp.ClientSession:
return req.app["MILLIPDS_AIOHTTP_CLIENT"]
def get_firehose_queues_lock(req: web.Request):
return req.app[MILLIPDS_FIREHOSE_QUEUES_LOCK]

def get_firehose_queues(req: web.Request) -> Set[asyncio.Queue[Optional[Tuple[int, bytes]]]]:
return req.app["MILLIPDS_FIREHOSE_QUEUES"]

def get_firehose_queues_lock(req: web.Request) -> asyncio.Lock:
return req.app["MILLIPDS_FIREHOSE_QUEUES_LOCK"]
__all__ = [
"MILLIPDS_DB",
"MILLIPDS_AIOHTTP_CLIENT",
"MILLIPDS_FIREHOSE_QUEUES",
"MILLIPDS_FIREHOSE_QUEUES_LOCK",
"get_db",
"get_client",
"get_firehose_queues",
"get_firehose_queues_lock",
]
12 changes: 7 additions & 5 deletions src/millipds/appview_proxy.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Optional
import logging
import time

Expand All @@ -12,12 +13,11 @@


# TODO: this should be done via actual DID resolution, not hardcoded!
SERVICE_ROUTES = {
"did:web:api.bsky.chat#bsky_chat": "https://api.bsky.chat"
}
SERVICE_ROUTES = {"did:web:api.bsky.chat#bsky_chat": "https://api.bsky.chat"}


@authenticated
async def service_proxy(request: web.Request, service: Optional[str]=None):
async def service_proxy(request: web.Request, service: Optional[str] = None):
"""
If `service` is None, default to bsky appview (per details in db config)
"""
Expand Down Expand Up @@ -59,7 +59,9 @@ async def service_proxy(request: web.Request, service: Optional[str]=None):
elif request.method == "POST":
request_body = await request.read() # TODO: streaming?
async with get_client(request).post(
service_route + request.path, data=request_body, headers=(authn|{"Content-Type": request.content_type})
service_route + request.path,
data=request_body,
headers=(authn | {"Content-Type": request.content_type}),
) as r:
body_bytes = await r.read() # TODO: streaming?
return web.Response(
Expand Down
Loading
Loading