Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix session sharing with all cloned client instances #531

Merged
merged 4 commits into from
Jan 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/advanced_usage/add_user_to_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def main() -> None:
mod_list = client.app.bsky.graph.get_list(models.AppBskyGraphGetList.Params(list=mod_list_uri))
mod_list_users = [item.subject.did for item in mod_list.items]
print(f'List users: {mod_list_users}')
assert user_to_add in mod_list_users, f'User {user_to_add} not found in the list {mod_list_uri}' # noqa: S101
assert user_to_add in mod_list_users, f'User {user_to_add} not found in the list {mod_list_uri}'

deleted_success = client.app.bsky.graph.listitem.delete(mod_list_owner, AtUri.from_str(created_list_item.uri).rkey)
print(f'Deleted list item: {deleted_success}')
Expand Down
6 changes: 3 additions & 3 deletions examples/advanced_usage/validate_string_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,19 @@
strict_validation_context = {'strict_string_format': True}
HandleTypeAdapter = TypeAdapter(string_formats.Handle)

assert string_formats._OPT_IN_KEY == 'strict_string_format' # noqa: S101
assert string_formats._OPT_IN_KEY == 'strict_string_format'

# values will not be validated if not opting in
sneaky_bad_handle = HandleTypeAdapter.validate_python(some_bad_handle)

assert sneaky_bad_handle == some_bad_handle # noqa: S101
assert sneaky_bad_handle == some_bad_handle

print(f'{sneaky_bad_handle=}\n\n')

# values will be validated if opting in
validated_good_handle = HandleTypeAdapter.validate_python(some_good_handle, context=strict_validation_context)

assert validated_good_handle == some_good_handle # noqa: S101
assert validated_good_handle == some_good_handle

print(f'{validated_good_handle=}\n\n')

Expand Down
22 changes: 13 additions & 9 deletions packages/atproto_client/client/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ async def _invoke(self, invoke_type: 'InvokeType', **kwargs: t.Any) -> 'Response
return await super()._invoke(invoke_type, **kwargs)

async with self._refresh_lock:
if self._access_jwt and self._should_refresh_session():
if self._session and self._session.access_jwt and self._should_refresh_session():
await self._refresh_and_set_session()

return await super()._invoke(invoke_type, **kwargs)

async def _set_session(self, event: SessionEvent, session: SessionResponse) -> None:
session = self._set_session_common(session, self._base_url)
await self._call_on_session_change_callbacks(event, session.copy())
self._set_session_common(session, self._base_url)
await self._call_on_session_change_callbacks(event)

