diff --git a/rialto/runner/config_overrides.py b/rialto/runner/config_overrides.py index 310c58e..72d59ac 100644 --- a/rialto/runner/config_overrides.py +++ b/rialto/runner/config_overrides.py @@ -1,33 +1,60 @@ -from typing import Dict +# Copyright 2022 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ["override_config"] + +from typing import Dict, List, Tuple from loguru import logger +def _split_index_key(key: str) -> Tuple[str, str]: + name = key.split("[")[0] + index = key.split("[")[1].replace("]", "") + return name, index + + +def _find_first_match(config: List, index: str) -> int: + index_key, index_value = index.split("=") + return next(i for i, x in enumerate(config) if x.get(index_key) == index_value) + + def _override(config, path, value) -> Dict: key = path[0] if "[" in key: - name = key.split("[")[0] - index = key.split("[")[1].replace("]", "") + name, index = _split_index_key(key) + if name not in config: + raise ValueError(f"Invalid key {name}") if "=" in index: - index_key, index_value = index.split("=") - position = next(i for i, x in enumerate(config[name]) if x.get(index_key) == index_value) - if len(path) == 1: - config[name][position] = value - else: - config[name][position] = _override(config[name][position], path[1:], value) + index = _find_first_match(config[name], index) else: index = int(index) - if index >= 0: - if len(path) == 1: - config[name][index] = value - else: - config[name][index] = _override(config[name][index], path[1:], value) + if index >= 0 and index < len(config[name]): + if len(path) == 1: + config[name][index] = value else: - if len(path) == 1: - config[name].append(value) - else: - raise ValueError(f"Invalid index {index} for key {name} in path {path}") + config[name][index] = _override(config[name][index], path[1:], value) + elif index == -1: + if len(path) == 1: + config[name].append(value) + else: + raise ValueError(f"Invalid index {index} for key {name} in path {path}") + else: + raise IndexError(f"Index {index} out of bounds for key {key}") else: + if key not in config: + raise ValueError(f"Invalid key {key}") if len(path) == 1: config[key] = value else: @@ -38,8 +65,8 @@ def _override(config, path, value) -> Dict: def override_config(config: Dict, overrides: Dict) -> Dict: """Override config with user input - :param config: Config dictionary - :param overrides: Dictionary of overrides + :param config: config dictionary + :param overrides: dictionary of overrides :return: Overridden config """ for path, value in overrides.items(): diff --git a/tests/runner/test_overrides.py b/tests/runner/test_overrides.py index 0962c0e..996fa10 100644 --- a/tests/runner/test_overrides.py +++ b/tests/runner/test_overrides.py @@ -1,10 +1,25 @@ +# Copyright 2022 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + from rialto.runner import Runner def test_overrides_simple(spark): runner = Runner( spark, - config_path="tests/runner/transformations/config.yaml", + config_path="tests/runner/overrider.yaml", run_date="2023-03-31", overrides={"runner.mail.to": ["x@b.c", "y@b.c", "z@b.c"]}, ) @@ -14,7 +29,7 @@ def test_overrides_simple(spark): def test_overrides_array_index(spark): runner = Runner( spark, - config_path="tests/runner/transformations/config.yaml", + config_path="tests/runner/overrider.yaml", run_date="2023-03-31", overrides={"runner.mail.to[1]": "a@b.c"}, ) @@ -24,7 +39,7 @@ def test_overrides_array_index(spark): def test_overrides_array_append(spark): runner = Runner( spark, - config_path="tests/runner/transformations/config.yaml", + config_path="tests/runner/overrider.yaml", run_date="2023-03-31", overrides={"runner.mail.to[-1]": "test"}, ) @@ -34,8 +49,57 @@ def test_overrides_array_append(spark): def test_overrides_array_lookup(spark): runner = Runner( spark, - config_path="tests/runner/transformations/config.yaml", + config_path="tests/runner/overrider.yaml", run_date="2023-03-31", overrides={"pipelines[name=SimpleGroup].target.target_schema": "new_schema"}, ) assert runner.config.pipelines[0].target.target_schema == "new_schema" + + +def test_overrides_combined(spark): + runner = Runner( + spark, + config_path="tests/runner/overrider.yaml", + run_date="2023-03-31", + overrides={ + "runner.mail.to": ["x@b.c", "y@b.c", "z@b.c"], + "pipelines[name=SimpleGroup].target.target_schema": "new_schema", + "pipelines[name=SimpleGroup].schedule.info_date_shift[0].value": 1, + }, + ) + assert runner.config.runner.mail.to == ["x@b.c", "y@b.c", "z@b.c"] + assert runner.config.pipelines[0].target.target_schema == "new_schema" + assert runner.config.pipelines[0].schedule.info_date_shift[0].value == 1 + + +def test_index_out_of_range(spark): + with pytest.raises(IndexError) as error: + Runner( + spark, + config_path="tests/runner/overrider.yaml", + run_date="2023-03-31", + overrides={"runner.mail.to[8]": "test"}, + ) + assert error.value.args[0] == "Index 8 out of bounds for key to[8]" + + +def test_invalid_index_key(spark): + with pytest.raises(ValueError) as error: + Runner( + spark, + config_path="tests/runner/overrider.yaml", + run_date="2023-03-31", + overrides={"runner.mail.test[8]": "test"}, + ) + assert error.value.args[0] == "Invalid key test" + + +def test_invalid_key(spark): + with pytest.raises(ValueError) as error: + Runner( + spark, + config_path="tests/runner/overrider.yaml", + run_date="2023-03-31", + overrides={"runner.mail.test": "test"}, + ) + assert error.value.args[0] == "Invalid key test"