Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose conformance to CredentialsProviderV2. #40

Merged
merged 1 commit into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Package.resolved
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
"kind" : "remoteSourceControl",
"location" : "https://github.com/amzn/smoke-aws-support.git",
"state" : {
"revision" : "141efadb31e399736b23cfd2478af3dbdc170259",
"version" : "1.5.0"
"revision" : "54d2a727df6d440c8c2415124f45c72484e448cd",
"version" : "1.6.0"
}
},
{
Expand Down
2 changes: 1 addition & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ let package = Package(
dependencies: [
.package(url: "https://github.com/swift-server/async-http-client.git", from: "1.19.0"),
.package(url: "https://github.com/amzn/smoke-aws.git", from: "2.44.174"),
.package(url: "https://github.com/amzn/smoke-aws-support.git", from: "1.5.0"),
.package(url: "https://github.com/amzn/smoke-aws-support.git", from: "1.6.0"),
.package(url: "https://github.com/apple/swift-nio.git", from: "2.0.0"),
.package(url: "https://github.com/apple/swift-log.git", from: "1.0.0"),
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ public extension AwsContainerRotatingCredentialsProvider {
logger: Logging.Logger = Logger(label: "com.amazon.SmokeAWSCredentials"),
traceContext _: TraceContextType,
eventLoopProvider: HTTPClient.EventLoopGroupProvider = .singleton)
-> StoppableCredentialsProvider? {
-> (StoppableCredentialsProvider & CredentialsProviderV2)? {
return self.get(fromEnvironment: environment,
logger: logger,
eventLoopProvider: eventLoopProvider)
Expand All @@ -97,11 +97,11 @@ public extension AwsContainerRotatingCredentialsProvider {
static func get(fromEnvironment environment: [String: String] = ProcessInfo.processInfo.environment,
logger: Logging.Logger = Logger(label: "com.amazon.SmokeAWSCredentials"),
eventLoopProvider: HTTPClient.EventLoopGroupProvider = .singleton)
-> StoppableCredentialsProvider? {
-> (StoppableCredentialsProvider & CredentialsProviderV2)? {
var credentialsLogger = logger
credentialsLogger[metadataKey: "credentials.source"] = "environment"

var credentialsProvider: StoppableCredentialsProvider?
var credentialsProvider: (StoppableCredentialsProvider & CredentialsProviderV2)?
if let credentialsRetriever = getRotatingCredentialsRetriever(fromEnvironment: environment,
logger: credentialsLogger,
eventLoopProvider: eventLoopProvider,
Expand Down Expand Up @@ -129,7 +129,7 @@ public extension AwsContainerRotatingCredentialsProvider {
static func get(fromEnvironment environment: [String: String] = ProcessInfo.processInfo.environment,
logger: Logging.Logger = Logger(label: "com.amazon.SmokeAWSCredentials"),
eventLoopProvider: HTTPClient.EventLoopGroupProvider = .singleton) async
-> StoppableCredentialsProvider? {
-> (StoppableCredentialsProvider & CredentialsProviderV2)? {
return await self.get(fromEnvironment: environment,
logger: logger,
dataRetrieverOverride: nil,
Expand All @@ -141,11 +141,11 @@ public extension AwsContainerRotatingCredentialsProvider {
logger: Logging.Logger = Logger(label: "com.amazon.SmokeAWSCredentials"),
dataRetrieverOverride: (() throws -> Data)?,
eventLoopProvider: HTTPClient.EventLoopGroupProvider = .singleton) async
-> StoppableCredentialsProvider? {
-> (StoppableCredentialsProvider & CredentialsProviderV2)? {
var credentialsLogger = logger
credentialsLogger[metadataKey: "credentials.source"] = "environment"

var credentialsProvider: StoppableCredentialsProvider?
var credentialsProvider: (StoppableCredentialsProvider & CredentialsProviderV2)?
if let credentialsRetriever = getRotatingCredentialsRetriever(fromEnvironment: environment,
logger: credentialsLogger,
eventLoopProvider: eventLoopProvider,
Expand All @@ -172,7 +172,7 @@ public extension AwsContainerRotatingCredentialsProvider {

private static func getStaticCredentialsProvider(fromEnvironment environment: [String: String],
logger: Logger)
-> StoppableCredentialsProvider? {
-> (StoppableCredentialsProvider & CredentialsProviderV2)? {
// get the values of the environment variables
let awsAccessKeyId = environment["AWS_ACCESS_KEY_ID"]
let awsSecretAccessKey = environment["AWS_SECRET_ACCESS_KEY"]
Expand Down
56 changes: 26 additions & 30 deletions Sources/SmokeAWSCredentials/AwsRotatingCredentialsProviderV2.swift
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ private actor CurrentCredentials {
// this is held seperately to `state` so the existing credentials can continue to
// be used until the background refresh is complete
private var backgroundPendingCredentialsTask: Task<ExpiringCredentials, Swift.Error>?
private let backgroundLogger: Logger
private let logger: Logger
private let credentialsStreamContinuation: AsyncStream<ExpiringCredentials>.Continuation

private enum State {
Expand All @@ -72,14 +72,14 @@ private actor CurrentCredentials {
init(
credentials: ExpiringCredentials,
expiringCredentialsRetriever: ExpiringCredentialsAsyncRetriever,
backgroundLogger: Logger,
logger: Logger,
credentialsStreamContinuation: AsyncStream<ExpiringCredentials>.Continuation,
expirationBufferSeconds: Double,
backgroundExpirationBufferSeconds: Double
) {
self.state = .present(credentials)
self.expiringCredentialsRetriever = expiringCredentialsRetriever
self.backgroundLogger = backgroundLogger
self.logger = logger
self.credentialsStreamContinuation = credentialsStreamContinuation
self.expirationBufferSeconds = expirationBufferSeconds
self.backgroundExpirationBufferSeconds = backgroundExpirationBufferSeconds
Expand All @@ -104,33 +104,38 @@ private actor CurrentCredentials {
Gets the current credentials, ensuring that these credentials are always valid
*/
func get(
isBackgroundRefresh: Bool = false,
logger: Logger = Logger(label: "com.azmn.smoke-aws-credentials.CurrentCredentials.get")
isBackgroundRefresh: Bool = false
) async throws -> AWSCore.Credentials {
switch self.state {
case .present(let presentValue):
// if not within the buffer period and about to become expired
if !isBackgroundRefresh, let expiration = presentValue.expiration,
expiration > Date(timeIntervalSinceNow: self.expirationBufferSeconds) {
// these credentials can be used
logger.trace("Current credentials used.")
self.logger.trace("Current credentials used. Current credentials do not expire until \(expiration.iso8601)")

return presentValue
} else if let backgroundPendingCredentialsTask = self.backgroundPendingCredentialsTask {
self.logger.trace("Waiting on existing background credentials refresh")

// if there is an-progress background refresh
// normally we wouldn't wait on this task but the current credentials are now expired
// so they can't be used
return try await backgroundPendingCredentialsTask.value
}

logger.trace("Replacing current credentials.")
if let expiration = presentValue.expiration {
self.logger.trace("Replacing current credentials. Current credentials expiring at \(expiration.iso8601)")
} else {
self.logger.trace("Replacing current credentials.")
}
case .pending(let task):
// There is a pending credentials refresh
logger.trace("Waiting on existing credentials refresh")
self.logger.trace("Waiting on existing credentials refresh")

return try await task.value
case .missing:
logger.trace("Fetching new credentials.")
self.logger.trace("Fetching new credentials.")
}

// get the task for this entry
Expand Down Expand Up @@ -162,8 +167,8 @@ private actor CurrentCredentials {
do {
try await self.expiringCredentialsRetriever.shutdown()
} catch {
self.backgroundLogger.warning("ExpiringCredentialsRetriever failed to shutdown cleanly",
metadata: ["cause": "\(error)"])
self.logger.warning("ExpiringCredentialsRetriever failed to shutdown cleanly",
metadata: ["cause": "\(error)"])
}

switch self.state {
Expand Down Expand Up @@ -240,34 +245,34 @@ private actor CurrentCredentials {
let overflowMinutes = Int(waitDurationInMinutes) % 60

if waitDurationInSeconds > 0 {
self.backgroundLogger.trace(
self.logger.trace(
"Credentials updated; rotation scheduled in \(wholeNumberOfHours) hours, \(overflowMinutes) minutes.")
do {
try await Task.sleep(nanoseconds: UInt64(waitDurationInSeconds) * secondsToNanoSeconds)
} catch is CancellationError {
self.backgroundLogger.trace(
self.logger.trace(
"Background credentials rotation cancelled.")
return
} catch {
self.backgroundLogger.error(
self.logger.error(
"Background credentials rotation failed due to error \(error).")
return
}
}

do {
_ = try await self.get(isBackgroundRefresh: true, logger: self.backgroundLogger)
_ = try await self.get(isBackgroundRefresh: true)
} catch is CancellationError {
self.backgroundLogger.trace(
self.logger.trace(
"Background credentials rotation cancelled.")
return
} catch {
self.backgroundLogger.error(
self.logger.error(
"Background credentials rotation failed due to error \(error).")
return
}

self.backgroundLogger.trace(
self.logger.trace(
"Background credentials rotation completed.")
}
}
Expand Down Expand Up @@ -323,15 +328,10 @@ public class AwsRotatingCredentialsProviderV2: StoppableCredentialsProvider, Cre
self.expiringCredentials = try expiringCredentialsRetriever.get()
self.status = .initialized

var decoratedLogger = logger
if let roleSessionName {
decoratedLogger[metadataKey: "roleSessionName"] = "\(roleSessionName)"
}

self.credentialsStream = AsyncStream.makeStream(of: ExpiringCredentials.self)
self.currentCredentials = CurrentCredentials(credentials: self.expiringCredentials,
expiringCredentialsRetriever: expiringCredentialsRetriever,
backgroundLogger: decoratedLogger,
logger: logger,
credentialsStreamContinuation: self.credentialsStream.continuation,
expirationBufferSeconds: expirationBufferSeconds,
backgroundExpirationBufferSeconds: backgroundExpirationBufferSeconds)
Expand All @@ -345,15 +345,10 @@ public class AwsRotatingCredentialsProviderV2: StoppableCredentialsProvider, Cre
self.expiringCredentials = try await expiringCredentialsRetriever.getCredentials()
self.status = .initialized

var decoratedLogger = logger
if let roleSessionName {
decoratedLogger[metadataKey: "roleSessionName"] = "\(roleSessionName)"
}

self.credentialsStream = AsyncStream.makeStream(of: ExpiringCredentials.self)
self.currentCredentials = CurrentCredentials(credentials: self.expiringCredentials,
expiringCredentialsRetriever: expiringCredentialsRetriever,
backgroundLogger: decoratedLogger,
logger: logger,
credentialsStreamContinuation: self.credentialsStream.continuation,
expirationBufferSeconds: expirationBufferSeconds,
backgroundExpirationBufferSeconds: backgroundExpirationBufferSeconds)
Expand Down Expand Up @@ -456,6 +451,7 @@ public class AwsRotatingCredentialsProviderV2: StoppableCredentialsProvider, Cre
}

public func getCredentials() async throws -> Credentials {

return try await self.currentCredentials.get()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ public extension SmokeAWSCore.CredentialsProvider {
durationSeconds: Int?,
logger: Logging.Logger = Logger(label: "com.amazon.SmokeAWSCredentials"),
retryConfiguration: HTTPClientRetryConfiguration = .default,
eventLoopProvider: HTTPClient.EventLoopGroupProvider = .singleton) -> StoppableCredentialsProvider? {
eventLoopProvider: HTTPClient.EventLoopGroupProvider = .singleton)
-> (StoppableCredentialsProvider & CredentialsProviderV2)? {
return self.getAssumedRotatingCredentials(
roleArn: roleArn,
roleSessionName: roleSessionName,
Expand All @@ -143,7 +144,8 @@ public extension SmokeAWSCore.CredentialsProvider {
durationSeconds: Int?,
logger: Logging.Logger = Logger(label: "com.amazon.SmokeAWSCredentials"),
retryConfiguration: HTTPClientRetryConfiguration = .default,
eventLoopProvider: HTTPClient.EventLoopGroupProvider = .singleton) async -> StoppableCredentialsProvider? {
eventLoopProvider: HTTPClient.EventLoopGroupProvider = .singleton) async
-> (StoppableCredentialsProvider & CredentialsProviderV2)? {
return await self.getAssumedRotatingCredentials(
roleArn: roleArn,
roleSessionName: roleSessionName,
Expand Down Expand Up @@ -177,7 +179,8 @@ public extension SmokeAWSCore.CredentialsProvider {
traceContext: TraceContextType,
retryConfiguration: HTTPClientRetryConfiguration = .default,
eventLoopProvider: HTTPClient
.EventLoopGroupProvider = .singleton) -> StoppableCredentialsProvider? {
.EventLoopGroupProvider = .singleton)
-> (StoppableCredentialsProvider & CredentialsProviderV2)? {
var credentialsLogger = logger
credentialsLogger[metadataKey: "credentials.source"] = "assumed.\(roleSessionName)"
let reporting = CredentialsInvocationReporting(logger: credentialsLogger,
Expand All @@ -201,7 +204,8 @@ public extension SmokeAWSCore.CredentialsProvider {
traceContext: TraceContextType,
retryConfiguration: HTTPClientRetryConfiguration = .default,
eventLoopProvider: HTTPClient
.EventLoopGroupProvider = .singleton) async -> StoppableCredentialsProvider? {
.EventLoopGroupProvider = .singleton) async
-> (StoppableCredentialsProvider & CredentialsProviderV2)? {
var credentialsLogger = logger
credentialsLogger[metadataKey: "credentials.source"] = "assumed.\(roleSessionName)"
let reporting = CredentialsInvocationReporting(logger: credentialsLogger,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ extension SecurityTokenClientProtocol {
retryConfiguration: HTTPClientRetryConfiguration,
eventLoopProvider: HTTPClient.EventLoopGroupProvider,
reportingConfiguration: SmokeAWSClientReportingConfiguration<SecurityTokenModelOperations> = .none)
-> StoppableCredentialsProvider? {
-> (StoppableCredentialsProvider & CredentialsProviderV2)? {
let credentialsRetriever = AWSSTSExpiringCredentialsRetriever(
credentialsProvider: credentialsProvider,
roleArn: roleArn,
Expand Down Expand Up @@ -278,7 +278,7 @@ extension SecurityTokenClientProtocolV2 {
retryConfiguration: HTTPClientRetryConfiguration,
eventLoopProvider: HTTPClient.EventLoopGroupProvider,
reportingConfiguration: SmokeAWSClientReportingConfiguration<SecurityTokenModelOperations> = .none) async
-> StoppableCredentialsProvider? {
-> (StoppableCredentialsProvider & CredentialsProviderV2)? {
let credentialsRetriever = AWSSTSExpiringCredentialsRetriever(
credentialsProvider: credentialsProvider,
roleArn: roleArn,
Expand Down Expand Up @@ -307,6 +307,20 @@ extension SecurityTokenClientProtocolV2 {
}
}

private let iso8601DateFormatter: DateFormatter = {
let formatter = DateFormatter()
formatter.calendar = Calendar(identifier: .iso8601)
formatter.locale = Locale(identifier: "en_US_POSIX")
formatter.dateFormat = "yyyy-MM-dd'T'HH:mm:ss.SSSXXXXX"
return formatter
}()

internal extension Foundation.Date {
var iso8601: String {
iso8601DateFormatter.string(from: self)
}
}

internal extension String {
/**
Returns a date instance if this string is formatted according to
Expand Down
Loading