diff --git a/tests/test_vault_utils.py b/tests/test_vault_utils.py index 3df41cc..a5c703d 100644 --- a/tests/test_vault_utils.py +++ b/tests/test_vault_utils.py @@ -2,9 +2,9 @@ import string from unittest.mock import patch, MagicMock import json -from dotenv import dotenv_values from whispr.utils.vault import fetch_secrets, get_filled_secrets, prepare_vault_config +from whispr.utils.vault import get_raw_secret from whispr.utils.crypto import generate_rand_secret from whispr.enums import VaultType @@ -190,3 +190,138 @@ def test_all_punctuation_exclusion(self): # Ensure no punctuation is in the result for ch in secret: self.assertNotIn(ch, string.punctuation) + + +class GetRawSecretTestCase(unittest.TestCase): + """Unit tests for get_raw_secret function.""" + + def setUp(self): + """Set up shared test data.""" + self.secret_name = "test_secret" + self.aws_region = "us-east-1" + self.azure_vault_url = "https://my-azure-vault.vault.azure.net/" + self.gcp_project_id = "my-gcp-project" + + @patch("whispr.utils.vault.logger", new_callable=MagicMock) + @patch("whispr.utils.vault.fetch_secrets") + def test_no_vault_provided(self, mock_fetch_secrets, mock_logger): + """Test that an empty dict is returned and an error is logged when no vault is provided.""" + mock_fetch_secrets.return_value = {"some_key": "some_value"} + + result = get_raw_secret(self.secret_name, vault="") + + # Expect an empty dict, an error log, and no call to fetch_secrets + self.assertEqual(result, {}) + mock_logger.error.assert_called_once() + mock_fetch_secrets.assert_not_called() + + @patch("whispr.utils.vault.logger", new_callable=MagicMock) + @patch("whispr.utils.vault.fetch_secrets") + def test_no_secret_name_provided(self, mock_fetch_secrets, mock_logger): + """Test that an empty dict is returned and an error is logged when no secret name is provided.""" + mock_fetch_secrets.return_value = {"some_key": "some_value"} + + result = get_raw_secret(secret_name="", vault=VaultType.AWS.value, region=self.aws_region) + + self.assertEqual(result, {}) + mock_logger.error.assert_called_once() + mock_fetch_secrets.assert_not_called() + + @patch("whispr.utils.vault.logger", new_callable=MagicMock) + @patch("whispr.utils.vault.fetch_secrets") + def test_aws_missing_region(self, mock_fetch_secrets, mock_logger): + """Test that an empty dict is returned and an error is logged for AWS if region is missing.""" + mock_fetch_secrets.return_value = {"aws_key": "aws_value"} + + result = get_raw_secret(self.secret_name, VaultType.AWS.value) + + self.assertEqual(result, {}) + mock_logger.error.assert_called_once() + mock_fetch_secrets.assert_not_called() + + @patch("whispr.utils.vault.logger", new_callable=MagicMock) + @patch("whispr.utils.vault.fetch_secrets") + def test_azure_missing_vault_url(self, mock_fetch_secrets, mock_logger): + """Test that an empty dict is returned and an error is logged for Azure if vault_url is missing.""" + mock_fetch_secrets.return_value = {"azure_key": "azure_value"} + + result = get_raw_secret(self.secret_name, VaultType.AZURE.value) + + self.assertEqual(result, {}) + mock_logger.error.assert_called_once() + mock_fetch_secrets.assert_not_called() + + @patch("whispr.utils.vault.logger", new_callable=MagicMock) + @patch("whispr.utils.vault.fetch_secrets") + def test_gcp_missing_project_id(self, mock_fetch_secrets, mock_logger): + """Test that an empty dict is returned and an error is logged for GCP if project_id is missing.""" + mock_fetch_secrets.return_value = {"gcp_key": "gcp_value"} + + result = get_raw_secret(self.secret_name, VaultType.GCP.value) + + self.assertEqual(result, {}) + mock_logger.error.assert_called_once() + mock_fetch_secrets.assert_not_called() + + @patch("whispr.utils.vault.logger", new_callable=MagicMock) + @patch("whispr.utils.vault.fetch_secrets") + def test_aws_success(self, mock_fetch_secrets, mock_logger): + """Test successful retrieval for AWS with valid region.""" + expected_response = {"aws_key": "aws_value"} + mock_fetch_secrets.return_value = expected_response + + result = get_raw_secret( + secret_name=self.secret_name, + vault=VaultType.AWS.value, + region=self.aws_region + ) + + self.assertEqual(result, expected_response) + mock_logger.error.assert_not_called() + mock_fetch_secrets.assert_called_once_with({ + "secret_name": self.secret_name, + "vault": VaultType.AWS.value, + "region": self.aws_region + }) + + @patch("whispr.utils.vault.logger", new_callable=MagicMock) + @patch("whispr.utils.vault.fetch_secrets") + def test_azure_success(self, mock_fetch_secrets, mock_logger): + """Test successful retrieval for Azure with valid vault_url.""" + expected_response = {"azure_key": "azure_value"} + mock_fetch_secrets.return_value = expected_response + + result = get_raw_secret( + secret_name=self.secret_name, + vault=VaultType.AZURE.value, + vault_url=self.azure_vault_url + ) + + self.assertEqual(result, expected_response) + mock_logger.error.assert_not_called() + mock_fetch_secrets.assert_called_once_with({ + "secret_name": self.secret_name, + "vault": VaultType.AZURE.value, + "vault_url": self.azure_vault_url + }) + + @patch("whispr.utils.vault.logger", new_callable=MagicMock) + @patch("whispr.utils.vault.fetch_secrets") + def test_gcp_success(self, mock_fetch_secrets, mock_logger): + """Test successful retrieval for GCP with valid project_id.""" + expected_response = {"gcp_key": "gcp_value"} + mock_fetch_secrets.return_value = expected_response + + result = get_raw_secret( + secret_name=self.secret_name, + vault=VaultType.GCP.value, + project_id=self.gcp_project_id + ) + + self.assertEqual(result, expected_response) + mock_logger.error.assert_not_called() + mock_fetch_secrets.assert_called_once_with({ + "secret_name": self.secret_name, + "vault": VaultType.GCP.value, + "project_id": self.gcp_project_id + })