async def _get_and_set_session(self, login: str, password: str) -> 'models.ComAtprotoServerCreateSession.Response':
session = await self.com.atproto.server.create_session(
Expand All @@ -60,11 +60,11 @@ async def _get_and_set_session(self, login: str, password: str) -> 'models.ComAt
return session

async def _refresh_and_set_session(self) -> 'models.ComAtprotoServerRefreshSession.Response':
if not self._refresh_jwt:
if not self._session or not self._session.refresh_jwt:
raise LoginRequiredError

refresh_session = await self.com.atproto.server.refresh_session(
headers=self._get_auth_headers(self._refresh_jwt), session_refreshing=True
headers=self._get_refresh_auth_headers(), session_refreshing=True
)
await self._set_session(SessionEvent.REFRESH, refresh_session)

Expand Down Expand Up @@ -225,16 +225,16 @@ async def send_images(
image_alts = image_alts + [''] * diff # [''] * (minus) => []

if image_aspect_ratios is None:
image_aspect_ratios = [None] * len(images)
aligned_image_aspect_ratios = [None] * len(images)
else:
# padding with None if len is insufficient
diff = len(images) - len(image_aspect_ratios)
image_aspect_ratios = image_aspect_ratios + [None] * diff
aligned_image_aspect_ratios = image_aspect_ratios + [None] * diff

uploads = await asyncio.gather(*[self.upload_blob(image) for image in images])
embed_images = [
models.AppBskyEmbedImages.Image(alt=alt, image=upload.blob, aspect_ratio=aspect_ratio)
for alt, upload, aspect_ratio in zip(image_alts, uploads, image_aspect_ratios)
for alt, upload, aspect_ratio in zip(image_alts, uploads, aligned_image_aspect_ratios)
]

return await self.send_post(
Expand Down Expand Up @@ -278,6 +278,10 @@ async def send_image(
Raises:
:class:`atproto.exceptions.AtProtocolError`: Base exception.
"""
image_aspect_ratios = None
if image_aspect_ratio:
image_aspect_ratios = [image_aspect_ratio]

return await self.send_images(
text,
images=[image],
Expand All @@ -286,7 +290,7 @@ async def send_image(
reply_to=reply_to,
langs=langs,
facets=facets,
image_aspect_ratios=[image_aspect_ratio],
image_aspect_ratios=image_aspect_ratios,
)

async def send_video(
Expand Down
22 changes: 13 additions & 9 deletions packages/atproto_client/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ def _invoke(self, invoke_type: 'InvokeType', **kwargs: t.Any) -> 'Response':
return super()._invoke(invoke_type, **kwargs)

with self._refresh_lock:
if self._access_jwt and self._should_refresh_session():
if self._session and self._session.access_jwt and self._should_refresh_session():
self._refresh_and_set_session()

return super()._invoke(invoke_type, **kwargs)

def _set_session(self, event: SessionEvent, session: SessionResponse) -> None:
session = self._set_session_common(session, self._base_url)
self._call_on_session_change_callbacks(event, session.copy())
self._set_session_common(session, self._base_url)
self._call_on_session_change_callbacks(event)

def _get_and_set_session(self, login: str, password: str) -> 'models.ComAtprotoServerCreateSession.Response':
session = self.com.atproto.server.create_session(
Expand All @@ -51,11 +51,11 @@ def _get_and_set_session(self, login: str, password: str) -> 'models.ComAtprotoS
return session

def _refresh_and_set_session(self) -> 'models.ComAtprotoServerRefreshSession.Response':
if not self._refresh_jwt:
if not self._session or not self._session.refresh_jwt:
raise LoginRequiredError

refresh_session = self.com.atproto.server.refresh_session(
headers=self._get_auth_headers(self._refresh_jwt), session_refreshing=True
headers=self._get_refresh_auth_headers(), session_refreshing=True
)
self._set_session(SessionEvent.REFRESH, refresh_session)

Expand Down Expand Up @@ -216,16 +216,16 @@ def send_images(
image_alts = image_alts + [''] * diff # [''] * (minus) => []

if image_aspect_ratios is None:
image_aspect_ratios = [None] * len(images)
aligned_image_aspect_ratios = [None] * len(images)
else:
# padding with None if len is insufficient
diff = len(images) - len(image_aspect_ratios)
image_aspect_ratios = image_aspect_ratios + [None] * diff
aligned_image_aspect_ratios = image_aspect_ratios + [None] * diff

uploads = [self.upload_blob(image) for image in images]
embed_images = [
models.AppBskyEmbedImages.Image(alt=alt, image=upload.blob, aspect_ratio=aspect_ratio)
for alt, upload, aspect_ratio in zip(image_alts, uploads, image_aspect_ratios)
for alt, upload, aspect_ratio in zip(image_alts, uploads, aligned_image_aspect_ratios)
]

return self.send_post(
Expand Down Expand Up @@ -269,6 +269,10 @@ def send_image(
Raises:
:class:`atproto.exceptions.AtProtocolError`: Base exception.
"""
image_aspect_ratios = None
if image_aspect_ratio:
image_aspect_ratios = [image_aspect_ratio]

return self.send_images(
text,
images=[image],
Expand All @@ -277,7 +281,7 @@ def send_image(
reply_to=reply_to,
langs=langs,
facets=facets,
image_aspect_ratios=[image_aspect_ratio],
image_aspect_ratios=image_aspect_ratios,
)

def send_video(
Expand Down
5 changes: 5 additions & 0 deletions packages/atproto_client/client/methods_mixin/headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@ def clone(self) -> te.Self:
Cloned client instance.
"""
cloned_client = super().clone()

# share the same objects to avoid conflicts with session changes
cloned_client.me = self.me
cloned_client._session = self._session
cloned_client._session_dispatcher = self._session_dispatcher

return cloned_client

def with_proxy(self, service_type: t.Union[AtprotoServiceType, str], did: str) -> te.Self:
Expand Down
120 changes: 59 additions & 61 deletions packages/atproto_client/client/methods_mixin/session.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,57 @@
import asyncio
import typing as t
from datetime import timedelta

from atproto_server.auth.jwt import get_jwt_payload

from atproto_client.client.methods_mixin.time import TimeMethodsMixin
from atproto_client.client.session import (
AsyncSessionChangeCallback,
Session,
SessionChangeCallback,
SessionDispatcher,
SessionEvent,
SessionResponse,
get_session_pds_endpoint,
)
from atproto_client.exceptions import LoginRequiredError

if t.TYPE_CHECKING:
from atproto_server.auth.jwt import JwtPayload


class SessionDispatchMixin:
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
super().__init__(*args, **kwargs)

self._on_session_change_callbacks: t.List[SessionChangeCallback] = []
self._on_session_change_async_callbacks: t.List[AsyncSessionChangeCallback] = []

def on_session_change(self, callback: SessionChangeCallback) -> None:
"""Register a callback for session change event.

Args:
callback: A callback to be called when the session changes.
The callback must accept two arguments: event and session.

Note:
Possible events: `SessionEvent.IMPORT`, `SessionEvent.CREATE`, `SessionEvent.REFRESH`.

Tip:
You should save the session string to persistent storage
on `SessionEvent.CREATE` and `SessionEvent.REFRESH` event.

Example:
>>> from atproto import Client, SessionEvent, Session
>>>
>>> client = Client()
>>>
>>> @client.on_session_change
>>> def on_session_change(event: SessionEvent, session: Session):
>>> print(event, session)
>>>
>>> client.on_session_change(on_session_change)
>>> # or you can use this syntax:
>>> # client.on_session_change(on_session_change)

Returns:
:obj:`None`
"""
self._on_session_change_callbacks.append(callback)
self._session_dispatcher.on_session_change(callback)

def _call_on_session_change_callbacks(self, event: SessionEvent, session: Session) -> None:
for on_session_change_callback in self._on_session_change_callbacks:
on_session_change_callback(event, session)
def _call_on_session_change_callbacks(self, event: SessionEvent) -> None:
self._session_dispatcher.dispatch_session_change(event)


class AsyncSessionDispatchMixin:
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
super().__init__(*args, **kwargs)

self._on_session_change_async_callbacks: t.List[AsyncSessionChangeCallback] = []

def on_session_change(self, callback: AsyncSessionChangeCallback) -> None:
def on_session_change(self, callback: t.Union['AsyncSessionChangeCallback', 'SessionChangeCallback']) -> None:
"""Register a callback for session change event.

Args:
Expand All @@ -69,6 +61,9 @@ def on_session_change(self, callback: AsyncSessionChangeCallback) -> None:
Note:
Possible events: `SessionEvent.IMPORT`, `SessionEvent.CREATE`, `SessionEvent.REFRESH`.

Note:
You can register both synchronous and asynchronous callbacks.

Tip:
You should save the session string to persistent storage
on `SessionEvent.CREATE` and `SessionEvent.REFRESH` event.
Expand All @@ -78,78 +73,81 @@ def on_session_change(self, callback: AsyncSessionChangeCallback) -> None:
>>>
>>> client = AsyncClient()
>>>
>>> @client.on_session_change
>>> async def on_session_change(event: SessionEvent, session: Session):
>>> print(event, session)
>>>
>>> client.on_session_change(on_session_change)
>>> # or you can use this syntax:
>>> # client.on_session_change(on_session_change)

Returns:
:obj:`None`
"""
self._on_session_change_async_callbacks.append(callback)
self._session_dispatcher.on_session_change(callback)

async def _call_on_session_change_callbacks(self, event: SessionEvent, session: Session) -> None:
coroutines: t.List[t.Coroutine[t.Any, t.Any, None]] = []
for on_session_change_async_callback in self._on_session_change_async_callbacks:
coroutines.append(on_session_change_async_callback(event, session))

await asyncio.gather(*coroutines)
async def _call_on_session_change_callbacks(self, event: SessionEvent) -> None:
await self._session_dispatcher.dispatch_session_change_async(event)


class SessionMethodsMixin(TimeMethodsMixin):
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
super().__init__(*args, **kwargs)

self._access_jwt: t.Optional[str] = None
self._access_jwt_payload: t.Optional['JwtPayload'] = None

self._refresh_jwt: t.Optional[str] = None
self._refresh_jwt_payload: t.Optional['JwtPayload'] = None

self._session: t.Optional[Session] = None
self._session_dispatcher = SessionDispatcher()

def _register_auth_headers_source(self) -> None:
self.request.add_additional_headers_source(self._get_access_auth_headers)

def _should_refresh_session(self) -> bool:
if not self._access_jwt_payload or not self._access_jwt_payload.exp:
if not self._session or not self._session.access_jwt_payload or not self._session.access_jwt_payload.exp:
raise LoginRequiredError

expired_at = self.get_time_from_timestamp(self._access_jwt_payload.exp)
expired_at = self.get_time_from_timestamp(self._session.access_jwt_payload.exp)
expired_at = expired_at - timedelta(minutes=15) # let's update the token a bit earlier than required

return self.get_current_time() > expired_at

def _set_session_common(self, session: SessionResponse, current_pds: str) -> Session:
self._access_jwt = session.access_jwt
self._access_jwt_payload = get_jwt_payload(session.access_jwt)
def _set_or_update_session(self, session: SessionResponse, pds_endpoint: str) -> 'Session':
if not self._session:
self._session = Session(
access_jwt=session.access_jwt,
refresh_jwt=session.refresh_jwt,
did=session.did,
handle=session.handle,
pds_endpoint=pds_endpoint,
)
self._session_dispatcher.set_session(self._session)
self._register_auth_headers_source()
else:
self._session.access_jwt = session.access_jwt
self._session.refresh_jwt = session.refresh_jwt
self._session.did = session.did
self._session.handle = session.handle
self._session.pds_endpoint = pds_endpoint

self._refresh_jwt = session.refresh_jwt
self._refresh_jwt_payload = get_jwt_payload(session.refresh_jwt)
return self._session

def _set_session_common(self, session: SessionResponse, current_pds: str) -> Session:
pds_endpoint = get_session_pds_endpoint(session)
if not pds_endpoint:
# current_pds ends with xrpc endpoint, but this is not a problem
# overhead is only 4-5 symbols in the exported session string
pds_endpoint = current_pds

self._session = Session(
access_jwt=session.access_jwt,
refresh_jwt=session.refresh_jwt,
did=session.did,
handle=session.handle,
pds_endpoint=pds_endpoint,
)

self._set_auth_headers(session.access_jwt)
self._update_pds_endpoint(pds_endpoint)
return self._set_or_update_session(session, pds_endpoint)

return self._session
def _get_access_auth_headers(self) -> t.Dict[str, str]:
if not self._session:
return {}

@staticmethod
def _get_auth_headers(token: str) -> t.Dict[str, str]:
return {'Authorization': f'Bearer {token}'}
return {'Authorization': f'Bearer {self._session.access_jwt}'}

def _get_refresh_auth_headers(self) -> t.Dict[str, str]:
if not self._session:
return {}

def _set_auth_headers(self, token: str) -> None:
for header_name, header_value in self._get_auth_headers(token).items():
self.request.add_additional_header(header_name, header_value)
return {'Authorization': f'Bearer {self._session.refresh_jwt}'}

def _update_pds_endpoint(self, pds_endpoint: str) -> None:
self.update_base_url(pds_endpoint)
Expand Down
Loading
Loading