Skip to content

Commit

Permalink
Add mypy; fix types; fix error handling of requests (#56)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarshalX authored Jun 1, 2023
1 parent 572e1e3 commit 1783a84
Show file tree
Hide file tree
Showing 43 changed files with 828 additions and 754 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,4 @@ docs/_build/

# Caches
.ruff_cache/
.mypy_cache/
14 changes: 10 additions & 4 deletions atproto/car/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from atproto import cbor, leb128
from atproto.cid import CID
from atproto.exceptions import InvalidCARFile

Blocks = t.Dict[CID, dict]

Expand All @@ -12,12 +13,12 @@ class CAR:

_CID_V1_BYTES_LEN = 36

def __init__(self, root: str, blocks: Blocks) -> None:
def __init__(self, root: CID, blocks: Blocks) -> None:
self._root = root
self._blocks = blocks

@property
def root(self) -> str:
def root(self) -> CID:
"""Get root."""
return self._root

Expand All @@ -32,7 +33,7 @@ def from_bytes(cls, data: bytes) -> 'CAR':
Note:
You could pass as `data` response of `client.com.atproto.sync.get_repo`, for example.
And another responses of methods in the `sync` namespace.
And other responses of methods in the `sync` namespace.
Example:
>>> from atproto import CAR, Client
Expand All @@ -53,7 +54,12 @@ def from_bytes(cls, data: bytes) -> 'CAR':

header_len, _ = leb128.u.decode_reader(stream)
header = cbor.decode_dag(stream.read(header_len))
root = header.get('roots')[0]

roots = header.get('roots')
if isinstance(roots, list) and len(roots):
root: CID = roots[0]
else:
raise InvalidCARFile('Invalid CAR file. Expected at least one root.')

blocks = {}
while stream.tell() != len(data):
Expand Down
13 changes: 8 additions & 5 deletions atproto/cbor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from io import BytesIO
from typing import List, Union
from typing import Dict, List, Union

import dag_cbor as _dag_cbor
from dag_cbor.decoding import CBORDecodingError as _CBORDecodingError
Expand All @@ -20,7 +20,7 @@ def __int__(self) -> int:
return self._num_bytes_read


def decode_dag(data: DagCborData, *, allow_concat: bool = False, callback=None) -> _dag_cbor.IPLDKind:
def decode_dag(data: DagCborData, *, allow_concat: bool = False, callback=None) -> Dict[str, _dag_cbor.IPLDKind]:
"""Decodes and returns a single data item from the given data, with the DAG-CBOR codec.
Args:
Expand All @@ -32,7 +32,10 @@ def decode_dag(data: DagCborData, *, allow_concat: bool = False, callback=None)
:obj:`dag_cbor.IPLDKind`: Decoded DAG-CBOR.
"""
try:
return _dag_cbor.decode(data, allow_concat=allow_concat, callback=callback)
decoded_data = _dag_cbor.decode(data, allow_concat=allow_concat, callback=callback)
if isinstance(decoded_data, dict):
return decoded_data
raise DAGCBORDecodingError(f'Invalid DAG-CBOR data. Expected dict instead of {type(decoded_data).__name__}')
except _DAGCBORDecodingError as e:
raise DAGCBORDecodingError from e
except _CBORDecodingError as e:
Expand All @@ -41,7 +44,7 @@ def decode_dag(data: DagCborData, *, allow_concat: bool = False, callback=None)
raise e


def decode_dag_multi(data: DagCborData) -> List[_dag_cbor.IPLDKind]:
def decode_dag_multi(data: DagCborData) -> List[Dict[str, _dag_cbor.IPLDKind]]:
"""Decodes and returns many data items from the given data, with the DAG-CBOR codec.
Args:
Expand All @@ -56,7 +59,7 @@ def decode_dag_multi(data: DagCborData) -> List[_dag_cbor.IPLDKind]:

data_size = data.getbuffer().nbytes

data_parts = []
data_parts: List[Dict[str, _dag_cbor.IPLDKind]] = []
if data_size == 0:
return data_parts

Expand Down
24 changes: 13 additions & 11 deletions atproto/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@ def resolve_command(
) -> t.Tuple[t.Optional[str], t.Optional[click.Command], t.List[str]]:
# always return the full command name
_, cmd, args = super().resolve_command(ctx, args)
return cmd.name, cmd, args

name = None
if cmd:
name = cmd.name

return name, cmd, args


@click.group(cls=AliasedGroup)
Expand All @@ -46,9 +51,8 @@ def atproto_cli(ctx: click.Context) -> None:
@click.option('--lexicon-dir', type=click.Path(exists=True), default=None, help='Path to dir with .JSON lexicon files.')
@click.pass_context
def gen(ctx: click.Context, lexicon_dir: t.Optional[str]) -> None:
if lexicon_dir:
lexicon_dir = Path(lexicon_dir)
ctx.obj['lexicon_dir'] = lexicon_dir
lexicon_dir_path = Path(lexicon_dir) if lexicon_dir else None
ctx.obj['lexicon_dir'] = lexicon_dir_path


@gen.command(name='all', help='Generated models, namespaces, and async clients with default configs.')
Expand Down Expand Up @@ -86,13 +90,13 @@ def gen_models(ctx: click.Context, output_dir: t.Optional[str]) -> None:

if output_dir:
# FIXME(MarshalX): remove hardcoded imports
output_dir = Path(output_dir)
click.secho(
"It doesn't work with '--output-dir' option very well because of hardcoded imports! Replace by yourself",
fg='red',
)

_gen_models(ctx.obj.get('lexicon_dir'), output_dir)
_gen_models(ctx.obj.get('lexicon_dir'), Path(output_dir))
else:
_gen_models(ctx.obj.get('lexicon_dir'))

click.echo('Done!')

Expand All @@ -107,10 +111,8 @@ def gen_namespaces(
) -> None:
click.echo('Generating namespaces...')

if output_dir:
output_dir = Path(output_dir)

_gen_namespaces(ctx.obj.get('lexicon_dir'), output_dir, async_filename, sync_filename)
output_dir_path = Path(output_dir) if output_dir else None
_gen_namespaces(ctx.obj.get('lexicon_dir'), output_dir_path, async_filename, sync_filename)

click.echo('Done!')

Expand Down
6 changes: 3 additions & 3 deletions atproto/codegen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@

from atproto.nsid import NSID

DISCLAIMER = [
_DISCLAIMER_LINES = [
"# THIS IS THE AUTO-GENERATED CODE. DON'T EDIT IT BY HANDS!",
'# Copyright (C) 2023 Ilya (Marshal) <https://github.com/MarshalX>.',
'# This file is part of Python atproto SDK. Licenced under MIT.',
]
_MAX_DISCLAIMER_LEN = max([len(s) for s in DISCLAIMER])
DISCLAIMER = '\n'.join(DISCLAIMER)
_MAX_DISCLAIMER_LEN = max([len(s) for s in _DISCLAIMER_LINES])
DISCLAIMER = '\n'.join(_DISCLAIMER_LINES)
DISCLAIMER = f'{"#" * _MAX_DISCLAIMER_LEN}\n{DISCLAIMER}\n{"#" * _MAX_DISCLAIMER_LEN}\n\n'

PARAMS_MODEL = 'Params'
Expand Down
95 changes: 62 additions & 33 deletions atproto/codegen/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,57 +5,41 @@
from atproto.lexicon.parser import lexicon_parse_dir
from atproto.nsid import NSID

_LEX_DEF_TYPES_FOR_PARAMS = {
models.LexDefinitionType.QUERY,
models.LexDefinitionType.PROCEDURE,
models.LexDefinitionType.SUBSCRIPTION,
}
_LEX_DEF_TYPES_FOR_RESPONSES = {models.LexDefinitionType.QUERY, models.LexDefinitionType.PROCEDURE}
_LEX_DEF_TYPES_FOR_REFS = {models.LexDefinitionType.QUERY, models.LexDefinitionType.PROCEDURE}
_LEX_DEF_TYPES_FOR_DATA = {models.LexDefinitionType.PROCEDURE}
_LEX_DEF_TYPES_FOR_RECORDS = {models.LexDefinitionType.RECORD}
_LEX_DEF_TYPES_FOR_DEF = {
models.LexDefinitionType.OBJECT,
models.LexPrimitiveType.STRING,
models.LexDefinitionType.TOKEN,
models.LexDefinitionType.ARRAY,
}
if t.TYPE_CHECKING:
from enum import Enum

LexDefs = t.Dict[
str,
t.Union[
models.LexXrpcProcedure,
models.LexXrpcQuery,
models.LexObject,
models.LexToken,
models.LexString,
models.LexRecord,
],
t.Any,
]
LexDB = t.Dict[NSID, LexDefs]


class _LexiconDir:
dir_path: t.Optional[Path]

def __init__(self, default_path: Path = None) -> None:
def __init__(self, default_path: t.Optional[Path] = None) -> None:
self.dir_path = default_path

def set(self, path: Path) -> None:
self.dir_path = path

def get(self) -> Path:
def get(self) -> t.Optional[Path]:
return self.dir_path


lexicon_dir = _LexiconDir()


def _filter_defs_by_type(defs: t.Dict[str, models.LexDefinition], def_types: set) -> LexDefs:
def _filter_defs_by_type(
defs: t.Dict[str, models.LexDefinition], def_types: t.Union[t.Set['models.LexDefinitionType'], t.Set['Enum']]
) -> LexDefs:
return {k: v for k, v in defs.items() if v.type in def_types}


def _build_nsid_to_defs_map(lexicons: t.List[models.LexiconDoc], def_types: set) -> LexDB:
def _build_nsid_to_defs_map(
lexicons: t.List[models.LexiconDoc], def_types: t.Union[t.Set['models.LexDefinitionType'], t.Set['Enum']]
) -> LexDB:
result = {}

for lexicon in lexicons:
Expand All @@ -67,27 +51,72 @@ def _build_nsid_to_defs_map(lexicons: t.List[models.LexiconDoc], def_types: set)
return result


def build_params_models() -> LexDB:
BuiltParamsModels = t.Dict[
NSID,
t.Dict[
str,
t.Union[
models.LexXrpcQuery,
models.LexXrpcProcedure,
models.LexSubscription,
],
],
]


def build_params_models() -> BuiltParamsModels:
_LEX_DEF_TYPES_FOR_PARAMS = {
models.LexDefinitionType.QUERY,
models.LexDefinitionType.PROCEDURE,
models.LexDefinitionType.SUBSCRIPTION,
}
return _build_nsid_to_defs_map(lexicon_parse_dir(lexicon_dir.get()), _LEX_DEF_TYPES_FOR_PARAMS)


def build_data_models() -> LexDB:
BuiltDataModels = t.Dict[NSID, t.Dict[str, t.Union[models.LexXrpcProcedure]]]


def build_data_models() -> BuiltDataModels:
_LEX_DEF_TYPES_FOR_DATA = {models.LexDefinitionType.PROCEDURE}
return _build_nsid_to_defs_map(lexicon_parse_dir(lexicon_dir.get()), _LEX_DEF_TYPES_FOR_DATA)


def build_response_models() -> LexDB:
BuiltResponseModels = t.Dict[NSID, t.Dict[str, t.Union[models.LexXrpcQuery, models.LexXrpcProcedure]]]


def build_response_models() -> BuiltResponseModels:
_LEX_DEF_TYPES_FOR_RESPONSES = {models.LexDefinitionType.QUERY, models.LexDefinitionType.PROCEDURE}
return _build_nsid_to_defs_map(lexicon_parse_dir(lexicon_dir.get()), _LEX_DEF_TYPES_FOR_RESPONSES)


def build_def_models() -> LexDB:
BuiltDefModels = t.Dict[
NSID, t.Dict[str, t.Union[models.LexObject, models.LexString, models.LexToken, models.LexArray]]
]


def build_def_models() -> BuiltDefModels:
_LEX_DEF_TYPES_FOR_DEF = {
models.LexDefinitionType.OBJECT,
models.LexPrimitiveType.STRING,
models.LexDefinitionType.TOKEN,
models.LexDefinitionType.ARRAY,
}
return _build_nsid_to_defs_map(lexicon_parse_dir(lexicon_dir.get()), _LEX_DEF_TYPES_FOR_DEF)


def build_record_models() -> LexDB:
BuiltRecordModels = t.Dict[NSID, t.Dict[str, t.Union[models.LexRecord]]]


def build_record_models() -> BuiltRecordModels:
_LEX_DEF_TYPES_FOR_RECORDS = {models.LexDefinitionType.RECORD}
return _build_nsid_to_defs_map(lexicon_parse_dir(lexicon_dir.get()), _LEX_DEF_TYPES_FOR_RECORDS)


def build_refs_models() -> LexDB:
BuiltRefsModels = t.Dict[NSID, t.Dict[str, t.Union[models.LexXrpcQuery, models.LexXrpcProcedure]]]


def build_refs_models() -> BuiltRefsModels:
_LEX_DEF_TYPES_FOR_REFS = {models.LexDefinitionType.QUERY, models.LexDefinitionType.PROCEDURE}
return _build_nsid_to_defs_map(lexicon_parse_dir(lexicon_dir.get()), _LEX_DEF_TYPES_FOR_REFS)


Expand Down
Loading

0 comments on commit 1783a84

Please sign in to comment.