Skip to content

Commit

Permalink
executes black, ruff, isort and adds check for required env vars
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex-Wenner-FHR committed Jan 2, 2025
1 parent db91df7 commit 27b165b
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 68 deletions.
37 changes: 10 additions & 27 deletions src/autogluon/assistant/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,7 @@
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf

from autogluon.assistant.llm import (
AssistantChatBedrock,
AssistantChatOpenAI,
AssistantAzureChatOpenAI,
LLMFactory,
)
from autogluon.assistant.llm import AssistantAzureChatOpenAI, AssistantChatBedrock, AssistantChatOpenAI, LLMFactory

from .predictor import AutogluonTabularPredictor
from .task import TabularPredictionTask
Expand All @@ -36,9 +31,7 @@
def timeout(seconds: int, error_message: Optional[str] = None):
if sys.platform == "win32":
# Windows implementation using threading
timer = threading.Timer(
seconds, lambda: (_ for _ in ()).throw(TimeoutError(error_message))
)
timer = threading.Timer(seconds, lambda: (_ for _ in ()).throw(TimeoutError(error_message)))
timer.start()
try:
yield
Expand All @@ -62,9 +55,9 @@ class TabularPredictionAssistant:

def __init__(self, config: DictConfig) -> None:
self.config = config
self.llm: Union[
AssistantChatOpenAI, AssistantAzureChatOpenAI, AssistantChatBedrock
] = LLMFactory.get_chat_model(config.llm)
self.llm: Union[AssistantChatOpenAI, AssistantAzureChatOpenAI, AssistantChatBedrock] = (
LLMFactory.get_chat_model(config.llm)
)
self.predictor = AutogluonTabularPredictor(config.autogluon)
self.feature_transformers_config = get_feature_transformers_config(config)

Expand Down Expand Up @@ -104,19 +97,13 @@ def inference_task(self, task: TabularPredictionTask) -> TabularPredictionTask:
):
task = preprocessor.transform(task)
except Exception as e:
self.handle_exception(
f"Task inference preprocessing: {preprocessor_class}", e
)
self.handle_exception(f"Task inference preprocessing: {preprocessor_class}", e)

bold_start = "\033[1m"
bold_end = "\033[0m"

logger.info(
f"{bold_start}Total number of prompt tokens:{bold_end} {self.llm.input_}"
)
logger.info(
f"{bold_start}Total number of completion tokens:{bold_end} {self.llm.output_}"
)
logger.info(f"{bold_start}Total number of prompt tokens:{bold_end} {self.llm.input_}")
logger.info(f"{bold_start}Total number of completion tokens:{bold_end} {self.llm.output_}")
logger.info("Task understanding complete!")
return task

Expand All @@ -126,9 +113,7 @@ def preprocess_task(self, task: TabularPredictionTask) -> TabularPredictionTask:
task = self.inference_task(task)
if self.feature_transformers_config:
logger.info("Automatic feature generation starts...")
fe_transformers = [
instantiate(ft_config) for ft_config in self.feature_transformers_config
]
fe_transformers = [instantiate(ft_config) for ft_config in self.feature_transformers_config]
for fe_transformer in fe_transformers:
try:
with timeout(
Expand All @@ -137,9 +122,7 @@ def preprocess_task(self, task: TabularPredictionTask) -> TabularPredictionTask:
):
task = fe_transformer.fit_transform(task)
except Exception as e:
self.handle_exception(
f"Task preprocessing: {fe_transformer.name}", e
)
self.handle_exception(f"Task preprocessing: {fe_transformer.name}", e)
logger.info("Automatic feature generation complete!")
else:
logger.info("Automatic feature generation is disabled. ")
Expand Down
3 changes: 1 addition & 2 deletions src/autogluon/assistant/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from .llm import AssistantChatBedrock, AssistantChatOpenAI, LLMFactory, AssistantAzureChatOpenAI
from .llm import AssistantAzureChatOpenAI, AssistantChatBedrock, AssistantChatOpenAI, LLMFactory

__all__ = [
"AssistantAzureChatOpenAI",
"AssistantChatOpenAI",
"AssistantChatBedrock",
"LLMFactory",

]
55 changes: 18 additions & 37 deletions src/autogluon/assistant/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import botocore
from langchain.schema import AIMessage, BaseMessage
from langchain_aws import ChatBedrock
from langchain_openai import ChatOpenAI, AzureChatOpenAI
from langchain_openai import AzureChatOpenAI, ChatOpenAI
from omegaconf import DictConfig
from openai import OpenAI, AzureOpenAI
from openai import AzureOpenAI, OpenAI
from pydantic import BaseModel, Field
from tenacity import retry, stop_after_attempt, wait_exponential

Expand All @@ -36,9 +36,7 @@ def describe(self) -> Dict[str, Any]:
"completion_tokens": self.output_,
}

@retry(
stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=10)
)
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=10))
def invoke(self, *args, **kwargs):
input_: List[BaseMessage] = args[0]
response = super().invoke(*args, **kwargs)
Expand Down Expand Up @@ -77,9 +75,7 @@ def describe(self) -> Dict[str, Any]:
"completion_tokens": self.output_,
}

