-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
bdfe008
commit 43c3ed2
Showing
3 changed files
with
136 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from .configs import load_config | ||
from .files import is_text_file, load_pd_quietly |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import logging | ||
from pathlib import Path | ||
from typing import Any, Dict, List, Optional | ||
|
||
from omegaconf import OmegaConf | ||
|
||
|
||
def _get_default_config_path() -> Path: | ||
""" | ||
Get default config folder under package root | ||
Returns Path to the config.yaml file | ||
""" | ||
current_file = Path(__file__).parent.parent.parent.parent.absolute() | ||
config_path = current_file / "config" / "config.yaml" | ||
|
||
if not config_path.exists(): | ||
raise ValueError(f"Config file not found at expected location: {config_path}") | ||
|
||
return config_path | ||
|
||
|
||
def parse_override(override: str) -> tuple: | ||
""" | ||
Parse a single override string in the format 'key=value' or 'key.nested=value' | ||
Args: | ||
override: String in format "key=value" or "key.nested=value" | ||
Returns: | ||
Tuple of (key, value) | ||
Raises: | ||
ValueError: If override string is not in correct format | ||
""" | ||
if "=" not in override: | ||
raise ValueError(f"Invalid override format: {override}. Must be in format 'key=value' or 'key.nested=value'") | ||
key, value = override.split("=", 1) | ||
return key, value | ||
|
||
|
||
def apply_overrides(config: Dict[str, Any], overrides: List[str]) -> Dict[str, Any]: | ||
""" | ||
Apply command-line overrides to config | ||
Args: | ||
config: Base configuration | ||
overrides: List of overrides in format ["key1=value1", "key2.nested=value2"] | ||
Returns: | ||
Updated configuration | ||
""" | ||
if not overrides: | ||
return config | ||
|
||
# Convert overrides to nested dict | ||
override_conf = {} | ||
for override in overrides: | ||
key, value = parse_override(override) | ||
|
||
# Try to convert value to appropriate type | ||
try: | ||
# Try to evaluate as literal (for numbers, bools, etc) | ||
value = eval(value) | ||
except: | ||
# Keep as string if eval fails | ||
pass | ||
|
||
# Handle nested keys | ||
current = override_conf | ||
key_parts = key.split(".") | ||
for part in key_parts[:-1]: | ||
current = current.setdefault(part, {}) | ||
current[key_parts[-1]] = value | ||
|
||
# Convert override dict to OmegaConf and merge | ||
override_conf = OmegaConf.create(override_conf) | ||
return OmegaConf.merge(config, override_conf) | ||
|
||
|
||
def load_config(config_path: Optional[str] = None, overrides: Optional[List[str]] = None) -> Dict[str, Any]: | ||
""" | ||
Load configuration from yaml file, merging with default config and applying overrides | ||
Args: | ||
config_path: Optional path to config file. If provided, will merge with and override default config | ||
overrides: Optional list of command-line overrides in format ["key1=value1", "key2.nested=value2"] | ||
Returns: | ||
Loaded and merged configuration | ||
Raises: | ||
ValueError: If config file not found or invalid | ||
""" | ||
# Load default config | ||
default_config_path = _get_default_config_path() | ||
logging.info(f"Loading default config from: {default_config_path}") | ||
config = OmegaConf.load(default_config_path) | ||
|
||
# If custom config provided, merge it | ||
if config_path: | ||
custom_config_path = Path(config_path) | ||
if not custom_config_path.is_file(): | ||
raise ValueError(f"Custom config file not found at: {custom_config_path}") | ||
|
||
logging.info(f"Loading custom config from: {custom_config_path}") | ||
custom_config = OmegaConf.load(custom_config_path) | ||
config = OmegaConf.merge(config, custom_config) | ||
logging.info("Successfully merged custom config with default config") | ||
|
||
# Apply command-line overrides if any | ||
if overrides: | ||
logging.info(f"Applying command-line overrides: {overrides}") | ||
config = apply_overrides(config, overrides) | ||
logging.info("Successfully applied command-line overrides") | ||
|
||
return config |