Skip to content

Commit

Permalink
rename classes; add docs
Browse files Browse the repository at this point in the history
  • Loading branch information
MarshalX committed Feb 9, 2024
1 parent 5e6b142 commit 98ed3db
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 35 deletions.
6 changes: 3 additions & 3 deletions packages/atproto/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from atproto_client import AsyncClient, Client, SessionChangeEvent, SessionString, models
from atproto_client import AsyncClient, Client, Session, SessionEvent, models
from atproto_client import utils as client_utils
from atproto_core.car import CAR
from atproto_core.cid import CID, CIDType
Expand Down Expand Up @@ -34,8 +34,8 @@
# client
'AsyncClient',
'Client',
'SessionChangeEvent',
'SessionString',
'SessionEvent',
'Session',
'client_utils',
'models',
# core
Expand Down
6 changes: 3 additions & 3 deletions packages/atproto_client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from atproto_client import models
from atproto_client.client.async_client import AsyncClient
from atproto_client.client.client import Client
from atproto_client.client.session import SessionChangeEvent, SessionString
from atproto_client.client.session import Session, SessionEvent

__all__ = [
'AsyncClient',
'Client',
'models',
'SessionChangeEvent',
'SessionString',
'SessionEvent',
'Session',
]
14 changes: 7 additions & 7 deletions packages/atproto_client/client/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from atproto_client.client.methods_mixin import SessionMethodsMixin, TimeMethodsMixin
from atproto_client.client.methods_mixin.session import AsyncSessionDispatchMixin
from atproto_client.client.methods_mixin.strong_ref_arg_backward_compatibility import _StrongRefArgBackwardCompatibility
from atproto_client.client.session import SessionChangeEvent, SessionResponse, SessionString
from atproto_client.client.session import Session, SessionEvent, SessionResponse
from atproto_client.models.languages import DEFAULT_LANGUAGE_CODE1
from atproto_client.utils import TextBuilder

Expand Down Expand Up @@ -46,28 +46,28 @@ async def _invoke(self, invoke_type: 'InvokeType', **kwargs: t.Any) -> 'Response

return await super()._invoke(invoke_type, **kwargs)

async def _set_session(self, event: SessionChangeEvent, session: SessionResponse) -> None:
async def _set_session(self, event: SessionEvent, session: SessionResponse) -> None:
self._set_session_common(session)
await self._call_on_session_change_callbacks(event, self._session.copy())

async def _get_and_set_session(self, login: str, password: str) -> 'models.ComAtprotoServerCreateSession.Response':
session = await self.com.atproto.server.create_session(
models.ComAtprotoServerCreateSession.Data(identifier=login, password=password)
)
await self._set_session(SessionChangeEvent.CREATE, session)
await self._set_session(SessionEvent.CREATE, session)
return session

async def _refresh_and_set_session(self) -> 'models.ComAtprotoServerRefreshSession.Response':
refresh_session = await self.com.atproto.server.refresh_session(
headers=self._get_auth_headers(self._refresh_jwt), session_refreshing=True
)
await self._set_session(SessionChangeEvent.REFRESH, refresh_session)
await self._set_session(SessionEvent.REFRESH, refresh_session)

return refresh_session

async def _import_session_string(self, session_string: str) -> SessionString:
import_session = SessionString.decode(session_string)
await self._set_session(SessionChangeEvent.IMPORT, import_session)
async def _import_session_string(self, session_string: str) -> Session:
import_session = Session.decode(session_string)
await self._set_session(SessionEvent.IMPORT, import_session)

return import_session

Expand Down
14 changes: 7 additions & 7 deletions packages/atproto_client/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from atproto_client.client.methods_mixin.session import SessionDispatchMixin
from atproto_client.client.methods_mixin.strong_ref_arg_backward_compatibility import _StrongRefArgBackwardCompatibility
from atproto_client.client.raw import ClientRaw
from atproto_client.client.session import SessionChangeEvent, SessionResponse, SessionString
from atproto_client.client.session import Session, SessionEvent, SessionResponse
from atproto_client.models.languages import DEFAULT_LANGUAGE_CODE1
from atproto_client.utils import TextBuilder

