Skip to content

Commit

Permalink
override tests and cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
MDobransky committed Aug 28, 2024
1 parent eea51cd commit 6427626
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 24 deletions.
67 changes: 47 additions & 20 deletions rialto/runner/config_overrides.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,60 @@
from typing import Dict
# Copyright 2022 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

__all__ = ["override_config"]

from typing import Dict, List, Tuple

from loguru import logger


def _split_index_key(key: str) -> Tuple[str, str]:
name = key.split("[")[0]
index = key.split("[")[1].replace("]", "")
return name, index


def _find_first_match(config: List, index: str) -> int:
index_key, index_value = index.split("=")
return next(i for i, x in enumerate(config) if x.get(index_key) == index_value)


def _override(config, path, value) -> Dict:
key = path[0]
if "[" in key:
name = key.split("[")[0]
index = key.split("[")[1].replace("]", "")
name, index = _split_index_key(key)
if name not in config:
raise ValueError(f"Invalid key {name}")
if "=" in index:
index_key, index_value = index.split("=")
position = next(i for i, x in enumerate(config[name]) if x.get(index_key) == index_value)
if len(path) == 1:
config[name][position] = value
else:
config[name][position] = _override(config[name][position], path[1:], value)
index = _find_first_match(config[name], index)
else:
index = int(index)
if index >= 0:
if len(path) == 1:
config[name][index] = value
else:
config[name][index] = _override(config[name][index], path[1:], value)
if index >= 0 and index < len(config[name]):
if len(path) == 1:
config[name][index] = value
else:
if len(path) == 1:
config[name].append(value)
else:
raise ValueError(f"Invalid index {index} for key {name} in path {path}")
config[name][index] = _override(config[name][index], path[1:], value)
elif index == -1:
if len(path) == 1:
config[name].append(value)
else:
raise ValueError(f"Invalid index {index} for key {name} in path {path}")
else:
raise IndexError(f"Index {index} out of bounds for key {key}")
else:
if key not in config:
raise ValueError(f"Invalid key {key}")
if len(path) == 1:
config[key] = value
else:
Expand All @@ -38,8 +65,8 @@ def _override(config, path, value) -> Dict:
def override_config(config: Dict, overrides: Dict) -> Dict:
"""Override config with user input
:param config: Config dictionary
:param overrides: Dictionary of overrides
:param config: config dictionary
:param overrides: dictionary of overrides
:return: Overridden config
"""
for path, value in overrides.items():
Expand Down
72 changes: 68 additions & 4 deletions tests/runner/test_overrides.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,25 @@
# Copyright 2022 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest

from rialto.runner import Runner


def test_overrides_simple(spark):
runner = Runner(
spark,
config_path="tests/runner/transformations/config.yaml",
config_path="tests/runner/overrider.yaml",
run_date="2023-03-31",
overrides={"runner.mail.to": ["[email protected]", "[email protected]", "[email protected]"]},
)
Expand All @@ -14,7 +29,7 @@ def test_overrides_simple(spark):
def test_overrides_array_index(spark):
runner = Runner(
spark,
config_path="tests/runner/transformations/config.yaml",
config_path="tests/runner/overrider.yaml",
run_date="2023-03-31",
overrides={"runner.mail.to[1]": "[email protected]"},
)
Expand All @@ -24,7 +39,7 @@ def test_overrides_array_index(spark):
def test_overrides_array_append(spark):
runner = Runner(
spark,
config_path="tests/runner/transformations/config.yaml",
config_path="tests/runner/overrider.yaml",
run_date="2023-03-31",
overrides={"runner.mail.to[-1]": "test"},
)
Expand All @@ -34,8 +49,57 @@ def test_overrides_array_append(spark):
def test_overrides_array_lookup(spark):
runner = Runner(
spark,
config_path="tests/runner/transformations/config.yaml",
config_path="tests/runner/overrider.yaml",
run_date="2023-03-31",
overrides={"pipelines[name=SimpleGroup].target.target_schema": "new_schema"},
)
assert runner.config.pipelines[0].target.target_schema == "new_schema"


def test_overrides_combined(spark):
runner = Runner(
spark,
config_path="tests/runner/overrider.yaml",
run_date="2023-03-31",
overrides={
"runner.mail.to": ["[email protected]", "[email protected]", "[email protected]"],
"pipelines[name=SimpleGroup].target.target_schema": "new_schema",
"pipelines[name=SimpleGroup].schedule.info_date_shift[0].value": 1,
},
)
assert runner.config.runner.mail.to == ["[email protected]", "[email protected]", "[email protected]"]
assert runner.config.pipelines[0].target.target_schema == "new_schema"
assert runner.config.pipelines[0].schedule.info_date_shift[0].value == 1


def test_index_out_of_range(spark):
with pytest.raises(IndexError) as error:
Runner(
spark,
config_path="tests/runner/overrider.yaml",
run_date="2023-03-31",
overrides={"runner.mail.to[8]": "test"},
)
assert error.value.args[0] == "Index 8 out of bounds for key to[8]"


def test_invalid_index_key(spark):
with pytest.raises(ValueError) as error:
Runner(
spark,
config_path="tests/runner/overrider.yaml",
run_date="2023-03-31",
overrides={"runner.mail.test[8]": "test"},
)
assert error.value.args[0] == "Invalid key test"


def test_invalid_key(spark):
with pytest.raises(ValueError) as error:
Runner(
spark,
config_path="tests/runner/overrider.yaml",
run_date="2023-03-31",
overrides={"runner.mail.test": "test"},
)
assert error.value.args[0] == "Invalid key test"

0 comments on commit 6427626

Please sign in to comment.