Skip to content

Commit

Permalink
Checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenthebuilder committed Apr 18, 2024
1 parent 3e4bc65 commit 33d707b
Show file tree
Hide file tree
Showing 9 changed files with 136 additions and 107 deletions.
8 changes: 3 additions & 5 deletions replit_river/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from replit_river.client_transport import ClientTransport
from replit_river.error_schema import RiverException
from replit_river.task_manager import BackgroundTaskManager
from replit_river.transport import TransportOptions
from replit_river.transport_options import TransportOptions

from .rpc import (
ErrorType,
Expand All @@ -19,10 +19,6 @@
ResponseType,
)

CROSIS_PREFIX_BYTES = b"\x00\x00"
PID2_PREFIX_BYTES = b"\xff\xff"
HEART_BEAT_INTERVAL_SECS = 2


class Client:
def __init__(
Expand All @@ -48,6 +44,7 @@ def __init__(

async def _create_session(self) -> None:
try:
logging.debug("Client start creating session")
client_session = await self._transport.create_client_session(
self._client_id, self._server_id, self._instance_id, self._ws
)
Expand All @@ -56,6 +53,7 @@ async def _create_session(self) -> None:
logging.error(f"Error creating session: {e}")
return
self._client_session = client_session
logging.debug("client start serving messages")
await self._client_session.start_serve_messages()

async def _wait_for_handshake(self) -> ClientSession:
Expand Down
20 changes: 10 additions & 10 deletions replit_river/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ async def send_rpc(
self._streams[stream_id] = output
try:
await self.send_message(
ws=self._websocket,
service_name=service_name,
procedure_name=procedure_name,
ws=self._ws,
stream_id=stream_id,
control_flags=STREAM_OPEN_BIT | STREAM_CLOSED_BIT,
payload=request_serializer(request),
service_name=service_name,
procedure_name=procedure_name,
)
except FailedSendingMessageException:
raise RiverException(
Expand Down Expand Up @@ -98,7 +98,7 @@ async def send_upload(
if init and init_serializer:
await self.send_message(
stream_id=stream_id,
ws=self._websocket,
ws=self._ws,
control_flags=STREAM_OPEN_BIT,
service_name=service_name,
procedure_name=procedure_name,
Expand All @@ -113,7 +113,7 @@ async def send_upload(
first_message = False
await self.send_message(
stream_id=stream_id,
ws=self._websocket,
ws=self._ws,
service_name=service_name,
procedure_name=procedure_name,
control_flags=control_flags,
Expand Down Expand Up @@ -169,7 +169,7 @@ async def send_subscription(
self._streams[stream_id] = output
try:
await self.send_message(
ws=self._websocket,
ws=self._ws,
service_name=service_name,
procedure_name=procedure_name,
stream_id=stream_id,
Expand Down Expand Up @@ -222,7 +222,7 @@ async def send_stream(
try:
if init and init_serializer:
await self.send_message(
ws=self._websocket,
ws=self._ws,
service_name=service_name,
procedure_name=procedure_name,
stream_id=stream_id,
Expand All @@ -234,7 +234,7 @@ async def send_stream(
request_iter = aiter(request)
first = await anext(request_iter)
await self.send_message(
ws=self._websocket,
ws=self._ws,
service_name=service_name,
procedure_name=procedure_name,
stream_id=stream_id,
Expand All @@ -253,7 +253,7 @@ async def _encode_stream() -> None:
if item is None:
continue
await self.send_message(
ws=self._websocket,
ws=self._ws,
service_name=service_name,
procedure_name=procedure_name,
stream_id=stream_id,
Expand Down Expand Up @@ -288,7 +288,7 @@ async def send_close_stream(
) -> None:
# close stream
await self.send_message(
ws=self._websocket,
ws=self._ws,
service_name=service_name,
procedure_name=procedure_name,
stream_id=stream_id,
Expand Down
16 changes: 12 additions & 4 deletions replit_river/client_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from pydantic import ValidationError
from websockets import (
ConnectionClosedError,
WebSocketCommonProtocol,
)

Expand Down Expand Up @@ -51,7 +52,7 @@ async def _send_handshake_request(
procedureName=None,
streamId=stream_id,
controlFlags=0,
id=0,
id=self.generate_nanoid(),
seq=0,
ack=0,
payload=handshake_request.model_dump(),
Expand Down Expand Up @@ -81,8 +82,13 @@ async def create_client_session(
raise RiverException(
ERROR_CODE_STREAM_CLOSED, "Stream closed before response"
)
logging.debug("river client waiting for handshake response")
while True:
data = await websocket.recv()
try:
data = await websocket.recv()
except ConnectionClosedError as e:
# TODO: handle this here
pass
try:
first_message = parse_transport_msg(data, self._transport_options)
except IgnoreTransportMessageException as e:
Expand All @@ -98,8 +104,10 @@ async def create_client_session(
handshake_response = ControlMessageHandshakeResponse(
**first_message.payload
)
except ValidationError:
raise RiverException(ERROR_HANDSHAKE, "Failed to parse handshake response")
except ValidationError as e:
raise RiverException(
ERROR_HANDSHAKE, f"Failed to parse handshake response : {e}"
)
if not handshake_response.status.ok:
raise RiverException(
ERROR_HANDSHAKE, f"Handshake failed: {handshake_response.status.reason}"
Expand Down
11 changes: 5 additions & 6 deletions replit_river/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
IgnoreTransportMessageException,
InvalidTransportMessageException,
)
from replit_river.transport import TransportOptions
from replit_river.transport_options import TransportOptions


class FailedSendingMessageException(Exception):
Expand All @@ -39,11 +39,10 @@ async def send_transport_message(
msg.model_dump(by_alias=True, exclude_none=True), datetime=True
)
)
except websockets.exceptions.ConnectionClosedOK:
raise FailedSendingMessageException(
"Trying to send message while connection closed "
f"from : {msg.from_} to {msg.to}"
)
except websockets.exceptions.ConnectionClosed as e:
raise e
except Exception as e:
raise FailedSendingMessageException(f"Exception during send message : {e}")


def formatted_bytes(message: bytes) -> str:
Expand Down
11 changes: 7 additions & 4 deletions replit_river/server.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import asyncio
import logging
from typing import Dict, Mapping, Tuple

from websockets.exceptions import ConnectionClosedError
from websockets.server import WebSocketServerProtocol

from replit_river.server_transport import ServerTransport
from replit_river.transport import Transport, TransportOptions

from .rpc import (
Expand All @@ -16,7 +18,7 @@ def __init__(self, server_id: str, transport_options: TransportOptions) -> None:
self._handlers: Dict[Tuple[str, str], Tuple[str, GenericRpcHandler]] = {}
self._server_id = server_id or "SERVER"
self._transport_options = transport_options
self._transport = Transport(
self._transport = ServerTransport(
transport_id=self._server_id,
transport_options=transport_options,
is_server=True,
Expand All @@ -29,15 +31,16 @@ def add_rpc_handlers(
self._handlers.update(rpc_handlers)

async def serve(self, websocket: WebSocketServerProtocol) -> None:
logging.debug("got a client")
logging.debug("River server started establishing session")
try:
session = await self._transport.establish_client_transport(websocket)
session = await self._transport.establish_session(websocket)
except Exception as e:
logging.error(f"Error establishing handshake, closing websocket: {e}")
await websocket.close()
return
logging.debug("River server session established, start serving messages")
try:
await session.serve()
await session.start_serve_messages()
except ConnectionClosedError as e:
logging.debug(f"ConnectionClosedError while serving {e}")
except Exception as e:
Expand Down
20 changes: 11 additions & 9 deletions replit_river/server_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)
from replit_river.rpc import (
ControlMessageHandshakeRequest,
ControlMessageHandshakeResponse,
HandShakeStatus,
TransportMessage,
)
Expand Down Expand Up @@ -45,7 +46,7 @@ async def get_or_create_session(
instance_id,
websocket,
self._transport_options,
self._close_session,
self._delete_session,
self._is_server,
self._handlers,
)
Expand All @@ -59,7 +60,7 @@ async def get_or_create_session(
instance_id,
websocket,
self._transport_options,
self._close_session,
self._delete_session,
self._is_server,
self._handlers,
)
Expand All @@ -71,20 +72,18 @@ async def get_or_create_session(
except FailedSendingMessageException as e:
raise e
if session_to_close:
logging.info("Closing stale websocket")
await session_to_close.close()
session = self._sessions[transport_id]
websocket.close_connection_task = asyncio.create_task(
self.on_disconnect(session)
)
return session

async def establish_client_transport(
async def establish_session(
self,
websocket: WebSocketServerProtocol,
) -> Session:
async for message in websocket:
try:
msg = parse_transport_msg(message)
msg = parse_transport_msg(message, self._transport_options)
handshake_request = await self._build_handshake_from_request(
msg, websocket
)
Expand All @@ -93,6 +92,7 @@ async def establish_client_transport(
except InvalidTransportMessageException:
error_msg = "Got invalid transport message, closing connection"
raise InvalidTransportMessageException(error_msg)
logging.debug("handshake request received: %r", handshake_request)
transport_id = msg.from_
to_id = msg.from_
instance_id = handshake_request.instanceId
Expand All @@ -105,7 +105,6 @@ async def establish_client_transport(
"Error building sessions from handshake request : "
f"client_id: {transport_id}, instance_id: {instance_id}, error: {e}"
)
logging.error(error_msg)
raise InvalidTransportMessageException(error_msg)
return session
raise InvalidTransportMessageException("No handshake message received")
Expand All @@ -116,6 +115,9 @@ async def _send_handshake_response(
handshake_status: HandShakeStatus,
websocket: WebSocketCommonProtocol,
) -> TransportMessage:
response = ControlMessageHandshakeResponse(
status=handshake_status,
)
response_message = TransportMessage(
streamId=request_message.streamId,
id=nanoid.generate(),
Expand All @@ -124,7 +126,7 @@ async def _send_handshake_response(
seq=0,
ack=0,
controlFlags=0,
payload=handshake_status.model_dump(by_alias=True, exclude_none=True),
payload=response.model_dump(by_alias=True, exclude_none=True),
serviceName=request_message.serviceName,
procedureName=request_message.procedureName,
)
Expand Down
Loading

0 comments on commit 33d707b

Please sign in to comment.