diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py b/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py index 29a3829a136..759ff39bdf2 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py @@ -6,7 +6,7 @@ import logging import os -from typing import Any, Callable, Dict, List, Optional, Set, Union +from typing import Any, Callable, Dict, List, Optional, Set, Union, get_type_hints from azure.core.credentials import TokenCredential from azure.core.exceptions import ResourceNotFoundError @@ -61,9 +61,16 @@ def __init__( "storageFileShareName", }, ) - assert self._parent is not None and isinstance( - self._parent, SupportsAuth - ), "Authorization service not provided. Include service-auth.jsonc?" + # Ensure that the parent service is an authentication service that provides + # a TokenCredential. + assert ( + self._parent is not None + and isinstance(self._parent, SupportsAuth) + and get_type_hints(self._parent.get_credential).get("return") == TokenCredential + ), ( + "Azure Authentication service not provided. " + "Include services/remote/azure/service-auth.jsonc?" + ) self._auth_service: SupportsAuth[TokenCredential] = self._parent self._share_client: Optional[ShareClient] = None @@ -71,9 +78,10 @@ def _get_share_client(self) -> ShareClient: """Get the Azure file share client object.""" if self._share_client is None: credential = self._auth_service.get_credential() - assert isinstance( - credential, TokenCredential - ), f"Expected a TokenCredential, but got {type(credential)} instead." + assert isinstance(credential, TokenCredential), ( + f"Expected a TokenCredential, but got {type(credential)} instead. " + "Include services/remote/azure/service-auth.jsonc?" + ) self._share_client = ShareClient.from_share_url( self._SHARE_URL.format( account_name=self.config["storageAccountName"], diff --git a/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py b/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py index e010fd140b9..2c2e74e3fb5 100644 --- a/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py @@ -50,15 +50,18 @@ def test_load_service_config_examples( """Tests loading a config example.""" parent: Service = config_loader_service config = config_loader_service.load_config(config_path, ConfigSchema.SERVICE) + # Add other services that require a SupportsAuth parent service as necessary. + # AzureFileShareService requires an AzureAuth service to be loaded as well. + # mock_auth_service_config = "services/remote/mock/mock_auth_service.jsonc" + azure_auth_service_config = "services/remote/azure/service-auth.jsonc" requires_auth_service_parent = { - "AzureFileShareService", + "AzureFileShareService": azure_auth_service_config, } config_class_name = str(config.get("class", "MISSING CLASS")).rsplit(".", maxsplit=1)[-1] - if config_class_name in requires_auth_service_parent: - # AzureFileShareService requires an auth service to be loaded as well. + if auth_service_config_path := requires_auth_service_parent.get(config_class_name): auth_service_config = config_loader_service.load_config( - "services/remote/mock/mock_auth_service.jsonc", + auth_service_config_path, ConfigSchema.SERVICE, ) auth_service = config_loader_service.build_service( @@ -66,6 +69,7 @@ def test_load_service_config_examples( parent=config_loader_service, ) parent = auth_service + # Make an instance of the class based on the config. service_inst = config_loader_service.build_service( config=config,