Skip to content

Commit

Permalink
Allow setting websocket connection parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
HaraldNordgren committed Nov 14, 2024
1 parent e030ff1 commit 2e037a5
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 12 deletions.
11 changes: 11 additions & 0 deletions graphql/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,16 @@ func NewClientUsingGet(endpoint string, httpClient Doer) Client {
// The client does not support queries nor mutations, and will return an error
// if passed a request that attempts one.
func NewClientUsingWebSocket(endpoint string, wsDialer Dialer, headers http.Header) WebSocketClient {
return NewClientUsingWebSocketWithConnectionParams(endpoint, wsDialer, headers, nil)
}

// NewClientUsingWebSocketWithConnectionParams returns a [WebSocketClient] which makes subscription requests
// to the given endpoint using webSocket. It allows to pass additional connection parameters
// to the server during the initial connection handshake.
//
// connectionParams is a map of connection parameters to be sent to the server
// during the initial connection handshake.
func NewClientUsingWebSocketWithConnectionParams(endpoint string, wsDialer Dialer, headers http.Header, connParams map[string]any) WebSocketClient {
if headers == nil {
headers = http.Header{}
}
Expand All @@ -141,6 +151,7 @@ func NewClientUsingWebSocket(endpoint string, wsDialer Dialer, headers http.Head
return &webSocketClient{
Dialer: wsDialer,
Header: headers,
connParams: connParams,
errChan: make(chan error),
endpoint: endpoint,
subscriptions: subscriptionMap{map_: make(map[string]subscription)},
Expand Down
13 changes: 10 additions & 3 deletions graphql/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,22 @@ type webSocketClient struct {
Header http.Header
endpoint string
conn WSConn
connParams map[string]any
errChan chan error
subscriptions subscriptionMap
isClosing bool
sync.Mutex
}

type webSocketInitMessage struct {

Check failure on line 58 in graphql/websocket.go

View workflow job for this annotation

GitHub Actions / Lint

fieldalignment: struct with 24 pointer bytes could be 16 (govet)
Type string `json:"type"`
Payload map[string]interface{} `json:"payload"`
}

type webSocketSendMessage struct {

Check failure on line 63 in graphql/websocket.go

View workflow job for this annotation

GitHub Actions / Lint

fieldalignment: struct with 40 pointer bytes could be 32 (govet)
Payload *Request `json:"payload"`
Type string `json:"type"`
ID string `json:"id"`
Payload *Request `json:"payload"`
}

type webSocketReceiveMessage struct {
Expand All @@ -67,8 +73,9 @@ type webSocketReceiveMessage struct {
}

func (w *webSocketClient) sendInit() error {
connInitMsg := webSocketSendMessage{
Type: webSocketTypeConnInit,
connInitMsg := webSocketInitMessage{
Type: webSocketTypeConnInit,
Payload: w.connParams,
}
return w.sendStructAsJSON(connInitMsg)
}
Expand Down
21 changes: 19 additions & 2 deletions internal/integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,17 @@ func TestSubscription(t *testing.T) {
_ = `# @genqlient
subscription count { count }`

authKey := server.AuthKey

ctx := context.Background()
server := server.RunServer()
defer server.Close()

cases := []struct {

Check failure on line 75 in internal/integration/integration_test.go

View workflow job for this annotation

GitHub Actions / Lint

fieldalignment: struct with 32 pointer bytes could be 16 (govet)
name string
unsubThreshold time.Duration
connParams map[string]interface{}
counterStart int
expected subscriptionResult
}{
{
Expand All @@ -83,6 +87,18 @@ func TestSubscription(t *testing.T) {
serverChannelClosed: true,
},
},
{
name: "server_closed_authorized_user_gets_incremented_counter",
unsubThreshold: 5 * time.Second,
connParams: map[string]interface{}{
authKey: "authorized-user-token",
},
counterStart: 1000,
expected: subscriptionResult{
clientUnsubscribed: false,
serverChannelClosed: true,
},
},
{
name: "client_unsubscribed",
unsubThreshold: 300 * time.Millisecond,
Expand All @@ -95,7 +111,8 @@ func TestSubscription(t *testing.T) {

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
wsClient := newRoundtripWebSocketClient(t, server.URL)
wsClient := newRoundtripWebSocketClient(t, server.URL, tc.connParams)

errChan, err := wsClient.Start(ctx)
require.NoError(t, err)

Expand All @@ -104,7 +121,7 @@ func TestSubscription(t *testing.T) {
defer wsClient.Close()

var (
counter = 0
counter = tc.counterStart
start = time.Now()
result = subscriptionResult{}
)
Expand Down
4 changes: 2 additions & 2 deletions internal/integration/roundtrip.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,14 @@ func (md *MyDialer) DialContext(ctx context.Context, urlStr string, requestHeade
return graphql.WSConn(conn), err
}

func newRoundtripWebSocketClient(t *testing.T, endpoint string) graphql.WebSocketClient {
func newRoundtripWebSocketClient(t *testing.T, endpoint string, connectionParams map[string]interface{}) graphql.WebSocketClient {
dialer := websocket.DefaultDialer
if !strings.HasPrefix(endpoint, "ws") {
_, address, _ := strings.Cut(endpoint, "://")
endpoint = "ws://" + address
}
return &roundtripClient{
wsWrapped: graphql.NewClientUsingWebSocket(endpoint, &MyDialer{Dialer: dialer}, nil),
wsWrapped: graphql.NewClientUsingWebSocketWithConnectionParams(endpoint, &MyDialer{Dialer: dialer}, nil, connectionParams),
t: t,
}
}
42 changes: 37 additions & 5 deletions internal/integration/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,28 +156,60 @@ func (m mutationResolver) CreateUser(ctx context.Context, input NewUser) (*User,
return &newUser, nil
}

func withAuthToken(ctx context.Context, token string) context.Context {
return context.WithValue(ctx, authTokenCtxKey{}, token)
}

func getAuthToken(ctx context.Context) string {
if tkn, ok := ctx.Value(authTokenCtxKey{}).(string); ok {
return tkn
}
return ""
}

func (s *subscriptionResolver) Count(ctx context.Context) (<-chan int, error) {
respCounter := 0
if getAuthToken(ctx) == "authorized-user-token" {
respCounter = 1000
}

respChan := make(chan int, 1)
go func(respChan chan int) {
defer close(respChan)
counter := 0
closeCounter := 0
for {
if counter == 10 {
if closeCounter == 10 {
return
}
respChan <- counter
counter++
closeCounter++
respChan <- respCounter
respCounter++
time.Sleep(100 * time.Millisecond)
}
}(respChan)
return respChan, nil
}

type (
authTokenCtxKey struct{}
)

const AuthKey = "authToken"

func RunServer() *httptest.Server {
gqlgenServer := handler.New(NewExecutableSchema(Config{Resolvers: &resolver{}}))
gqlgenServer.AddTransport(transport.POST{})
gqlgenServer.AddTransport(transport.GET{})
gqlgenServer.AddTransport(transport.Websocket{})

gqlgenServer.AddTransport(transport.Websocket{
InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
if authToken, ok := initPayload[AuthKey].(string); ok && authToken != "" {
ctx = withAuthToken(ctx, authToken)
}
return ctx, &initPayload, nil
},
})

gqlgenServer.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
graphql.RegisterExtension(ctx, "foobar", "test")
return next(ctx)
Expand Down

0 comments on commit 2e037a5

Please sign in to comment.