Skip to content

Commit

Permalink
Add automatically switching to PDS endpoint after login and session r…
Browse files Browse the repository at this point in the history
…esume (#344)
  • Loading branch information
MarshalX authored Jun 18, 2024
1 parent e9b0844 commit 4828cac
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 27 deletions.
17 changes: 5 additions & 12 deletions docs/source/dm.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ Bluesky Direct Messages were launched on May 22, 2024. It began as a simple chat

The Python SDK has supported the Bluesky Direct Messages API since day one. You can use the SDK to send messages to other users, create new conversations, list existing conversations, and perform all other functions available in the mobile app and web client.

The API is not very user-friendly and lacks high-level abstractions. However, thanks to the fully automated process of code generation, it is possible to use Python abstractions for the AT Protocol. Additionally, Bsky proxies requests to the Chat API in an unconventional manner, which adds extra steps to make it work.
**You need to grant access to direct messages when creating App Password!** Otherwise, you will get "Bad token scope" error.

## Example

This example demonstrates how to list conversations, create a new conversation, and send a message to it.

```Python
```python
from atproto import Client, IdResolver, models

USERNAME = 'example.com'
Expand All @@ -21,16 +21,8 @@ def main() -> None:
# create resolver instance with in-memory cache
id_resolver = IdResolver()

# resolve our DID from a handle
did = id_resolver.handle.resolve(USERNAME)
# resolve did document from DID
did_doc = id_resolver.did.resolve(did)
# get pds (where our account is hosted) endpoint from DID Document
pds_url = did_doc.get_pds_endpoint()

# create client instance with our pds url
client = Client(base_url=pds_url)
# login with our username and password
# create client instance and login
client = Client()
client.login(USERNAME, PASSWORD)

convo_list = client.chat.bsky.convo.list_convos() # use limit and cursor to paginate
Expand All @@ -39,6 +31,7 @@ def main() -> None:
members = ', '.join(member.display_name for member in convo.members)
print(f'- ID: {convo.id} ({members})')

# resolve DID
chat_to = id_resolver.handle.resolve('test.marshal.dev')

# create or get conversation with chat_to
Expand Down
13 changes: 3 additions & 10 deletions examples/advanced_usage/direct_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,8 @@ def main() -> None:
# create resolver instance with in-memory cache
id_resolver = IdResolver()

# resolve our DID from a handle
did = id_resolver.handle.resolve(USERNAME)
# resolve did document from DID
did_doc = id_resolver.did.resolve(did)
# get pds (where our account is hosted) endpoint from DID Document
pds_url = did_doc.get_pds_endpoint()

# create client instance with our pds url
client = Client(base_url=pds_url)
# login with our username and password
# create client instance and login
client = Client()
client.login(USERNAME, PASSWORD)

convo_list = client.chat.bsky.convo.list_convos() # use limit and cursor to paginate
Expand All @@ -26,6 +18,7 @@ def main() -> None:
members = ', '.join(member.display_name for member in convo.members)
print(f'- ID: {convo.id} ({members})')

# resolve DID
chat_to = id_resolver.handle.resolve('test.marshal.dev')

# create or get conversation with chat_to
Expand Down
2 changes: 1 addition & 1 deletion packages/atproto_client/client/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ async def login(
Args:
login: Handle/username of the account.
password: Main or app-specific password of the account.
session_string: Session string (use :py:attr:`~export_session_string` to obtain it).
session_string: Session string (use :py:attr:`~export_session_string` to get it).
Note:
Either `session_string` or `login` and `password` should be provided.
Expand Down
11 changes: 11 additions & 0 deletions packages/atproto_client/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,17 @@ def __init__(self, base_url: t.Optional[str] = None, request: t.Optional[Request
def request(self) -> Request:
return self._request

def update_base_url(self, base_url: t.Optional[str] = None) -> None:
"""Update XRPC base URL.
Typically used for switching between PDSs.
Args:
base_url: New base URL.
Defaults to bsky.social.
"""
self._base_url = _handle_base_url(base_url)

def _build_url(self, nsid: str) -> str:
return f'{self._base_url}/{nsid}'

Expand Down
2 changes: 1 addition & 1 deletion packages/atproto_client/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def login(
Args:
login: Handle/username of the account.
password: Main or app-specific password of the account.
session_string: Session string (use :py:attr:`~export_session_string` to obtain it).
session_string: Session string (use :py:attr:`~export_session_string` to get it).
Note:
Either `session_string` or `login` and `password` should be provided.
Expand Down
8 changes: 8 additions & 0 deletions packages/atproto_client/client/methods_mixin/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
SessionChangeCallback,
SessionEvent,
SessionResponse,
get_session_pds_endpoint,
)
from atproto_client.exceptions import LoginRequiredError

Expand Down Expand Up @@ -123,14 +124,18 @@ def _set_session_common(self, session: SessionResponse) -> Session:
self._refresh_jwt = session.refresh_jwt
self._refresh_jwt_payload = get_jwt_payload(session.refresh_jwt)

pds_endpoint = get_session_pds_endpoint(session)

self._session = Session(
access_jwt=session.access_jwt,
refresh_jwt=session.refresh_jwt,
did=session.did,
handle=session.handle,
pds_endpoint=pds_endpoint,
)

self._set_auth_headers(session.access_jwt)
self._update_pds_endpoint(pds_endpoint)

return self._session

Expand All @@ -141,6 +146,9 @@ def _get_auth_headers(token: str) -> t.Dict[str, str]:
def _set_auth_headers(self, token: str) -> None:
self.request.set_additional_headers(self._get_auth_headers(token))

def _update_pds_endpoint(self, pds_endpoint: str) -> None:
self.update_base_url(pds_endpoint)

def export_session_string(self) -> str:
"""Export session string.
Expand Down
28 changes: 25 additions & 3 deletions packages/atproto_client/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from enum import Enum

import typing_extensions as te
from atproto_core.did_doc import DidDocument, is_valid_did_doc

if t.TYPE_CHECKING:
from atproto_client import models
Expand All @@ -23,6 +24,7 @@ class Session:
did: str
access_jwt: str
refresh_jwt: str
pds_endpoint: t.Optional[str] = 'https://bsky.social' # Backward compatibility for old sessions

def __repr__(self) -> str:
return f'<Session handle={self.handle} did={self.did}>'
Expand All @@ -36,16 +38,24 @@ def encode(self) -> str:
self.did,
self.access_jwt,
self.refresh_jwt,
self.pds_endpoint,
]
return _SESSION_STRING_SEPARATOR.join(payload)

@classmethod
def decode(cls, session_string: str) -> 'Session':
handle, did, access_jwt, refresh_jwt = session_string.split(_SESSION_STRING_SEPARATOR)
return cls(handle, did, access_jwt, refresh_jwt)
fields = session_string.split(_SESSION_STRING_SEPARATOR)

if len(fields) == 4:
# Old session format
handle, did, access_jwt, refresh_jwt = fields
return cls(handle, did, access_jwt, refresh_jwt)

handle, did, access_jwt, refresh_jwt, pds_endpoint = session_string.split(_SESSION_STRING_SEPARATOR)
return cls(handle, did, access_jwt, refresh_jwt, pds_endpoint)

def copy(self) -> 'Session':
return Session(self.handle, self.did, self.access_jwt, self.refresh_jwt)
return Session(self.handle, self.did, self.access_jwt, self.refresh_jwt, self.pds_endpoint)

#: Alias for :attr:`encode`
export = encode
Expand All @@ -59,3 +69,15 @@ def copy(self) -> 'Session':

SessionChangeCallback = t.Callable[[SessionEvent, Session], None]
AsyncSessionChangeCallback = t.Callable[[SessionEvent, Session], t.Coroutine[t.Any, t.Any, None]]


def get_session_pds_endpoint(session: SessionResponse) -> t.Optional[str]:
"""Return the PDS endpoint of the given session."""
if isinstance(session, Session):
return session.pds_endpoint

if is_valid_did_doc(session.did_doc):
doc = DidDocument.from_dict(session.did_doc)
return doc.get_pds_endpoint()

return None
47 changes: 47 additions & 0 deletions tests/test_atproto_client/client/test_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from atproto_client import Session
from atproto_client.client.session import get_session_pds_endpoint


def test_session_old_format_migration() -> None:
expected_pds = 'https://bsky.social'
session_string_old = 'handle:::did:::access_jwt:::refresh_jwt'
session_string_new = f'handle:::did:::access_jwt:::refresh_jwt:::{expected_pds}'

session = Session.decode(session_string_old)

assert session.handle == 'handle'
assert session.did == 'did'
assert session.access_jwt == 'access_jwt'
assert session.refresh_jwt == 'refresh_jwt'
assert session.pds_endpoint == expected_pds

assert session.encode() == session_string_new


def test_session_roundtrip() -> None:
session_string = 'handle:::did:::access_jwt:::refresh_jwt:::https://blabla.bla'
session = Session.decode(session_string)
assert session.encode() == session_string


def test_session_copy() -> None:
session_string = 'handle:::did:::access_jwt:::refresh_jwt:::https://blabla.bla'
session = Session.decode(session_string)
session_copy = session.copy()

assert session_copy.handle == session.handle
assert session_copy.did == session.did
assert session_copy.access_jwt == session.access_jwt
assert session_copy.refresh_jwt == session.refresh_jwt
assert session_copy.pds_endpoint == session.pds_endpoint

assert session_copy.encode() == session.encode()


def test_get_session_pds_endpoint() -> None:
expected_pds = 'https://blabla.bla'
session = Session('handle', 'did', 'access_jwt', 'refresh_jwt', expected_pds)
assert get_session_pds_endpoint(session) == expected_pds

session = Session('handle', 'did', 'access_jwt', 'refresh_jwt')
assert get_session_pds_endpoint(session) == 'https://bsky.social'

0 comments on commit 4828cac

Please sign in to comment.