Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add request headers configuration for the WebSocket client #5

Merged
merged 1 commit into from
Aug 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions websocket_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ type WebSocketClientOpts struct {
// HandShakeTimeout specifies the amount of time allowed to complete the
// WebSocket handshake.
HandShakeTimeout time.Duration
// RequestHeaders specifies the headers to be sent with the WebSocket
// handshake.
RequestHeaders map[string][]string
// ServerURL specifies the WebSocket server URL.
ServerURL url.URL
// TLSConfig specifies the TLS configuration for the WebSocketClient.
Expand All @@ -75,8 +78,9 @@ type WebSocketClientOpts struct {

// WebSocketClient is the client instance for a WebSocket.
type WebSocketClient struct {
handler WebSocketClientEventHandler
serverURL url.URL
handler WebSocketClientEventHandler
requestHeaders http.Header
serverURL url.URL

dialer websocket.Dialer
isConnected bool
Expand All @@ -101,16 +105,17 @@ func NewWebSocketClient(opts WebSocketClientOpts) (*WebSocketClient, error) {
}

return &WebSocketClient{
serverURL: opts.ServerURL,
handler: opts.Handler,
serverURL: opts.ServerURL,
handler: opts.Handler,
requestHeaders: opts.RequestHeaders,

dialer: dialer,
isConnected: false,
}, nil
}

func (c *WebSocketClient) Run(ctx context.Context) error {
conn, resp, err := c.dialer.Dial(c.serverURL.String(), nil)
conn, resp, err := c.dialer.Dial(c.serverURL.String(), c.requestHeaders)
if err != nil {
if resp != nil {
return fmt.Errorf("unable to connect to server URL at %s (%s): %w", c.serverURL.String(), resp.Status, err)
Expand Down
12 changes: 10 additions & 2 deletions websocket_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ func createWebSocketClient(t *testing.T, serverURL url.URL, enableTLS bool) *web

handler := Mock[ankh.WebSocketClientEventHandler]()
client, err := ankh.NewWebSocketClient(ankh.WebSocketClientOpts{
ServerURL: serverURL,
ServerURL: serverURL,
RequestHeaders: map[string][]string{
"X-Custom-Header": {"custom-value"},
},
Handler: handler,
HandShakeTimeout: 5 * time.Second,
TLSConfig: tlsConfig,
Expand Down Expand Up @@ -117,7 +120,8 @@ func TestWebSocketClient(t *testing.T) {
server := createWebSocketServer(t, tt.withTLS)
serverHandler := server.mockHandlers[0]
captorServerSession := Captor[*ankh.Session]()
When(serverHandler.OnConnectionHandler(Any[http.ResponseWriter](), Any[*http.Request]())).ThenReturn("client-key", nil)
captorServerRequest := Captor[*http.Request]()
When(serverHandler.OnConnectionHandler(Any[http.ResponseWriter](), captorServerRequest.Capture())).ThenReturn("client-key", nil)
WhenSingle(serverHandler.OnConnectedHandler(Exact("client-key"), captorServerSession.Capture())).ThenReturn(nil)
defer server.cancel()

Expand All @@ -129,6 +133,10 @@ func TestWebSocketClient(t *testing.T) {
handler := client.mockHandler
defer client.cancel()

t.Log("verify client connection request headers are sent to the server")
waitForCapture(t, captorServerRequest)
require.Equal(t, "custom-value", captorServerRequest.Last().Header.Get("X-Custom-Header"))

t.Log("obtain server session for further client testing")
waitForCapture(t, captorServerSession)
serverSession := captorServerSession.Last()
Expand Down
2 changes: 1 addition & 1 deletion websocket_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func waitFor(t *testing.T) {
func waitForCapture[T any](t *testing.T, captor matchers.ArgumentCaptor[T]) {
t.Helper()

defaultWaitForCapture := 10 * time.Millisecond
defaultWaitForCapture := 100 * time.Millisecond
waitForCapturStr := os.Getenv("ANKH_TEST_WAIT_FOR_CAPTURE")
waitForCapture := defaultWaitForCapture
if len(waitForCapturStr) != 0 {
Expand Down