diff --git a/packages/atproto_client/client/async_client.py b/packages/atproto_client/client/async_client.py index 12d836c0..d1b78112 100644 --- a/packages/atproto_client/client/async_client.py +++ b/packages/atproto_client/client/async_client.py @@ -43,7 +43,7 @@ async def _invoke(self, invoke_type: 'InvokeType', **kwargs: t.Any) -> 'Response return await super()._invoke(invoke_type, **kwargs) async with self._refresh_lock: - if self._access_jwt and self._should_refresh_session(): + if self._session and self._session.access_jwt and self._should_refresh_session(): await self._refresh_and_set_session() return await super()._invoke(invoke_type, **kwargs) @@ -60,11 +60,11 @@ async def _get_and_set_session(self, login: str, password: str) -> 'models.ComAt return session async def _refresh_and_set_session(self) -> 'models.ComAtprotoServerRefreshSession.Response': - if not self._refresh_jwt: + if not self._session or not self._session.refresh_jwt: raise LoginRequiredError refresh_session = await self.com.atproto.server.refresh_session( - headers=self._get_auth_headers(self._refresh_jwt), session_refreshing=True + headers=self._get_auth_headers(self._session.refresh_jwt), session_refreshing=True ) await self._set_session(SessionEvent.REFRESH, refresh_session) diff --git a/packages/atproto_client/client/client.py b/packages/atproto_client/client/client.py index f3ef8094..dd2be95f 100644 --- a/packages/atproto_client/client/client.py +++ b/packages/atproto_client/client/client.py @@ -34,7 +34,7 @@ def _invoke(self, invoke_type: 'InvokeType', **kwargs: t.Any) -> 'Response': return super()._invoke(invoke_type, **kwargs) with self._refresh_lock: - if self._access_jwt and self._should_refresh_session(): + if self._session and self._session.access_jwt and self._should_refresh_session(): self._refresh_and_set_session() return super()._invoke(invoke_type, **kwargs) @@ -51,11 +51,11 @@ def _get_and_set_session(self, login: str, password: str) -> 'models.ComAtprotoS return session def _refresh_and_set_session(self) -> 'models.ComAtprotoServerRefreshSession.Response': - if not self._refresh_jwt: + if not self._session or not self._session.refresh_jwt: raise LoginRequiredError refresh_session = self.com.atproto.server.refresh_session( - headers=self._get_auth_headers(self._refresh_jwt), session_refreshing=True + headers=self._get_auth_headers(self._session.refresh_jwt), session_refreshing=True ) self._set_session(SessionEvent.REFRESH, refresh_session) diff --git a/packages/atproto_client/client/methods_mixin/headers.py b/packages/atproto_client/client/methods_mixin/headers.py index ab8ec039..d7e53841 100644 --- a/packages/atproto_client/client/methods_mixin/headers.py +++ b/packages/atproto_client/client/methods_mixin/headers.py @@ -30,6 +30,7 @@ def clone(self) -> te.Self: """ cloned_client = super().clone() cloned_client.me = self.me + cloned_client._session = self._session # share the same object to avoid conflicts with session changes return cloned_client def with_proxy(self, service_type: t.Union[AtprotoServiceType, str], did: str) -> te.Self: diff --git a/packages/atproto_client/client/methods_mixin/session.py b/packages/atproto_client/client/methods_mixin/session.py index 5f6d7f1c..8c12fdd1 100644 --- a/packages/atproto_client/client/methods_mixin/session.py +++ b/packages/atproto_client/client/methods_mixin/session.py @@ -2,8 +2,6 @@ import typing as t from datetime import timedelta -from atproto_server.auth.jwt import get_jwt_payload - from atproto_client.client.methods_mixin.time import TimeMethodsMixin from atproto_client.client.session import ( AsyncSessionChangeCallback, @@ -15,9 +13,6 @@ ) from atproto_client.exceptions import LoginRequiredError -if t.TYPE_CHECKING: - from atproto_server.auth.jwt import JwtPayload - class SessionDispatchMixin: def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: @@ -99,46 +94,43 @@ async def _call_on_session_change_callbacks(self, event: SessionEvent, session: class SessionMethodsMixin(TimeMethodsMixin): def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: super().__init__(*args, **kwargs) - - self._access_jwt: t.Optional[str] = None - self._access_jwt_payload: t.Optional['JwtPayload'] = None - - self._refresh_jwt: t.Optional[str] = None - self._refresh_jwt_payload: t.Optional['JwtPayload'] = None - self._session: t.Optional[Session] = None def _should_refresh_session(self) -> bool: - if not self._access_jwt_payload or not self._access_jwt_payload.exp: + if not self._session.access_jwt_payload or not self._session.access_jwt_payload.exp: raise LoginRequiredError - expired_at = self.get_time_from_timestamp(self._access_jwt_payload.exp) + expired_at = self.get_time_from_timestamp(self._session.access_jwt_payload.exp) expired_at = expired_at - timedelta(minutes=15) # let's update the token a bit earlier than required return self.get_current_time() > expired_at - def _set_session_common(self, session: SessionResponse, current_pds: str) -> Session: - self._access_jwt = session.access_jwt - self._access_jwt_payload = get_jwt_payload(session.access_jwt) - - self._refresh_jwt = session.refresh_jwt - self._refresh_jwt_payload = get_jwt_payload(session.refresh_jwt) + def _set_or_update_session(self, session: SessionResponse, pds_endpoint: str) -> None: + if not self._session: + self._session = Session( + access_jwt=session.access_jwt, + refresh_jwt=session.refresh_jwt, + did=session.did, + handle=session.handle, + pds_endpoint=pds_endpoint, + ) + else: + self._session.access_jwt = session.access_jwt + self._session.refresh_jwt = session.refresh_jwt + self._session.did = session.did + self._session.handle = session.handle + self._session.pds_endpoint = pds_endpoint + def _set_session_common(self, session: SessionResponse, current_pds: str) -> Session: pds_endpoint = get_session_pds_endpoint(session) if not pds_endpoint: # current_pds ends with xrpc endpoint, but this is not a problem # overhead is only 4-5 symbols in the exported session string pds_endpoint = current_pds - self._session = Session( - access_jwt=session.access_jwt, - refresh_jwt=session.refresh_jwt, - did=session.did, - handle=session.handle, - pds_endpoint=pds_endpoint, - ) + self._set_or_update_session(session, pds_endpoint) - self._set_auth_headers(session.access_jwt) + self._set_auth_headers(self._session.access_jwt) self._update_pds_endpoint(pds_endpoint) return self._session diff --git a/packages/atproto_client/client/session.py b/packages/atproto_client/client/session.py index f139d9b6..41f30fd2 100644 --- a/packages/atproto_client/client/session.py +++ b/packages/atproto_client/client/session.py @@ -4,8 +4,11 @@ import typing_extensions as te from atproto_core.did_doc import DidDocument, is_valid_did_doc +from atproto_server.auth.jwt import get_jwt_payload if t.TYPE_CHECKING: + from atproto_server.auth.jwt import JwtPayload + from atproto_client import models @@ -26,6 +29,14 @@ class Session: refresh_jwt: str pds_endpoint: t.Optional[str] = 'https://bsky.social' # Backward compatibility for old sessions + @property + def access_jwt_payload(self) -> 'JwtPayload': + return get_jwt_payload(self.access_jwt) + + @property + def refresh_jwt_payload(self) -> 'JwtPayload': + return get_jwt_payload(self.refresh_jwt) + def __repr__(self) -> str: return f''