diff --git a/websocket_client.go b/websocket_client.go index 995f837..81176d0 100644 --- a/websocket_client.go +++ b/websocket_client.go @@ -67,6 +67,9 @@ type WebSocketClientOpts struct { // HandShakeTimeout specifies the amount of time allowed to complete the // WebSocket handshake. HandShakeTimeout time.Duration + // RequestHeaders specifies the headers to be sent with the WebSocket + // handshake. + RequestHeaders map[string][]string // ServerURL specifies the WebSocket server URL. ServerURL url.URL // TLSConfig specifies the TLS configuration for the WebSocketClient. @@ -75,8 +78,9 @@ type WebSocketClientOpts struct { // WebSocketClient is the client instance for a WebSocket. type WebSocketClient struct { - handler WebSocketClientEventHandler - serverURL url.URL + handler WebSocketClientEventHandler + requestHeaders http.Header + serverURL url.URL dialer websocket.Dialer isConnected bool @@ -101,8 +105,9 @@ func NewWebSocketClient(opts WebSocketClientOpts) (*WebSocketClient, error) { } return &WebSocketClient{ - serverURL: opts.ServerURL, - handler: opts.Handler, + serverURL: opts.ServerURL, + handler: opts.Handler, + requestHeaders: opts.RequestHeaders, dialer: dialer, isConnected: false, @@ -110,7 +115,7 @@ func NewWebSocketClient(opts WebSocketClientOpts) (*WebSocketClient, error) { } func (c *WebSocketClient) Run(ctx context.Context) error { - conn, resp, err := c.dialer.Dial(c.serverURL.String(), nil) + conn, resp, err := c.dialer.Dial(c.serverURL.String(), c.requestHeaders) if err != nil { if resp != nil { return fmt.Errorf("unable to connect to server URL at %s (%s): %w", c.serverURL.String(), resp.Status, err) diff --git a/websocket_client_test.go b/websocket_client_test.go index d619b74..08fae45 100644 --- a/websocket_client_test.go +++ b/websocket_client_test.go @@ -47,7 +47,10 @@ func createWebSocketClient(t *testing.T, serverURL url.URL, enableTLS bool) *web handler := Mock[ankh.WebSocketClientEventHandler]() client, err := ankh.NewWebSocketClient(ankh.WebSocketClientOpts{ - ServerURL: serverURL, + ServerURL: serverURL, + RequestHeaders: map[string][]string{ + "X-Custom-Header": {"custom-value"}, + }, Handler: handler, HandShakeTimeout: 5 * time.Second, TLSConfig: tlsConfig, @@ -117,7 +120,8 @@ func TestWebSocketClient(t *testing.T) { server := createWebSocketServer(t, tt.withTLS) serverHandler := server.mockHandlers[0] captorServerSession := Captor[*ankh.Session]() - When(serverHandler.OnConnectionHandler(Any[http.ResponseWriter](), Any[*http.Request]())).ThenReturn("client-key", nil) + captorServerRequest := Captor[*http.Request]() + When(serverHandler.OnConnectionHandler(Any[http.ResponseWriter](), captorServerRequest.Capture())).ThenReturn("client-key", nil) WhenSingle(serverHandler.OnConnectedHandler(Exact("client-key"), captorServerSession.Capture())).ThenReturn(nil) defer server.cancel() @@ -129,6 +133,10 @@ func TestWebSocketClient(t *testing.T) { handler := client.mockHandler defer client.cancel() + t.Log("verify client connection request headers are sent to the server") + waitForCapture(t, captorServerRequest) + require.Equal(t, "custom-value", captorServerRequest.Last().Header.Get("X-Custom-Header")) + t.Log("obtain server session for further client testing") waitForCapture(t, captorServerSession) serverSession := captorServerSession.Last() diff --git a/websocket_server_test.go b/websocket_server_test.go index 92d4b7f..97c21f5 100644 --- a/websocket_server_test.go +++ b/websocket_server_test.go @@ -137,7 +137,7 @@ func waitFor(t *testing.T) { func waitForCapture[T any](t *testing.T, captor matchers.ArgumentCaptor[T]) { t.Helper() - defaultWaitForCapture := 10 * time.Millisecond + defaultWaitForCapture := 100 * time.Millisecond waitForCapturStr := os.Getenv("ANKH_TEST_WAIT_FOR_CAPTURE") waitForCapture := defaultWaitForCapture if len(waitForCapturStr) != 0 {