Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor] try to make DefaultAzureCredential func configurable #65

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions eureka_ml_insights/configs/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
TestModel,
)

from azure.identity import DefaultAzureCredential

from .config import ModelConfig

# For models that require secret keys, you can store the keys in a json file and provide the path to the file
Expand All @@ -32,6 +34,7 @@
"key_name": "your_openai_secret_key_name",
"local_keys_path": "keys/keys.json",
"key_vault_url": None,
"credential_func": DefaultAzureCredential,
}

OAI_O1_PREVIEW_CONFIG = ModelConfig(
Expand Down Expand Up @@ -96,6 +99,7 @@
"key_name": "your_gemini_secret_key_name",
"local_keys_path": "keys/keys.json",
"key_vault_url": None,
"credential_func": DefaultAzureCredential,
}

GEMINI_V15_PRO_CONFIG = ModelConfig(
Expand All @@ -119,6 +123,7 @@
"key_name": "your_claude_secret_key_name",
"local_keys_path": "keys/keys.json",
"key_vault_url": None,
"credential_func": DefaultAzureCredential,
}

CLAUDE_3_OPUS_CONFIG = ModelConfig(
Expand Down Expand Up @@ -168,6 +173,7 @@
"key_name": "your_llama_secret_key_name",
"local_keys_path": "keys/keys.json",
"key_vault_url": None,
"credential_func": DefaultAzureCredential,
},
"model_name": "meta-llama-3-1-70b-instruct",
},
Expand All @@ -181,6 +187,7 @@
"key_name": "your_llama_secret_key_name",
"local_keys_path": "keys/keys.json",
"key_vault_url": None,
"credential_func": DefaultAzureCredential,
},
"model_name": "Meta-Llama-3-1-405B-Instruct",
},
Expand All @@ -195,6 +202,7 @@
"key_name": "your_mistral_secret_key_name",
"local_keys_path": "keys/keys.json",
"key_vault_url": None,
"credential_func": DefaultAzureCredential,
},
"model_name": "Mistral-large-2407",
},
Expand Down
16 changes: 10 additions & 6 deletions eureka_ml_insights/data_utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import jsonlines
import pandas as pd
from azure.identity import DefaultAzureCredential
from azure.storage.blob import BlobClient, ContainerClient
from datasets import load_dataset
from PIL import Image
Expand Down Expand Up @@ -234,6 +233,7 @@ def __init__(
path,
account_url,
blob_container,
credential_func:callable=lambda _: None,
total_lines=None,
image_column_names=None,
image_column_search_regex="image",
Expand All @@ -260,7 +260,7 @@ def __init__(
self.container_client = ContainerClient(
account_url=self.account_url,
container_name=self.blob_container,
credential=DefaultAzureCredential(),
credential=credential_func(),
logger=self.logger,
)

Expand Down Expand Up @@ -312,13 +312,13 @@ def read(self):
class AzureBlobReader:
"""Reads an Azure storage blob from a full URL to a str"""

def read_azure_blob(self, blob_url) -> str:
def read_azure_blob(self, blob_url, credential_func:callable=lambda _: None) -> str:
"""
Reads an Azure storage blob..
args:
blob_url: str, The Azure storage blob full URL.
"""
blob_client = BlobClient.from_blob_url(blob_url, credential=DefaultAzureCredential(), logger=self.logger)
blob_client = BlobClient.from_blob_url(blob_url, credential=credential_func(), logger=self.logger)
# real all the bytes from the blob
file = blob_client.download_blob().readall()
file = file.decode("utf-8")
Expand All @@ -336,6 +336,7 @@ def __init__(
account_url: str,
blob_container: str,
blob_name: str,
credential_func:callable=lambda _: None,
):
"""
Initializes an AzureJsonReader.
Expand All @@ -346,10 +347,11 @@ def __init__(
"""
self.blob_url = f"{account_url}/{blob_container}/{blob_name}"
super().__init__(self.blob_url)
self.credential_func = credential_func
self.logger = AzureStorageLogger().get_logger()

def read(self) -> dict:
file = super().read_azure_blob(self.blob_url)
file = super().read_azure_blob(self.blob_url, credential_func=self.credential_func)
if self.format == ".json":
data = json.loads(file)
elif self.format == ".jsonl":
Expand Down Expand Up @@ -464,6 +466,7 @@ def __init__(
account_url: str,
blob_container: str,
blob_name: str,
credential_func:callable = lambda _: None,
format: str = None,
transform: Optional[DFTransformBase] = None,
**kwargs,
Expand All @@ -480,10 +483,11 @@ def __init__(
"""
self.blob_url = f"{account_url}/{blob_container}/{blob_name}"
super().__init__(self.blob_url, format, transform, **kwargs)
self.credential_func = credential_func
self.logger = AzureStorageLogger().get_logger()

def _load_dataset(self) -> pd.DataFrame:
file = super().read_azure_blob(self.blob_url)
file = super().read_azure_blob(self.blob_url, credential_func=self.credential_func)
if self.format == ".jsonl":
jlr = jsonlines.Reader(file.splitlines())
df = pd.DataFrame(jlr.iter(skip_empty=True, skip_invalid=True))
Expand Down
9 changes: 5 additions & 4 deletions eureka_ml_insights/metrics/kitab_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
ServiceRequestError,
ServiceResponseError,
)
from azure.identity import DefaultAzureCredential
from fuzzywuzzy import fuzz

from eureka_ml_insights.metrics import CompositeMetric
Expand All @@ -38,12 +37,14 @@ def __init__(self, temp_path_names, azure_lang_service_config):
"https://huggingface.co/datasets/microsoft/kitab/raw/main/code/utils/gpt_4_name_data_processed.csv",
temp_path_names,
)
self.credential_func = azure_lang_service_config["secret_key_params"].get("credential_func", lambda _: None)
# requires an Azure Cognitive Services Endpoint
# (ref: https://learn.microsoft.com/en-us/azure/ai-services/language-service/)
self.key = get_secret(
key_name=azure_lang_service_config["secret_key_params"].get("key_name", None),
local_keys_path=azure_lang_service_config["secret_key_params"].get("local_keys_path", None),
key_vault_url=azure_lang_service_config["secret_key_params"].get("key_vault_url", None),
credential_func=self.credential_func
)
self.endpoint = azure_lang_service_config["url"]
self.text_analytics_credential = self.get_verified_credential()
Expand All @@ -58,11 +59,11 @@ def get_verified_credential(self):
logging.info(f"Failed to create the TextAnalyticsClient using AzureKeyCredential")
logging.info("The error is caused by: {}".format(e))
try:
text_analytics_client = TextAnalyticsClient(endpoint=self.endpoint, credential=DefaultAzureCredential())
text_analytics_client = TextAnalyticsClient(endpoint=self.endpoint, credential=self.credential_func())
text_analytics_client.recognize_entities(["New York City"], model_version=model_version)
return DefaultAzureCredential()
return self.credential_func()
except Exception as e:
logging.info(f"Failed to create the TextAnalyticsClient using DefaultAzureCredential")
logging.info(f"Failed to create the TextAnalyticsClient using provided credential func")
logging.info("The error is caused by: {}".format(e))
return None

Expand Down
12 changes: 9 additions & 3 deletions eureka_ml_insights/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import anthropic
import tiktoken
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
from azure.identity import get_bearer_token_provider

from eureka_ml_insights.secret_management import get_secret

Expand Down Expand Up @@ -222,6 +222,8 @@ class ServerlessAzureRestEndpointModel(EndpointModel, KeyBasedAuthMixIn):
stream: bool = False

def __post_init__(self):
if self.secret_key_params is None:
raise ValueError("secret_key_params must be provided.")
try:
super().__post_init__()
self.headers = {
Expand All @@ -235,7 +237,7 @@ def __post_init__(self):
}
except ValueError:
self.bearer_token_provider = get_bearer_token_provider(
DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
self.secret_key_params["credential_func"](), "https://cognitiveservices.azure.com/.default"
)
self.headers = {
"Content-Type": "application/json",
Expand Down Expand Up @@ -400,11 +402,15 @@ def get_response(self, request):
class AzureOpenAIClientMixIn:
"""This mixin provides some methods to interact with Azure OpenAI models."""

def __post_init__(self):
if self.secret_key_params is None:
raise ValueError("secret_key_params must be provided.")

def get_client(self):
from openai import AzureOpenAI

token_provider = get_bearer_token_provider(
DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
self.secret_key_params["credential_func"](), "https://cognitiveservices.azure.com/.default"
)
return AzureOpenAI(
azure_endpoint=self.url,
Expand Down
22 changes: 6 additions & 16 deletions eureka_ml_insights/secret_management/secret_key_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
import os
from typing import Dict, Optional

from azure.identity import DefaultAzureCredential, DeviceCodeCredential
from azure.keyvault.secrets import SecretClient

logging.basicConfig(level=logging.INFO, format="%(filename)s - %(funcName)s - %(message)s")


def get_secret(key_name: str, local_keys_path:Optional[str]=None, key_vault_url:Optional[str]=None) -> Optional[str]:
def get_secret(key_name: str, local_keys_path:Optional[str]=None, key_vault_url:Optional[str]=None, credential_func=lambda _: None) -> Optional[str]:
"""This function retrieves a key from key vault or if it is locally cached in a file.
args:
key_name: str, the name of the key to retrieve.
Expand Down Expand Up @@ -41,7 +40,7 @@ def get_secret(key_name: str, local_keys_path:Optional[str]=None, key_vault_url:
f"Key [{key_name}] not found in local keys file [{local_keys_path}] and key_vault_url is not provided."
)
else:
key_value = get_key_from_azure(key_name, key_vault_url)
key_value = get_key_from_azure(key_name, key_vault_url, credential_func=credential_func)

# if the key still wasn't found, raise an error
if key_value is None:
Expand All @@ -59,7 +58,7 @@ def get_secret(key_name: str, local_keys_path:Optional[str]=None, key_vault_url:
return key_value


def get_key_from_azure(key_name: str, key_vault_url: str) -> Optional[str]:
def get_key_from_azure(key_name: str, key_vault_url: str, credential_func=lambda _: None) -> Optional[str]:
"""This function retrieves a key from azure key vault.
args:
key_name: str, the name of the key to retrieve.
Expand All @@ -69,23 +68,14 @@ def get_key_from_azure(key_name: str, key_vault_url: str) -> Optional[str]:
"""
logging.getLogger("azure").setLevel(logging.ERROR)
try:
logging.info(f"Trying to get the key from Azure Key Vault {key_vault_url} using DefaultAzureCredential")
credential = DefaultAzureCredential(additionally_allowed_tenants=["*"])
logging.info("Trying to get the key from Azure Key Vault using provided func")
credential = credential_func(additionally_allowed_tenants=["*"])
client = SecretClient(vault_url=key_vault_url, credential=credential)
retrieved_key = client.get_secret(key_name)
return retrieved_key.value
except Exception as e:
logging.info(f"Failed to get the key from Azure Key Vault {key_vault_url} using DefaultAzureCredential")
logging.info("Failed to get the key from Azure Key Vault using provided func")
logging.info("The error is caused by: {}".format(e))
try:
logging.info(f"Trying to get the key from Azure Key Vault {key_vault_url} using DeviceCodeCredential")
credential = DeviceCodeCredential(additionally_allowed_tenants=["*"])
client = SecretClient(vault_url=key_vault_url, credential=credential)
retrieved_key = client.get_secret(key_name)
return retrieved_key.value
except Exception as e:
logging.error("Failed to get the key from Azure Key Vault using DeviceCodeCredential")
logging.error("The error is caused by: {}".format(e))
return None


Expand Down
3 changes: 3 additions & 0 deletions eureka_ml_insights/user_configs/kitab.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
)
from eureka_ml_insights.configs import ExperimentConfig

from azure.identity import DefaultAzureCredential

# Example template for an Azure Language Service config
# required for running entity recognition for evaluating human and city name
AZURE_LANG_SERVICE_CONFIG = {
Expand All @@ -43,6 +45,7 @@
"key_name": "your_azure_lang_service_secret_key_name",
"local_keys_path": "keys/keys.json",
"key_vault_url": None,
"credential_func": DefaultAzureCredential,
},
}

Expand Down
3 changes: 2 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
)
from eureka_ml_insights.metrics import ClassicMetric, CompositeMetric

from azure.identity import DefaultAzureCredential

class TestModel:
def __init__(self, model_name="generic_test_model"):
Expand Down Expand Up @@ -265,5 +266,5 @@ def __init__(self, path, n_iter, image_column_names=None):

class TestAzureMMDataLoader(EarlyStoppableIterable, AzureMMDataLoader):
def __init__(self, path, n_iter, account_url, blob_container, image_column_names=None):
super().__init__(path, account_url, blob_container, image_column_names=image_column_names)
super().__init__(path, account_url, blob_container, credential_func=DefaultAzureCredential, image_column_names=image_column_names)
self.n_iter = n_iter
Loading