Expand Down Expand Up @@ -40,28 +40,28 @@ def _invoke(self, invoke_type: 'InvokeType', **kwargs: t.Any) -> 'Response':

return super()._invoke(invoke_type, **kwargs)

def _set_session(self, event: SessionChangeEvent, session: SessionResponse) -> None:
def _set_session(self, event: SessionEvent, session: SessionResponse) -> None:
self._set_session_common(session)
self._call_on_session_change_callbacks(event, self._session.copy())

def _get_and_set_session(self, login: str, password: str) -> 'models.ComAtprotoServerCreateSession.Response':
session = self.com.atproto.server.create_session(
models.ComAtprotoServerCreateSession.Data(identifier=login, password=password)
)
self._set_session(SessionChangeEvent.CREATE, session)
self._set_session(SessionEvent.CREATE, session)
return session

def _refresh_and_set_session(self) -> 'models.ComAtprotoServerRefreshSession.Response':
refresh_session = self.com.atproto.server.refresh_session(
headers=self._get_auth_headers(self._refresh_jwt), session_refreshing=True
)
self._set_session(SessionChangeEvent.REFRESH, refresh_session)
self._set_session(SessionEvent.REFRESH, refresh_session)

return refresh_session

def _import_session_string(self, session_string: str) -> SessionString:
import_session = SessionString.decode(session_string)
self._set_session(SessionChangeEvent.IMPORT, import_session)
def _import_session_string(self, session_string: str) -> Session:
import_session = Session.decode(session_string)
self._set_session(SessionEvent.IMPORT, import_session)

return import_session

Expand Down
64 changes: 57 additions & 7 deletions packages/atproto_client/client/methods_mixin/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

from atproto_client.client.session import (
AsyncSessionChangeCallback,
Session,
SessionChangeCallback,
SessionChangeEvent,
SessionEvent,
SessionResponse,
SessionString,
)

if t.TYPE_CHECKING:
Expand All @@ -24,9 +24,28 @@ def __init__(self, *args, **kwargs: t.Any) -> None:
self._on_session_change_async_callbacks: t.List[AsyncSessionChangeCallback] = []

def on_session_change(self, callback: SessionChangeCallback) -> None:
"""Register a callback for session change event.
Args:
callback: A callback to be called when the session changes.
The callback must accept two arguments: event and session.
Example:
>>> from atproto import Client, SessionEvent, Session
>>>
>>> client = Client()
>>>
>>> def on_session_change(event: SessionEvent, session: Session):
>>> print(event, session)
>>>
>>> client.on_session_change(on_session_change)
Returns:
:obj:`None`
"""
self._on_session_change_callbacks.append(callback)

def _call_on_session_change_callbacks(self, event: SessionChangeEvent, session: SessionString) -> None:
def _call_on_session_change_callbacks(self, event: SessionEvent, session: Session) -> None:
for on_session_change_callback in self._on_session_change_callbacks:
on_session_change_callback(event, session)

Expand All @@ -38,9 +57,35 @@ def __init__(self, *args, **kwargs: t.Any) -> None:
self._on_session_change_async_callbacks: t.List[AsyncSessionChangeCallback] = []

def on_session_change(self, callback: AsyncSessionChangeCallback) -> None:
"""Register a callback for session change event.
Args:
callback: A callback to be called when the session changes.
The callback must accept two arguments: event and session.
Note:
Possible events: `SessionEvent.IMPORT`, `SessionEvent.CREATE`, `SessionEvent.REFRESH`.
Tip:
You should save the session string to persistent storage
on `SessionEvent.CREATE` and `SessionEvent.REFRESN` event.
Example:
>>> from atproto import AsyncClient, SessionEvent, Session
>>>
>>> client = AsyncClient()
>>>
>>> async def on_session_change(event: SessionEvent, session: Session):
>>> print(event, session)
>>>
>>> client.on_session_change(on_session_change)
Returns:
:obj:`None`
"""
self._on_session_change_async_callbacks.append(callback)

