diff --git a/.gitignore b/.gitignore index 42336d3f..528233de 100644 --- a/.gitignore +++ b/.gitignore @@ -42,3 +42,4 @@ docs/_build/ # Caches .ruff_cache/ +.mypy_cache/ diff --git a/atproto/car/__init__.py b/atproto/car/__init__.py index 083f92c0..eb0b6a1c 100644 --- a/atproto/car/__init__.py +++ b/atproto/car/__init__.py @@ -3,6 +3,7 @@ from atproto import cbor, leb128 from atproto.cid import CID +from atproto.exceptions import InvalidCARFile Blocks = t.Dict[CID, dict] @@ -12,12 +13,12 @@ class CAR: _CID_V1_BYTES_LEN = 36 - def __init__(self, root: str, blocks: Blocks) -> None: + def __init__(self, root: CID, blocks: Blocks) -> None: self._root = root self._blocks = blocks @property - def root(self) -> str: + def root(self) -> CID: """Get root.""" return self._root @@ -32,7 +33,7 @@ def from_bytes(cls, data: bytes) -> 'CAR': Note: You could pass as `data` response of `client.com.atproto.sync.get_repo`, for example. - And another responses of methods in the `sync` namespace. + And other responses of methods in the `sync` namespace. Example: >>> from atproto import CAR, Client @@ -53,7 +54,12 @@ def from_bytes(cls, data: bytes) -> 'CAR': header_len, _ = leb128.u.decode_reader(stream) header = cbor.decode_dag(stream.read(header_len)) - root = header.get('roots')[0] + + roots = header.get('roots') + if isinstance(roots, list) and len(roots): + root: CID = roots[0] + else: + raise InvalidCARFile('Invalid CAR file. Expected at least one root.') blocks = {} while stream.tell() != len(data): diff --git a/atproto/cbor/__init__.py b/atproto/cbor/__init__.py index d83a5e77..df6f1d7a 100644 --- a/atproto/cbor/__init__.py +++ b/atproto/cbor/__init__.py @@ -1,5 +1,5 @@ from io import BytesIO -from typing import List, Union +from typing import Dict, List, Union import dag_cbor as _dag_cbor from dag_cbor.decoding import CBORDecodingError as _CBORDecodingError @@ -20,7 +20,7 @@ def __int__(self) -> int: return self._num_bytes_read -def decode_dag(data: DagCborData, *, allow_concat: bool = False, callback=None) -> _dag_cbor.IPLDKind: +def decode_dag(data: DagCborData, *, allow_concat: bool = False, callback=None) -> Dict[str, _dag_cbor.IPLDKind]: """Decodes and returns a single data item from the given data, with the DAG-CBOR codec. Args: @@ -32,7 +32,10 @@ def decode_dag(data: DagCborData, *, allow_concat: bool = False, callback=None) :obj:`dag_cbor.IPLDKind`: Decoded DAG-CBOR. """ try: - return _dag_cbor.decode(data, allow_concat=allow_concat, callback=callback) + decoded_data = _dag_cbor.decode(data, allow_concat=allow_concat, callback=callback) + if isinstance(decoded_data, dict): + return decoded_data + raise DAGCBORDecodingError(f'Invalid DAG-CBOR data. Expected dict instead of {type(decoded_data).__name__}') except _DAGCBORDecodingError as e: raise DAGCBORDecodingError from e except _CBORDecodingError as e: @@ -41,7 +44,7 @@ def decode_dag(data: DagCborData, *, allow_concat: bool = False, callback=None) raise e -def decode_dag_multi(data: DagCborData) -> List[_dag_cbor.IPLDKind]: +def decode_dag_multi(data: DagCborData) -> List[Dict[str, _dag_cbor.IPLDKind]]: """Decodes and returns many data items from the given data, with the DAG-CBOR codec. Args: @@ -56,7 +59,7 @@ def decode_dag_multi(data: DagCborData) -> List[_dag_cbor.IPLDKind]: data_size = data.getbuffer().nbytes - data_parts = [] + data_parts: List[Dict[str, _dag_cbor.IPLDKind]] = [] if data_size == 0: return data_parts diff --git a/atproto/cli/__init__.py b/atproto/cli/__init__.py index 4df94d9e..397005a3 100644 --- a/atproto/cli/__init__.py +++ b/atproto/cli/__init__.py @@ -31,7 +31,12 @@ def resolve_command( ) -> t.Tuple[t.Optional[str], t.Optional[click.Command], t.List[str]]: # always return the full command name _, cmd, args = super().resolve_command(ctx, args) - return cmd.name, cmd, args + + name = None + if cmd: + name = cmd.name + + return name, cmd, args @click.group(cls=AliasedGroup) @@ -46,9 +51,8 @@ def atproto_cli(ctx: click.Context) -> None: @click.option('--lexicon-dir', type=click.Path(exists=True), default=None, help='Path to dir with .JSON lexicon files.') @click.pass_context def gen(ctx: click.Context, lexicon_dir: t.Optional[str]) -> None: - if lexicon_dir: - lexicon_dir = Path(lexicon_dir) - ctx.obj['lexicon_dir'] = lexicon_dir + lexicon_dir_path = Path(lexicon_dir) if lexicon_dir else None + ctx.obj['lexicon_dir'] = lexicon_dir_path @gen.command(name='all', help='Generated models, namespaces, and async clients with default configs.') @@ -86,13 +90,13 @@ def gen_models(ctx: click.Context, output_dir: t.Optional[str]) -> None: if output_dir: # FIXME(MarshalX): remove hardcoded imports - output_dir = Path(output_dir) click.secho( "It doesn't work with '--output-dir' option very well because of hardcoded imports! Replace by yourself", fg='red', ) - - _gen_models(ctx.obj.get('lexicon_dir'), output_dir) + _gen_models(ctx.obj.get('lexicon_dir'), Path(output_dir)) + else: + _gen_models(ctx.obj.get('lexicon_dir')) click.echo('Done!') @@ -107,10 +111,8 @@ def gen_namespaces( ) -> None: click.echo('Generating namespaces...') - if output_dir: - output_dir = Path(output_dir) - - _gen_namespaces(ctx.obj.get('lexicon_dir'), output_dir, async_filename, sync_filename) + output_dir_path = Path(output_dir) if output_dir else None + _gen_namespaces(ctx.obj.get('lexicon_dir'), output_dir_path, async_filename, sync_filename) click.echo('Done!') diff --git a/atproto/codegen/__init__.py b/atproto/codegen/__init__.py index b381b867..18183cba 100644 --- a/atproto/codegen/__init__.py +++ b/atproto/codegen/__init__.py @@ -5,13 +5,13 @@ from atproto.nsid import NSID -DISCLAIMER = [ +_DISCLAIMER_LINES = [ "# THIS IS THE AUTO-GENERATED CODE. DON'T EDIT IT BY HANDS!", '# Copyright (C) 2023 Ilya (Marshal) .', '# This file is part of Python atproto SDK. Licenced under MIT.', ] -_MAX_DISCLAIMER_LEN = max([len(s) for s in DISCLAIMER]) -DISCLAIMER = '\n'.join(DISCLAIMER) +_MAX_DISCLAIMER_LEN = max([len(s) for s in _DISCLAIMER_LINES]) +DISCLAIMER = '\n'.join(_DISCLAIMER_LINES) DISCLAIMER = f'{"#" * _MAX_DISCLAIMER_LEN}\n{DISCLAIMER}\n{"#" * _MAX_DISCLAIMER_LEN}\n\n' PARAMS_MODEL = 'Params' diff --git a/atproto/codegen/models/builder.py b/atproto/codegen/models/builder.py index d4d4b875..e261665a 100644 --- a/atproto/codegen/models/builder.py +++ b/atproto/codegen/models/builder.py @@ -5,32 +5,12 @@ from atproto.lexicon.parser import lexicon_parse_dir from atproto.nsid import NSID -_LEX_DEF_TYPES_FOR_PARAMS = { - models.LexDefinitionType.QUERY, - models.LexDefinitionType.PROCEDURE, - models.LexDefinitionType.SUBSCRIPTION, -} -_LEX_DEF_TYPES_FOR_RESPONSES = {models.LexDefinitionType.QUERY, models.LexDefinitionType.PROCEDURE} -_LEX_DEF_TYPES_FOR_REFS = {models.LexDefinitionType.QUERY, models.LexDefinitionType.PROCEDURE} -_LEX_DEF_TYPES_FOR_DATA = {models.LexDefinitionType.PROCEDURE} -_LEX_DEF_TYPES_FOR_RECORDS = {models.LexDefinitionType.RECORD} -_LEX_DEF_TYPES_FOR_DEF = { - models.LexDefinitionType.OBJECT, - models.LexPrimitiveType.STRING, - models.LexDefinitionType.TOKEN, - models.LexDefinitionType.ARRAY, -} +if t.TYPE_CHECKING: + from enum import Enum LexDefs = t.Dict[ str, - t.Union[ - models.LexXrpcProcedure, - models.LexXrpcQuery, - models.LexObject, - models.LexToken, - models.LexString, - models.LexRecord, - ], + t.Any, ] LexDB = t.Dict[NSID, LexDefs] @@ -38,24 +18,28 @@ class _LexiconDir: dir_path: t.Optional[Path] - def __init__(self, default_path: Path = None) -> None: + def __init__(self, default_path: t.Optional[Path] = None) -> None: self.dir_path = default_path def set(self, path: Path) -> None: self.dir_path = path - def get(self) -> Path: + def get(self) -> t.Optional[Path]: return self.dir_path lexicon_dir = _LexiconDir() -def _filter_defs_by_type(defs: t.Dict[str, models.LexDefinition], def_types: set) -> LexDefs: +def _filter_defs_by_type( + defs: t.Dict[str, models.LexDefinition], def_types: t.Union[t.Set['models.LexDefinitionType'], t.Set['Enum']] +) -> LexDefs: return {k: v for k, v in defs.items() if v.type in def_types} -def _build_nsid_to_defs_map(lexicons: t.List[models.LexiconDoc], def_types: set) -> LexDB: +def _build_nsid_to_defs_map( + lexicons: t.List[models.LexiconDoc], def_types: t.Union[t.Set['models.LexDefinitionType'], t.Set['Enum']] +) -> LexDB: result = {} for lexicon in lexicons: @@ -67,27 +51,72 @@ def _build_nsid_to_defs_map(lexicons: t.List[models.LexiconDoc], def_types: set) return result -def build_params_models() -> LexDB: +BuiltParamsModels = t.Dict[ + NSID, + t.Dict[ + str, + t.Union[ + models.LexXrpcQuery, + models.LexXrpcProcedure, + models.LexSubscription, + ], + ], +] + + +def build_params_models() -> BuiltParamsModels: + _LEX_DEF_TYPES_FOR_PARAMS = { + models.LexDefinitionType.QUERY, + models.LexDefinitionType.PROCEDURE, + models.LexDefinitionType.SUBSCRIPTION, + } return _build_nsid_to_defs_map(lexicon_parse_dir(lexicon_dir.get()), _LEX_DEF_TYPES_FOR_PARAMS) -def build_data_models() -> LexDB: +BuiltDataModels = t.Dict[NSID, t.Dict[str, t.Union[models.LexXrpcProcedure]]] + + +def build_data_models() -> BuiltDataModels: + _LEX_DEF_TYPES_FOR_DATA = {models.LexDefinitionType.PROCEDURE} return _build_nsid_to_defs_map(lexicon_parse_dir(lexicon_dir.get()), _LEX_DEF_TYPES_FOR_DATA) -def build_response_models() -> LexDB: +BuiltResponseModels = t.Dict[NSID, t.Dict[str, t.Union[models.LexXrpcQuery, models.LexXrpcProcedure]]] + + +def build_response_models() -> BuiltResponseModels: + _LEX_DEF_TYPES_FOR_RESPONSES = {models.LexDefinitionType.QUERY, models.LexDefinitionType.PROCEDURE} return _build_nsid_to_defs_map(lexicon_parse_dir(lexicon_dir.get()), _LEX_DEF_TYPES_FOR_RESPONSES) -def build_def_models() -> LexDB: +BuiltDefModels = t.Dict[ + NSID, t.Dict[str, t.Union[models.LexObject, models.LexString, models.LexToken, models.LexArray]] +] + + +def build_def_models() -> BuiltDefModels: + _LEX_DEF_TYPES_FOR_DEF = { + models.LexDefinitionType.OBJECT, + models.LexPrimitiveType.STRING, + models.LexDefinitionType.TOKEN, + models.LexDefinitionType.ARRAY, + } return _build_nsid_to_defs_map(lexicon_parse_dir(lexicon_dir.get()), _LEX_DEF_TYPES_FOR_DEF) -def build_record_models() -> LexDB: +BuiltRecordModels = t.Dict[NSID, t.Dict[str, t.Union[models.LexRecord]]] + + +def build_record_models() -> BuiltRecordModels: + _LEX_DEF_TYPES_FOR_RECORDS = {models.LexDefinitionType.RECORD} return _build_nsid_to_defs_map(lexicon_parse_dir(lexicon_dir.get()), _LEX_DEF_TYPES_FOR_RECORDS) -def build_refs_models() -> LexDB: +BuiltRefsModels = t.Dict[NSID, t.Dict[str, t.Union[models.LexXrpcQuery, models.LexXrpcProcedure]]] + + +def build_refs_models() -> BuiltRefsModels: + _LEX_DEF_TYPES_FOR_REFS = {models.LexDefinitionType.QUERY, models.LexDefinitionType.PROCEDURE} return _build_nsid_to_defs_map(lexicon_parse_dir(lexicon_dir.get()), _LEX_DEF_TYPES_FOR_REFS) diff --git a/atproto/codegen/models/generator.py b/atproto/codegen/models/generator.py index 46af1f91..62ab5e8e 100644 --- a/atproto/codegen/models/generator.py +++ b/atproto/codegen/models/generator.py @@ -190,7 +190,7 @@ def _get_model_field_typehint(nsid: NSID, field_name: str, field_type_def, *, op # yes, it returns blob,but actually it's blob ref here return _get_optional_typehint('BlobRef', optional=optional) - raise ValueError(f'Unknown field type {field_name.__name__}') + raise ValueError(f'Unknown field type {field_type.__name__}') def _get_req_fields_set(lex_obj: t.Union[models.LexObject, models.LexXrpcParameters]) -> set: @@ -216,7 +216,9 @@ def _get_field_docstring(field_name: str, field_type) -> str: def _get_model_docstring( - nsid: t.Union[str, NSID], lex_object: t.Union[models.LexObject, models.LexXrpcParameters], model_type: ModelType + nsid: t.Union[str, NSID], + lex_object: t.Union[models.LexXrpcQuery, models.LexSubscription, models.LexObject, models.LexXrpcParameters], + model_type: ModelType, ) -> str: model_desc = lex_object.description or '' model_desc = f'{model_type.value} model for :obj:`{nsid}`. {model_desc}' @@ -257,12 +259,11 @@ def _get_model(nsid: NSID, lex_object: t.Union[models.LexObject, models.LexXrpcP def _get_model_ref(nsid: NSID, ref: models.LexRef) -> str: # FIXME(MarshalX): "local=True" Is it works well? ;d ref_class, _ = _resolve_nsid_ref(nsid, ref.ref, local=True) - ref_typehint = f't.Type[{ref_class}]' # "Ref" suffix required to fix name collisions from different namespaces lines = [ f'#: {OUTPUT_MODEL} reference to :obj:`{ref_class}` model.', - f'{OUTPUT_MODEL}Ref: {ref_typehint} = {ref_class}', + f'{OUTPUT_MODEL}Ref = {ref_class}', '', '', ] @@ -271,11 +272,11 @@ def _get_model_ref(nsid: NSID, ref: models.LexRef) -> str: def _get_model_raw_data(name: str) -> str: - lines = [f'#: {name} raw data type.', f'{name}: t.Union[t.Type[str], t.Type[bytes]] = bytes\n\n'] + lines = [f'#: {name} raw data type.', f'{name}: te.TypeAlias = bytes\n\n'] return join_code(lines) -def _generate_params_model(nsid: NSID, definition: t.Union[models.LexXrpcProcedure, models.LexXrpcQuery]) -> str: +def _generate_params_model(nsid: NSID, definition: t.Union[models.LexXrpcQuery, models.LexSubscription]) -> str: lines = [_get_model_class_def(nsid.name, ModelType.PARAMS)] if definition.parameters: @@ -351,8 +352,8 @@ def _generate_def_string(def_name: str, def_model: models.LexString) -> str: return '' # FIXME(MarshalX): Use ref resolver - known_values = ["'" + get_def_model_name(v.split('#', 1)[1]) + "'" for v in def_model.knownValues] - known_values = ', '.join(known_values) + known_values_list = ["'" + get_def_model_name(v.split('#', 1)[1]) + "'" for v in def_model.knownValues] + known_values = ', '.join(known_values_list) type_ = f'te.Literal[{known_values}]' @@ -365,16 +366,20 @@ def _generate_def_string(def_name: str, def_model: models.LexString) -> str: return join_code(lines) -def _generate_params_models(lex_db: builder.LexDB) -> None: +def _generate_params_models(lex_db: builder.BuiltParamsModels) -> None: for nsid, defs in lex_db.items(): _save_code_import_if_not_exist(nsid) definition = defs['main'] if definition.parameters: - save_code_part(nsid, _generate_params_model(nsid, definition)) + if isinstance(definition.parameters, models.LexXrpcParameters): + save_code_part(nsid, _generate_params_model(nsid, definition)) + else: + # LexXrpcProcedure has parameters using another model + raise ValueError('Wrong parameters type or not implemented') -def _generate_data_models(lex_db: builder.LexDB) -> None: +def _generate_data_models(lex_db: builder.BuiltDataModels) -> None: for nsid, defs in lex_db.items(): _save_code_import_if_not_exist(nsid) @@ -383,7 +388,7 @@ def _generate_data_models(lex_db: builder.LexDB) -> None: save_code_part(nsid, _generate_data_model(nsid, definition.input)) -def _generate_response_models(lex_db: builder.LexDB) -> None: +def _generate_response_models(lex_db: builder.BuiltResponseModels) -> None: for nsid, defs in lex_db.items(): _save_code_import_if_not_exist(nsid) @@ -392,7 +397,7 @@ def _generate_response_models(lex_db: builder.LexDB) -> None: save_code_part(nsid, _generate_response_model(nsid, definition.output)) -def _generate_def_models(lex_db: builder.LexDB) -> None: +def _generate_def_models(lex_db: builder.BuiltDefModels) -> None: for nsid, defs in lex_db.items(): _save_code_import_if_not_exist(nsid) @@ -409,7 +414,7 @@ def _generate_def_models(lex_db: builder.LexDB) -> None: raise ValueError(f'Unhandled type {type(def_model)}') -def _generate_record_models(lex_db: builder.LexDB) -> None: +def _generate_record_models(lex_db: builder.BuiltRecordModels) -> None: for nsid, defs in lex_db.items(): _save_code_import_if_not_exist(nsid) @@ -419,7 +424,7 @@ def _generate_record_models(lex_db: builder.LexDB) -> None: save_code_part(nsid, _generate_def_model(nsid, def_name, def_record.record, ModelType.RECORD)) -def _generate_record_type_database(lex_db: builder.LexDB) -> None: +def _generate_record_type_database(lex_db: builder.BuiltRecordModels) -> None: lines = ['from atproto.xrpc_client import models', 'RECORD_TYPE_TO_MODEL_CLASS = {'] for nsid, defs in lex_db.items(): @@ -441,7 +446,7 @@ def _generate_record_type_database(lex_db: builder.LexDB) -> None: write_code(_MODELS_OUTPUT_DIR.joinpath('type_conversion.py'), join_code(lines)) -def _generate_ref_models(lex_db: builder.LexDB) -> None: +def _generate_ref_models(lex_db: builder.BuiltRefsModels) -> None: for nsid, defs in lex_db.items(): definition = defs['main'] if ( @@ -462,16 +467,16 @@ def _generate_ref_models(lex_db: builder.LexDB) -> None: def _generate_init_files(root_package_path: Path) -> None: - # one of the ways that I tried. Doesn't work well due to circular imports + # One of the ways that I tried. Doesn't work well due to circular imports for root, dirs, files in os.walk(root_package_path): - root = Path(root) + root_path = Path(root) import_lines = [] for dir_name in dirs: if dir_name.startswith('__'): continue - import_parts = root.parts[root.joinpath(dir_name).parts.index(_MODELS_OUTPUT_DIR.parent.name) :] + import_parts = root_path.parts[root_path.joinpath(dir_name).parts.index(_MODELS_OUTPUT_DIR.parent.name) :] from_import = '.'.join(import_parts) if dir_name in {'app', 'com'}: @@ -483,24 +488,24 @@ def _generate_init_files(root_package_path: Path) -> None: if file_name.startswith('__'): continue - import_parts = root.parts[root.parts.index(_MODELS_OUTPUT_DIR.parent.name) :] + import_parts = root_path.parts[root_path.parts.index(_MODELS_OUTPUT_DIR.parent.name) :] from_import = '.'.join(import_parts) import_lines.append(f'from atproto.{from_import} import {file_name[:-3]}') - if root.name == 'models': + if root_path.name == 'models': # FIXME skip for now. should be generated too continue - if root.name == '__pycache__': + if root_path.name == '__pycache__': continue - write_code(root.joinpath('__init__.py'), join_code(import_lines)) + write_code(root_path.joinpath('__init__.py'), join_code(import_lines)) def _generate_empty_init_files(root_package_path: Path): for root, dirs, files in os.walk(root_package_path): - root = Path(root) + root_path = Path(root) for dir_name in dirs: if dir_name.startswith('__'): @@ -513,14 +518,14 @@ def _generate_empty_init_files(root_package_path: Path): if file_name.startswith('__'): continue - if root.name == 'models': + if root_path.name == 'models': # FIXME skip for now. should be generated too continue - if root.name == '__pycache__': + if root_path.name == '__pycache__': continue - write_code(root.joinpath('__init__.py'), DISCLAIMER) + write_code(root_path.joinpath('__init__.py'), DISCLAIMER) def _generate_import_aliases(root_package_path: Path) -> None: @@ -530,9 +535,9 @@ def _generate_import_aliases(root_package_path: Path) -> None: import_lines = [] ids_db = ['class _Ids:'] for root, __, files in os.walk(root_package_path): - root = Path(root) + root_path = Path(root) - if root == root_package_path: + if root_path == root_package_path: continue for file in files: @@ -541,10 +546,10 @@ def _generate_import_aliases(root_package_path: Path) -> None: if '.cpython-' in file: continue - import_parts = root.parts[root.parts.index(_MODELS_OUTPUT_DIR.parent.name) :] + import_parts = root_path.parts[root_path.parts.index(_MODELS_OUTPUT_DIR.parent.name) :] from_import = '.'.join(import_parts) - nsid_parts = list(root.parts[root.parts.index('models') + 1 :]) + nsid_parts = list(root_path.parts[root_path.parts.index('models') + 1 :]) method_name_parts = file[:-3].split('_') alias_name = ''.join([p.capitalize() for p in [*nsid_parts, *method_name_parts]]) diff --git a/atproto/codegen/namespaces/builder.py b/atproto/codegen/namespaces/builder.py index db512f8c..f73bf8f1 100644 --- a/atproto/codegen/namespaces/builder.py +++ b/atproto/codegen/namespaces/builder.py @@ -1,4 +1,5 @@ import typing as t +from dataclasses import dataclass from atproto.lexicon.models import ( LexDefinition, @@ -12,7 +13,6 @@ from atproto.nsid import NSID _VALID_LEX_DEF_TYPES = {LexDefinitionType.QUERY, LexDefinitionType.PROCEDURE, LexDefinitionType.RECORD} -_VALID_LEX_DEF_TYPE = t.Union[LexXrpcProcedure, LexXrpcQuery, LexRecord] def _filter_namespace_valid_definitions(definitions: t.Dict[str, LexDefinition]) -> t.Dict[str, LexDefinition]: @@ -32,43 +32,63 @@ def get_definition_by_name(name: str, defs: t.Dict[str, LexDefinition]) -> LexDe return defs['main'] -class ObjectInfo(t.NamedTuple): +@dataclass +class ObjectInfo: name: str nsid: NSID - definition: _VALID_LEX_DEF_TYPE -class MethodInfo(ObjectInfo): - pass +@dataclass +class ProcedureInfo(ObjectInfo): + definition: LexXrpcProcedure +@dataclass +class QueryInfo(ObjectInfo): + definition: LexXrpcQuery + + +MethodInfo = t.Union[ProcedureInfo, QueryInfo] + + +@dataclass class RecordInfo(ObjectInfo): - pass + definition: LexRecord def _enrich_namespace_tree(root: dict, nsid: NSID, defs: t.Dict[str, LexDefinition]) -> None: + root_node: t.Union[dict, list] = root + segments_count = len(nsid.segments) for path_level, segment in enumerate(nsid.segments): # if method if path_level == segments_count - 1: definition = get_definition_by_name(segment, defs) - model_class = MethodInfo - if definition.type is LexDefinitionType.RECORD: + model_class: t.Type[ObjectInfo] + if definition.type is LexDefinitionType.PROCEDURE: + model_class = ProcedureInfo + elif definition.type is LexDefinitionType.QUERY: + model_class = QueryInfo + elif definition.type is LexDefinitionType.RECORD: model_class = RecordInfo + else: + raise RuntimeError(f'Unknown definition type: {definition.type}') # TODO(MarshalX): fake records as namespaces with methods to be able to reuse code of generator? - root.append(model_class(name=segment, nsid=nsid, definition=definition)) + if model_class: + root_node.append(model_class(name=segment, nsid=nsid, definition=definition)) + continue - if segment not in root: + if segment not in root_node: # if end of method's path if path_level == segments_count - 2: - root[segment] = [] + root_node[segment] = [] else: - root[segment] = {} - root = root[segment] + root_node[segment] = {} + root_node = root_node[segment] def build_namespace_tree(lexicons: t.List[LexiconDoc]) -> dict: diff --git a/atproto/codegen/namespaces/generator.py b/atproto/codegen/namespaces/generator.py index e5a2215b..bcd241c2 100644 --- a/atproto/codegen/namespaces/generator.py +++ b/atproto/codegen/namespaces/generator.py @@ -16,12 +16,12 @@ write_code, ) from atproto.codegen import get_code_intent as _ -from atproto.codegen.namespaces.builder import MethodInfo, RecordInfo, build_namespaces +from atproto.codegen.namespaces.builder import MethodInfo, ProcedureInfo, QueryInfo, RecordInfo, build_namespaces from atproto.lexicon.models import ( - LexDefinitionType, LexObject, LexRef, LexXrpcProcedure, + LexXrpcQuery, ) from atproto.nsid import NSID @@ -45,37 +45,30 @@ def get_record_name(path_part: str) -> str: def _get_namespace_imports() -> str: lines = [ DISCLAIMER, - 'from dataclasses import dataclass, field', 'import typing as t', '', 'from atproto.xrpc_client import models', - 'from atproto.xrpc_client.models.utils import get_or_create, get_response_model', - 'from atproto.xrpc_client.namespaces.base import DefaultNamespace, NamespaceBase', + 'from atproto.xrpc_client.models.utils import get_or_create_model, get_response_model', + 'from atproto.xrpc_client.namespaces.base import AsyncNamespaceBase, DefaultNamespace, NamespaceBase', + '', + 'if t.TYPE_CHECKING:', + f'{_(1)}from atproto.xrpc_client.client.async_raw import AsyncClientRaw', + f'{_(1)}from atproto.xrpc_client.client.raw import ClientRaw', ] return join_code(lines) -def _get_namespace_class_def(name: str) -> str: - lines = ['@dataclass', f'class {get_namespace_name(name)}(NamespaceBase):'] +def _get_namespace_class_def(name: str, *, sync: bool) -> str: + base_class = 'NamespaceBase' if sync else 'AsyncNamespaceBase' + lines = [f'class {get_namespace_name(name)}({base_class}):'] return join_code(lines) -def _get_sub_namespaces_block(sub_namespaces: dict) -> str: - lines = [] - - sub_namespaces = sort_dict_by_key(sub_namespaces) - for sub_namespace in sub_namespaces: - lines.append( - f"{_(1)}{sub_namespace}: '{get_namespace_name(sub_namespace)}' = field(default_factory=DefaultNamespace)" - ) - - return join_code(lines) - - -def _get_post_init_method(sub_namespaces: dict) -> str: - lines = [f'{_(1)}def __post_init__(self) -> None:'] +def _get_init_method(sub_namespaces: dict, *, sync: bool) -> str: + client_typehint = "'ClientRaw'" if sync else "'AsyncClientRaw'" + lines = [f'{_(1)}def __init__(self, client: {client_typehint}) -> None:', f'{_(2)}super().__init__(client)'] sub_namespaces = sort_dict_by_key(sub_namespaces) for sub_namespace in sub_namespaces: @@ -95,7 +88,7 @@ def _get_method_docstring(method_info: MethodInfo) -> str: doc_string = [f'{_(2)}"""{method_desc}', '', f'{_(2)}Args:'] - presented_args = _get_namespace_method_signature_args_names(method_info) + presented_args = _get_namespace_method_signature_args_names(method_info.definition) if 'params' in presented_args: doc_string.append(f'{_(3)}params: Parameters.') if 'data_schema' in presented_args: @@ -126,25 +119,35 @@ def _get_method_docstring(method_info: MethodInfo) -> str: return join_code(doc_string) +@t.overload +def _get_namespace_method_body(method_info: ProcedureInfo, *, sync: bool) -> str: + ... + + +@t.overload +def _get_namespace_method_body(method_info: QueryInfo, *, sync: bool) -> str: + ... + + def _get_namespace_method_body(method_info: MethodInfo, *, sync: bool) -> str: d, c = get_sync_async_keywords(sync=sync) lines = [_get_method_docstring(method_info)] - presented_args = _get_namespace_method_signature_args_names(method_info) + presented_args = _get_namespace_method_signature_args_names(method_info.definition) presented_args.remove('self') def _override_arg_line(name: str, model_name: str) -> str: model_path = f'models.{get_import_path(method_info.nsid)}.{model_name}' - return f'{_(2)}{name} = get_or_create({name}, {model_path})' + return f'{_(2)}{name}_model = get_or_create_model({name}, {model_path})' invoke_args = [f"'{method_info.nsid}'"] if 'params' in presented_args: - invoke_args.append('params=params') + invoke_args.append('params=params_model') lines.append(_override_arg_line('params', PARAMS_MODEL)) if 'data_schema' in presented_args: - invoke_args.append('data=data') + invoke_args.append('data=data_model') lines.append(_override_arg_line('data', INPUT_MODEL)) if 'data_alias' in presented_args: invoke_args.append('data=data') @@ -185,21 +188,21 @@ def _get_namespace_method_signature_arg( return f'{name}: {type_hint}{default_value}' -def _get_namespace_method_signature_args_names(method_info: MethodInfo) -> t.Set[str]: +def _get_namespace_method_signature_args_names(definition: t.Union[LexXrpcProcedure, LexXrpcQuery]) -> t.Set[str]: args = {'self'} - if method_info.definition.parameters: + if definition.parameters: args.add('params') - if method_info.definition.type is LexDefinitionType.PROCEDURE and method_info.definition.input: - if method_info.definition.input.schema: + if isinstance(definition, LexXrpcProcedure) and definition.input: + if definition.input.schema: args.add('data_schema') else: args.add('data_alias') - if method_info.definition.input.encoding: + if definition.input.encoding: args.add('input_encoding') - if method_info.definition.output and method_info.definition.output.encoding: + if definition.output and definition.output.encoding: args.add('output_encoding') return args @@ -225,7 +228,7 @@ def is_optional_arg(lex_obj) -> bool: arg = _get_namespace_method_signature_arg('params', method_info.nsid, PARAMS_MODEL, optional=is_optional) _add_arg(arg, optional=is_optional) - if method_info.definition.type is LexDefinitionType.PROCEDURE and method_info.definition.input: + if isinstance(method_info, ProcedureInfo) and method_info.definition.input: schema = method_info.definition.input.schema if schema: is_optional = is_optional_arg(schema) @@ -296,22 +299,21 @@ def _get_namespace_records_block(records_info: t.List[RecordInfo]) -> str: return join_code(lines) -def _generate_namespace_in_output(namespace_tree: t.Union[dict, list], output: t.List[str], *, sync: bool) -> None: +def _generate_namespace_in_output(namespace_tree: dict, output: t.List[str], *, sync: bool) -> None: for node_name, sub_node in namespace_tree.items(): if isinstance(sub_node, dict): - output.append(_get_namespace_class_def(node_name)) - output.append(_get_sub_namespaces_block(sub_node)) - output.append(_get_post_init_method(sub_node)) + output.append(_get_namespace_class_def(node_name, sync=sync)) + output.append(_get_init_method(sub_node, sync=sync)) _generate_namespace_in_output(sub_node, output, sync=sync) if isinstance(sub_node, list): - output.append(_get_namespace_class_def(node_name)) + output.append(_get_namespace_class_def(node_name, sync=sync)) # TODO(MarshalX): gen namespace by RecordInfo later # TODO(MarshalX): generate namespace record classes! - methods = [info for info in sub_node if isinstance(info, MethodInfo)] + methods = [info for info in sub_node if isinstance(info, (ProcedureInfo, QueryInfo))] output.append(_get_namespace_methods_block(methods, sync=sync)) @@ -331,7 +333,7 @@ def generate_namespaces( namespace_tree = build_namespaces(lexicon_dir) for sync in (True, False): - generated_code_lines_buffer = [] + generated_code_lines_buffer: t.List[str] = [] _generate_namespace_in_output(namespace_tree, generated_code_lines_buffer, sync=sync) code = join_code([_get_namespace_imports(), *generated_code_lines_buffer]) diff --git a/atproto/exceptions.py b/atproto/exceptions.py index 14b4962e..272840d1 100644 --- a/atproto/exceptions.py +++ b/atproto/exceptions.py @@ -50,7 +50,7 @@ class ModelFieldNotFoundError(ModelError): class RequestErrorBase(AtProtocolError): def __init__(self, response: t.Optional['Response'] = None) -> None: - self.response: 'Response' = response + self.response: t.Optional['Response'] = response class NetworkError(RequestErrorBase): @@ -87,3 +87,7 @@ class CBORDecodingError(AtProtocolError): class DAGCBORDecodingError(AtProtocolError): ... + + +class InvalidCARFile(AtProtocolError): + ... diff --git a/atproto/firehose/__init__.py b/atproto/firehose/__init__.py index 3d713657..8e55abda 100644 --- a/atproto/firehose/__init__.py +++ b/atproto/firehose/__init__.py @@ -2,7 +2,7 @@ from atproto.firehose.client import AsyncFirehoseClient, FirehoseClient from atproto.xrpc_client import models -from atproto.xrpc_client.models.utils import get_model_as_dict, get_or_create +from atproto.xrpc_client.models.utils import get_model_as_dict, get_or_create_model if t.TYPE_CHECKING: from atproto.firehose.models import MessageFrame @@ -49,7 +49,7 @@ def parse_subscribe_repos_message(message: 'MessageFrame') -> SubscribeReposMess :obj:`SubscribeReposMessage`: Corresponding message model. """ model_class = _SUBSCRIBE_REPOS_MESSAGE_TYPE_TO_MODEL[message.type] - return get_or_create(message.body, model_class) + return get_or_create_model(message.body, model_class) def parse_subscribe_labels_message(message: 'MessageFrame') -> SubscribeLabelsMessage: @@ -62,35 +62,55 @@ def parse_subscribe_labels_message(message: 'MessageFrame') -> SubscribeLabelsMe :obj:`SubscribeLabelsMessage`: Corresponding message model. """ model_class = _SUBSCRIBE_LABELS_MESSAGE_TYPE_TO_MODEL[message.type] - return get_or_create(message.body, model_class) + return get_or_create_model(message.body, model_class) class FirehoseSubscribeReposClient(FirehoseClient): def __init__(self, params: t.Optional[t.Union[dict, 'models.ComAtprotoSyncSubscribeRepos.Params']] = None) -> None: - params = get_or_create(params, models.ComAtprotoSyncSubscribeRepos.Params) - if params: - params = get_model_as_dict(params) - super().__init__(method='com.atproto.sync.subscribeRepos', params=params) + params_model = get_or_create_model(params, models.ComAtprotoSyncSubscribeRepos.Params) + + params_dict = None + if params_model: + params_dict = get_model_as_dict(params_model) + + super().__init__(method='com.atproto.sync.subscribeRepos', params=params_dict) class AsyncFirehoseSubscribeReposClient(AsyncFirehoseClient): def __init__(self, params: t.Optional[t.Union[dict, 'models.ComAtprotoSyncSubscribeRepos.Params']] = None) -> None: - params = get_or_create(params, models.ComAtprotoSyncSubscribeRepos.Params) - if params: - params = get_model_as_dict(params) - super().__init__(method='com.atproto.sync.subscribeRepos', params=params) + params_model = get_or_create_model(params, models.ComAtprotoSyncSubscribeRepos.Params) + + params_dict = None + if params_model: + params_dict = get_model_as_dict(params_model) + + super().__init__(method='com.atproto.sync.subscribeRepos', params=params_dict) # TODO(MarshalX): SubscribeLabels doesn't work yet? HTTP 502 Error class FirehoseSubscribeLabelsClient(FirehoseClient): - def __init__(self, params: t.Optional[t.Union[dict, 'models.ComAtprotoLabelSubscribeLabels']] = None) -> None: - params = get_or_create(params, models.ComAtprotoLabelSubscribeLabels.Params) - super().__init__(method='com.atproto.label.subscribeLabels', params=params) + def __init__( + self, params: t.Optional[t.Union[dict, 'models.ComAtprotoLabelSubscribeLabels.Params']] = None + ) -> None: + params_model = get_or_create_model(params, models.ComAtprotoLabelSubscribeLabels.Params) + + params_dict = None + if params_model: + params_dict = get_model_as_dict(params_model) + + super().__init__(method='com.atproto.label.subscribeLabels', params=params_dict) class AsyncFirehoseSubscribeLabelsClient(AsyncFirehoseClient): - def __init__(self, params: t.Optional[t.Union[dict, 'models.ComAtprotoLabelSubscribeLabels']] = None) -> None: - params = get_or_create(params, models.ComAtprotoLabelSubscribeLabels.Params) - super().__init__(method='com.atproto.label.subscribeLabels', params=params) + def __init__( + self, params: t.Optional[t.Union[dict, 'models.ComAtprotoLabelSubscribeLabels.Params']] = None + ) -> None: + params_model = get_or_create_model(params, models.ComAtprotoLabelSubscribeLabels.Params) + + params_dict = None + if params_model: + params_dict = get_model_as_dict(params_model) + + super().__init__(method='com.atproto.label.subscribeLabels', params=params_dict) diff --git a/atproto/firehose/client.py b/atproto/firehose/client.py index 095d10ea..02f001c1 100644 --- a/atproto/firehose/client.py +++ b/atproto/firehose/client.py @@ -16,18 +16,13 @@ ) from atproto.exceptions import CBORDecodingError, DAGCBORDecodingError, FirehoseError -from atproto.firehose.models import Frame +from atproto.firehose.models import ErrorFrame, Frame, MessageFrame from atproto.xrpc_client.models.common import XrpcError -if t.TYPE_CHECKING: - from httpx_ws import AsyncWebSocketSession, WebSocketSession - - from atproto.firehose.models import MessageFrame - _BASE_WEBSOCKET_URL = 'https://bsky.social/xrpc' -OnMessageCallback = t.Callable[['MessageFrame'], None] -AsyncOnMessageCallback = t.Callable[['MessageFrame'], t.Awaitable[None]] +OnMessageCallback = t.Callable[['MessageFrame'], t.Generator[t.Any, None, t.Any]] +AsyncOnMessageCallback = t.Callable[['MessageFrame'], t.Coroutine[t.Any, t.Any, t.Any]] OnCallbackErrorCallback = t.Callable[[BaseException], None] @@ -74,28 +69,22 @@ def __init__( method: str, base_url: t.Optional[str] = None, params: t.Optional[t.Dict[str, t.Any]] = None, - *, - async_version: bool = False, ) -> None: self._params = params self._url = _build_websocket_url(method, base_url) - self._async_version = async_version - self._reconnect_no = 0 self._max_reconnect_delay_sec = 64 self._on_message_callback: t.Optional[t.Union[OnMessageCallback, AsyncOnMessageCallback]] = None self._on_callback_error_callback: t.Optional[OnCallbackErrorCallback] = None - def _get_client( - self, - ) -> t.Union[t.AsyncGenerator['AsyncWebSocketSession', None], t.Generator['WebSocketSession', None, None]]: - if self._async_version: - return aconnect_ws(self._url, params=self._params) - + def _get_client(self): return connect_ws(self._url, params=self._params) + def _get_async_client(self): + return aconnect_ws(self._url, params=self._params) + def _get_reconnection_delay(self) -> int: base_sec = 2**self._reconnect_no rand_sec = random.uniform(-0.5, 0.5) # noqa: S311 @@ -104,9 +93,9 @@ def _get_reconnection_delay(self) -> int: def _process_raw_frame(self, data: bytes) -> None: frame = Frame.from_bytes(data) - if frame.is_error: + if isinstance(frame, ErrorFrame): raise FirehoseError(XrpcError(frame.body.error, frame.body.message)) - if frame.is_message: + if isinstance(frame, MessageFrame): self._process_message_frame(frame) else: raise FirehoseError('Unknown frame type') @@ -125,9 +114,10 @@ def start( Returns: :obj:`None` """ - raise NotImplementedError + self._on_message_callback = on_message_callback + self._on_callback_error_callback = on_callback_error_callback - def stop(self) -> None: + def stop(self): """Unsubscribe and stop Firehose client. Returns: @@ -159,13 +149,8 @@ def _process_message_frame(self, frame: 'MessageFrame') -> None: else: traceback.print_exc() - def start( - self, - on_message_callback: OnMessageCallback, - on_callback_error_callback: t.Optional[OnCallbackErrorCallback] = None, - ) -> None: - self._on_message_callback = on_message_callback - self._on_callback_error_callback = on_callback_error_callback + def start(self, *args, **kwargs): + super().start(*args, **kwargs) while not self._stop_lock.locked(): try: @@ -187,7 +172,7 @@ def start( self._stop_lock.release() - def stop(self) -> None: + def stop(self): if not self._stop_lock.locked(): self._stop_lock.acquire() @@ -196,45 +181,41 @@ class _AsyncWebsocketClient(_WebsocketClientBase): def __init__( self, method: str, base_url: t.Optional[str] = None, params: t.Optional[t.Dict[str, t.Any]] = None ) -> None: - super().__init__(method, base_url, params, async_version=True) + super().__init__(method, base_url, params) self._loop = asyncio.get_event_loop() - self._on_message_tasks = set() + self._on_message_tasks: t.Set[asyncio.Task] = set() self._stop_lock = asyncio.Lock() def _on_message_callback_done(self, task: asyncio.Task) -> None: self._on_message_tasks.discard(task) - if task.exception(): + exception = task.exception() + if exception: if not self._on_callback_error_callback: traceback.print_exc() return try: - self._on_callback_error_callback(task.exception()) + self._on_callback_error_callback(exception) except: # noqa traceback.print_exc() def _process_message_frame(self, frame: 'MessageFrame') -> None: - task = self._loop.create_task(self._on_message_callback(frame)) + task: asyncio.Task = self._loop.create_task(self._on_message_callback(frame)) self._on_message_tasks.add(task) task.add_done_callback(self._on_message_callback_done) - async def start( - self, - on_message_callback: AsyncOnMessageCallback, - on_callback_error_callback: t.Optional[OnCallbackErrorCallback] = None, - ) -> None: - self._on_message_callback = on_message_callback - self._on_callback_error_callback = on_callback_error_callback + async def start(self, *args, **kwargs): + super().start(*args, **kwargs) while not self._stop_lock.locked(): try: if self._reconnect_no != 0: await asyncio.sleep(self._get_reconnection_delay()) - async with self._get_client() as client: + async with self._get_async_client() as client: self._reconnect_no = 0 while not self._stop_lock.locked(): @@ -249,7 +230,7 @@ async def start( self._stop_lock.release() - async def stop(self) -> None: + async def stop(self): if not self._stop_lock.locked(): await self._stop_lock.acquire() diff --git a/atproto/firehose/models.py b/atproto/firehose/models.py index b72970ce..ca2241d8 100644 --- a/atproto/firehose/models.py +++ b/atproto/firehose/models.py @@ -4,7 +4,7 @@ from atproto.cbor import decode_dag_multi from atproto.exceptions import AtProtocolError, FirehoseError -from atproto.xrpc_client.models.utils import get_or_create +from atproto.xrpc_client.models.utils import get_or_create_model class FrameType(Enum): @@ -53,8 +53,8 @@ def parse_frame_header(raw_header: dict) -> FrameHeader: frame_type = FrameType(header_op) if frame_type is FrameType.MESSAGE: - return get_or_create(raw_header, MessageFrameHeader) - return get_or_create(raw_header, ErrorFrameHeader) + return get_or_create_model(raw_header, MessageFrameHeader) + return get_or_create_model(raw_header, ErrorFrameHeader) except (ValueError, AtProtocolError) as e: raise FirehoseError('Invalid frame header') from e @@ -62,7 +62,7 @@ def parse_frame_header(raw_header: dict) -> FrameHeader: def parse_frame(header: FrameHeader, raw_body: dict) -> Union['ErrorFrame', 'MessageFrame']: try: if isinstance(header, ErrorFrameHeader): - body = get_or_create(raw_body, ErrorFrameBody) + body = get_or_create_model(raw_body, ErrorFrameBody) return ErrorFrame(header, body) if isinstance(header, MessageFrameHeader): return MessageFrame(header, raw_body) @@ -83,11 +83,6 @@ def operation(self) -> FrameType: """:obj:`FrameType`: Frame operation (frame type).""" return self.header.op - @property - def type(self) -> str: - """:obj:`str`: Frame type.""" - return self.header.t - @property def is_message(self) -> bool: """:obj:`bool`: Is frame the MessageFrame.""" diff --git a/atproto/leb128/__init__.py b/atproto/leb128/__init__.py index aa9bdd80..490afe37 100644 --- a/atproto/leb128/__init__.py +++ b/atproto/leb128/__init__.py @@ -1,5 +1,5 @@ """ -Original source code: https://github.com/mohanson/leb128 +Original source code: https://github.com/mohanson/leb128 (MIT license) https://en.wikipedia.org/wiki/LEB128 @@ -37,9 +37,9 @@ def decode(b: bytearray) -> int: return r @staticmethod - def decode_reader(r: typing.BinaryIO) -> (int, int): + def decode_reader(r: typing.BinaryIO) -> typing.Tuple[int, int]: """ - Decode the unsigned leb128 encoded from a reader, it will return two values, the actual number and the number + Decode the unsigned leb128 encoded from a reader, it will return two values, the actual number, and the number of bytes read. """ a = bytearray() @@ -75,9 +75,9 @@ def decode(b: bytearray) -> int: return r @staticmethod - def decode_reader(r: typing.BinaryIO) -> (int, int): + def decode_reader(r: typing.BinaryIO) -> typing.Tuple[int, int]: """ - Decode the signed leb128 encoded from a reader, it will return two values, the actual number and the number + Decode the signed leb128 encoded from a reader, it will return two values, the actual number, and the number of bytes read. """ a = bytearray() diff --git a/atproto/lexicon/models.py b/atproto/lexicon/models.py index 3adb4674..2b0bfdb6 100644 --- a/atproto/lexicon/models.py +++ b/atproto/lexicon/models.py @@ -2,7 +2,6 @@ from dataclasses import dataclass from enum import Enum -LexRef = str Number = t.Union[int, float, complex] diff --git a/atproto/lexicon/parser.py b/atproto/lexicon/parser.py index 62f55022..7c8abcd8 100644 --- a/atproto/lexicon/parser.py +++ b/atproto/lexicon/parser.py @@ -67,10 +67,7 @@ def _lex_primitive_type_hook(data: dict) -> models.LexPrimitive: L = t.TypeVar('L') -def lexicon_parse(data: dict, data_class: t.Optional[t.Type[L]] = None) -> t.Union[L, models.LexiconDoc]: - if not data_class: - data_class = models.LexiconDoc - +def lexicon_parse(data: dict, data_class: t.Optional[t.Type[L]] = models.LexiconDoc) -> L: return dacite.from_dict(data_class=data_class, data=data, config=_DEFAULT_DACITE_CONFIG) @@ -86,18 +83,21 @@ def lexicon_parse_file(lexicon_path: t.Union[Path, str], *, soft_fail: bool = Fa raise LexiconParsingError("Can't parse lexicon") from e -def lexicon_parse_dir(path: t.Union[Path, str] = None, *, soft_fail: bool = False) -> t.List[models.LexiconDoc]: - if path is None: - path = _PATH_TO_LEXICONS +def lexicon_parse_dir( + lexicon_dir: t.Optional[t.Union[Path, str]] = None, *, soft_fail: bool = False +) -> t.List[models.LexiconDoc]: + lexicon_dir_path = Path(lexicon_dir) if isinstance(lexicon_dir, str) else lexicon_dir + if lexicon_dir_path is None: + lexicon_dir_path = _PATH_TO_LEXICONS parsed_lexicons = [] - for _, _, lexicons in os.walk(path): + for _, _, lexicons in os.walk(lexicon_dir_path): for lexicon in lexicons: if not lexicon.endswith(_LEXICON_FILE_EXT): continue - lexicon_path = path.joinpath(lexicon) + lexicon_path = lexicon_dir_path.joinpath(lexicon) parsed_lexicon = lexicon_parse_file(lexicon_path, soft_fail=soft_fail) if parsed_lexicon: parsed_lexicons.append(parsed_lexicon) diff --git a/atproto/xrpc_client/client/async_client.py b/atproto/xrpc_client/client/async_client.py index 1c0e6a13..a7b9e63e 100644 --- a/atproto/xrpc_client/client/async_client.py +++ b/atproto/xrpc_client/client/async_client.py @@ -14,23 +14,16 @@ from atproto.xrpc_client.models import ids if t.TYPE_CHECKING: - from atproto.xrpc_client.client.auth import JwtPayload from atproto.xrpc_client.client.base import InvokeType from atproto.xrpc_client.request import Response -class AsyncClient(AsyncClientRaw, SessionMethodsMixin): +class AsyncClient(SessionMethodsMixin, AsyncClientRaw): """High-level client for XRPC of ATProto.""" - def __init__(self, base_url: str = None) -> None: + def __init__(self, base_url: t.Optional[str] = None) -> None: super().__init__(base_url) - self._access_jwt: t.Optional[str] = None - self._access_jwt_payload: t.Optional['JwtPayload'] = None - - self._refresh_jwt: t.Optional[str] = None - self._refresh_jwt_payload: t.Optional['JwtPayload'] = None - self._refresh_lock = Lock() self.me: t.Optional[models.AppBskyActorDefs.ProfileViewDetailed] = None @@ -63,11 +56,12 @@ async def _refresh_and_set_session(self) -> models.ComAtprotoServerRefreshSessio return refresh_session async def login(self, login: str, password: str) -> models.AppBskyActorGetProfile.ResponseRef: - """Authorize client and get profile info. + """Authorize a client and get profile info. Args: login: Handle/username of the account. - password: Password of the account. Could be app specific one. + password: Password of the account. + Could be an app-specific one. Returns: :obj:`models.AppBskyActorGetProfile.ResponseRef`: Profile information. diff --git a/atproto/xrpc_client/client/base.py b/atproto/xrpc_client/client/base.py index 25b5ea7f..d57971ea 100644 --- a/atproto/xrpc_client/client/base.py +++ b/atproto/xrpc_client/client/base.py @@ -4,6 +4,9 @@ from atproto.xrpc_client.models.utils import get_model_as_dict, get_model_as_json from atproto.xrpc_client.request import AsyncRequest, Request, Response +if t.TYPE_CHECKING: + from atproto.xrpc_client.models.base import DataModelBase, ParamsModelBase + # TODO(MarshalX): Generate async version automatically! @@ -61,12 +64,20 @@ def _build_url(self, nsid: str) -> str: return f'{self._base_url}/{nsid}' def invoke_query( - self, nsid: str, params: t.Optional[dict] = None, data: t.Optional[dict] = None, **kwargs + self, + nsid: str, + params: t.Optional['ParamsModelBase'] = None, + data: t.Optional[t.Union['DataModelBase', bytes]] = None, + **kwargs, ) -> Response: return self._invoke(InvokeType.QUERY, url=self._build_url(nsid), params=params, data=data, **kwargs) def invoke_procedure( - self, nsid: str, params: t.Optional[dict] = None, data: t.Optional[dict] = None, **kwargs + self, + nsid: str, + params: t.Optional['ParamsModelBase'] = None, + data: t.Optional[t.Union['DataModelBase', bytes]] = None, + **kwargs, ) -> Response: return self._invoke(InvokeType.PROCEDURE, url=self._build_url(nsid), params=params, data=data, **kwargs) @@ -98,12 +109,20 @@ def _build_url(self, nsid: str) -> str: return f'{self._base_url}/{nsid}' async def invoke_query( - self, nsid: str, params: t.Optional[dict] = None, data: t.Optional[dict] = None, **kwargs + self, + nsid: str, + params: t.Optional['ParamsModelBase'] = None, + data: t.Optional[t.Union['DataModelBase', bytes]] = None, + **kwargs, ) -> Response: return await self._invoke(InvokeType.QUERY, url=self._build_url(nsid), params=params, data=data, **kwargs) async def invoke_procedure( - self, nsid: str, params: t.Optional[dict] = None, data: t.Optional[dict] = None, **kwargs + self, + nsid: str, + params: t.Optional['ParamsModelBase'] = None, + data: t.Optional[t.Union['DataModelBase', bytes]] = None, + **kwargs, ) -> Response: return await self._invoke(InvokeType.PROCEDURE, url=self._build_url(nsid), params=params, data=data, **kwargs) diff --git a/atproto/xrpc_client/client/client.py b/atproto/xrpc_client/client/client.py index 9dbecf93..ea46b4fb 100644 --- a/atproto/xrpc_client/client/client.py +++ b/atproto/xrpc_client/client/client.py @@ -8,23 +8,16 @@ from atproto.xrpc_client.models import ids if t.TYPE_CHECKING: - from atproto.xrpc_client.client.auth import JwtPayload from atproto.xrpc_client.client.base import InvokeType from atproto.xrpc_client.request import Response -class Client(ClientRaw, SessionMethodsMixin): +class Client(SessionMethodsMixin, ClientRaw): """High-level client for XRPC of ATProto.""" - def __init__(self, base_url: str = None) -> None: + def __init__(self, base_url: t.Optional[str] = None) -> None: super().__init__(base_url) - self._access_jwt: t.Optional[str] = None - self._access_jwt_payload: t.Optional['JwtPayload'] = None - - self._refresh_jwt: t.Optional[str] = None - self._refresh_jwt_payload: t.Optional['JwtPayload'] = None - self._refresh_lock = Lock() self.me: t.Optional[models.AppBskyActorDefs.ProfileViewDetailed] = None @@ -55,11 +48,12 @@ def _refresh_and_set_session(self) -> models.ComAtprotoServerRefreshSession.Resp return refresh_session def login(self, login: str, password: str) -> models.AppBskyActorGetProfile.ResponseRef: - """Authorize client and get profile info. + """Authorize a client and get profile info. Args: login: Handle/username of the account. - password: Password of the account. Could be app specific one. + password: Password of the account. + Could be an app-specific one. Returns: :obj:`models.AppBskyActorGetProfile.ResponseRef`: Profile information. diff --git a/atproto/xrpc_client/client/methods_mixin/session.py b/atproto/xrpc_client/client/methods_mixin/session.py index 8e207076..725424f4 100644 --- a/atproto/xrpc_client/client/methods_mixin/session.py +++ b/atproto/xrpc_client/client/methods_mixin/session.py @@ -5,12 +5,22 @@ if t.TYPE_CHECKING: from atproto.xrpc_client import models + from atproto.xrpc_client.client.auth import JwtPayload class SessionMethodsMixin: + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._access_jwt: t.Optional[str] = None + self._access_jwt_payload: t.Optional['JwtPayload'] = None + + self._refresh_jwt: t.Optional[str] = None + self._refresh_jwt_payload: t.Optional['JwtPayload'] = None + def _should_refresh_session(self) -> bool: expired_at = datetime.fromtimestamp(self._access_jwt_payload.exp, tz=timezone.utc) - expired_at = expired_at - timedelta(minutes=15) # let's update token a bit later than required + expired_at = expired_at - timedelta(minutes=15) # let's update the token a bit later than required datetime_now = datetime.now(timezone.utc) diff --git a/atproto/xrpc_client/models/app/bsky/actor/get_profile.py b/atproto/xrpc_client/models/app/bsky/actor/get_profile.py index ef4aecb7..ba5604ee 100644 --- a/atproto/xrpc_client/models/app/bsky/actor/get_profile.py +++ b/atproto/xrpc_client/models/app/bsky/actor/get_profile.py @@ -5,7 +5,6 @@ ################################################################## -import typing as t from dataclasses import dataclass from atproto.xrpc_client import models @@ -21,4 +20,4 @@ class Params(base.ParamsModelBase): #: Response reference to :obj:`models.AppBskyActorDefs.ProfileViewDetailed` model. -ResponseRef: t.Type[models.AppBskyActorDefs.ProfileViewDetailed] = models.AppBskyActorDefs.ProfileViewDetailed +ResponseRef = models.AppBskyActorDefs.ProfileViewDetailed diff --git a/atproto/xrpc_client/models/com/atproto/admin/get_moderation_action.py b/atproto/xrpc_client/models/com/atproto/admin/get_moderation_action.py index b5a01e94..2ad4d733 100644 --- a/atproto/xrpc_client/models/com/atproto/admin/get_moderation_action.py +++ b/atproto/xrpc_client/models/com/atproto/admin/get_moderation_action.py @@ -5,7 +5,6 @@ ################################################################## -import typing as t from dataclasses import dataclass from atproto.xrpc_client import models @@ -21,4 +20,4 @@ class Params(base.ParamsModelBase): #: Response reference to :obj:`models.ComAtprotoAdminDefs.ActionViewDetail` model. -ResponseRef: t.Type[models.ComAtprotoAdminDefs.ActionViewDetail] = models.ComAtprotoAdminDefs.ActionViewDetail +ResponseRef = models.ComAtprotoAdminDefs.ActionViewDetail diff --git a/atproto/xrpc_client/models/com/atproto/admin/get_moderation_report.py b/atproto/xrpc_client/models/com/atproto/admin/get_moderation_report.py index 0247686c..4a1a7c19 100644 --- a/atproto/xrpc_client/models/com/atproto/admin/get_moderation_report.py +++ b/atproto/xrpc_client/models/com/atproto/admin/get_moderation_report.py @@ -5,7 +5,6 @@ ################################################################## -import typing as t from dataclasses import dataclass from atproto.xrpc_client import models @@ -21,4 +20,4 @@ class Params(base.ParamsModelBase): #: Response reference to :obj:`models.ComAtprotoAdminDefs.ReportViewDetail` model. -ResponseRef: t.Type[models.ComAtprotoAdminDefs.ReportViewDetail] = models.ComAtprotoAdminDefs.ReportViewDetail +ResponseRef = models.ComAtprotoAdminDefs.ReportViewDetail diff --git a/atproto/xrpc_client/models/com/atproto/admin/get_record.py b/atproto/xrpc_client/models/com/atproto/admin/get_record.py index d197e6f7..af5a5d55 100644 --- a/atproto/xrpc_client/models/com/atproto/admin/get_record.py +++ b/atproto/xrpc_client/models/com/atproto/admin/get_record.py @@ -22,4 +22,4 @@ class Params(base.ParamsModelBase): #: Response reference to :obj:`models.ComAtprotoAdminDefs.RecordViewDetail` model. -ResponseRef: t.Type[models.ComAtprotoAdminDefs.RecordViewDetail] = models.ComAtprotoAdminDefs.RecordViewDetail +ResponseRef = models.ComAtprotoAdminDefs.RecordViewDetail diff --git a/atproto/xrpc_client/models/com/atproto/admin/get_repo.py b/atproto/xrpc_client/models/com/atproto/admin/get_repo.py index fa54c8a3..e421efff 100644 --- a/atproto/xrpc_client/models/com/atproto/admin/get_repo.py +++ b/atproto/xrpc_client/models/com/atproto/admin/get_repo.py @@ -5,7 +5,6 @@ ################################################################## -import typing as t from dataclasses import dataclass from atproto.xrpc_client import models @@ -21,4 +20,4 @@ class Params(base.ParamsModelBase): #: Response reference to :obj:`models.ComAtprotoAdminDefs.RepoViewDetail` model. -ResponseRef: t.Type[models.ComAtprotoAdminDefs.RepoViewDetail] = models.ComAtprotoAdminDefs.RepoViewDetail +ResponseRef = models.ComAtprotoAdminDefs.RepoViewDetail diff --git a/atproto/xrpc_client/models/com/atproto/admin/resolve_moderation_reports.py b/atproto/xrpc_client/models/com/atproto/admin/resolve_moderation_reports.py index 1b766969..b215a8f4 100644 --- a/atproto/xrpc_client/models/com/atproto/admin/resolve_moderation_reports.py +++ b/atproto/xrpc_client/models/com/atproto/admin/resolve_moderation_reports.py @@ -23,4 +23,4 @@ class Data(base.DataModelBase): #: Response reference to :obj:`models.ComAtprotoAdminDefs.ActionView` model. -ResponseRef: t.Type[models.ComAtprotoAdminDefs.ActionView] = models.ComAtprotoAdminDefs.ActionView +ResponseRef = models.ComAtprotoAdminDefs.ActionView diff --git a/atproto/xrpc_client/models/com/atproto/admin/reverse_moderation_action.py b/atproto/xrpc_client/models/com/atproto/admin/reverse_moderation_action.py index 811ba146..381e594f 100644 --- a/atproto/xrpc_client/models/com/atproto/admin/reverse_moderation_action.py +++ b/atproto/xrpc_client/models/com/atproto/admin/reverse_moderation_action.py @@ -5,7 +5,6 @@ ################################################################## -import typing as t from dataclasses import dataclass from atproto.xrpc_client import models @@ -23,4 +22,4 @@ class Data(base.DataModelBase): #: Response reference to :obj:`models.ComAtprotoAdminDefs.ActionView` model. -ResponseRef: t.Type[models.ComAtprotoAdminDefs.ActionView] = models.ComAtprotoAdminDefs.ActionView +ResponseRef = models.ComAtprotoAdminDefs.ActionView diff --git a/atproto/xrpc_client/models/com/atproto/admin/take_moderation_action.py b/atproto/xrpc_client/models/com/atproto/admin/take_moderation_action.py index 5a5e92e7..d6d66ba2 100644 --- a/atproto/xrpc_client/models/com/atproto/admin/take_moderation_action.py +++ b/atproto/xrpc_client/models/com/atproto/admin/take_moderation_action.py @@ -29,4 +29,4 @@ class Data(base.DataModelBase): #: Response reference to :obj:`models.ComAtprotoAdminDefs.ActionView` model. -ResponseRef: t.Type[models.ComAtprotoAdminDefs.ActionView] = models.ComAtprotoAdminDefs.ActionView +ResponseRef = models.ComAtprotoAdminDefs.ActionView diff --git a/atproto/xrpc_client/models/com/atproto/repo/upload_blob.py b/atproto/xrpc_client/models/com/atproto/repo/upload_blob.py index 00afa3b7..5a9c4533 100644 --- a/atproto/xrpc_client/models/com/atproto/repo/upload_blob.py +++ b/atproto/xrpc_client/models/com/atproto/repo/upload_blob.py @@ -5,14 +5,15 @@ ################################################################## -import typing as t from dataclasses import dataclass +import typing_extensions as te + from atproto.xrpc_client.models import base from atproto.xrpc_client.models.blob_ref import BlobRef #: Data raw data type. -Data: t.Union[t.Type[str], t.Type[bytes]] = bytes +Data: te.TypeAlias = bytes @dataclass diff --git a/atproto/xrpc_client/models/com/atproto/server/create_app_password.py b/atproto/xrpc_client/models/com/atproto/server/create_app_password.py index e52691d5..23babcdc 100644 --- a/atproto/xrpc_client/models/com/atproto/server/create_app_password.py +++ b/atproto/xrpc_client/models/com/atproto/server/create_app_password.py @@ -5,7 +5,6 @@ ################################################################## -import typing as t from dataclasses import dataclass from atproto.xrpc_client.models import base @@ -32,4 +31,4 @@ class AppPassword(base.ModelBase): #: Response reference to :obj:`AppPassword` model. -ResponseRef: t.Type[AppPassword] = AppPassword +ResponseRef = AppPassword diff --git a/atproto/xrpc_client/models/com/atproto/sync/get_blob.py b/atproto/xrpc_client/models/com/atproto/sync/get_blob.py index d1082273..e39f5451 100644 --- a/atproto/xrpc_client/models/com/atproto/sync/get_blob.py +++ b/atproto/xrpc_client/models/com/atproto/sync/get_blob.py @@ -5,9 +5,10 @@ ################################################################## -import typing as t from dataclasses import dataclass +import typing_extensions as te + from atproto.xrpc_client.models import base @@ -21,4 +22,4 @@ class Params(base.ParamsModelBase): #: Response raw data type. -Response: t.Union[t.Type[str], t.Type[bytes]] = bytes +Response: te.TypeAlias = bytes diff --git a/atproto/xrpc_client/models/com/atproto/sync/get_blocks.py b/atproto/xrpc_client/models/com/atproto/sync/get_blocks.py index d77b12ff..f4d6d1a9 100644 --- a/atproto/xrpc_client/models/com/atproto/sync/get_blocks.py +++ b/atproto/xrpc_client/models/com/atproto/sync/get_blocks.py @@ -8,6 +8,8 @@ import typing as t from dataclasses import dataclass +import typing_extensions as te + from atproto.xrpc_client.models import base @@ -21,4 +23,4 @@ class Params(base.ParamsModelBase): #: Response raw data type. -Response: t.Union[t.Type[str], t.Type[bytes]] = bytes +Response: te.TypeAlias = bytes diff --git a/atproto/xrpc_client/models/com/atproto/sync/get_checkout.py b/atproto/xrpc_client/models/com/atproto/sync/get_checkout.py index d0ccdbf0..656f4373 100644 --- a/atproto/xrpc_client/models/com/atproto/sync/get_checkout.py +++ b/atproto/xrpc_client/models/com/atproto/sync/get_checkout.py @@ -8,6 +8,8 @@ import typing as t from dataclasses import dataclass +import typing_extensions as te + from atproto.xrpc_client.models import base @@ -21,4 +23,4 @@ class Params(base.ParamsModelBase): #: Response raw data type. -Response: t.Union[t.Type[str], t.Type[bytes]] = bytes +Response: te.TypeAlias = bytes diff --git a/atproto/xrpc_client/models/com/atproto/sync/get_record.py b/atproto/xrpc_client/models/com/atproto/sync/get_record.py index 9780a4c7..f18ac7b7 100644 --- a/atproto/xrpc_client/models/com/atproto/sync/get_record.py +++ b/atproto/xrpc_client/models/com/atproto/sync/get_record.py @@ -8,6 +8,8 @@ import typing as t from dataclasses import dataclass +import typing_extensions as te + from atproto.xrpc_client.models import base @@ -23,4 +25,4 @@ class Params(base.ParamsModelBase): #: Response raw data type. -Response: t.Union[t.Type[str], t.Type[bytes]] = bytes +Response: te.TypeAlias = bytes diff --git a/atproto/xrpc_client/models/com/atproto/sync/get_repo.py b/atproto/xrpc_client/models/com/atproto/sync/get_repo.py index f3e16f3b..4c14b799 100644 --- a/atproto/xrpc_client/models/com/atproto/sync/get_repo.py +++ b/atproto/xrpc_client/models/com/atproto/sync/get_repo.py @@ -8,6 +8,8 @@ import typing as t from dataclasses import dataclass +import typing_extensions as te + from atproto.xrpc_client.models import base @@ -22,4 +24,4 @@ class Params(base.ParamsModelBase): #: Response raw data type. -Response: t.Union[t.Type[str], t.Type[bytes]] = bytes +Response: te.TypeAlias = bytes diff --git a/atproto/xrpc_client/models/utils.py b/atproto/xrpc_client/models/utils.py index 7146b918..46938e78 100644 --- a/atproto/xrpc_client/models/utils.py +++ b/atproto/xrpc_client/models/utils.py @@ -3,6 +3,7 @@ import typing as t from enum import Enum +import typing_extensions as te from dacite import Config, exceptions, from_dict from atproto.cid import CID @@ -21,12 +22,13 @@ from atproto.xrpc_client.request import Response M = t.TypeVar('M') +ModelData: te.TypeAlias = t.Union[M, dict, None] def _record_model_type_hook(data: dict) -> RecordModelBase: # used for inner Record types record_type = data.pop('$type') - return get_or_create(data, RECORD_TYPE_TO_MODEL_CLASS[record_type]) + return get_or_create_model(data, RECORD_TYPE_TO_MODEL_CLASS[record_type]) def _decode_cid_hook(ref: t.Union[CID, str]) -> CID: @@ -45,7 +47,7 @@ def _decode_cid_hook(ref: t.Union[CID, str]) -> CID: def get_or_create( - model_data: t.Union[dict], model: t.Type[M] = None, *, strict: bool = True + model_data: ModelData, model: t.Type[M] = None, *, strict: bool = True ) -> t.Optional[t.Union[M, dict]]: """Get model instance from raw data. @@ -75,7 +77,7 @@ def get_or_create( return model_data -def _get_or_create(model_data: t.Union[dict], model: t.Type[M], *, strict: bool) -> t.Optional[t.Union[M, dict]]: +def _get_or_create(model_data: ModelData, model: t.Type[M], *, strict: bool) -> t.Optional[t.Union[M, dict]]: if model_data is None: return None @@ -110,16 +112,25 @@ def _get_or_create(model_data: t.Union[dict], model: t.Type[M], *, strict: bool) raise ModelError(str(e)) from e -def get_response_model(response: 'Response', model: t.Type[M]) -> t.Optional[M]: +def get_or_create_model(model_data: ModelData, model: t.Type[M]) -> t.Optional[M]: + model_instance = get_or_create(model_data, model) + if model_instance is not None and not isinstance(model_instance, model): + raise ModelError(f"Can't properly parse model of type {model}") + + return model_instance + + +def get_response_model(response: 'Response', model: t.Type[M]) -> M: if model is bool: # Could not be False? Because the exception with errors will be raised from the server return response.success - return get_or_create(response.content, model) + # return is optional if response.content is None, but doesn't occur in practice + return get_or_create_model(response.content, model) def _handle_dict_key(key: str) -> str: - if key == '_type': # System field. Replaced to original $ symbol because is not allowed in Python. + if key == '_type': # System field. Replaced to original $ symbol because it is not allowed in Python. return '$type' return key diff --git a/atproto/xrpc_client/namespaces/async_ns.py b/atproto/xrpc_client/namespaces/async_ns.py index e9b33480..6f35ffdc 100644 --- a/atproto/xrpc_client/namespaces/async_ns.py +++ b/atproto/xrpc_client/namespaces/async_ns.py @@ -6,30 +6,24 @@ import typing as t -from dataclasses import dataclass, field from atproto.xrpc_client import models -from atproto.xrpc_client.models.utils import get_or_create, get_response_model -from atproto.xrpc_client.namespaces.base import DefaultNamespace, NamespaceBase +from atproto.xrpc_client.models.utils import get_or_create_model, get_response_model +from atproto.xrpc_client.namespaces.base import AsyncNamespaceBase +if t.TYPE_CHECKING: + from atproto.xrpc_client.client.async_raw import AsyncClientRaw -@dataclass -class AppNamespace(NamespaceBase): - bsky: 'BskyNamespace' = field(default_factory=DefaultNamespace) - def __post_init__(self) -> None: +class AppNamespace(AsyncNamespaceBase): + def __init__(self, client: 'AsyncClientRaw') -> None: + super().__init__(client) self.bsky = BskyNamespace(self._client) -@dataclass -class BskyNamespace(NamespaceBase): - actor: 'ActorNamespace' = field(default_factory=DefaultNamespace) - feed: 'FeedNamespace' = field(default_factory=DefaultNamespace) - graph: 'GraphNamespace' = field(default_factory=DefaultNamespace) - notification: 'NotificationNamespace' = field(default_factory=DefaultNamespace) - unspecced: 'UnspeccedNamespace' = field(default_factory=DefaultNamespace) - - def __post_init__(self) -> None: +class BskyNamespace(AsyncNamespaceBase): + def __init__(self, client: 'AsyncClientRaw') -> None: + super().__init__(client) self.actor = ActorNamespace(self._client) self.feed = FeedNamespace(self._client) self.graph = GraphNamespace(self._client) @@ -37,8 +31,7 @@ def __post_init__(self) -> None: self.unspecced = UnspeccedNamespace(self._client) -@dataclass -class ActorNamespace(NamespaceBase): +class ActorNamespace(AsyncNamespaceBase): async def get_preferences( self, params: t.Optional[t.Union[dict, 'models.AppBskyActorGetPreferences.Params']] = None, **kwargs ) -> 'models.AppBskyActorGetPreferences.Response': @@ -55,9 +48,9 @@ async def get_preferences( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyActorGetPreferences.Params) + params_model = get_or_create_model(params, models.AppBskyActorGetPreferences.Params) response = await self._client.invoke_query( - 'app.bsky.actor.getPreferences', params=params, output_encoding='application/json', **kwargs + 'app.bsky.actor.getPreferences', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyActorGetPreferences.Response) @@ -77,9 +70,9 @@ async def get_profile( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyActorGetProfile.Params) + params_model = get_or_create_model(params, models.AppBskyActorGetProfile.Params) response = await self._client.invoke_query( - 'app.bsky.actor.getProfile', params=params, output_encoding='application/json', **kwargs + 'app.bsky.actor.getProfile', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyActorGetProfile.ResponseRef) @@ -99,9 +92,9 @@ async def get_profiles( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyActorGetProfiles.Params) + params_model = get_or_create_model(params, models.AppBskyActorGetProfiles.Params) response = await self._client.invoke_query( - 'app.bsky.actor.getProfiles', params=params, output_encoding='application/json', **kwargs + 'app.bsky.actor.getProfiles', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyActorGetProfiles.Response) @@ -121,9 +114,9 @@ async def get_suggestions( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyActorGetSuggestions.Params) + params_model = get_or_create_model(params, models.AppBskyActorGetSuggestions.Params) response = await self._client.invoke_query( - 'app.bsky.actor.getSuggestions', params=params, output_encoding='application/json', **kwargs + 'app.bsky.actor.getSuggestions', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyActorGetSuggestions.Response) @@ -141,9 +134,9 @@ async def put_preferences(self, data: t.Union[dict, 'models.AppBskyActorPutPrefe :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.AppBskyActorPutPreferences.Data) + data_model = get_or_create_model(data, models.AppBskyActorPutPreferences.Data) response = await self._client.invoke_procedure( - 'app.bsky.actor.putPreferences', data=data, input_encoding='application/json', **kwargs + 'app.bsky.actor.putPreferences', data=data_model, input_encoding='application/json', **kwargs ) return get_response_model(response, bool) @@ -163,9 +156,9 @@ async def search_actors( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyActorSearchActors.Params) + params_model = get_or_create_model(params, models.AppBskyActorSearchActors.Params) response = await self._client.invoke_query( - 'app.bsky.actor.searchActors', params=params, output_encoding='application/json', **kwargs + 'app.bsky.actor.searchActors', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyActorSearchActors.Response) @@ -185,15 +178,14 @@ async def search_actors_typeahead( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyActorSearchActorsTypeahead.Params) + params_model = get_or_create_model(params, models.AppBskyActorSearchActorsTypeahead.Params) response = await self._client.invoke_query( - 'app.bsky.actor.searchActorsTypeahead', params=params, output_encoding='application/json', **kwargs + 'app.bsky.actor.searchActorsTypeahead', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyActorSearchActorsTypeahead.Response) -@dataclass -class FeedNamespace(NamespaceBase): +class FeedNamespace(AsyncNamespaceBase): async def describe_feed_generator(self, **kwargs) -> 'models.AppBskyFeedDescribeFeedGenerator.Response': """Returns information about a given feed generator including TOS & offered feed URIs. @@ -228,9 +220,9 @@ async def get_actor_feeds( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyFeedGetActorFeeds.Params) + params_model = get_or_create_model(params, models.AppBskyFeedGetActorFeeds.Params) response = await self._client.invoke_query( - 'app.bsky.feed.getActorFeeds', params=params, output_encoding='application/json', **kwargs + 'app.bsky.feed.getActorFeeds', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyFeedGetActorFeeds.Response) @@ -250,9 +242,9 @@ async def get_author_feed( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyFeedGetAuthorFeed.Params) + params_model = get_or_create_model(params, models.AppBskyFeedGetAuthorFeed.Params) response = await self._client.invoke_query( - 'app.bsky.feed.getAuthorFeed', params=params, output_encoding='application/json', **kwargs + 'app.bsky.feed.getAuthorFeed', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyFeedGetAuthorFeed.Response) @@ -272,9 +264,9 @@ async def get_feed( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyFeedGetFeed.Params) + params_model = get_or_create_model(params, models.AppBskyFeedGetFeed.Params) response = await self._client.invoke_query( - 'app.bsky.feed.getFeed', params=params, output_encoding='application/json', **kwargs + 'app.bsky.feed.getFeed', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyFeedGetFeed.Response) @@ -294,9 +286,9 @@ async def get_feed_generator( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyFeedGetFeedGenerator.Params) + params_model = get_or_create_model(params, models.AppBskyFeedGetFeedGenerator.Params) response = await self._client.invoke_query( - 'app.bsky.feed.getFeedGenerator', params=params, output_encoding='application/json', **kwargs + 'app.bsky.feed.getFeedGenerator', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyFeedGetFeedGenerator.Response) @@ -316,9 +308,9 @@ async def get_feed_generators( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyFeedGetFeedGenerators.Params) + params_model = get_or_create_model(params, models.AppBskyFeedGetFeedGenerators.Params) response = await self._client.invoke_query( - 'app.bsky.feed.getFeedGenerators', params=params, output_encoding='application/json', **kwargs + 'app.bsky.feed.getFeedGenerators', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyFeedGetFeedGenerators.Response) @@ -338,9 +330,9 @@ async def get_feed_skeleton( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyFeedGetFeedSkeleton.Params) + params_model = get_or_create_model(params, models.AppBskyFeedGetFeedSkeleton.Params) response = await self._client.invoke_query( - 'app.bsky.feed.getFeedSkeleton', params=params, output_encoding='application/json', **kwargs + 'app.bsky.feed.getFeedSkeleton', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyFeedGetFeedSkeleton.Response) @@ -360,9 +352,9 @@ async def get_likes( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyFeedGetLikes.Params) + params_model = get_or_create_model(params, models.AppBskyFeedGetLikes.Params) response = await self._client.invoke_query( - 'app.bsky.feed.getLikes', params=params, output_encoding='application/json', **kwargs + 'app.bsky.feed.getLikes', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyFeedGetLikes.Response) @@ -382,9 +374,9 @@ async def get_post_thread( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyFeedGetPostThread.Params) + params_model = get_or_create_model(params, models.AppBskyFeedGetPostThread.Params) response = await self._client.invoke_query( - 'app.bsky.feed.getPostThread', params=params, output_encoding='application/json', **kwargs + 'app.bsky.feed.getPostThread', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyFeedGetPostThread.Response) @@ -404,9 +396,9 @@ async def get_posts( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyFeedGetPosts.Params) + params_model = get_or_create_model(params, models.AppBskyFeedGetPosts.Params) response = await self._client.invoke_query( - 'app.bsky.feed.getPosts', params=params, output_encoding='application/json', **kwargs + 'app.bsky.feed.getPosts', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyFeedGetPosts.Response) @@ -426,9 +418,9 @@ async def get_reposted_by( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyFeedGetRepostedBy.Params) + params_model = get_or_create_model(params, models.AppBskyFeedGetRepostedBy.Params) response = await self._client.invoke_query( - 'app.bsky.feed.getRepostedBy', params=params, output_encoding='application/json', **kwargs + 'app.bsky.feed.getRepostedBy', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyFeedGetRepostedBy.Response) @@ -448,15 +440,14 @@ async def get_timeline( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyFeedGetTimeline.Params) + params_model = get_or_create_model(params, models.AppBskyFeedGetTimeline.Params) response = await self._client.invoke_query( - 'app.bsky.feed.getTimeline', params=params, output_encoding='application/json', **kwargs + 'app.bsky.feed.getTimeline', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyFeedGetTimeline.Response) -@dataclass -class GraphNamespace(NamespaceBase): +class GraphNamespace(AsyncNamespaceBase): async def get_blocks( self, params: t.Optional[t.Union[dict, 'models.AppBskyGraphGetBlocks.Params']] = None, **kwargs ) -> 'models.AppBskyGraphGetBlocks.Response': @@ -473,9 +464,9 @@ async def get_blocks( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyGraphGetBlocks.Params) + params_model = get_or_create_model(params, models.AppBskyGraphGetBlocks.Params) response = await self._client.invoke_query( - 'app.bsky.graph.getBlocks', params=params, output_encoding='application/json', **kwargs + 'app.bsky.graph.getBlocks', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyGraphGetBlocks.Response) @@ -495,9 +486,9 @@ async def get_followers( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyGraphGetFollowers.Params) + params_model = get_or_create_model(params, models.AppBskyGraphGetFollowers.Params) response = await self._client.invoke_query( - 'app.bsky.graph.getFollowers', params=params, output_encoding='application/json', **kwargs + 'app.bsky.graph.getFollowers', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyGraphGetFollowers.Response) @@ -517,9 +508,9 @@ async def get_follows( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyGraphGetFollows.Params) + params_model = get_or_create_model(params, models.AppBskyGraphGetFollows.Params) response = await self._client.invoke_query( - 'app.bsky.graph.getFollows', params=params, output_encoding='application/json', **kwargs + 'app.bsky.graph.getFollows', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyGraphGetFollows.Response) @@ -539,9 +530,9 @@ async def get_list( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyGraphGetList.Params) + params_model = get_or_create_model(params, models.AppBskyGraphGetList.Params) response = await self._client.invoke_query( - 'app.bsky.graph.getList', params=params, output_encoding='application/json', **kwargs + 'app.bsky.graph.getList', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyGraphGetList.Response) @@ -561,9 +552,9 @@ async def get_list_mutes( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyGraphGetListMutes.Params) + params_model = get_or_create_model(params, models.AppBskyGraphGetListMutes.Params) response = await self._client.invoke_query( - 'app.bsky.graph.getListMutes', params=params, output_encoding='application/json', **kwargs + 'app.bsky.graph.getListMutes', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyGraphGetListMutes.Response) @@ -583,9 +574,9 @@ async def get_lists( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyGraphGetLists.Params) + params_model = get_or_create_model(params, models.AppBskyGraphGetLists.Params) response = await self._client.invoke_query( - 'app.bsky.graph.getLists', params=params, output_encoding='application/json', **kwargs + 'app.bsky.graph.getLists', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyGraphGetLists.Response) @@ -605,9 +596,9 @@ async def get_mutes( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyGraphGetMutes.Params) + params_model = get_or_create_model(params, models.AppBskyGraphGetMutes.Params) response = await self._client.invoke_query( - 'app.bsky.graph.getMutes', params=params, output_encoding='application/json', **kwargs + 'app.bsky.graph.getMutes', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyGraphGetMutes.Response) @@ -625,9 +616,9 @@ async def mute_actor(self, data: t.Union[dict, 'models.AppBskyGraphMuteActor.Dat :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.AppBskyGraphMuteActor.Data) + data_model = get_or_create_model(data, models.AppBskyGraphMuteActor.Data) response = await self._client.invoke_procedure( - 'app.bsky.graph.muteActor', data=data, input_encoding='application/json', **kwargs + 'app.bsky.graph.muteActor', data=data_model, input_encoding='application/json', **kwargs ) return get_response_model(response, bool) @@ -645,9 +636,9 @@ async def mute_actor_list(self, data: t.Union[dict, 'models.AppBskyGraphMuteActo :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.AppBskyGraphMuteActorList.Data) + data_model = get_or_create_model(data, models.AppBskyGraphMuteActorList.Data) response = await self._client.invoke_procedure( - 'app.bsky.graph.muteActorList', data=data, input_encoding='application/json', **kwargs + 'app.bsky.graph.muteActorList', data=data_model, input_encoding='application/json', **kwargs ) return get_response_model(response, bool) @@ -665,9 +656,9 @@ async def unmute_actor(self, data: t.Union[dict, 'models.AppBskyGraphUnmuteActor :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.AppBskyGraphUnmuteActor.Data) + data_model = get_or_create_model(data, models.AppBskyGraphUnmuteActor.Data) response = await self._client.invoke_procedure( - 'app.bsky.graph.unmuteActor', data=data, input_encoding='application/json', **kwargs + 'app.bsky.graph.unmuteActor', data=data_model, input_encoding='application/json', **kwargs ) return get_response_model(response, bool) @@ -685,15 +676,14 @@ async def unmute_actor_list(self, data: t.Union[dict, 'models.AppBskyGraphUnmute :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.AppBskyGraphUnmuteActorList.Data) + data_model = get_or_create_model(data, models.AppBskyGraphUnmuteActorList.Data) response = await self._client.invoke_procedure( - 'app.bsky.graph.unmuteActorList', data=data, input_encoding='application/json', **kwargs + 'app.bsky.graph.unmuteActorList', data=data_model, input_encoding='application/json', **kwargs ) return get_response_model(response, bool) -@dataclass -class UnspeccedNamespace(NamespaceBase): +class UnspeccedNamespace(AsyncNamespaceBase): async def get_popular( self, params: t.Optional[t.Union[dict, 'models.AppBskyUnspeccedGetPopular.Params']] = None, **kwargs ) -> 'models.AppBskyUnspeccedGetPopular.Response': @@ -710,9 +700,9 @@ async def get_popular( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyUnspeccedGetPopular.Params) + params_model = get_or_create_model(params, models.AppBskyUnspeccedGetPopular.Params) response = await self._client.invoke_query( - 'app.bsky.unspecced.getPopular', params=params, output_encoding='application/json', **kwargs + 'app.bsky.unspecced.getPopular', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyUnspeccedGetPopular.Response) @@ -735,8 +725,7 @@ async def get_popular_feed_generators(self, **kwargs) -> 'models.AppBskyUnspecce return get_response_model(response, models.AppBskyUnspeccedGetPopularFeedGenerators.Response) -@dataclass -class NotificationNamespace(NamespaceBase): +class NotificationNamespace(AsyncNamespaceBase): async def get_unread_count( self, params: t.Optional[t.Union[dict, 'models.AppBskyNotificationGetUnreadCount.Params']] = None, **kwargs ) -> 'models.AppBskyNotificationGetUnreadCount.Response': @@ -753,9 +742,9 @@ async def get_unread_count( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyNotificationGetUnreadCount.Params) + params_model = get_or_create_model(params, models.AppBskyNotificationGetUnreadCount.Params) response = await self._client.invoke_query( - 'app.bsky.notification.getUnreadCount', params=params, output_encoding='application/json', **kwargs + 'app.bsky.notification.getUnreadCount', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyNotificationGetUnreadCount.Response) @@ -775,9 +764,9 @@ async def list_notifications( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyNotificationListNotifications.Params) + params_model = get_or_create_model(params, models.AppBskyNotificationListNotifications.Params) response = await self._client.invoke_query( - 'app.bsky.notification.listNotifications', params=params, output_encoding='application/json', **kwargs + 'app.bsky.notification.listNotifications', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyNotificationListNotifications.Response) @@ -795,32 +784,22 @@ async def update_seen(self, data: t.Union[dict, 'models.AppBskyNotificationUpdat :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.AppBskyNotificationUpdateSeen.Data) + data_model = get_or_create_model(data, models.AppBskyNotificationUpdateSeen.Data) response = await self._client.invoke_procedure( - 'app.bsky.notification.updateSeen', data=data, input_encoding='application/json', **kwargs + 'app.bsky.notification.updateSeen', data=data_model, input_encoding='application/json', **kwargs ) return get_response_model(response, bool) -@dataclass -class ComNamespace(NamespaceBase): - atproto: 'AtprotoNamespace' = field(default_factory=DefaultNamespace) - - def __post_init__(self) -> None: +class ComNamespace(AsyncNamespaceBase): + def __init__(self, client: 'AsyncClientRaw') -> None: + super().__init__(client) self.atproto = AtprotoNamespace(self._client) -@dataclass -class AtprotoNamespace(NamespaceBase): - admin: 'AdminNamespace' = field(default_factory=DefaultNamespace) - identity: 'IdentityNamespace' = field(default_factory=DefaultNamespace) - label: 'LabelNamespace' = field(default_factory=DefaultNamespace) - moderation: 'ModerationNamespace' = field(default_factory=DefaultNamespace) - repo: 'RepoNamespace' = field(default_factory=DefaultNamespace) - server: 'ServerNamespace' = field(default_factory=DefaultNamespace) - sync: 'SyncNamespace' = field(default_factory=DefaultNamespace) - - def __post_init__(self) -> None: +class AtprotoNamespace(AsyncNamespaceBase): + def __init__(self, client: 'AsyncClientRaw') -> None: + super().__init__(client) self.admin = AdminNamespace(self._client) self.identity = IdentityNamespace(self._client) self.label = LabelNamespace(self._client) @@ -830,8 +809,7 @@ def __post_init__(self) -> None: self.sync = SyncNamespace(self._client) -@dataclass -class SyncNamespace(NamespaceBase): +class SyncNamespace(AsyncNamespaceBase): async def get_blob( self, params: t.Union[dict, 'models.ComAtprotoSyncGetBlob.Params'], **kwargs ) -> 'models.ComAtprotoSyncGetBlob.Response': @@ -848,9 +826,9 @@ async def get_blob( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoSyncGetBlob.Params) + params_model = get_or_create_model(params, models.ComAtprotoSyncGetBlob.Params) response = await self._client.invoke_query( - 'com.atproto.sync.getBlob', params=params, output_encoding='*/*', **kwargs + 'com.atproto.sync.getBlob', params=params_model, output_encoding='*/*', **kwargs ) return get_response_model(response, models.ComAtprotoSyncGetBlob.Response) @@ -870,9 +848,9 @@ async def get_blocks( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoSyncGetBlocks.Params) + params_model = get_or_create_model(params, models.ComAtprotoSyncGetBlocks.Params) response = await self._client.invoke_query( - 'com.atproto.sync.getBlocks', params=params, output_encoding='application/vnd.ipld.car', **kwargs + 'com.atproto.sync.getBlocks', params=params_model, output_encoding='application/vnd.ipld.car', **kwargs ) return get_response_model(response, models.ComAtprotoSyncGetBlocks.Response) @@ -892,9 +870,9 @@ async def get_checkout( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoSyncGetCheckout.Params) + params_model = get_or_create_model(params, models.ComAtprotoSyncGetCheckout.Params) response = await self._client.invoke_query( - 'com.atproto.sync.getCheckout', params=params, output_encoding='application/vnd.ipld.car', **kwargs + 'com.atproto.sync.getCheckout', params=params_model, output_encoding='application/vnd.ipld.car', **kwargs ) return get_response_model(response, models.ComAtprotoSyncGetCheckout.Response) @@ -914,9 +892,9 @@ async def get_commit_path( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoSyncGetCommitPath.Params) + params_model = get_or_create_model(params, models.ComAtprotoSyncGetCommitPath.Params) response = await self._client.invoke_query( - 'com.atproto.sync.getCommitPath', params=params, output_encoding='application/json', **kwargs + 'com.atproto.sync.getCommitPath', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.ComAtprotoSyncGetCommitPath.Response) @@ -936,9 +914,9 @@ async def get_head( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoSyncGetHead.Params) + params_model = get_or_create_model(params, models.ComAtprotoSyncGetHead.Params) response = await self._client.invoke_query( - 'com.atproto.sync.getHead', params=params, output_encoding='application/json', **kwargs + 'com.atproto.sync.getHead', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.ComAtprotoSyncGetHead.Response) @@ -958,9 +936,9 @@ async def get_record( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoSyncGetRecord.Params) + params_model = get_or_create_model(params, models.ComAtprotoSyncGetRecord.Params) response = await self._client.invoke_query( - 'com.atproto.sync.getRecord', params=params, output_encoding='application/vnd.ipld.car', **kwargs + 'com.atproto.sync.getRecord', params=params_model, output_encoding='application/vnd.ipld.car', **kwargs ) return get_response_model(response, models.ComAtprotoSyncGetRecord.Response) @@ -980,9 +958,9 @@ async def get_repo( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoSyncGetRepo.Params) + params_model = get_or_create_model(params, models.ComAtprotoSyncGetRepo.Params) response = await self._client.invoke_query( - 'com.atproto.sync.getRepo', params=params, output_encoding='application/vnd.ipld.car', **kwargs + 'com.atproto.sync.getRepo', params=params_model, output_encoding='application/vnd.ipld.car', **kwargs ) return get_response_model(response, models.ComAtprotoSyncGetRepo.Response) @@ -1002,9 +980,9 @@ async def list_blobs( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoSyncListBlobs.Params) + params_model = get_or_create_model(params, models.ComAtprotoSyncListBlobs.Params) response = await self._client.invoke_query( - 'com.atproto.sync.listBlobs', params=params, output_encoding='application/json', **kwargs + 'com.atproto.sync.listBlobs', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.ComAtprotoSyncListBlobs.Response) @@ -1024,9 +1002,9 @@ async def list_repos( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoSyncListRepos.Params) + params_model = get_or_create_model(params, models.ComAtprotoSyncListRepos.Params) response = await self._client.invoke_query( - 'com.atproto.sync.listRepos', params=params, output_encoding='application/json', **kwargs + 'com.atproto.sync.listRepos', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.ComAtprotoSyncListRepos.Response) @@ -1046,8 +1024,8 @@ async def notify_of_update( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoSyncNotifyOfUpdate.Params) - response = await self._client.invoke_query('com.atproto.sync.notifyOfUpdate', params=params, **kwargs) + params_model = get_or_create_model(params, models.ComAtprotoSyncNotifyOfUpdate.Params) + response = await self._client.invoke_query('com.atproto.sync.notifyOfUpdate', params=params_model, **kwargs) return get_response_model(response, bool) async def request_crawl(self, params: t.Union[dict, 'models.ComAtprotoSyncRequestCrawl.Params'], **kwargs) -> bool: @@ -1064,13 +1042,12 @@ async def request_crawl(self, params: t.Union[dict, 'models.ComAtprotoSyncReques :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoSyncRequestCrawl.Params) - response = await self._client.invoke_query('com.atproto.sync.requestCrawl', params=params, **kwargs) + params_model = get_or_create_model(params, models.ComAtprotoSyncRequestCrawl.Params) + response = await self._client.invoke_query('com.atproto.sync.requestCrawl', params=params_model, **kwargs) return get_response_model(response, bool) -@dataclass -class ServerNamespace(NamespaceBase): +class ServerNamespace(AsyncNamespaceBase): async def create_account( self, data: t.Union[dict, 'models.ComAtprotoServerCreateAccount.Data'], **kwargs ) -> 'models.ComAtprotoServerCreateAccount.Response': @@ -1087,10 +1064,10 @@ async def create_account( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoServerCreateAccount.Data) + data_model = get_or_create_model(data, models.ComAtprotoServerCreateAccount.Data) response = await self._client.invoke_procedure( 'com.atproto.server.createAccount', - data=data, + data=data_model, input_encoding='application/json', output_encoding='application/json', **kwargs, @@ -1113,10 +1090,10 @@ async def create_app_password( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoServerCreateAppPassword.Data) + data_model = get_or_create_model(data, models.ComAtprotoServerCreateAppPassword.Data) response = await self._client.invoke_procedure( 'com.atproto.server.createAppPassword', - data=data, + data=data_model, input_encoding='application/json', output_encoding='application/json', **kwargs, @@ -1139,10 +1116,10 @@ async def create_invite_code( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoServerCreateInviteCode.Data) + data_model = get_or_create_model(data, models.ComAtprotoServerCreateInviteCode.Data) response = await self._client.invoke_procedure( 'com.atproto.server.createInviteCode', - data=data, + data=data_model, input_encoding='application/json', output_encoding='application/json', **kwargs, @@ -1165,10 +1142,10 @@ async def create_invite_codes( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoServerCreateInviteCodes.Data) + data_model = get_or_create_model(data, models.ComAtprotoServerCreateInviteCodes.Data) response = await self._client.invoke_procedure( 'com.atproto.server.createInviteCodes', - data=data, + data=data_model, input_encoding='application/json', output_encoding='application/json', **kwargs, @@ -1191,10 +1168,10 @@ async def create_session( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoServerCreateSession.Data) + data_model = get_or_create_model(data, models.ComAtprotoServerCreateSession.Data) response = await self._client.invoke_procedure( 'com.atproto.server.createSession', - data=data, + data=data_model, input_encoding='application/json', output_encoding='application/json', **kwargs, @@ -1215,9 +1192,9 @@ async def delete_account(self, data: t.Union[dict, 'models.ComAtprotoServerDelet :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoServerDeleteAccount.Data) + data_model = get_or_create_model(data, models.ComAtprotoServerDeleteAccount.Data) response = await self._client.invoke_procedure( - 'com.atproto.server.deleteAccount', data=data, input_encoding='application/json', **kwargs + 'com.atproto.server.deleteAccount', data=data_model, input_encoding='application/json', **kwargs ) return get_response_model(response, bool) @@ -1271,9 +1248,12 @@ async def get_account_invite_codes( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoServerGetAccountInviteCodes.Params) + params_model = get_or_create_model(params, models.ComAtprotoServerGetAccountInviteCodes.Params) response = await self._client.invoke_query( - 'com.atproto.server.getAccountInviteCodes', params=params, output_encoding='application/json', **kwargs + 'com.atproto.server.getAccountInviteCodes', + params=params_model, + output_encoding='application/json', + **kwargs, ) return get_response_model(response, models.ComAtprotoServerGetAccountInviteCodes.Response) @@ -1363,9 +1343,9 @@ async def request_password_reset( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoServerRequestPasswordReset.Data) + data_model = get_or_create_model(data, models.ComAtprotoServerRequestPasswordReset.Data) response = await self._client.invoke_procedure( - 'com.atproto.server.requestPasswordReset', data=data, input_encoding='application/json', **kwargs + 'com.atproto.server.requestPasswordReset', data=data_model, input_encoding='application/json', **kwargs ) return get_response_model(response, bool) @@ -1383,9 +1363,9 @@ async def reset_password(self, data: t.Union[dict, 'models.ComAtprotoServerReset :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoServerResetPassword.Data) + data_model = get_or_create_model(data, models.ComAtprotoServerResetPassword.Data) response = await self._client.invoke_procedure( - 'com.atproto.server.resetPassword', data=data, input_encoding='application/json', **kwargs + 'com.atproto.server.resetPassword', data=data_model, input_encoding='application/json', **kwargs ) return get_response_model(response, bool) @@ -1405,15 +1385,14 @@ async def revoke_app_password( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoServerRevokeAppPassword.Data) + data_model = get_or_create_model(data, models.ComAtprotoServerRevokeAppPassword.Data) response = await self._client.invoke_procedure( - 'com.atproto.server.revokeAppPassword', data=data, input_encoding='application/json', **kwargs + 'com.atproto.server.revokeAppPassword', data=data_model, input_encoding='application/json', **kwargs ) return get_response_model(response, bool) -@dataclass -class RepoNamespace(NamespaceBase): +class RepoNamespace(AsyncNamespaceBase): async def apply_writes(self, data: t.Union[dict, 'models.ComAtprotoRepoApplyWrites.Data'], **kwargs) -> bool: """Apply a batch transaction of creates, updates, and deletes. @@ -1428,9 +1407,9 @@ async def apply_writes(self, data: t.Union[dict, 'models.ComAtprotoRepoApplyWrit :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoRepoApplyWrites.Data) + data_model = get_or_create_model(data, models.ComAtprotoRepoApplyWrites.Data) response = await self._client.invoke_procedure( - 'com.atproto.repo.applyWrites', data=data, input_encoding='application/json', **kwargs + 'com.atproto.repo.applyWrites', data=data_model, input_encoding='application/json', **kwargs ) return get_response_model(response, bool) @@ -1450,10 +1429,10 @@ async def create_record( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoRepoCreateRecord.Data) + data_model = get_or_create_model(data, models.ComAtprotoRepoCreateRecord.Data) response = await self._client.invoke_procedure( 'com.atproto.repo.createRecord', - data=data, + data=data_model, input_encoding='application/json', output_encoding='application/json', **kwargs, @@ -1474,9 +1453,9 @@ async def delete_record(self, data: t.Union[dict, 'models.ComAtprotoRepoDeleteRe :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoRepoDeleteRecord.Data) + data_model = get_or_create_model(data, models.ComAtprotoRepoDeleteRecord.Data) response = await self._client.invoke_procedure( - 'com.atproto.repo.deleteRecord', data=data, input_encoding='application/json', **kwargs + 'com.atproto.repo.deleteRecord', data=data_model, input_encoding='application/json', **kwargs ) return get_response_model(response, bool) @@ -1496,9 +1475,9 @@ async def describe_repo( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoRepoDescribeRepo.Params) + params_model = get_or_create_model(params, models.ComAtprotoRepoDescribeRepo.Params) response = await self._client.invoke_query( - 'com.atproto.repo.describeRepo', params=params, output_encoding='application/json', **kwargs + 'com.atproto.repo.describeRepo', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.ComAtprotoRepoDescribeRepo.Response) @@ -1518,9 +1497,9 @@ async def get_record( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoRepoGetRecord.Params) + params_model = get_or_create_model(params, models.ComAtprotoRepoGetRecord.Params) response = await self._client.invoke_query( - 'com.atproto.repo.getRecord', params=params, output_encoding='application/json', **kwargs + 'com.atproto.repo.getRecord', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.ComAtprotoRepoGetRecord.Response) @@ -1540,9 +1519,9 @@ async def list_records( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoRepoListRecords.Params) + params_model = get_or_create_model(params, models.ComAtprotoRepoListRecords.Params) response = await self._client.invoke_query( - 'com.atproto.repo.listRecords', params=params, output_encoding='application/json', **kwargs + 'com.atproto.repo.listRecords', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.ComAtprotoRepoListRecords.Response) @@ -1562,10 +1541,10 @@ async def put_record( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoRepoPutRecord.Data) + data_model = get_or_create_model(data, models.ComAtprotoRepoPutRecord.Data) response = await self._client.invoke_procedure( 'com.atproto.repo.putRecord', - data=data, + data=data_model, input_encoding='application/json', output_encoding='application/json', **kwargs, @@ -1586,9 +1565,9 @@ async def rebase_repo(self, data: t.Union[dict, 'models.ComAtprotoRepoRebaseRepo :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoRepoRebaseRepo.Data) + data_model = get_or_create_model(data, models.ComAtprotoRepoRebaseRepo.Data) response = await self._client.invoke_procedure( - 'com.atproto.repo.rebaseRepo', data=data, input_encoding='application/json', **kwargs + 'com.atproto.repo.rebaseRepo', data=data_model, input_encoding='application/json', **kwargs ) return get_response_model(response, bool) @@ -1614,8 +1593,7 @@ async def upload_blob( return get_response_model(response, models.ComAtprotoRepoUploadBlob.Response) -@dataclass -class AdminNamespace(NamespaceBase): +class AdminNamespace(AsyncNamespaceBase): async def disable_account_invites( self, data: t.Union[dict, 'models.ComAtprotoAdminDisableAccountInvites.Data'], **kwargs ) -> bool: @@ -1632,9 +1610,9 @@ async def disable_account_invites( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoAdminDisableAccountInvites.Data) + data_model = get_or_create_model(data, models.ComAtprotoAdminDisableAccountInvites.Data) response = await self._client.invoke_procedure( - 'com.atproto.admin.disableAccountInvites', data=data, input_encoding='application/json', **kwargs + 'com.atproto.admin.disableAccountInvites', data=data_model, input_encoding='application/json', **kwargs ) return get_response_model(response, bool) @@ -1654,9 +1632,9 @@ async def disable_invite_codes( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoAdminDisableInviteCodes.Data) + data_model = get_or_create_model(data, models.ComAtprotoAdminDisableInviteCodes.Data) response = await self._client.invoke_procedure( - 'com.atproto.admin.disableInviteCodes', data=data, input_encoding='application/json', **kwargs + 'com.atproto.admin.disableInviteCodes', data=data_model, input_encoding='application/json', **kwargs ) return get_response_model(response, bool) @@ -1676,9 +1654,9 @@ async def enable_account_invites( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoAdminEnableAccountInvites.Data) + data_model = get_or_create_model(data, models.ComAtprotoAdminEnableAccountInvites.Data) response = await self._client.invoke_procedure( - 'com.atproto.admin.enableAccountInvites', data=data, input_encoding='application/json', **kwargs + 'com.atproto.admin.enableAccountInvites', data=data_model, input_encoding='application/json', **kwargs ) return get_response_model(response, bool) @@ -1698,9 +1676,9 @@ async def get_invite_codes( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoAdminGetInviteCodes.Params) + params_model = get_or_create_model(params, models.ComAtprotoAdminGetInviteCodes.Params) response = await self._client.invoke_query( - 'com.atproto.admin.getInviteCodes', params=params, output_encoding='application/json', **kwargs + 'com.atproto.admin.getInviteCodes', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.ComAtprotoAdminGetInviteCodes.Response) @@ -1720,9 +1698,9 @@ async def get_moderation_action( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoAdminGetModerationAction.Params) + params_model = get_or_create_model(params, models.ComAtprotoAdminGetModerationAction.Params) response = await self._client.invoke_query( - 'com.atproto.admin.getModerationAction', params=params, output_encoding='application/json', **kwargs + 'com.atproto.admin.getModerationAction', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.ComAtprotoAdminGetModerationAction.ResponseRef) @@ -1742,9 +1720,9 @@ async def get_moderation_actions( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoAdminGetModerationActions.Params) + params_model = get_or_create_model(params, models.ComAtprotoAdminGetModerationActions.Params) response = await self._client.invoke_query( - 'com.atproto.admin.getModerationActions', params=params, output_encoding='application/json', **kwargs + 'com.atproto.admin.getModerationActions', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.ComAtprotoAdminGetModerationActions.Response) @@ -1764,9 +1742,9 @@ async def get_moderation_report( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoAdminGetModerationReport.Params) + params_model = get_or_create_model(params, models.ComAtprotoAdminGetModerationReport.Params) response = await self._client.invoke_query( - 'com.atproto.admin.getModerationReport', params=params, output_encoding='application/json', **kwargs + 'com.atproto.admin.getModerationReport', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.ComAtprotoAdminGetModerationReport.ResponseRef) @@ -1786,9 +1764,9 @@ async def get_moderation_reports( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoAdminGetModerationReports.Params) + params_model = get_or_create_model(params, models.ComAtprotoAdminGetModerationReports.Params) response = await self._client.invoke_query( - 'com.atproto.admin.getModerationReports', params=params, output_encoding='application/json', **kwargs + 'com.atproto.admin.getModerationReports', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.ComAtprotoAdminGetModerationReports.Response) @@ -1808,9 +1786,9 @@ async def get_record( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoAdminGetRecord.Params) + params_model = get_or_create_model(params, models.ComAtprotoAdminGetRecord.Params) response = await self._client.invoke_query( - 'com.atproto.admin.getRecord', params=params, output_encoding='application/json', **kwargs + 'com.atproto.admin.getRecord', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.ComAtprotoAdminGetRecord.ResponseRef) @@ -1830,9 +1808,9 @@ async def get_repo( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoAdminGetRepo.Params) + params_model = get_or_create_model(params, models.ComAtprotoAdminGetRepo.Params) response = await self._client.invoke_query( - 'com.atproto.admin.getRepo', params=params, output_encoding='application/json', **kwargs + 'com.atproto.admin.getRepo', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.ComAtprotoAdminGetRepo.ResponseRef) @@ -1852,10 +1830,10 @@ async def resolve_moderation_reports( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoAdminResolveModerationReports.Data) + data_model = get_or_create_model(data, models.ComAtprotoAdminResolveModerationReports.Data) response = await self._client.invoke_procedure( 'com.atproto.admin.resolveModerationReports', - data=data, + data=data_model, input_encoding='application/json', output_encoding='application/json', **kwargs, @@ -1878,10 +1856,10 @@ async def reverse_moderation_action( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoAdminReverseModerationAction.Data) + data_model = get_or_create_model(data, models.ComAtprotoAdminReverseModerationAction.Data) response = await self._client.invoke_procedure( 'com.atproto.admin.reverseModerationAction', - data=data, + data=data_model, input_encoding='application/json', output_encoding='application/json', **kwargs, @@ -1904,9 +1882,9 @@ async def search_repos( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoAdminSearchRepos.Params) + params_model = get_or_create_model(params, models.ComAtprotoAdminSearchRepos.Params) response = await self._client.invoke_query( - 'com.atproto.admin.searchRepos', params=params, output_encoding='application/json', **kwargs + 'com.atproto.admin.searchRepos', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.ComAtprotoAdminSearchRepos.Response) @@ -1926,10 +1904,10 @@ async def take_moderation_action( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoAdminTakeModerationAction.Data) + data_model = get_or_create_model(data, models.ComAtprotoAdminTakeModerationAction.Data) response = await self._client.invoke_procedure( 'com.atproto.admin.takeModerationAction', - data=data, + data=data_model, input_encoding='application/json', output_encoding='application/json', **kwargs, @@ -1952,9 +1930,9 @@ async def update_account_email( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoAdminUpdateAccountEmail.Data) + data_model = get_or_create_model(data, models.ComAtprotoAdminUpdateAccountEmail.Data) response = await self._client.invoke_procedure( - 'com.atproto.admin.updateAccountEmail', data=data, input_encoding='application/json', **kwargs + 'com.atproto.admin.updateAccountEmail', data=data_model, input_encoding='application/json', **kwargs ) return get_response_model(response, bool) @@ -1974,15 +1952,14 @@ async def update_account_handle( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoAdminUpdateAccountHandle.Data) + data_model = get_or_create_model(data, models.ComAtprotoAdminUpdateAccountHandle.Data) response = await self._client.invoke_procedure( - 'com.atproto.admin.updateAccountHandle', data=data, input_encoding='application/json', **kwargs + 'com.atproto.admin.updateAccountHandle', data=data_model, input_encoding='application/json', **kwargs ) return get_response_model(response, bool) -@dataclass -class IdentityNamespace(NamespaceBase): +class IdentityNamespace(AsyncNamespaceBase): async def resolve_handle( self, params: t.Union[dict, 'models.ComAtprotoIdentityResolveHandle.Params'], **kwargs ) -> 'models.ComAtprotoIdentityResolveHandle.Response': @@ -1999,9 +1976,9 @@ async def resolve_handle( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoIdentityResolveHandle.Params) + params_model = get_or_create_model(params, models.ComAtprotoIdentityResolveHandle.Params) response = await self._client.invoke_query( - 'com.atproto.identity.resolveHandle', params=params, output_encoding='application/json', **kwargs + 'com.atproto.identity.resolveHandle', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.ComAtprotoIdentityResolveHandle.Response) @@ -2019,15 +1996,14 @@ async def update_handle(self, data: t.Union[dict, 'models.ComAtprotoIdentityUpda :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoIdentityUpdateHandle.Data) + data_model = get_or_create_model(data, models.ComAtprotoIdentityUpdateHandle.Data) response = await self._client.invoke_procedure( - 'com.atproto.identity.updateHandle', data=data, input_encoding='application/json', **kwargs + 'com.atproto.identity.updateHandle', data=data_model, input_encoding='application/json', **kwargs ) return get_response_model(response, bool) -@dataclass -class ModerationNamespace(NamespaceBase): +class ModerationNamespace(AsyncNamespaceBase): async def create_report( self, data: t.Union[dict, 'models.ComAtprotoModerationCreateReport.Data'], **kwargs ) -> 'models.ComAtprotoModerationCreateReport.Response': @@ -2044,10 +2020,10 @@ async def create_report( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoModerationCreateReport.Data) + data_model = get_or_create_model(data, models.ComAtprotoModerationCreateReport.Data) response = await self._client.invoke_procedure( 'com.atproto.moderation.createReport', - data=data, + data=data_model, input_encoding='application/json', output_encoding='application/json', **kwargs, @@ -2055,8 +2031,7 @@ async def create_report( return get_response_model(response, models.ComAtprotoModerationCreateReport.Response) -@dataclass -class LabelNamespace(NamespaceBase): +class LabelNamespace(AsyncNamespaceBase): async def query_labels( self, params: t.Union[dict, 'models.ComAtprotoLabelQueryLabels.Params'], **kwargs ) -> 'models.ComAtprotoLabelQueryLabels.Response': @@ -2073,8 +2048,8 @@ async def query_labels( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoLabelQueryLabels.Params) + params_model = get_or_create_model(params, models.ComAtprotoLabelQueryLabels.Params) response = await self._client.invoke_query( - 'com.atproto.label.queryLabels', params=params, output_encoding='application/json', **kwargs + 'com.atproto.label.queryLabels', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.ComAtprotoLabelQueryLabels.Response) diff --git a/atproto/xrpc_client/namespaces/base.py b/atproto/xrpc_client/namespaces/base.py index cf07c1fe..bb67eee9 100644 --- a/atproto/xrpc_client/namespaces/base.py +++ b/atproto/xrpc_client/namespaces/base.py @@ -6,9 +6,14 @@ from atproto.xrpc_client.client.raw import ClientRaw -@dataclass class NamespaceBase: - _client: t.Union['ClientRaw', 'AsyncClientRaw'] + def __init__(self, client: 'ClientRaw') -> None: + self._client: 'ClientRaw' = client + + +class AsyncNamespaceBase: + def __init__(self, client: 'AsyncClientRaw') -> None: + self._client: 'AsyncClientRaw' = client @dataclass diff --git a/atproto/xrpc_client/namespaces/sync_ns.py b/atproto/xrpc_client/namespaces/sync_ns.py index 4e93d851..0a8205bb 100644 --- a/atproto/xrpc_client/namespaces/sync_ns.py +++ b/atproto/xrpc_client/namespaces/sync_ns.py @@ -6,30 +6,24 @@ import typing as t -from dataclasses import dataclass, field from atproto.xrpc_client import models -from atproto.xrpc_client.models.utils import get_or_create, get_response_model -from atproto.xrpc_client.namespaces.base import DefaultNamespace, NamespaceBase +from atproto.xrpc_client.models.utils import get_or_create_model, get_response_model +from atproto.xrpc_client.namespaces.base import NamespaceBase +if t.TYPE_CHECKING: + from atproto.xrpc_client.client.raw import ClientRaw -@dataclass -class AppNamespace(NamespaceBase): - bsky: 'BskyNamespace' = field(default_factory=DefaultNamespace) - def __post_init__(self) -> None: +class AppNamespace(NamespaceBase): + def __init__(self, client: 'ClientRaw') -> None: + super().__init__(client) self.bsky = BskyNamespace(self._client) -@dataclass class BskyNamespace(NamespaceBase): - actor: 'ActorNamespace' = field(default_factory=DefaultNamespace) - feed: 'FeedNamespace' = field(default_factory=DefaultNamespace) - graph: 'GraphNamespace' = field(default_factory=DefaultNamespace) - notification: 'NotificationNamespace' = field(default_factory=DefaultNamespace) - unspecced: 'UnspeccedNamespace' = field(default_factory=DefaultNamespace) - - def __post_init__(self) -> None: + def __init__(self, client: 'ClientRaw') -> None: + super().__init__(client) self.actor = ActorNamespace(self._client) self.feed = FeedNamespace(self._client) self.graph = GraphNamespace(self._client) @@ -37,7 +31,6 @@ def __post_init__(self) -> None: self.unspecced = UnspeccedNamespace(self._client) -@dataclass class ActorNamespace(NamespaceBase): def get_preferences( self, params: t.Optional[t.Union[dict, 'models.AppBskyActorGetPreferences.Params']] = None, **kwargs @@ -55,9 +48,9 @@ def get_preferences( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyActorGetPreferences.Params) + params_model = get_or_create_model(params, models.AppBskyActorGetPreferences.Params) response = self._client.invoke_query( - 'app.bsky.actor.getPreferences', params=params, output_encoding='application/json', **kwargs + 'app.bsky.actor.getPreferences', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyActorGetPreferences.Response) @@ -77,9 +70,9 @@ def get_profile( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyActorGetProfile.Params) + params_model = get_or_create_model(params, models.AppBskyActorGetProfile.Params) response = self._client.invoke_query( - 'app.bsky.actor.getProfile', params=params, output_encoding='application/json', **kwargs + 'app.bsky.actor.getProfile', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyActorGetProfile.ResponseRef) @@ -99,9 +92,9 @@ def get_profiles( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyActorGetProfiles.Params) + params_model = get_or_create_model(params, models.AppBskyActorGetProfiles.Params) response = self._client.invoke_query( - 'app.bsky.actor.getProfiles', params=params, output_encoding='application/json', **kwargs + 'app.bsky.actor.getProfiles', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyActorGetProfiles.Response) @@ -121,9 +114,9 @@ def get_suggestions( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyActorGetSuggestions.Params) + params_model = get_or_create_model(params, models.AppBskyActorGetSuggestions.Params) response = self._client.invoke_query( - 'app.bsky.actor.getSuggestions', params=params, output_encoding='application/json', **kwargs + 'app.bsky.actor.getSuggestions', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyActorGetSuggestions.Response) @@ -141,9 +134,9 @@ def put_preferences(self, data: t.Union[dict, 'models.AppBskyActorPutPreferences :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.AppBskyActorPutPreferences.Data) + data_model = get_or_create_model(data, models.AppBskyActorPutPreferences.Data) response = self._client.invoke_procedure( - 'app.bsky.actor.putPreferences', data=data, input_encoding='application/json', **kwargs + 'app.bsky.actor.putPreferences', data=data_model, input_encoding='application/json', **kwargs ) return get_response_model(response, bool) @@ -163,9 +156,9 @@ def search_actors( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyActorSearchActors.Params) + params_model = get_or_create_model(params, models.AppBskyActorSearchActors.Params) response = self._client.invoke_query( - 'app.bsky.actor.searchActors', params=params, output_encoding='application/json', **kwargs + 'app.bsky.actor.searchActors', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyActorSearchActors.Response) @@ -185,14 +178,13 @@ def search_actors_typeahead( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyActorSearchActorsTypeahead.Params) + params_model = get_or_create_model(params, models.AppBskyActorSearchActorsTypeahead.Params) response = self._client.invoke_query( - 'app.bsky.actor.searchActorsTypeahead', params=params, output_encoding='application/json', **kwargs + 'app.bsky.actor.searchActorsTypeahead', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyActorSearchActorsTypeahead.Response) -@dataclass class FeedNamespace(NamespaceBase): def describe_feed_generator(self, **kwargs) -> 'models.AppBskyFeedDescribeFeedGenerator.Response': """Returns information about a given feed generator including TOS & offered feed URIs. @@ -228,9 +220,9 @@ def get_actor_feeds( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyFeedGetActorFeeds.Params) + params_model = get_or_create_model(params, models.AppBskyFeedGetActorFeeds.Params) response = self._client.invoke_query( - 'app.bsky.feed.getActorFeeds', params=params, output_encoding='application/json', **kwargs + 'app.bsky.feed.getActorFeeds', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyFeedGetActorFeeds.Response) @@ -250,9 +242,9 @@ def get_author_feed( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyFeedGetAuthorFeed.Params) + params_model = get_or_create_model(params, models.AppBskyFeedGetAuthorFeed.Params) response = self._client.invoke_query( - 'app.bsky.feed.getAuthorFeed', params=params, output_encoding='application/json', **kwargs + 'app.bsky.feed.getAuthorFeed', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyFeedGetAuthorFeed.Response) @@ -272,9 +264,9 @@ def get_feed( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyFeedGetFeed.Params) + params_model = get_or_create_model(params, models.AppBskyFeedGetFeed.Params) response = self._client.invoke_query( - 'app.bsky.feed.getFeed', params=params, output_encoding='application/json', **kwargs + 'app.bsky.feed.getFeed', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyFeedGetFeed.Response) @@ -294,9 +286,9 @@ def get_feed_generator( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyFeedGetFeedGenerator.Params) + params_model = get_or_create_model(params, models.AppBskyFeedGetFeedGenerator.Params) response = self._client.invoke_query( - 'app.bsky.feed.getFeedGenerator', params=params, output_encoding='application/json', **kwargs + 'app.bsky.feed.getFeedGenerator', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyFeedGetFeedGenerator.Response) @@ -316,9 +308,9 @@ def get_feed_generators( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyFeedGetFeedGenerators.Params) + params_model = get_or_create_model(params, models.AppBskyFeedGetFeedGenerators.Params) response = self._client.invoke_query( - 'app.bsky.feed.getFeedGenerators', params=params, output_encoding='application/json', **kwargs + 'app.bsky.feed.getFeedGenerators', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyFeedGetFeedGenerators.Response) @@ -338,9 +330,9 @@ def get_feed_skeleton( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyFeedGetFeedSkeleton.Params) + params_model = get_or_create_model(params, models.AppBskyFeedGetFeedSkeleton.Params) response = self._client.invoke_query( - 'app.bsky.feed.getFeedSkeleton', params=params, output_encoding='application/json', **kwargs + 'app.bsky.feed.getFeedSkeleton', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyFeedGetFeedSkeleton.Response) @@ -360,9 +352,9 @@ def get_likes( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyFeedGetLikes.Params) + params_model = get_or_create_model(params, models.AppBskyFeedGetLikes.Params) response = self._client.invoke_query( - 'app.bsky.feed.getLikes', params=params, output_encoding='application/json', **kwargs + 'app.bsky.feed.getLikes', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyFeedGetLikes.Response) @@ -382,9 +374,9 @@ def get_post_thread( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyFeedGetPostThread.Params) + params_model = get_or_create_model(params, models.AppBskyFeedGetPostThread.Params) response = self._client.invoke_query( - 'app.bsky.feed.getPostThread', params=params, output_encoding='application/json', **kwargs + 'app.bsky.feed.getPostThread', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyFeedGetPostThread.Response) @@ -404,9 +396,9 @@ def get_posts( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyFeedGetPosts.Params) + params_model = get_or_create_model(params, models.AppBskyFeedGetPosts.Params) response = self._client.invoke_query( - 'app.bsky.feed.getPosts', params=params, output_encoding='application/json', **kwargs + 'app.bsky.feed.getPosts', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyFeedGetPosts.Response) @@ -426,9 +418,9 @@ def get_reposted_by( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyFeedGetRepostedBy.Params) + params_model = get_or_create_model(params, models.AppBskyFeedGetRepostedBy.Params) response = self._client.invoke_query( - 'app.bsky.feed.getRepostedBy', params=params, output_encoding='application/json', **kwargs + 'app.bsky.feed.getRepostedBy', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyFeedGetRepostedBy.Response) @@ -448,14 +440,13 @@ def get_timeline( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyFeedGetTimeline.Params) + params_model = get_or_create_model(params, models.AppBskyFeedGetTimeline.Params) response = self._client.invoke_query( - 'app.bsky.feed.getTimeline', params=params, output_encoding='application/json', **kwargs + 'app.bsky.feed.getTimeline', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyFeedGetTimeline.Response) -@dataclass class GraphNamespace(NamespaceBase): def get_blocks( self, params: t.Optional[t.Union[dict, 'models.AppBskyGraphGetBlocks.Params']] = None, **kwargs @@ -473,9 +464,9 @@ def get_blocks( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyGraphGetBlocks.Params) + params_model = get_or_create_model(params, models.AppBskyGraphGetBlocks.Params) response = self._client.invoke_query( - 'app.bsky.graph.getBlocks', params=params, output_encoding='application/json', **kwargs + 'app.bsky.graph.getBlocks', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyGraphGetBlocks.Response) @@ -495,9 +486,9 @@ def get_followers( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyGraphGetFollowers.Params) + params_model = get_or_create_model(params, models.AppBskyGraphGetFollowers.Params) response = self._client.invoke_query( - 'app.bsky.graph.getFollowers', params=params, output_encoding='application/json', **kwargs + 'app.bsky.graph.getFollowers', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyGraphGetFollowers.Response) @@ -517,9 +508,9 @@ def get_follows( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyGraphGetFollows.Params) + params_model = get_or_create_model(params, models.AppBskyGraphGetFollows.Params) response = self._client.invoke_query( - 'app.bsky.graph.getFollows', params=params, output_encoding='application/json', **kwargs + 'app.bsky.graph.getFollows', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyGraphGetFollows.Response) @@ -539,9 +530,9 @@ def get_list( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyGraphGetList.Params) + params_model = get_or_create_model(params, models.AppBskyGraphGetList.Params) response = self._client.invoke_query( - 'app.bsky.graph.getList', params=params, output_encoding='application/json', **kwargs + 'app.bsky.graph.getList', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyGraphGetList.Response) @@ -561,9 +552,9 @@ def get_list_mutes( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyGraphGetListMutes.Params) + params_model = get_or_create_model(params, models.AppBskyGraphGetListMutes.Params) response = self._client.invoke_query( - 'app.bsky.graph.getListMutes', params=params, output_encoding='application/json', **kwargs + 'app.bsky.graph.getListMutes', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyGraphGetListMutes.Response) @@ -583,9 +574,9 @@ def get_lists( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyGraphGetLists.Params) + params_model = get_or_create_model(params, models.AppBskyGraphGetLists.Params) response = self._client.invoke_query( - 'app.bsky.graph.getLists', params=params, output_encoding='application/json', **kwargs + 'app.bsky.graph.getLists', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyGraphGetLists.Response) @@ -605,9 +596,9 @@ def get_mutes( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyGraphGetMutes.Params) + params_model = get_or_create_model(params, models.AppBskyGraphGetMutes.Params) response = self._client.invoke_query( - 'app.bsky.graph.getMutes', params=params, output_encoding='application/json', **kwargs + 'app.bsky.graph.getMutes', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyGraphGetMutes.Response) @@ -625,9 +616,9 @@ def mute_actor(self, data: t.Union[dict, 'models.AppBskyGraphMuteActor.Data'], * :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.AppBskyGraphMuteActor.Data) + data_model = get_or_create_model(data, models.AppBskyGraphMuteActor.Data) response = self._client.invoke_procedure( - 'app.bsky.graph.muteActor', data=data, input_encoding='application/json', **kwargs + 'app.bsky.graph.muteActor', data=data_model, input_encoding='application/json', **kwargs ) return get_response_model(response, bool) @@ -645,9 +636,9 @@ def mute_actor_list(self, data: t.Union[dict, 'models.AppBskyGraphMuteActorList. :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.AppBskyGraphMuteActorList.Data) + data_model = get_or_create_model(data, models.AppBskyGraphMuteActorList.Data) response = self._client.invoke_procedure( - 'app.bsky.graph.muteActorList', data=data, input_encoding='application/json', **kwargs + 'app.bsky.graph.muteActorList', data=data_model, input_encoding='application/json', **kwargs ) return get_response_model(response, bool) @@ -665,9 +656,9 @@ def unmute_actor(self, data: t.Union[dict, 'models.AppBskyGraphUnmuteActor.Data' :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.AppBskyGraphUnmuteActor.Data) + data_model = get_or_create_model(data, models.AppBskyGraphUnmuteActor.Data) response = self._client.invoke_procedure( - 'app.bsky.graph.unmuteActor', data=data, input_encoding='application/json', **kwargs + 'app.bsky.graph.unmuteActor', data=data_model, input_encoding='application/json', **kwargs ) return get_response_model(response, bool) @@ -685,14 +676,13 @@ def unmute_actor_list(self, data: t.Union[dict, 'models.AppBskyGraphUnmuteActorL :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.AppBskyGraphUnmuteActorList.Data) + data_model = get_or_create_model(data, models.AppBskyGraphUnmuteActorList.Data) response = self._client.invoke_procedure( - 'app.bsky.graph.unmuteActorList', data=data, input_encoding='application/json', **kwargs + 'app.bsky.graph.unmuteActorList', data=data_model, input_encoding='application/json', **kwargs ) return get_response_model(response, bool) -@dataclass class UnspeccedNamespace(NamespaceBase): def get_popular( self, params: t.Optional[t.Union[dict, 'models.AppBskyUnspeccedGetPopular.Params']] = None, **kwargs @@ -710,9 +700,9 @@ def get_popular( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyUnspeccedGetPopular.Params) + params_model = get_or_create_model(params, models.AppBskyUnspeccedGetPopular.Params) response = self._client.invoke_query( - 'app.bsky.unspecced.getPopular', params=params, output_encoding='application/json', **kwargs + 'app.bsky.unspecced.getPopular', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyUnspeccedGetPopular.Response) @@ -735,7 +725,6 @@ def get_popular_feed_generators(self, **kwargs) -> 'models.AppBskyUnspeccedGetPo return get_response_model(response, models.AppBskyUnspeccedGetPopularFeedGenerators.Response) -@dataclass class NotificationNamespace(NamespaceBase): def get_unread_count( self, params: t.Optional[t.Union[dict, 'models.AppBskyNotificationGetUnreadCount.Params']] = None, **kwargs @@ -753,9 +742,9 @@ def get_unread_count( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyNotificationGetUnreadCount.Params) + params_model = get_or_create_model(params, models.AppBskyNotificationGetUnreadCount.Params) response = self._client.invoke_query( - 'app.bsky.notification.getUnreadCount', params=params, output_encoding='application/json', **kwargs + 'app.bsky.notification.getUnreadCount', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyNotificationGetUnreadCount.Response) @@ -775,9 +764,9 @@ def list_notifications( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.AppBskyNotificationListNotifications.Params) + params_model = get_or_create_model(params, models.AppBskyNotificationListNotifications.Params) response = self._client.invoke_query( - 'app.bsky.notification.listNotifications', params=params, output_encoding='application/json', **kwargs + 'app.bsky.notification.listNotifications', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.AppBskyNotificationListNotifications.Response) @@ -795,32 +784,22 @@ def update_seen(self, data: t.Union[dict, 'models.AppBskyNotificationUpdateSeen. :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.AppBskyNotificationUpdateSeen.Data) + data_model = get_or_create_model(data, models.AppBskyNotificationUpdateSeen.Data) response = self._client.invoke_procedure( - 'app.bsky.notification.updateSeen', data=data, input_encoding='application/json', **kwargs + 'app.bsky.notification.updateSeen', data=data_model, input_encoding='application/json', **kwargs ) return get_response_model(response, bool) -@dataclass class ComNamespace(NamespaceBase): - atproto: 'AtprotoNamespace' = field(default_factory=DefaultNamespace) - - def __post_init__(self) -> None: + def __init__(self, client: 'ClientRaw') -> None: + super().__init__(client) self.atproto = AtprotoNamespace(self._client) -@dataclass class AtprotoNamespace(NamespaceBase): - admin: 'AdminNamespace' = field(default_factory=DefaultNamespace) - identity: 'IdentityNamespace' = field(default_factory=DefaultNamespace) - label: 'LabelNamespace' = field(default_factory=DefaultNamespace) - moderation: 'ModerationNamespace' = field(default_factory=DefaultNamespace) - repo: 'RepoNamespace' = field(default_factory=DefaultNamespace) - server: 'ServerNamespace' = field(default_factory=DefaultNamespace) - sync: 'SyncNamespace' = field(default_factory=DefaultNamespace) - - def __post_init__(self) -> None: + def __init__(self, client: 'ClientRaw') -> None: + super().__init__(client) self.admin = AdminNamespace(self._client) self.identity = IdentityNamespace(self._client) self.label = LabelNamespace(self._client) @@ -830,7 +809,6 @@ def __post_init__(self) -> None: self.sync = SyncNamespace(self._client) -@dataclass class SyncNamespace(NamespaceBase): def get_blob( self, params: t.Union[dict, 'models.ComAtprotoSyncGetBlob.Params'], **kwargs @@ -848,8 +826,10 @@ def get_blob( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoSyncGetBlob.Params) - response = self._client.invoke_query('com.atproto.sync.getBlob', params=params, output_encoding='*/*', **kwargs) + params_model = get_or_create_model(params, models.ComAtprotoSyncGetBlob.Params) + response = self._client.invoke_query( + 'com.atproto.sync.getBlob', params=params_model, output_encoding='*/*', **kwargs + ) return get_response_model(response, models.ComAtprotoSyncGetBlob.Response) def get_blocks( @@ -868,9 +848,9 @@ def get_blocks( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoSyncGetBlocks.Params) + params_model = get_or_create_model(params, models.ComAtprotoSyncGetBlocks.Params) response = self._client.invoke_query( - 'com.atproto.sync.getBlocks', params=params, output_encoding='application/vnd.ipld.car', **kwargs + 'com.atproto.sync.getBlocks', params=params_model, output_encoding='application/vnd.ipld.car', **kwargs ) return get_response_model(response, models.ComAtprotoSyncGetBlocks.Response) @@ -890,9 +870,9 @@ def get_checkout( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoSyncGetCheckout.Params) + params_model = get_or_create_model(params, models.ComAtprotoSyncGetCheckout.Params) response = self._client.invoke_query( - 'com.atproto.sync.getCheckout', params=params, output_encoding='application/vnd.ipld.car', **kwargs + 'com.atproto.sync.getCheckout', params=params_model, output_encoding='application/vnd.ipld.car', **kwargs ) return get_response_model(response, models.ComAtprotoSyncGetCheckout.Response) @@ -912,9 +892,9 @@ def get_commit_path( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoSyncGetCommitPath.Params) + params_model = get_or_create_model(params, models.ComAtprotoSyncGetCommitPath.Params) response = self._client.invoke_query( - 'com.atproto.sync.getCommitPath', params=params, output_encoding='application/json', **kwargs + 'com.atproto.sync.getCommitPath', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.ComAtprotoSyncGetCommitPath.Response) @@ -934,9 +914,9 @@ def get_head( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoSyncGetHead.Params) + params_model = get_or_create_model(params, models.ComAtprotoSyncGetHead.Params) response = self._client.invoke_query( - 'com.atproto.sync.getHead', params=params, output_encoding='application/json', **kwargs + 'com.atproto.sync.getHead', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.ComAtprotoSyncGetHead.Response) @@ -956,9 +936,9 @@ def get_record( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoSyncGetRecord.Params) + params_model = get_or_create_model(params, models.ComAtprotoSyncGetRecord.Params) response = self._client.invoke_query( - 'com.atproto.sync.getRecord', params=params, output_encoding='application/vnd.ipld.car', **kwargs + 'com.atproto.sync.getRecord', params=params_model, output_encoding='application/vnd.ipld.car', **kwargs ) return get_response_model(response, models.ComAtprotoSyncGetRecord.Response) @@ -978,9 +958,9 @@ def get_repo( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoSyncGetRepo.Params) + params_model = get_or_create_model(params, models.ComAtprotoSyncGetRepo.Params) response = self._client.invoke_query( - 'com.atproto.sync.getRepo', params=params, output_encoding='application/vnd.ipld.car', **kwargs + 'com.atproto.sync.getRepo', params=params_model, output_encoding='application/vnd.ipld.car', **kwargs ) return get_response_model(response, models.ComAtprotoSyncGetRepo.Response) @@ -1000,9 +980,9 @@ def list_blobs( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoSyncListBlobs.Params) + params_model = get_or_create_model(params, models.ComAtprotoSyncListBlobs.Params) response = self._client.invoke_query( - 'com.atproto.sync.listBlobs', params=params, output_encoding='application/json', **kwargs + 'com.atproto.sync.listBlobs', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.ComAtprotoSyncListBlobs.Response) @@ -1022,9 +1002,9 @@ def list_repos( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoSyncListRepos.Params) + params_model = get_or_create_model(params, models.ComAtprotoSyncListRepos.Params) response = self._client.invoke_query( - 'com.atproto.sync.listRepos', params=params, output_encoding='application/json', **kwargs + 'com.atproto.sync.listRepos', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.ComAtprotoSyncListRepos.Response) @@ -1042,8 +1022,8 @@ def notify_of_update(self, params: t.Union[dict, 'models.ComAtprotoSyncNotifyOfU :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoSyncNotifyOfUpdate.Params) - response = self._client.invoke_query('com.atproto.sync.notifyOfUpdate', params=params, **kwargs) + params_model = get_or_create_model(params, models.ComAtprotoSyncNotifyOfUpdate.Params) + response = self._client.invoke_query('com.atproto.sync.notifyOfUpdate', params=params_model, **kwargs) return get_response_model(response, bool) def request_crawl(self, params: t.Union[dict, 'models.ComAtprotoSyncRequestCrawl.Params'], **kwargs) -> bool: @@ -1060,12 +1040,11 @@ def request_crawl(self, params: t.Union[dict, 'models.ComAtprotoSyncRequestCrawl :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoSyncRequestCrawl.Params) - response = self._client.invoke_query('com.atproto.sync.requestCrawl', params=params, **kwargs) + params_model = get_or_create_model(params, models.ComAtprotoSyncRequestCrawl.Params) + response = self._client.invoke_query('com.atproto.sync.requestCrawl', params=params_model, **kwargs) return get_response_model(response, bool) -@dataclass class ServerNamespace(NamespaceBase): def create_account( self, data: t.Union[dict, 'models.ComAtprotoServerCreateAccount.Data'], **kwargs @@ -1083,10 +1062,10 @@ def create_account( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoServerCreateAccount.Data) + data_model = get_or_create_model(data, models.ComAtprotoServerCreateAccount.Data) response = self._client.invoke_procedure( 'com.atproto.server.createAccount', - data=data, + data=data_model, input_encoding='application/json', output_encoding='application/json', **kwargs, @@ -1109,10 +1088,10 @@ def create_app_password( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoServerCreateAppPassword.Data) + data_model = get_or_create_model(data, models.ComAtprotoServerCreateAppPassword.Data) response = self._client.invoke_procedure( 'com.atproto.server.createAppPassword', - data=data, + data=data_model, input_encoding='application/json', output_encoding='application/json', **kwargs, @@ -1135,10 +1114,10 @@ def create_invite_code( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoServerCreateInviteCode.Data) + data_model = get_or_create_model(data, models.ComAtprotoServerCreateInviteCode.Data) response = self._client.invoke_procedure( 'com.atproto.server.createInviteCode', - data=data, + data=data_model, input_encoding='application/json', output_encoding='application/json', **kwargs, @@ -1161,10 +1140,10 @@ def create_invite_codes( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoServerCreateInviteCodes.Data) + data_model = get_or_create_model(data, models.ComAtprotoServerCreateInviteCodes.Data) response = self._client.invoke_procedure( 'com.atproto.server.createInviteCodes', - data=data, + data=data_model, input_encoding='application/json', output_encoding='application/json', **kwargs, @@ -1187,10 +1166,10 @@ def create_session( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoServerCreateSession.Data) + data_model = get_or_create_model(data, models.ComAtprotoServerCreateSession.Data) response = self._client.invoke_procedure( 'com.atproto.server.createSession', - data=data, + data=data_model, input_encoding='application/json', output_encoding='application/json', **kwargs, @@ -1211,9 +1190,9 @@ def delete_account(self, data: t.Union[dict, 'models.ComAtprotoServerDeleteAccou :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoServerDeleteAccount.Data) + data_model = get_or_create_model(data, models.ComAtprotoServerDeleteAccount.Data) response = self._client.invoke_procedure( - 'com.atproto.server.deleteAccount', data=data, input_encoding='application/json', **kwargs + 'com.atproto.server.deleteAccount', data=data_model, input_encoding='application/json', **kwargs ) return get_response_model(response, bool) @@ -1267,9 +1246,12 @@ def get_account_invite_codes( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoServerGetAccountInviteCodes.Params) + params_model = get_or_create_model(params, models.ComAtprotoServerGetAccountInviteCodes.Params) response = self._client.invoke_query( - 'com.atproto.server.getAccountInviteCodes', params=params, output_encoding='application/json', **kwargs + 'com.atproto.server.getAccountInviteCodes', + params=params_model, + output_encoding='application/json', + **kwargs, ) return get_response_model(response, models.ComAtprotoServerGetAccountInviteCodes.Response) @@ -1359,9 +1341,9 @@ def request_password_reset( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoServerRequestPasswordReset.Data) + data_model = get_or_create_model(data, models.ComAtprotoServerRequestPasswordReset.Data) response = self._client.invoke_procedure( - 'com.atproto.server.requestPasswordReset', data=data, input_encoding='application/json', **kwargs + 'com.atproto.server.requestPasswordReset', data=data_model, input_encoding='application/json', **kwargs ) return get_response_model(response, bool) @@ -1379,9 +1361,9 @@ def reset_password(self, data: t.Union[dict, 'models.ComAtprotoServerResetPasswo :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoServerResetPassword.Data) + data_model = get_or_create_model(data, models.ComAtprotoServerResetPassword.Data) response = self._client.invoke_procedure( - 'com.atproto.server.resetPassword', data=data, input_encoding='application/json', **kwargs + 'com.atproto.server.resetPassword', data=data_model, input_encoding='application/json', **kwargs ) return get_response_model(response, bool) @@ -1401,14 +1383,13 @@ def revoke_app_password( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoServerRevokeAppPassword.Data) + data_model = get_or_create_model(data, models.ComAtprotoServerRevokeAppPassword.Data) response = self._client.invoke_procedure( - 'com.atproto.server.revokeAppPassword', data=data, input_encoding='application/json', **kwargs + 'com.atproto.server.revokeAppPassword', data=data_model, input_encoding='application/json', **kwargs ) return get_response_model(response, bool) -@dataclass class RepoNamespace(NamespaceBase): def apply_writes(self, data: t.Union[dict, 'models.ComAtprotoRepoApplyWrites.Data'], **kwargs) -> bool: """Apply a batch transaction of creates, updates, and deletes. @@ -1424,9 +1405,9 @@ def apply_writes(self, data: t.Union[dict, 'models.ComAtprotoRepoApplyWrites.Dat :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoRepoApplyWrites.Data) + data_model = get_or_create_model(data, models.ComAtprotoRepoApplyWrites.Data) response = self._client.invoke_procedure( - 'com.atproto.repo.applyWrites', data=data, input_encoding='application/json', **kwargs + 'com.atproto.repo.applyWrites', data=data_model, input_encoding='application/json', **kwargs ) return get_response_model(response, bool) @@ -1446,10 +1427,10 @@ def create_record( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoRepoCreateRecord.Data) + data_model = get_or_create_model(data, models.ComAtprotoRepoCreateRecord.Data) response = self._client.invoke_procedure( 'com.atproto.repo.createRecord', - data=data, + data=data_model, input_encoding='application/json', output_encoding='application/json', **kwargs, @@ -1470,9 +1451,9 @@ def delete_record(self, data: t.Union[dict, 'models.ComAtprotoRepoDeleteRecord.D :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoRepoDeleteRecord.Data) + data_model = get_or_create_model(data, models.ComAtprotoRepoDeleteRecord.Data) response = self._client.invoke_procedure( - 'com.atproto.repo.deleteRecord', data=data, input_encoding='application/json', **kwargs + 'com.atproto.repo.deleteRecord', data=data_model, input_encoding='application/json', **kwargs ) return get_response_model(response, bool) @@ -1492,9 +1473,9 @@ def describe_repo( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoRepoDescribeRepo.Params) + params_model = get_or_create_model(params, models.ComAtprotoRepoDescribeRepo.Params) response = self._client.invoke_query( - 'com.atproto.repo.describeRepo', params=params, output_encoding='application/json', **kwargs + 'com.atproto.repo.describeRepo', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.ComAtprotoRepoDescribeRepo.Response) @@ -1514,9 +1495,9 @@ def get_record( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoRepoGetRecord.Params) + params_model = get_or_create_model(params, models.ComAtprotoRepoGetRecord.Params) response = self._client.invoke_query( - 'com.atproto.repo.getRecord', params=params, output_encoding='application/json', **kwargs + 'com.atproto.repo.getRecord', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.ComAtprotoRepoGetRecord.Response) @@ -1536,9 +1517,9 @@ def list_records( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoRepoListRecords.Params) + params_model = get_or_create_model(params, models.ComAtprotoRepoListRecords.Params) response = self._client.invoke_query( - 'com.atproto.repo.listRecords', params=params, output_encoding='application/json', **kwargs + 'com.atproto.repo.listRecords', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.ComAtprotoRepoListRecords.Response) @@ -1558,10 +1539,10 @@ def put_record( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoRepoPutRecord.Data) + data_model = get_or_create_model(data, models.ComAtprotoRepoPutRecord.Data) response = self._client.invoke_procedure( 'com.atproto.repo.putRecord', - data=data, + data=data_model, input_encoding='application/json', output_encoding='application/json', **kwargs, @@ -1582,9 +1563,9 @@ def rebase_repo(self, data: t.Union[dict, 'models.ComAtprotoRepoRebaseRepo.Data' :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoRepoRebaseRepo.Data) + data_model = get_or_create_model(data, models.ComAtprotoRepoRebaseRepo.Data) response = self._client.invoke_procedure( - 'com.atproto.repo.rebaseRepo', data=data, input_encoding='application/json', **kwargs + 'com.atproto.repo.rebaseRepo', data=data_model, input_encoding='application/json', **kwargs ) return get_response_model(response, bool) @@ -1610,7 +1591,6 @@ def upload_blob( return get_response_model(response, models.ComAtprotoRepoUploadBlob.Response) -@dataclass class AdminNamespace(NamespaceBase): def disable_account_invites( self, data: t.Union[dict, 'models.ComAtprotoAdminDisableAccountInvites.Data'], **kwargs @@ -1628,9 +1608,9 @@ def disable_account_invites( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoAdminDisableAccountInvites.Data) + data_model = get_or_create_model(data, models.ComAtprotoAdminDisableAccountInvites.Data) response = self._client.invoke_procedure( - 'com.atproto.admin.disableAccountInvites', data=data, input_encoding='application/json', **kwargs + 'com.atproto.admin.disableAccountInvites', data=data_model, input_encoding='application/json', **kwargs ) return get_response_model(response, bool) @@ -1650,9 +1630,9 @@ def disable_invite_codes( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoAdminDisableInviteCodes.Data) + data_model = get_or_create_model(data, models.ComAtprotoAdminDisableInviteCodes.Data) response = self._client.invoke_procedure( - 'com.atproto.admin.disableInviteCodes', data=data, input_encoding='application/json', **kwargs + 'com.atproto.admin.disableInviteCodes', data=data_model, input_encoding='application/json', **kwargs ) return get_response_model(response, bool) @@ -1672,9 +1652,9 @@ def enable_account_invites( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoAdminEnableAccountInvites.Data) + data_model = get_or_create_model(data, models.ComAtprotoAdminEnableAccountInvites.Data) response = self._client.invoke_procedure( - 'com.atproto.admin.enableAccountInvites', data=data, input_encoding='application/json', **kwargs + 'com.atproto.admin.enableAccountInvites', data=data_model, input_encoding='application/json', **kwargs ) return get_response_model(response, bool) @@ -1694,9 +1674,9 @@ def get_invite_codes( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoAdminGetInviteCodes.Params) + params_model = get_or_create_model(params, models.ComAtprotoAdminGetInviteCodes.Params) response = self._client.invoke_query( - 'com.atproto.admin.getInviteCodes', params=params, output_encoding='application/json', **kwargs + 'com.atproto.admin.getInviteCodes', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.ComAtprotoAdminGetInviteCodes.Response) @@ -1716,9 +1696,9 @@ def get_moderation_action( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoAdminGetModerationAction.Params) + params_model = get_or_create_model(params, models.ComAtprotoAdminGetModerationAction.Params) response = self._client.invoke_query( - 'com.atproto.admin.getModerationAction', params=params, output_encoding='application/json', **kwargs + 'com.atproto.admin.getModerationAction', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.ComAtprotoAdminGetModerationAction.ResponseRef) @@ -1738,9 +1718,9 @@ def get_moderation_actions( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoAdminGetModerationActions.Params) + params_model = get_or_create_model(params, models.ComAtprotoAdminGetModerationActions.Params) response = self._client.invoke_query( - 'com.atproto.admin.getModerationActions', params=params, output_encoding='application/json', **kwargs + 'com.atproto.admin.getModerationActions', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.ComAtprotoAdminGetModerationActions.Response) @@ -1760,9 +1740,9 @@ def get_moderation_report( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoAdminGetModerationReport.Params) + params_model = get_or_create_model(params, models.ComAtprotoAdminGetModerationReport.Params) response = self._client.invoke_query( - 'com.atproto.admin.getModerationReport', params=params, output_encoding='application/json', **kwargs + 'com.atproto.admin.getModerationReport', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.ComAtprotoAdminGetModerationReport.ResponseRef) @@ -1782,9 +1762,9 @@ def get_moderation_reports( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoAdminGetModerationReports.Params) + params_model = get_or_create_model(params, models.ComAtprotoAdminGetModerationReports.Params) response = self._client.invoke_query( - 'com.atproto.admin.getModerationReports', params=params, output_encoding='application/json', **kwargs + 'com.atproto.admin.getModerationReports', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.ComAtprotoAdminGetModerationReports.Response) @@ -1804,9 +1784,9 @@ def get_record( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoAdminGetRecord.Params) + params_model = get_or_create_model(params, models.ComAtprotoAdminGetRecord.Params) response = self._client.invoke_query( - 'com.atproto.admin.getRecord', params=params, output_encoding='application/json', **kwargs + 'com.atproto.admin.getRecord', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.ComAtprotoAdminGetRecord.ResponseRef) @@ -1826,9 +1806,9 @@ def get_repo( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoAdminGetRepo.Params) + params_model = get_or_create_model(params, models.ComAtprotoAdminGetRepo.Params) response = self._client.invoke_query( - 'com.atproto.admin.getRepo', params=params, output_encoding='application/json', **kwargs + 'com.atproto.admin.getRepo', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.ComAtprotoAdminGetRepo.ResponseRef) @@ -1848,10 +1828,10 @@ def resolve_moderation_reports( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoAdminResolveModerationReports.Data) + data_model = get_or_create_model(data, models.ComAtprotoAdminResolveModerationReports.Data) response = self._client.invoke_procedure( 'com.atproto.admin.resolveModerationReports', - data=data, + data=data_model, input_encoding='application/json', output_encoding='application/json', **kwargs, @@ -1874,10 +1854,10 @@ def reverse_moderation_action( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoAdminReverseModerationAction.Data) + data_model = get_or_create_model(data, models.ComAtprotoAdminReverseModerationAction.Data) response = self._client.invoke_procedure( 'com.atproto.admin.reverseModerationAction', - data=data, + data=data_model, input_encoding='application/json', output_encoding='application/json', **kwargs, @@ -1900,9 +1880,9 @@ def search_repos( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoAdminSearchRepos.Params) + params_model = get_or_create_model(params, models.ComAtprotoAdminSearchRepos.Params) response = self._client.invoke_query( - 'com.atproto.admin.searchRepos', params=params, output_encoding='application/json', **kwargs + 'com.atproto.admin.searchRepos', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.ComAtprotoAdminSearchRepos.Response) @@ -1922,10 +1902,10 @@ def take_moderation_action( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoAdminTakeModerationAction.Data) + data_model = get_or_create_model(data, models.ComAtprotoAdminTakeModerationAction.Data) response = self._client.invoke_procedure( 'com.atproto.admin.takeModerationAction', - data=data, + data=data_model, input_encoding='application/json', output_encoding='application/json', **kwargs, @@ -1948,9 +1928,9 @@ def update_account_email( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoAdminUpdateAccountEmail.Data) + data_model = get_or_create_model(data, models.ComAtprotoAdminUpdateAccountEmail.Data) response = self._client.invoke_procedure( - 'com.atproto.admin.updateAccountEmail', data=data, input_encoding='application/json', **kwargs + 'com.atproto.admin.updateAccountEmail', data=data_model, input_encoding='application/json', **kwargs ) return get_response_model(response, bool) @@ -1970,14 +1950,13 @@ def update_account_handle( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoAdminUpdateAccountHandle.Data) + data_model = get_or_create_model(data, models.ComAtprotoAdminUpdateAccountHandle.Data) response = self._client.invoke_procedure( - 'com.atproto.admin.updateAccountHandle', data=data, input_encoding='application/json', **kwargs + 'com.atproto.admin.updateAccountHandle', data=data_model, input_encoding='application/json', **kwargs ) return get_response_model(response, bool) -@dataclass class IdentityNamespace(NamespaceBase): def resolve_handle( self, params: t.Union[dict, 'models.ComAtprotoIdentityResolveHandle.Params'], **kwargs @@ -1995,9 +1974,9 @@ def resolve_handle( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoIdentityResolveHandle.Params) + params_model = get_or_create_model(params, models.ComAtprotoIdentityResolveHandle.Params) response = self._client.invoke_query( - 'com.atproto.identity.resolveHandle', params=params, output_encoding='application/json', **kwargs + 'com.atproto.identity.resolveHandle', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.ComAtprotoIdentityResolveHandle.Response) @@ -2015,14 +1994,13 @@ def update_handle(self, data: t.Union[dict, 'models.ComAtprotoIdentityUpdateHand :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoIdentityUpdateHandle.Data) + data_model = get_or_create_model(data, models.ComAtprotoIdentityUpdateHandle.Data) response = self._client.invoke_procedure( - 'com.atproto.identity.updateHandle', data=data, input_encoding='application/json', **kwargs + 'com.atproto.identity.updateHandle', data=data_model, input_encoding='application/json', **kwargs ) return get_response_model(response, bool) -@dataclass class ModerationNamespace(NamespaceBase): def create_report( self, data: t.Union[dict, 'models.ComAtprotoModerationCreateReport.Data'], **kwargs @@ -2040,10 +2018,10 @@ def create_report( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - data = get_or_create(data, models.ComAtprotoModerationCreateReport.Data) + data_model = get_or_create_model(data, models.ComAtprotoModerationCreateReport.Data) response = self._client.invoke_procedure( 'com.atproto.moderation.createReport', - data=data, + data=data_model, input_encoding='application/json', output_encoding='application/json', **kwargs, @@ -2051,7 +2029,6 @@ def create_report( return get_response_model(response, models.ComAtprotoModerationCreateReport.Response) -@dataclass class LabelNamespace(NamespaceBase): def query_labels( self, params: t.Union[dict, 'models.ComAtprotoLabelQueryLabels.Params'], **kwargs @@ -2069,8 +2046,8 @@ def query_labels( :class:`atproto.exceptions.AtProtocolError`: Base exception. """ - params = get_or_create(params, models.ComAtprotoLabelQueryLabels.Params) + params_model = get_or_create_model(params, models.ComAtprotoLabelQueryLabels.Params) response = self._client.invoke_query( - 'com.atproto.label.queryLabels', params=params, output_encoding='application/json', **kwargs + 'com.atproto.label.queryLabels', params=params_model, output_encoding='application/json', **kwargs ) return get_response_model(response, models.ComAtprotoLabelQueryLabels.Response) diff --git a/atproto/xrpc_client/request.py b/atproto/xrpc_client/request.py index 4962a2a4..340d9871 100644 --- a/atproto/xrpc_client/request.py +++ b/atproto/xrpc_client/request.py @@ -98,8 +98,9 @@ def _send_request(self, method: str, url: str, **kwargs) -> httpx.Response: try: response = self._client.request(method=method, url=url, headers=headers, **kwargs) return _handle_response(response) - except Exception as e: # noqa: BLE001 + except Exception as e: _handle_request_errors(e) + raise e def close(self) -> None: self._client.close() @@ -124,8 +125,9 @@ async def _send_request(self, method: str, url: str, **kwargs) -> httpx.Response try: response = await self._client.request(method=method, url=url, headers=headers, **kwargs) return _handle_response(response) - except Exception as e: # noqa: BLE001 + except Exception as e: _handle_request_errors(e) + raise e async def close(self) -> None: await self._client.aclose() diff --git a/poetry.lock b/poetry.lock index f82575d1..426fab11 100644 --- a/poetry.lock +++ b/poetry.lock @@ -826,6 +826,54 @@ multiformats = "*" [package.extras] dev = ["mypy", "pylint", "pytest", "pytest-cov"] +[[package]] +name = "mypy" +version = "1.3.0" +description = "Optional static typing for Python" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "mypy-1.3.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c1eb485cea53f4f5284e5baf92902cd0088b24984f4209e25981cc359d64448d"}, + {file = "mypy-1.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4c99c3ecf223cf2952638da9cd82793d8f3c0c5fa8b6ae2b2d9ed1e1ff51ba85"}, + {file = "mypy-1.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:550a8b3a19bb6589679a7c3c31f64312e7ff482a816c96e0cecec9ad3a7564dd"}, + {file = "mypy-1.3.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:cbc07246253b9e3d7d74c9ff948cd0fd7a71afcc2b77c7f0a59c26e9395cb152"}, + {file = "mypy-1.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:a22435632710a4fcf8acf86cbd0d69f68ac389a3892cb23fbad176d1cddaf228"}, + {file = "mypy-1.3.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6e33bb8b2613614a33dff70565f4c803f889ebd2f859466e42b46e1df76018dd"}, + {file = "mypy-1.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7d23370d2a6b7a71dc65d1266f9a34e4cde9e8e21511322415db4b26f46f6b8c"}, + {file = "mypy-1.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:658fe7b674769a0770d4b26cb4d6f005e88a442fe82446f020be8e5f5efb2fae"}, + {file = "mypy-1.3.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:6e42d29e324cdda61daaec2336c42512e59c7c375340bd202efa1fe0f7b8f8ca"}, + {file = "mypy-1.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:d0b6c62206e04061e27009481cb0ec966f7d6172b5b936f3ead3d74f29fe3dcf"}, + {file = "mypy-1.3.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:76ec771e2342f1b558c36d49900dfe81d140361dd0d2df6cd71b3db1be155409"}, + {file = "mypy-1.3.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ebc95f8386314272bbc817026f8ce8f4f0d2ef7ae44f947c4664efac9adec929"}, + {file = "mypy-1.3.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:faff86aa10c1aa4a10e1a301de160f3d8fc8703b88c7e98de46b531ff1276a9a"}, + {file = "mypy-1.3.0-cp37-cp37m-win_amd64.whl", hash = "sha256:8c5979d0deb27e0f4479bee18ea0f83732a893e81b78e62e2dda3e7e518c92ee"}, + {file = "mypy-1.3.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c5d2cc54175bab47011b09688b418db71403aefad07cbcd62d44010543fc143f"}, + {file = "mypy-1.3.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:87df44954c31d86df96c8bd6e80dfcd773473e877ac6176a8e29898bfb3501cb"}, + {file = "mypy-1.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:473117e310febe632ddf10e745a355714e771ffe534f06db40702775056614c4"}, + {file = "mypy-1.3.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:74bc9b6e0e79808bf8678d7678b2ae3736ea72d56eede3820bd3849823e7f305"}, + {file = "mypy-1.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:44797d031a41516fcf5cbfa652265bb994e53e51994c1bd649ffcd0c3a7eccbf"}, + {file = "mypy-1.3.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ddae0f39ca146972ff6bb4399f3b2943884a774b8771ea0a8f50e971f5ea5ba8"}, + {file = "mypy-1.3.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1c4c42c60a8103ead4c1c060ac3cdd3ff01e18fddce6f1016e08939647a0e703"}, + {file = "mypy-1.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e86c2c6852f62f8f2b24cb7a613ebe8e0c7dc1402c61d36a609174f63e0ff017"}, + {file = "mypy-1.3.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:f9dca1e257d4cc129517779226753dbefb4f2266c4eaad610fc15c6a7e14283e"}, + {file = "mypy-1.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:95d8d31a7713510685b05fbb18d6ac287a56c8f6554d88c19e73f724a445448a"}, + {file = "mypy-1.3.0-py3-none-any.whl", hash = "sha256:a8763e72d5d9574d45ce5881962bc8e9046bf7b375b0abf031f3e6811732a897"}, + {file = "mypy-1.3.0.tar.gz", hash = "sha256:e1f4d16e296f5135624b34e8fb741eb0eadedca90862405b1f1fde2040b9bd11"}, +] + +[package.dependencies] +mypy-extensions = ">=1.0.0" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +typed-ast = {version = ">=1.4.0,<2", markers = "python_version < \"3.8\""} +typing-extensions = ">=3.10" + +[package.extras] +dmypy = ["psutil (>=4.0)"] +install-types = ["pip"] +python2 = ["typed-ast (>=1.4.0,<2)"] +reports = ["lxml"] + [[package]] name = "mypy-extensions" version = "1.0.0" @@ -1627,4 +1675,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more [metadata] lock-version = "2.0" python-versions = ">=3.7,<3.12" -content-hash = "6bef83653705d4124ae965f2cdfd0a1bf6b2488800ebae921d4d676134696e40" +content-hash = "83969db8069a22481474b71d64cd9932f19fe792a6b4db2e4e56a1928f4a1a0e" diff --git a/pyproject.toml b/pyproject.toml index a943e828..de1a7beb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ click = ">=8.1.3,<8.2.0" [tool.poetry.group.dev.dependencies] black = "23.3.0" ruff = ">=0.0.270,<0.1.0" +mypy = ">=1.3.0,<1.4.0" [tool.poetry.group.docs.dependencies] sphinx = "5.3.0" @@ -74,6 +75,11 @@ style = "pep440" requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" +[tool.mypy] +python_version = "3.7" +disallow_untyped_defs = false # TODO(MarshalX) enable + + [tool.ruff] extend-select = [ "E", # pycodestyle errors diff --git a/test.py b/test.py index f67dacee..f70b1171 100644 --- a/test.py +++ b/test.py @@ -45,6 +45,8 @@ def sync_main(): client = Client() client.login(os.environ['USERNAME'], os.environ['PASSWORD']) + # client.com.atproto.admin.get_moderation_actions() + # repo = client.com.atproto.sync.get_repo({'did': client.me.did}) did = client.com.atproto.identity.resolve_handle({'handle': 'bsky.app'}).did repo = client.com.atproto.sync.get_repo({'did': did}) @@ -109,32 +111,6 @@ async def main(): # assert resolve.did == profile.did -def test_strange_embed_images_type(): - d = { - 'text': 'Jack will save us from Elon I hope he doesn`t sell us out again @jack.bsky.social here`s to the future in the present moment #bluesky', - 'embed': { - '$type': 'app.bsky.embed.images', - 'images': [ - { - 'alt': '', - 'image': { - 'cid': 'bafkreib66ejhcuiomfqusm52xriallilizk6uqppymyaz7dmz7yargpwhi', - 'mimeType': 'image/jpeg', - }, - } - ], - }, - 'entities': [ - {'type': 'mention', 'index': {'end': 81, 'start': 64}, 'value': 'did:plc:6fktaamhhxdqb2ypum33kbkj'} - ], - 'createdAt': '2023-03-26T15:36:13.302Z', - } - from atproto.xrpc_client.models.utils import get_or_create - - m = get_or_create(d, models.AppBskyFeedPost.Main) - print(m) - - def _get_ops_by_type(commit: models.ComAtprotoSyncSubscribeRepos.Commit) -> dict: # noqa: C901 operation_by_type = { 'posts': {'created': [], 'deleted': []}, @@ -190,20 +166,8 @@ def on_message_handler(message: 'MessageFrame') -> None: return ops = _get_ops_by_type(commit) - - # Here we can filter, process, run ML with classificator, etc. - # After our feed alg we can save posts uri in our DB - # also we should process here deleted posts to remove it from our DB - # for example lets create our custom feed that will contain all posts that contains M letter - - posts_to_create = [] for post in ops['posts']['created']: - if 'M' in post['record'].text: - posts_to_create.append(post['uri']) - - if posts_to_create: - ... - # print('Posts with M letter:', posts_to_create) + print(post['record'].text) client.start(on_message_handler) @@ -258,9 +222,7 @@ async def _stop_after_n_sec(): if __name__ == '__main__': - # test_strange_embed_images_type() - - # sync_main() + sync_main() # asyncio.get_event_loop().run_until_complete(main()) _custom_feed_firehose()