diff --git a/ee/middleware/upstreamoauth/middleware.go b/ee/middleware/upstreamoauth/middleware.go new file mode 100644 index 00000000000..7a08595f328 --- /dev/null +++ b/ee/middleware/upstreamoauth/middleware.go @@ -0,0 +1,110 @@ +package upstreamoauth + +import ( + "fmt" + "net/http" + + "github.com/sirupsen/logrus" + + "github.com/TykTechnologies/tyk/header" + "github.com/TykTechnologies/tyk/internal/event" + "github.com/TykTechnologies/tyk/internal/httputil" + "github.com/TykTechnologies/tyk/internal/model" +) + +// Middleware implements upstream OAuth middleware. +type Middleware struct { + Spec model.MergedAPI + Gw Gateway + + Base BaseMiddleware + + clientCredentialsStorageHandler Storage + passwordStorageHandler Storage +} + +// Middleware implements model.Middleware. +var _ model.Middleware = &Middleware{} + +// NewMiddleware returns a new instance of Middleware. +func NewMiddleware(gw Gateway, mw BaseMiddleware, spec model.MergedAPI, ccStorageHandler Storage, pwStorageHandler Storage) *Middleware { + return &Middleware{ + Base: mw, + Gw: gw, + Spec: spec, + clientCredentialsStorageHandler: ccStorageHandler, + passwordStorageHandler: pwStorageHandler, + } +} + +// Logger returns a logger with middleware filled out. +func (m *Middleware) Logger() *logrus.Entry { + return m.Base.Logger().WithField("mw", m.Name()) +} + +// Name returns the name for the middleware. +func (m *Middleware) Name() string { + return MiddlewareName +} + +// EnabledForSpec checks if streaming is enabled on the config. +func (m *Middleware) EnabledForSpec() bool { + if !m.Spec.UpstreamAuth.IsEnabled() { + return false + } + + if !m.Spec.UpstreamAuth.OAuth.Enabled { + return false + } + + return true +} + +// Init initializes the middleware. +func (m *Middleware) Init() { + m.Logger().Debug("Initializing Upstream basic auth Middleware") +} + +// ProcessRequest will handle upstream OAuth. +func (m *Middleware) ProcessRequest(_ http.ResponseWriter, r *http.Request, _ interface{}) (error, int) { + provider, err := NewOAuthHeaderProvider(m.Spec.UpstreamAuth.OAuth) + if err != nil { + return fmt.Errorf("failed to get OAuth header provider: %w", err), http.StatusInternalServerError + } + + payload, err := provider.getOAuthToken(r, m) + if err != nil { + return fmt.Errorf("failed to get OAuth token: %w", err), http.StatusInternalServerError + } + + upstreamOAuthProvider := Provider{ + HeaderName: header.Authorization, + AuthValue: payload, + } + + headerName := provider.getHeaderName(m) + if headerName != "" { + upstreamOAuthProvider.HeaderName = headerName + } + + if provider.headerEnabled(m) { + headerName := provider.getHeaderName(m) + if headerName != "" { + upstreamOAuthProvider.HeaderName = headerName + } + } + + httputil.SetUpstreamAuth(r, upstreamOAuthProvider) + return nil, http.StatusOK +} + +// FireEvent emits an upstream OAuth event with an optional custom message. +func (mw *Middleware) FireEvent(r *http.Request, e event.Event, message string, apiId string) { + if message == "" { + message = event.String(e) + } + mw.Base.FireEvent(e, EventUpstreamOAuthMeta{ + EventMetaDefault: model.NewEventMetaDefault(r, message), + APIID: apiId, + }) +} diff --git a/ee/middleware/upstreamoauth/model.go b/ee/middleware/upstreamoauth/model.go new file mode 100644 index 00000000000..9c96bc747fd --- /dev/null +++ b/ee/middleware/upstreamoauth/model.go @@ -0,0 +1,54 @@ +package upstreamoauth + +import ( + "time" + + "github.com/TykTechnologies/tyk/apidef" + "github.com/TykTechnologies/tyk/ctx" + "github.com/TykTechnologies/tyk/internal/httpctx" + "github.com/TykTechnologies/tyk/internal/model" +) + +const ( + ErrorEventName = "UpstreamOAuthError" + MiddlewareName = "UpstreamOAuth" + + ClientCredentialsAuthorizeType = "clientCredentials" + PasswordAuthorizeType = "password" +) + +// BaseMiddleware is the subset of BaseMiddleware APIs that the middleware uses. +type BaseMiddleware interface { + model.LoggerProvider + FireEvent(name apidef.TykEvent, meta interface{}) +} + +// Gateway is the subset of Gateway APIs that the middleware uses. +type Gateway interface { + model.ConfigProvider +} + +// Type Storage is a subset of storage.RedisCluster +type Storage interface { + GetKey(key string) (string, error) + SetKey(string, string, int64) error + Lock(key string, timeout time.Duration) (bool, error) +} + +type ClientCredentialsOAuthProvider struct{} + +type PerAPIClientCredentialsOAuthProvider struct{} + +type PasswordOAuthProvider struct{} + +type TokenData struct { + Token string `json:"token"` + ExtraMetadata map[string]interface{} `json:"extra_metadata"` +} + +var ( + ctxData = httpctx.NewValue[map[string]any](ctx.ContextData) + + CtxGetData = ctxData.Get + CtxSetData = ctxData.Set +) diff --git a/ee/middleware/upstreamoauth/provider.go b/ee/middleware/upstreamoauth/provider.go new file mode 100644 index 00000000000..cb599f28343 --- /dev/null +++ b/ee/middleware/upstreamoauth/provider.go @@ -0,0 +1,204 @@ +package upstreamoauth + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "net/http" + "strings" + "time" + + "github.com/sirupsen/logrus" + "golang.org/x/oauth2" + oauth2clientcredentials "golang.org/x/oauth2/clientcredentials" + + "github.com/TykTechnologies/tyk/apidef" + "github.com/TykTechnologies/tyk/internal/model" +) + +// Provider implements upstream auth provider. +type Provider struct { + // Logger is the logger to be used. + Logger *logrus.Entry + // HeaderName is the header name to be used to fill upstream auth with. + HeaderName string + // AuthValue is the value of auth header. + AuthValue string +} + +// Fill sets the request's HeaderName with AuthValue +func (u Provider) Fill(r *http.Request) { + if r.Header.Get(u.HeaderName) != "" { + u.Logger.WithFields(logrus.Fields{ + "header": u.HeaderName, + }).Info("Authorization header conflict detected: Client header overwritten by Gateway upstream authentication header.") + } + r.Header.Set(u.HeaderName, u.AuthValue) +} + +type OAuthHeaderProvider interface { + // getOAuthToken returns the OAuth token for the request. + getOAuthToken(r *http.Request, mw *Middleware) (string, error) + // getHeaderName returns the header name for the OAuth token. + getHeaderName(mw *Middleware) string + // + headerEnabled(mw *Middleware) bool +} + +func NewOAuthHeaderProvider(oauthConfig apidef.UpstreamOAuth) (OAuthHeaderProvider, error) { + if !oauthConfig.IsEnabled() { + return nil, fmt.Errorf("upstream OAuth is not enabled") + } + + switch { + case len(oauthConfig.AllowedAuthorizeTypes) == 0: + return nil, fmt.Errorf("no OAuth configuration selected") + case len(oauthConfig.AllowedAuthorizeTypes) > 1: + return nil, fmt.Errorf("both client credentials and password authentication are provided") + case oauthConfig.AllowedAuthorizeTypes[0] == ClientCredentialsAuthorizeType: + return &ClientCredentialsOAuthProvider{}, nil + case oauthConfig.AllowedAuthorizeTypes[0] == PasswordAuthorizeType: + return &PasswordOAuthProvider{}, nil + default: + return nil, fmt.Errorf("no valid OAuth configuration provided") + } +} + +func (p *ClientCredentialsOAuthProvider) getOAuthToken(r *http.Request, mw *Middleware) (string, error) { + client := ClientCredentialsClient{mw} + token, err := client.GetToken(r) + if err != nil { + return handleOAuthError(r, mw, err) + } + + return fmt.Sprintf("Bearer %s", token), nil +} + +func handleOAuthError(r *http.Request, mw *Middleware, err error) (string, error) { + mw.FireEvent(r, ErrorEventName, err.Error(), mw.Spec.APIID) + return "", err +} + +func (p *ClientCredentialsOAuthProvider) getHeaderName(OAuthSpec *Middleware) string { + return OAuthSpec.Spec.UpstreamAuth.OAuth.ClientCredentials.Header.Name +} + +func (p *ClientCredentialsOAuthProvider) headerEnabled(OAuthSpec *Middleware) bool { + return OAuthSpec.Spec.UpstreamAuth.OAuth.ClientCredentials.Header.Enabled +} + +func newOAuth2ClientCredentialsConfig(OAuthSpec *Middleware) oauth2clientcredentials.Config { + return oauth2clientcredentials.Config{ + ClientID: OAuthSpec.Spec.UpstreamAuth.OAuth.ClientCredentials.ClientID, + ClientSecret: OAuthSpec.Spec.UpstreamAuth.OAuth.ClientCredentials.ClientSecret, + TokenURL: OAuthSpec.Spec.UpstreamAuth.OAuth.ClientCredentials.TokenURL, + Scopes: OAuthSpec.Spec.UpstreamAuth.OAuth.ClientCredentials.Scopes, + } +} + +func newOAuth2PasswordConfig(OAuthSpec *Middleware) oauth2.Config { + return oauth2.Config{ + ClientID: OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.ClientID, + ClientSecret: OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.ClientSecret, + Endpoint: oauth2.Endpoint{ + TokenURL: OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.TokenURL, + }, + Scopes: OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.Scopes, + } +} + +type ClientCredentialsClient struct { + mw *Middleware +} + +type PasswordClient struct { + mw *Middleware +} + +func generateClientCredentialsCacheKey(config apidef.UpstreamOAuth, apiId string) string { + key := fmt.Sprintf( + "cc-%s|%s|%s|%s", + apiId, + config.ClientCredentials.ClientID, + config.ClientCredentials.TokenURL, + strings.Join(config.ClientCredentials.Scopes, ",")) + + hash := sha256.New() + hash.Write([]byte(key)) + return hex.EncodeToString(hash.Sum(nil)) +} + +func retryGetKeyAndLock(cacheKey string, cache Storage) (string, error) { + const maxRetries = 10 + const retryDelay = 100 * time.Millisecond + + var tokenData string + var err error + + for i := 0; i < maxRetries; i++ { + tokenData, err = cache.GetKey(cacheKey) + if err == nil { + return tokenData, nil + } + + lockKey := cacheKey + ":lock" + ok, err := cache.Lock(lockKey, time.Second*5) + if err == nil && ok { + return "", nil + } + + time.Sleep(retryDelay) + } + + return "", fmt.Errorf("failed to acquire lock after retries: %w", err) +} + +func SetExtraMetadata(r *http.Request, keyList []string, metadata map[string]interface{}) { + contextDataObject := CtxGetData(r) + if contextDataObject == nil { + contextDataObject = make(map[string]interface{}) + } + for _, key := range keyList { + if val, ok := metadata[key]; ok && val != "" { + contextDataObject[key] = val + } + } + CtxSetData(r, contextDataObject) +} + +// EventUpstreamOAuthMeta is the metadata structure for an upstream OAuth event +type EventUpstreamOAuthMeta struct { + model.EventMetaDefault + APIID string +} + +func (p *PasswordOAuthProvider) getOAuthToken(r *http.Request, mw *Middleware) (string, error) { + client := PasswordClient{mw} + token, err := client.GetToken(r) + if err != nil { + return handleOAuthError(r, mw, err) + } + + return fmt.Sprintf("Bearer %s", token), nil +} + +func (p *PasswordOAuthProvider) getHeaderName(OAuthSpec *Middleware) string { + return OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.Header.Name +} + +func (p *PasswordOAuthProvider) headerEnabled(OAuthSpec *Middleware) bool { + return OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.Header.Enabled +} + +func generatePasswordOAuthCacheKey(config apidef.UpstreamOAuth, apiId string) string { + key := fmt.Sprintf( + "pw-%s|%s|%s|%s", + apiId, + config.PasswordAuthentication.ClientID, + config.PasswordAuthentication.ClientSecret, + strings.Join(config.PasswordAuthentication.Scopes, ",")) + + hash := sha256.New() + hash.Write([]byte(key)) + return hex.EncodeToString(hash.Sum(nil)) +} diff --git a/ee/middleware/upstreamoauth/provider_client_credentials.go b/ee/middleware/upstreamoauth/provider_client_credentials.go new file mode 100644 index 00000000000..19215b22cc6 --- /dev/null +++ b/ee/middleware/upstreamoauth/provider_client_credentials.go @@ -0,0 +1,26 @@ +package upstreamoauth + +import ( + "context" + "net/http" + + "golang.org/x/oauth2" +) + +func (cache *ClientCredentialsClient) ObtainToken(ctx context.Context) (*oauth2.Token, error) { + cfg := newOAuth2ClientCredentialsConfig(cache.mw) + tokenSource := cfg.TokenSource(ctx) + return tokenSource.Token() +} + +func (cache *ClientCredentialsClient) GetToken(r *http.Request) (string, error) { + cacheKey := generateClientCredentialsCacheKey(cache.mw.Spec.UpstreamAuth.OAuth, cache.mw.Spec.APIID) + secret := cache.mw.Gw.GetConfig().Secret + extraMetadata := cache.mw.Spec.UpstreamAuth.OAuth.ClientCredentials.ExtraMetadata + + obtainTokenFunc := func(ctx context.Context) (*oauth2.Token, error) { + return cache.ObtainToken(ctx) + } + + return getToken(r, cacheKey, obtainTokenFunc, secret, extraMetadata, cache.mw.clientCredentialsStorageHandler) +} diff --git a/ee/middleware/upstreamoauth/provider_password_authentication.go b/ee/middleware/upstreamoauth/provider_password_authentication.go new file mode 100644 index 00000000000..159267101bb --- /dev/null +++ b/ee/middleware/upstreamoauth/provider_password_authentication.go @@ -0,0 +1,25 @@ +package upstreamoauth + +import ( + "context" + "net/http" + + "golang.org/x/oauth2" +) + +func (cache *PasswordClient) ObtainToken(ctx context.Context) (*oauth2.Token, error) { + cfg := newOAuth2PasswordConfig(cache.mw) + return cfg.PasswordCredentialsToken(ctx, cache.mw.Spec.UpstreamAuth.OAuth.PasswordAuthentication.Username, cache.mw.Spec.UpstreamAuth.OAuth.PasswordAuthentication.Password) +} + +func (cache *PasswordClient) GetToken(r *http.Request) (string, error) { + cacheKey := generatePasswordOAuthCacheKey(cache.mw.Spec.UpstreamAuth.OAuth, cache.mw.Spec.APIID) + secret := cache.mw.Gw.GetConfig().Secret + extraMetadata := cache.mw.Spec.UpstreamAuth.OAuth.PasswordAuthentication.ExtraMetadata + + obtainTokenFunc := func(ctx context.Context) (*oauth2.Token, error) { + return cache.ObtainToken(ctx) + } + + return getToken(r, cacheKey, obtainTokenFunc, secret, extraMetadata, cache.mw.passwordStorageHandler) +} diff --git a/gateway/mw_oauth2_auth_test.go b/ee/middleware/upstreamoauth/provider_test.go similarity index 90% rename from gateway/mw_oauth2_auth_test.go rename to ee/middleware/upstreamoauth/provider_test.go index 8f26b67039b..fd791c9be4e 100644 --- a/gateway/mw_oauth2_auth_test.go +++ b/ee/middleware/upstreamoauth/provider_test.go @@ -1,4 +1,4 @@ -package gateway +package upstreamoauth_test import ( "encoding/json" @@ -10,6 +10,9 @@ import ( "golang.org/x/oauth2" + "github.com/TykTechnologies/tyk/ee/middleware/upstreamoauth" + "github.com/TykTechnologies/tyk/gateway" + "github.com/stretchr/testify/assert" "github.com/TykTechnologies/tyk/apidef" @@ -17,8 +20,14 @@ import ( "github.com/TykTechnologies/tyk/test" ) -func TestUpstreamOauth2(t *testing.T) { +var StartTest = gateway.StartTest + +type APISpec = gateway.APISpec + +const ClientCredentialsAuthorizeType = upstreamoauth.ClientCredentialsAuthorizeType +const PasswordAuthorizeType = upstreamoauth.PasswordAuthorizeType +func TestProvider_ClientCredentialsAuthorizeType(t *testing.T) { tst := StartTest(nil) t.Cleanup(tst.Close) @@ -120,7 +129,7 @@ func TestUpstreamOauth2(t *testing.T) { } -func TestPasswordCredentialsTokenRequest(t *testing.T) { +func TestProvider_PasswordAuthorizeType(t *testing.T) { tst := StartTest(nil) t.Cleanup(tst.Close) @@ -235,9 +244,9 @@ func TestSetExtraMetadata(t *testing.T) { "key3": "value3", } - setExtraMetadata(req, keyList, token) + upstreamoauth.SetExtraMetadata(req, keyList, token) - contextData := ctxGetData(req) + contextData := upstreamoauth.CtxGetData(req) assert.Equal(t, "value1", contextData["key1"]) assert.Equal(t, "value2", contextData["key2"]) @@ -257,7 +266,7 @@ func TestBuildMetadataMap(t *testing.T) { }) extraMetadataKeys := []string{"key1", "key2", "key3", "key4"} - metadataMap := buildMetadataMap(token, extraMetadataKeys) + metadataMap := upstreamoauth.BuildMetadataMap(token, extraMetadataKeys) assert.Equal(t, "value1", metadataMap["key1"]) assert.Equal(t, "value2", metadataMap["key2"]) @@ -280,11 +289,11 @@ func TestCreateTokenDataBytes(t *testing.T) { extraMetadataKeys := []string{"key1", "key2", "key3", "key4"} encryptedToken := "encrypted_tyk_upstream_oauth_access_token" - tokenDataBytes, err := createTokenDataBytes(encryptedToken, token, extraMetadataKeys) + tokenDataBytes, err := upstreamoauth.CreateTokenDataBytes(encryptedToken, token, extraMetadataKeys) assert.NoError(t, err) - var tokenData TokenData + var tokenData upstreamoauth.TokenData err = json.Unmarshal(tokenDataBytes, &tokenData) assert.NoError(t, err) @@ -296,7 +305,7 @@ func TestCreateTokenDataBytes(t *testing.T) { } func TestUnmarshalTokenData(t *testing.T) { - tokenData := TokenData{ + tokenData := upstreamoauth.TokenData{ Token: "tyk_upstream_oauth_access_token", ExtraMetadata: map[string]interface{}{ "key1": "value1", @@ -307,7 +316,7 @@ func TestUnmarshalTokenData(t *testing.T) { tokenDataBytes, err := json.Marshal(tokenData) assert.NoError(t, err) - result, err := unmarshalTokenData(string(tokenDataBytes)) + result, err := upstreamoauth.UnmarshalTokenData(string(tokenDataBytes)) assert.NoError(t, err) diff --git a/ee/middleware/upstreamoauth/token_cache.go b/ee/middleware/upstreamoauth/token_cache.go new file mode 100644 index 00000000000..34b1a4ea708 --- /dev/null +++ b/ee/middleware/upstreamoauth/token_cache.go @@ -0,0 +1,89 @@ +package upstreamoauth + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "time" + + "golang.org/x/oauth2" + + "github.com/TykTechnologies/tyk/internal/crypto" +) + +type Cache interface { + // GetToken returns the token from cache or issues a request to obtain it from the OAuth provider. + GetToken(r *http.Request) (string, error) + // ObtainToken issues a request to obtain the token from the OAuth provider. + ObtainToken(ctx context.Context) (*oauth2.Token, error) +} + +func getToken(r *http.Request, cacheKey string, obtainTokenFunc func(context.Context) (*oauth2.Token, error), secret string, extraMetadata []string, cache Storage) (string, error) { + tokenData, err := retryGetKeyAndLock(cacheKey, cache) + if err != nil { + return "", err + } + + if tokenData != "" { + tokenContents, err := UnmarshalTokenData(tokenData) + if err != nil { + return "", err + } + decryptedToken := crypto.Decrypt(crypto.GetPaddedString(secret), tokenContents.Token) + SetExtraMetadata(r, extraMetadata, tokenContents.ExtraMetadata) + return decryptedToken, nil + } + + token, err := obtainTokenFunc(r.Context()) + if err != nil { + return "", err + } + + encryptedToken := crypto.Encrypt(crypto.GetPaddedString(secret), token.AccessToken) + tokenDataBytes, err := CreateTokenDataBytes(encryptedToken, token, extraMetadata) + if err != nil { + return "", err + } + metadataMap := BuildMetadataMap(token, extraMetadata) + SetExtraMetadata(r, extraMetadata, metadataMap) + + ttl := time.Until(token.Expiry) + if err := setTokenInCache(cache, cacheKey, string(tokenDataBytes), ttl); err != nil { + return "", err + } + + return token.AccessToken, nil +} + +func setTokenInCache(cache Storage, cacheKey string, token string, ttl time.Duration) error { + oauthTokenExpiry := time.Now().Add(ttl) + return cache.SetKey(cacheKey, token, int64(time.Until(oauthTokenExpiry).Seconds())) +} + +func CreateTokenDataBytes(encryptedToken string, token *oauth2.Token, extraMetadataKeys []string) ([]byte, error) { + td := TokenData{ + Token: encryptedToken, + ExtraMetadata: BuildMetadataMap(token, extraMetadataKeys), + } + return json.Marshal(td) +} + +func UnmarshalTokenData(tokenData string) (TokenData, error) { + var tokenContents TokenData + err := json.Unmarshal([]byte(tokenData), &tokenContents) + if err != nil { + return TokenData{}, fmt.Errorf("failed to unmarshal token data: %w", err) + } + return tokenContents, nil +} + +func BuildMetadataMap(token *oauth2.Token, extraMetadataKeys []string) map[string]interface{} { + metadataMap := make(map[string]interface{}) + for _, key := range extraMetadataKeys { + if val := token.Extra(key); val != "" && val != nil { + metadataMap[key] = val + } + } + return metadataMap +} diff --git a/gateway/api.go b/gateway/api.go index b06d8f965fa..daf76eede5b 100644 --- a/gateway/api.go +++ b/gateway/api.go @@ -27,7 +27,6 @@ package gateway import ( "bytes" - "context" "encoding/base64" "encoding/json" "errors" @@ -3062,34 +3061,6 @@ func (gw *Gateway) makeImportedOASTykAPI(next http.HandlerFunc) http.HandlerFunc } } -// TODO: Don't modify http.Request values in-place. We must right now -// because our middleware design doesn't pass around http.Request -// pointers, so we have no way to modify the pointer in a middleware. -// -// If we ever redesign middlewares - or if we find another workaround - -// revisit this. -func setContext(r *http.Request, ctx context.Context) { - r2 := r.WithContext(ctx) - *r = *r2 -} -func setCtxValue(r *http.Request, key, val interface{}) { - setContext(r, context.WithValue(r.Context(), key, val)) -} - -func ctxGetData(r *http.Request) map[string]interface{} { - if v := r.Context().Value(ctx.ContextData); v != nil { - return v.(map[string]interface{}) - } - return nil -} - -func ctxSetData(r *http.Request, m map[string]interface{}) { - if m == nil { - panic("setting a nil context ContextData") - } - setCtxValue(r, ctx.ContextData, m) -} - // ctxSetCacheOptions sets a cache key to use for the http request func ctxSetCacheOptions(r *http.Request, options *cacheOptions) { setCtxValue(r, ctx.CacheOptions, options) diff --git a/gateway/api_loader.go b/gateway/api_loader.go index 1bd3a96373b..e2c7268008a 100644 --- a/gateway/api_loader.go +++ b/gateway/api_loader.go @@ -440,7 +440,9 @@ func (gw *Gateway) processSpec(spec *APISpec, apisByListen map[string]int, gw.mwAppendEnabled(&chainArray, upstreamBasicAuthMw) } - gw.mwAppendEnabled(&chainArray, &UpstreamOAuth{BaseMiddleware: baseMid}) + if upstreamOAuthMw := getUpstreamOAuthMw(baseMid); upstreamOAuthMw != nil { + gw.mwAppendEnabled(&chainArray, upstreamOAuthMw) + } gw.mwAppendEnabled(&chainArray, &ValidateJSON{BaseMiddleware: baseMid}) gw.mwAppendEnabled(&chainArray, &ValidateRequest{BaseMiddleware: baseMid}) diff --git a/gateway/api_test.go b/gateway/api_test.go index 3f7f8a86a2a..00bef2aa88a 100644 --- a/gateway/api_test.go +++ b/gateway/api_test.go @@ -1936,12 +1936,6 @@ func TestContextData(t *testing.T) { if ctxGetData(r) == nil { t.Fatal("expected ctxGetData to return non-nil") } - defer func() { - if r := recover(); r == nil { - t.Fatal("expected ctxSetData of zero val to panic") - } - }() - ctxSetData(r, nil) } func TestContextSession(t *testing.T) { diff --git a/gateway/event_system.go b/gateway/event_system.go index 4cf80d04bec..31ebe2c0cd8 100644 --- a/gateway/event_system.go +++ b/gateway/event_system.go @@ -1,11 +1,8 @@ package gateway import ( - "bytes" - "encoding/base64" "errors" "fmt" - "net/http" "time" "github.com/sirupsen/logrus" @@ -59,24 +56,11 @@ const ( EventTokenDeleted = event.TokenDeleted ) -// EventMetaDefault is a standard embedded struct to be used with custom event metadata types, gives an interface for -// easily extending event metadata objects -type EventMetaDefault struct { - Message string - OriginatingRequest string -} - type EventHostStatusMeta struct { EventMetaDefault HostInfo HostHealthReport } -// EventUpstreamOAuthMeta is the metadata structure for an upstream OAuth event -type EventUpstreamOAuthMeta struct { - EventMetaDefault - APIID string -} - // EventKeyFailureMeta is the metadata structure for any failure related // to a key, such as quota or auth failures. type EventKeyFailureMeta struct { @@ -117,15 +101,6 @@ type EventTokenMeta struct { Key string } -// EncodeRequestToEvent will write the request out in wire protocol and -// encode it to base64 and store it in an Event object -func EncodeRequestToEvent(r *http.Request) string { - var asBytes bytes.Buffer - r.Write(&asBytes) - - return base64.StdEncoding.EncodeToString(asBytes.Bytes()) -} - // EventHandlerByName is a convenience function to get event handler instances from an API Definition func (gw *Gateway) EventHandlerByName(handlerConf apidef.EventHandlerTriggerConfig, spec *APISpec) (config.TykEventHandler, error) { diff --git a/gateway/handler_success.go b/gateway/handler_success.go index f61281a2b7f..c149dc229ca 100644 --- a/gateway/handler_success.go +++ b/gateway/handler_success.go @@ -10,14 +10,14 @@ import ( "strings" "time" - graphqlinternal "github.com/TykTechnologies/tyk/internal/graphql" - - "github.com/TykTechnologies/tyk/apidef" + "github.com/TykTechnologies/tyk/ctx" "github.com/TykTechnologies/tyk/internal/httputil" + graphqlinternal "github.com/TykTechnologies/tyk/internal/graphql" + "github.com/TykTechnologies/tyk-pump/analytics" + "github.com/TykTechnologies/tyk/apidef" "github.com/TykTechnologies/tyk/config" - "github.com/TykTechnologies/tyk/ctx" "github.com/TykTechnologies/tyk/header" "github.com/TykTechnologies/tyk/request" "github.com/TykTechnologies/tyk/user" diff --git a/gateway/handler_success_test.go b/gateway/handler_success_test.go index e96cf99eecd..23922195675 100644 --- a/gateway/handler_success_test.go +++ b/gateway/handler_success_test.go @@ -5,17 +5,17 @@ import ( "net/http" "testing" - "github.com/TykTechnologies/graphql-go-tools/pkg/engine/datasource/httpclient" + "github.com/stretchr/testify/assert" + "github.com/TykTechnologies/graphql-go-tools/pkg/engine/datasource/httpclient" "github.com/TykTechnologies/graphql-go-tools/pkg/graphql" - "github.com/TykTechnologies/tyk-pump/analytics" - "github.com/TykTechnologies/tyk/test" - "github.com/stretchr/testify/assert" + "github.com/TykTechnologies/tyk-pump/analytics" "github.com/TykTechnologies/tyk/apidef" "github.com/TykTechnologies/tyk/config" ctxpkg "github.com/TykTechnologies/tyk/ctx" + "github.com/TykTechnologies/tyk/test" "github.com/TykTechnologies/tyk/user" ) diff --git a/gateway/middleware.go b/gateway/middleware.go index 55a68ff67fb..c3d85649d76 100644 --- a/gateway/middleware.go +++ b/gateway/middleware.go @@ -531,21 +531,6 @@ func (t *BaseMiddleware) emitRateLimitEvent(r *http.Request, e event.Event, mess }) } -// emitUpstreamOAuthEvent emits an upstream OAuth event with an optional custom message. -func (t *BaseMiddleware) emitUpstreamOAuthEvent(r *http.Request, e event.Event, message string, apiId string) { - if message == "" { - message = event.String(e) - } - - t.FireEvent(e, EventUpstreamOAuthMeta{ - EventMetaDefault: EventMetaDefault{ - Message: message, - OriginatingRequest: EncodeRequestToEvent(r), - }, - APIID: apiId, - }) -} - // handleRateLimitFailure handles the actions to be taken when a rate limit failure occurs. func (t *BaseMiddleware) handleRateLimitFailure(r *http.Request, e event.Event, message string, rateLimitKey string) (error, int) { t.emitRateLimitEvent(r, e, message, rateLimitKey) diff --git a/gateway/model.go b/gateway/model.go new file mode 100644 index 00000000000..f9315965c3d --- /dev/null +++ b/gateway/model.go @@ -0,0 +1,30 @@ +package gateway + +import ( + "net/http" + + "github.com/TykTechnologies/tyk/ctx" + "github.com/TykTechnologies/tyk/internal/event" + "github.com/TykTechnologies/tyk/internal/httpctx" + "github.com/TykTechnologies/tyk/internal/httputil" + "github.com/TykTechnologies/tyk/internal/model" +) + +type EventMetaDefault = model.EventMetaDefault + +var ( + ctxData = httpctx.NewValue[map[string]any](ctx.ContextData) + + ctxGetData = ctxData.Get + ctxSetData = ctxData.Set + + setContext = httputil.SetContext + + // how is type safety avoided: exhibit A, old school generics + setCtxValue = func(h *http.Request, key, value any) { + ctxvalue := httpctx.NewValue[any](key) + h = ctxvalue.Set(h, value) + } + + EncodeRequestToEvent = event.EncodeRequestToEvent +) diff --git a/gateway/mw_oauth2_auth.go b/gateway/mw_oauth2_auth.go index 6808a39ab45..f2fa16de282 100644 --- a/gateway/mw_oauth2_auth.go +++ b/gateway/mw_oauth2_auth.go @@ -1,447 +1,34 @@ +//go:build !ee && !dev + package gateway import ( - "context" - "crypto/sha256" - "encoding/hex" - "encoding/json" - "fmt" "net/http" - "strings" - "time" - - "golang.org/x/oauth2" - - "github.com/sirupsen/logrus" - oauth2clientcredentials "golang.org/x/oauth2/clientcredentials" - - "github.com/TykTechnologies/tyk/apidef" - "github.com/TykTechnologies/tyk/header" - "github.com/TykTechnologies/tyk/internal/httputil" - "github.com/TykTechnologies/tyk/storage" -) - -const ( - UpstreamOAuthErrorEventName = "UpstreamOAuthError" - UpstreamOAuthMiddlewareName = "UpstreamOAuth" ) -type OAuthHeaderProvider interface { - // getOAuthToken returns the OAuth token for the request. - getOAuthToken(r *http.Request, OAuthSpec *UpstreamOAuth) (string, error) - getHeaderName(OAuthSpec *UpstreamOAuth) string - headerEnabled(OAuthSpec *UpstreamOAuth) bool -} - -type ClientCredentialsOAuthProvider struct{} - -type PerAPIClientCredentialsOAuthProvider struct{} - -type PasswordOAuthProvider struct{} - -func newUpstreamOAuthClientCredentialsCache(connectionHandler *storage.ConnectionHandler) UpstreamOAuthCache { - return &upstreamOAuthClientCredentialsCache{RedisCluster: storage.RedisCluster{KeyPrefix: "upstreamOAuthCC-", ConnectionHandler: connectionHandler}} -} - -func newUpstreamOAuthPasswordCache(connectionHandler *storage.ConnectionHandler) UpstreamOAuthCache { - return &upstreamOAuthPasswordCache{RedisCluster: storage.RedisCluster{KeyPrefix: "upstreamOAuthPW-", ConnectionHandler: connectionHandler}} -} - -type upstreamOAuthClientCredentialsCache struct { - storage.RedisCluster -} - -type upstreamOAuthPasswordCache struct { - storage.RedisCluster -} - -func (cache *upstreamOAuthPasswordCache) getToken(r *http.Request, OAuthSpec *UpstreamOAuth) (string, error) { - cacheKey := generatePasswordOAuthCacheKey(OAuthSpec.Spec.UpstreamAuth.OAuth, OAuthSpec.Spec.APIID) - - tokenData, err := retryGetKeyAndLock(cacheKey, &cache.RedisCluster) - if err != nil { - return "", err - } - - if tokenData != "" { - tokenContents, err := unmarshalTokenData(tokenData) - if err != nil { - return "", err - } - decryptedToken := decrypt(getPaddedSecret(OAuthSpec.Gw.GetConfig().Secret), tokenContents.Token) - setExtraMetadata(r, OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.ExtraMetadata, tokenContents.ExtraMetadata) - return decryptedToken, nil - } - - token, err := cache.obtainToken(r.Context(), OAuthSpec) - if err != nil { - return "", err - } - - encryptedToken := encrypt(getPaddedSecret(OAuthSpec.Gw.GetConfig().Secret), token.AccessToken) - tokenDataBytes, err := createTokenDataBytes(encryptedToken, token, OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.ExtraMetadata) - if err != nil { - return "", err - } - metadataMap := buildMetadataMap(token, OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.ExtraMetadata) - setExtraMetadata(r, OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.ExtraMetadata, metadataMap) - - ttl := time.Until(token.Expiry) - if err := setTokenInCache(cacheKey, string(tokenDataBytes), ttl, &cache.RedisCluster); err != nil { - return "", err - } - - return token.AccessToken, nil -} - -func (cache *upstreamOAuthPasswordCache) obtainToken(ctx context.Context, OAuthSpec *UpstreamOAuth) (*oauth2.Token, error) { - cfg := newOAuth2PasswordConfig(OAuthSpec) - - token, err := cfg.PasswordCredentialsToken(ctx, OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.Username, OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.Password) - if err != nil { - return &oauth2.Token{}, err - } - - return token, nil -} - -type UpstreamOAuthCache interface { - // getToken returns the token from cache or issues a request to obtain it from the OAuth provider. - getToken(r *http.Request, OAuthSpec *UpstreamOAuth) (string, error) - // obtainToken issues a request to obtain the token from the OAuth provider. - obtainToken(ctx context.Context, OAuthSpec *UpstreamOAuth) (*oauth2.Token, error) +func getUpstreamOAuthMw(base *BaseMiddleware) TykMiddleware { + return &noopUpstreamOAuth{base} } -// UpstreamOAuth is a middleware that will do basic authentication for upstream connections. -// UpstreamOAuth middleware is only supported in Tyk OAS API definitions. -type UpstreamOAuth struct { +type noopUpstreamOAuth struct { *BaseMiddleware } -// Name returns the name of middleware. -func (OAuthSpec *UpstreamOAuth) Name() string { - return UpstreamOAuthMiddlewareName -} - -// EnabledForSpec returns true if the middleware is enabled based on API Spec. -func (OAuthSpec *UpstreamOAuth) EnabledForSpec() bool { - if !OAuthSpec.Spec.UpstreamAuth.Enabled { - return false - } - - if !OAuthSpec.Spec.UpstreamAuth.OAuth.Enabled { - return false - } - - return true -} - -// ProcessRequest will inject basic auth info into request context so that it can be used during reverse proxy. -func (OAuthSpec *UpstreamOAuth) ProcessRequest(_ http.ResponseWriter, r *http.Request, _ interface{}) (error, int) { - oauthConfig := OAuthSpec.Spec.UpstreamAuth.OAuth - - upstreamOAuthProvider := UpstreamOAuthProvider{ - HeaderName: header.Authorization, - } - - provider, err := getOAuthHeaderProvider(oauthConfig) - if err != nil { - return fmt.Errorf("failed to get OAuth header provider: %w", err), http.StatusInternalServerError - } - - payload, err := provider.getOAuthToken(r, OAuthSpec) - if err != nil { - return fmt.Errorf("failed to get OAuth token: %w", err), http.StatusInternalServerError - } - - upstreamOAuthProvider.AuthValue = payload - headerName := provider.getHeaderName(OAuthSpec) - if headerName != "" { - upstreamOAuthProvider.HeaderName = headerName - } - - if provider.headerEnabled(OAuthSpec) { - headerName := provider.getHeaderName(OAuthSpec) - if headerName != "" { - upstreamOAuthProvider.HeaderName = headerName - } - } - - httputil.SetUpstreamAuth(r, upstreamOAuthProvider) +// ProcessRequest is noop implementation for upstream OAuth mw. +func (d *noopUpstreamOAuth) ProcessRequest(_ http.ResponseWriter, _ *http.Request, _ interface{}) (error, int) { return nil, http.StatusOK } -func getOAuthHeaderProvider(oauthConfig apidef.UpstreamOAuth) (OAuthHeaderProvider, error) { - if !oauthConfig.IsEnabled() { - return nil, fmt.Errorf("upstream OAuth is not enabled") - } - - switch { - case len(oauthConfig.AllowedAuthorizeTypes) == 0: - return nil, fmt.Errorf("no OAuth configuration selected") - case len(oauthConfig.AllowedAuthorizeTypes) > 1: - return nil, fmt.Errorf("both client credentials and password authentication are provided") - case oauthConfig.AllowedAuthorizeTypes[0] == apidef.OAuthAuthorizationTypeClientCredentials: - return &ClientCredentialsOAuthProvider{}, nil - case oauthConfig.AllowedAuthorizeTypes[0] == apidef.OAuthAuthorizationTypePassword: - return &PasswordOAuthProvider{}, nil - default: - return nil, fmt.Errorf("no valid OAuth configuration provided") - } -} - -func (p *PerAPIClientCredentialsOAuthProvider) getOAuthHeaderValue(r *http.Request, OAuthSpec *UpstreamOAuth) (string, error) { - oauthConfig := OAuthSpec.Spec.UpstreamAuth.OAuth - - if oauthConfig.ClientCredentials.TokenProvider == nil { - cfg := newOAuth2ClientCredentialsConfig(OAuthSpec) - tokenSource := cfg.TokenSource(r.Context()) - - oauthConfig.ClientCredentials.TokenProvider = tokenSource - } - - oauthToken, err := oauthConfig.ClientCredentials.TokenProvider.Token() - if err != nil { - return handleOAuthError(r, OAuthSpec, err) - } - - payload := fmt.Sprintf("Bearer %s", oauthToken.AccessToken) - return payload, nil -} - -func handleOAuthError(r *http.Request, OAuthSpec *UpstreamOAuth, err error) (string, error) { - OAuthSpec.emitUpstreamOAuthEvent(r, UpstreamOAuthErrorEventName, err.Error(), OAuthSpec.Spec.APIID) - return "", err -} - -func (p *ClientCredentialsOAuthProvider) getOAuthToken(r *http.Request, OAuthSpec *UpstreamOAuth) (string, error) { - if OAuthSpec.Gw.UpstreamOAuthCache == nil { - OAuthSpec.Gw.UpstreamOAuthCache = newUpstreamOAuthClientCredentialsCache(OAuthSpec.Gw.StorageConnectionHandler) - } - - token, err := OAuthSpec.Gw.UpstreamOAuthCache.getToken(r, OAuthSpec) - if err != nil { - return handleOAuthError(r, OAuthSpec, err) +// EnabledForSpec will always return false for noopUpstreamOAuth. +func (d *noopUpstreamOAuth) EnabledForSpec() bool { + if d.Spec.UpstreamAuth.OAuth.Enabled { + d.Logger().Error("Upstream OAuth is supported only in Tyk Enterprise Edition") } - return fmt.Sprintf("Bearer %s", token), nil + return false } -func (p *ClientCredentialsOAuthProvider) headerEnabled(OAuthSpec *UpstreamOAuth) bool { - return OAuthSpec.Spec.UpstreamAuth.OAuth.ClientCredentials.Header.Enabled -} - -func (p *ClientCredentialsOAuthProvider) getHeaderName(OAuthSpec *UpstreamOAuth) string { - return OAuthSpec.Spec.UpstreamAuth.OAuth.ClientCredentials.Header.Name -} - -func (p *PasswordOAuthProvider) getOAuthToken(r *http.Request, OAuthSpec *UpstreamOAuth) (string, error) { - if OAuthSpec.Gw.UpstreamOAuthCache == nil { - OAuthSpec.Gw.UpstreamOAuthCache = newUpstreamOAuthPasswordCache(OAuthSpec.Gw.StorageConnectionHandler) - } - - token, err := OAuthSpec.Gw.UpstreamOAuthCache.getToken(r, OAuthSpec) - if err != nil { - return handleOAuthError(r, OAuthSpec, err) - } - - return fmt.Sprintf("Bearer %s", token), nil -} - -func (p *PasswordOAuthProvider) getHeaderName(OAuthSpec *UpstreamOAuth) string { - return OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.Header.Name -} - -func (p *PasswordOAuthProvider) headerEnabled(OAuthSpec *UpstreamOAuth) bool { - return OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.Header.Enabled -} - -func generatePasswordOAuthCacheKey(config apidef.UpstreamOAuth, apiId string) string { - key := fmt.Sprintf( - "%s|%s|%s|%s", - apiId, - config.PasswordAuthentication.ClientID, - config.PasswordAuthentication.ClientSecret, - strings.Join(config.PasswordAuthentication.Scopes, ",")) - - hash := sha256.New() - hash.Write([]byte(key)) - return hex.EncodeToString(hash.Sum(nil)) -} - -func generateClientCredentialsCacheKey(config apidef.UpstreamOAuth, apiId string) string { - key := fmt.Sprintf( - "%s|%s|%s|%s", - apiId, - config.ClientCredentials.ClientID, - config.ClientCredentials.TokenURL, - strings.Join(config.ClientCredentials.Scopes, ",")) - - hash := sha256.New() - hash.Write([]byte(key)) - return hex.EncodeToString(hash.Sum(nil)) -} - -type TokenData struct { - Token string `json:"token"` - ExtraMetadata map[string]interface{} `json:"extra_metadata"` -} - -func (cache *upstreamOAuthClientCredentialsCache) getToken(r *http.Request, OAuthSpec *UpstreamOAuth) (string, error) { - cacheKey := generateClientCredentialsCacheKey(OAuthSpec.Spec.UpstreamAuth.OAuth, OAuthSpec.Spec.APIID) - - tokenData, err := retryGetKeyAndLock(cacheKey, &cache.RedisCluster) - if err != nil { - return "", err - } - - if tokenData != "" { - tokenContents, err := unmarshalTokenData(tokenData) - if err != nil { - return "", err - } - decryptedToken := decrypt(getPaddedSecret(OAuthSpec.Gw.GetConfig().Secret), tokenContents.Token) - setExtraMetadata(r, OAuthSpec.Spec.UpstreamAuth.OAuth.ClientCredentials.ExtraMetadata, tokenContents.ExtraMetadata) - return decryptedToken, nil - } - - token, err := cache.obtainToken(r.Context(), OAuthSpec) - if err != nil { - return "", err - } - - encryptedToken := encrypt(getPaddedSecret(OAuthSpec.Gw.GetConfig().Secret), token.AccessToken) - tokenDataBytes, err := createTokenDataBytes(encryptedToken, token, OAuthSpec.Spec.UpstreamAuth.OAuth.ClientCredentials.ExtraMetadata) - if err != nil { - return "", err - } - metadataMap := buildMetadataMap(token, OAuthSpec.Spec.UpstreamAuth.OAuth.ClientCredentials.ExtraMetadata) - setExtraMetadata(r, OAuthSpec.Spec.UpstreamAuth.OAuth.ClientCredentials.ExtraMetadata, metadataMap) - - ttl := time.Until(token.Expiry) - if err := setTokenInCache(cacheKey, string(tokenDataBytes), ttl, &cache.RedisCluster); err != nil { - return "", err - } - - return token.AccessToken, nil -} - -func createTokenDataBytes(encryptedToken string, token *oauth2.Token, extraMetadataKeys []string) ([]byte, error) { - td := TokenData{ - Token: encryptedToken, - ExtraMetadata: buildMetadataMap(token, extraMetadataKeys), - } - return json.Marshal(td) -} - -func unmarshalTokenData(tokenData string) (TokenData, error) { - var tokenContents TokenData - err := json.Unmarshal([]byte(tokenData), &tokenContents) - if err != nil { - return TokenData{}, fmt.Errorf("failed to unmarshal token data: %w", err) - } - return tokenContents, nil -} - -func buildMetadataMap(token *oauth2.Token, extraMetadataKeys []string) map[string]interface{} { - metadataMap := make(map[string]interface{}) - for _, key := range extraMetadataKeys { - if val := token.Extra(key); val != "" && val != nil { - metadataMap[key] = val - } - } - return metadataMap -} - -func setExtraMetadata(r *http.Request, keyList []string, token map[string]interface{}) { - contextDataObject := ctxGetData(r) - if contextDataObject == nil { - contextDataObject = make(map[string]interface{}) - } - for _, key := range keyList { - if val, ok := token[key]; ok && val != "" { - contextDataObject[key] = val - } - } - ctxSetData(r, contextDataObject) -} - -func retryGetKeyAndLock(cacheKey string, cache *storage.RedisCluster) (string, error) { - const maxRetries = 10 - const retryDelay = 100 * time.Millisecond - - var tokenData string - var err error - - for i := 0; i < maxRetries; i++ { - tokenData, err = cache.GetKey(cacheKey) - if err == nil { - return tokenData, nil - } - - lockKey := cacheKey + ":lock" - ok, err := cache.Lock(lockKey, time.Second*5) - if err == nil && ok { - return "", nil - } - - time.Sleep(retryDelay) - } - - return "", fmt.Errorf("failed to acquire lock after retries: %v", err) -} - -func newOAuth2ClientCredentialsConfig(OAuthSpec *UpstreamOAuth) oauth2clientcredentials.Config { - return oauth2clientcredentials.Config{ - ClientID: OAuthSpec.Spec.UpstreamAuth.OAuth.ClientCredentials.ClientID, - ClientSecret: OAuthSpec.Spec.UpstreamAuth.OAuth.ClientCredentials.ClientSecret, - TokenURL: OAuthSpec.Spec.UpstreamAuth.OAuth.ClientCredentials.TokenURL, - Scopes: OAuthSpec.Spec.UpstreamAuth.OAuth.ClientCredentials.Scopes, - } -} - -func newOAuth2PasswordConfig(OAuthSpec *UpstreamOAuth) oauth2.Config { - return oauth2.Config{ - ClientID: OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.ClientID, - ClientSecret: OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.ClientSecret, - Endpoint: oauth2.Endpoint{ - TokenURL: OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.TokenURL, - }, - Scopes: OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.Scopes, - } -} - -func (cache *upstreamOAuthClientCredentialsCache) obtainToken(ctx context.Context, OAuthSpec *UpstreamOAuth) (*oauth2.Token, error) { - cfg := newOAuth2ClientCredentialsConfig(OAuthSpec) - - tokenSource := cfg.TokenSource(ctx) - oauthToken, err := tokenSource.Token() - if err != nil { - return &oauth2.Token{}, err - } - - return oauthToken, nil -} - -func setTokenInCache(cacheKey string, token string, ttl time.Duration, cache *storage.RedisCluster) error { - oauthTokenExpiry := time.Now().Add(ttl) - return cache.SetKey(cacheKey, token, int64(oauthTokenExpiry.Sub(time.Now()).Seconds())) -} - -// UpstreamOAuthProvider implements upstream auth provider. -type UpstreamOAuthProvider struct { - // HeaderName is the header name to be used to fill upstream auth with. - HeaderName string - // AuthValue is the value of auth header. - AuthValue string -} - -// Fill sets the request's HeaderName with AuthValue -func (u UpstreamOAuthProvider) Fill(r *http.Request) { - if r.Header.Get(u.HeaderName) != "" { - log.WithFields(logrus.Fields{ - "header": u.HeaderName, - }).Info("Authorization header conflict detected: Client header overwritten by Gateway upstream authentication header.") - } - r.Header.Set(u.HeaderName, u.AuthValue) +// Name returns the name of the mw. +func (d *noopUpstreamOAuth) Name() string { + return "NooPUpstreamOAuth" } diff --git a/gateway/mw_oauth2_auth_ee.go b/gateway/mw_oauth2_auth_ee.go new file mode 100644 index 00000000000..a034ea7d019 --- /dev/null +++ b/gateway/mw_oauth2_auth_ee.go @@ -0,0 +1,30 @@ +//go:build ee || dev + +package gateway + +import ( + "github.com/TykTechnologies/tyk/ee/middleware/upstreamoauth" + "github.com/TykTechnologies/tyk/internal/model" + "github.com/TykTechnologies/tyk/storage" +) + +func getUpstreamOAuthMw(base *BaseMiddleware) TykMiddleware { + mwSpec := model.MergedAPI{APIDefinition: base.Spec.APIDefinition} + upstreamOAuthMw := upstreamoauth.NewMiddleware( + base.Gw, + base, + mwSpec, + getClientCredentialsStorageHandler(base), + getPasswordStorageHandler(base), + ) + + return WrapMiddleware(base, upstreamOAuthMw) +} + +func getClientCredentialsStorageHandler(base *BaseMiddleware) *storage.RedisCluster { + return &storage.RedisCluster{KeyPrefix: "upstreamOAuthCC-", ConnectionHandler: base.Gw.StorageConnectionHandler} +} + +func getPasswordStorageHandler(base *BaseMiddleware) *storage.RedisCluster { + return &storage.RedisCluster{KeyPrefix: "upstreamOAuthPW-", ConnectionHandler: base.Gw.StorageConnectionHandler} +} diff --git a/gateway/mw_url_rewrite_test.go b/gateway/mw_url_rewrite_test.go index 1da2fbc3072..d7b690463cf 100644 --- a/gateway/mw_url_rewrite_test.go +++ b/gateway/mw_url_rewrite_test.go @@ -8,11 +8,9 @@ import ( "github.com/stretchr/testify/assert" + "github.com/TykTechnologies/tyk/apidef" "github.com/TykTechnologies/tyk/ctx" - "github.com/TykTechnologies/tyk/test" - - "github.com/TykTechnologies/tyk/apidef" "github.com/TykTechnologies/tyk/user" ) diff --git a/gateway/reverse_proxy_test.go b/gateway/reverse_proxy_test.go index 65a6d0b48c1..6335233d01c 100644 --- a/gateway/reverse_proxy_test.go +++ b/gateway/reverse_proxy_test.go @@ -24,10 +24,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/TykTechnologies/tyk/user" - - "github.com/TykTechnologies/tyk/header" - "github.com/TykTechnologies/graphql-go-tools/pkg/execution/datasource" "github.com/TykTechnologies/graphql-go-tools/pkg/graphql" @@ -35,8 +31,10 @@ import ( "github.com/TykTechnologies/tyk/config" "github.com/TykTechnologies/tyk/ctx" "github.com/TykTechnologies/tyk/dnscache" + "github.com/TykTechnologies/tyk/header" "github.com/TykTechnologies/tyk/request" "github.com/TykTechnologies/tyk/test" + "github.com/TykTechnologies/tyk/user" ) func TestCopyHeader_NoDuplicateCORSHeaders(t *testing.T) { diff --git a/gateway/rpc_backup_handlers.go b/gateway/rpc_backup_handlers.go index 3d682c81407..121b7fb64bc 100644 --- a/gateway/rpc_backup_handlers.go +++ b/gateway/rpc_backup_handlers.go @@ -1,15 +1,12 @@ package gateway import ( - "crypto/aes" - "crypto/cipher" - cryptorand "crypto/rand" - "encoding/base64" "encoding/json" "errors" - "io" "strings" + "github.com/TykTechnologies/tyk/internal/crypto" + "github.com/sirupsen/logrus" "github.com/TykTechnologies/tyk/storage" @@ -41,13 +38,13 @@ func (gw *Gateway) LoadDefinitionsFromRPCBackup() ([]*APISpec, error) { return nil, errors.New("[RPC] --> RPC Backup recovery failed: redis connection failed") } - secret := rightPad2Len(gw.GetConfig().Secret, "=", 32) + secret := crypto.GetPaddedString(gw.GetConfig().Secret) cryptoText, err := store.GetKey(checkKey) if err != nil { return nil, errors.New("[RPC] --> Failed to get node backup (" + checkKey + "): " + err.Error()) } - apiListAsString := decrypt([]byte(secret), cryptoText) + apiListAsString := crypto.Decrypt([]byte(secret), cryptoText) a := APIDefinitionLoader{Gw: gw} return a.processRPCDefinitions(apiListAsString, gw) @@ -72,8 +69,8 @@ func (gw *Gateway) saveRPCDefinitionsBackup(list string) error { return errors.New("--> RPC Backup save failed: redis connection failed") } - secret := rightPad2Len(gw.GetConfig().Secret, "=", 32) - cryptoText := encrypt([]byte(secret), list) + secret := crypto.GetPaddedString(gw.GetConfig().Secret) + cryptoText := crypto.Encrypt([]byte(secret), list) err := store.SetKey(BackupApiKeyBase+tagList, cryptoText, -1) if err != nil { return errors.New("Failed to store node backup: " + err.Error()) @@ -95,9 +92,9 @@ func (gw *Gateway) LoadPoliciesFromRPCBackup() (map[string]user.Policy, error) { return nil, errors.New("[RPC] --> RPC Policy Backup recovery failed: redis connection failed") } - secret := rightPad2Len(gw.GetConfig().Secret, "=", 32) + secret := crypto.GetPaddedString(gw.GetConfig().Secret) cryptoText, err := store.GetKey(checkKey) - listAsString := decrypt([]byte(secret), cryptoText) + listAsString := crypto.Decrypt([]byte(secret), cryptoText) if err != nil { return nil, errors.New("[RPC] --> Failed to get node policy backup (" + checkKey + "): " + err.Error()) @@ -113,10 +110,6 @@ func (gw *Gateway) LoadPoliciesFromRPCBackup() (map[string]user.Policy, error) { } } -func getPaddedSecret(secret string) []byte { - return []byte(rightPad2Len(secret, "=", 32)) -} - func (gw *Gateway) saveRPCPoliciesBackup(list string) error { if !json.Valid([]byte(list)) { return errors.New("--> RPC Backup save failure: wrong format, skipping.") @@ -136,7 +129,7 @@ func (gw *Gateway) saveRPCPoliciesBackup(list string) error { return errors.New("--> RPC Backup save failed: redis connection failed") } - cryptoText := encrypt(getPaddedSecret(gw.GetConfig().Secret), list) + cryptoText := crypto.Encrypt(crypto.GetPaddedString(gw.GetConfig().Secret), list) err := store.SetKey(BackupPolicyKeyBase+tagList, cryptoText, -1) if err != nil { return errors.New("Failed to store node backup: " + err.Error()) @@ -144,62 +137,3 @@ func (gw *Gateway) saveRPCPoliciesBackup(list string) error { return nil } - -// encrypt string to base64 crypto using AES -func encrypt(key []byte, text string) string { - plaintext := []byte(text) - - block, err := aes.NewCipher(key) - if err != nil { - log.Error(err) - return "" - } - - // The IV needs to be unique, but not secure. Therefore it's common to - // include it at the beginning of the ciphertext. - ciphertext := make([]byte, aes.BlockSize+len(plaintext)) - iv := ciphertext[:aes.BlockSize] - if _, err := io.ReadFull(cryptorand.Reader, iv); err != nil { - log.Error(err) - return "" - } - - stream := cipher.NewCFBEncrypter(block, iv) - stream.XORKeyStream(ciphertext[aes.BlockSize:], plaintext) - - // convert to base64 - return base64.URLEncoding.EncodeToString(ciphertext) -} - -// decrypt from base64 to decrypted string -func decrypt(key []byte, cryptoText string) string { - ciphertext, _ := base64.URLEncoding.DecodeString(cryptoText) - - block, err := aes.NewCipher(key) - if err != nil { - log.Error(err) - return "" - } - - // The IV needs to be unique, but not secure. Therefore it's common to - // include it at the beginning of the ciphertext. - if len(ciphertext) < aes.BlockSize { - log.Error("ciphertext too short") - return "" - } - iv := ciphertext[:aes.BlockSize] - ciphertext = ciphertext[aes.BlockSize:] - - stream := cipher.NewCFBDecrypter(block, iv) - - // XORKeyStream can work in-place if the two arguments are the same. - stream.XORKeyStream(ciphertext, ciphertext) - - return string(ciphertext) -} - -func rightPad2Len(s, padStr string, overallLen int) string { - padCountInt := 1 + (overallLen-len(padStr))/len(padStr) - retStr := s + strings.Repeat(padStr, padCountInt) - return retStr[:overallLen] -} diff --git a/gateway/server.go b/gateway/server.go index bc7b40a1ec4..aa20ee19eea 100644 --- a/gateway/server.go +++ b/gateway/server.go @@ -125,8 +125,6 @@ type Gateway struct { HostCheckTicker chan struct{} HostCheckerClient *http.Client TracerProvider otel.TracerProvider - // UpstreamOAuthCache is used to cache upstream OAuth tokens - UpstreamOAuthCache UpstreamOAuthCache keyGen DefaultKeyGenerator diff --git a/internal/crypto/helpers.go b/internal/crypto/helpers.go index 8bd41dea66f..54e5d3cbeba 100644 --- a/internal/crypto/helpers.go +++ b/internal/crypto/helpers.go @@ -2,15 +2,19 @@ package crypto import ( "bytes" + "crypto/aes" + "crypto/cipher" "crypto/rand" "crypto/rsa" "crypto/sha256" "crypto/tls" "crypto/x509" "crypto/x509/pkix" + "encoding/base64" "encoding/hex" "encoding/pem" "errors" + "io" "math/big" "net" "net/http" @@ -18,6 +22,8 @@ import ( "testing" "time" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" ) @@ -179,3 +185,70 @@ func GenerateRSAPublicKey(tb testing.TB) []byte { publicKeyPEM := pem.EncodeToMemory(publicKeyBlock) return publicKeyPEM } + +func GetPaddedString(str string) []byte { + return []byte(RightPad2Len(str, "=", 32)) +} + +// encrypt string to base64 crypto using AES +func Encrypt(key []byte, str string) string { + plaintext := []byte(str) + + block, err := aes.NewCipher(key) + if err != nil { + logrus.Error(err) + return "" + } + + // The IV needs to be unique, but not secure. Therefore, it's common to + // include it at the beginning of the ciphertext. + ciphertext := make([]byte, aes.BlockSize+len(plaintext)) + iv := ciphertext[:aes.BlockSize] + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + logrus.Error(err) + return "" + } + + stream := cipher.NewCFBEncrypter(block, iv) + stream.XORKeyStream(ciphertext[aes.BlockSize:], plaintext) + + // convert to base64 + return base64.URLEncoding.EncodeToString(ciphertext) +} + +// Decrypt from base64 to decrypted string +func Decrypt(key []byte, cryptoText string) string { + ciphertext, err := base64.URLEncoding.DecodeString(cryptoText) + if err != nil { + logrus.Error(err) + return "" + } + + block, err := aes.NewCipher(key) + if err != nil { + logrus.Error(err) + return "" + } + + // The IV needs to be unique, but not secure. Therefore it's common to + // include it at the beginning of the ciphertext. + if len(ciphertext) < aes.BlockSize { + logrus.Error("ciphertext too short") + return "" + } + iv := ciphertext[:aes.BlockSize] + ciphertext = ciphertext[aes.BlockSize:] + + stream := cipher.NewCFBDecrypter(block, iv) + + // XORKeyStream can work in-place if the two arguments are the same. + stream.XORKeyStream(ciphertext, ciphertext) + + return string(ciphertext) +} + +func RightPad2Len(s, padStr string, overallLen int) string { + padCountInt := 1 + (overallLen-len(padStr))/len(padStr) + retStr := s + strings.Repeat(padStr, padCountInt) + return retStr[:overallLen] +} diff --git a/internal/event/event.go b/internal/event/event.go index e0ce929af69..f24b5f4796c 100644 --- a/internal/event/event.go +++ b/internal/event/event.go @@ -1,7 +1,9 @@ package event import ( + "bytes" "context" + "encoding/base64" "net/http" ) @@ -122,3 +124,15 @@ func Get(ctx context.Context) []Event { } return nil } + +// EncodeRequestToEvent will write the request out in wire protocol and +// encode it to base64 and store it in an Event object +func EncodeRequestToEvent(r *http.Request) string { + var asBytes bytes.Buffer + err := r.Write(&asBytes) + if err != nil { + return "" + } + + return base64.StdEncoding.EncodeToString(asBytes.Bytes()) +} diff --git a/internal/httpctx/context.go b/internal/httpctx/context.go new file mode 100644 index 00000000000..a139172e9a5 --- /dev/null +++ b/internal/httpctx/context.go @@ -0,0 +1,28 @@ +package httpctx + +import ( + "context" + "net/http" +) + +type Value[T any] struct { + Key any +} + +func NewValue[T any](key any) *Value[T] { + return &Value[T]{Key: key} +} + +func (v *Value[T]) Get(r *http.Request) (res T) { + if val := r.Context().Value(v.Key); val != nil { + res, _ = val.(T) + } + return +} + +func (v *Value[T]) Set(r *http.Request, val T) *http.Request { + ctx := context.WithValue(r.Context(), v.Key, val) + h := r.WithContext(ctx) + *r = *h + return h +} diff --git a/internal/httpctx/context_test.go b/internal/httpctx/context_test.go new file mode 100644 index 00000000000..758fc60b37c --- /dev/null +++ b/internal/httpctx/context_test.go @@ -0,0 +1,64 @@ +package httpctx_test + +import ( + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/TykTechnologies/tyk/internal/httpctx" +) + +func TestValue_SetAndGet(t *testing.T) { + // Define a key and instantiate a new Value with type map[string]any + key := "testKey" + value := httpctx.NewValue[map[string]any](key) + + // Prepare a map to store in context + expectedData := map[string]any{ + "userID": 123, + "userRole": "admin", + } + + // Create a new HTTP request using httptest + req := httptest.NewRequest("GET", "/", nil) + + // Set the value in the request's context + req = value.Set(req, expectedData) + + // Retrieve the value from the context + retrievedData := value.Get(req) + assert.Equal(t, expectedData, retrievedData, "Retrieved data does not match expected data") +} + +func TestValue_GetWithMissingKey(t *testing.T) { + // Define a key and instantiate a new Value with type map[string]any + key := "missingKey" + value := httpctx.NewValue[map[string]any](key) + + // Create a new HTTP request using httptest + req := httptest.NewRequest("GET", "/", nil) + + // Try to retrieve the value from the context + retrievedData := value.Get(req) + + // Expect not to find any data + assert.Nil(t, retrievedData, "Expected retrieved data to be nil for a missing key") +} + +func TestValue_SetDifferentTypes(t *testing.T) { + // Test using a different type for Value, e.g., int + intKey := "intKey" + intValue := httpctx.NewValue[int](intKey) + + // Create a new HTTP request using httptest + req := httptest.NewRequest("GET", "/", nil) + + // Set an int value in the context + expectedInt := 42 + req = intValue.Set(req, expectedInt) + + // Retrieve the int value from the context + retrievedInt := intValue.Get(req) + assert.Equal(t, expectedInt, retrievedInt, "Retrieved int value does not match expected value") +} diff --git a/internal/model/events.go b/internal/model/events.go new file mode 100644 index 00000000000..2b0604d4a84 --- /dev/null +++ b/internal/model/events.go @@ -0,0 +1,22 @@ +package model + +import ( + "net/http" + + "github.com/TykTechnologies/tyk/internal/event" +) + +// EventMetaDefault is a standard embedded struct to be used with custom event metadata types, gives an interface for +// easily extending event metadata objects +type EventMetaDefault struct { + Message string + OriginatingRequest string +} + +// NewEventMetaDefault creates an instance of model.EventMetaDefault. +func NewEventMetaDefault(r *http.Request, message string) EventMetaDefault { + return EventMetaDefault{ + Message: message, + OriginatingRequest: event.EncodeRequestToEvent(r), + } +}