From a0dd49a79570a3b75a880e47038161983de94f2e Mon Sep 17 00:00:00 2001 From: ramfox Date: Thu, 23 Sep 2021 00:58:10 -0400 Subject: [PATCH] wip: add `Subscribe` & `Unsubscribe` process to websocket connections --- lib/websocket.go | 208 +++++++++++++++++++++++++++++++++++++++--- lib/websocket_test.go | 98 +++++++++++++++++++- 2 files changed, 291 insertions(+), 15 deletions(-) diff --git a/lib/websocket.go b/lib/websocket.go index a91ce8b6b..589d81c76 100644 --- a/lib/websocket.go +++ b/lib/websocket.go @@ -2,8 +2,15 @@ package lib import ( "context" + "encoding/json" + "fmt" + "io" "net/http" + "sync" + "github.com/google/uuid" + "github.com/qri-io/qri/auth/key" + "github.com/qri-io/qri/auth/token" "github.com/qri-io/qri/event" "nhooyr.io/websocket" "nhooyr.io/websocket/wsjson" @@ -11,6 +18,23 @@ import ( const qriWebsocketProtocol = "qri-websocket" +// newID returns a new websocket connection ID +func newID() string { + return uuid.New().String() +} + +// SetIDRand sets the random reader that NewID uses as a source of random bytes +// passing in nil will default to crypto.Rand. This can be used to make ID +// generation deterministic for tests. eg: +// myString := "SomeRandomStringThatIsLong-SoYouCanCallItAsMuchAsNeeded..." +// workflow.SetIDRand(strings.NewReader(myString)) +// a := NewID() +// workflow.SetIDRand(strings.NewReader(myString)) +// b := NewID() +func SetIDRand(r io.Reader) { + uuid.SetRand(r) +} + // WebsocketHandler defines the handler interface type WebsocketHandler interface { WSConnectionHandler(w http.ResponseWriter, r *http.Request) @@ -20,7 +44,16 @@ type WebsocketHandler interface { // and serves to maintain the list of connections type wsHandler struct { // Collect all websocket connections - conns []*websocket.Conn + conns map[string]*wsConn + connsLock sync.Mutex + keystore key.Store + subscriptions map[string]string + subLock sync.Mutex +} + +type wsConn struct { + profileID string + conn *websocket.Conn } var _ WebsocketHandler = (*wsHandler)(nil) @@ -29,7 +62,11 @@ var _ WebsocketHandler = (*wsHandler)(nil) // can connect to in order to get realtime events func NewWebsocketHandler(ctx context.Context, inst *Instance) (WebsocketHandler, error) { ws := &wsHandler{ - conns: []*websocket.Conn{}, + conns: map[string]*wsConn{}, + connsLock: sync.Mutex{}, + keystore: inst.keystore, + subscriptions: map[string]string{}, + subLock: sync.Mutex{}, } inst.bus.SubscribeAll(ws.wsMessageHandler) @@ -38,7 +75,7 @@ func NewWebsocketHandler(ctx context.Context, inst *Instance) (WebsocketHandler, // WSConnectionHandler handles websocket upgrade requests and accepts the connection func (h *wsHandler) WSConnectionHandler(w http.ResponseWriter, r *http.Request) { - c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ Subprotocols: []string{qriWebsocketProtocol}, InsecureSkipVerify: true, }) @@ -46,7 +83,14 @@ func (h *wsHandler) WSConnectionHandler(w http.ResponseWriter, r *http.Request) log.Debugf("Websocket accept error: %s", err) return } - h.conns = append(h.conns, c) + connID := newID() + wsc := &wsConn{ + conn: conn, + } + h.connsLock.Lock() + defer h.connsLock.Unlock() + h.conns[connID] = wsc + go h.read(connID) } func (h *wsHandler) wsMessageHandler(_ context.Context, e event.Event) error { @@ -58,14 +102,154 @@ func (h *wsHandler) wsMessageHandler(_ context.Context, e event.Event) error { "data": e.Payload, } - log.Debugf("sending event %q to %d websocket conns", e.Type, len(h.conns)) - for k, c := range h.conns { - go func(k int, c *websocket.Conn) { - err := wsjson.Write(ctx, c, evt) - if err != nil { - log.Errorf("connection %d: wsjson write error: %s", k, err) - } - }(k, c) + profileIDString := e.ProfileID + if profileIDString == "" { + log.Debugf("Event with SessionID %q has no scope. Not sending event over websocket.", e.SessionID) + return nil + } + connID, ok := h.subscriptions[profileIDString] + if !ok { + return fmt.Errorf("no websocket connection ID found for profile %q", profileIDString) + } + c, ok := h.conns[connID] + if !ok { + h.unsubscribeConn(profileIDString) + return fmt.Errorf("no websocket connection found for connection ID %q, profile %q", connID, profileIDString) + } + log.Debugf("sending event %q to websocket conns %q", e.Type, profileIDString) + err := wsjson.Write(ctx, c.conn, evt) + if err != nil { + log.Errorf("connection %q: wsjson write error: %s", profileIDString, err) + } + return nil +} + +// subscribeConn authenticates the given token and adds the connID to the map +// of "subscribed" connections +func (h *wsHandler) subscribeConn(connID, tokenString string) error { + ctx := context.TODO() + tok, err := token.ParseAuthToken(ctx, tokenString, h.keystore) + if err != nil { + return err } + + claims, ok := tok.Claims.(*token.Claims) + if !ok || claims.Subject == "" { + return fmt.Errorf("cannot get profile.ID from token") + } + // TODO(b5): at this point we have a valid signature of a profileID string + // but no proof that this profile is owned by the key that signed the + // token. We either need ProfileID == KeyID, or we need a UCAN. we need to + // check for those, ideally in a method within the profile package that + // abstracts over profile & key agreement + + h.connsLock.Lock() + c, ok := h.conns[connID] + if !ok { + return fmt.Errorf("no connection for connection ID %q found", connID) + } + c.profileID = claims.Subject + h.connsLock.Unlock() + + h.subLock.Lock() + defer h.subLock.Unlock() + h.subscriptions[claims.Subject] = connID return nil } + +// unsubscribeConn remove the profileID and connID from the map of "subscribed" +// connections +func (h *wsHandler) unsubscribeConn(profileID string) { + h.subLock.Lock() + defer h.subLock.Unlock() + delete(h.subscriptions, profileID) +} + +// removeConn removes the conn from the map of connections and subscriptions +// closing the connection if needed +func (h *wsHandler) removeConn(connID string) { + c, ok := h.conns[connID] + if !ok { + return + } + defer func() { + c.conn.Close(websocket.StatusNormalClosure, "pruning connection") + }() + if c.profileID != "" { + h.unsubscribeConn(c.profileID) + } + h.connsLock.Lock() + defer h.connsLock.Unlock() + delete(h.conns, connID) +} + +// read listens to the given connection, handling any messages that come through +// stops listening if it encounters any error +func (h *wsHandler) read(id string) error { + ctx := context.Background() + msg := &message{} + var err error + wsc, ok := h.conns[id] + if !ok { + return fmt.Errorf("connection for connection ID %q not found", id) + } + + for { + err = wsjson.Read(ctx, wsc.conn, msg) + if err != nil { + // all websocket methods that return w/ failure + // close the connection + h.removeConn(id) + return err + } + go h.handleMessage(id, msg) + } +} + +// handleMessage handles each message based on msgType +func (h *wsHandler) handleMessage(id string, msg *message) { + switch msg.Type { + case wsSubscribe: + subMsg := &subscribeMessage{} + err := json.Unmarshal(msg.Payload, subMsg) + if err != nil { + log.Debugf("connection %q - error unmarshaling payload for subscribe message: %s", id, err) + } + h.subscribeConn(id, subMsg.Token) + case wsUnsubscribe: + c, ok := h.conns[id] + if !ok { + log.Errorf("conn not found %q", id) + return + } + h.unsubscribeConn(c.profileID) + default: + log.Debug("unknown message type over websocket %s: %q", id, msg.Type) + } +} + +// msgType is the type of message that we receive on the +type msgType string + +const ( + // wsSubscribe indicates the connection is trying to become + // an authenticated connection + // payload is a `subscribeMessage` + wsSubscribe = msgType("subscribe") + // wsUnsubscribe indicates the connection no longer wants + // to be authenticated + // payload is nil + wsUnsubscribe = msgType("unsubscribe") +) + +// message is the expected structure of an incoming websocket message +type message struct { + Type msgType `json:"type"` + Payload json.RawMessage `json:"payload"` +} + +// subscribeMessage is the expected structure of an incoming "subscribe" +// message +type subscribeMessage struct { + Token string `json:"token"` +} diff --git a/lib/websocket_test.go b/lib/websocket_test.go index b60a8b239..5b019fe1a 100644 --- a/lib/websocket_test.go +++ b/lib/websocket_test.go @@ -1,10 +1,19 @@ package lib import ( + "bufio" + "bytes" "context" + "net" + "net/http" + "net/http/httptest" + "strings" "testing" "github.com/qri-io/qfs" + "github.com/qri-io/qri/auth/key" + testkeys "github.com/qri-io/qri/auth/key/test" + "github.com/qri-io/qri/auth/token" testcfg "github.com/qri-io/qri/config/test" repotest "github.com/qri-io/qri/repo/test" ) @@ -26,7 +35,18 @@ func TestWebsocket(t *testing.T) { instCtx, instCancel := context.WithCancel(context.Background()) defer instCancel() - inst, err := NewInstance(instCtx, tr.QriPath, OptConfig(cfg)) + // create key store & add test key + kd := testkeys.GetKeyData(0) + ks, err := key.NewMemStore() + if err != nil { + t.Fatal(err) + } + if err := ks.AddPubKey(context.Background(), kd.KeyID, kd.PrivKey.GetPublic()); err != nil { + t.Fatal(err) + } + + // create instance + inst, err := NewInstance(instCtx, tr.QriPath, OptConfig(cfg), OptKeyStore(ks)) if err != nil { t.Fatal(err) } @@ -34,15 +54,87 @@ func TestWebsocket(t *testing.T) { subsCount := inst.bus.NumSubscribers() wsCtx, wsCancel := context.WithCancel(context.Background()) - _, err = NewWebsocketHandler(wsCtx, inst) + defer wsCancel() + + // create WebsocketHandler + websocketHandler, err := NewWebsocketHandler(wsCtx, inst) if err != nil { t.Fatal(err) } + wsh := websocketHandler.(*wsHandler) // websockets should subscribe the WS message handler if inst.bus.NumSubscribers() != subsCount+1 { t.Fatalf("failed to subscribe websocket handlers") } - wsCancel() + // add connection + randIDStr := "test_connection_id_str" + SetIDRand(strings.NewReader(randIDStr)) + connID := newID() + SetIDRand(strings.NewReader(randIDStr)) + + wsh.WSConnectionHandler(mockWebsocketWriterAndRequest()) + _, ok := wsh.conns[connID] + if !ok { + t.Fatal("WSConnectionHandler did not create a connection") + } + + // create a token from a private key + kd = testkeys.GetKeyData(0) + tokenStr, err := token.NewPrivKeyAuthToken(kd.PrivKey, kd.KeyID.String(), 0) + if err != nil { + t.Fatal(err) + } + // upgrade connection w/ valid token + wsh.subscribeConn(connID, tokenStr) + proID := kd.KeyID.String() + gotConnID, ok := wsh.subscriptions[proID] + if !ok { + t.Fatal("wsHandler.SubscribeConn did not add profileID or conn to subscriptions map") + } + if gotConnID != connID { + t.Fatalf("wsHandler.SubscribeConn added incorrect connID to subscriptions map, expected %q, got %q", connID, gotConnID) + } + + // unsubscribe connection via profileID + wsh.unsubscribeConn(proID) + _, ok = wsh.subscriptions[proID] + if ok { + t.Fatal("wsHandler.UnsubscribeConn did not remove the profileID from the subscription map") + } + + // remove the connection + wsh.removeConn(connID) + _, ok = wsh.conns[connID] + if ok { + t.Fatal("wsHandler.Removeconn did not remove the connection from the map of conns") + } +} + +func mockWebsocketWriterAndRequest() (http.ResponseWriter, *http.Request) { + w := mockHijacker{ + ResponseWriter: httptest.NewRecorder(), + } + + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set("Connection", "keep-alive, Upgrade") + r.Header.Set("Upgrade", "websocket") + r.Header.Set("Sec-WebSocket-Version", "13") + r.Header.Set("Sec-WebSocket-Key", "test_key") + return w, r +} + +type mockHijacker struct { + http.ResponseWriter +} + +var _ http.Hijacker = mockHijacker{} + +func (mj mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { + c, _ := net.Pipe() + r := bufio.NewReader(strings.NewReader("test_reader")) + w := bufio.NewWriter(&bytes.Buffer{}) + rw := bufio.NewReadWriter(r, w) + return c, rw, nil }