async def _call_on_session_change_callbacks(self, event: SessionChangeEvent, session: SessionString) -> None:
async def _call_on_session_change_callbacks(self, event: SessionEvent, session: Session) -> None:
coroutines = []
for on_session_change_async_callback in self._on_session_change_async_callbacks:
coroutines.append(on_session_change_async_callback(event, session))
Expand All @@ -58,7 +103,7 @@ def __init__(self, *args, **kwargs: t.Any) -> None:
self._refresh_jwt: t.Optional[str] = None
self._refresh_jwt_payload: t.Optional['JwtPayload'] = None

self._session: t.Optional[SessionString] = None
self._session: t.Optional[Session] = None

def _should_refresh_session(self) -> bool:
expired_at = self.get_time_from_timestamp(self._access_jwt_payload.exp)
Expand All @@ -73,7 +118,7 @@ def _set_session_common(self, session: SessionResponse) -> None:
self._refresh_jwt = session.refresh_jwt
self._refresh_jwt_payload = get_jwt_payload(session.refresh_jwt)

self._session = SessionString(
self._session = Session(
access_jwt=session.access_jwt,
refresh_jwt=session.refresh_jwt,
did=session.did,
Expand Down Expand Up @@ -101,6 +146,11 @@ def export_session_string(self) -> str:
Rate limited by handle.
30/5 min, 300/day.
Attention:
You must export session at the end of the Client`s life cycle!
Alternatively, you can subscribe to the session change event.
Use :py:attr:`~on_session_change` to register handler.
Example:
>>> from atproto import Client
>>> # the first time login with login and password
Expand All @@ -115,4 +165,4 @@ def export_session_string(self) -> str:
Returns:
:obj:`str`: Session string.
"""
return self._session.encode()
return self._session.export()
39 changes: 31 additions & 8 deletions packages/atproto_client/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from atproto_client import models


class SessionChangeEvent(Enum):
class SessionEvent(Enum):
IMPORT = 'import'
CREATE = 'creat'
REFRESH = 'refresh'
Expand All @@ -18,12 +18,18 @@ class SessionChangeEvent(Enum):


@dataclass
class SessionString:
class Session:
handle: str
did: str
access_jwt: str
refresh_jwt: str

def __repr__(self) -> str:
return f'<Session handle={self.handle} did={self.did}>'

def __str__(self) -> str:
return self.encode()

def encode(self) -> str:
payload = [
self.handle,
Expand All @@ -34,19 +40,36 @@ def encode(self) -> str:
return _SESSION_STRING_SEPARATOR.join(payload)

@classmethod
def decode(cls, session_string: str) -> 'SessionString':
def decode(cls, session_string: str) -> 'Session':
handle, did, access_jwt, refresh_jwt = session_string.split(_SESSION_STRING_SEPARATOR)
return cls(handle, did, access_jwt, refresh_jwt)

def copy(self) -> 'SessionString':
return SessionString(self.handle, self.did, self.access_jwt, self.refresh_jwt)
def copy(self) -> 'Session':
return Session(self.handle, self.did, self.access_jwt, self.refresh_jwt)

#: Alias for :attr:`encode`
export = encode


SessionResponse: te.TypeAlias = t.Union[
'models.ComAtprotoServerCreateSession.Response',
'models.ComAtprotoServerRefreshSession.Response',
'SessionString',
'Session',
]

SessionChangeCallback = t.Callable[[SessionChangeEvent, SessionString], None]
AsyncSessionChangeCallback = t.Callable[[SessionChangeEvent, SessionString], t.Coroutine[t.Any, t.Any, None]]
SessionChangeCallback = t.Callable[[SessionEvent, Session], None]
AsyncSessionChangeCallback = t.Callable[[SessionEvent, Session], t.Coroutine[t.Any, t.Any, None]]


class SessionString(Session):
def __init_subclass__(cls, *args, **kwargs: t.Any) -> None:
import warnings

warnings.warn('SessionString class is deprecated. Use Session class instead.', DeprecationWarning, stacklevel=2)
super().__init_subclass__(*args, **kwargs)

def __init__(self, *args, **kwargs: t.Any) -> None:
import warnings

warnings.warn('SessionString class is deprecated. Use Session class instead.', DeprecationWarning, stacklevel=2)
super().__init__(*args, **kwargs)

0 comments on commit 98ed3db

Please sign in to comment.