From 6ca429d7bebe77468df3010328e263d0b31fe7d5 Mon Sep 17 00:00:00 2001 From: Marek Dobransky Date: Thu, 5 Sep 2024 12:37:30 +0200 Subject: [PATCH] custom env loader --- .flake8 | 1 + pyproject.toml | 1 - rialto/common/env_yaml.py | 28 ++++++++++++++++++++++++++++ rialto/common/utils.py | 3 ++- 4 files changed, 31 insertions(+), 2 deletions(-) create mode 100644 rialto/common/env_yaml.py diff --git a/.flake8 b/.flake8 index 21099b7..c2cf6c9 100644 --- a/.flake8 +++ b/.flake8 @@ -14,3 +14,4 @@ extend-ignore = D100, D104, D107, + E203, diff --git a/pyproject.toml b/pyproject.toml index 5812612..23aa34e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,6 @@ pandas = "^2.1.0" flake8-broken-line = "^1.0.0" loguru = "^0.7.2" importlib-metadata = "^7.2.1" -env_yaml = "^0.0.3" [tool.poetry.dev-dependencies] pyspark = "^3.4.1" diff --git a/rialto/common/env_yaml.py b/rialto/common/env_yaml.py new file mode 100644 index 0000000..ec2f591 --- /dev/null +++ b/rialto/common/env_yaml.py @@ -0,0 +1,28 @@ +import os +import re + +import yaml +from loguru import logger + +__all__ = ["EnvLoader"] + +_path_matcher = re.compile(r"\$\{(?P[^}^{:]+)(?::(?P[^}^{]*))?\}") + + +def _path_constructor(loader, node): + value = node.value + match = _path_matcher.match(value) + sub = os.getenv(match.group("env_name"), match.group("default_value")) + new_value = value[0 : match.start()] + sub + value[match.end() :] + logger.info(f"Config: Replacing {value}, with {new_value}") + return new_value + + +class EnvLoader(yaml.SafeLoader): + """Custom loader that replaces values with environment variables""" + + pass + + +EnvLoader.add_implicit_resolver("!env_substitute", _path_matcher, None) +EnvLoader.add_constructor("!env_substitute", _path_constructor) diff --git a/rialto/common/utils.py b/rialto/common/utils.py index b2e19b4..6f5ed1f 100644 --- a/rialto/common/utils.py +++ b/rialto/common/utils.py @@ -19,10 +19,11 @@ import pyspark.sql.functions as F import yaml -from env_yaml import EnvLoader from pyspark.sql import DataFrame from pyspark.sql.types import FloatType +from rialto.common.env_yaml import EnvLoader + def load_yaml(path: str) -> Any: """