Skip to content

Commit

Permalink
Closes #107 - refactored retrieveAwsSecret to solve the GitHub comments
Browse files Browse the repository at this point in the history
  • Loading branch information
TebaleloS committed Feb 1, 2024
1 parent 851a212 commit 9628b50
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ package za.co.absa.atum.server.api
import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider
import software.amazon.awssdk.regions.Region
import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient
import software.amazon.awssdk.services.secretsmanager.model.{GetSecretValueRequest, SecretsManagerException}
import org.slf4j.{Logger, LoggerFactory}
import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest


/**
Expand All @@ -38,24 +37,14 @@ class RetrieveAwsSecret (profileCredentials: String = "default") {
* @param secretName
* @return a sequence of string from aws secret service
*/
def retrieveAwsSecret(secretName: String): Seq[String] = {
val logger: Logger = LoggerFactory.getLogger(getClass.getName)

try {
val request = GetSecretValueRequest.builder()
.secretId(secretName)
.build()

val response = secretsManagerClient.getSecretValue(request)
response.secretString.map(_.toString)
} catch {
case e: SecretsManagerException =>
logger.error(s"Error retrieving secret key: ${e.getMessage}")
e.getMessage.map(_.toString)
} finally {
// Close the client when done
secretsManagerClient.close()
}
def retrieveAwsSecret(secretName: String): String = {
val request = GetSecretValueRequest.builder()
.secretId(secretName)
.build()

val response = secretsManagerClient.getSecretValue(request)
response.secretString
}

}

Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class PostgresAccessProvider @Autowired()(atumConfig: AtumConfig) {

private def overrideWithSecret(oldConfig: Config, path: String, secretName: String): Config = {

val secretString: Seq[String] = retrieveAwsSecret.retrieveAwsSecret(secretName)
val secretString: String = retrieveAwsSecret.retrieveAwsSecret(secretName)
val overrideValue = secretString.foldLeft(Try("")) {(acc, s) =>
acc.flatMap(str => Try(str + s))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class RetrieveAwsSecretTest extends AnyFlatSpec with MockitoSugar {
assert(overrideValue == expectedResults)
}

it should "handle SecretsManagerException and return error message" in {
it should "return an error message" in {
val mockSecretsManagerClient = mock[SecretsManagerClient]
val retrieveAwsSecret = new RetrieveAwsSecret("testProfile") {
override val secretsManagerClient: SecretsManagerClient = mockSecretsManagerClient
Expand All @@ -57,7 +57,11 @@ class RetrieveAwsSecretTest extends AnyFlatSpec with MockitoSugar {
val exception = SecretsManagerException.builder().message("testError").build()
when(mockSecretsManagerClient.getSecretValue(any[GetSecretValueRequest])).thenThrow(exception)

assert(retrieveAwsSecret.retrieveAwsSecret(testSecretName) == "testError".map(_.toString))
val thrown = intercept[SecretsManagerException] {
retrieveAwsSecret.retrieveAwsSecret(testSecretName)
}
assert(thrown.getMessage == "testError")
}


}

0 comments on commit 9628b50

Please sign in to comment.