Skip to content

Commit

Permalink
Migrate lexicon parser from dacite to pydantic; enable ruff ANN
Browse files Browse the repository at this point in the history
  • Loading branch information
MarshalX committed Dec 17, 2023
1 parent b0eb25f commit 162a7cf
Show file tree
Hide file tree
Showing 66 changed files with 412 additions and 396 deletions.
10 changes: 0 additions & 10 deletions atproto/cbor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,6 @@
from atproto.exceptions import DAGCBORDecodingError


class _BytesReadCounter:
_num_bytes_read = 0

def __call__(self, _, num_bytes_read: int) -> None:
self._num_bytes_read += num_bytes_read

def __int__(self) -> int:
return self._num_bytes_read


def decode_dag(data: bytes) -> dict:
"""Decodes and returns a single data item from the given data, with the DAG-CBOR codec.
Expand Down
2 changes: 1 addition & 1 deletion atproto/cid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def encode(self) -> str:
def __str__(self) -> str:
return self.encode()

def __hash__(self):
def __hash__(self) -> int:
return hash(self.encode())


Expand Down
63 changes: 38 additions & 25 deletions atproto/codegen/models/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,18 +124,18 @@ def _get_typeddict_class_def(name: str, model_type: TypedDictType) -> str:
}


def _get_optional_typehint(type_hint, *, optional: bool) -> str:
def _get_optional_typehint(type_hint: str, *, optional: bool) -> str:
if optional:
return f't.Optional[{type_hint}]'
return type_hint


def _get_ref_typehint(nsid: NSID, field_type_def, *, optional: bool) -> str:
def _get_ref_typehint(nsid: NSID, field_type_def: models.LexRef, *, optional: bool) -> str:
model_path, _ = _resolve_nsid_ref(nsid, field_type_def.ref)
return _get_optional_typehint(f"'{model_path}'", optional=optional)


def _get_ref_union_typehint(nsid: NSID, field_type_def, *, optional: bool) -> str:
def _get_ref_union_typehint(nsid: NSID, field_type_def: models.LexRefUnion, *, optional: bool) -> str:
def_names = []
for ref in field_type_def.refs:
import_path, _ = _resolve_nsid_ref(nsid, ref)
Expand All @@ -156,7 +156,13 @@ def _get_ref_union_typehint(nsid: NSID, field_type_def, *, optional: bool) -> st
return _get_optional_typehint(annotated_union, optional=optional)


def _get_model_field_typehint(nsid: NSID, field_type_def, *, optional: bool, is_input_type: bool = False) -> str:
def _get_model_field_typehint(
nsid: NSID,
field_type_def: t.Union[models.LexPrimitive, models.LexArray, models.LexBlob],
*,
optional: bool,
is_input_type: bool = False,
) -> str:
field_type = type(field_type_def)

if field_type == models.LexUnknown:
Expand Down Expand Up @@ -195,7 +201,12 @@ def _get_model_field_typehint(nsid: NSID, field_type_def, *, optional: bool, is_
raise ValueError(f'Unknown field type {field_type.__name__}')


def _get_model_field_value(field_type_def, alias_name: t.Optional[str] = None, *, optional: bool) -> str: # noqa: C901
def _get_model_field_value( # noqa: C901
field_type_def: t.Union[models.LexPrimitive, models.LexArray, models.LexBlob],
alias_name: t.Optional[str] = None,
*,
optional: bool,
) -> str:
not_set = object()

default: t.Any = not_set
Expand Down Expand Up @@ -223,10 +234,10 @@ def _get_model_field_value(field_type_def, alias_name: t.Optional[str] = None, *
default = field_type_def.default
if field_type_def.const is not None:
frozen = field_type_def.const
if field_type_def.minLength is not None:
min_length = field_type_def.minLength
if field_type_def.maxLength is not None:
max_length = field_type_def.maxLength
if field_type_def.min_length is not None:
min_length = field_type_def.min_length
if field_type_def.max_length is not None:
max_length = field_type_def.max_length
# TODO (MarshalX): support knownValue, format, enum?

elif field_type == models.LexBoolean:
Expand All @@ -236,10 +247,10 @@ def _get_model_field_value(field_type_def, alias_name: t.Optional[str] = None, *
frozen = field_type_def.const

elif field_type is models.LexArray:
if field_type_def.minLength is not None:
min_length = field_type_def.minLength
if field_type_def.maxLength is not None:
max_length = field_type_def.maxLength
if field_type_def.min_length is not None:
min_length = field_type_def.min_length
if field_type_def.max_length is not None:
max_length = field_type_def.max_length

if default is not_set and optional:
default = None
Expand Down Expand Up @@ -295,7 +306,9 @@ def _get_req_fields_set(lex_obj: t.Union[models.LexObject, models.LexXrpcParamet
return required_fields


def _get_field_docstring(field_name: str, field_type) -> str:
def _get_field_docstring(
field_name: str, field_type: t.Union[models.LexPrimitive, models.LexArray, models.LexBlob]
) -> str:
field_desc = field_type.description
if field_desc is None:
field_desc = gen_description_by_camel_case_name(field_name)
Expand Down Expand Up @@ -395,7 +408,7 @@ def _get_typeddict(
type_hint = _get_model_field_typehint(nsid, field_type_def, optional=is_optional, is_input_type=is_input_type)
description = _get_field_docstring(field_name, field_type_def)

# Allow optional params to actually be ommitted from the dict entirely
# Allow optional params to actually be omitted from the dict entirely
type_hint_defaulting = f'te.NotRequired[{type_hint}]' if is_optional else type_hint
field_def = f'{_(1)}{snake_cased_field_name}: {type_hint_defaulting} #: {description}'

Expand Down Expand Up @@ -439,11 +452,11 @@ def _generate_params_model(nsid: NSID, definition: t.Union[models.LexXrpcQuery,

def _generate_xrpc_body_model(nsid: NSID, body: models.LexXrpcBody, model_type: ModelType) -> str:
lines = []
if body.schema:
if isinstance(body.schema, models.LexObject):
if body.schema_:
if isinstance(body.schema_, models.LexObject):
lines.append(_get_model_class_def(nsid.name, model_type))
lines.append(_get_model_docstring(nsid, body.schema, model_type))
lines.append(_get_model(nsid, body.schema, is_input_type=(model_type is ModelType.DATA)))
lines.append(_get_model_docstring(nsid, body.schema_, model_type))
lines.append(_get_model(nsid, body.schema_, is_input_type=(model_type is ModelType.DATA)))
else:
if model_type is ModelType.DATA:
model_name = INPUT_MODEL
Expand All @@ -459,9 +472,9 @@ def _generate_xrpc_body_model(nsid: NSID, body: models.LexXrpcBody, model_type:

def _generate_data_typedict(nsid: NSID, body: models.LexXrpcBody) -> str:
lines: t.List[str] = []
if isinstance(body.schema, models.LexObject):
if isinstance(body.schema_, models.LexObject):
lines.append(_get_typeddict_class_def(nsid.name, TypedDictType.DATA))
lines.append(_get_typeddict(nsid, body.schema, is_input_type=True))
lines.append(_get_typeddict(nsid, body.schema_, is_input_type=True))
return join_code(lines)


Expand Down Expand Up @@ -510,13 +523,13 @@ def _generate_def_array(nsid: NSID, def_name: str, def_model: models.LexArray) -


def _generate_def_string(nsid: NSID, def_name: str, def_model: models.LexString) -> str:
# FIXME(MarshalX): support more fields. only knownValues field is supported for now
# FIXME(MarshalX): support more fields. only known_values field is supported for now

if not def_model.knownValues:
if not def_model.known_values:
return ''

union_types = []
for known_value in def_model.knownValues:
for known_value in def_model.known_values:
if '#' in known_value:
# reference to literal (token)
model_path, _ = _resolve_nsid_ref(nsid, known_value)
Expand Down Expand Up @@ -687,7 +700,7 @@ def _generate_init_files(root_package_path: Path) -> None:
write_code(root_path.joinpath('__init__.py'), join_code(import_lines))


def _generate_empty_init_files(root_package_path: Path):
def _generate_empty_init_files(root_package_path: Path) -> None:
for root, dirs, files in os.walk(root_package_path):
root_path = Path(root)

Expand Down
3 changes: 2 additions & 1 deletion atproto/codegen/namespaces/builder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import typing as t
from dataclasses import dataclass
from pathlib import Path

from atproto.lexicon.models import (
LexDefinition,
Expand Down Expand Up @@ -104,7 +105,7 @@ def build_namespace_tree(lexicons: t.List[LexiconDoc]) -> dict:
return namespace_tree


def build_namespaces(lexicon_dir=None) -> dict:
def build_namespaces(lexicon_dir: t.Optional[Path] = None) -> dict:
lexicons = lexicon_parse_dir(lexicon_dir)
return build_namespace_tree(lexicons)

Expand Down
11 changes: 6 additions & 5 deletions atproto/codegen/namespaces/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from atproto.lexicon.models import (
LexObject,
LexRef,
LexXrpcParameters,
LexXrpcProcedure,
LexXrpcQuery,
)
Expand Down Expand Up @@ -201,7 +202,7 @@ def _get_namespace_method_signature_args_names(definition: t.Union[LexXrpcProced
args.add('params')

if isinstance(definition, LexXrpcProcedure) and definition.input:
if definition.input.schema:
if definition.input.schema_:
args.add('data_schema')
else:
args.add('data_alias')
Expand All @@ -225,7 +226,7 @@ def _add_arg(arg_def: str, *, optional: bool) -> None:
else:
args.append(arg_def)

def is_optional_arg(lex_obj) -> bool:
def is_optional_arg(lex_obj: t.Union[LexObject, LexXrpcParameters]) -> bool:
return lex_obj.required is None or len(lex_obj.required) == 0

if method_info.definition.parameters:
Expand All @@ -238,7 +239,7 @@ def is_optional_arg(lex_obj) -> bool:
_add_arg(arg, optional=is_optional)

if isinstance(method_info, ProcedureInfo) and method_info.definition.input:
schema = method_info.definition.input.schema
schema = method_info.definition.input.schema_
if schema:
is_optional = is_optional_arg(schema)

Expand All @@ -259,8 +260,8 @@ def is_optional_arg(lex_obj) -> bool:


def _get_namespace_method_return_type(method_info: MethodInfo) -> t.Tuple[str, bool]:
if method_info.definition.output and isinstance(method_info.definition.output.schema, LexRef):
ref_class, _ = _resolve_nsid_ref(method_info.nsid, method_info.definition.output.schema.ref)
if method_info.definition.output and isinstance(method_info.definition.output.schema_, LexRef):
ref_class, _ = _resolve_nsid_ref(method_info.nsid, method_info.definition.output.schema_.ref)
return ref_class, True

is_model = False
Expand Down
8 changes: 0 additions & 8 deletions atproto/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,6 @@ class LexiconParsingError(AtProtocolError):
...


class UnknownPrimitiveTypeError(LexiconParsingError):
...


class UnknownDefinitionTypeError(LexiconParsingError):
...


class InvalidNsidError(AtProtocolError):
...

Expand Down
9 changes: 7 additions & 2 deletions atproto/firehose/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@
AsyncOnCallbackErrorCallback = t.Callable[[BaseException], t.Coroutine[t.Any, t.Any, None]]


if t.TYPE_CHECKING:
from websockets.client import ClientConnection as SyncWebSocketClient
from websockets.legacy.client import Connect as AsyncConnect


def _build_websocket_uri(
method: str, base_uri: t.Optional[str] = None, params: t.Optional[t.Dict[str, t.Any]] = None
) -> str:
Expand Down Expand Up @@ -109,10 +114,10 @@ def _websocket_uri(self) -> str:
# the user should care about updated params by himself
return _build_websocket_uri(self._method, self._base_uri, self._params)

def _get_client(self):
def _get_client(self) -> 'SyncWebSocketClient':
return connect(self._websocket_uri, max_size=_MAX_MESSAGE_SIZE_BYTES)

def _get_async_client(self):
def _get_async_client(self) -> 'AsyncConnect':
# FIXME(DXsmiley): I've noticed that the close operation often takes the entire timeout for some reason
# By default, this is 10 seconds, which is pretty long.
# Maybe shorten it?
Expand Down
Loading

0 comments on commit 162a7cf

Please sign in to comment.