Skip to content

Commit

Permalink
Fix session sharing with all cloned client instances (#531)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarshalX authored Jan 25, 2025
1 parent 850f872 commit df93518
Show file tree
Hide file tree
Showing 11 changed files with 295 additions and 99 deletions.
2 changes: 1 addition & 1 deletion examples/advanced_usage/add_user_to_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def main() -> None:
mod_list = client.app.bsky.graph.get_list(models.AppBskyGraphGetList.Params(list=mod_list_uri))
mod_list_users = [item.subject.did for item in mod_list.items]
print(f'List users: {mod_list_users}')
assert user_to_add in mod_list_users, f'User {user_to_add} not found in the list {mod_list_uri}' # noqa: S101
assert user_to_add in mod_list_users, f'User {user_to_add} not found in the list {mod_list_uri}'

deleted_success = client.app.bsky.graph.listitem.delete(mod_list_owner, AtUri.from_str(created_list_item.uri).rkey)
print(f'Deleted list item: {deleted_success}')
Expand Down
6 changes: 3 additions & 3 deletions examples/advanced_usage/validate_string_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,19 @@
strict_validation_context = {'strict_string_format': True}
HandleTypeAdapter = TypeAdapter(string_formats.Handle)

assert string_formats._OPT_IN_KEY == 'strict_string_format' # noqa: S101
assert string_formats._OPT_IN_KEY == 'strict_string_format'

# values will not be validated if not opting in
sneaky_bad_handle = HandleTypeAdapter.validate_python(some_bad_handle)

assert sneaky_bad_handle == some_bad_handle # noqa: S101
assert sneaky_bad_handle == some_bad_handle

print(f'{sneaky_bad_handle=}\n\n')

# values will be validated if opting in
validated_good_handle = HandleTypeAdapter.validate_python(some_good_handle, context=strict_validation_context)

assert validated_good_handle == some_good_handle # noqa: S101
assert validated_good_handle == some_good_handle

print(f'{validated_good_handle=}\n\n')

Expand Down
22 changes: 13 additions & 9 deletions packages/atproto_client/client/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ async def _invoke(self, invoke_type: 'InvokeType', **kwargs: t.Any) -> 'Response
return await super()._invoke(invoke_type, **kwargs)

async with self._refresh_lock:
if self._access_jwt and self._should_refresh_session():
if self._session and self._session.access_jwt and self._should_refresh_session():
await self._refresh_and_set_session()

return await super()._invoke(invoke_type, **kwargs)

async def _set_session(self, event: SessionEvent, session: SessionResponse) -> None:
session = self._set_session_common(session, self._base_url)
await self._call_on_session_change_callbacks(event, session.copy())
self._set_session_common(session, self._base_url)
await self._call_on_session_change_callbacks(event)

async def _get_and_set_session(self, login: str, password: str) -> 'models.ComAtprotoServerCreateSession.Response':
session = await self.com.atproto.server.create_session(
Expand All @@ -60,11 +60,11 @@ async def _get_and_set_session(self, login: str, password: str) -> 'models.ComAt
return session

async def _refresh_and_set_session(self) -> 'models.ComAtprotoServerRefreshSession.Response':
if not self._refresh_jwt:
if not self._session or not self._session.refresh_jwt:
raise LoginRequiredError

refresh_session = await self.com.atproto.server.refresh_session(
headers=self._get_auth_headers(self._refresh_jwt), session_refreshing=True
headers=self._get_refresh_auth_headers(), session_refreshing=True
)
await self._set_session(SessionEvent.REFRESH, refresh_session)

Expand Down Expand Up @@ -225,16 +225,16 @@ async def send_images(
image_alts = image_alts + [''] * diff # [''] * (minus) => []

if image_aspect_ratios is None:
image_aspect_ratios = [None] * len(images)
aligned_image_aspect_ratios = [None] * len(images)
else:
# padding with None if len is insufficient
diff = len(images) - len(image_aspect_ratios)
image_aspect_ratios = image_aspect_ratios + [None] * diff
aligned_image_aspect_ratios = image_aspect_ratios + [None] * diff

uploads = await asyncio.gather(*[self.upload_blob(image) for image in images])
embed_images = [
models.AppBskyEmbedImages.Image(alt=alt, image=upload.blob, aspect_ratio=aspect_ratio)
for alt, upload, aspect_ratio in zip(image_alts, uploads, image_aspect_ratios)
for alt, upload, aspect_ratio in zip(image_alts, uploads, aligned_image_aspect_ratios)
]

return await self.send_post(
Expand Down Expand Up @@ -278,6 +278,10 @@ async def send_image(
Raises:
:class:`atproto.exceptions.AtProtocolError`: Base exception.
"""
image_aspect_ratios = None
if image_aspect_ratio:
image_aspect_ratios = [image_aspect_ratio]

return await self.send_images(
text,
images=[image],
Expand All @@ -286,7 +290,7 @@ async def send_image(
reply_to=reply_to,
langs=langs,
facets=facets,
image_aspect_ratios=[image_aspect_ratio],
image_aspect_ratios=image_aspect_ratios,
)

async def send_video(
Expand Down
22 changes: 13 additions & 9 deletions packages/atproto_client/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ def _invoke(self, invoke_type: 'InvokeType', **kwargs: t.Any) -> 'Response':
return super()._invoke(invoke_type, **kwargs)

with self._refresh_lock:
if self._access_jwt and self._should_refresh_session():
if self._session and self._session.access_jwt and self._should_refresh_session():
self._refresh_and_set_session()

return super()._invoke(invoke_type, **kwargs)

def _set_session(self, event: SessionEvent, session: SessionResponse) -> None:
session = self._set_session_common(session, self._base_url)
self._call_on_session_change_callbacks(event, session.copy())
self._set_session_common(session, self._base_url)
self._call_on_session_change_callbacks(event)

def _get_and_set_session(self, login: str, password: str) -> 'models.ComAtprotoServerCreateSession.Response':
session = self.com.atproto.server.create_session(
Expand All @@ -51,11 +51,11 @@ def _get_and_set_session(self, login: str, password: str) -> 'models.ComAtprotoS
return session

def _refresh_and_set_session(self) -> 'models.ComAtprotoServerRefreshSession.Response':
if not self._refresh_jwt:
if not self._session or not self._session.refresh_jwt:
raise LoginRequiredError

refresh_session = self.com.atproto.server.refresh_session(
headers=self._get_auth_headers(self._refresh_jwt), session_refreshing=True
headers=self._get_refresh_auth_headers(), session_refreshing=True
)
self._set_session(SessionEvent.REFRESH, refresh_session)

Expand Down Expand Up @@ -216,16 +216,16 @@ def send_images(
image_alts = image_alts + [''] * diff # [''] * (minus) => []

if image_aspect_ratios is None:
image_aspect_ratios = [None] * len(images)
aligned_image_aspect_ratios = [None] * len(images)
else:
# padding with None if len is insufficient
diff = len(images) - len(image_aspect_ratios)
image_aspect_ratios = image_aspect_ratios + [None] * diff
aligned_image_aspect_ratios = image_aspect_ratios + [None] * diff

uploads = [self.upload_blob(image) for image in images]
embed_images = [
models.AppBskyEmbedImages.Image(alt=alt, image=upload.blob, aspect_ratio=aspect_ratio)
for alt, upload, aspect_ratio in zip(image_alts, uploads, image_aspect_ratios)
for alt, upload, aspect_ratio in zip(image_alts, uploads, aligned_image_aspect_ratios)
]

return self.send_post(
Expand Down Expand Up @@ -269,6 +269,10 @@ def send_image(
Raises:
:class:`atproto.exceptions.AtProtocolError`: Base exception.
"""
image_aspect_ratios = None
if image_aspect_ratio:
image_aspect_ratios = [image_aspect_ratio]

return self.send_images(
text,
images=[image],
Expand All @@ -277,7 +281,7 @@ def send_image(
reply_to=reply_to,
langs=langs,
facets=facets,
image_aspect_ratios=[image_aspect_ratio],
image_aspect_ratios=image_aspect_ratios,
)

def send_video(
Expand Down
5 changes: 5 additions & 0 deletions packages/atproto_client/client/methods_mixin/headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@ def clone(self) -> te.Self:
Cloned client instance.
"""
cloned_client = super().clone()

# share the same objects to avoid conflicts with session changes
cloned_client.me = self.me
cloned_client._session = self._session
cloned_client._session_dispatcher = self._session_dispatcher

return cloned_client

def with_proxy(self, service_type: t.Union[AtprotoServiceType, str], did: str) -> te.Self:
Expand Down
120 changes: 59 additions & 61 deletions packages/atproto_client/client/methods_mixin/session.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,57 @@
import asyncio
import typing as t
from datetime import timedelta

from atproto_server.auth.jwt import get_jwt_payload

from atproto_client.client.methods_mixin.time import TimeMethodsMixin
from atproto_client.client.session import (
AsyncSessionChangeCallback,
Session,
SessionChangeCallback,
SessionDispatcher,
SessionEvent,
SessionResponse,
get_session_pds_endpoint,
)
from atproto_client.exceptions import LoginRequiredError

if t.TYPE_CHECKING:
from atproto_server.auth.jwt import JwtPayload


class SessionDispatchMixin:
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
super().__init__(*args, **kwargs)

self._on_session_change_callbacks: t.List[SessionChangeCallback] = []
self._on_session_change_async_callbacks: t.List[AsyncSessionChangeCallback] = []

def on_session_change(self, callback: SessionChangeCallback) -> None:
"""Register a callback for session change event.
Args:
callback: A callback to be called when the session changes.
The callback must accept two arguments: event and session.
Note:
Possible events: `SessionEvent.IMPORT`, `SessionEvent.CREATE`, `SessionEvent.REFRESH`.
Tip:
You should save the session string to persistent storage
on `SessionEvent.CREATE` and `SessionEvent.REFRESH` event.
Example:
>>> from atproto import Client, SessionEvent, Session
>>>
>>> client = Client()
>>>
>>> @client.on_session_change
>>> def on_session_change(event: SessionEvent, session: Session):
>>> print(event, session)
>>>
>>> client.on_session_change(on_session_change)
>>> # or you can use this syntax:
>>> # client.on_session_change(on_session_change)
Returns:
:obj:`None`
"""
self._on_session_change_callbacks.append(callback)
self._session_dispatcher.on_session_change(callback)

def _call_on_session_change_callbacks(self, event: SessionEvent, session: Session) -> None:
for on_session_change_callback in self._on_session_change_callbacks:
on_session_change_callback(event, session)
def _call_on_session_change_callbacks(self, event: SessionEvent) -> None:
self._session_dispatcher.dispatch_session_change(event)


class AsyncSessionDispatchMixin:
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
super().__init__(*args, **kwargs)

self._on_session_change_async_callbacks: t.List[AsyncSessionChangeCallback] = []

def on_session_change(self, callback: AsyncSessionChangeCallback) -> None:
def on_session_change(self, callback: t.Union['AsyncSessionChangeCallback', 'SessionChangeCallback']) -> None:
"""Register a callback for session change event.
Args:
Expand All @@ -69,6 +61,9 @@ def on_session_change(self, callback: AsyncSessionChangeCallback) -> None:
Note:
Possible events: `SessionEvent.IMPORT`, `SessionEvent.CREATE`, `SessionEvent.REFRESH`.
Note:
You can register both synchronous and asynchronous callbacks.
Tip:
You should save the session string to persistent storage
on `SessionEvent.CREATE` and `SessionEvent.REFRESH` event.
Expand All @@ -78,78 +73,81 @@ def on_session_change(self, callback: AsyncSessionChangeCallback) -> None:
>>>
>>> client = AsyncClient()
>>>
>>> @client.on_session_change
>>> async def on_session_change(event: SessionEvent, session: Session):
>>> print(event, session)
>>>
>>> client.on_session_change(on_session_change)
>>> # or you can use this syntax:
>>> # client.on_session_change(on_session_change)
Returns:
:obj:`None`
"""
self._on_session_change_async_callbacks.append(callback)
self._session_dispatcher.on_session_change(callback)

async def _call_on_session_change_callbacks(self, event: SessionEvent, session: Session) -> None:
coroutines: t.List[t.Coroutine[t.Any, t.Any, None]] = []
for on_session_change_async_callback in self._on_session_change_async_callbacks:
coroutines.append(on_session_change_async_callback(event, session))

await asyncio.gather(*coroutines)
async def _call_on_session_change_callbacks(self, event: SessionEvent) -> None:
await self._session_dispatcher.dispatch_session_change_async(event)


class SessionMethodsMixin(TimeMethodsMixin):
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
super().__init__(*args, **kwargs)

self._access_jwt: t.Optional[str] = None
self._access_jwt_payload: t.Optional['JwtPayload'] = None

self._refresh_jwt: t.Optional[str] = None
self._refresh_jwt_payload: t.Optional['JwtPayload'] = None

self._session: t.Optional[Session] = None
self._session_dispatcher = SessionDispatcher()

def _register_auth_headers_source(self) -> None:
self.request.add_additional_headers_source(self._get_access_auth_headers)

def _should_refresh_session(self) -> bool:
if not self._access_jwt_payload or not self._access_jwt_payload.exp:
if not self._session or not self._session.access_jwt_payload or not self._session.access_jwt_payload.exp:
raise LoginRequiredError

expired_at = self.get_time_from_timestamp(self._access_jwt_payload.exp)
expired_at = self.get_time_from_timestamp(self._session.access_jwt_payload.exp)
expired_at = expired_at - timedelta(minutes=15) # let's update the token a bit earlier than required

return self.get_current_time() > expired_at

def _set_session_common(self, session: SessionResponse, current_pds: str) -> Session:
self._access_jwt = session.access_jwt
self._access_jwt_payload = get_jwt_payload(session.access_jwt)
def _set_or_update_session(self, session: SessionResponse, pds_endpoint: str) -> 'Session':
if not self._session:
self._session = Session(
access_jwt=session.access_jwt,
refresh_jwt=session.refresh_jwt,
did=session.did,
handle=session.handle,
pds_endpoint=pds_endpoint,
)
self._session_dispatcher.set_session(self._session)
self._register_auth_headers_source()
else:
self._session.access_jwt = session.access_jwt
self._session.refresh_jwt = session.refresh_jwt
self._session.did = session.did
self._session.handle = session.handle
self._session.pds_endpoint = pds_endpoint

self._refresh_jwt = session.refresh_jwt
self._refresh_jwt_payload = get_jwt_payload(session.refresh_jwt)
return self._session

def _set_session_common(self, session: SessionResponse, current_pds: str) -> Session:
pds_endpoint = get_session_pds_endpoint(session)
if not pds_endpoint:
# current_pds ends with xrpc endpoint, but this is not a problem
# overhead is only 4-5 symbols in the exported session string
pds_endpoint = current_pds

self._session = Session(
access_jwt=session.access_jwt,
refresh_jwt=session.refresh_jwt,
did=session.did,
handle=session.handle,
pds_endpoint=pds_endpoint,
)

self._set_auth_headers(session.access_jwt)
self._update_pds_endpoint(pds_endpoint)
return self._set_or_update_session(session, pds_endpoint)

return self._session
def _get_access_auth_headers(self) -> t.Dict[str, str]:
if not self._session:
return {}

@staticmethod
def _get_auth_headers(token: str) -> t.Dict[str, str]:
return {'Authorization': f'Bearer {token}'}
return {'Authorization': f'Bearer {self._session.access_jwt}'}

def _get_refresh_auth_headers(self) -> t.Dict[str, str]:
if not self._session:
return {}

def _set_auth_headers(self, token: str) -> None:
for header_name, header_value in self._get_auth_headers(token).items():
self.request.add_additional_header(header_name, header_value)
return {'Authorization': f'Bearer {self._session.refresh_jwt}'}

def _update_pds_endpoint(self, pds_endpoint: str) -> None:
self.update_base_url(pds_endpoint)
Expand Down
Loading

0 comments on commit df93518

Please sign in to comment.