diff --git a/gateway/mw_streaming.go b/gateway/mw_streaming.go index 4d0235ede71..3e6e603b90c 100644 --- a/gateway/mw_streaming.go +++ b/gateway/mw_streaming.go @@ -3,23 +3,24 @@ package gateway import ( "context" "encoding/json" - "errors" "fmt" + "io" "net/http" "net/url" - "strings" "sync" "sync/atomic" "time" - "github.com/sirupsen/logrus" - "github.com/gorilla/mux" + "github.com/sirupsen/logrus" + "github.com/TykTechnologies/tyk/internal/errors" "github.com/TykTechnologies/tyk/internal/streaming" ) const ( + strippedRequestKey = "stripped-request-key" + tykStreamsVariablesKey = "tyk-streams-variables-key" // ExtensionTykStreaming is the oas extension for tyk streaming ExtensionTykStreaming = "x-tyk-streaming" ) @@ -38,21 +39,19 @@ var globalStreamCounter atomic.Int64 // StreamingMiddleware is a middleware that handles streaming functionality type StreamingMiddleware struct { *BaseMiddleware - streamManagers sync.Map // Map of consumer group IDs to StreamManager - ctx context.Context - cancel context.CancelFunc - allowedUnsafe []string - defaultStreamManager *StreamManager + streamManagers sync.Map // Map of consumer group IDs to StreamManager + ctx context.Context + cancel context.CancelFunc + allowedUnsafe []string + router *mux.Router } // StreamManager is responsible for creating a single stream type StreamManager struct { - streams sync.Map - routeLock sync.Mutex - muxer *mux.Router - mw *StreamingMiddleware - dryRun bool - listenPaths []string + streams sync.Map + routeLock sync.Mutex + muxer *mux.Router + mw *StreamingMiddleware } func (sm *StreamManager) initStreams(r *http.Request, config *StreamsConfig) { @@ -60,37 +59,18 @@ func (sm *StreamManager) initStreams(r *http.Request, config *StreamsConfig) { sm.muxer = mux.NewRouter() for streamID, streamConfig := range config.Streams { - sm.setUpOrDryRunStream(streamConfig, streamID) - } - - // If it is default stream manager, init muxer - if r == nil { - for _, path := range sm.listenPaths { - sm.muxer.HandleFunc(path, func(_ http.ResponseWriter, _ *http.Request) { - // Dummy handler - }) - } - } -} - -func (sm *StreamManager) setUpOrDryRunStream(streamConfig any, streamID string) { - if streamMap, ok := streamConfig.(map[string]interface{}); ok { - httpPaths := GetHTTPPaths(streamMap) - - if sm.dryRun { - if len(httpPaths) == 0 { - err := sm.createStream(streamID, streamMap) - if err != nil { - sm.mw.Logger().WithError(err).Errorf("Error creating stream %s", streamID) - } + if streamMap, ok := streamConfig.(map[string]interface{}); ok { + httpPaths := GetHTTPPaths(streamMap) + if r == nil && len(httpPaths) > 0 { + // If r is nil that means a default stream manager is being created for background jobs. + // If httpPaths is not an empty slice, we do not need any background job for this stream. + continue } - } else { err := sm.createStream(streamID, streamMap) if err != nil { sm.mw.Logger().WithError(err).Errorf("Error creating stream %s", streamID) } } - sm.listenPaths = append(sm.listenPaths, httpPaths...) } } @@ -143,20 +123,47 @@ func (s *StreamingMiddleware) EnabledForSpec() bool { return false } +func (s *StreamingMiddleware) registerHandlers(config *StreamsConfig) { + for streamId, rawConfig := range config.Streams { + streamConfig := rawConfig.(map[string]interface{}) + + httpServerInputPath := findHTTPServerInputPath(streamConfig) + for _, path := range GetHTTPPaths(streamConfig) { + if path == httpServerInputPath { + // We only use this handler to receive messages from the HTTP endpoint + // Consider this: + // input: + // http_server: + // path: /post + // timeout: 1s + // output: + // http_server: + // ws_path: /subscribe + s.router.HandleFunc(path, s.inputHttpServerPublishHandler) + } else { + // Subscription handler responds to WebSocket requests and hands over the request to Bento + s.router.HandleFunc(path, s.subscriptionHandler) + } + } + s.Logger().Debugf("Tyk Stream handlers have been registered for stream: %s", streamId) + } +} + // Init initializes the middleware func (s *StreamingMiddleware) Init() { s.Logger().Debug("Initializing StreamingMiddleware") + s.ctx, s.cancel = context.WithCancel(context.Background()) + s.router = mux.NewRouter() - s.Logger().Debug("Initializing default stream manager") - s.defaultStreamManager = s.createStreamManager(nil) + s.createStreamManager(nil) // create a default stream manager here for background jobs. + s.registerHandlers(s.getStreamsConfig(nil)) } func (s *StreamingMiddleware) createStreamManager(r *http.Request) *StreamManager { newStreamManager := &StreamManager{ - muxer: mux.NewRouter(), - mw: s, - dryRun: r == nil, + muxer: mux.NewRouter(), + mw: s, } streamID := fmt.Sprintf("_%d", time.Now().UnixNano()) s.streamManagers.Store(streamID, newStreamManager) @@ -167,7 +174,7 @@ func (s *StreamingMiddleware) createStreamManager(r *http.Request) *StreamManage return newStreamManager } -// Helper function to extract paths from an http_server configuration +// Helper function to extract paths from a http_server configuration func extractPaths(httpConfig map[string]interface{}) []string { var paths []string defaultPaths := map[string]string{ @@ -208,6 +215,17 @@ func handleBroker(brokerConfig map[string]interface{}) []string { return paths } +func findHTTPServerInputPath(streamConfig map[string]interface{}) string { + if componentMap, ok := streamConfig["input"].(map[string]interface{}); ok { + if httpServerConfig, ok := componentMap["http_server"].(map[string]interface{}); ok { + if val, ok := httpServerConfig["path"].(string); ok { + return val + } + } + } + return "" +} + // GetHTTPPaths is the ain function to get HTTP paths from the stream configuration func GetHTTPPaths(streamConfig map[string]interface{}) []string { var paths []string @@ -294,7 +312,8 @@ func (sm *StreamManager) createStream(streamID string, config map[string]interfa muxer: sm.muxer, sm: sm, // child logger is necessary to prevent race condition - logger: sm.mw.Logger().WithField("stream", streamFullID), + logger: sm.mw.Logger().WithField("stream", streamFullID), + httpServerInputPath: findHTTPServerInputPath(config), }) if err != nil { sm.mw.Logger().Errorf("Failed to start stream %s: %v", streamFullID, err) @@ -307,48 +326,123 @@ func (sm *StreamManager) createStream(streamID string, config map[string]interfa return nil } -func (sm *StreamManager) hasPath(path string) bool { - for _, p := range sm.listenPaths { - if strings.TrimPrefix(path, "/") == strings.TrimPrefix(p, "/") { - return true - } - } - return false -} - // ProcessRequest will handle the streaming functionality func (s *StreamingMiddleware) ProcessRequest(w http.ResponseWriter, r *http.Request, _ interface{}) (error, int) { strippedPath := s.Spec.StripListenPath(r.URL.Path) - if !s.defaultStreamManager.hasPath(strippedPath) { - return nil, http.StatusOK - } s.Logger().Debugf("Processing request: %s, %s", r.URL.Path, strippedPath) - newRequest := &http.Request{ - Method: r.Method, - URL: &url.URL{Scheme: r.URL.Scheme, Host: r.URL.Host, Path: strippedPath}, + variables := make(map[string]any) + // Clone the request here to transfer some variables to the underlying components such as Bento + clonedRequest := r.Clone(context.WithValue(r.Context(), tykStreamsVariablesKey, variables)) + + strippedPathRequest := &http.Request{ + Method: clonedRequest.Method, + URL: &url.URL{Scheme: clonedRequest.URL.Scheme, Host: clonedRequest.URL.Host, Path: strippedPath}, } + variables[strippedRequestKey] = strippedPathRequest - if !s.defaultStreamManager.muxer.Match(newRequest, &mux.RouteMatch{}) { + // Use the muxer to find a matched route for the request. + routeMatch := &mux.RouteMatch{} + if !s.router.Match(strippedPathRequest, routeMatch) { return nil, http.StatusOK } + routeMatch.Handler.ServeHTTP(w, clonedRequest) + return nil, mwStatusRespond +} + +func (s *StreamingMiddleware) inputHttpServerPublishHandler(w http.ResponseWriter, r *http.Request) { + // This method handles publishing messages via an HTTP endpoint without creating a new + // Bento stream for every HTTP request. + // + // It simply iterates over the existing streams and hands over the request to Bento. + // + // TODO: We may implement a queue or buffer here to store or distribute messages in a different way. + + var err error + s.streamManagers.Range(func(_, value interface{}) bool { + manager := value.(*StreamManager) + dummyResponse := &dummyResponseWriter{} + + var body io.ReadCloser + body, err = copyBody(r.Body, true) + if err != nil { + return false // break + } + clonedRequest := r.Clone(r.Context()) + clonedRequest.Body = body + s.handOverRequestToBento(manager, dummyResponse, clonedRequest) + return true // continue + }) + + if err != nil { + doJSONWrite(w, http.StatusInternalServerError, err.Error()) + return + } + + // Message received + w.WriteHeader(http.StatusAccepted) +} + +func (s *StreamingMiddleware) subscriptionHandler(w http.ResponseWriter, r *http.Request) { + manager := s.createStreamManager(r) + s.handOverRequestToBento(manager, w, r) +} + +func (s *StreamingMiddleware) getRouteMatch(manager *StreamManager, r *http.Request) (*mux.RouteMatch, error) { + manager.routeLock.Lock() + defer manager.routeLock.Unlock() + var match mux.RouteMatch - streamManager := s.createStreamManager(r) - streamManager.routeLock.Lock() - streamManager.muxer.Match(newRequest, &match) - streamManager.routeLock.Unlock() + if !manager.muxer.Match(r, &match) { + // request does not match any of this router's or its subrouters' routes then this function returns false. + return nil, mux.ErrNotFound + } + if match.MatchErr != nil { + return nil, match.MatchErr + } + return &match, nil +} + +func getStrippedRequest(r *http.Request) (*http.Request, error) { + variables, ok := r.Context().Value(tykStreamsVariablesKey).(map[string]any) + if !ok { + return nil, fmt.Errorf("%s could not be found in request context", tykStreamsVariablesKey) + } + strippedRequest, ok := variables[strippedRequestKey].(*http.Request) + if !ok { + return nil, fmt.Errorf("%s could not be found in request variables", strippedRequestKey) + } + return strippedRequest, nil +} + +func (s *StreamingMiddleware) handOverRequestToBento(manager *StreamManager, w http.ResponseWriter, r *http.Request) { + strippedRequest, err := getStrippedRequest(r) + if err != nil { + doJSONWrite(w, http.StatusInternalServerError, apiError(err.Error())) + return + } + + match, err := s.getRouteMatch(manager, strippedRequest) + if err != nil { + var code int = http.StatusInternalServerError + if errors.Is(err, mux.ErrNotFound) { + code = http.StatusNotFound + } + doJSONWrite(w, code, apiError(err.Error())) + return + } // direct Bento handler handler, ok := match.Handler.(http.HandlerFunc) if !ok { - return errors.New("invalid route handler"), http.StatusInternalServerError + doJSONWrite(w, http.StatusInternalServerError, apiError("invalid route handler")) + return } + // Wait until the subscription has killed by one of the parties. handler.ServeHTTP(w, r) - - return nil, mwStatusRespond } // Unload closes and remove active streams @@ -394,11 +488,12 @@ func (s *StreamingMiddleware) Unload() { } type handleFuncAdapter struct { - streamID string - sm *StreamManager - mw *StreamingMiddleware - muxer *mux.Router - logger *logrus.Entry + streamID string + sm *StreamManager + mw *StreamingMiddleware + muxer *mux.Router + logger *logrus.Entry + httpServerInputPath string } func (h *handleFuncAdapter) HandleFunc(path string, f func(http.ResponseWriter, *http.Request)) { @@ -410,16 +505,39 @@ func (h *handleFuncAdapter) HandleFunc(path string, f func(http.ResponseWriter, } h.sm.routeLock.Lock() + defer h.sm.routeLock.Unlock() + h.muxer.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) { defer func() { - // Stop the stream when the HTTP request finishes - if err := h.sm.removeStream(h.streamID); err != nil { - h.logger.Errorf("Failed to stop stream %s: %v", h.streamID, err) + // If this handler handles a request that publishes a message via HTTP, we don't need to + // remove the stream. It was just an HTTP request that handled by Bento + if h.httpServerInputPath != path { + // Stop the stream when the HTTP request finishes + if err := h.sm.removeStream(h.streamID); err != nil { + h.logger.Errorf("Failed to stop stream %s: %v", h.streamID, err) + } } }() f(w, r) }) - h.sm.routeLock.Unlock() + h.logger.Debugf("Registered handler for path: %s", path) } + +type dummyResponseWriter struct { +} + +func (m dummyResponseWriter) Header() http.Header { + return http.Header{} +} + +func (m dummyResponseWriter) Write(bytes []byte) (int, error) { + return len(bytes), nil +} + +func (m dummyResponseWriter) WriteHeader(statusCode int) { + return +} + +var _ http.ResponseWriter = (*dummyResponseWriter)(nil) diff --git a/gateway/mw_streaming_test.go b/gateway/mw_streaming_test.go index c55ef7641dc..ac174f86d72 100644 --- a/gateway/mw_streaming_test.go +++ b/gateway/mw_streaming_test.go @@ -1,6 +1,7 @@ package gateway import ( + "bytes" "context" "crypto/tls" "encoding/json" @@ -148,6 +149,18 @@ streams: '@service': benthos ` +const bentoHTTPServerTemplate = ` +streams: + test: + input: + http_server: + path: /post + timeout: 1s + output: + http_server: + ws_path: /subscribe +` + func TestStreamingAPISingleClient(t *testing.T) { ctx := context.Background() @@ -215,6 +228,7 @@ func TestStreamingAPISingleClient(t *testing.T) { assert.Equal(t, fmt.Sprintf("Hello %d", i), string(p), "message not equal") } } + func TestStreamingAPIMultipleClients(t *testing.T) { ctx := context.Background() @@ -289,7 +303,6 @@ func TestStreamingAPIMultipleClients(t *testing.T) { assert.Equal(t, fmt.Sprintf("Hello %d", i), string(p), fmt.Sprintf("message not equal for client %d", clientID)) } } - } func setUpStreamAPI(ts *Test, apiName string, streamConfig string) error { @@ -781,3 +794,108 @@ func TestWebSocketConnectionClosedOnAPIReload(t *testing.T) { t.Log("WebSocket connection was successfully closed on API reload") } + +func TestStreamingAPISingleClient_Input_HTTPServer(t *testing.T) { + ts := StartTest(func(globalConf *config.Config) { + globalConf.Streaming.Enabled = true + }) + t.Cleanup(func() { + ts.Close() + }) + + apiName := "test-api" + if err := setUpStreamAPI(ts, apiName, bentoHTTPServerTemplate); err != nil { + t.Fatal(err) + } + + const totalMessages = 3 + + dialer := websocket.Dialer{ + HandshakeTimeout: 1 * time.Second, + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + + wsURL := strings.Replace(ts.URL, "http", "ws", 1) + fmt.Sprintf("/%s/subscribe", apiName) + wsConn, _, err := dialer.Dial(wsURL, nil) + require.NoError(t, err, "failed to connect to ws server") + t.Cleanup(func() { + if err = wsConn.Close(); err != nil { + t.Logf("failed to close ws connection: %v", err) + } + }) + + publishURL := fmt.Sprintf("%s/%s/post", ts.URL, apiName) + for i := 0; i < totalMessages; i++ { + data := []byte(fmt.Sprintf("{\"test\": \"message %d\"}", i)) + resp, err := http.Post(publishURL, "application/json", bytes.NewReader(data)) + require.NoError(t, err) + _ = resp.Body.Close() + } + + err = wsConn.SetReadDeadline(time.Now().Add(500 * time.Millisecond)) + require.NoError(t, err, "error setting read deadline") + + for i := 0; i < totalMessages; i++ { + _, p, err := wsConn.ReadMessage() + require.NoError(t, err, "error reading message") + assert.Equal(t, fmt.Sprintf("{\"test\": \"message %d\"}", i), string(p), "message not equal") + } +} + +func TestStreamingAPIMultipleClients_Input_HTTPServer(t *testing.T) { + ts := StartTest(func(globalConf *config.Config) { + globalConf.Streaming.Enabled = true + }) + t.Cleanup(func() { + ts.Close() + }) + + apiName := "test-api" + if err := setUpStreamAPI(ts, apiName, bentoHTTPServerTemplate); err != nil { + t.Fatal(err) + } + + const ( + totalClients = 3 + totalMessages = 3 + ) + dialer := websocket.Dialer{ + HandshakeTimeout: 1 * time.Second, + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + + wsURL := strings.Replace(ts.URL, "http", "ws", 1) + fmt.Sprintf("/%s/subscribe", apiName) + + // Create multiple WebSocket connections + var wsConns []*websocket.Conn + for i := 0; i < totalClients; i++ { + wsConn, _, err := dialer.Dial(wsURL, nil) + require.NoError(t, err, fmt.Sprintf("failed to connect to ws server for client %d", i)) + wsConns = append(wsConns, wsConn) + t.Cleanup(func() { + if err := wsConn.Close(); err != nil { + t.Logf("failed to close ws connection: %v", err) + } + }) + } + + publishURL := fmt.Sprintf("%s/%s/post", ts.URL, apiName) + for i := 0; i < totalMessages; i++ { + data := []byte(fmt.Sprintf("{\"test\": \"message %d\"}", i)) + resp, err := http.Post(publishURL, "application/json", bytes.NewReader(data)) + require.NoError(t, err) + _ = resp.Body.Close() + } + + // Read messages from all clients + for clientID, wsConn := range wsConns { + err := wsConn.SetReadDeadline(time.Now().Add(5000 * time.Millisecond)) + require.NoError(t, err, fmt.Sprintf("error setting read deadline for client %d", clientID)) + + for i := 0; i < totalMessages; i++ { + _, p, err := wsConn.ReadMessage() + require.NoError(t, err, fmt.Sprintf("error reading message for client %d, message %d", clientID, i)) + assert.Equal(t, fmt.Sprintf("{\"test\": \"message %d\"}", i), string(p), fmt.Sprintf("message not equal for client %d", clientID)) + } + } +}