Skip to content

Commit

Permalink
Fix pyright (#278)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarshalX authored Feb 17, 2024
1 parent ccba1f3 commit cb605dc
Show file tree
Hide file tree
Showing 24 changed files with 259 additions and 132 deletions.
6 changes: 3 additions & 3 deletions packages/atproto_cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def resolve_command(
return name, cmd, args


def echo(ctx: click.Context, *args) -> None:
def echo(ctx: click.Context, *args: t.Any) -> None:
if not ctx.obj.get('silent'):
click.echo(*args)

Expand Down Expand Up @@ -74,11 +74,11 @@ def gen_all(ctx: click.Context) -> None:
echo(ctx, 'Done!')


def _gen_models(*args) -> None:
def _gen_models(*args: t.Any) -> None:
generate_models(*args)


def _gen_namespaces(*args) -> None:
def _gen_namespaces(*args: t.Any) -> None:
generate_namespaces(*args)


Expand Down
53 changes: 39 additions & 14 deletions packages/atproto_client/client/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from atproto_client.client.methods_mixin.backward_compatibility import _BackwardCompatibility
from atproto_client.client.methods_mixin.session import AsyncSessionDispatchMixin
from atproto_client.client.session import Session, SessionEvent, SessionResponse
from atproto_client.exceptions import LoginRequiredError
from atproto_client.models.languages import DEFAULT_LANGUAGE_CODE1
from atproto_client.utils import TextBuilder

Expand All @@ -28,7 +29,7 @@ class AsyncClient(
):
"""High-level client for XRPC of ATProto."""

def __init__(self, base_url: t.Optional[str] = None, *args, **kwargs: t.Any) -> None:
def __init__(self, base_url: t.Optional[str] = None, *args: t.Any, **kwargs: t.Any) -> None:
super().__init__(base_url, *args, **kwargs)

self._refresh_lock = Lock()
Expand All @@ -47,8 +48,8 @@ async def _invoke(self, invoke_type: 'InvokeType', **kwargs: t.Any) -> 'Response
return await super()._invoke(invoke_type, **kwargs)

async def _set_session(self, event: SessionEvent, session: SessionResponse) -> None:
self._set_session_common(session)
await self._call_on_session_change_callbacks(event, self._session.copy())
session = self._set_session_common(session)
await self._call_on_session_change_callbacks(event, session.copy())

async def _get_and_set_session(self, login: str, password: str) -> 'models.ComAtprotoServerCreateSession.Response':
session = await self.com.atproto.server.create_session(
Expand All @@ -58,6 +59,9 @@ async def _get_and_set_session(self, login: str, password: str) -> 'models.ComAt
return session

async def _refresh_and_set_session(self) -> 'models.ComAtprotoServerRefreshSession.Response':
if not self._refresh_jwt:
raise LoginRequiredError

refresh_session = await self.com.atproto.server.refresh_session(
headers=self._get_auth_headers(self._refresh_jwt), session_refreshing=True
)
Expand Down Expand Up @@ -104,7 +108,7 @@ async def send_post(
self,
text: t.Union[str, TextBuilder],
profile_identify: t.Optional[str] = None,
reply_to: t.Optional[t.Union['models.AppBskyFeedPost.ReplyRef', 'models.AppBskyFeedDefs.ReplyRef']] = None,
reply_to: t.Optional['models.AppBskyFeedPost.ReplyRef'] = None,
embed: t.Optional[
t.Union[
'models.AppBskyEmbedImages.Main',
Expand Down Expand Up @@ -142,10 +146,13 @@ async def send_post(
facets = text.build_facets()
text = text.build_text()

repo = self.me.did
repo = self.me and self.me.did
if profile_identify:
repo = profile_identify

if not repo:
raise LoginRequiredError

if not langs:
langs = [DEFAULT_LANGUAGE_CODE1]

Expand Down Expand Up @@ -180,7 +187,7 @@ async def send_image(
image: bytes,
image_alt: str,
profile_identify: t.Optional[str] = None,
reply_to: t.Optional[t.Union['models.AppBskyFeedPost.ReplyRef', 'models.AppBskyFeedDefs.ReplyRef']] = None,
reply_to: t.Optional['models.AppBskyFeedPost.ReplyRef'] = None,
langs: t.Optional[t.List[str]] = None,
facets: t.Optional[t.List['models.AppBskyRichtextFacet.Main']] = None,
) -> 'models.AppBskyFeedPost.CreateRecordResponse':
Expand Down Expand Up @@ -231,10 +238,13 @@ async def get_post(
Raises:
:class:`atproto.exceptions.AtProtocolError`: Base exception.
"""
repo = self.me.did
repo = self.me and self.me.did
if profile_identify:
repo = profile_identify

if not repo:
raise LoginRequiredError

return await self.app.bsky.feed.post.get(repo, post_rkey, cid)

async def get_posts(self, uris: t.List[str]) -> 'models.AppBskyFeedGetPosts.Response':
Expand Down Expand Up @@ -364,7 +374,7 @@ async def get_author_feed(
:class:`atproto.exceptions.AtProtocolError`: Base exception.
"""
return await self.app.bsky.feed.get_author_feed(
models.AppBskyFeedGetAuthorFeed.Params(actor=actor, cursor=cursor, fitler=filter, limit=limit)
models.AppBskyFeedGetAuthorFeed.Params(actor=actor, cursor=cursor, filter=filter, limit=limit)
)

async def like(
Expand All @@ -373,13 +383,16 @@ async def like(
cid: t.Optional[str] = None,
subject: t.Optional['models.ComAtprotoRepoStrongRef.Main'] = None,
) -> 'models.AppBskyFeedLike.CreateRecordResponse':
"""Like the post.
"""Like the record.
Args:
cid: The CID of the post.
uri: The URI of the post.
cid: The CID of the record.
uri: The URI of the record.
subject: DEPRECATED.
Note:
Record could be post, custom feed, etc.
Returns:
:obj:`models.AppBskyFeedLike.CreateRecordResponse`: Reference to the created record.
Expand All @@ -388,8 +401,12 @@ async def like(
"""
subject_obj = self._strong_ref_arg_backward_compatibility(uri, cid, subject)

repo = self.me and self.me.did
if not repo:
raise LoginRequiredError

record = models.AppBskyFeedLike.Record(created_at=self.get_current_time_iso(), subject=subject_obj)
return await self.app.bsky.feed.like.create(self.me.did, record)
return await self.app.bsky.feed.like.create(repo, record)

async def unlike(self, like_uri: str) -> bool:
"""Unlike the post.
Expand Down Expand Up @@ -427,8 +444,12 @@ async def repost(
"""
subject_obj = self._strong_ref_arg_backward_compatibility(uri, cid, subject)

repo = self.me and self.me.did
if not repo:
raise LoginRequiredError

record = models.AppBskyFeedRepost.Record(created_at=self.get_current_time_iso(), subject=subject_obj)
return await self.app.bsky.feed.repost.create(self.me.did, record)
return await self.app.bsky.feed.repost.create(repo, record)

async def unrepost(self, repost_uri: str) -> bool:
"""Unrepost the post (delete repost).
Expand Down Expand Up @@ -457,8 +478,12 @@ async def follow(self, subject: str) -> 'models.AppBskyGraphFollow.CreateRecordR
Raises:
:class:`atproto.exceptions.AtProtocolError`: Base exception.
"""
repo = self.me and self.me.did
if not repo:
raise LoginRequiredError

record = models.AppBskyGraphFollow.Record(created_at=self.get_current_time_iso(), subject=subject)
return await self.app.bsky.graph.follow.create(self.me.did, record)
return await self.app.bsky.graph.follow.create(repo, record)

async def unfollow(self, follow_uri: str) -> bool:
"""Unfollow the profile.
Expand Down
2 changes: 1 addition & 1 deletion packages/atproto_client/client/async_raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class AsyncClientRaw(AsyncClientBase):
com: 'async_ns.ComNamespace'
app: 'async_ns.AppNamespace'

def __init__(self, *args, **kwargs: t.Any) -> None:
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
super().__init__(*args, **kwargs)

self.com = async_ns.ComNamespace(self)
Expand Down
53 changes: 39 additions & 14 deletions packages/atproto_client/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from atproto_client.client.methods_mixin.session import SessionDispatchMixin
from atproto_client.client.raw import ClientRaw
from atproto_client.client.session import Session, SessionEvent, SessionResponse
from atproto_client.exceptions import LoginRequiredError
from atproto_client.models.languages import DEFAULT_LANGUAGE_CODE1
from atproto_client.utils import TextBuilder

Expand All @@ -20,7 +21,7 @@
class Client(_BackwardCompatibility, SessionDispatchMixin, SessionMethodsMixin, TimeMethodsMixin, ClientRaw):
"""High-level client for XRPC of ATProto."""

def __init__(self, base_url: t.Optional[str] = None, *args, **kwargs: t.Any) -> None:
def __init__(self, base_url: t.Optional[str] = None, *args: t.Any, **kwargs: t.Any) -> None:
super().__init__(base_url, *args, **kwargs)

self._refresh_lock = Lock()
Expand All @@ -39,8 +40,8 @@ def _invoke(self, invoke_type: 'InvokeType', **kwargs: t.Any) -> 'Response':
return super()._invoke(invoke_type, **kwargs)

def _set_session(self, event: SessionEvent, session: SessionResponse) -> None:
self._set_session_common(session)
self._call_on_session_change_callbacks(event, self._session.copy())
session = self._set_session_common(session)
self._call_on_session_change_callbacks(event, session.copy())

def _get_and_set_session(self, login: str, password: str) -> 'models.ComAtprotoServerCreateSession.Response':
session = self.com.atproto.server.create_session(
Expand All @@ -50,6 +51,9 @@ def _get_and_set_session(self, login: str, password: str) -> 'models.ComAtprotoS
return session

def _refresh_and_set_session(self) -> 'models.ComAtprotoServerRefreshSession.Response':
if not self._refresh_jwt:
raise LoginRequiredError

refresh_session = self.com.atproto.server.refresh_session(
headers=self._get_auth_headers(self._refresh_jwt), session_refreshing=True
)
Expand Down Expand Up @@ -96,7 +100,7 @@ def send_post(
self,
text: t.Union[str, TextBuilder],
profile_identify: t.Optional[str] = None,
reply_to: t.Optional[t.Union['models.AppBskyFeedPost.ReplyRef', 'models.AppBskyFeedDefs.ReplyRef']] = None,
reply_to: t.Optional['models.AppBskyFeedPost.ReplyRef'] = None,
embed: t.Optional[
t.Union[
'models.AppBskyEmbedImages.Main',
Expand Down Expand Up @@ -134,10 +138,13 @@ def send_post(
facets = text.build_facets()
text = text.build_text()

repo = self.me.did
repo = self.me and self.me.did
if profile_identify:
repo = profile_identify

if not repo:
raise LoginRequiredError

if not langs:
langs = [DEFAULT_LANGUAGE_CODE1]

Expand Down Expand Up @@ -172,7 +179,7 @@ def send_image(
image: bytes,
image_alt: str,
profile_identify: t.Optional[str] = None,
reply_to: t.Optional[t.Union['models.AppBskyFeedPost.ReplyRef', 'models.AppBskyFeedDefs.ReplyRef']] = None,
reply_to: t.Optional['models.AppBskyFeedPost.ReplyRef'] = None,
langs: t.Optional[t.List[str]] = None,
facets: t.Optional[t.List['models.AppBskyRichtextFacet.Main']] = None,
) -> 'models.AppBskyFeedPost.CreateRecordResponse':
Expand Down Expand Up @@ -223,10 +230,13 @@ def get_post(
Raises:
:class:`atproto.exceptions.AtProtocolError`: Base exception.
"""
repo = self.me.did
repo = self.me and self.me.did
if profile_identify:
repo = profile_identify

if not repo:
raise LoginRequiredError

return self.app.bsky.feed.post.get(repo, post_rkey, cid)

def get_posts(self, uris: t.List[str]) -> 'models.AppBskyFeedGetPosts.Response':
Expand Down Expand Up @@ -356,7 +366,7 @@ def get_author_feed(
:class:`atproto.exceptions.AtProtocolError`: Base exception.
"""
return self.app.bsky.feed.get_author_feed(
models.AppBskyFeedGetAuthorFeed.Params(actor=actor, cursor=cursor, fitler=filter, limit=limit)
models.AppBskyFeedGetAuthorFeed.Params(actor=actor, cursor=cursor, filter=filter, limit=limit)
)

def like(
Expand All @@ -365,13 +375,16 @@ def like(
cid: t.Optional[str] = None,
subject: t.Optional['models.ComAtprotoRepoStrongRef.Main'] = None,
) -> 'models.AppBskyFeedLike.CreateRecordResponse':
"""Like the post.
"""Like the record.
Args:
cid: The CID of the post.
uri: The URI of the post.
cid: The CID of the record.
uri: The URI of the record.
subject: DEPRECATED.
Note:
Record could be post, custom feed, etc.
Returns:
:obj:`models.AppBskyFeedLike.CreateRecordResponse`: Reference to the created record.
Expand All @@ -380,8 +393,12 @@ def like(
"""
subject_obj = self._strong_ref_arg_backward_compatibility(uri, cid, subject)

repo = self.me and self.me.did
if not repo:
raise LoginRequiredError

record = models.AppBskyFeedLike.Record(created_at=self.get_current_time_iso(), subject=subject_obj)
return self.app.bsky.feed.like.create(self.me.did, record)
return self.app.bsky.feed.like.create(repo, record)

def unlike(self, like_uri: str) -> bool:
"""Unlike the post.
Expand Down Expand Up @@ -419,8 +436,12 @@ def repost(
"""
subject_obj = self._strong_ref_arg_backward_compatibility(uri, cid, subject)

repo = self.me and self.me.did
if not repo:
raise LoginRequiredError

record = models.AppBskyFeedRepost.Record(created_at=self.get_current_time_iso(), subject=subject_obj)
return self.app.bsky.feed.repost.create(self.me.did, record)
return self.app.bsky.feed.repost.create(repo, record)

def unrepost(self, repost_uri: str) -> bool:
"""Unrepost the post (delete repost).
Expand Down Expand Up @@ -449,8 +470,12 @@ def follow(self, subject: str) -> 'models.AppBskyGraphFollow.CreateRecordRespons
Raises:
:class:`atproto.exceptions.AtProtocolError`: Base exception.
"""
repo = self.me and self.me.did
if not repo:
raise LoginRequiredError

record = models.AppBskyGraphFollow.Record(created_at=self.get_current_time_iso(), subject=subject)
return self.app.bsky.graph.follow.create(self.me.did, record)
return self.app.bsky.graph.follow.create(repo, record)

def unfollow(self, follow_uri: str) -> bool:
"""Unfollow the profile.
Expand Down
Loading

0 comments on commit cb605dc

Please sign in to comment.