Skip to content

Commit

Permalink
fix: ensure client/server session is a pointer
Browse files Browse the repository at this point in the history
  • Loading branch information
mikefero committed Jun 22, 2024
1 parent 9d74728 commit 36d3459
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 25 deletions.
9 changes: 6 additions & 3 deletions websocket_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type WebSocketClientEventHandler interface {
// OnConnectedHandler is a function callback for indicating that the client
// connection is completed while creating handles for closing connections and
// sending messages on the WebSocket.
OnConnectedHandler(resp *http.Response, session Session) error
OnConnectedHandler(resp *http.Response, session *Session) error

// OnDisconnectionHandler is a function callback for WebSocket connections
// that is executed upon disconnection of a client.
Expand Down Expand Up @@ -112,13 +112,16 @@ func NewWebSocketClient(opts WebSocketClientOpts) (*WebSocketClient, error) {
func (c *WebSocketClient) Run(ctx context.Context) error {
conn, resp, err := c.dialer.Dial(c.serverURL.String(), nil)
if err != nil {
return fmt.Errorf("unable to connect to server URL at %s (%s): %w", c.serverURL.String(), resp.Status, err)
if resp != nil {
return fmt.Errorf("unable to connect to server URL at %s (%s): %w", c.serverURL.String(), resp.Status, err)
}
return fmt.Errorf("unable to connect to server URL at %s: %w", c.serverURL.String(), err)
}
defer conn.Close()
defer c.closeConnection(conn)
defer c.handler.OnDisconnectionHandler()

if err := c.handler.OnConnectedHandler(resp, Session{
if err := c.handler.OnConnectedHandler(resp, &Session{
// Generate a close function for the session
Close: func() {
c.closeConnection(conn)
Expand Down
23 changes: 12 additions & 11 deletions websocket_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type webSocketClient struct {
cancel context.CancelFunc
client *ankh.WebSocketClient
mockHandler ankh.WebSocketClientEventHandler
session ankh.Session
session *ankh.Session
}

func createWebSocketClient(t *testing.T, serverURL url.URL, enableTLS bool) *webSocketClient {
Expand Down Expand Up @@ -77,14 +77,14 @@ func runWebSocketClient(t *testing.T, client *webSocketClient, shouldFail bool)
}
}()

client.session = ankh.Session{}
client.session = &ankh.Session{}
if !shouldFail {
t.Log("verify WebSocket client connects")
captorSession := Captor[ankh.Session]()
captorSession := Captor[*ankh.Session]()
WhenSingle(client.mockHandler.OnConnectedHandler(Any[*http.Response](), captorSession.Capture())).ThenReturn(nil)
waitForCapture(t, captorSession)
client.session = captorSession.Last()
Verify(client.mockHandler, Once()).OnConnectedHandler(Any[*http.Response](), Any[ankh.Session]())
Verify(client.mockHandler, Once()).OnConnectedHandler(Any[*http.Response](), Any[*ankh.Session]())
}
}

Expand Down Expand Up @@ -116,7 +116,7 @@ func TestWebSocketClient(t *testing.T) {
t.Log("creating WebSocket server")
server := createWebSocketServer(t, tt.withTLS)
serverHandler := server.mockHandlers[0]
captorServerSession := Captor[ankh.Session]()
captorServerSession := Captor[*ankh.Session]()
When(serverHandler.OnConnectionHandler(Any[http.ResponseWriter](), Any[*http.Request]())).ThenReturn("client-key", nil)
WhenSingle(serverHandler.OnConnectedHandler(Exact("client-key"), captorServerSession.Capture())).ThenReturn(nil)
defer server.cancel()
Expand All @@ -133,7 +133,7 @@ func TestWebSocketClient(t *testing.T) {
waitForCapture(t, captorServerSession)
serverSession := captorServerSession.Last()
Verify(serverHandler, Once()).OnConnectionHandler(Any[http.ResponseWriter](), Any[*http.Request]())
Verify(serverHandler, Once()).OnConnectedHandler(Exact("client-key"), Any[ankh.Session]())
Verify(serverHandler, Once()).OnConnectedHandler(Exact("client-key"), Any[*ankh.Session]())

t.Log("verify ping message is received from the server and client receives pong message")
captorPing := Captor[string]()
Expand All @@ -156,6 +156,7 @@ func TestWebSocketClient(t *testing.T) {
t.Log("verify message from the server to the client is received")
captorClientReadMessage := Captor[[]byte]()
serverSession.Send([]byte("ankh-server"))
waitFor(t) // wait for the read messages to be handled
Verify(handler, Once()).OnReadMessageHandler(Exact(websocket.BinaryMessage), captorClientReadMessage.Capture())
waitForCapture(t, captorClientReadMessage)
require.Equal(t, []byte("ankh-server"), captorClientReadMessage.Last())
Expand All @@ -181,12 +182,12 @@ func TestWebSocketClient(t *testing.T) {
serverURL.Path = "/path1"
client := createWebSocketClient(t, serverURL, tt.withTLS)
handler := client.mockHandler
WhenSingle(handler.OnConnectedHandler(Any[*http.Response](), Any[ankh.Session]())).ThenReturn(errors.New("connection error"))
WhenSingle(handler.OnConnectedHandler(Any[*http.Response](), Any[*ankh.Session]())).ThenReturn(errors.New("connection error"))
runWebSocketClient(t, client, true)
defer client.cancel()

waitFor(t) // wait connected handler to be called
Verify(handler, Once()).OnConnectedHandler(Any[*http.Response](), Any[ankh.Session]())
Verify(handler, Once()).OnConnectedHandler(Any[*http.Response](), Any[*ankh.Session]())
Verify(handler, Once()).OnDisconnectionHandler()
VerifyNoMoreInteractions(handler)
})
Expand All @@ -197,22 +198,22 @@ func TestWebSocketClient(t *testing.T) {

server := createWebSocketServer(t, tt.withTLS)
serverHandler := server.mockHandlers[0]
captorServerSession := Captor[ankh.Session]()
captorServerSession := Captor[*ankh.Session]()
When(serverHandler.OnConnectionHandler(Any[http.ResponseWriter](), Any[*http.Request]())).ThenReturn("client-key", nil)
When(serverHandler.OnConnectedHandler(Exact("client-key"), captorServerSession.Capture())).ThenReturn(nil)
serverURL := server.serverURL
serverURL.Path = "/path1"
client := createWebSocketClient(t, serverURL, tt.withTLS)
handler := client.mockHandler
WhenSingle(handler.OnConnectedHandler(Any[*http.Response](), Any[ankh.Session]())).ThenReturn(errors.New("connection error"))
WhenSingle(handler.OnConnectedHandler(Any[*http.Response](), Any[*ankh.Session]())).ThenReturn(errors.New("connection error"))
runWebSocketClient(t, client, true)
defer client.cancel()

waitForCapture(t, captorServerSession)
serverSession := captorServerSession.Last()
serverSession.Close()
waitFor(t) // wait connected handler to be called
Verify(handler, Once()).OnConnectedHandler(Any[*http.Response](), Any[ankh.Session]())
Verify(handler, Once()).OnConnectedHandler(Any[*http.Response](), Any[*ankh.Session]())
Verify(handler, Once()).OnDisconnectionHandler()
VerifyNoMoreInteractions(handler)
})
Expand Down
4 changes: 2 additions & 2 deletions websocket_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ type WebSocketServerEventHandler interface {
// OnConnectedHandler is a function callback for indicating that the client
// connection is completed while creating handles for closing connections and
// sending messages on the WebSocket.
OnConnectedHandler(clientKey any, session Session) error
OnConnectedHandler(clientKey any, session *Session) error

// OnDisconnectionHandler is a function callback for WebSocket connections
// that is executed upon disconnection of a client.
Expand Down Expand Up @@ -244,7 +244,7 @@ func (s *WebSocketServer) handleConnection(ctx context.Context, w http.ResponseW
return writeMessage(conn, &mutex, websocket.PongMessage, data)
})

if err := handler.OnConnectedHandler(clientKey, Session{
if err := handler.OnConnectedHandler(clientKey, &Session{
// Generate a close function for the session
Close: func() {
s.closeConnection(conn, &mutex, clientKey, handler)
Expand Down
18 changes: 9 additions & 9 deletions websocket_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -299,10 +299,10 @@ func TestWebSocketServer(t *testing.T) {
When(handler2.OnConnectionHandler(Any[http.ResponseWriter](), Any[*http.Request]())).
ThenReturn("client-key-3", nil).
ThenReturn("client-key-4", nil)
captor1Session := Captor[ankh.Session]()
captor2Session := Captor[ankh.Session]()
captor3Session := Captor[ankh.Session]()
captor4Session := Captor[ankh.Session]()
captor1Session := Captor[*ankh.Session]()
captor2Session := Captor[*ankh.Session]()
captor3Session := Captor[*ankh.Session]()
captor4Session := Captor[*ankh.Session]()
WhenSingle(handler1.OnConnectedHandler(Exact("client-key-1"), captor1Session.Capture())).ThenReturn(nil)
WhenSingle(handler1.OnConnectedHandler(Exact("client-key-2"), captor2Session.Capture())).ThenReturn(nil)
WhenSingle(handler2.OnConnectedHandler(Exact("client-key-3"), captor3Session.Capture())).ThenReturn(nil)
Expand All @@ -329,10 +329,10 @@ func TestWebSocketServer(t *testing.T) {
t.Log("verify WebSocket clients connected")
Verify(handler1, Times(2)).OnConnectionHandler(Any[http.ResponseWriter](), Any[*http.Request]())
Verify(handler2, Times(2)).OnConnectionHandler(Any[http.ResponseWriter](), Any[*http.Request]())
Verify(handler1, Once()).OnConnectedHandler(Exact("client-key-1"), Any[ankh.Session]())
Verify(handler1, Once()).OnConnectedHandler(Exact("client-key-2"), Any[ankh.Session]())
Verify(handler2, Once()).OnConnectedHandler(Exact("client-key-3"), Any[ankh.Session]())
Verify(handler2, Once()).OnConnectedHandler(Exact("client-key-4"), Any[ankh.Session]())
Verify(handler1, Once()).OnConnectedHandler(Exact("client-key-1"), Any[*ankh.Session]())
Verify(handler1, Once()).OnConnectedHandler(Exact("client-key-2"), Any[*ankh.Session]())
Verify(handler2, Once()).OnConnectedHandler(Exact("client-key-3"), Any[*ankh.Session]())
Verify(handler2, Once()).OnConnectedHandler(Exact("client-key-4"), Any[*ankh.Session]())
VerifyNoMoreInteractions(handler1)
VerifyNoMoreInteractions(handler2)

Expand Down Expand Up @@ -484,7 +484,7 @@ func TestWebSocketServer(t *testing.T) {
waitFor(t) // wait for handlers to be called

Verify(handler, Once()).OnConnectionHandler(Any[http.ResponseWriter](), Any[*http.Request]())
Verify(handler, Once()).OnConnectedHandler(Exact("client-key"), Any[ankh.Session]())
Verify(handler, Once()).OnConnectedHandler(Exact("client-key"), Any[*ankh.Session]())
Verify(handler, Once()).OnDisconnectionHandler(Exact("client-key"))
Verify(handler, Once()).OnReadMessageErrorHandler(Exact("client-key"), Any[error]()) // connection closed improperly
VerifyNoMoreInteractions(handler)
Expand Down

0 comments on commit 36d3459

Please sign in to comment.