@retry(
stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=10)
)
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=10))
def invoke(self, *args, **kwargs):
input_: List[BaseMessage] = args[0]
response = super().invoke(*args, **kwargs)
Expand Down Expand Up @@ -117,9 +113,7 @@ def describe(self) -> Dict[str, Any]:
"completion_tokens": self.output_,
}

@retry(
stop=stop_after_attempt(50), wait=wait_exponential(multiplier=1, min=4, max=10)
)
@retry(stop=stop_after_attempt(50), wait=wait_exponential(multiplier=1, min=4, max=10))
def invoke(self, *args, **kwargs):
input_: List[BaseMessage] = args[0]
try:
Expand Down Expand Up @@ -152,11 +146,7 @@ def get_openai_models() -> List[str]:
try:
client = OpenAI()
models = client.models.list()
return [
model.id
for model in models
if model.id.startswith(("gpt-3.5", "gpt-4"))
]
return [model.id for model in models if model.id.startswith(("gpt-3.5", "gpt-4"))]
except Exception as e:
print(f"Error fetching OpenAI models: {e}")
return []
Expand All @@ -177,11 +167,7 @@ def get_azure_models() -> List[str]:
try:
client = AzureOpenAI()
models = client.models.list()
return [
model.id
for model in models
if model.id.startswith(("gpt-3.5", "gpt-4"))
]
return [model.id for model in models if model.id.startswith(("gpt-3.5", "gpt-4"))]
except Exception as e:
print(f"Error fetching Azure models: {e}")
return []
Expand Down Expand Up @@ -212,11 +198,14 @@ def _get_azure_chat_model(
else:
raise Exception("Azure API env variable AZURE_API_KEY not set")

logger.info(
f"AGA is using model {config.model} from Azure to assist you with the task."
)
if "OPENAI_API_VERSION" not in os.environ:
raise Exception("Azure API env variable OPENAI_API_VERSION not set")
if "AZURE_OPENAI_ENDPOINT" not in os.environ:
raise Exception("Azure API env variable AZURE_OPENAI_ENDPOINT not set")

logger.info(f"AGA is using model {config.model} from Azure to assist you with the task.")
return AssistantAzureChatOpenAI(
api_key = api_key,
api_key=api_key,
model_name=config.model,
temperature=config.temperature,
max_tokens=config.max_tokens,
Expand All @@ -232,9 +221,7 @@ def _get_openai_chat_model(
else:
raise Exception("OpenAI API env variable OPENAI_API_KEY not set")

logger.info(
f"AGA is using model {config.model} from OpenAI to assist you with the task."
)
logger.info(f"AGA is using model {config.model} from OpenAI to assist you with the task.")
return AssistantChatOpenAI(
model_name=config.model,
temperature=config.temperature,
Expand All @@ -246,9 +233,7 @@ def _get_openai_chat_model(

@staticmethod
def _get_bedrock_chat_model(config: DictConfig) -> AssistantChatBedrock:
logger.info(
f"AGA is using model {config.model} from Bedrock to assist you with the task."
)
logger.info(f"AGA is using model {config.model} from Bedrock to assist you with the task.")

return AssistantChatBedrock(
model_id=config.model,
Expand All @@ -266,19 +251,15 @@ def get_chat_model(
cls, config: DictConfig
) -> Union[AssistantChatOpenAI, AssistantAzureChatOpenAI, AssistantChatBedrock]:
valid_providers = cls.get_valid_providers()
assert (
config.provider in valid_providers
), f"{config.provider} is not a valid provider in: {valid_providers}"
assert config.provider in valid_providers, f"{config.provider} is not a valid provider in: {valid_providers}"

valid_models = cls.get_valid_models(config.provider)
assert (
config.model in valid_models
), f"{config.model} is not a valid model in: {valid_models} for provider {config.provider}"

if config.model not in WHITE_LIST_LLM:
logger.warning(
f"{config.model} is not on the white list. Our white list models include {WHITE_LIST_LLM}"
)
logger.warning(f"{config.model} is not on the white list. Our white list models include {WHITE_LIST_LLM}")

if config.provider == "azure":
return LLMFactory._get_azure_chat_model(config)
Expand Down
7 changes: 5 additions & 2 deletions src/autogluon/assistant/ui/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,16 @@
"Claude 3.5 with Amazon Bedrock": "anthropic.claude-3-5-sonnet-20241022-v2:0",
"GPT 4o with OpenAI": "gpt-4o-2024-08-06",
"GPT 4o with Azure": "gpt-4o-2024-08-06",

}

LLM_OPTIONS = ["Claude 3.5 with Amazon Bedrock", "GPT 4o with OpenAI", "GPT 4o with Azure"]

# Provider configuration
PROVIDER_MAPPING = {"Claude 3.5 with Amazon Bedrock": "bedrock", "GPT 4o with OpenAI": "openai", "GPT 4o with Azure": "azure"}
PROVIDER_MAPPING = {
"Claude 3.5 with Amazon Bedrock": "bedrock",
"GPT 4o with OpenAI": "openai",
"GPT 4o with Azure": "azure",
}

INITIAL_STAGE = {
"Task Understanding": [],
Expand Down

0 comments on commit 27b165b

Please sign in to comment.