Skip to content

Commit

Permalink
Refine Config Logic (#84)
Browse files Browse the repository at this point in the history
  • Loading branch information
FANGAreNotGnu authored Nov 1, 2024
1 parent bdfe008 commit 43c3ed2
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 15 deletions.
34 changes: 19 additions & 15 deletions src/autogluon_assistant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,24 @@
import logging
import os
from pathlib import Path
from typing import Optional
from typing import List, Optional

import pandas as pd
import typer
from hydra import compose, initialize
from omegaconf import OmegaConf
from rich import print as rprint
from typing_extensions import Annotated

from .assistant import TabularPredictionAssistant
from .constants import NO_ID_COLUMN_IDENTIFIED
from .task import TabularPredictionTask
from .utils import load_config

logging.basicConfig(level=logging.INFO)

__all__ = ["TabularPredictionAssistant", "TabularPredictionTask"]


def _resolve_config_path(path: str):
print(Path.cwd())
return os.path.relpath(Path(path), Path(__file__).parent.absolute())


def get_task(path: Path) -> TabularPredictionTask:
"""Get a task from a path."""

Expand Down Expand Up @@ -81,18 +76,27 @@ def run_assistant(
task_path: Annotated[str, typer.Argument(help="Directory where task files are included")],
config_path: Annotated[
Optional[str],
typer.Option("--config-path", "-c", help="Path to the configuration file (config.yaml)"),
] = None,
config_overrides: Annotated[
Optional[List[str]],
typer.Option(
"--config-path", "-c", help="Path to the configuration directory, which includes a config.yaml file"
"--config_overrides",
"-o",
help="Override config values. Format: key=value or key.nested=value. Can be used multiple times.",
),
] = "./config/",
] = None,
output_filename: Annotated[Optional[str], typer.Option(help="Output File")] = "",
config_overrides: Annotated[Optional[str], typer.Option(help="Overrides for the config in Hydra format")] = "",
) -> str:
"""Run AutoGluon-Assistant on a task defined in a path."""
rel_config_path = _resolve_config_path(config_path)
with initialize(version_base=None, config_path=rel_config_path):
overrides_list = config_overrides.split(" ") if config_overrides else []
config = compose(config_name="config", overrides=overrides_list)
logging.info("Starting AutoGluon-Assistant")

# Load config with all overrides
try:
config = load_config(config_path, config_overrides)
logging.info("Successfully loaded config")
except Exception as e:
logging.error(f"Failed to load config: {e}")
raise

rprint("🤖 [bold red] Welcome to AutoGluon-Assistant [/bold red]")

Expand Down
1 change: 1 addition & 0 deletions src/autogluon_assistant/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .configs import load_config
from .files import is_text_file, load_pd_quietly
116 changes: 116 additions & 0 deletions src/autogluon_assistant/utils/configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional

from omegaconf import OmegaConf


def _get_default_config_path() -> Path:
"""
Get default config folder under package root
Returns Path to the config.yaml file
"""
current_file = Path(__file__).parent.parent.parent.parent.absolute()
config_path = current_file / "config" / "config.yaml"

if not config_path.exists():
raise ValueError(f"Config file not found at expected location: {config_path}")

return config_path


def parse_override(override: str) -> tuple:
"""
Parse a single override string in the format 'key=value' or 'key.nested=value'
Args:
override: String in format "key=value" or "key.nested=value"
Returns:
Tuple of (key, value)
Raises:
ValueError: If override string is not in correct format
"""
if "=" not in override:
raise ValueError(f"Invalid override format: {override}. Must be in format 'key=value' or 'key.nested=value'")
key, value = override.split("=", 1)
return key, value


def apply_overrides(config: Dict[str, Any], overrides: List[str]) -> Dict[str, Any]:
"""
Apply command-line overrides to config
Args:
config: Base configuration
overrides: List of overrides in format ["key1=value1", "key2.nested=value2"]
Returns:
Updated configuration
"""
if not overrides:
return config

# Convert overrides to nested dict
override_conf = {}
for override in overrides:
key, value = parse_override(override)

# Try to convert value to appropriate type
try:
# Try to evaluate as literal (for numbers, bools, etc)
value = eval(value)
except:
# Keep as string if eval fails
pass

# Handle nested keys
current = override_conf
key_parts = key.split(".")
for part in key_parts[:-1]:
current = current.setdefault(part, {})
current[key_parts[-1]] = value

# Convert override dict to OmegaConf and merge
override_conf = OmegaConf.create(override_conf)
return OmegaConf.merge(config, override_conf)


def load_config(config_path: Optional[str] = None, overrides: Optional[List[str]] = None) -> Dict[str, Any]:
"""
Load configuration from yaml file, merging with default config and applying overrides
Args:
config_path: Optional path to config file. If provided, will merge with and override default config
overrides: Optional list of command-line overrides in format ["key1=value1", "key2.nested=value2"]
Returns:
Loaded and merged configuration
Raises:
ValueError: If config file not found or invalid
"""
# Load default config
default_config_path = _get_default_config_path()
logging.info(f"Loading default config from: {default_config_path}")
config = OmegaConf.load(default_config_path)

# If custom config provided, merge it
if config_path:
custom_config_path = Path(config_path)
if not custom_config_path.is_file():
raise ValueError(f"Custom config file not found at: {custom_config_path}")

logging.info(f"Loading custom config from: {custom_config_path}")
custom_config = OmegaConf.load(custom_config_path)
config = OmegaConf.merge(config, custom_config)
logging.info("Successfully merged custom config with default config")

# Apply command-line overrides if any
if overrides:
logging.info(f"Applying command-line overrides: {overrides}")
config = apply_overrides(config, overrides)
logging.info("Successfully applied command-line overrides")

return config

0 comments on commit 43c3ed2

Please sign in to comment.