Skip to content

Commit

Permalink
Merge pull request #25 from DavidBuchanan314/atmst-next
Browse files Browse the repository at this point in the history
atmst refactoring, clean up warnings emitted during testing
  • Loading branch information
DavidBuchanan314 authored Dec 21, 2024
2 parents b0c1f9f + 5b52714 commit 2b4dc1c
Show file tree
Hide file tree
Showing 17 changed files with 1,055 additions and 631 deletions.
7 changes: 5 additions & 2 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
"-p",
"test_*.py"
],
"python.testing.pytestEnabled": false,
"python.testing.unittestEnabled": true
"python.testing.pytestEnabled": true,
"python.testing.unittestEnabled": false,
"python.testing.pytestArgs": [
"."
]
}
11 changes: 8 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ classifiers = [
]
dependencies = [
"cbrrr >= 1.0.0, < 2",
"atmst >= 0.0.4",
"atmst >= 0.0.6",
"pyjwt[crypto]",
"cryptography",
"aiohttp",
Expand All @@ -42,6 +42,7 @@ test = [
"pytest",
"pytest-aio",
"pytest-aiohttp",
"pytest-depends",
]

[project.urls]
Expand All @@ -51,7 +52,11 @@ Issues = "https://github.com/DavidBuchanan314/millipds/issues"
[project.scripts]
millipds = "millipds.__main__:main"

[tool.ruff.format]
indent-style = "tab"
[tool.ruff]
line-length = 80
format.indent-style = "tab"

[tool.setuptools_scm]

[tool.pytest.ini_options]
asyncio_default_fixture_loop_scope = "session"
67 changes: 44 additions & 23 deletions src/millipds/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
from getpass import getpass

from docopt import docopt
import aiohttp

import cbrrr

Expand All @@ -85,7 +86,8 @@ def main():
"""

args = docopt(
__doc__, version=f"millipds version {importlib.metadata.version('millipds')}"
__doc__,
version=f"millipds version {importlib.metadata.version('millipds')}",
)

if args["init"]:
Expand All @@ -96,14 +98,16 @@ def main():
" or manually delete the db and try again."
)
return
if args["--dev"]: # like prod but http://
if args["--dev"]: # like prod but http://
db.update_config(
pds_pfx=f'http://{args["<hostname>"]}',
pds_did=f'did:web:{urllib.parse.quote(args["<hostname>"])}',
bsky_appview_pfx="https://api.bsky.app",
bsky_appview_did="did:web:api.bsky.app",
)
elif args["--sandbox"]: # now-defunct, need to figure out how to point at local infra
elif args[
"--sandbox"
]: # now-defunct, need to figure out how to point at local infra
db.update_config(
pds_pfx=f'https://{args["<hostname>"]}',
pds_did=f'did:web:{urllib.parse.quote(args["<hostname>"])}',
Expand All @@ -121,11 +125,15 @@ def main():
db.print_config()
return
elif args["util"]:
if args["keygen"]: # TODO: deprecate in favour of openssl?
if args["keygen"]: # TODO: deprecate in favour of openssl?
if args["--k256"]:
privkey = crypto.keygen_k256() # openssl ecparam -name secp256k1 -genkey -noout
else: # default
privkey = crypto.keygen_p256() # openssl ecparam -name prime256v1 -genkey -noout
privkey = (
crypto.keygen_k256()
) # openssl ecparam -name secp256k1 -genkey -noout
else: # default
privkey = (
crypto.keygen_p256()
) # openssl ecparam -name prime256v1 -genkey -noout
print(crypto.privkey_to_pem(privkey), end="")
elif args["print_pubkey"]:
with open(args["<pem>"]) as pem:
Expand All @@ -142,20 +150,27 @@ def main():
raise ValueError("invalid did:key")
genesis = {
"type": "plc_operation",
"rotationKeys": [ crypto.encode_pubkey_as_did_key(rotation_key.public_key()) ],
"verificationMethods": { "atproto": args["--repo_pubkey"] },
"alsoKnownAs": [ "at://" + args["--handle"] ],
"rotationKeys": [
crypto.encode_pubkey_as_did_key(rotation_key.public_key())
],
"verificationMethods": {"atproto": args["--repo_pubkey"]},
"alsoKnownAs": ["at://" + args["--handle"]],
"services": {
"atproto_pds": {
"type": "AtprotoPersonalDataServer",
"endpoint": args["--pds_host"]
"endpoint": args["--pds_host"],
}
},
"prev": None,
}
genesis["sig"] = crypto.plc_sign(rotation_key, genesis)
genesis_digest = hashlib.sha256(cbrrr.encode_dag_cbor(genesis)).digest()
plc = "did:plc:" + base64.b32encode(genesis_digest)[:24].lower().decode()
genesis_digest = hashlib.sha256(
cbrrr.encode_dag_cbor(genesis)
).digest()
plc = (
"did:plc:"
+ base64.b32encode(genesis_digest)[:24].lower().decode()
)
with open(args["--genesis_json"], "w") as out:
json.dump(genesis, out, indent=4)
print(plc)
Expand All @@ -167,8 +182,10 @@ def main():
if args["--prev_op"]:
with open(args["--prev_op"]) as op_json:
prev_op = json.load(op_json)
op["prev"] = cbrrr.CID.cidv1_dag_cbor_sha256_32_from(cbrrr.encode_dag_cbor(prev_op)).encode()
del op["sig"] # remove any existing sig
op["prev"] = cbrrr.CID.cidv1_dag_cbor_sha256_32_from(
cbrrr.encode_dag_cbor(prev_op)
).encode()
del op["sig"] # remove any existing sig
op["sig"] = crypto.plc_sign(rotation_key, op)
print(json.dumps(op, indent=4))
else:
Expand Down Expand Up @@ -215,14 +232,18 @@ def main():
else:
print("invalid account subcommand")
elif args["run"]:
asyncio.run(
service.run(
db=db,
sock_path=args["--sock_path"],
host=args["--listen_host"],
port=int(args["--listen_port"]),
)
)

async def run_service_with_client():
async with aiohttp.ClientSession() as client:
await service.run(
db=db,
client=client,
sock_path=args["--sock_path"],
host=args["--listen_host"],
port=int(args["--listen_port"]),
)

asyncio.run(run_service_with_client())
else:
print("CLI arg parse error?!")

Expand Down
42 changes: 34 additions & 8 deletions src/millipds/app_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,42 @@

from . import database

MILLIPDS_DB = web.AppKey("MILLIPDS_DB", database.Database)
MILLIPDS_AIOHTTP_CLIENT = web.AppKey(
"MILLIPDS_AIOHTTP_CLIENT", aiohttp.ClientSession
)
MILLIPDS_FIREHOSE_QUEUES = web.AppKey(
"MILLIPDS_FIREHOSE_QUEUES", Set[asyncio.Queue[Optional[Tuple[int, bytes]]]]
)
MILLIPDS_FIREHOSE_QUEUES_LOCK = web.AppKey(
"MILLIPDS_FIREHOSE_QUEUES_LOCK", asyncio.Lock
)


# these helpers are useful for conciseness and type hinting
def get_db(req: web.Request) -> database.Database:
return req.app["MILLIPDS_DB"]
def get_db(req: web.Request):
return req.app[MILLIPDS_DB]


def get_client(req: web.Request):
return req.app[MILLIPDS_AIOHTTP_CLIENT]


def get_firehose_queues(req: web.Request):
return req.app[MILLIPDS_FIREHOSE_QUEUES]


def get_client(req: web.Request) -> aiohttp.ClientSession:
return req.app["MILLIPDS_AIOHTTP_CLIENT"]
def get_firehose_queues_lock(req: web.Request):
return req.app[MILLIPDS_FIREHOSE_QUEUES_LOCK]

def get_firehose_queues(req: web.Request) -> Set[asyncio.Queue[Optional[Tuple[int, bytes]]]]:
return req.app["MILLIPDS_FIREHOSE_QUEUES"]

def get_firehose_queues_lock(req: web.Request) -> asyncio.Lock:
return req.app["MILLIPDS_FIREHOSE_QUEUES_LOCK"]
__all__ = [
"MILLIPDS_DB",
"MILLIPDS_AIOHTTP_CLIENT",
"MILLIPDS_FIREHOSE_QUEUES",
"MILLIPDS_FIREHOSE_QUEUES_LOCK",
"get_db",
"get_client",
"get_firehose_queues",
"get_firehose_queues_lock",
]
12 changes: 7 additions & 5 deletions src/millipds/appview_proxy.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Optional
import logging
import time

Expand All @@ -12,12 +13,11 @@


# TODO: this should be done via actual DID resolution, not hardcoded!
SERVICE_ROUTES = {
"did:web:api.bsky.chat#bsky_chat": "https://api.bsky.chat"
}
SERVICE_ROUTES = {"did:web:api.bsky.chat#bsky_chat": "https://api.bsky.chat"}


@authenticated
async def service_proxy(request: web.Request, service: Optional[str]=None):
async def service_proxy(request: web.Request, service: Optional[str] = None):
"""
If `service` is None, default to bsky appview (per details in db config)
"""
Expand Down Expand Up @@ -59,7 +59,9 @@ async def service_proxy(request: web.Request, service: Optional[str]=None):
elif request.method == "POST":
request_body = await request.read() # TODO: streaming?
async with get_client(request).post(
service_route + request.path, data=request_body, headers=(authn|{"Content-Type": request.content_type})
service_route + request.path,
data=request_body,
headers=(authn | {"Content-Type": request.content_type}),
) as r:
body_bytes = await r.read() # TODO: streaming?
return web.Response(
Expand Down
Loading

0 comments on commit 2b4dc1c

Please sign in to comment.