Skip to content

Commit

Permalink
feat: add request headers configuration for the WebSocket client
Browse files Browse the repository at this point in the history
This feature allows additional headers to be sent during the handshake
process from the client.
  • Loading branch information
mikefero authored Aug 4, 2024
1 parent c7e6ca0 commit 7add4ff
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 8 deletions.
15 changes: 10 additions & 5 deletions websocket_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ type WebSocketClientOpts struct {
// HandShakeTimeout specifies the amount of time allowed to complete the
// WebSocket handshake.
HandShakeTimeout time.Duration
// RequestHeaders specifies the headers to be sent with the WebSocket
// handshake.
RequestHeaders map[string][]string
// ServerURL specifies the WebSocket server URL.
ServerURL url.URL
// TLSConfig specifies the TLS configuration for the WebSocketClient.
Expand All @@ -75,8 +78,9 @@ type WebSocketClientOpts struct {

// WebSocketClient is the client instance for a WebSocket.
type WebSocketClient struct {
handler WebSocketClientEventHandler
serverURL url.URL
handler WebSocketClientEventHandler
requestHeaders http.Header
serverURL url.URL

dialer websocket.Dialer
isConnected bool
Expand All @@ -101,16 +105,17 @@ func NewWebSocketClient(opts WebSocketClientOpts) (*WebSocketClient, error) {
}

return &WebSocketClient{
serverURL: opts.ServerURL,
handler: opts.Handler,
serverURL: opts.ServerURL,
handler: opts.Handler,
requestHeaders: opts.RequestHeaders,

dialer: dialer,
isConnected: false,
}, nil
}

func (c *WebSocketClient) Run(ctx context.Context) error {
conn, resp, err := c.dialer.Dial(c.serverURL.String(), nil)
conn, resp, err := c.dialer.Dial(c.serverURL.String(), c.requestHeaders)
if err != nil {
if resp != nil {
return fmt.Errorf("unable to connect to server URL at %s (%s): %w", c.serverURL.String(), resp.Status, err)
Expand Down
12 changes: 10 additions & 2 deletions websocket_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ func createWebSocketClient(t *testing.T, serverURL url.URL, enableTLS bool) *web

handler := Mock[ankh.WebSocketClientEventHandler]()
client, err := ankh.NewWebSocketClient(ankh.WebSocketClientOpts{
ServerURL: serverURL,
ServerURL: serverURL,
RequestHeaders: map[string][]string{
"X-Custom-Header": {"custom-value"},
},
Handler: handler,
HandShakeTimeout: 5 * time.Second,
TLSConfig: tlsConfig,
Expand Down Expand Up @@ -117,7 +120,8 @@ func TestWebSocketClient(t *testing.T) {
server := createWebSocketServer(t, tt.withTLS)
serverHandler := server.mockHandlers[0]
captorServerSession := Captor[*ankh.Session]()
When(serverHandler.OnConnectionHandler(Any[http.ResponseWriter](), Any[*http.Request]())).ThenReturn("client-key", nil)
captorServerRequest := Captor[*http.Request]()
When(serverHandler.OnConnectionHandler(Any[http.ResponseWriter](), captorServerRequest.Capture())).ThenReturn("client-key", nil)
WhenSingle(serverHandler.OnConnectedHandler(Exact("client-key"), captorServerSession.Capture())).ThenReturn(nil)
defer server.cancel()

Expand All @@ -129,6 +133,10 @@ func TestWebSocketClient(t *testing.T) {
handler := client.mockHandler
defer client.cancel()

t.Log("verify client connection request headers are sent to the server")
waitForCapture(t, captorServerRequest)
require.Equal(t, "custom-value", captorServerRequest.Last().Header.Get("X-Custom-Header"))

t.Log("obtain server session for further client testing")
waitForCapture(t, captorServerSession)
serverSession := captorServerSession.Last()
Expand Down
2 changes: 1 addition & 1 deletion websocket_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func waitFor(t *testing.T) {
func waitForCapture[T any](t *testing.T, captor matchers.ArgumentCaptor[T]) {
t.Helper()

defaultWaitForCapture := 10 * time.Millisecond
defaultWaitForCapture := 100 * time.Millisecond
waitForCapturStr := os.Getenv("ANKH_TEST_WAIT_FOR_CAPTURE")
waitForCapture := defaultWaitForCapture
if len(waitForCapturStr) != 0 {
Expand Down

0 comments on commit 7add4ff

Please sign in to comment.