Skip to content

Commit

Permalink
add presets
Browse files Browse the repository at this point in the history
  • Loading branch information
FANGAreNotGnu committed Nov 1, 2024
1 parent 43c3ed2 commit ab63347
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 11 deletions.
8 changes: 3 additions & 5 deletions config/config.yaml → configs/best_quality.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ save_artifacts:
enabled: False
append_timestamp: True
path: "./aga-artifacts"
# path: "s3://autogluon-assistant-agts/outputs/<run_id>/aga-artifacts/"
feature_transformers:
- _target_: autogluon_assistant.transformer.CAAFETransformer
eval_model: lightgbm
Expand All @@ -24,20 +23,19 @@ autogluon:
predictor_init_kwargs: {}
predictor_fit_kwargs:
verbosity: 2
presets: best_quality
time_limit: 60
presets: medium_quality
time_limit: 120
dynamic_stacking: True
num_bag_folds: 5
num_stack_levels: 2
llm:
# Note: bedrock is only supported in limited AWS regions
provider: bedrock
api_key_location: BEDROCK_API_KEY
model: "anthropic.claude-3-5-sonnet-20241022-v2:0" #"anthropic.claude-3-5-sonnet-20240620-v1:0"
model: "anthropic.claude-3-5-sonnet-20241022-v2:0"
# provider: openai
# api_key_location: OPENAI_API_KEY
# model: gpt-4o-2024-08-06
# model: gpt-3.5-turbo
max_tokens: 512
proxy_url: null
temperature: 0
Expand Down
29 changes: 29 additions & 0 deletions configs/high_quality.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
infer_eval_metric: True
detect_and_drop_id_column: False
task_preprocessors_timeout: 3600
save_artifacts:
enabled: False
append_timestamp: True
path: "./aga-artifacts"
feature_transformers:
autogluon:
predictor_init_kwargs: {}
predictor_fit_kwargs:
verbosity: 2
presets: high_quality
time_limit: 3600
dynamic_stacking: True
num_bag_folds: 5
num_stack_levels: 2
llm:
# Note: bedrock is only supported in limited AWS regions
provider: bedrock
api_key_location: BEDROCK_API_KEY
model: "anthropic.claude-3-5-sonnet-20241022-v2:0"
# provider: openai
# api_key_location: OPENAI_API_KEY
# model: gpt-4o-2024-08-06
max_tokens: 512
proxy_url: null
temperature: 0
verbose: True
29 changes: 29 additions & 0 deletions configs/medium_quality.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
infer_eval_metric: True
detect_and_drop_id_column: False
task_preprocessors_timeout: 3600
save_artifacts:
enabled: False
append_timestamp: True
path: "./aga-artifacts"
feature_transformers:
autogluon:
predictor_init_kwargs: {}
predictor_fit_kwargs:
verbosity: 2
presets: medium_quality
time_limit: 600
dynamic_stacking: True
num_bag_folds: 5
num_stack_levels: 2
llm:
# Note: bedrock is only supported in limited AWS regions
provider: bedrock
api_key_location: BEDROCK_API_KEY
model: "anthropic.claude-3-5-sonnet-20241022-v2:0"
# provider: openai
# api_key_location: OPENAI_API_KEY
# model: gpt-4o-2024-08-06
max_tokens: 512
proxy_url: null
temperature: 0
verbose: True
14 changes: 12 additions & 2 deletions src/autogluon_assistant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing_extensions import Annotated

from .assistant import TabularPredictionAssistant
from .constants import NO_ID_COLUMN_IDENTIFIED
from .constants import DEFAULT_QUALITY, NO_ID_COLUMN_IDENTIFIED, PRESETS
from .task import TabularPredictionTask
from .utils import load_config

Expand Down Expand Up @@ -74,6 +74,10 @@ def make_prediction_outputs(task: TabularPredictionTask, predictions: pd.DataFra

def run_assistant(
task_path: Annotated[str, typer.Argument(help="Directory where task files are included")],
presets: Annotated[
Optional[str],
typer.Option("--presets", "-p", help="Presets"),
] = None,
config_path: Annotated[
Optional[str],
typer.Option("--config-path", "-c", help="Path to the configuration file (config.yaml)"),
Expand All @@ -90,9 +94,15 @@ def run_assistant(
) -> str:
logging.info("Starting AutoGluon-Assistant")

if presets is None or presets not in PRESETS:
logging.info(f"Presets is not provided or invalid: {presets}")
presets = DEFAULT_QUALITY
logging.info(f"Using default presets: {presets}")
logging.info(f"Presets: {presets}")

# Load config with all overrides
try:
config = load_config(config_path, config_overrides)
config = load_config(presets, config_path, config_overrides)
logging.info("Successfully loaded config")
except Exception as e:
logging.error(f"Failed to load config: {e}")
Expand Down
8 changes: 8 additions & 0 deletions src/autogluon_assistant/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@
PROBLEM_TYPES = [BINARY, MULTICLASS, REGRESSION]
CLASSIFICATION_PROBLEM_TYPES = [BINARY, MULTICLASS]

# Presets/Configs
CONFIGS = "configs"
MEDIUM_QUALITY = "medium_quality"
HIGH_QUALITY = "high_quality"
BEST_QUALITY = "best_quality"
DEFAULT_QUALITY = BEST_QUALITY
PRESETS = [MEDIUM_QUALITY, HIGH_QUALITY, BEST_QUALITY]

# Metrics
ROC_AUC = "roc_auc"
LOG_LOSS = "log_loss"
Expand Down
14 changes: 10 additions & 4 deletions src/autogluon_assistant/utils/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@

from omegaconf import OmegaConf

from ..constants import CONFIGS

def _get_default_config_path() -> Path:

def _get_default_config_path(
presets: str,
) -> Path:
"""
Get default config folder under package root
Returns Path to the config.yaml file
"""
current_file = Path(__file__).parent.parent.parent.parent.absolute()
config_path = current_file / "config" / "config.yaml"
config_path = current_file / CONFIGS / f"{presets}.yaml"

if not config_path.exists():
raise ValueError(f"Config file not found at expected location: {config_path}")
Expand Down Expand Up @@ -77,7 +81,9 @@ def apply_overrides(config: Dict[str, Any], overrides: List[str]) -> Dict[str, A
return OmegaConf.merge(config, override_conf)


def load_config(config_path: Optional[str] = None, overrides: Optional[List[str]] = None) -> Dict[str, Any]:
def load_config(
presets: str, config_path: Optional[str] = None, overrides: Optional[List[str]] = None
) -> Dict[str, Any]:
"""
Load configuration from yaml file, merging with default config and applying overrides
Expand All @@ -92,7 +98,7 @@ def load_config(config_path: Optional[str] = None, overrides: Optional[List[str]
ValueError: If config file not found or invalid
"""
# Load default config
default_config_path = _get_default_config_path()
default_config_path = _get_default_config_path(presets)
logging.info(f"Loading default config from: {default_config_path}")
config = OmegaConf.load(default_config_path)

Expand Down

0 comments on commit ab63347

Please sign in to comment.