Skip to content

Commit

Permalink
tests: add test case for get raw secret
Browse files Browse the repository at this point in the history
  • Loading branch information
narenaryan committed Jan 15, 2025
1 parent 21e8281 commit 029c899
Showing 1 changed file with 136 additions and 1 deletion.
137 changes: 136 additions & 1 deletion tests/test_vault_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import string
from unittest.mock import patch, MagicMock
import json
from dotenv import dotenv_values

from whispr.utils.vault import fetch_secrets, get_filled_secrets, prepare_vault_config
from whispr.utils.vault import get_raw_secret
from whispr.utils.crypto import generate_rand_secret
from whispr.enums import VaultType

Expand Down Expand Up @@ -190,3 +190,138 @@ def test_all_punctuation_exclusion(self):
# Ensure no punctuation is in the result
for ch in secret:
self.assertNotIn(ch, string.punctuation)


class GetRawSecretTestCase(unittest.TestCase):
"""Unit tests for get_raw_secret function."""

def setUp(self):
"""Set up shared test data."""
self.secret_name = "test_secret"
self.aws_region = "us-east-1"
self.azure_vault_url = "https://my-azure-vault.vault.azure.net/"
self.gcp_project_id = "my-gcp-project"

@patch("whispr.utils.vault.logger", new_callable=MagicMock)
@patch("whispr.utils.vault.fetch_secrets")
def test_no_vault_provided(self, mock_fetch_secrets, mock_logger):
"""Test that an empty dict is returned and an error is logged when no vault is provided."""
mock_fetch_secrets.return_value = {"some_key": "some_value"}

result = get_raw_secret(self.secret_name, vault="")

# Expect an empty dict, an error log, and no call to fetch_secrets
self.assertEqual(result, {})
mock_logger.error.assert_called_once()
mock_fetch_secrets.assert_not_called()

@patch("whispr.utils.vault.logger", new_callable=MagicMock)
@patch("whispr.utils.vault.fetch_secrets")
def test_no_secret_name_provided(self, mock_fetch_secrets, mock_logger):
"""Test that an empty dict is returned and an error is logged when no secret name is provided."""
mock_fetch_secrets.return_value = {"some_key": "some_value"}

result = get_raw_secret(secret_name="", vault=VaultType.AWS.value, region=self.aws_region)

self.assertEqual(result, {})
mock_logger.error.assert_called_once()
mock_fetch_secrets.assert_not_called()

@patch("whispr.utils.vault.logger", new_callable=MagicMock)
@patch("whispr.utils.vault.fetch_secrets")
def test_aws_missing_region(self, mock_fetch_secrets, mock_logger):
"""Test that an empty dict is returned and an error is logged for AWS if region is missing."""
mock_fetch_secrets.return_value = {"aws_key": "aws_value"}

result = get_raw_secret(self.secret_name, VaultType.AWS.value)

self.assertEqual(result, {})
mock_logger.error.assert_called_once()
mock_fetch_secrets.assert_not_called()

@patch("whispr.utils.vault.logger", new_callable=MagicMock)
@patch("whispr.utils.vault.fetch_secrets")
def test_azure_missing_vault_url(self, mock_fetch_secrets, mock_logger):
"""Test that an empty dict is returned and an error is logged for Azure if vault_url is missing."""
mock_fetch_secrets.return_value = {"azure_key": "azure_value"}

result = get_raw_secret(self.secret_name, VaultType.AZURE.value)

self.assertEqual(result, {})
mock_logger.error.assert_called_once()
mock_fetch_secrets.assert_not_called()

@patch("whispr.utils.vault.logger", new_callable=MagicMock)
@patch("whispr.utils.vault.fetch_secrets")
def test_gcp_missing_project_id(self, mock_fetch_secrets, mock_logger):
"""Test that an empty dict is returned and an error is logged for GCP if project_id is missing."""
mock_fetch_secrets.return_value = {"gcp_key": "gcp_value"}

result = get_raw_secret(self.secret_name, VaultType.GCP.value)

self.assertEqual(result, {})
mock_logger.error.assert_called_once()
mock_fetch_secrets.assert_not_called()

@patch("whispr.utils.vault.logger", new_callable=MagicMock)
@patch("whispr.utils.vault.fetch_secrets")
def test_aws_success(self, mock_fetch_secrets, mock_logger):
"""Test successful retrieval for AWS with valid region."""
expected_response = {"aws_key": "aws_value"}
mock_fetch_secrets.return_value = expected_response

result = get_raw_secret(
secret_name=self.secret_name,
vault=VaultType.AWS.value,
region=self.aws_region
)

self.assertEqual(result, expected_response)
mock_logger.error.assert_not_called()
mock_fetch_secrets.assert_called_once_with({
"secret_name": self.secret_name,
"vault": VaultType.AWS.value,
"region": self.aws_region
})

@patch("whispr.utils.vault.logger", new_callable=MagicMock)
@patch("whispr.utils.vault.fetch_secrets")
def test_azure_success(self, mock_fetch_secrets, mock_logger):
"""Test successful retrieval for Azure with valid vault_url."""
expected_response = {"azure_key": "azure_value"}
mock_fetch_secrets.return_value = expected_response

result = get_raw_secret(
secret_name=self.secret_name,
vault=VaultType.AZURE.value,
vault_url=self.azure_vault_url
)

self.assertEqual(result, expected_response)
mock_logger.error.assert_not_called()
mock_fetch_secrets.assert_called_once_with({
"secret_name": self.secret_name,
"vault": VaultType.AZURE.value,
"vault_url": self.azure_vault_url
})

@patch("whispr.utils.vault.logger", new_callable=MagicMock)
@patch("whispr.utils.vault.fetch_secrets")
def test_gcp_success(self, mock_fetch_secrets, mock_logger):
"""Test successful retrieval for GCP with valid project_id."""
expected_response = {"gcp_key": "gcp_value"}
mock_fetch_secrets.return_value = expected_response

result = get_raw_secret(
secret_name=self.secret_name,
vault=VaultType.GCP.value,
project_id=self.gcp_project_id
)

self.assertEqual(result, expected_response)
mock_logger.error.assert_not_called()
mock_fetch_secrets.assert_called_once_with({
"secret_name": self.secret_name,
"vault": VaultType.GCP.value,
"project_id": self.gcp_project_id
})

0 comments on commit 029c899

Please sign in to comment.