diff --git a/pkg/utils/connect/aws/config.go b/pkg/utils/connect/aws/config.go index bc121be109..00d7d1dcfb 100644 --- a/pkg/utils/connect/aws/config.go +++ b/pkg/utils/connect/aws/config.go @@ -21,6 +21,7 @@ import ( "fmt" "os" "strings" + "sync" "github.com/aws/aws-sdk-go-v2/aws" awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware" @@ -31,6 +32,7 @@ import ( stscredstypesv2 "github.com/aws/aws-sdk-go-v2/service/sts/types" awsv1 "github.com/aws/aws-sdk-go/aws" credentialsv1 "github.com/aws/aws-sdk-go/aws/credentials" + stscredsv1 "github.com/aws/aws-sdk-go/aws/credentials/stscreds" endpointsv1 "github.com/aws/aws-sdk-go/aws/endpoints" requestv1 "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/session" @@ -81,6 +83,18 @@ var userAgentV1 = requestv1.NamedHandler{ Fn: requestv1.MakeAddToUserAgentHandler("crossplane-provider-aws", version.Version), } +// userAgentV2 constructs the Crossplane user agent for AWS v2 clients +var userAgentV2 = config.WithAPIOptions([]func(*middleware.Stack) error{ + awsmiddleware.AddUserAgentKeyValue("crossplane-provider-aws", version.Version), +}) + +var ( + muV1 sync.Mutex + muV2 sync.Mutex + defaultConfigV2 *aws.Config + defaultConfigV1 *awsv1.Config +) + // GetConfig constructs an *aws.Config that can be used to authenticate to AWS // API by the AWS clients. func GetConfig(ctx context.Context, c client.Client, mg resource.Managed, region string) (*aws.Config, error) { @@ -376,25 +390,37 @@ func getWebidentityTokenFilePath() string { return webIdentityTokenFileDefaultPath } +// GetDefaultConfigV2 returns a shallow copy of a default SDK +// config. We use this to get a shared credentials cache. +func GetDefaultConfigV2(ctx context.Context) (aws.Config, error) { + // TODO: Possible performance improvement by using an RWMutex and RLock + // to allow parallel copying. + // However, this would likely increase the complexity of the code. + muV2.Lock() + defer muV2.Unlock() + + if defaultConfigV2 == nil { + cfg, err := config.LoadDefaultConfig(ctx, userAgentV2) + if err != nil { + return aws.Config{}, errors.Wrap(err, "failed to load default AWS config") + } + defaultConfigV2 = &cfg + } + + return defaultConfigV2.Copy(), nil +} + // UsePodServiceAccount assumes an IAM role configured via a ServiceAccount. // https://docs.aws.amazon.com/eks/latest/userguide/iam-roles-for-service-accounts.html func UsePodServiceAccount(ctx context.Context, _ []byte, _, region string) (*aws.Config, error) { - if region == GlobalRegion { - cfg, err := config.LoadDefaultConfig( - ctx, - middlewareV2, - ) - return &cfg, errors.Wrap(err, "failed to load default AWS config") - } - cfg, err := config.LoadDefaultConfig( - ctx, - middlewareV2, - config.WithRegion(region), - ) + cfg, err := GetDefaultConfigV2(ctx) if err != nil { - return nil, errors.Wrap(err, fmt.Sprintf("failed to load default AWS config with region %s", region)) + return nil, err + } + if region != GlobalRegion { + cfg.Region = region } - return &cfg, err + return &cfg, nil } // NOTE(muvaf): ACK-generated controllers use aws/aws-sdk-go instead of @@ -637,28 +663,41 @@ func UsePodServiceAccountV1AssumeRoleWithWebIdentity(ctx context.Context, _ []by return SetResolverV1(pc, awsv1.NewConfig().WithCredentials(v1creds).WithRegion(region)), nil } +// GetDefaultConfigV1 returns a shallow copy of a default SDK +// config. We use this to get a shared credentials cache. +func GetDefaultConfigV1() (*awsv1.Config, error) { + // TODO: Possible performance improvement by using an RWMutex and RLock + // to allow parallel copying. + // However, this would likely increase the complexity of the code. + muV1.Lock() + defer muV1.Unlock() + if defaultConfigV1 == nil { + cfg := awsv1.NewConfig() + sess, err := GetSessionV1(cfg) + if err != nil { + return nil, errors.Wrap(err, "failed to load default AWS config") + } + envCfg, err := config.NewEnvConfig() + if err != nil { + return nil, errors.Wrap(err, "failed to load default AWS env config") + } + creds := stscredsv1.NewWebIdentityCredentials(sess, envCfg.RoleARN, envCfg.RoleSessionName, envCfg.WebIdentityTokenFilePath) //nolint:staticcheck + defaultConfigV1 = cfg.WithCredentials(creds) + } + return defaultConfigV1.Copy(), nil +} + // UsePodServiceAccountV1 assumes an IAM role configured via a ServiceAccount. // https://docs.aws.amazon.com/eks/latest/userguide/iam-roles-for-service-accounts.html func UsePodServiceAccountV1(ctx context.Context, _ []byte, pc *v1beta1.ProviderConfig, _, region string) (*awsv1.Config, error) { - cfg, err := config.LoadDefaultConfig( - ctx, - middlewareV2, - ) - if err != nil { - return nil, errors.Wrap(err, "failed to load default AWS config") - } - v2creds, err := cfg.Credentials.Retrieve(ctx) + cfg, err := GetDefaultConfigV1() if err != nil { - return nil, errors.Wrap(err, "failed to retrieve credentials") + return nil, err } - if region == GlobalRegion { - region = cfg.Region + if region != GlobalRegion { + cfg = cfg.WithRegion(region) } - v1creds := credentialsv1.NewStaticCredentials( - v2creds.AccessKeyID, - v2creds.SecretAccessKey, - v2creds.SessionToken) - return SetResolverV1(pc, awsv1.NewConfig().WithCredentials(v1creds).WithRegion(region)), nil + return SetResolverV1(pc, cfg), nil } // SetResolverV1 parses annotations from the managed resource