Skip to content

Commit

Permalink
chore/schema handshake metadata pt2 (#74)
Browse files Browse the repository at this point in the history
Why
===

Follow-up to #72 to account for `schema.json` files without metadata.

What changed
============

If the `handshakeSchema` field is defined, then the parameter is
required. Otherwise, the parameter is `Literal[None]`, which matches the
previous behavior of the default.

Due to the metadata field now being required, I had to remove the `=
None` default parameter. I think this is alright.

It'll mean that #73 will likely need to change to
`handshake_metadata_factory: HandshakeType | Callable[[],
Awaitable[HandshakeType]]` to avoid `async def stub() -> None: return
None` just to satisfy the async requirement. It's somewhat challenging
to capture the exact semantics we want here, but I think that's alright.

Test plan
=========

Do typechecks pass?
  • Loading branch information
blast-hardcheese authored Aug 26, 2024
1 parent d5aabb4 commit 7062bf6
Show file tree
Hide file tree
Showing 9 changed files with 90 additions and 75 deletions.
121 changes: 66 additions & 55 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ nanoid = "^2.0.0"
pydantic = {git = "https://github.com/pydantic/pydantic.git", rev = "f5d6acfe19fca38fad802458dab2b4c859182d7b"}
websockets = "^12.0"
pydantic-core = "^2.20.1"
msgpack-types = "^0.3.0"

[tool.poetry.group.dev.dependencies]
pytest = "^7.4.0"
Expand Down
2 changes: 1 addition & 1 deletion replit_river/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(
client_id: str,
server_id: str,
transport_options: TransportOptions,
handshake_metadata: Optional[HandshakeType] = None,
handshake_metadata: HandshakeType,
) -> None:
self._client_id = client_id
self._server_id = server_id
Expand Down
6 changes: 3 additions & 3 deletions replit_river/client_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
client_id: str,
server_id: str,
transport_options: TransportOptions,
handshake_metadata: Optional[HandshakeType] = None,
handshake_metadata: HandshakeType,
):
super().__init__(
transport_id=client_id,
Expand Down Expand Up @@ -226,7 +226,7 @@ async def websocket_closed_callback() -> None:
try:
await send_transport_message(
TransportMessage(
from_=transport_id,
from_=transport_id, # type: ignore
to=to_id,
streamId=stream_id,
controlFlags=0,
Expand Down Expand Up @@ -276,7 +276,7 @@ async def _establish_handshake(
transport_id: str,
to_id: str,
session_id: str,
handshake_metadata: Optional[HandshakeType],
handshake_metadata: HandshakeType,
websocket: WebSocketCommonProtocol,
old_session: Optional[ClientSession],
) -> Tuple[
Expand Down
13 changes: 8 additions & 5 deletions replit_river/codegen/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class RiverService(BaseModel):

class RiverSchema(BaseModel):
services: Dict[str, RiverService]
handshakeSchema: RiverConcreteType
handshakeSchema: Optional[RiverConcreteType]


RiverSchemaFile = RootModel[RiverSchema]
Expand Down Expand Up @@ -266,10 +266,13 @@ def generate_river_client_module(
"",
]

(handshake_type, handshake_chunks) = encode_type(
schema_root.handshakeSchema, "HandshakeSchema"
)
chunks.extend(handshake_chunks)
if schema_root.handshakeSchema is not None:
(handshake_type, handshake_chunks) = encode_type(
schema_root.handshakeSchema, "HandshakeSchema"
)
chunks.extend(handshake_chunks)
else:
handshake_type = "Literal[None]"

for schema_name, schema in schema_root.services.items():
current_chunks: List[str] = [
Expand Down
11 changes: 5 additions & 6 deletions replit_river/messages.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from typing import Any, Callable, Coroutine

import msgpack # type: ignore
import msgpack
import websockets
from pydantic import ValidationError
from pydantic_core import ValidationError as PydanticCoreValidationError
Expand Down Expand Up @@ -43,12 +43,11 @@ async def send_transport_message(
) -> None:
logger.debug("sending a message %r to ws %s", msg, ws)
try:
await ws.send(
prefix_bytes
+ msgpack.packb(
msg.model_dump(by_alias=True, exclude_none=True), datetime=True
)
packed = msgpack.packb(
msg.model_dump(by_alias=True, exclude_none=True), datetime=True
)
assert isinstance(packed, bytes)
await ws.send(prefix_bytes + packed)
except websockets.exceptions.ConnectionClosed as e:
await websocket_closed_callback()
raise WebsocketClosedException("Websocket closed during send message") from e
Expand Down
2 changes: 1 addition & 1 deletion replit_river/server_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ async def _send_handshake_response(
response_message = TransportMessage(
streamId=request_message.streamId,
id=nanoid.generate(),
from_=request_message.to,
from_=request_message.to, # type: ignore
to=request_message.from_,
seq=0,
ack=0,
Expand Down
2 changes: 1 addition & 1 deletion replit_river/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ async def send_message(
msg = TransportMessage(
streamId=stream_id,
id=nanoid.generate(),
from_=self._transport_id,
from_=self._transport_id, # type: ignore
to=self._to_id,
seq=await self._seq_manager.get_seq_and_increment(),
ack=await self._seq_manager.get_ack(),
Expand Down
7 changes: 4 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import logging
from collections.abc import AsyncIterator
from typing import Any, AsyncGenerator, NoReturn
from typing import Any, AsyncGenerator, Literal

import nanoid # type: ignore
import pytest
Expand Down Expand Up @@ -36,7 +36,7 @@ def transport_message(
) -> TransportMessage:
return TransportMessage(
id=str(nanoid.generate()),
from_=from_,
from_=from_, # type: ignore
to=to,
streamId=streamId,
seq=seq,
Expand Down Expand Up @@ -139,11 +139,12 @@ async def client(
) -> AsyncGenerator[Client, None]:
try:
async with serve(server.serve, "localhost", 8765):
client: Client[NoReturn] = Client(
client: Client[Literal[None]] = Client(
"ws://localhost:8765",
client_id="test_client",
server_id="test_server",
transport_options=transport_options,
handshake_metadata=None,
)
try:
yield client
Expand Down

0 comments on commit 7062bf6

Please sign in to comment.