From 80c152cfce817285a22df42650137b6b269e266b Mon Sep 17 00:00:00 2001 From: Michael Fero <6863207+mikefero@users.noreply.github.com> Date: Tue, 11 Jun 2024 20:30:48 -0400 Subject: [PATCH] feat: adding Ankh WebSocket client implementation - Added client and client tests for event based WebSocket - Cleaned up server implementation and removed errors causing panics - Refactored server tests to use new client - Added increased timeouts for CI/CD pipeline - Added client usage to README - Added missing acknowledgements --- .github/workflows/test_and_coverage.yml | 3 + README.md | 190 ++++++++++-- common.go | 7 + websocket_client.go | 218 ++++++++++++++ websocket_client_test.go | 220 ++++++++++++++ websocket_server.go | 25 +- websocket_server_test.go | 379 ++++++++++-------------- 7 files changed, 796 insertions(+), 246 deletions(-) create mode 100644 websocket_client.go create mode 100644 websocket_client_test.go diff --git a/.github/workflows/test_and_coverage.yml b/.github/workflows/test_and_coverage.yml index c27d785..8596b01 100644 --- a/.github/workflows/test_and_coverage.yml +++ b/.github/workflows/test_and_coverage.yml @@ -1,4 +1,7 @@ name: Test +env: + ANKH_TEST_WAIT_FOR: 500ms + ANKH_TEST_WAIT_FOR_CAPTURE: 100ms concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.sha }} cancel-in-progress: true diff --git a/README.md b/README.md index c14aebd..20c9744 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ build scalable real-time applications. To use Ankh, you'll need to have Go installed on your system. You can [download and install Go from the official website]. -### WebSocket Server +### WebSocket Client/Server Features: @@ -32,9 +32,157 @@ Features: #### Usage -Here's an example of how to set up and run an Ankh WebSocket server. +##### WebSocket Client -#### 1. Define Your Event Handlers +###### 1. Define Your Event Handlers + +Implement the `WebSocketClientEventHandler` interface to handle WebSocket +events: + +```go +type MyWebSocketClientHandler struct{} + +func (h *MyWebSocketClientHandler) OnConnectedHandler(resp *http.Response, session ankh.Session) error { + // Handle post-connection setup + fmt.Println("connected to server") + return nil +} + +func (h *MyWebSocketClientHandler) OnDisconnectionHandler() { + // Handle disconnection cleanup + fmt.Println("disconnected from server") +} + +func (h *MyWebSocketClientHandler) OnDisconnectionErrorHandler(err error) { + // Handle disconnection errors + fmt.Println("disconnection error:", err) +} + +func (h *MyWebSocketClientHandler) OnPongHandler(appData string) { + // Handle pong messages + fmt.Println("pong received:", appData) +} + +func (h *MyWebSocketClientHandler) OnReadMessageHandler(messageType int, data []byte) { + // Handle incoming messages + fmt.Println("message received:", string(data)) +} + +func (h *MyWebSocketClientHandler) OnReadMessageErrorHandler(err error) { + // Handle read message errors + fmt.Println("read message error:", err) +} + +func (h *MyWebSocketClientHandler) OnReadMessagePanicHandler(err error) { + // Handle read message panic + fmt.Println("read message panic:", err) +} +``` + +###### 2. Create and Configure the WebSocket Client + +Configure the client with the appropriate options: + +```go +opts := ankh.WebSocketClientOpts{ + Handler: &MyWebSocketClientHandler{}, + HandShakeTimeout: 10 * time.Second, + ServerURL: url.URL{Scheme: "ws", Host: "localhost:3737", Path: "/path"}, + TLSConfig: nil, // Or provide a TLS configuration for secure connections +} + +client, err := ankh.NewWebSocketClient(opts) +if err != nil { + log.Fatalf("failed to create client: %v", err) +} +``` + +###### 3. Run the Client + +Run the client within a context to manage its lifecycle: + +```go +ctx, cancel := context.WithCancel(context.Background()) +defer cancel() + +if err := client.Run(ctx); err != nil { + log.Fatalf("client error: %v", err) +} +``` + +###### Example + +Here's a complete example combining the above steps: + +```go +package main + +import ( + "context" + "fmt" + "log" + "net/http" + "net/url" + "time" + + "github.com/mikefero/ankh" +) + +type MyWebSocketClientHandler struct{} + +func (h *MyWebSocketClientHandler) OnConnectedHandler(resp *http.Response, session ankh.Session) error { + fmt.Println("connected to server") + return nil +} + +func (h *MyWebSocketClientHandler) OnDisconnectionHandler() { + fmt.Println("disconnected from server") +} + +func (h *MyWebSocketClientHandler) OnDisconnectionErrorHandler(err error) { + fmt.Println("disconnection error:", err) +} + +func (h *MyWebSocketClientHandler) OnPongHandler(appData string) { + fmt.Println("pong received:", appData) +} + +func (h *MyWebSocketClientHandler) OnReadMessageHandler(messageType int, data []byte) { + fmt.Println("message received:", string(data)) +} + +func (h *MyWebSocketClientHandler) OnReadMessageErrorHandler(err error) { + fmt.Println("read message error:", err) +} + +func (h *MyWebSocketClientHandler) OnReadMessagePanicHandler(err error) { + fmt.Println("read message panic:", err) +} + +func main() { + opts := ankh.WebSocketClientOpts{ + Handler: &MyWebSocketClientHandler{}, + HandShakeTimeout: 10 * time.Second, + ServerURL: url.URL{Scheme: "ws", Host: "localhost:3737", Path: "/path"}, + } + + client, err := ankh.NewWebSocketClient(opts) + if err != nil { + log.Fatalf("failed to create client: %v", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + if err := client.Run(ctx); err != nil { + log.Fatalf("client error: %v", err) + } +} +``` + +##### WebSocket Server + +###### 1. Define Your Event Handlers Implement the `WebSocketServerEventHandler` interface to handle WebSocket events: @@ -76,16 +224,15 @@ func (h *MyWebSocketServerHandler) OnDisconnectionErrorHandler(clientKey any, er log.Printf("disconnection error: %v, client: %v", err, clientKey) } -func (h *MyWebSocketServerHandler) OnPingHandler(clientKey any, appData string) ([]byte, error) { +func (h *MyWebSocketServerHandler) OnPingHandler(clientKey any, appData string) []byte { // Handle ping messages log.Printf("ping received from client: %v, data: %v", clientKey, appData) - return []byte("pong message or nil"), nil + return []byte("pong message or nil") } -func (h *MyWebSocketServerHandler) OnReadMessageHandler(clientKey any, messageType int, data []byte) error { +func (h *MyWebSocketServerHandler) OnReadMessageHandler(clientKey any, messageType int, data []byte) { // Handle incoming messages log.Printf("message received from client: %v, type: %v, data: %s", clientKey, messageType, string(data)) - return nil } func (h *MyWebSocketServerHandler) OnReadMessageErrorHandler(clientKey any, err error) { @@ -104,7 +251,7 @@ func (h *MyWebSocketServerHandler) OnWebSocketUpgraderErrorHandler(clientKey any } ``` -#### 2. Create and Configure the WebSocket Server +###### 2. Create and Configure the WebSocket Server Configure the server with the appropriate options: @@ -125,7 +272,7 @@ if err != nil { } ``` -#### 3. Run the Server +###### 3. Run the Server Run the server within a context to manage its lifecycle: @@ -138,7 +285,7 @@ if err := server.Run(ctx); err != nil { } ``` -#### Example +###### Example Here's a complete example combining the above steps: @@ -193,14 +340,13 @@ func (h *MyWebSocketServerHandler) OnDisconnectionErrorHandler(clientKey any, er log.Printf("disconnection error: %v, client: %v", err, clientKey) } -func (h *MyWebSocketServerHandler) OnPingHandler(clientKey any, appData string) ([]byte, error) { +func (h *MyWebSocketServerHandler) OnPingHandler(clientKey any, appData string) []byte { log.Printf("ping received from client: %v, data: %v", clientKey, appData) - return []byte("pong message or nil"), nil + return []byte("pong message or nil") } -func (h *MyWebSocketServerHandler) OnReadMessageHandler(clientKey any, messageType int, data []byte) error { +func (h *MyWebSocketServerHandler) OnReadMessageHandler(clientKey any, messageType int, data []byte) { log.Printf("message received from client: %v, type: %v, data: %s", clientKey, messageType, string(data)) - return nil } func (h *MyWebSocketServerHandler) OnReadMessageErrorHandler(clientKey any, err error) { @@ -246,10 +392,12 @@ The Session type provides thread-safe methods to interact with a connected WebSocket client/server. You can use it to send messages or close the connection. -- **Send a Binary Message**: To send a binary message to the client/server, use - the `Send` method. - **Close the Connection**: To close the client/server connection, use the `Close` method. +- **Send a Ping Message**: To send a ping message to the client/server, use the + `Ping` method. +- **Send a Binary Message**: To send a binary message to the client/server, use + the `Send` method. ## License @@ -258,9 +406,15 @@ This project is licensed under the Apache License, Version 2.0. See the ## Acknowledgements -[Gorilla WebSocket] - A fast, well-tested, and widely used WebSocket library in -Go. +- [Gorilla WebSocket] - A fast, well-tested, and widely used WebSocket library + in Go. +- [golangci-lint] - A fast Go linters runner for Go. It runs linters in + parallel, caching their results for much faster runs. +- [mockio] - A mocking framework for Go that helps in creating and using mocks + for testing purposes. [download and install Go from the official website]: https://golang.org/dl/ [LICENSE]: LICENSE [Gorilla WebSocket]: https://github.com/gorilla/websocket +[golangci-lint]: https://github.com/golangci/golangci-lint +[mockio]: https://github.com/ovechkin-dm/mockio/mock diff --git a/common.go b/common.go index e3da939..929646c 100644 --- a/common.go +++ b/common.go @@ -18,6 +18,11 @@ package ankh // and is guaranteed to be thread safe. type CloseConnection func() +// PingMessage will send a ping message to the client/server application. This handle +// will be given during the OnConnectedHandler callback and is guaranteed to be +// thread safe. +type PingMessage func(data []byte) error + // SendMessage will send a binary message to the client/server application. This // handle will be given during the OnConnectedHandler callback and is guaranteed // to be thread safe. @@ -27,6 +32,8 @@ type SendMessage func(data []byte) error type Session struct { // Close will close the connection with the client/server application. Close CloseConnection + // Ping will send a ping message to the client/server application. + Ping PingMessage // Send will send a binary message to the client/server application. Send SendMessage } diff --git a/websocket_client.go b/websocket_client.go new file mode 100644 index 0000000..9678483 --- /dev/null +++ b/websocket_client.go @@ -0,0 +1,218 @@ +// Copyright © 2024 Michael Fero +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package ankh + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "net/http" + "net/url" + "runtime" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +// WebsocketClientEventHandler represents the callback handlers for the +// WebSocket client. +type WebSocketClientEventHandler interface { + // OnConnectedHandler is a function callback for indicating that the client + // connection is completed while creating handles for closing connections and + // sending messages on the WebSocket. + OnConnectedHandler(resp *http.Response, session Session) error + + // OnDisconnectionHandler is a function callback for WebSocket connections + // that is executed upon disconnection of a client. + OnDisconnectionHandler() + + // OnDisconnectionErrorHandler is a function callback for WebSocket + // disconnection error received from the WebSocket. + OnDisconnectionErrorHandler(err error) + + // OnPongHandler is a function callback for WebSocket connections during pong + // operations. + OnPongHandler(appData string) + + // OnReadMessageHandler is a function callback for WebSocket connection that + // is executed upon a message received from the WebSocket. + OnReadMessageHandler(messageType int, data []byte) + + // OnReadMessageErrorHandler is a function callback for WebSocket connection + // read error received from the WebSocket. + OnReadMessageErrorHandler(err error) + + // OnReadMessagePanicHandler is a function callback for WebSocket connection + // read panic received from the WebSocket. + OnReadMessagePanicHandler(err error) +} + +// WebSocketClientOpts are the options for the WebSocketClient. +type WebSocketClientOpts struct { + // Handler specifies the callback handler for the WebSocketClient. + Handler WebSocketClientEventHandler + // HandShakeTimeout specifies the amount of time allowed to complete the + // WebSocket handshake. + HandShakeTimeout time.Duration + // ServerURL specifies the WebSocket server URL. + ServerURL url.URL + // TLSConfig specifies the TLS configuration for the WebSocketClient. + TLSConfig *tls.Config +} + +// WebSocketClient is the client instance for a WebSocket. +type WebSocketClient struct { + handler WebSocketClientEventHandler + serverURL url.URL + + dialer websocket.Dialer + isConnected bool + mutex sync.Mutex +} + +// NewWebSocketClient creates a new WebSocketClient instance. Options are +// validated and will return an error if any are invalid. +func NewWebSocketClient(opts WebSocketClientOpts) (*WebSocketClient, error) { + if opts.Handler == nil { + return nil, errors.New("handler must not be empty") + } + + serverURL := opts.ServerURL + serverURL.Scheme = "ws" + dialer := websocket.Dialer{ + HandshakeTimeout: opts.HandShakeTimeout, + } + if opts.TLSConfig != nil { + serverURL.Scheme = "wss" + dialer.TLSClientConfig = opts.TLSConfig + } + + return &WebSocketClient{ + serverURL: opts.ServerURL, + handler: opts.Handler, + + dialer: dialer, + isConnected: false, + }, nil +} + +func (c *WebSocketClient) Run(ctx context.Context) error { + conn, resp, err := c.dialer.Dial(c.serverURL.String(), nil) + if err != nil { + return fmt.Errorf("unable to connect to server URL at %s (%s): %w", c.serverURL.String(), resp.Status, err) + } + defer conn.Close() + defer c.closeConnection(conn) + defer c.handler.OnDisconnectionHandler() + + if err := c.handler.OnConnectedHandler(resp, Session{ + // Generate a close function for the session + Close: func() { + c.closeConnection(conn) + }, + // Generate a ping function for the session + Ping: func(data []byte) error { + return writeMessage(conn, &c.mutex, websocket.PingMessage, data) + }, + // Generate a send message function for the session + Send: func(data []byte) error { + return writeMessage(conn, &c.mutex, websocket.BinaryMessage, data) + }, + }); err != nil { + // Unable to handle connection + return fmt.Errorf("unable to handle connection: %w", err) + } + + // Update the connection status + c.mutex.Lock() + c.isConnected = true + c.mutex.Unlock() + + conn.SetPongHandler(func(appData string) error { + c.handler.OnPongHandler(appData) + return nil + }) + + // Start the read loop + done := make(chan struct{}) + go func() { + defer close(done) + defer func() { + if r := recover(); r != nil { + var stackBuf [4096]byte + n := runtime.Stack(stackBuf[:], false) + c.handler.OnReadMessagePanicHandler(fmt.Errorf("recovered from read panic: %v\n%s", r, + string(stackBuf[:n]))) + } + }() + + for { + select { + case <-ctx.Done(): + return + default: + messageType, data, err := conn.ReadMessage() + if err != nil { + // Close call may be sent from multiple areas of the client + // application or by gorilla WebSocket automatically + if errors.Is(err, websocket.ErrCloseSent) { + return + } + var closeErr *websocket.CloseError + if errors.As(err, &closeErr) { + if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { + c.handler.OnReadMessageErrorHandler(closeErr) + } + return + } + continue + } + c.handler.OnReadMessageHandler(messageType, data) + } + } + }() + + select { + case <-done: + case <-ctx.Done(): + } + + // Update the connection status + c.mutex.Lock() + defer c.mutex.Unlock() + c.isConnected = false + return nil +} + +// closeConnection will close the connection with the client application. +func (c *WebSocketClient) closeConnection(conn *websocket.Conn) { + // Ensure closed message is not sent when connection is already closed + err := writeMessage(conn, &c.mutex, websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + + // Close call may be sent from multiple areas of the client application or + // by gorilla WebSocket automatically; ensure disconnection handler is not + // called twice + if err != nil && !errors.Is(err, websocket.ErrCloseSent) { + c.handler.OnDisconnectionErrorHandler(err) + } +} + +func (c *WebSocketClient) IsConnected() bool { + c.mutex.Lock() + defer c.mutex.Unlock() + return c.isConnected +} diff --git a/websocket_client_test.go b/websocket_client_test.go new file mode 100644 index 0000000..d8d52ca --- /dev/null +++ b/websocket_client_test.go @@ -0,0 +1,220 @@ +// Copyright © 2024 Michael Fero +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package ankh_test + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "net/http" + "net/url" + "testing" + "time" + + "github.com/gorilla/websocket" + "github.com/mikefero/ankh" + . "github.com/ovechkin-dm/mockio/mock" + "github.com/stretchr/testify/require" +) + +type webSocketClient struct { + cancel context.CancelFunc + client *ankh.WebSocketClient + mockHandler ankh.WebSocketClientEventHandler + session ankh.Session +} + +func createWebSocketClient(t *testing.T, serverURL url.URL, enableTLS bool) *webSocketClient { + t.Helper() + + var tlsConfig *tls.Config + if enableTLS { + tlsConfig = generateTestCertificate(t) + tlsConfig.ServerName = "ankh.example.com" + } + + handler := Mock[ankh.WebSocketClientEventHandler]() + client, err := ankh.NewWebSocketClient(ankh.WebSocketClientOpts{ + ServerURL: serverURL, + Handler: handler, + HandShakeTimeout: 5 * time.Second, + TLSConfig: tlsConfig, + }) + if err != nil { + t.Fatalf("failed to create WebSocket client: %v", err) + return nil + } + + return &webSocketClient{ + client: client, + mockHandler: handler, + } +} + +func runWebSocketClient(t *testing.T, client *webSocketClient, shouldFail bool) { + t.Helper() + + ctx, cancel := context.WithCancel(context.Background()) + client.cancel = cancel + go func() { + err := client.client.Run(ctx) + if shouldFail { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }() + + client.session = ankh.Session{} + if !shouldFail { + t.Log("verify WebSocket client connects") + captorSession := Captor[ankh.Session]() + WhenSingle(client.mockHandler.OnConnectedHandler(Any[*http.Response](), captorSession.Capture())).ThenReturn(nil) + waitForCapture(t, captorSession) + client.session = captorSession.Last() + Verify(client.mockHandler, Once()).OnConnectedHandler(Any[*http.Response](), Any[ankh.Session]()) + } +} + +func TestWebSocketClient(t *testing.T) { + t.Run("error when creating WebSocket client with empty handler", func(t *testing.T) { + t.Parallel() + + _, err := ankh.NewWebSocketClient(ankh.WebSocketClientOpts{ + ServerURL: url.URL{}, + }) + require.Error(t, err) + require.ErrorContains(t, err, "handler must not be empty") + }) + + tests := []struct { + withTLS bool + suffix string + }{ + {false, ""}, + {true, " with TLS"}, + } + for _, tt := range tests { + tt := tt // create a new instance of tt for each iteration (loopclosure) + + t.Run(fmt.Sprintf("can start a WebSocket client and handle events%s", tt.suffix), func(t *testing.T) { + t.Parallel() + SetUp(t) + + t.Log("creating WebSocket server") + server := createWebSocketServer(t, tt.withTLS) + serverHandler := server.mockHandlers[0] + captorServerSession := Captor[ankh.Session]() + When(serverHandler.OnConnectionHandler(Any[http.ResponseWriter](), Any[*http.Request]())).ThenReturn("client-key", nil) + WhenSingle(serverHandler.OnConnectedHandler(Exact("client-key"), captorServerSession.Capture())).ThenReturn(nil) + defer server.cancel() + + t.Log("creating WebSocket client") + serverURL := server.serverURL + serverURL.Path = "/path1" + client := createWebSocketClient(t, serverURL, tt.withTLS) + runWebSocketClient(t, client, false) + handler := client.mockHandler + defer client.cancel() + + t.Log("obtain server session for further client testing") + waitForCapture(t, captorServerSession) + serverSession := captorServerSession.Last() + Verify(serverHandler, Once()).OnConnectionHandler(Any[http.ResponseWriter](), Any[*http.Request]()) + Verify(serverHandler, Once()).OnConnectedHandler(Exact("client-key"), Any[ankh.Session]()) + + t.Log("verify ping message is received from the server and client receives pong message") + captorPing := Captor[string]() + WhenSingle(serverHandler.OnPingHandler(Exact("client-key"), captorPing.Capture())).ThenReturn([]byte("ankh-server")) + client.session.Ping([]byte("ping")) + waitForCapture(t, captorPing) + require.Equal(t, "ping", captorPing.Last()) + Verify(serverHandler, Once()).OnPingHandler(Exact("client-key"), Exact("ping")) + Verify(handler, Once()).OnPongHandler(Exact("ankh-server")) + VerifyNoMoreInteractions(handler) + + t.Log("verify message from the client to the server is received") + captorServerReadMessage := Captor[[]byte]() + When(serverHandler.OnReadMessageHandler(Exact("client-key"), Exact(websocket.BinaryMessage), captorServerReadMessage.Capture())).ThenReturn(nil) + client.session.Send([]byte("ankh-client")) + waitForCapture(t, captorServerReadMessage) + require.Equal(t, []byte("ankh-client"), captorServerReadMessage.Last()) + Verify(serverHandler, Once()).OnReadMessageHandler(Exact("client-key"), Exact(websocket.BinaryMessage), Any[[]byte]()) + + t.Log("verify message from the server to the client is received") + captorClientReadMessage := Captor[[]byte]() + serverSession.Send([]byte("ankh-server")) + Verify(handler, Once()).OnReadMessageHandler(Exact(websocket.BinaryMessage), captorClientReadMessage.Capture()) + waitForCapture(t, captorClientReadMessage) + require.Equal(t, []byte("ankh-server"), captorClientReadMessage.Last()) + VerifyNoMoreInteractions(handler) + + t.Log("verify closing the connection of the WebSocket client will close the connection with the server") + client.session.Close() + waitFor(t) // wait for the close messages to be handled + Verify(serverHandler, Once()).OnDisconnectionHandler(Exact("client-key")) + require.False(t, client.client.IsConnected()) + Verify(handler, Once()).OnDisconnectionHandler() + VerifyNoMoreInteractions(handler) + }) + + t.Run(fmt.Sprintf("verify connection error occurs client returns error on connected event%s", tt.suffix), func(t *testing.T) { + t.Parallel() + SetUp(t) + + server := createWebSocketServer(t, tt.withTLS) + serverHandler := server.mockHandlers[0] + When(serverHandler.OnConnectionHandler(Any[http.ResponseWriter](), Any[*http.Request]())).ThenReturn("client-key", nil) + serverURL := server.serverURL + serverURL.Path = "/path1" + client := createWebSocketClient(t, serverURL, tt.withTLS) + handler := client.mockHandler + WhenSingle(handler.OnConnectedHandler(Any[*http.Response](), Any[ankh.Session]())).ThenReturn(errors.New("connection error")) + runWebSocketClient(t, client, true) + defer client.cancel() + + waitFor(t) // wait connected handler to be called + Verify(handler, Once()).OnConnectedHandler(Any[*http.Response](), Any[ankh.Session]()) + Verify(handler, Once()).OnDisconnectionHandler() + VerifyNoMoreInteractions(handler) + }) + + t.Run(fmt.Sprintf("verify closing server connection terminates client%s", tt.suffix), func(t *testing.T) { + t.Parallel() + SetUp(t) + + server := createWebSocketServer(t, tt.withTLS) + serverHandler := server.mockHandlers[0] + captorServerSession := Captor[ankh.Session]() + When(serverHandler.OnConnectionHandler(Any[http.ResponseWriter](), Any[*http.Request]())).ThenReturn("client-key", nil) + When(serverHandler.OnConnectedHandler(Exact("client-key"), captorServerSession.Capture())).ThenReturn(nil) + serverURL := server.serverURL + serverURL.Path = "/path1" + client := createWebSocketClient(t, serverURL, tt.withTLS) + handler := client.mockHandler + WhenSingle(handler.OnConnectedHandler(Any[*http.Response](), Any[ankh.Session]())).ThenReturn(errors.New("connection error")) + runWebSocketClient(t, client, true) + defer client.cancel() + + waitForCapture(t, captorServerSession) + serverSession := captorServerSession.Last() + serverSession.Close() + waitFor(t) // wait connected handler to be called + Verify(handler, Once()).OnConnectedHandler(Any[*http.Response](), Any[ankh.Session]()) + Verify(handler, Once()).OnDisconnectionHandler() + VerifyNoMoreInteractions(handler) + }) + } +} diff --git a/websocket_server.go b/websocket_server.go index 6ddba79..d41d4f6 100644 --- a/websocket_server.go +++ b/websocket_server.go @@ -54,7 +54,7 @@ type WebSocketServerEventHandler interface { // OnPingHandler is a function callback for WebSocket connections during ping // operations. The byte array returned will be sent back to the client as a // pong message. - OnPingHandler(clientKey any, appData string) ([]byte, error) + OnPingHandler(clientKey any, appData string) []byte // OnReadMessageHandler is a function callback for WebSocket connection that // is executed upon a message received from the WebSocket. @@ -108,8 +108,8 @@ type WebSocketServer struct { shutdownTimeout time.Duration } -// NewWebSocketServer creates a new WebSocketServer instance. All options are -// validated and required to be set. +// NewWebSocketServer creates a new WebSocketServer instance. Options are +// validated and will return an error if any are invalid. func NewWebSocketServer(opts WebSocketServerOpts) (*WebSocketServer, error) { // Validate all WebSocketServer options address := strings.TrimSpace(opts.Address) @@ -192,7 +192,9 @@ func (s *WebSocketServer) Run(ctx context.Context) error { } // closeConnection will close the connection with the client application. -func closeConnection(conn *websocket.Conn, mutex *sync.Mutex, clientKey any, handler WebSocketServerEventHandler) { +func (s *WebSocketServer) closeConnection(conn *websocket.Conn, mutex *sync.Mutex, clientKey any, + handler WebSocketServerEventHandler, +) { // Ensure closed message is not sent when connection is already closed err := writeMessage(conn, mutex, websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) @@ -234,21 +236,22 @@ func (s *WebSocketServer) handleConnection(ctx context.Context, w http.ResponseW return } defer conn.Close() - defer closeConnection(conn, &mutex, clientKey, handler) + defer s.closeConnection(conn, &mutex, clientKey, handler) defer handler.OnDisconnectionHandler(clientKey) conn.SetPingHandler(func(appData string) error { - data, err := handler.OnPingHandler(clientKey, appData) - if err != nil { - return fmt.Errorf("unable to handle ping: %w", err) - } + data := handler.OnPingHandler(clientKey, appData) return writeMessage(conn, &mutex, websocket.PongMessage, data) }) if err := handler.OnConnectedHandler(clientKey, Session{ // Generate a close function for the session Close: func() { - closeConnection(conn, &mutex, clientKey, handler) + s.closeConnection(conn, &mutex, clientKey, handler) + }, + // Generate a ping function for the session + Ping: func(data []byte) error { + return writeMessage(conn, &mutex, websocket.PingMessage, data) }, // Generate a send message function for the session Send: func(data []byte) error { @@ -293,7 +296,7 @@ func (s *WebSocketServer) handleConnection(ctx context.Context, w http.ResponseW } continue } - _ = handler.OnReadMessageHandler(clientKey, messageType, data) + handler.OnReadMessageHandler(clientKey, messageType, data) } } }() diff --git a/websocket_server_test.go b/websocket_server_test.go index 566321e..e0a75cd 100644 --- a/websocket_server_test.go +++ b/websocket_server_test.go @@ -20,10 +20,10 @@ import ( "crypto/tls" "crypto/x509" "crypto/x509/pkix" + "encoding/base64" "encoding/pem" "errors" "fmt" - "io" "math/big" "net" "net/http" @@ -52,7 +52,7 @@ func generateTestCertificate(t *testing.T) *tls.Config { template := x509.Certificate{ SerialNumber: big.NewInt(1), Subject: pkix.Name{ - Organization: []string{"fero"}, + Organization: []string{"ankh"}, }, NotBefore: time.Now(), NotAfter: time.Now().Add(30 * time.Minute), @@ -111,20 +111,55 @@ func waitForServer(t *testing.T, address string, establishConnection bool) { } } -func waitForCapture[T any](captor matchers.ArgumentCaptor[T]) { - for i := 0; i < 100; i++ { - if len(captor.Values()) != 0 { - return +func waitFor(t *testing.T) { + t.Helper() + + defaultWaitFor := 100 * time.Millisecond + waitForStr := os.Getenv("ANKH_TEST_WAIT_FOR") + waitFor := defaultWaitFor + if len(waitForStr) != 0 { + var err error + waitFor, err = time.ParseDuration(waitForStr) + if err != nil { + t.Fatalf("failed to parse timeout from ANKH_TEST_WAIT_FOR: %v", err) } - time.Sleep(10 * time.Millisecond) } + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + time.Sleep(waitFor) + }() + wg.Wait() } -type webSocketServer struct { - address string - cancel context.CancelFunc - mockHandlers []ankh.WebSocketServerEventHandler - tlsConfig *tls.Config +func waitForCapture[T any](t *testing.T, captor matchers.ArgumentCaptor[T]) { + t.Helper() + + defaultWaitForCapture := 10 * time.Millisecond + waitForCapturStr := os.Getenv("ANKH_TEST_WAIT_FOR_CAPTURE") + waitForCapture := defaultWaitForCapture + if len(waitForCapturStr) != 0 { + var err error + waitForCapture, err = time.ParseDuration(waitForCapturStr) + if err != nil { + t.Fatalf("failed to parse timeout from ANKH_TEST_WAIT_FOR_CAPTURE: %v", err) + } + } + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < 100; i++ { + if len(captor.Values()) != 0 { + return + } + time.Sleep(waitForCapture) + } + }() + wg.Wait() } func findUnusedLocalAddress(t *testing.T) string { @@ -144,15 +179,28 @@ func findUnusedLocalAddress(t *testing.T) string { return address } +type webSocketServer struct { + address string + cancel context.CancelFunc + mockHandlers []ankh.WebSocketServerEventHandler + serverURL url.URL + tlsConfig *tls.Config +} + func createWebSocketServer(t *testing.T, enableTLS bool) *webSocketServer { t.Helper() address := findUnusedLocalAddress(t) + serverURL := url.URL{ + Scheme: "ws", + Host: address, + } handler1 := Mock[ankh.WebSocketServerEventHandler]() handler2 := Mock[ankh.WebSocketServerEventHandler]() var tlsConfig *tls.Config if enableTLS { + serverURL.Scheme = "wss" tlsConfig = generateTestCertificate(t) } @@ -187,157 +235,11 @@ func createWebSocketServer(t *testing.T, enableTLS bool) *webSocketServer { handler1, handler2, }, + serverURL: serverURL, tlsConfig: tlsConfig, } } -type webSocketClient struct { - conn *websocket.Conn - cancel context.CancelFunc - close func() - pingMessage func(message string) - readMessage func() (messageType int, p []byte, err error) -} - -func createWebSocketClient(t *testing.T, address string, path string, enableTLS bool, -) (*webSocketClient, error) { - t.Helper() - - webSocketDialer := &websocket.Dialer{ - HandshakeTimeout: 5 * time.Second, - } - scheme := "ws" - if enableTLS { - scheme = "wss" - tlsConfig := generateTestCertificate(t) - tlsConfig.ServerName = "fero.example.com" - webSocketDialer.TLSClientConfig = tlsConfig - } - - u := url.URL{ - Scheme: scheme, - Host: address, - Path: path, - } - ctx, cancel := context.WithCancel(context.Background()) - connChan := make(chan *websocket.Conn) - errChan := make(chan error) - - go func() { - conn, _, err := webSocketDialer.DialContext(ctx, u.String(), nil) - if err != nil { - errChan <- err - return - } - connChan <- conn - }() - - select { - case conn := <-connChan: - return &webSocketClient{ - conn: conn, - cancel: cancel, - close: func() { - if err := conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil { - t.Fatalf("unable to close socket connection: %v", err) - } - }, - pingMessage: func(message string) { - if err := conn.WriteMessage(websocket.PingMessage, []byte(message)); err != nil { - t.Fatalf("unable to send ping message: %v", err) - } - }, - readMessage: func() (int, []byte, error) { - defaultClientReadTimeout := 1 * time.Second - clientReadTimeoutStr := os.Getenv("ANKH_TEST_CLIENT_READ_TIMEOUT") - clientReadTimeout := defaultClientReadTimeout - if len(clientReadTimeoutStr) != 0 { - var err error - clientReadTimeout, err = time.ParseDuration(clientReadTimeoutStr) - if err != nil { - t.Fatalf("failed to parse timeout from ANKH_TEST_WAIT_FOR: %v", err) - } - } - ctx, cancel := context.WithTimeout(context.Background(), clientReadTimeout) - defer cancel() - - result := make(chan struct { - messageType int - p []byte - err error - }) - - go func() { - var r io.Reader - var messageType int - var err error - - messageType, r, err = conn.NextReader() - if err != nil { - result <- struct { - messageType int - p []byte - err error - }{messageType, nil, err} - return - } - p, err := io.ReadAll(r) - result <- struct { - messageType int - p []byte - err error - }{messageType, p, err} - }() - - select { - case res := <-result: - return res.messageType, res.p, res.err - case <-ctx.Done(): - return 0, nil, ctx.Err() - } - }, - }, nil - case err := <-errChan: - defer cancel() - return nil, fmt.Errorf("failed to connect client WebSocket: %v", err) - case <-ctx.Done(): - defer cancel() - return nil, fmt.Errorf("context cancelled before WebSocket connection could be established") - } -} - -func validateClientWebSocketClosed(t *testing.T, client *webSocketClient) { - t.Helper() - - _, _, err := client.readMessage() - var closeErr *websocket.CloseError - if !errors.As(err, &closeErr) { - t.Fatalf("expected close error, got %v", err) - } - require.Equal(t, websocket.CloseNormalClosure, closeErr.Code) -} - -func waitFor(t *testing.T) { - defaultWaitFor := 100 * time.Millisecond - waitForStr := os.Getenv("ANKH_TEST_WAIT_FOR") - waitFor := defaultWaitFor - if len(waitForStr) != 0 { - var err error - waitFor, err = time.ParseDuration(waitForStr) - if err != nil { - t.Fatalf("failed to parse timeout from ANKH_TEST_WAIT_FOR: %v", err) - } - } - - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - time.Sleep(waitFor) - }() - wg.Wait() -} - func TestWebSocketServer(t *testing.T) { t.Run("error when creating WebSocket server with empty address", func(t *testing.T) { t.Parallel() @@ -405,14 +307,18 @@ func TestWebSocketServer(t *testing.T) { WhenSingle(handler1.OnConnectedHandler(Exact("client-key-2"), captor2Session.Capture())).ThenReturn(nil) WhenSingle(handler2.OnConnectedHandler(Exact("client-key-3"), captor3Session.Capture())).ThenReturn(nil) WhenSingle(handler2.OnConnectedHandler(Exact("client-key-4"), captor4Session.Capture())).ThenReturn(nil) - client1, err := createWebSocketClient(t, webSocketServer.address, "/path1", tt.withTLS) - require.NoError(t, err) - client2, err := createWebSocketClient(t, webSocketServer.address, "/path1", tt.withTLS) - require.NoError(t, err) - client3, err := createWebSocketClient(t, webSocketServer.address, "/path2", tt.withTLS) - require.NoError(t, err) - client4, err := createWebSocketClient(t, webSocketServer.address, "/path2", tt.withTLS) - require.NoError(t, err) + path1ServerURL := webSocketServer.serverURL + path1ServerURL.Path = "/path1" + path2ServerURL := webSocketServer.serverURL + path2ServerURL.Path = "/path2" + client1 := createWebSocketClient(t, path1ServerURL, tt.withTLS) + client2 := createWebSocketClient(t, path1ServerURL, tt.withTLS) + client3 := createWebSocketClient(t, path2ServerURL, tt.withTLS) + client4 := createWebSocketClient(t, path2ServerURL, tt.withTLS) + runWebSocketClient(t, client1, false) + runWebSocketClient(t, client2, false) + runWebSocketClient(t, client3, false) + runWebSocketClient(t, client4, false) defer func() { client1.cancel() client2.cancel() @@ -421,10 +327,6 @@ func TestWebSocketServer(t *testing.T) { }() t.Log("verify WebSocket clients connected") - waitForCapture(captor1Session) - waitForCapture(captor2Session) - waitForCapture(captor3Session) - waitForCapture(captor4Session) Verify(handler1, Times(2)).OnConnectionHandler(Any[http.ResponseWriter](), Any[*http.Request]()) Verify(handler2, Times(2)).OnConnectionHandler(Any[http.ResponseWriter](), Any[*http.Request]()) Verify(handler1, Once()).OnConnectedHandler(Exact("client-key-1"), Any[ankh.Session]()) @@ -437,23 +339,25 @@ func TestWebSocketServer(t *testing.T) { t.Log("verify closing the connection of the WebSocket client will close the connection with the server") // Close the WebSocket server connection and wait for client to receive the close message captor2Session.Last().Close() - _, _, err = client2.readMessage() - var closeErr *websocket.CloseError - if !errors.As(err, &closeErr) { - t.Fatalf("expected close error for client2, got %v", err) - } - require.Equal(t, websocket.CloseNormalClosure, closeErr.Code) waitFor(t) // wait for the close messages to be handled Verify(handler1, Once()).OnDisconnectionHandler(Exact("client-key-2")) VerifyNoMoreInteractions(handler1) + require.False(t, client2.client.IsConnected()) + + // This is mainly for coverage as the client application doesn't have ping + // event handler + t.Log("server can send ping messages to the client") + captor1Session.Last().Ping([]byte("ankh-server")) + captor3Session.Last().Ping([]byte("ankh-server")) + captor4Session.Last().Ping([]byte("ankh-server")) t.Log("verify ping message from the client to the WebSocket server is handled") - When(handler1.OnPingHandler(Exact("client-key-1"), Exact("client1"))).ThenReturn([]byte("client1-pong"), nil) - When(handler2.OnPingHandler(Exact("client-key-3"), Exact("client3"))).ThenReturn([]byte("client3-pong"), nil) - When(handler2.OnPingHandler(Exact("client-key-4"), Exact("client4"))).ThenReturn([]byte("client4-pong"), nil) - client1.pingMessage("client1") - client3.pingMessage("client3") - client4.pingMessage("client4") + WhenSingle(handler1.OnPingHandler(Exact("client-key-1"), Exact("client1"))).ThenReturn([]byte("client1-pong")) + WhenSingle(handler2.OnPingHandler(Exact("client-key-3"), Exact("client3"))).ThenReturn([]byte("client3-pong")) + WhenSingle(handler2.OnPingHandler(Exact("client-key-4"), Exact("client4"))).ThenReturn([]byte("client4-pong")) + client1.session.Ping([]byte("client1")) + client3.session.Ping([]byte("client3")) + client4.session.Ping([]byte("client4")) waitFor(t) // wait for the ping messages to be handled Verify(handler1, Once()).OnPingHandler(Exact("client-key-1"), Exact("client1")) Verify(handler2, Once()).OnPingHandler(Exact("client-key-3"), Exact("client3")) @@ -462,31 +366,34 @@ func TestWebSocketServer(t *testing.T) { VerifyNoMoreInteractions(handler2) t.Log("verify message from WebSocket server to the client is sent") - captor1Session.Last().Send([]byte("fero-1")) - messageType, data, err := client1.readMessage() - require.NoError(t, err) - require.Equal(t, websocket.BinaryMessage, messageType) - require.Equal(t, []byte("fero-1"), data) - captor3Session.Last().Send([]byte("fero-3")) - messageType, data, err = client3.readMessage() - require.NoError(t, err) - require.Equal(t, websocket.BinaryMessage, messageType) - require.Equal(t, []byte("fero-3"), data) - captor4Session.Last().Send([]byte("fero-4")) - messageType, data, err = client4.readMessage() - require.NoError(t, err) - require.Equal(t, websocket.BinaryMessage, messageType) - require.Equal(t, []byte("fero-4"), data) + captor1ReadMessage := Captor[[]byte]() + captor3ReadMessage := Captor[[]byte]() + captor4ReadMessage := Captor[[]byte]() + captor1Session.Last().Send([]byte("ankh-1")) + waitFor(t) // wait for the read messages to be handled + Verify(client1.mockHandler, Once()).OnReadMessageHandler(Exact(websocket.BinaryMessage), captor1ReadMessage.Capture()) + captor3Session.Last().Send([]byte("ankh-3")) + waitFor(t) // wait for the read messages to be handled + Verify(client3.mockHandler, Once()).OnReadMessageHandler(Exact(websocket.BinaryMessage), captor3ReadMessage.Capture()) + waitFor(t) // wait for the read messages to be handled + captor4Session.Last().Send([]byte("ankh-4")) + Verify(client4.mockHandler, Once()).OnReadMessageHandler(Exact(websocket.BinaryMessage), captor4ReadMessage.Capture()) + waitForCapture(t, captor1ReadMessage) + require.Equal(t, []byte("ankh-1"), captor1ReadMessage.Last()) + waitForCapture(t, captor3ReadMessage) + require.Equal(t, []byte("ankh-3"), captor3ReadMessage.Last()) + waitForCapture(t, captor4ReadMessage) + require.Equal(t, []byte("ankh-4"), captor4ReadMessage.Last()) t.Log("verify closing the connection of the WebSocket server will close the connection with the client") webSocketServer.cancel() - validateClientWebSocketClosed(t, client1) - validateClientWebSocketClosed(t, client3) - validateClientWebSocketClosed(t, client4) waitFor(t) // wait for the close messages to be handled Verify(handler1, Once()).OnDisconnectionHandler(Exact("client-key-1")) Verify(handler2, Once()).OnDisconnectionHandler(Exact("client-key-3")) Verify(handler2, Once()).OnDisconnectionHandler(Exact("client-key-4")) + require.False(t, client1.client.IsConnected()) + require.False(t, client3.client.IsConnected()) + require.False(t, client4.client.IsConnected()) }) t.Run(fmt.Sprintf("a error occurs when starting WebSocket server with invalid address%s", tt.suffix), func(t *testing.T) { @@ -536,10 +443,44 @@ func TestWebSocketServer(t *testing.T) { defer webSocketServer.cancel() handler := webSocketServer.mockHandlers[0] - When(handler.OnConnectionHandler(Any[http.ResponseWriter](), Any[*http.Request]())).ThenReturn("client-key", nil) - client, err := createWebSocketClient(t, webSocketServer.address, "/path1", tt.withTLS) + captor := Captor[*http.Request]() + When(handler.OnConnectionHandler(Any[http.ResponseWriter](), captor.Capture())).ThenReturn("client-key", nil) + + var conn net.Conn + if tt.withTLS { + var err error + conn, err = tls.Dial("tcp", webSocketServer.address, webSocketServer.tlsConfig) + require.NoError(t, err) + } else { + var err error + conn, err = net.Dial("tcp", webSocketServer.address) + require.NoError(t, err) + } + defer conn.Close() + + // Perform the WebSocket upgrade request handshake manually + key := make([]byte, 16) + _, err := rand.Read(key) + require.NoError(t, err) + secWebSocketKey := base64.StdEncoding.EncodeToString(key) + request := "GET /path1 HTTP/1.1\r\n" + + "Host: %s\r\n" + + "Upgrade: websocket\r\n" + + "Connection: Upgrade\r\n" + + "Sec-WebSocket-Key: %s\r\n" + + "Sec-WebSocket-Version: 13\r\n\r\n" + fmt.Fprintf(conn, request, webSocketServer.address, secWebSocketKey) + + // Read the response + var responseBuffer [4096]byte + n, err := conn.Read(responseBuffer[:]) require.NoError(t, err) - client.conn.Close() + response := string(responseBuffer[:n]) + require.Contains(t, response, "HTTP/1.1 101 Switching Protocols") + + waitForCapture(t, captor) + require.NotNil(t, captor.Last()) + conn.Close() waitFor(t) // wait for handlers to be called Verify(handler, Once()).OnConnectionHandler(Any[http.ResponseWriter](), Any[*http.Request]()) @@ -556,16 +497,21 @@ func TestWebSocketServer(t *testing.T) { defer webSocketServer.cancel() handler := webSocketServer.mockHandlers[0] - When(handler.OnConnectionHandler(Any[http.ResponseWriter](), Any[*http.Request]())).ThenReturn("", errors.New("connection denied")) - _, err := createWebSocketClient(t, webSocketServer.address, "/path1", tt.withTLS) - require.Error(t, err) - waitFor(t) // wait for OnConnectionHandler to be called + captor := Captor[*http.Request]() + When(handler.OnConnectionHandler(Any[http.ResponseWriter](), captor.Capture())).ThenReturn("", errors.New("connection denied")) + serverURL := webSocketServer.serverURL + serverURL.Path = "/path1" + client := createWebSocketClient(t, serverURL, tt.withTLS) + runWebSocketClient(t, client, true) + require.NotEmpty(t, client) + waitForCapture(t, captor) + require.NotNil(t, captor.Last()) Verify(handler, Once()).OnConnectionHandler(Any[http.ResponseWriter](), Any[*http.Request]()) VerifyNoMoreInteractions(handler) }) - t.Run(fmt.Sprintf("server connection handles WebSocket upgrade%s", tt.suffix), func(t *testing.T) { + t.Run(fmt.Sprintf("server connection handles WebSocket upgrade error%s", tt.suffix), func(t *testing.T) { t.Parallel() SetUp(t) webSocketServer := createWebSocketServer(t, tt.withTLS) @@ -577,21 +523,20 @@ func TestWebSocketServer(t *testing.T) { var err error conn, err = tls.Dial("tcp", webSocketServer.address, webSocketServer.tlsConfig) require.NoError(t, err) - defer conn.Close() } else { var err error conn, err = net.Dial("tcp", webSocketServer.address) require.NoError(t, err) - defer conn.Close() } + defer conn.Close() // Perform a WebSocket handshake manually which uses a missing key and version - fmt.Fprint(conn, "GET /path1 HTTP/1.1\r\n") - fmt.Fprintf(conn, "Host: %s\r\n", webSocketServer.address) - fmt.Fprintf(conn, "Upgrade: websocket\r\n") - fmt.Fprintf(conn, "Connection: Upgrade\r\n") - fmt.Fprintf(conn, "\r\n") - waitFor(t) // wait for the upgrade to be handled + request := "GET /path1 HTTP/1.1\r\n" + + "Host: %s\r\n" + + "Upgrade: websocket\r\n" + + "Connection: Upgrade\r\n\r\n" + fmt.Fprintf(conn, request, webSocketServer.address) + waitFor(t) // wait for handlers to be called Verify(handler, Once()).OnConnectionHandler(Any[http.ResponseWriter](), Any[*http.Request]()) Verify(handler, Once()).OnWebSocketUpgraderErrorHandler(Any[any](), Any[error]())