+
+ """
+
+ @staticmethod
+ def _make_exceptions(records: List[Record]):
+ html = ""
+ for record, i in zip(records, range(len(records))):
+ if record.exception is not None:
+ r = f"""
+
+
+
{record.job}
+
{record.date}
+
+
+
+ Expand
+
+
+
+
{record.exception}
+
+
+
+ """
+ html += r
+ return html
+
+ @staticmethod
+ def _make_insights(records: List[Record]):
+ return f"""
+
+
+
Exceptions
+
+
+ {HTMLMessage._make_exceptions(records)}
+ """
+
+ @staticmethod
+ def make_report(target: str, start: datetime, records: List[Record]) -> str:
+ """Create html email report"""
+ html = [
+ """
+ """,
+ HTMLMessage._head(),
+ HTMLMessage._body_open(),
+ HTMLMessage._make_header(target, start),
+ HTMLMessage._make_overview(records),
+ HTMLMessage._make_insights(records),
+ HTMLMessage._body_close(),
+ ]
+ return "\n".join(html)
+
+
+class Mailer:
+ @staticmethod
+ def create_message(subject: str, sender: str, receiver: str, body: str) -> MIMEMultipart:
+ msg = MIMEMultipart()
+ msg["Subject"] = subject
+ msg["From"] = sender
+ msg["To"] = receiver
+ body = MIMEText(body, "html")
+ msg.attach(body)
+ return msg
+
+ @staticmethod
+ def send_mail(smtp: str, message: MIMEMultipart):
+ s = smtplib.SMTP(host=smtp, port=25)
+ s.sendmail(from_addr=message["From"], to_addrs=message["To"], msg=message.as_string())
+ s.quit()
diff --git a/rialto/runner/transformation.py b/rialto/runner/transformation.py
new file mode 100644
index 0000000..210cb0b
--- /dev/null
+++ b/rialto/runner/transformation.py
@@ -0,0 +1,48 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+__all__ = ["Transformation"]
+
+import abc
+import datetime
+from typing import Dict
+
+from pyspark.sql import DataFrame, SparkSession
+
+from rialto.common import TableReader
+from rialto.metadata import MetadataManager
+
+
+class Transformation(metaclass=abc.ABCMeta):
+ """Interface for feature implementation"""
+
+ @abc.abstractmethod
+ def run(
+ self,
+ reader: TableReader,
+ run_date: datetime.date,
+ spark: SparkSession = None,
+ metadata_manager: MetadataManager = None,
+ dependencies: Dict = None,
+ ) -> DataFrame:
+ """
+ Run the transformation
+
+ :param reader: data store api object
+ :param run_date: date
+ :param spark: spark session
+ :param metadata_manager: metadata api object
+ :return: dataframe
+ """
+ raise NotImplementedError
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..79c3773
--- /dev/null
+++ b/tests/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/common/conftest.py b/tests/common/conftest.py
new file mode 100644
index 0000000..79455ff
--- /dev/null
+++ b/tests/common/conftest.py
@@ -0,0 +1,36 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import pytest
+from pyspark.sql import SparkSession
+
+
+@pytest.fixture(scope="session")
+def spark(request):
+ """fixture for creating a spark session
+ Args:
+ request: pytest.FixtureRequest object
+ """
+
+ spark = (
+ SparkSession.builder.master("local[2]")
+ .appName("pytest-pyspark-local-testing")
+ .config("spark.ui.enabled", "false")
+ .config("spark.driver.bindAddress", "127.0.0.1")
+ .config("spark.driver.host", "127.0.0.1")
+ .getOrCreate()
+ )
+
+ request.addfinalizer(lambda: spark.stop())
+
+ return spark
diff --git a/tests/common/test_utils.py b/tests/common/test_utils.py
new file mode 100644
index 0000000..cd3ebd9
--- /dev/null
+++ b/tests/common/test_utils.py
@@ -0,0 +1,44 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pyspark.sql.functions as F
+import pytest
+from numpy import dtype
+
+from rialto.common.utils import cast_decimals_to_floats
+
+
+@pytest.fixture
+def sample_df(spark):
+ df = spark.createDataFrame(
+ [(1, 2.33, "str", 4.55, 5.66), (1, 2.33, "str", 4.55, 5.66), (1, 2.33, "str", 4.55, 5.66)],
+ schema="a long, b float, c string, d float, e float",
+ )
+
+ return df.select("a", "b", "c", F.col("d").cast("decimal"), F.col("e").cast("decimal(18,5)"))
+
+
+def test_cast_decimals_to_floats(sample_df):
+ df_fixed = cast_decimals_to_floats(sample_df)
+
+ assert df_fixed.dtypes[3] == ("d", "float")
+ assert df_fixed.dtypes[4] == ("e", "float")
+
+
+def test_cast_decimals_to_floats_topandas_works(sample_df):
+ df_fixed = cast_decimals_to_floats(sample_df)
+ df_pd = df_fixed.toPandas()
+
+ assert df_pd.dtypes[3] == dtype("float32")
+ assert df_pd.dtypes[4] == dtype("float32")
diff --git a/tests/jobs/conftest.py b/tests/jobs/conftest.py
new file mode 100644
index 0000000..dda863d
--- /dev/null
+++ b/tests/jobs/conftest.py
@@ -0,0 +1,37 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import pytest
+from pyspark.sql import SparkSession
+
+
+@pytest.fixture(scope="session")
+def spark(request):
+ """fixture for creating a spark session
+ :param request: pytest.FixtureRequest object
+ """
+
+ spark = (
+ SparkSession.builder.master("local[3]")
+ .appName("pytest-pyspark-local-testing")
+ .config("spark.ui.enabled", "false")
+ .config("spark.driver.bindAddress", "127.0.0.1")
+ .config("spark.driver.host", "127.0.0.1")
+ .getOrCreate()
+ )
+
+ request.addfinalizer(lambda: spark.stop())
+
+ return spark
diff --git a/tests/jobs/resources.py b/tests/jobs/resources.py
new file mode 100644
index 0000000..4d33fad
--- /dev/null
+++ b/tests/jobs/resources.py
@@ -0,0 +1,43 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import typing
+
+import pandas as pd
+
+from rialto.jobs.decorators.job_base import JobBase
+
+
+def custom_callable():
+ pass
+
+
+class CustomJobNoReturnVal(JobBase):
+ def get_job_name(self) -> str:
+ return "job_name"
+
+ def get_job_version(self) -> str:
+ return "job_version"
+
+ def get_custom_callable(self) -> typing.Callable:
+ return custom_callable
+
+
+class CustomJobReturnsDataFrame(CustomJobNoReturnVal):
+ def get_custom_callable(self) -> typing.Callable:
+ def f(spark):
+ df = pd.DataFrame([["A", 1], ["B", 2]], columns=["FIRST", "SECOND"])
+
+ return spark.createDataFrame(df)
+
+ return f
diff --git a/tests/jobs/test_config_holder.py b/tests/jobs/test_config_holder.py
new file mode 100644
index 0000000..38fadb1
--- /dev/null
+++ b/tests/jobs/test_config_holder.py
@@ -0,0 +1,100 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from datetime import date
+
+import pytest
+
+from rialto.jobs.configuration.config_holder import (
+ ConfigException,
+ ConfigHolder,
+ FeatureStoreConfig,
+)
+
+
+def test_run_date_unset():
+ with pytest.raises(ConfigException):
+ ConfigHolder.get_run_date()
+
+
+def test_run_date():
+ dt = date(2023, 1, 1)
+
+ ConfigHolder.set_run_date(dt)
+
+ assert ConfigHolder.get_run_date() == dt
+
+
+def test_feature_store_config_unset():
+ with pytest.raises(ConfigException):
+ ConfigHolder.get_feature_store_config()
+
+
+def test_feature_store_config():
+ ConfigHolder.set_feature_store_config("store_schema", "metadata_schema")
+
+ fsc = ConfigHolder.get_feature_store_config()
+
+ assert type(fsc) is FeatureStoreConfig
+ assert fsc.feature_store_schema == "store_schema"
+ assert fsc.feature_metadata_schema == "metadata_schema"
+
+
+def test_config_unset():
+ config = ConfigHolder.get_config()
+
+ assert type(config) is type({})
+ assert len(config.items()) == 0
+
+
+def test_config_dict_copied_not_ref():
+ """Test that config holder config can't be set from outside"""
+ config = ConfigHolder.get_config()
+
+ config["test"] = 123
+
+ assert "test" not in ConfigHolder.get_config()
+
+
+def test_config():
+ ConfigHolder.set_custom_config(hello=123)
+ ConfigHolder.set_custom_config(world="test")
+
+ config = ConfigHolder.get_config()
+
+ assert config["hello"] == 123
+ assert config["world"] == "test"
+
+
+def test_config_from_dict():
+ ConfigHolder.set_custom_config(**{"dict_item_1": 123, "dict_item_2": 456})
+
+ config = ConfigHolder.get_config()
+
+ assert config["dict_item_1"] == 123
+ assert config["dict_item_2"] == 456
+
+
+def test_dependencies_unset():
+ deps = ConfigHolder.get_dependency_config()
+ assert len(deps.keys()) == 0
+
+
+def test_dependencies():
+ ConfigHolder.set_dependency_config({"hello": 123})
+
+ deps = ConfigHolder.get_dependency_config()
+
+ assert deps["hello"] == 123
diff --git a/tests/jobs/test_decorators.py b/tests/jobs/test_decorators.py
new file mode 100644
index 0000000..e896cec
--- /dev/null
+++ b/tests/jobs/test_decorators.py
@@ -0,0 +1,65 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from importlib import import_module
+
+from rialto.jobs.configuration.config_holder import ConfigHolder
+from rialto.jobs.decorators.job_base import JobBase
+from rialto.jobs.decorators.resolver import Resolver
+
+
+def test_dataset_decorator():
+ _ = import_module("tests.jobs.test_job.test_job")
+ test_dataset = Resolver.resolve("dataset")
+
+ assert test_dataset == "dataset_return"
+
+
+def _rialto_import_stub(module_name, class_name):
+ module = import_module(module_name)
+ class_obj = getattr(module, class_name)
+ return class_obj()
+
+
+def test_job_function_type():
+ result_class = _rialto_import_stub("tests.jobs.test_job.test_job", "job_function")
+ assert issubclass(type(result_class), JobBase)
+
+
+def test_job_function_callables_filled():
+ result_class = _rialto_import_stub("tests.jobs.test_job.test_job", "job_function")
+
+ custom_callable = result_class.get_custom_callable()
+ assert custom_callable() == "job_function_return"
+
+ version = result_class.get_job_version()
+ assert version == "N/A"
+
+ job_name = result_class.get_job_name()
+ assert job_name == "job_function"
+
+
+def test_custom_name_function():
+ result_class = _rialto_import_stub("tests.jobs.test_job.test_job", "custom_job_name")
+ assert issubclass(type(result_class), JobBase)
+
+ custom_callable = result_class.get_custom_callable()
+ assert custom_callable() == "custom_job_name_return"
+
+
+def test_job_dependencies_registered(spark):
+ ConfigHolder.set_custom_config(value=123)
+ job_class = _rialto_import_stub("tests.jobs.test_job.test_job", "job_asking_for_all_deps")
+ # asserts part of the run
+ job_class.run(spark=spark, run_date=456, reader=789, metadata_manager=None, dependencies=1011)
diff --git a/tests/jobs/test_job/test_job.py b/tests/jobs/test_job/test_job.py
new file mode 100644
index 0000000..12baec9
--- /dev/null
+++ b/tests/jobs/test_job/test_job.py
@@ -0,0 +1,40 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from rialto.jobs.decorators import datasource, job
+
+
+@datasource
+def dataset():
+ return "dataset_return"
+
+
+@job
+def job_function():
+ return "job_function_return"
+
+
+@job("custom_job_name")
+def custom_name_job_function():
+ return "custom_job_name_return"
+
+
+@job
+def job_asking_for_all_deps(spark, run_date, config, dependencies, table_reader):
+ assert spark is not None
+ assert run_date == 456
+ assert config["value"] == 123
+ assert table_reader == 789
+ assert dependencies == 1011
diff --git a/tests/jobs/test_job_base.py b/tests/jobs/test_job_base.py
new file mode 100644
index 0000000..2cdc741
--- /dev/null
+++ b/tests/jobs/test_job_base.py
@@ -0,0 +1,93 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import datetime
+from unittest.mock import MagicMock, patch
+
+import pyspark.sql.types
+
+import tests.jobs.resources as resources
+from rialto.jobs.configuration.config_holder import ConfigHolder, FeatureStoreConfig
+from rialto.jobs.decorators.resolver import Resolver
+from rialto.loader import PysparkFeatureLoader
+
+
+def test_setup_except_feature_loader(spark):
+ table_reader = MagicMock()
+ date = datetime.date(2023, 1, 1)
+
+ ConfigHolder.set_custom_config(hello=1, world=2)
+
+ resources.CustomJobNoReturnVal().run(
+ reader=table_reader, run_date=date, spark=spark, metadata_manager=None, dependencies={1: 1}
+ )
+
+ assert Resolver.resolve("run_date") == date
+ assert Resolver.resolve("config") == ConfigHolder.get_config()
+ assert Resolver.resolve("dependencies") == ConfigHolder.get_dependency_config()
+ assert Resolver.resolve("spark") == spark
+ assert Resolver.resolve("table_reader") == table_reader
+
+
+@patch(
+ "rialto.jobs.configuration.config_holder.ConfigHolder.get_feature_store_config",
+ return_value=FeatureStoreConfig(feature_store_schema="schema", feature_metadata_schema="metadata_schema"),
+)
+def test_setup_feature_loader(spark):
+ table_reader = MagicMock()
+ date = datetime.date(2023, 1, 1)
+
+ resources.CustomJobNoReturnVal().run(reader=table_reader, run_date=date, spark=spark, metadata_manager=None)
+
+ assert type(Resolver.resolve("feature_loader")) == PysparkFeatureLoader
+
+
+def test_custom_callable_called(spark, mocker):
+ spy_cc = mocker.spy(resources, "custom_callable")
+
+ table_reader = MagicMock()
+ date = datetime.date(2023, 1, 1)
+
+ resources.CustomJobNoReturnVal().run(reader=table_reader, run_date=date, spark=spark, metadata_manager=None)
+
+ spy_cc.assert_called_once()
+
+
+def test_no_return_vaue_adds_version_timestamp_dataframe(spark):
+ table_reader = MagicMock()
+ date = datetime.date(2023, 1, 1)
+
+ result = resources.CustomJobNoReturnVal().run(
+ reader=table_reader, run_date=date, spark=spark, metadata_manager=None
+ )
+
+ assert type(result) is pyspark.sql.DataFrame
+ assert result.columns == ["JOB_NAME", "CREATION_TIME", "VERSION"]
+ assert result.first()["VERSION"] == "job_version"
+ assert result.count() == 1
+
+
+def test_return_dataframe_forwarded_with_version(spark):
+ table_reader = MagicMock()
+ date = datetime.date(2023, 1, 1)
+
+ result = resources.CustomJobReturnsDataFrame().run(
+ reader=table_reader, run_date=date, spark=spark, metadata_manager=None
+ )
+
+ assert type(result) is pyspark.sql.DataFrame
+ assert result.columns == ["FIRST", "SECOND", "VERSION"]
+ assert result.first()["VERSION"] == "job_version"
+ assert result.count() == 2
diff --git a/tests/jobs/test_resolver.py b/tests/jobs/test_resolver.py
new file mode 100644
index 0000000..df56b72
--- /dev/null
+++ b/tests/jobs/test_resolver.py
@@ -0,0 +1,65 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import pytest
+
+from rialto.jobs.decorators.resolver import Resolver, ResolverException
+
+
+def test_simple_resolve_custom_name():
+ def f():
+ return 7
+
+ Resolver.register_callable(f, "hello")
+
+ assert Resolver.resolve("hello") == 7
+
+
+def test_simple_resolve_infer_f_name():
+ def f():
+ return 7
+
+ Resolver.register_callable(f)
+
+ assert Resolver.resolve("f") == 7
+
+
+def test_dependency_resolve():
+ def f():
+ return 7
+
+ def g(f):
+ return f + 1
+
+ Resolver.register_callable(f)
+ Resolver.register_callable(g)
+
+ assert Resolver.resolve("g") == 8
+
+
+def test_resolve_non_defined():
+ with pytest.raises(ResolverException):
+ Resolver.resolve("whatever")
+
+
+def test_register_resolve(mocker):
+ def f():
+ return 7
+
+ mocker.patch("rialto.jobs.decorators.resolver.Resolver.register_callable", return_value="f")
+ mocker.patch("rialto.jobs.decorators.resolver.Resolver.resolve")
+
+ Resolver.register_resolve(f)
+
+ Resolver.register_callable.assert_called_once_with(f)
+ Resolver.resolve.assert_called_once_with("f")
diff --git a/tests/jobs/test_test_utils.py b/tests/jobs/test_test_utils.py
new file mode 100644
index 0000000..a6b31b2
--- /dev/null
+++ b/tests/jobs/test_test_utils.py
@@ -0,0 +1,48 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import rialto.jobs.decorators as decorators
+import tests.jobs.test_job.test_job as test_job
+from rialto.jobs.decorators.resolver import Resolver
+from rialto.jobs.decorators.test_utils import disable_job_decorators
+
+
+def test_raw_dataset_patch(mocker):
+ spy_rc = mocker.spy(Resolver, "register_callable")
+ spy_dec = mocker.spy(decorators, "datasource")
+
+ with disable_job_decorators(test_job):
+ assert test_job.dataset() == "dataset_return"
+
+ spy_dec.assert_not_called()
+ spy_rc.assert_not_called()
+
+
+def test_job_function_patch(mocker):
+ spy_dec = mocker.spy(decorators, "job")
+
+ with disable_job_decorators(test_job):
+ assert test_job.job_function() == "job_function_return"
+
+ spy_dec.assert_not_called()
+
+
+def test_custom_name_job_function_patch(mocker):
+ spy_dec = mocker.spy(decorators, "job")
+
+ with disable_job_decorators(test_job):
+ assert test_job.custom_name_job_function() == "custom_job_name_return"
+
+ spy_dec.assert_not_called()
diff --git a/tests/loader/__init__.py b/tests/loader/__init__.py
new file mode 100644
index 0000000..79c3773
--- /dev/null
+++ b/tests/loader/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/loader/metadata_config/full_example.yaml b/tests/loader/metadata_config/full_example.yaml
new file mode 100644
index 0000000..9ad780c
--- /dev/null
+++ b/tests/loader/metadata_config/full_example.yaml
@@ -0,0 +1,33 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+selection:
+ - group: A
+ prefix: A
+ features:
+ - A1
+ - A2
+ - group: B
+ prefix: B
+ features:
+ - B1
+ - B2
+base:
+ group: D
+ keys:
+ - K
+ - L
+maps:
+ - M
+ - N
diff --git a/tests/loader/metadata_config/missing_field_example.yaml b/tests/loader/metadata_config/missing_field_example.yaml
new file mode 100644
index 0000000..1caf3b0
--- /dev/null
+++ b/tests/loader/metadata_config/missing_field_example.yaml
@@ -0,0 +1,27 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+selection:
+ - group: A
+ prefix: A
+ features:
+ - A1
+ - A2
+ - group: B
+ features:
+ - B1
+ - B2
+base:
+ group: D
+ keys: K
diff --git a/tests/loader/metadata_config/missing_value_example.yaml b/tests/loader/metadata_config/missing_value_example.yaml
new file mode 100644
index 0000000..844e25f
--- /dev/null
+++ b/tests/loader/metadata_config/missing_value_example.yaml
@@ -0,0 +1,28 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+selection:
+ - group: A
+ prefix: A
+ features:
+ - A1
+ - A2
+ - group: B
+ prefix: B
+ features:
+ - B1
+ - B2
+base:
+ group: D
+ keys:
diff --git a/tests/loader/metadata_config/no_map_example.yaml b/tests/loader/metadata_config/no_map_example.yaml
new file mode 100644
index 0000000..d3679fa
--- /dev/null
+++ b/tests/loader/metadata_config/no_map_example.yaml
@@ -0,0 +1,30 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+selection:
+ - group: A
+ prefix: A
+ features:
+ - A1
+ - A2
+ - group: B
+ prefix: B
+ features:
+ - B1
+ - B2
+base:
+ group: D
+ keys:
+ - K
+ - L
diff --git a/tests/loader/metadata_config/test_main_config.py b/tests/loader/metadata_config/test_main_config.py
new file mode 100644
index 0000000..b09f155
--- /dev/null
+++ b/tests/loader/metadata_config/test_main_config.py
@@ -0,0 +1,56 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+from pydantic import ValidationError
+
+from rialto.loader.config_loader import get_feature_config
+
+
+def test_get_config_full_cfg():
+ cfg = get_feature_config("tests/loader/metadata_config/full_example.yaml")
+ assert len(cfg.selection) == 2
+ assert cfg.selection[0].group == "A"
+ assert cfg.selection[0].prefix == "A"
+ assert cfg.selection[0].features == ["A1", "A2"]
+ assert cfg.selection[1].group == "B"
+ assert cfg.selection[1].prefix == "B"
+ assert cfg.selection[1].features == ["B1", "B2"]
+ assert cfg.base.group == "D"
+ assert cfg.base.keys == ["K", "L"]
+ assert cfg.maps == ["M", "N"]
+
+
+def test_get_config_no_map_cfg():
+ cfg = get_feature_config("tests/loader/metadata_config/no_map_example.yaml")
+ assert len(cfg.selection) == 2
+ assert cfg.selection[0].group == "A"
+ assert cfg.selection[0].prefix == "A"
+ assert cfg.selection[0].features == ["A1", "A2"]
+ assert cfg.selection[1].group == "B"
+ assert cfg.selection[1].prefix == "B"
+ assert cfg.selection[1].features == ["B1", "B2"]
+ assert cfg.base.group == "D"
+ assert cfg.base.keys == ["K", "L"]
+ assert cfg.maps is None
+
+
+def test_get_config_no_base_key():
+ with pytest.raises(ValidationError):
+ get_feature_config("tests/loader/metadata_config/missing_value_example.yaml")
+
+
+def test_get_config_no_prefix_field():
+ with pytest.raises(ValidationError):
+ get_feature_config("tests/loader/metadata_config/missing_field_example.yaml")
diff --git a/tests/loader/pyspark/dataframe_builder.py b/tests/loader/pyspark/dataframe_builder.py
new file mode 100644
index 0000000..94a755e
--- /dev/null
+++ b/tests/loader/pyspark/dataframe_builder.py
@@ -0,0 +1,27 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import typing
+
+from pyspark.sql.types import DataType, StructField, StructType
+
+
+def dataframe_builder(
+ spark, data: typing.List, columns: typing.List[typing.Union[typing.Tuple[str, typing.Type[DataType]]]]
+):
+ schema_builder = []
+ for name, data_type in columns:
+ schema_builder.append(StructField(name, data_type, True))
+ schema = StructType(schema_builder)
+ return spark.createDataFrame(data, schema)
diff --git a/tests/loader/pyspark/dummy_loaders.py b/tests/loader/pyspark/dummy_loaders.py
new file mode 100644
index 0000000..a2b0cb8
--- /dev/null
+++ b/tests/loader/pyspark/dummy_loaders.py
@@ -0,0 +1,24 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from datetime import date
+
+from rialto.loader.data_loader import DataLoader
+
+
+class DummyDataLoader(DataLoader):
+ def __init__(self):
+ super().__init__()
+
+ def read_group(self, group: str, information_date: date):
+ return None
diff --git a/tests/loader/pyspark/example_cfg.yaml b/tests/loader/pyspark/example_cfg.yaml
new file mode 100644
index 0000000..6b19277
--- /dev/null
+++ b/tests/loader/pyspark/example_cfg.yaml
@@ -0,0 +1,24 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+selection:
+ - group: B
+ prefix: B
+ features:
+ - F1
+ - F3
+base:
+ group: D
+ keys:
+ - KEY1
diff --git a/tests/loader/pyspark/resources.py b/tests/loader/pyspark/resources.py
new file mode 100644
index 0000000..64a8363
--- /dev/null
+++ b/tests/loader/pyspark/resources.py
@@ -0,0 +1,44 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from pyspark.sql.types import FloatType, IntegerType, StringType
+
+feature_group_a_data = [("K1", 1, "A"), ("K2", 2, "A"), ("K3", 3, None)]
+feature_group_a_columns = [("KEY", StringType()), ("A1", IntegerType()), ("A2", StringType())]
+
+base_frame_data = [("K1",), ("K2",), ("K3",)]
+base_frame_columns = [("KEY1", StringType())]
+
+mapping1_data = [("M1", "K1", "N11"), ("M2", "K3", "N23")]
+mapping1_columns = [("KEY2", StringType()), ("KEY1", StringType()), ("KEY3", StringType())]
+
+mapping2_data = [("K1", "M1"), ("K2", "M1"), ("K3", "M2"), ("K4", "M2")]
+mapping2_columns = [("KEY1", StringType()), ("KEY2", StringType())]
+
+mapping3_data = [("N11", "H5"), ("N23", "H6")]
+mapping3_columns = [("KEY3", StringType()), ("KEY4", StringType())]
+
+expected_mapping_data = [("K1", "M1", "N11", "H5"), ("K3", "M2", "N23", "H6")]
+expected_mapping_columns = [
+ ("KEY1", StringType()),
+ ("KEY2", StringType()),
+ ("KEY3", StringType()),
+ ("KEY4", StringType()),
+]
+
+feature_group_b_data = [("K1", "A", 5, None), ("K3", "B", 7, 0.36)]
+feature_group_b_columns = [("KEY1", StringType()), ("F1", StringType()), ("F2", IntegerType()), ("F3", FloatType())]
+
+expected_features_b_data = [("K1", "A", None), ("K2", None, None), ("K3", "B", 0.36)]
+expected_features_b_columns = [("KEY1", StringType()), ("B_F1", StringType()), ("B_F3", FloatType())]
diff --git a/tests/loader/pyspark/test_from_cfg.py b/tests/loader/pyspark/test_from_cfg.py
new file mode 100644
index 0000000..3ad653e
--- /dev/null
+++ b/tests/loader/pyspark/test_from_cfg.py
@@ -0,0 +1,137 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from unittest.mock import MagicMock
+
+import pytest
+from chispa import assert_df_equality
+from pyspark.sql import SparkSession
+
+import tests.loader.pyspark.resources as r
+from rialto.loader.config_loader import get_feature_config
+from rialto.loader.pyspark_feature_loader import PysparkFeatureLoader
+from tests.loader.pyspark.dataframe_builder import dataframe_builder as dfb
+from tests.loader.pyspark.dummy_loaders import DummyDataLoader
+
+
+@pytest.fixture(scope="session")
+def spark(request):
+ """fixture for creating a spark session
+ :param request: pytest.FixtureRequest object
+ """
+ spark = (
+ SparkSession.builder.master("local[2]")
+ .appName("pytest-pyspark-local-testing")
+ .config("spark.ui.enabled", "false")
+ .config("spark.driver.bindAddress", "127.0.0.1")
+ .config("spark.driver.host", "127.0.0.1")
+ .getOrCreate()
+ )
+
+ request.addfinalizer(lambda: spark.stop())
+
+ return spark
+
+
+@pytest.fixture(scope="session")
+def loader(spark):
+ return PysparkFeatureLoader(spark, DummyDataLoader(), MagicMock())
+
+
+VALID_LIST = [(["a"], ["a"]), (["a"], ["a", "b", "c"]), (["c", "a"], ["a", "b", "c"])]
+
+
+@pytest.mark.parametrize("valid_terms", VALID_LIST)
+def test_all_keys_in_true(loader, valid_terms):
+ assert loader._are_all_keys_in(valid_terms[0], valid_terms[1]) is True
+
+
+INVALID_LIST = [(["d"], ["a"]), (["a", "d"], ["a", "b", "c"]), (["c", "a", "b"], ["a", "c"])]
+
+
+@pytest.mark.parametrize("invalid_terms", INVALID_LIST)
+def test_all_keys_in_false(loader, invalid_terms):
+ assert loader._are_all_keys_in(invalid_terms[0], invalid_terms[1]) is False
+
+
+def test_add_prefix(loader):
+ df = dfb(loader.spark, data=r.feature_group_a_data, columns=r.feature_group_a_columns)
+ assert loader._add_prefix(df, "A", ["KEY"]).columns == ["KEY", "A_A1", "A_A2"]
+
+
+def test_join_keymaps(loader, spark):
+ key_maps = [
+ PysparkFeatureLoader.KeyMap(dfb(spark, data=r.mapping1_data, columns=r.mapping1_columns), ["KEY1", "KEY2"]),
+ PysparkFeatureLoader.KeyMap(dfb(spark, data=r.mapping2_data, columns=r.mapping2_columns), ["KEY1"]),
+ PysparkFeatureLoader.KeyMap(dfb(spark, data=r.mapping3_data, columns=r.mapping3_columns), ["KEY3"]),
+ ]
+ mapped = loader._join_keymaps(dfb(spark, data=r.base_frame_data, columns=r.base_frame_columns), key_maps)
+ expected = dfb(spark, data=r.expected_mapping_data, columns=r.expected_mapping_columns)
+ assert_df_equality(mapped, expected, ignore_column_order=True, ignore_row_order=True)
+
+
+def test_add_group(spark, monkeypatch):
+ class GroupMd:
+ def __init__(self):
+ self.key = ["KEY1"]
+
+ def __call__(self, *args, **kwargs):
+ return self
+
+ metadata = MagicMock()
+ monkeypatch.setattr(metadata, "get_group", GroupMd())
+ loader = PysparkFeatureLoader(spark, DummyDataLoader(), "")
+ loader.metadata = metadata
+
+ base = dfb(spark, data=r.base_frame_data, columns=r.base_frame_columns)
+ df = dfb(spark, data=r.feature_group_b_data, columns=r.feature_group_b_columns)
+ group_cfg = get_feature_config("tests/loader/pyspark/example_cfg.yaml").selection[0]
+
+ features = loader._add_feature_group(base, df, group_cfg)
+ expected = dfb(spark, data=r.expected_features_b_data, columns=r.expected_features_b_columns)
+ assert_df_equality(features, expected, ignore_column_order=True, ignore_row_order=True)
+
+
+def test_get_group_metadata(spark, mocker):
+ mocker.patch("rialto.loader.pyspark_feature_loader.MetadataManager.get_group", return_value=7)
+
+ loader = PysparkFeatureLoader(spark, DummyDataLoader(), "")
+ ret_val = loader.get_group_metadata("group_name")
+
+ assert ret_val == 7
+ loader.metadata.get_group.assert_called_once_with("group_name")
+
+
+def test_get_feature_metadata(spark, mocker):
+ mocker.patch("rialto.loader.pyspark_feature_loader.MetadataManager.get_feature", return_value=8)
+
+ loader = PysparkFeatureLoader(spark, DummyDataLoader(), "")
+ ret_val = loader.get_feature_metadata("group_name", "feature")
+
+ assert ret_val == 8
+ loader.metadata.get_feature.assert_called_once_with("group_name", "feature")
+
+
+def test_get_metadata_from_cfg(spark, mocker):
+ mocker.patch(
+ "rialto.loader.pyspark_feature_loader.MetadataManager.get_feature",
+ side_effect=lambda g, f: {"B": {"F1": 1, "F3": 2}}[g][f],
+ )
+ mocker.patch("rialto.loader.pyspark_feature_loader.MetadataManager.get_group", side_effect=lambda g: {"B": 10}[g])
+
+ loader = PysparkFeatureLoader(spark, DummyDataLoader(), "")
+ metadata = loader.get_metadata_from_cfg("tests/loader/pyspark/example_cfg.yaml")
+
+ assert metadata["B_F1"] == 1
+ assert metadata["B_F3"] == 2
+ assert len(metadata.keys()) == 2
diff --git a/tests/maker/__init__.py b/tests/maker/__init__.py
new file mode 100644
index 0000000..79c3773
--- /dev/null
+++ b/tests/maker/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/maker/conftest.py b/tests/maker/conftest.py
new file mode 100644
index 0000000..79455ff
--- /dev/null
+++ b/tests/maker/conftest.py
@@ -0,0 +1,36 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import pytest
+from pyspark.sql import SparkSession
+
+
+@pytest.fixture(scope="session")
+def spark(request):
+ """fixture for creating a spark session
+ Args:
+ request: pytest.FixtureRequest object
+ """
+
+ spark = (
+ SparkSession.builder.master("local[2]")
+ .appName("pytest-pyspark-local-testing")
+ .config("spark.ui.enabled", "false")
+ .config("spark.driver.bindAddress", "127.0.0.1")
+ .config("spark.driver.host", "127.0.0.1")
+ .getOrCreate()
+ )
+
+ request.addfinalizer(lambda: spark.stop())
+
+ return spark
diff --git a/tests/maker/test_FeatureFunction.py b/tests/maker/test_FeatureFunction.py
new file mode 100644
index 0000000..43590ae
--- /dev/null
+++ b/tests/maker/test_FeatureFunction.py
@@ -0,0 +1,74 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from unittest.mock import Mock
+
+import pytest
+
+from rialto.maker.containers import FeatureFunction
+from rialto.metadata import ValueType
+
+
+def test_name_generation_no_parameters():
+ func = FeatureFunction("feature", Mock())
+ assert func.get_feature_name() == "FEATURE"
+
+
+def test_name_generation_with_parameter():
+ func = FeatureFunction("feature", Mock())
+ func.parameters["param"] = 6
+ assert func.get_feature_name() == "FEATURE_PARAM_6"
+
+
+def test_name_generation_multiple_params():
+ func = FeatureFunction("feature", Mock())
+ func.parameters["paramC"] = 1
+ func.parameters["paramA"] = 4
+ func.parameters["paramB"] = 6
+ assert func.get_feature_name() == "FEATURE_PARAMA_4_PARAMB_6_PARAMC_1"
+
+
+def test_feature_type_default_is_nominal():
+ func = FeatureFunction("feature", Mock())
+ assert func.type == ValueType.nominal
+
+
+@pytest.mark.parametrize(
+ "feature_type",
+ [(ValueType.nominal, "nominal"), (ValueType.ordinal, "ordinal"), (ValueType.numerical, "numerical")],
+)
+def test_feature_type_getter(feature_type: tuple):
+ func = FeatureFunction("feature", Mock(), feature_type[0])
+ assert func.get_type() == feature_type[1]
+
+
+def test_serialization():
+ func = FeatureFunction("feature", Mock())
+ func.parameters["paramC"] = 1
+ func.parameters["paramA"] = 4
+ assert (
+ func.__str__()
+ == "Name: feature\n\tParameters: {'paramC': 1, 'paramA': 4}\n\tType: nominal\n\tDescription: basic feature"
+ )
+
+
+def test_metadata():
+ func = FeatureFunction("feature", Mock(), ValueType.ordinal)
+ func.parameters["paramC"] = 1
+ func.parameters["paramA"] = 4
+ func.dependencies = ["featureB", "featureC"]
+ func.description = "nice feature"
+
+ assert func.metadata().name == "FEATURE_PARAMA_4_PARAMC_1"
+ assert func.metadata().value_type == ValueType.ordinal
+ assert func.metadata().description == "nice feature"
diff --git a/tests/maker/test_FeatureHolder.py b/tests/maker/test_FeatureHolder.py
new file mode 100644
index 0000000..5c00cdb
--- /dev/null
+++ b/tests/maker/test_FeatureHolder.py
@@ -0,0 +1,36 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from unittest.mock import Mock
+
+from rialto.maker.containers import FeatureFunction, FeatureHolder
+from rialto.metadata import ValueType
+
+
+def test_metadata_return_type_empty():
+ assert isinstance(FeatureHolder().get_metadata(), list)
+
+
+def test_metadata_return_type():
+ fh = FeatureHolder()
+ fh.append(FeatureFunction("feature_nominal", Mock(), ValueType.nominal))
+ assert isinstance(fh.get_metadata(), list)
+
+
+def test_metadata_value():
+ fh = FeatureHolder()
+ ff = FeatureFunction("feature_ordinal", Mock(), ValueType.ordinal)
+ ff.parameters["param"] = 3
+ fh.append(ff)
+ metadata = fh.get_metadata()
+ assert metadata[0].value_type == ValueType.ordinal
diff --git a/tests/maker/test_FeatureMaker.py b/tests/maker/test_FeatureMaker.py
new file mode 100644
index 0000000..b8f9c1a
--- /dev/null
+++ b/tests/maker/test_FeatureMaker.py
@@ -0,0 +1,187 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from datetime import date
+
+import pandas as pd
+import pytest
+
+from rialto.maker.feature_maker import FeatureMaker
+from rialto.metadata import ValueType
+from tests.maker.test_features import (
+ aggregated_num_sum_outbound,
+ aggregated_num_sum_txn,
+ dependent_features_fail,
+ dependent_features_fail2,
+ dependent_features_ok,
+ sequential_avg_outbound,
+ sequential_avg_txn,
+ sequential_for_testing,
+ sequential_outbound,
+ sequential_outbound_with_param,
+)
+
+
+@pytest.fixture
+def input_df(spark):
+ df = pd.DataFrame(
+ [
+ [42, "A", "C_1"],
+ [-35, "A", "C_1"],
+ [-12, "B", "C_1"],
+ [-65, "B", "C_1"],
+ [12, "A", "C_2"],
+ [16, "A", "C_2"],
+ [-10, "A", "C_2"],
+ ],
+ columns=["AMT", "TYPE", "CUSTOMER_KEY"],
+ )
+ return spark.createDataFrame(df)
+
+
+def test_sequential_column_exists(input_df):
+ df, _ = FeatureMaker.make(input_df, "CUSTOMER_KEY", date.today(), sequential_outbound, keep_preexisting=True)
+ assert "TRANSACTIONS_OUTBOUND_VALUE" in df.columns
+
+
+def test_sequential_multi_key(input_df):
+ df, _ = FeatureMaker.make(
+ input_df, ["CUSTOMER_KEY", "TYPE"], date.today(), sequential_outbound, keep_preexisting=True
+ )
+ assert "TRANSACTIONS_OUTBOUND_VALUE" in df.columns
+
+
+def test_sequential_keeps(input_df):
+ df, _ = FeatureMaker.make(input_df, "CUSTOMER_KEY", date.today(), sequential_outbound, keep_preexisting=True)
+ assert "AMT" in df.columns
+
+
+def test_sequential_drops(input_df):
+ df, _ = FeatureMaker.make(input_df, "CUSTOMER_KEY", date.today(), sequential_outbound, keep_preexisting=False)
+ assert "AMT" not in df.columns
+
+
+def test_sequential_key_not_dropped(input_df):
+ df, _ = FeatureMaker.make(input_df, "CUSTOMER_KEY", date.today(), sequential_outbound, keep_preexisting=False)
+ assert "CUSTOMER_KEY" in df.columns
+
+
+def test_sequential_with_params_column_exists(input_df):
+ df, _ = FeatureMaker.make(
+ input_df, "CUSTOMER_KEY", date.today(), sequential_outbound_with_param, keep_preexisting=False
+ )
+ assert "TRANSACTIONS_OUTBOUND_VALUE_V_TYPE_A" in df.columns
+
+
+def test_aggregated_column_exists(input_df):
+ df, _ = FeatureMaker.make_aggregated(input_df, "CUSTOMER_KEY", date.today(), aggregated_num_sum_txn)
+ assert "TRANSACTIONS_NUM_TRANSACTIONS" in df.columns
+
+
+def test_aggregated_key_exists(input_df):
+ df, _ = FeatureMaker.make_aggregated(input_df, "CUSTOMER_KEY", date.today(), aggregated_num_sum_txn)
+ assert "CUSTOMER_KEY" in df.columns
+
+
+def test_aggregated_multi_key_exists(input_df):
+ df, _ = FeatureMaker.make_aggregated(input_df, ["CUSTOMER_KEY", "TYPE"], date.today(), aggregated_num_sum_txn)
+ assert "CUSTOMER_KEY" in df.columns and "TYPE" in df.columns
+
+
+def test_maker_metadata(input_df):
+ df, metadata = FeatureMaker.make_aggregated(input_df, "CUSTOMER_KEY", date.today(), aggregated_num_sum_txn)
+ assert metadata[0].value_type == ValueType.numerical
+
+
+def test_double_chained_makers_column_exists(input_df):
+ df, _ = FeatureMaker.make_aggregated(input_df, "CUSTOMER_KEY", date.today(), aggregated_num_sum_txn)
+ df, _ = FeatureMaker.make(df, "CUSTOMER_KEY", date.today(), sequential_avg_txn)
+ assert "TRANSACTIONS_AVG_TRANSACTION" in df.columns
+
+
+def test_tripple_chained_makers_column_exists(input_df):
+ # create outbound column
+ df, _ = FeatureMaker.make(input_df, "CUSTOMER_KEY", date.today(), sequential_outbound)
+ # agg outbound sum and num
+ df, _ = FeatureMaker.make_aggregated(df, "CUSTOMER_KEY", date.today(), aggregated_num_sum_outbound)
+ # create outbound avg
+ df, _ = FeatureMaker.make(df, "CUSTOMER_KEY", date.today(), sequential_avg_outbound)
+ assert "TRANSACTIONS_AVG_OUTBOUND" in df.columns
+
+
+def test_tripple_chained_makers_key_exists(input_df):
+ # create outbound column
+ df, _ = FeatureMaker.make(input_df, "CUSTOMER_KEY", date.today(), sequential_outbound)
+ # agg outbound sum and num
+ df, _ = FeatureMaker.make_aggregated(df, "CUSTOMER_KEY", date.today(), aggregated_num_sum_outbound)
+ # create outbound avg
+ df, _ = FeatureMaker.make(df, "CUSTOMER_KEY", date.today(), sequential_avg_outbound)
+ assert "CUSTOMER_KEY" in df.columns
+
+
+def test_dependency_resolution(input_df):
+ ordered = FeatureMaker._order_by_dependencies(FeatureMaker._load_features(dependent_features_ok))
+ ordered = [f[0].name for f in ordered]
+ assert ordered.index("f4_raw") == 0
+ assert ordered.index("f3_depends_f2") < ordered.index("f1_depends_f3_f5")
+ assert ordered.index("f5_depends_f4") < ordered.index("f1_depends_f3_f5")
+ assert ordered.index("f4_raw") < ordered.index("f2_depends_f4")
+ assert ordered.index("f2_depends_f4") < ordered.index("f3_depends_f2")
+ assert ordered.index("f4_raw") < ordered.index("f5_depends_f4")
+
+
+def test_dependency_resolution_cycle(input_df):
+ with pytest.raises(Exception, match="Feature dependencies can't be resolved!"):
+ FeatureMaker._order_by_dependencies(FeatureMaker._load_features(dependent_features_fail))
+
+
+def test_dependency_resolution_self_reference(input_df):
+ with pytest.raises(Exception, match="Feature dependencies can't be resolved!"):
+ FeatureMaker._order_by_dependencies(FeatureMaker._load_features(dependent_features_fail2))
+
+
+def test_find_single_feature():
+ features = FeatureMaker._register_module(sequential_for_testing)
+ feature = FeatureMaker._find_feature("FOR_TESTING_PARAM_B", features)
+ assert feature.get_feature_name() == "FOR_TESTING_PARAM_B"
+
+
+def test_make_single_feature_column_exists(input_df):
+ out = FeatureMaker.make_single_feature(input_df, "FOR_TESTING_PARAM_B", sequential_for_testing)
+ assert "FOR_TESTING_PARAM_B" in out.columns
+
+
+def test_make_single_feature_column_single(input_df):
+ out = FeatureMaker.make_single_feature(input_df, "FOR_TESTING_PARAM_B", sequential_for_testing)
+ assert len(out.columns) == 1
+
+
+def test_make_single_agg_feature_column_exists(input_df):
+ out = FeatureMaker.make_single_agg_feature(
+ input_df, "TRANSACTIONS_SUM_TRANSACTIONS", "CUSTOMER_KEY", aggregated_num_sum_txn
+ )
+ assert "TRANSACTIONS_SUM_TRANSACTIONS" in out.columns
+
+
+def test_make_single_agg_feature_column_single(input_df):
+ out = FeatureMaker.make_single_agg_feature(
+ input_df, "TRANSACTIONS_SUM_TRANSACTIONS", "CUSTOMER_KEY", aggregated_num_sum_txn
+ )
+ assert len(out.columns) == 2
+
+
+def test_make_single_agg_feature_multikey(input_df):
+ out = FeatureMaker.make_single_agg_feature(
+ input_df, "TRANSACTIONS_SUM_TRANSACTIONS", ["CUSTOMER_KEY", "TYPE"], aggregated_num_sum_txn
+ )
+ assert len(out.columns) == 3
diff --git a/tests/maker/test_features/__init__.py b/tests/maker/test_features/__init__.py
new file mode 100644
index 0000000..79c3773
--- /dev/null
+++ b/tests/maker/test_features/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/maker/test_features/aggregated_num_sum_outbound.py b/tests/maker/test_features/aggregated_num_sum_outbound.py
new file mode 100644
index 0000000..ce3937b
--- /dev/null
+++ b/tests/maker/test_features/aggregated_num_sum_outbound.py
@@ -0,0 +1,27 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import pyspark.sql.functions as F
+from pyspark.sql import Column
+
+from rialto import maker
+
+
+@maker.feature(maker.ValueType.numerical)
+def transactions_num_outbound() -> Column:
+ return F.count(F.col("transactions_outbound_value"))
+
+
+@maker.feature(maker.ValueType.numerical)
+def transactions_sum_outbound() -> Column:
+ return F.sum(F.col("transactions_outbound_value"))
diff --git a/tests/maker/test_features/aggregated_num_sum_txn.py b/tests/maker/test_features/aggregated_num_sum_txn.py
new file mode 100644
index 0000000..6c807af
--- /dev/null
+++ b/tests/maker/test_features/aggregated_num_sum_txn.py
@@ -0,0 +1,27 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import pyspark.sql.functions as F
+from pyspark.sql import Column
+
+from rialto import maker
+
+
+@maker.feature(maker.ValueType.numerical)
+def transactions_num_transactions() -> Column:
+ return F.count(F.col("AMT"))
+
+
+@maker.feature(maker.ValueType.numerical)
+def transactions_sum_transactions() -> Column:
+ return F.sum(F.col("AMT"))
diff --git a/tests/maker/test_features/dependent_features_fail.py b/tests/maker/test_features/dependent_features_fail.py
new file mode 100644
index 0000000..d9a8c7f
--- /dev/null
+++ b/tests/maker/test_features/dependent_features_fail.py
@@ -0,0 +1,29 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import pyspark.sql.functions as F
+from pyspark.sql import Column
+
+from rialto import maker
+
+
+@maker.feature(maker.ValueType.nominal)
+@maker.depends("f2_depends_f1")
+def f1_depends_f2() -> Column:
+ return F.col("CUSTOMER_KEY")
+
+
+@maker.feature(maker.ValueType.nominal)
+@maker.depends("f1_depends_f2")
+def f2_depends_f1() -> Column:
+ return F.col("CUSTOMER_KEY")
diff --git a/tests/maker/test_features/dependent_features_fail2.py b/tests/maker/test_features/dependent_features_fail2.py
new file mode 100644
index 0000000..4964c8a
--- /dev/null
+++ b/tests/maker/test_features/dependent_features_fail2.py
@@ -0,0 +1,23 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import pyspark.sql.functions as F
+from pyspark.sql import Column
+
+from rialto import maker
+
+
+@maker.feature(maker.ValueType.nominal)
+@maker.depends("f5")
+def f1_dependsf5() -> Column:
+ return F.col("CUSTOMER_KEY")
diff --git a/tests/maker/test_features/dependent_features_ok.py b/tests/maker/test_features/dependent_features_ok.py
new file mode 100644
index 0000000..232f08b
--- /dev/null
+++ b/tests/maker/test_features/dependent_features_ok.py
@@ -0,0 +1,47 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import pyspark.sql.functions as F
+from pyspark.sql import Column
+
+from rialto import maker
+
+
+@maker.feature(maker.ValueType.nominal)
+@maker.depends("f3_depends_f2")
+@maker.depends("f5_depends_f4")
+def f1_depends_f3_f5() -> Column:
+ return F.col("CUSTOMER_KEY")
+
+
+@maker.feature(maker.ValueType.nominal)
+@maker.depends("f4_raw")
+def f2_depends_f4() -> Column:
+ return F.col("CUSTOMER_KEY")
+
+
+@maker.feature(maker.ValueType.nominal)
+@maker.depends("f2_depends_f4")
+def f3_depends_f2() -> Column:
+ return F.col("CUSTOMER_KEY")
+
+
+@maker.feature(maker.ValueType.nominal)
+def f4_raw() -> Column:
+ return F.col("CUSTOMER_KEY")
+
+
+@maker.feature(maker.ValueType.nominal)
+@maker.depends("f4_raw")
+def f5_depends_f4() -> Column:
+ return F.col("CUSTOMER_KEY")
diff --git a/tests/maker/test_features/sequential_avg_outbound.py b/tests/maker/test_features/sequential_avg_outbound.py
new file mode 100644
index 0000000..cedad5f
--- /dev/null
+++ b/tests/maker/test_features/sequential_avg_outbound.py
@@ -0,0 +1,22 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import pyspark.sql.functions as F
+from pyspark.sql import Column
+
+from rialto import maker
+
+
+@maker.feature(maker.ValueType.numerical)
+def transactions_avg_outbound() -> Column:
+ return F.col("transactions_sum_outbound") / F.col("transactions_num_outbound")
diff --git a/tests/maker/test_features/sequential_avg_txn.py b/tests/maker/test_features/sequential_avg_txn.py
new file mode 100644
index 0000000..65d1f7f
--- /dev/null
+++ b/tests/maker/test_features/sequential_avg_txn.py
@@ -0,0 +1,22 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import pyspark.sql.functions as F
+from pyspark.sql import Column
+
+from rialto import maker
+
+
+@maker.feature(maker.ValueType.numerical)
+def transactions_avg_transaction() -> Column:
+ return F.col("transactions_sum_transactions") / F.col("transactions_num_transactions")
diff --git a/tests/maker/test_features/sequential_for_testing.py b/tests/maker/test_features/sequential_for_testing.py
new file mode 100644
index 0000000..5a8de84
--- /dev/null
+++ b/tests/maker/test_features/sequential_for_testing.py
@@ -0,0 +1,25 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import pyspark.sql.functions as F
+from pyspark.sql import Column
+
+from rialto import maker
+
+
+@maker.feature(maker.ValueType.numerical)
+@maker.param("param", ["A", "B"])
+def for_testing(param) -> Column:
+ filtered = F.col("TYPE") == param
+ outbound = F.when(F.col("AMT") < 0, F.col("AMT")).otherwise(None)
+ return F.when(filtered, outbound)
diff --git a/tests/maker/test_features/sequential_outbound.py b/tests/maker/test_features/sequential_outbound.py
new file mode 100644
index 0000000..6b0764e
--- /dev/null
+++ b/tests/maker/test_features/sequential_outbound.py
@@ -0,0 +1,22 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import pyspark.sql.functions as F
+from pyspark.sql import Column
+
+from rialto import maker
+
+
+@maker.feature(maker.ValueType.numerical)
+def transactions_outbound_value() -> Column:
+ return F.when(F.col("AMT") < 0, F.col("AMT")).otherwise(None)
diff --git a/tests/maker/test_features/sequential_outbound_with_param.py b/tests/maker/test_features/sequential_outbound_with_param.py
new file mode 100644
index 0000000..eb50d80
--- /dev/null
+++ b/tests/maker/test_features/sequential_outbound_with_param.py
@@ -0,0 +1,25 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import pyspark.sql.functions as F
+from pyspark.sql import Column
+
+from rialto import maker
+
+
+@maker.feature(maker.ValueType.numerical)
+@maker.param("v_type", ["A", "B"])
+def transactions_outbound_value(v_type) -> Column:
+ filtered = F.col("TYPE") == v_type
+ outbound = F.when(F.col("AMT") < 0, F.col("AMT")).otherwise(None)
+ return F.when(filtered, outbound)
diff --git a/tests/maker/test_wrappers.py b/tests/maker/test_wrappers.py
new file mode 100644
index 0000000..135b4ad
--- /dev/null
+++ b/tests/maker/test_wrappers.py
@@ -0,0 +1,116 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from rialto.maker.containers import FeatureHolder
+from rialto.maker.wrappers import depends, desc, feature, param
+from rialto.metadata import ValueType
+
+
+def dummy_feature_function():
+ return None
+
+
+def dummy_feature_with_args(parameter_1, parameter_2, parameter_3):
+ return parameter_1 + parameter_2 + parameter_3
+
+
+def test_feature_from_holder():
+ val = feature(ValueType.numerical)(FeatureHolder())
+ assert isinstance(val, FeatureHolder)
+
+
+def test_feature_from_function_return_type():
+ val = feature(ValueType.numerical)(dummy_feature_function)
+ assert isinstance(val, FeatureHolder)
+
+
+def test_feature_from_function_function_name():
+ val = feature(ValueType.numerical)(dummy_feature_function)
+ assert val[0].get_feature_name() == "DUMMY_FEATURE_FUNCTION"
+
+
+def test_feature_from_function_function_object():
+ val = feature(ValueType.numerical)(dummy_feature_function)
+ assert val[0].callable == dummy_feature_function
+
+
+def test_parametrize_from_function_return_type():
+ val = param("parameter", [1, 2, 3])(dummy_feature_with_args)
+ assert isinstance(val, FeatureHolder)
+
+
+def test_parametrize_from_function_size():
+ val = param("parameter", [1, 2, 3])(dummy_feature_with_args)
+ assert len(val) == 3
+
+
+def test_parametrize_chained_size():
+ val = param("parameter_1", [1, 2, 3])(dummy_feature_with_args)
+ val = param("parameter_2", [4, 5, 6])(val)
+ assert len(val) == 9
+
+
+def test_parametrize_chained_values():
+ val = param("parameter_1", [1, 2, 3])(dummy_feature_with_args)
+ val = param("parameter_2", [4, 5, 6])(val)
+ val = param("parameter_3", [7, 8, 9])(val)
+ # expecting ordered combinations (1,4,7)(1,4,8)(1,4,9)(1,5,7)(1,5,8).....
+ assert (
+ val[13].parameters["parameter_1"] == 2
+ and val[13].parameters["parameter_2"] == 5
+ and val[13].parameters["parameter_3"] == 8
+ )
+
+
+def test_parametrize_chained_callable():
+ val = param("parameter_1", [1, 2, 3])(dummy_feature_with_args)
+ val = param("parameter_2", [4, 5, 6])(val)
+ val = param("parameter_3", [7, 8, 9])(val)
+ assert val[13].callable() == 15
+
+
+def test_feature_keeps_size():
+ val = feature(ValueType.ordinal)(dummy_feature_function)
+ assert len(val) == 1
+
+
+def test_depends():
+ val = depends("previous")(dummy_feature_function)
+ assert val[0].dependencies[0] == "previous"
+
+
+def test_depends_keeps_size():
+ val = depends("previous")(dummy_feature_function)
+ assert len(val) == 1
+
+
+def test_description():
+ val = desc("Feature A")(dummy_feature_function)
+ assert val[0].description == "Feature A"
+
+
+def test_description_keeps_size():
+ val = desc("Feature A")(dummy_feature_function)
+ assert len(val) == 1
+
+
+def test_chaining():
+ f = desc("Feature A")(dummy_feature_function)
+ f = param("Param", ["B"])(f)
+ f = depends("previous")(f)
+ f = feature(ValueType.ordinal)(f)
+ assert f[0].description == "Feature A"
+ assert f[0].dependencies[0] == "previous"
+ assert f[0].get_type() == "ordinal"
+ assert f[0].get_feature_name() == "DUMMY_FEATURE_FUNCTION_PARAM_B"
+ assert len(f) == 1
diff --git a/tests/metadata/__init__.py b/tests/metadata/__init__.py
new file mode 100644
index 0000000..79c3773
--- /dev/null
+++ b/tests/metadata/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/metadata/conftest.py b/tests/metadata/conftest.py
new file mode 100644
index 0000000..b0cd24e
--- /dev/null
+++ b/tests/metadata/conftest.py
@@ -0,0 +1,56 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import pytest
+from pyspark.sql import SparkSession
+
+from rialto.metadata.metadata_manager import MetadataManager
+from tests.metadata.resources import (
+ feature_base,
+ feature_schema,
+ group_base,
+ group_schema,
+)
+
+
+@pytest.fixture(scope="session")
+def spark(request):
+ """fixture for creating a spark session
+ :param request: pytest.FixtureRequest object
+ """
+
+ spark = (
+ SparkSession.builder.master("local[3]")
+ .appName("pytest-pyspark-local-testing")
+ .config("spark.ui.enabled", "false")
+ .config("spark.driver.bindAddress", "127.0.0.1")
+ .config("spark.driver.host", "127.0.0.1")
+ .getOrCreate()
+ )
+
+ request.addfinalizer(lambda: spark.stop())
+
+ return spark
+
+
+@pytest.fixture(scope="function")
+def mdc(spark):
+ """
+ Metadata manager fixture with mocked metadata
+ :param spark: spark
+ :return: pytest fixture
+ """
+ mdc = MetadataManager(spark)
+ mdc.groups = spark.createDataFrame(group_base, group_schema)
+ mdc.features = spark.createDataFrame(feature_base, feature_schema)
+ return mdc
diff --git a/tests/metadata/resources.py b/tests/metadata/resources.py
new file mode 100644
index 0000000..46f0ab1
--- /dev/null
+++ b/tests/metadata/resources.py
@@ -0,0 +1,64 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from pyspark.sql.types import ArrayType, StringType, StructField, StructType
+
+from rialto.metadata import FeatureMetadata, GroupMetadata, Schedule, ValueType
+
+group_schema = StructType(
+ [
+ StructField("group_name", StringType(), False),
+ StructField("group_frequency", StringType(), False),
+ StructField("group_description", StringType(), False),
+ StructField("group_key", ArrayType(StringType(), True), False),
+ StructField("group_fs_name", StringType(), False),
+ ]
+)
+
+feature_schema = StructType(
+ [
+ StructField("feature_name", StringType(), True),
+ StructField("feature_type", StringType(), True),
+ StructField("feature_description", StringType(), True),
+ StructField("group_name", StringType(), True),
+ ]
+)
+
+group_base = [
+ ("Group1", "weekly", "group1", ["key1"], "group_1"),
+ ("Group2", "monthly", "group2", ["key2", "key3"], "group_2"),
+]
+
+feature_base = [
+ ("Feature1", "nominal", "feature1", "Group2"),
+ ("Feature2", "nominal", "feature2", "Group2"),
+]
+
+group_md1 = GroupMetadata(
+ name="Group1",
+ fs_name="group_1",
+ frequency=Schedule.weekly,
+ description="group1",
+ key=["key1"],
+)
+
+group_md2 = GroupMetadata(
+ name="Group2",
+ fs_name="group_2",
+ frequency=Schedule.monthly,
+ description="group2",
+ key=["key2", "key3"],
+ features=["Feature1", "Feature2"],
+)
+
+feature_md1 = FeatureMetadata(name="Feature1", value_type=ValueType.nominal, description="feature1", group=group_md2)
diff --git a/tests/metadata/test_metadata_connector.py b/tests/metadata/test_metadata_connector.py
new file mode 100644
index 0000000..6594e6c
--- /dev/null
+++ b/tests/metadata/test_metadata_connector.py
@@ -0,0 +1,43 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import pytest
+
+from tests.metadata.resources import feature_md1, group_md1, group_md2
+
+
+def test_get_group_no_features(mdc):
+ assert str(mdc.get_group("Group1")) == str(group_md1)
+
+
+def test_get_group_w_features(mdc):
+ assert str(mdc.get_group("Group2")) == str(group_md2)
+
+
+def test_get_group_none(mdc):
+ with pytest.raises(Exception):
+ mdc.get_group("Group42")
+
+
+def test_get_feature(mdc):
+ assert str(mdc.get_feature("Group2", "Feature1")) == str(feature_md1)
+
+
+def test_get_feature_none_group(mdc):
+ with pytest.raises(Exception):
+ mdc.get_feature("Group42", "Feature1")
+
+
+def test_get_feature_none_feature(mdc):
+ with pytest.raises(Exception):
+ mdc.get_feature("Group2", "Feature8")
diff --git a/tests/runner/__init__.py b/tests/runner/__init__.py
new file mode 100644
index 0000000..79c3773
--- /dev/null
+++ b/tests/runner/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/runner/conftest.py b/tests/runner/conftest.py
new file mode 100644
index 0000000..44f0c09
--- /dev/null
+++ b/tests/runner/conftest.py
@@ -0,0 +1,44 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import pytest
+from pyspark.sql import SparkSession
+
+from rialto.runner import Runner
+
+
+@pytest.fixture(scope="session")
+def spark(request):
+ """fixture for creating a spark session
+ :param request: pytest.FixtureRequest object
+ """
+
+ spark = (
+ SparkSession.builder.master("local[3]")
+ .appName("pytest-pyspark-local-testing")
+ .config("spark.ui.enabled", "false")
+ .config("spark.driver.bindAddress", "127.0.0.1")
+ .config("spark.driver.host", "127.0.0.1")
+ .getOrCreate()
+ )
+
+ request.addfinalizer(lambda: spark.stop())
+
+ return spark
+
+
+@pytest.fixture(scope="function")
+def basic_runner(spark):
+ return Runner(
+ spark, config_path="tests/runner/transformations/config.yaml", feature_metadata_schema="", run_date="2023-03-31"
+ )
diff --git a/tests/runner/runner_resources.py b/tests/runner/runner_resources.py
new file mode 100644
index 0000000..bd39947
--- /dev/null
+++ b/tests/runner/runner_resources.py
@@ -0,0 +1,38 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from pyspark.sql.types import DateType, StringType, StructField, StructType
+
+from rialto.runner.date_manager import DateManager
+
+simple_group_data = [
+ ("A", DateManager.str_to_date("2023-03-05")),
+ ("B", DateManager.str_to_date("2023-03-12")),
+ ("C", DateManager.str_to_date("2023-03-19")),
+]
+
+general_schema = StructType([StructField("KEY", StringType(), True), StructField("DATE", DateType(), True)])
+
+
+dep1_data = [
+ ("E", DateManager.str_to_date("2023-03-05")),
+ ("F", DateManager.str_to_date("2023-03-10")),
+ ("G", DateManager.str_to_date("2023-03-15")),
+ ("H", DateManager.str_to_date("2023-03-25")),
+]
+
+dep2_data = [
+ ("J", DateManager.str_to_date("2022-11-01")),
+ ("K", DateManager.str_to_date("2022-12-01")),
+ ("L", DateManager.str_to_date("2023-01-01")),
+]
diff --git a/tests/runner/test_date_manager.py b/tests/runner/test_date_manager.py
new file mode 100644
index 0000000..9088e0c
--- /dev/null
+++ b/tests/runner/test_date_manager.py
@@ -0,0 +1,171 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from datetime import datetime
+
+import pytest
+
+from rialto.runner.config_loader import IntervalConfig, ScheduleConfig
+from rialto.runner.date_manager import DateManager
+
+
+def test_str_to_date():
+ assert DateManager.str_to_date("2023-03-05") == datetime.strptime("2023-03-05", "%Y-%m-%d").date()
+
+
+@pytest.mark.parametrize(
+ "units , value, res",
+ [("days", 7, "2023-02-26"), ("weeks", 3, "2023-02-12"), ("months", 5, "2022-10-05"), ("years", 2, "2021-03-5")],
+)
+def test_date_from(units, value, res):
+ rundate = DateManager.str_to_date("2023-03-05")
+ date_from = DateManager.date_subtract(run_date=rundate, units=units, value=value)
+ assert date_from == DateManager.str_to_date(res)
+
+
+def test_date_from_bad():
+ rundate = DateManager.str_to_date("2023-03-05")
+ with pytest.raises(ValueError) as exception:
+ DateManager.date_subtract(run_date=rundate, units="random", value=1)
+ assert str(exception.value) == "Unknown time unit random"
+
+
+def test_all_dates():
+ all_dates = DateManager.all_dates(
+ date_from=DateManager.str_to_date("2023-02-05"),
+ date_to=DateManager.str_to_date("2023-04-12"),
+ )
+ assert len(all_dates) == 67
+ assert all_dates[1] == DateManager.str_to_date("2023-02-06")
+
+
+def test_all_dates_reversed():
+ all_dates = DateManager.all_dates(
+ date_from=DateManager.str_to_date("2023-04-12"),
+ date_to=DateManager.str_to_date("2023-02-05"),
+ )
+ assert len(all_dates) == 67
+ assert all_dates[1] == DateManager.str_to_date("2023-02-06")
+
+
+def test_run_dates_weekly():
+ cfg = ScheduleConfig(frequency="weekly", day=5)
+
+ run_dates = DateManager.run_dates(
+ date_from=DateManager.str_to_date("2023-02-05"),
+ date_to=DateManager.str_to_date("2023-04-07"),
+ schedule=cfg,
+ )
+
+ expected = [
+ "2023-02-10",
+ "2023-02-17",
+ "2023-02-24",
+ "2023-03-03",
+ "2023-03-10",
+ "2023-03-17",
+ "2023-03-24",
+ "2023-03-31",
+ "2023-04-07",
+ ]
+ expected = [DateManager.str_to_date(d) for d in expected]
+ assert run_dates == expected
+
+
+def test_run_dates_monthly():
+ cfg = ScheduleConfig(frequency="monthly", day=5)
+
+ run_dates = DateManager.run_dates(
+ date_from=DateManager.str_to_date("2022-08-05"),
+ date_to=DateManager.str_to_date("2023-04-07"),
+ schedule=cfg,
+ )
+
+ expected = [
+ "2022-08-05",
+ "2022-09-05",
+ "2022-10-05",
+ "2022-11-05",
+ "2022-12-05",
+ "2023-01-05",
+ "2023-02-05",
+ "2023-03-05",
+ "2023-04-05",
+ ]
+ expected = [DateManager.str_to_date(d) for d in expected]
+ assert run_dates == expected
+
+
+def test_run_dates_daily():
+ cfg = ScheduleConfig(frequency="daily")
+
+ run_dates = DateManager.run_dates(
+ date_from=DateManager.str_to_date("2023-03-28"),
+ date_to=DateManager.str_to_date("2023-04-03"),
+ schedule=cfg,
+ )
+
+ expected = [
+ "2023-03-28",
+ "2023-03-29",
+ "2023-03-30",
+ "2023-03-31",
+ "2023-04-01",
+ "2023-04-02",
+ "2023-04-03",
+ ]
+ expected = [DateManager.str_to_date(d) for d in expected]
+ assert run_dates == expected
+
+
+def test_run_dates_invalid():
+ cfg = ScheduleConfig(frequency="random")
+ with pytest.raises(ValueError) as exception:
+ DateManager.run_dates(
+ date_from=DateManager.str_to_date("2023-03-28"),
+ date_to=DateManager.str_to_date("2023-04-03"),
+ schedule=cfg,
+ )
+ assert str(exception.value) == "Unknown frequency random"
+
+
+@pytest.mark.parametrize(
+ "shift, res",
+ [(7, "2023-02-26"), (3, "2023-03-02"), (-5, "2023-03-10"), (0, "2023-03-05")],
+)
+def test_to_info_date(shift, res):
+ cfg = ScheduleConfig(frequency="daily", info_date_shift=IntervalConfig(units="days", value=shift))
+ base = DateManager.str_to_date("2023-03-05")
+ info = DateManager.to_info_date(base, cfg)
+ assert DateManager.str_to_date(res) == info
+
+
+@pytest.mark.parametrize(
+ "unit, result",
+ [("days", "2023-03-02"), ("weeks", "2023-02-12"), ("months", "2022-12-05"), ("years", "2020-03-05")],
+)
+def test_info_date_shift_units(unit, result):
+ cfg = ScheduleConfig(frequency="daily", info_date_shift=IntervalConfig(units=unit, value=3))
+ base = DateManager.str_to_date("2023-03-05")
+ info = DateManager.to_info_date(base, cfg)
+ assert DateManager.str_to_date(result) == info
+
+
+def test_info_date_shift_combined():
+ cfg = ScheduleConfig(
+ frequency="daily",
+ info_date_shift=[IntervalConfig(units="months", value=3), IntervalConfig(units="days", value=4)],
+ )
+ base = DateManager.str_to_date("2023-03-05")
+ info = DateManager.to_info_date(base, cfg)
+ assert DateManager.str_to_date("2022-12-01") == info
diff --git a/tests/runner/test_runner.py b/tests/runner/test_runner.py
new file mode 100644
index 0000000..0459411
--- /dev/null
+++ b/tests/runner/test_runner.py
@@ -0,0 +1,360 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from collections import namedtuple
+from datetime import datetime
+from typing import Optional
+
+import pytest
+from pyspark.sql import DataFrame
+
+from rialto.common.table_reader import DataReader
+from rialto.jobs.configuration.config_holder import ConfigHolder
+from rialto.runner.runner import DateManager, Runner
+from rialto.runner.table import Table
+from tests.runner.runner_resources import (
+ dep1_data,
+ dep2_data,
+ general_schema,
+ simple_group_data,
+)
+from tests.runner.transformations.simple_group import SimpleGroup
+
+
+class MockReader(DataReader):
+ def __init__(self, spark):
+ self.spark = spark
+
+ def get_table(
+ self,
+ table: str,
+ info_date_from: Optional[datetime.date] = None,
+ info_date_to: Optional[datetime.date] = None,
+ date_column: str = None,
+ uppercase_columns: bool = False,
+ ) -> DataFrame:
+ if table == "catalog.schema.simple_group":
+ return self.spark.createDataFrame(simple_group_data, general_schema)
+ if table == "source.schema.dep1":
+ return self.spark.createDataFrame(dep1_data, general_schema)
+ if table == "source.schema.dep2":
+ return self.spark.createDataFrame(dep2_data, general_schema)
+
+ def get_latest(
+ self,
+ table: str,
+ until: Optional[datetime.date] = None,
+ date_column: str = None,
+ uppercase_columns: bool = False,
+ ) -> DataFrame:
+ pass
+
+
+def test_table_exists(spark, mocker, basic_runner):
+ mock = mocker.patch("pyspark.sql.Catalog.tableExists", return_value=True)
+ basic_runner._table_exists("abc")
+ mock.assert_called_once_with("abc")
+
+
+def test_infer_column(spark, mocker, basic_runner):
+ column = namedtuple("catalog", ["name", "isPartition"])
+ catalog = [column("a", True), column("b", False), column("c", False)]
+
+ mock = mocker.patch("pyspark.sql.Catalog.listColumns", return_value=catalog)
+ partition = basic_runner._delta_partition("aaa")
+ assert partition == "a"
+ mock.assert_called_once_with("aaa")
+
+
+def test_load_module(spark, basic_runner):
+ module = basic_runner._load_module(basic_runner.config.pipelines[0].module)
+ assert isinstance(module, SimpleGroup)
+
+
+def test_generate(spark, mocker, basic_runner):
+ run = mocker.patch("tests.runner.transformations.simple_group.SimpleGroup.run")
+ group = SimpleGroup()
+ basic_runner._generate(group, DateManager.str_to_date("2023-01-31"))
+ run.assert_called_once_with(
+ reader=basic_runner.reader,
+ run_date=DateManager.str_to_date("2023-01-31"),
+ spark=spark,
+ metadata_manager=basic_runner.metadata,
+ dependencies=None,
+ )
+
+
+def test_generate_w_dep(spark, mocker, basic_runner):
+ run = mocker.patch("tests.runner.transformations.simple_group.SimpleGroup.run")
+ group = SimpleGroup()
+ basic_runner._generate(group, DateManager.str_to_date("2023-01-31"), basic_runner.config.pipelines[2].dependencies)
+ run.assert_called_once_with(
+ reader=basic_runner.reader,
+ run_date=DateManager.str_to_date("2023-01-31"),
+ spark=spark,
+ metadata_manager=basic_runner.metadata,
+ dependencies={
+ "source1": basic_runner.config.pipelines[2].dependencies[0],
+ "source2": basic_runner.config.pipelines[2].dependencies[1],
+ },
+ )
+
+
+def test_init_dates(spark):
+ runner = Runner(
+ spark, config_path="tests/runner/transformations/config.yaml", feature_metadata_schema="", run_date="2023-03-31"
+ )
+ assert runner.date_from == DateManager.str_to_date("2023-01-31")
+ assert runner.date_until == DateManager.str_to_date("2023-03-31")
+
+ runner = Runner(
+ spark,
+ config_path="tests/runner/transformations/config.yaml",
+ feature_metadata_schema="",
+ date_from="2023-03-01",
+ date_until="2023-03-31",
+ )
+ assert runner.date_from == DateManager.str_to_date("2023-03-01")
+ assert runner.date_until == DateManager.str_to_date("2023-03-31")
+
+ runner = Runner(
+ spark,
+ config_path="tests/runner/transformations/config2.yaml",
+ feature_metadata_schema="",
+ run_date="2023-03-31",
+ )
+ assert runner.date_from == DateManager.str_to_date("2023-02-24")
+ assert runner.date_until == DateManager.str_to_date("2023-03-31")
+
+
+def test_possible_run_dates(spark):
+ runner = Runner(
+ spark,
+ config_path="tests/runner/transformations/config.yaml",
+ feature_metadata_schema="",
+ date_from="2023-03-01",
+ date_until="2023-03-31",
+ )
+
+ dates = runner.get_possible_run_dates(runner.config.pipelines[0].schedule)
+ expected = ["2023-03-05", "2023-03-12", "2023-03-19", "2023-03-26"]
+ assert dates == [DateManager.str_to_date(d) for d in expected]
+
+
+def test_info_dates(spark, basic_runner):
+ run = ["2023-02-05", "2023-02-12", "2023-02-19", "2023-02-26", "2023-03-05"]
+ run = [DateManager.str_to_date(d) for d in run]
+ info = basic_runner.get_info_dates(basic_runner.config.pipelines[0].schedule, run)
+ expected = ["2023-02-02", "2023-02-09", "2023-02-16", "2023-02-23", "2023-03-02"]
+ assert info == [DateManager.str_to_date(d) for d in expected]
+
+
+def test_completion(spark, mocker, basic_runner):
+ mocker.patch("rialto.runner.runner.Runner._table_exists", return_value=True)
+
+ basic_runner.reader = MockReader(spark)
+
+ dates = ["2023-02-26", "2023-03-05", "2023-03-12", "2023-03-19", "2023-03-26"]
+ dates = [DateManager.str_to_date(d) for d in dates]
+
+ comp = basic_runner._get_completion(Table(table_path="catalog.schema.simple_group", partition="DATE"), dates)
+ expected = [False, True, True, True, False]
+ assert comp == expected
+
+
+def test_completion_rerun(spark, mocker, basic_runner):
+ mocker.patch("rialto.runner.runner.Runner._table_exists", return_value=True)
+
+ runner = Runner(
+ spark, config_path="tests/runner/transformations/config.yaml", feature_metadata_schema="", run_date="2023-03-31"
+ )
+ runner.reader = MockReader(spark)
+
+ dates = ["2023-02-26", "2023-03-05", "2023-03-12", "2023-03-19", "2023-03-26"]
+ dates = [DateManager.str_to_date(d) for d in dates]
+
+ comp = runner._get_completion(Table(table_path="catalog.schema.simple_group", partition="DATE"), dates)
+ expected = [False, True, True, True, False]
+ assert comp == expected
+
+
+def test_check_dates_have_partition(spark, mocker):
+ mocker.patch("rialto.runner.runner.Runner._table_exists", return_value=True)
+
+ runner = Runner(
+ spark,
+ config_path="tests/runner/transformations/config.yaml",
+ feature_metadata_schema="",
+ date_from="2023-03-01",
+ date_until="2023-03-31",
+ )
+ runner.reader = MockReader(spark)
+ dates = ["2023-03-04", "2023-03-05", "2023-03-06"]
+ dates = [DateManager.str_to_date(d) for d in dates]
+ res = runner.check_dates_have_partition(Table(schema_path="source.schema", table="dep1", partition="DATE"), dates)
+ expected = [False, True, False]
+ assert res == expected
+
+
+def test_check_dates_have_partition_no_table(spark, mocker):
+ mocker.patch("rialto.runner.runner.Runner._table_exists", return_value=False)
+
+ runner = Runner(
+ spark,
+ config_path="tests/runner/transformations/config.yaml",
+ feature_metadata_schema="",
+ date_from="2023-03-01",
+ date_until="2023-03-31",
+ )
+ dates = ["2023-03-04", "2023-03-05", "2023-03-06"]
+ dates = [DateManager.str_to_date(d) for d in dates]
+ res = runner.check_dates_have_partition(Table(schema_path="source.schema", table="dep66", partition="DATE"), dates)
+ expected = [False, False, False]
+ assert res == expected
+
+
+@pytest.mark.parametrize(
+ "r_date, expected",
+ [("2023-02-26", False), ("2023-03-05", True)],
+)
+def test_check_dependencies(spark, mocker, r_date, expected):
+ mocker.patch("rialto.runner.runner.Runner._table_exists", return_value=True)
+
+ runner = Runner(
+ spark,
+ config_path="tests/runner/transformations/config.yaml",
+ feature_metadata_schema="",
+ date_from="2023-03-01",
+ date_until="2023-03-31",
+ )
+ runner.reader = MockReader(spark)
+ res = runner.check_dependencies(runner.config.pipelines[0], DateManager.str_to_date(r_date))
+ assert res == expected
+
+
+def test_check_no_dependencies(spark, mocker):
+ mocker.patch("rialto.runner.runner.Runner._table_exists", return_value=True)
+
+ runner = Runner(
+ spark,
+ config_path="tests/runner/transformations/config.yaml",
+ feature_metadata_schema="",
+ date_from="2023-03-01",
+ date_until="2023-03-31",
+ )
+ runner.reader = MockReader(spark)
+ res = runner.check_dependencies(runner.config.pipelines[1], DateManager.str_to_date("2023-03-05"))
+ assert res is True
+
+
+def test_select_dates(spark, mocker):
+ mocker.patch("rialto.runner.runner.Runner._table_exists", return_value=True)
+
+ runner = Runner(
+ spark,
+ config_path="tests/runner/transformations/config.yaml",
+ feature_metadata_schema="",
+ date_from="2023-03-01",
+ date_until="2023-03-31",
+ )
+ runner.reader = MockReader(spark)
+
+ r, i = runner._select_run_dates(
+ runner.config.pipelines[0], Table(table_path="catalog.schema.simple_group", partition="DATE")
+ )
+ expected_run = ["2023-03-05", "2023-03-12", "2023-03-19", "2023-03-26"]
+ expected_run = [DateManager.str_to_date(d) for d in expected_run]
+ expected_info = ["2023-03-02", "2023-03-09", "2023-03-16", "2023-03-23"]
+ expected_info = [DateManager.str_to_date(d) for d in expected_info]
+ assert r == expected_run
+ assert i == expected_info
+
+
+def test_select_dates_all_done(spark, mocker):
+ mocker.patch("rialto.runner.runner.Runner._table_exists", return_value=True)
+
+ runner = Runner(
+ spark,
+ config_path="tests/runner/transformations/config.yaml",
+ feature_metadata_schema="",
+ date_from="2023-03-02",
+ date_until="2023-03-02",
+ )
+ runner.reader = MockReader(spark)
+
+ r, i = runner._select_run_dates(
+ runner.config.pipelines[0], Table(table_path="catalog.schema.simple_group", partition="DATE")
+ )
+ expected_run = []
+ expected_run = [DateManager.str_to_date(d) for d in expected_run]
+ expected_info = []
+ expected_info = [DateManager.str_to_date(d) for d in expected_info]
+ assert r == expected_run
+ assert i == expected_info
+
+
+def test_op_selected(spark, mocker):
+ mocker.patch("rialto.runner.tracker.Tracker.report")
+ run = mocker.patch("rialto.runner.runner.Runner._run_pipeline")
+
+ runner = Runner(
+ spark, config_path="tests/runner/transformations/config.yaml", feature_metadata_schema="", op="SimpleGroup"
+ )
+
+ runner()
+ run.called_once()
+
+
+def test_op_bad(spark, mocker):
+ mocker.patch("rialto.runner.tracker.Tracker.report")
+ mocker.patch("rialto.runner.runner.Runner._run_pipeline")
+
+ runner = Runner(
+ spark, config_path="tests/runner/transformations/config.yaml", feature_metadata_schema="", op="BadOp"
+ )
+
+ with pytest.raises(ValueError) as exception:
+ runner()
+ assert str(exception.value) == "Unknown operation selected: BadOp"
+
+
+def test_custom_config(spark, mocker):
+ cc_spy = mocker.spy(ConfigHolder, "set_custom_config")
+ custom_config = {"cc": 42}
+
+ _ = Runner(spark, config_path="tests/runner/transformations/config.yaml", custom_job_config=custom_config)
+
+ cc_spy.assert_called_once_with(cc=42)
+
+
+def test_feature_store_config(spark, mocker):
+ fs_spy = mocker.spy(ConfigHolder, "set_feature_store_config")
+
+ _ = Runner(
+ spark,
+ config_path="tests/runner/transformations/config.yaml",
+ feature_store_schema="schema",
+ feature_metadata_schema="metadata",
+ )
+
+ fs_spy.assert_called_once_with("schema", "metadata")
+
+
+def test_no_configs(spark, mocker):
+ cc_spy = mocker.spy(ConfigHolder, "set_custom_config")
+ fs_spy = mocker.spy(ConfigHolder, "set_feature_store_config")
+
+ _ = Runner(spark, config_path="tests/runner/transformations/config.yaml")
+
+ cc_spy.assert_not_called()
+ fs_spy.assert_not_called()
diff --git a/tests/runner/test_table.py b/tests/runner/test_table.py
new file mode 100644
index 0000000..82e6fa6
--- /dev/null
+++ b/tests/runner/test_table.py
@@ -0,0 +1,28 @@
+from rialto.runner.table import Table
+
+
+def test_table_basic_init():
+ t = Table(catalog="cat", schema="sch", table="tab", schema_path=None, table_path=None, class_name=None)
+
+ assert t.get_table_path() == "cat.sch.tab"
+ assert t.get_schema_path() == "cat.sch"
+
+
+def test_table_classname_init():
+ t = Table(catalog=None, schema=None, table=None, schema_path="cat.sch", table_path=None, class_name="ClaSs")
+
+ assert t.get_table_path() == "cat.sch.cla_ss"
+ assert t.get_schema_path() == "cat.sch"
+ assert t.catalog == "cat"
+ assert t.schema == "sch"
+ assert t.table == "cla_ss"
+
+
+def test_table_path_init():
+ t = Table(catalog=None, schema=None, table=None, schema_path=None, table_path="cat.sch.tab", class_name=None)
+
+ assert t.get_table_path() == "cat.sch.tab"
+ assert t.get_schema_path() == "cat.sch"
+ assert t.catalog == "cat"
+ assert t.schema == "sch"
+ assert t.table == "tab"
diff --git a/tests/runner/transformations/__init__.py b/tests/runner/transformations/__init__.py
new file mode 100644
index 0000000..eaa15cd
--- /dev/null
+++ b/tests/runner/transformations/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from tests.runner.transformations.simple_group import SimpleGroup # noqa
diff --git a/tests/runner/transformations/config.yaml b/tests/runner/transformations/config.yaml
new file mode 100644
index 0000000..2bfeaf1
--- /dev/null
+++ b/tests/runner/transformations/config.yaml
@@ -0,0 +1,82 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+general:
+ target_schema: catalog.schema
+ target_partition_column: "INFORMATION_DATE"
+ watched_period_units: "months"
+ watched_period_value: 2
+ job: "run" # run/check
+ mail:
+ sender: test@testing.org
+ smtp: server.test
+ to:
+ - developer@testing.org
+ - developer2@testing.org
+ subject: test report
+pipelines:
+ - name: SimpleGroup
+ module:
+ python_module: tests.runner.transformations
+ python_class: SimpleGroup
+ schedule:
+ frequency: weekly
+ day: 7
+ info_date_shift:
+ value: 3
+ units: days
+ dependencies:
+ - table: source.schema.dep1
+ interval:
+ units: "days"
+ value: 1
+ date_col: "DATE"
+ - table: source.schema.dep2
+ interval:
+ units: "months"
+ value: 3
+ date_col: "DATE"
+ - name: GroupNoDeps
+ module:
+ python_module: tests.runner.transformations
+ python_class: SimpleGroup
+ schedule:
+ frequency: weekly
+ day: 7
+ info_date_shift:
+ value: 3
+ units: days
+ - name: NamedDeps
+ module:
+ python_module: tests.runner.transformations
+ python_class: SimpleGroup
+ schedule:
+ frequency: weekly
+ day: 7
+ info_date_shift:
+ value: 3
+ units: days
+ dependencies:
+ - table: source.schema.dep1
+ name: source1
+ interval:
+ units: "days"
+ value: 1
+ date_col: "DATE"
+ - table: source.schema.dep2
+ name: source2
+ interval:
+ units: "months"
+ value: 3
+ date_col: "batch"
diff --git a/tests/runner/transformations/config2.yaml b/tests/runner/transformations/config2.yaml
new file mode 100644
index 0000000..a91894b
--- /dev/null
+++ b/tests/runner/transformations/config2.yaml
@@ -0,0 +1,45 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+general:
+ target_schema: catalog.schema
+ target_partition_column: "INFORMATION_DATE"
+ watched_period_units: "weeks"
+ watched_period_value: 5
+ job: "run" # run/check
+ mail:
+ sender: test@testing.org
+ smtp: server.test
+ to:
+ - developer@testing.org
+ subject: test report
+pipelines:
+- name: SimpleGroup
+ module:
+ python_module: transformations
+ python_class: SimpleGroup
+ schedule:
+ frequency: weekly
+ day: 7
+ dependencies:
+ - table: source.schema.dep1
+ interval:
+ units: "days"
+ value: 1
+ date_col: "DATE"
+ - table: source.schema.dep2
+ interval:
+ units: "months"
+ value: 1
+ date_col: "DATE"
diff --git a/tests/runner/transformations/simple_group.py b/tests/runner/transformations/simple_group.py
new file mode 100644
index 0000000..fcda5c7
--- /dev/null
+++ b/tests/runner/transformations/simple_group.py
@@ -0,0 +1,34 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from datetime import datetime
+from typing import Dict
+
+from pyspark.sql import DataFrame, SparkSession
+from pyspark.sql.types import StructType
+
+from rialto.common import TableReader
+from rialto.metadata import MetadataManager
+from rialto.runner import Transformation
+
+
+class SimpleGroup(Transformation):
+ def run(
+ self,
+ reader: TableReader,
+ run_date: datetime.date,
+ spark: SparkSession = None,
+ metadata_manager: MetadataManager = None,
+ dependencies: Dict = None,
+ ) -> DataFrame:
+ return spark.createDataFrame([], StructType([]))