Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Share the same session with all cloned client instances
Browse files Browse the repository at this point in the history
MarshalX committed Jan 23, 2025

Verified

This commit was signed with the committer’s verified signature.
W-Mai Benign X
1 parent 850f872 commit 32855c8
Showing 5 changed files with 38 additions and 34 deletions.
6 changes: 3 additions & 3 deletions packages/atproto_client/client/async_client.py
Original file line number Diff line number Diff line change
@@ -43,7 +43,7 @@ async def _invoke(self, invoke_type: 'InvokeType', **kwargs: t.Any) -> 'Response
return await super()._invoke(invoke_type, **kwargs)

async with self._refresh_lock:
if self._access_jwt and self._should_refresh_session():
if self._session and self._session.access_jwt and self._should_refresh_session():
await self._refresh_and_set_session()

return await super()._invoke(invoke_type, **kwargs)
@@ -60,11 +60,11 @@ async def _get_and_set_session(self, login: str, password: str) -> 'models.ComAt
return session

async def _refresh_and_set_session(self) -> 'models.ComAtprotoServerRefreshSession.Response':
if not self._refresh_jwt:
if not self._session or not self._session.refresh_jwt:
raise LoginRequiredError

refresh_session = await self.com.atproto.server.refresh_session(
headers=self._get_auth_headers(self._refresh_jwt), session_refreshing=True
headers=self._get_auth_headers(self._session.refresh_jwt), session_refreshing=True
)
await self._set_session(SessionEvent.REFRESH, refresh_session)

6 changes: 3 additions & 3 deletions packages/atproto_client/client/client.py
Original file line number Diff line number Diff line change
@@ -34,7 +34,7 @@ def _invoke(self, invoke_type: 'InvokeType', **kwargs: t.Any) -> 'Response':
return super()._invoke(invoke_type, **kwargs)

with self._refresh_lock:
if self._access_jwt and self._should_refresh_session():
if self._session and self._session.access_jwt and self._should_refresh_session():
self._refresh_and_set_session()

return super()._invoke(invoke_type, **kwargs)
@@ -51,11 +51,11 @@ def _get_and_set_session(self, login: str, password: str) -> 'models.ComAtprotoS
return session

def _refresh_and_set_session(self) -> 'models.ComAtprotoServerRefreshSession.Response':
if not self._refresh_jwt:
if not self._session or not self._session.refresh_jwt:
raise LoginRequiredError

refresh_session = self.com.atproto.server.refresh_session(
headers=self._get_auth_headers(self._refresh_jwt), session_refreshing=True
headers=self._get_auth_headers(self._session.refresh_jwt), session_refreshing=True
)
self._set_session(SessionEvent.REFRESH, refresh_session)

1 change: 1 addition & 0 deletions packages/atproto_client/client/methods_mixin/headers.py
Original file line number Diff line number Diff line change
@@ -30,6 +30,7 @@ def clone(self) -> te.Self:
"""
cloned_client = super().clone()
cloned_client.me = self.me
cloned_client._session = self._session # share the same object to avoid conflicts with session changes
return cloned_client

def with_proxy(self, service_type: t.Union[AtprotoServiceType, str], did: str) -> te.Self:
48 changes: 20 additions & 28 deletions packages/atproto_client/client/methods_mixin/session.py
Original file line number Diff line number Diff line change
@@ -2,8 +2,6 @@
import typing as t
from datetime import timedelta

from atproto_server.auth.jwt import get_jwt_payload

from atproto_client.client.methods_mixin.time import TimeMethodsMixin
from atproto_client.client.session import (
AsyncSessionChangeCallback,
@@ -15,9 +13,6 @@
)
from atproto_client.exceptions import LoginRequiredError

if t.TYPE_CHECKING:
from atproto_server.auth.jwt import JwtPayload


class SessionDispatchMixin:
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
@@ -99,46 +94,43 @@ async def _call_on_session_change_callbacks(self, event: SessionEvent, session:
class SessionMethodsMixin(TimeMethodsMixin):
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
super().__init__(*args, **kwargs)

self._access_jwt: t.Optional[str] = None
self._access_jwt_payload: t.Optional['JwtPayload'] = None

self._refresh_jwt: t.Optional[str] = None
self._refresh_jwt_payload: t.Optional['JwtPayload'] = None

self._session: t.Optional[Session] = None

def _should_refresh_session(self) -> bool:
if not self._access_jwt_payload or not self._access_jwt_payload.exp:
if not self._session.access_jwt_payload or not self._session.access_jwt_payload.exp:
raise LoginRequiredError

expired_at = self.get_time_from_timestamp(self._access_jwt_payload.exp)
expired_at = self.get_time_from_timestamp(self._session.access_jwt_payload.exp)
expired_at = expired_at - timedelta(minutes=15) # let's update the token a bit earlier than required

return self.get_current_time() > expired_at

def _set_session_common(self, session: SessionResponse, current_pds: str) -> Session:
self._access_jwt = session.access_jwt
self._access_jwt_payload = get_jwt_payload(session.access_jwt)

self._refresh_jwt = session.refresh_jwt
self._refresh_jwt_payload = get_jwt_payload(session.refresh_jwt)
def _set_or_update_session(self, session: SessionResponse, pds_endpoint: str) -> None:
if not self._session:
self._session = Session(
access_jwt=session.access_jwt,
refresh_jwt=session.refresh_jwt,
did=session.did,
handle=session.handle,
pds_endpoint=pds_endpoint,
)
else:
self._session.access_jwt = session.access_jwt
self._session.refresh_jwt = session.refresh_jwt
self._session.did = session.did
self._session.handle = session.handle
self._session.pds_endpoint = pds_endpoint

def _set_session_common(self, session: SessionResponse, current_pds: str) -> Session:
pds_endpoint = get_session_pds_endpoint(session)
if not pds_endpoint:
# current_pds ends with xrpc endpoint, but this is not a problem
# overhead is only 4-5 symbols in the exported session string
pds_endpoint = current_pds

self._session = Session(
access_jwt=session.access_jwt,
refresh_jwt=session.refresh_jwt,
did=session.did,
handle=session.handle,
pds_endpoint=pds_endpoint,
)
self._set_or_update_session(session, pds_endpoint)

self._set_auth_headers(session.access_jwt)
self._set_auth_headers(self._session.access_jwt)
self._update_pds_endpoint(pds_endpoint)

return self._session
11 changes: 11 additions & 0 deletions packages/atproto_client/client/session.py
Original file line number Diff line number Diff line change
@@ -4,8 +4,11 @@

import typing_extensions as te
from atproto_core.did_doc import DidDocument, is_valid_did_doc
from atproto_server.auth.jwt import get_jwt_payload

if t.TYPE_CHECKING:
from atproto_server.auth.jwt import JwtPayload

from atproto_client import models


@@ -26,6 +29,14 @@ class Session:
refresh_jwt: str
pds_endpoint: t.Optional[str] = 'https://bsky.social' # Backward compatibility for old sessions

@property
def access_jwt_payload(self) -> 'JwtPayload':
return get_jwt_payload(self.access_jwt)

@property
def refresh_jwt_payload(self) -> 'JwtPayload':
return get_jwt_payload(self.refresh_jwt)

def __repr__(self) -> str:
return f'<Session handle={self.handle} did={self.did}>'

0 comments on commit 32855c8

Please sign in to comment.