Skip to content

Commit

Permalink
feat: add WebSocket connection state to the Session object
Browse files Browse the repository at this point in the history
This feature adds the TLS connection state of the established connection
on the server to the Session instance that is passed into the
OnConnectedHandler().
  • Loading branch information
mikefero committed Jul 9, 2024
1 parent a1af837 commit 37a6632
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 0 deletions.
4 changes: 4 additions & 0 deletions common.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
// limitations under the License.
package ankh

import "crypto/tls"

// CloseConnection will terminate the connection with the client/server
// application. This handle will be given during the OnConnectedHandler callback
// and is guaranteed to be thread safe.
Expand All @@ -36,4 +38,6 @@ type Session struct {
Ping PingMessage
// Send will send a binary message to the client/server application.
Send SendMessage
// ConnectionState contains the TLS connection state.
ConnectionState *tls.ConnectionState
}
10 changes: 10 additions & 0 deletions websocket_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,15 @@ func (c *WebSocketClient) Run(ctx context.Context) error {
defer c.closeConnection(conn)
defer c.handler.OnDisconnectionHandler()

// Get the TLS connection state
var connectionState tls.ConnectionState
if nConn := conn.NetConn(); nConn != nil {
tlsConn, ok := nConn.(*tls.Conn)
if ok {
connectionState = tlsConn.ConnectionState()
}
}

if err := c.handler.OnConnectedHandler(resp, &Session{
// Generate a close function for the session
Close: func() {
Expand All @@ -134,6 +143,7 @@ func (c *WebSocketClient) Run(ctx context.Context) error {
Send: func(data []byte) error {
return writeMessage(conn, &c.mutex, websocket.BinaryMessage, data)
},
ConnectionState: &connectionState,
}); err != nil {
// Unable to handle connection
return fmt.Errorf("unable to handle connection: %w", err)
Expand Down
4 changes: 4 additions & 0 deletions websocket_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ func TestWebSocketClient(t *testing.T) {
t.Log("obtain server session for further client testing")
waitForCapture(t, captorServerSession)
serverSession := captorServerSession.Last()
if tt.withTLS {
require.NotNil(t, serverSession.ConnectionState)
require.Equal(t, (uint16(tls.VersionTLS13)), serverSession.ConnectionState.Version)
}
Verify(serverHandler, Once()).OnConnectionHandler(Any[http.ResponseWriter](), Any[*http.Request]())
Verify(serverHandler, Once()).OnConnectedHandler(Exact("client-key"), Any[*ankh.Session]())

Expand Down
10 changes: 10 additions & 0 deletions websocket_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,15 @@ func (s *WebSocketServer) handleConnection(ctx context.Context, w http.ResponseW
defer s.closeConnection(conn, &mutex, clientKey, handler)
defer handler.OnDisconnectionHandler(clientKey)

// Get the TLS connection state
var connectionState tls.ConnectionState
if nConn := conn.NetConn(); nConn != nil {
tlsConn, ok := nConn.(*tls.Conn)
if ok {
connectionState = tlsConn.ConnectionState()
}
}

conn.SetPingHandler(func(appData string) error {
data := handler.OnPingHandler(clientKey, appData)
return writeMessage(conn, &mutex, websocket.PongMessage, data)
Expand All @@ -257,6 +266,7 @@ func (s *WebSocketServer) handleConnection(ctx context.Context, w http.ResponseW
Send: func(data []byte) error {
return writeMessage(conn, &mutex, websocket.BinaryMessage, data)
},
ConnectionState: &connectionState,
}); err != nil {
// Unable to handle connection
return
Expand Down
11 changes: 11 additions & 0 deletions websocket_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,17 @@ func TestWebSocketServer(t *testing.T) {
client3.cancel()
client4.cancel()
}()
if tt.withTLS {
tlsV13 := (uint16(tls.VersionTLS13))
require.NotNil(t, captor1Session.Last().ConnectionState)
require.Equal(t, tlsV13, captor1Session.Last().ConnectionState.Version)
require.NotNil(t, captor2Session.Last().ConnectionState)
require.Equal(t, tlsV13, captor2Session.Last().ConnectionState.Version)
require.NotNil(t, captor3Session.Last().ConnectionState)
require.Equal(t, tlsV13, captor3Session.Last().ConnectionState.Version)
require.NotNil(t, captor4Session.Last().ConnectionState)
require.Equal(t, tlsV13, captor4Session.Last().ConnectionState.Version)
}

t.Log("verify WebSocket clients connected")
Verify(handler1, Times(2)).OnConnectionHandler(Any[http.ResponseWriter](), Any[*http.Request]())
Expand Down

0 comments on commit 37a6632

Please sign in to comment.