Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proposed fixes for token-based auth in azure fileshare service #820

Merged
merged 11 commits into from
Aug 3, 2024
8 changes: 7 additions & 1 deletion mlos_bench/mlos_bench/services/remote/azure/azure_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Union

import azure.core.credentials as azure_cred
import azure.identity as azure_id
from azure.keyvault.secrets import SecretClient
from pytz import UTC
Expand All @@ -20,7 +21,7 @@
_LOG = logging.getLogger(__name__)


class AzureAuthService(Service, SupportsAuth):
class AzureAuthService(Service, SupportsAuth[azure_cred.TokenCredential]):
"""Helper methods to get access to Azure services."""

_REQ_INTERVAL = 300 # = 5 min
Expand Down Expand Up @@ -56,6 +57,7 @@ def __init__(
[
self.get_access_token,
self.get_auth_headers,
self.get_credential,
],
),
)
Expand Down Expand Up @@ -133,3 +135,7 @@ def get_access_token(self) -> str:
def get_auth_headers(self) -> dict:
"""Get the authorization part of HTTP headers for REST API calls."""
return {"Authorization": "Bearer " + self.get_access_token()}

def get_credential(self) -> azure_cred.TokenCredential:
"""Return the Azure SDK credential object."""
return self._cred
14 changes: 10 additions & 4 deletions mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
from typing import Any, Callable, Dict, List, Optional, Set, Union

import azure.core.credentials as azure_cred
from azure.core.exceptions import ResourceNotFoundError
from azure.storage.fileshare import ShareClient

Expand Down Expand Up @@ -60,20 +61,25 @@ def __init__(
"storageFileShareName",
},
)
assert self._parent is not None and isinstance(
self._parent, SupportsAuth
), "Authorization service not provided. Include service-auth.jsonc?"
self._auth_service: SupportsAuth[azure_cred.TokenCredential] = self._parent
self._share_client: Optional[ShareClient] = None

def _get_share_client(self) -> ShareClient:
"""Get the Azure file share client object."""
if self._share_client is None:
assert self._parent is not None and isinstance(
self._parent, SupportsAuth
), "Authorization service not provided. Include service-auth.jsonc?"
credential = self._auth_service.get_credential()
assert isinstance(
credential, azure_cred.TokenCredential
), f"Expected a TokenCredential, but got {type(credential)} instead."
Copy link
Contributor

@bpkroth bpkroth Aug 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still a late runtime error that I was trying to turn into a config load error in #819

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, I made a new change to #819 to handle that now.
From what I can tell the rest of this one is good, but it'd be good to get confirmation that DefaultCredential doesn't need refreshed, especially with the SP part.

self._share_client = ShareClient.from_share_url(
self._SHARE_URL.format(
account_name=self.config["storageAccountName"],
fs_name=self.config["storageFileShareName"],
),
credential=self._parent.get_access_token(),
credential=credential,
token_intent="backup",
)
return self._share_client
Expand Down
16 changes: 14 additions & 2 deletions mlos_bench/mlos_bench/services/types/authenticator_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
#
"""Protocol interface for authentication for the cloud services."""

from typing import Protocol, runtime_checkable
from typing import Protocol, TypeVar, runtime_checkable

T_co = TypeVar("T_co", covariant=True)


@runtime_checkable
class SupportsAuth(Protocol):
class SupportsAuth(Protocol[T_co]):
"""Protocol interface for authentication for the cloud services."""

def get_access_token(self) -> str:
Expand All @@ -30,3 +32,13 @@ def get_auth_headers(self) -> dict:
access_header : dict
HTTP header containing the access token.
"""

def get_credential(self) -> T_co:
"""
Get the credential object for cloud services.

Returns
-------
credential : T
Cloud-specific credential object.
"""
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,28 @@ def test_load_service_config_examples(
config_path: str,
) -> None:
"""Tests loading a config example."""
parent: Service = config_loader_service
config = config_loader_service.load_config(config_path, ConfigSchema.SERVICE)
# Add other services that require a SupportsAuth parent service as necessary.
requires_auth_service_parent = {
"AzureFileShareService",
}
config_class_name = str(config.get("class", "MISSING CLASS")).rsplit(".", maxsplit=1)[-1]
if config_class_name in requires_auth_service_parent:
# AzureFileShareService requires an auth service to be loaded as well.
auth_service_config = config_loader_service.load_config(
"services/remote/mock/mock_auth_service.jsonc",
ConfigSchema.SERVICE,
)
auth_service = config_loader_service.build_service(
config=auth_service_config,
parent=config_loader_service,
)
parent = auth_service
# Make an instance of the class based on the config.
service_inst = config_loader_service.build_service(
config=config,
parent=config_loader_service,
parent=parent,
)
assert service_inst is not None
assert isinstance(service_inst, Service)
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
_LOG = logging.getLogger(__name__)


class MockAuthService(Service, SupportsAuth):
class MockAuthService(Service, SupportsAuth[str]):
"""A collection Service functions for mocking authentication ops."""

def __init__(
Expand All @@ -32,6 +32,7 @@ def __init__(
[
self.get_access_token,
self.get_auth_headers,
self.get_credential,
],
),
)
Expand All @@ -41,3 +42,6 @@ def get_access_token(self) -> str:

def get_auth_headers(self) -> dict:
return {"Authorization": "Bearer " + self.get_access_token()}

def get_credential(self) -> str:
return "MOCK CREDENTIAL"
Loading