From 93815c23fa21a191f53df61408210926f3117068 Mon Sep 17 00:00:00 2001 From: Marek Dobransky Date: Tue, 3 Sep 2024 11:35:05 +0200 Subject: [PATCH] RLG-3595 rialto v2 (#12) * v2 changes --- CHANGELOG.md | 27 +++ README.md | 176 +++++++++------ poetry.lock | 31 +-- pyproject.toml | 5 +- rialto/common/__init__.py | 2 +- rialto/common/table_reader.py | 70 ++---- rialto/common/utils.py | 39 +--- rialto/jobs/__init__.py | 2 + rialto/jobs/configuration/config_holder.py | 130 ----------- rialto/jobs/decorators/__init__.py | 2 +- rialto/jobs/decorators/decorators.py | 22 +- rialto/jobs/decorators/job_base.py | 46 ++-- rialto/jobs/decorators/resolver.py | 4 +- rialto/jobs/decorators/test_utils.py | 7 +- rialto/loader/__init__.py | 1 - rialto/loader/data_loader.py | 45 ---- rialto/loader/interfaces.py | 20 +- rialto/loader/pyspark_feature_loader.py | 43 ++-- rialto/runner/config_loader.py | 59 +++-- rialto/runner/config_overrides.py | 76 +++++++ rialto/runner/runner.py | 215 +++++-------------- rialto/runner/tracker.py | 13 +- rialto/runner/transformation.py | 15 +- rialto/runner/utils.py | 74 +++++++ tests/jobs/test_config_holder.py | 100 --------- tests/jobs/test_decorators.py | 11 +- tests/jobs/test_job/dependency_tests_job.py | 4 +- tests/jobs/test_job/test_job.py | 12 +- tests/jobs/test_job_base.py | 34 ++- tests/loader/pyspark/dummy_loaders.py | 24 --- tests/loader/pyspark/test_from_cfg.py | 11 +- tests/runner/conftest.py | 4 +- tests/runner/overrider.yaml | 86 ++++++++ tests/runner/test_date_manager.py | 4 +- tests/runner/test_overrides.py | 137 ++++++++++++ tests/runner/test_runner.py | 164 ++++---------- tests/runner/transformations/config.yaml | 23 +- tests/runner/transformations/config2.yaml | 8 +- tests/runner/transformations/simple_group.py | 6 +- 39 files changed, 844 insertions(+), 908 deletions(-) delete mode 100644 rialto/jobs/configuration/config_holder.py delete mode 100644 rialto/loader/data_loader.py create mode 100644 rialto/runner/config_overrides.py create mode 100644 rialto/runner/utils.py delete mode 100644 tests/jobs/test_config_holder.py delete mode 100644 tests/loader/pyspark/dummy_loaders.py create mode 100644 tests/runner/overrider.yaml create mode 100644 tests/runner/test_overrides.py diff --git a/CHANGELOG.md b/CHANGELOG.md index cfd48eb..63e9791 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,33 @@ # Change Log All notable changes to this project will be documented in this file. +## 2.0.0 - 2024-mm-dd + #### Runner + - runner config now accepts environment variables + - restructured runner config + - added metadata and feature loader sections + - target moved to pipeline + - dependency date_col is now mandatory + - custom extras config is available in each pipeline and will be passed as dictionary available under pipeline_config.extras + - general section is renamed to runner + - info_date_shift is always a list + - transformation header changed + - added argument to skip dependency checking + - added overrides parameter to allow for dynamic overriding of config values + - removed date_from and date_to from arguments, use overrides instead + #### Jobs + - jobs are now the main way to create all pipelines + - config holder removed from jobs + - metadata_manager and feature_loader are now available arguments, depending on configuration + - added @config decorator, similar use case to @datasource, for parsing configuration + #### TableReader + - function signatures changed + - until -> date_until + - info_date_from -> date_from, info_date_to -> date_to + - date_column is now mandatory + - removed TableReaders ability to infer schema from partitions or properties + #### Loader + - removed DataLoader class, now only PysparkFeatureLoader is needed with additional parameters ## 1.3.0 - 2024-06-07 diff --git a/README.md b/README.md index 4f52d50..2ac915f 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ - +from pydantic import BaseModelfrom rialto.runner.config_loader import PipelineConfigfrom rialto.jobs import config # Rialto @@ -53,31 +53,21 @@ runner() A runner by default executes all the jobs provided in the configuration file, for all the viable execution dates according to the configuration file for which the job has not yet run successfully (i.e. the date partition doesn't exist on the storage) This behavior can be modified by various parameters and switches available. -* **feature_metadata_schema** - path to schema where feature metadata are read and stored, needed for [maker](#maker) jobs and jobs that utilized feature [loader](#loader) * **run_date** - date at which the runner is triggered (defaults to day of running) -* **date_from** - starting date (defaults to rundate - config watch period) -* **date_until** - end date (defaults to rundate) -* **feature_store_schema** - location of features, needed for jobs utilizing feature [loader](#loader) -* **custom_job_config** - dictionary with key-value pairs that will be accessible under the "config" variable in your rialto jobs * **rerun** - rerun all jobs even if they already succeeded in the past runs * **op** - run only selected operation / pipeline - +* **skip_dependencies** - ignore dependency checks and run all jobs +* **overrides** - dictionary of overrides for the configuration Transformations are not included in the runner itself, it imports them dynamically according to the configuration, therefore it's necessary to have them locally installed. -A runner created table has will have automatically created **rialto_date_column** table property set according to target partition set in the configuration. - ### Configuration ```yaml -general: - target_schema: catalog.schema # schema where tables will be created, must exist - target_partition_column: INFORMATION_DATE # date to partition new tables on - source_date_column_property: rialto_date_column # name of the date property on source tables +runner: watched_period_units: "months" # unit of default run period watched_period_value: 2 # value of default run period - job: "run" # run for running the pipelines, check for only checking dependencies mail: to: # a list of email addresses - name@host.domain @@ -100,7 +90,7 @@ pipelines: # a list of pipelines to run dependencies: # list of dependent tables - table: catalog.schema.table1 name: "table1" # Optional table name, used to recall dependency details in transformation - date_col: generation_date # Optional date column name, takes priority + date_col: generation_date # Mandatory date column name interval: # mandatory availability interval, subtracted from scheduled day units: "days" value: 1 @@ -109,6 +99,18 @@ pipelines: # a list of pipelines to run interval: units: "months" value: 1 + target: + target_schema: catalog.schema # schema where tables will be created, must exist + target_partition_column: INFORMATION_DATE # date to partition new tables on + metadata_manager: # optional + metadata_schema: catalog.metadata # schema where metadata is stored + feature_loader: # optional + config_path: model_features_config.yaml # path to the feature loader configuration file + feature_schema: catalog.feature_tables # schema where feature tables are stored + metadata_schema: catalog.metadata # schema where metadata is stored + extras: #optional arguments processed as dictionary + some_value: 3 + some_other_value: giraffe - name: PipelineTable1 # will be written as pipeline_table1 module: @@ -127,8 +129,67 @@ pipelines: # a list of pipelines to run interval: units: "days" value: 6 + target: + target_schema: catalog.schema # schema where tables will be created, must exist + target_partition_column: INFORMATION_DATE # date to partition new tables on +``` + +The configuration can be dynamically overridden by providing a dictionary of overrides to the runner. All overrides must adhere to configurations schema, with pipeline.extras section available for custom schema. +Here are few examples of overrides: + +#### Simple override of a single value +Specify the path to the value in the configuration file as a dot-separated string + +```python +Runner( + spark, + config_path="tests/overrider.yaml", + run_date="2023-03-31", + overrides={"runner.watch_period_value": 4}, + ) +``` + +#### Override list element +You can refer to list elements by their index (starting with 0) +```python +overrides={"runner.mail.to[1]": "a@b.c"} +``` + +#### Append to list +You can append to list by using index -1 +```python +overrides={"runner.mail.to[-1]": "test@test.com"} +``` + +#### Lookup by attribute value in a list +You can use the following syntax to find a specific element in a list by its attribute value +```python +overrides={"pipelines[name=SimpleGroup].target.target_schema": "new_schema"}, +``` + +#### Injecting/Replacing whole sections +You can directly replace a bigger section of the configuration by providing a dictionary +When the whole section doesn't exist, it will be added to the configuration, however it needs to be added as a whole. +i.e. if the yaml file doesn't specify feature_loader, you can't just add a feature_loader.config_path, you need to add the whole section. +```python +overrides={"pipelines[name=SimpleGroup].feature_loader": + {"config_path": "features_cfg.yaml", + "feature_schema": "catalog.features", + "metadata_schema": "catalog.metadata"}} ``` +#### Multiple overrides +You can provide multiple overrides at once, the order of execution is not guaranteed +```python +overrides={"runner.watch_period_value": 4, + "runner.watch_period_units": "weeks", + "pipelines[name=SimpleGroup].target.target_schema": "new_schema", + "pipelines[name=SimpleGroup].feature_loader": + {"config_path": "features_cfg.yaml", + "feature_schema": "catalog.features", + "metadata_schema": "catalog.metadata"} + } +``` ## 2.2 - maker @@ -302,6 +363,7 @@ We have a set of pre-defined dependencies: * **dependencies** returns a dictionary containing the job dependencies config * **table_reader** returns *TableReader* * **feature_loader** provides *PysparkFeatureLoader* +* **metadata_manager** provides *MetadataManager* Apart from that, each **datasource** also becomes a fully usable dependency. Note, that this means that datasources can also be dependent on other datasources - just beware of any circular dependencies! @@ -310,19 +372,30 @@ With that sorted out, we can now provide a quick example of the *rialto.jobs* mo ```python from pyspark.sql import DataFrame from rialto.common import TableReader -from rialto.jobs.decorators import job, datasource +from rialto.jobs.decorators import config, job, datasource +from rialto.runner.config_loader import PipelineConfig +from pydantic import BaseModel + + +class ConfigModel(BaseModel): + some_value: int + some_other_value: str + +@config +def my_config(config: PipelineConfig): + return ConfigModel(**config.extras) @datasource def my_datasource(run_date: datetime.date, table_reader: TableReader) -> DataFrame: - return table_reader.get_latest("my_catalog.my_schema.my_table", until=run_date) + return table_reader.get_latest("my_catalog.my_schema.my_table", date_until=run_date) @job -def my_job(my_datasource: DataFrame) -> DataFrame: - return my_datasource.withColumn("HelloWorld", F.lit(1)) +def my_job(my_datasource: DataFrame, my_config: ConfigModel) -> DataFrame: + return my_datasource.withColumn("HelloWorld", F.lit(my_config.some_value)) ``` -This piece of code -1. creates a rialto transformation called *my_job*, which is then callable by the rialto runner. +This piece of code +1. creates a rialto transformation called *my_job*, which is then callable by the rialto runner. 2. It sources the *my_datasource* and then runs *my_job* on top of that datasource. 3. Rialto adds VERSION (of your package) and INFORMATION_DATE (as per config) columns automatically. 4. The rialto runner stores the final to a catalog, to a table according to the job's name. @@ -383,20 +456,20 @@ import my_package.test_job_module as tjm # Datasource Testing def test_datasource_a(): ... mocks here ... - + with disable_job_decorators(tjm): datasource_a_output = tjm.datasource_a(... mocks ...) - + ... asserts ... - + # Job Testing def test_my_job(): datasource_a_mock = ... ... other mocks... - + with disable_job_decorators(tjm): job_output = tjm.my_job(datasource_a_mock, ... mocks ...) - + ... asserts ... ``` @@ -418,19 +491,6 @@ This module is used to load features from feature store into your models and scr Two public classes are exposed form this module. **DatabricksLoader**(DataLoader), **PysparkFeatureLoader**(FeatureLoaderInterface). -### DatabricksLoader -This is a support class for feature loader and provides the data reading capability from the feature store. - -This class needs to be instantiated with an active spark session and a path to the feature store schema (in the format of "catalog_name.schema_name"). -Optionally a date_column information can be passed, otherwise it defaults to use INFORMATION_DATE -```python -from rialto.loader import DatabricksLoader - -data_loader = DatabricksLoader(spark= spark_instance, schema= "catalog.schema", date_column= "INFORMATION_DATE") -``` - -This class provides one method, read_group(...), which returns a whole feature group for selected date. This is mostly used inside feature loader. - ### PysparkFeatureLoader This class needs to be instantiated with an active spark session, data loader and a path to the metadata schema (in the format of "catalog_name.schema_name"). @@ -438,17 +498,16 @@ This class needs to be instantiated with an active spark session, data loader an ```python from rialto.loader import PysparkFeatureLoader -feature_loader = PysparkFeatureLoader(spark= spark_instance, data_loader= data_loader_instance, metadata_schema= "catalog.schema") +feature_loader = PysparkFeatureLoader(spark= spark_instance, feature_schema="catalog.schema", metadata_schema= "catalog.schema2", date_column="information_date") ``` #### Single feature ```python -from rialto.loader import DatabricksLoader, PysparkFeatureLoader +from rialto.loader import PysparkFeatureLoader from datetime import datetime -data_loader = DatabricksLoader(spark, "feature_catalog.feature_schema") -feature_loader = PysparkFeatureLoader(spark, data_loader, "metadata_catalog.metadata_schema") +feature_loader = PysparkFeatureLoader(spark, "feature_catalog.feature_schema", "metadata_catalog.metadata_schema") my_date = datetime.strptime("2020-01-01", "%Y-%m-%d").date() feature = feature_loader.get_feature(group_name="CustomerFeatures", feature_name="AGE", information_date=my_date) @@ -459,11 +518,10 @@ metadata = feature_loader.get_feature_metadata(group_name="CustomerFeatures", fe This method of data access is only recommended for experimentation, as the group schema can evolve over time. ```python -from rialto.loader import DatabricksLoader, PysparkFeatureLoader +from rialto.loader import PysparkFeatureLoader from datetime import datetime -data_loader = DatabricksLoader(spark, "feature_catalog.feature_schema") -feature_loader = PysparkFeatureLoader(spark, data_loader, "metadata_catalog.metadata_schema") +feature_loader = PysparkFeatureLoader(spark, "feature_catalog.feature_schema", "metadata_catalog.metadata_schema") my_date = datetime.strptime("2020-01-01", "%Y-%m-%d").date() features = feature_loader.get_group(group_name="CustomerFeatures", information_date=my_date) @@ -473,11 +531,10 @@ metadata = feature_loader.get_group_metadata(group_name="CustomerFeatures") #### Configuration ```python -from rialto.loader import DatabricksLoader, PysparkFeatureLoader +from rialto.loader import PysparkFeatureLoader from datetime import datetime -data_loader = DatabricksLoader(spark, "feature_catalog.feature_schema") -feature_loader = PysparkFeatureLoader(spark, data_loader, "metadata_catalog.metadata_schema") +feature_loader = PysparkFeatureLoader(spark, "feature_catalog.feature_schema", "metadata_catalog.metadata_schema") my_date = datetime.strptime("2020-01-01", "%Y-%m-%d").date() features = feature_loader.get_features_from_cfg(path="local/configuration/file.yaml", information_date=my_date) @@ -563,6 +620,7 @@ reader = TableReader(spark=spark_instance) ``` usage of _get_table_: + ```python # get whole table df = reader.get_table(table="catalog.schema.table", date_column="information_date") @@ -573,10 +631,11 @@ from datetime import datetime start = datetime.strptime("2020-01-01", "%Y-%m-%d").date() end = datetime.strptime("2024-01-01", "%Y-%m-%d").date() -df = reader.get_table(table="catalog.schema.table", info_date_from=start, info_date_to=end) +df = reader.get_table(table="catalog.schema.table", date_from=start, date_to=end, date_column="information_date") ``` usage of _get_latest_: + ```python # most recent partition df = reader.get_latest(table="catalog.schema.table", date_column="information_date") @@ -584,7 +643,7 @@ df = reader.get_latest(table="catalog.schema.table", date_column="information_da # most recent partition until until = datetime.strptime("2020-01-01", "%Y-%m-%d").date() -df = reader.get_latest(table="catalog.schema.table", until=until, date_column="information_date") +df = reader.get_latest(table="catalog.schema.table", date_until=until, date_column="information_date") ``` For full information on parameters and their optionality see technical documentation. @@ -592,21 +651,6 @@ For full information on parameters and their optionality see technical documenta _TableReader_ needs an active spark session and an information which column is the **date column**. There are three options how to pass that information on. -In order of priority from highest: -* Explicit _date_column_ parameter in _get_table_ and _get_latest_ -```python -reader.get_latest(table="catalog.schema.table", date_column="information_date") -``` -* Inferred from delta metadata, triggered by init parameter, only works on delta tables (e.g. doesn't work on views) -```python -reader = TableReader(spark=spark_instance, infer_partition=True) -reader.get_latest(table="catalog.schema.table") -``` -* A custom sql property defined on the table containing the date column name, defaults to _rialto_date_column_ -```python -reader = TableReader(spark=spark_instance, date_property="rialto_date_column") -reader.get_latest(table="catalog.schema.table") -``` # 3. Contributing Contributing: diff --git a/poetry.lock b/poetry.lock index 0cb768b..66ca41b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -343,6 +343,20 @@ files = [ {file = "docutils-0.20.1.tar.gz", hash = "sha256:f08a4e276c3a1583a86dce3e34aba3fe04d02bba2dd51ed16106244e8a923e3b"}, ] +[[package]] +name = "env-yaml" +version = "0.0.3" +description = "Provides a yaml loader which substitutes environment variables and supports defaults" +optional = false +python-versions = "*" +files = [ + {file = "env-yaml-0.0.3.tar.gz", hash = "sha256:b6b55b18c28fb623793137a8e55bd666d6483af7fd0162a41a62325ce662fda6"}, + {file = "env_yaml-0.0.3-py3-none-any.whl", hash = "sha256:f56723c8997bea1240bf634b9e29832714dd9745a42cbc2649f1238a6a576244"}, +] + +[package.dependencies] +pyyaml = ">=6.0" + [[package]] name = "exceptiongroup" version = "1.2.2" @@ -751,9 +765,9 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, {version = ">=1.22.4", markers = "python_version < \"3.11\""}, {version = ">=1.23.2", markers = "python_version == \"3.11\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -906,8 +920,8 @@ files = [ annotated-types = ">=0.4.0" pydantic-core = "2.20.1" typing-extensions = [ - {version = ">=4.12.2", markers = "python_version >= \"3.13\""}, {version = ">=4.6.1", markers = "python_version < \"3.13\""}, + {version = ">=4.12.2", markers = "python_version >= \"3.13\""}, ] [package.extras] @@ -1170,7 +1184,6 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, - {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -1178,16 +1191,8 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, - {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, - {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, - {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, - {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -1204,7 +1209,6 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, - {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -1212,7 +1216,6 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, - {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -1544,4 +1547,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.10,<4.0" -content-hash = "243b1919c3e881039c2cd7b4e786f455b15a78872278050e7850e6a21c706c8e" +content-hash = "6e87c6539147b57b03fb983b28d15396c2eccfe95661805eda7d9f77602d1f58" diff --git a/pyproject.toml b/pyproject.toml index 8255885..5812612 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [tool.poetry] -name = "rialto" +name = "rialto-dev" -version = "1.3.2" +version = "2.0.0" packages = [ { include = "rialto" }, @@ -31,6 +31,7 @@ pandas = "^2.1.0" flake8-broken-line = "^1.0.0" loguru = "^0.7.2" importlib-metadata = "^7.2.1" +env_yaml = "^0.0.3" [tool.poetry.dev-dependencies] pyspark = "^3.4.1" diff --git a/rialto/common/__init__.py b/rialto/common/__init__.py index 93e8922..1bd5055 100644 --- a/rialto/common/__init__.py +++ b/rialto/common/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from rialto.common.table_reader import TableReader +from rialto.common.table_reader import DataReader, TableReader diff --git a/rialto/common/table_reader.py b/rialto/common/table_reader.py index 1aef614..d3926f2 100644 --- a/rialto/common/table_reader.py +++ b/rialto/common/table_reader.py @@ -21,8 +21,6 @@ import pyspark.sql.functions as F from pyspark.sql import DataFrame, SparkSession -from rialto.common.utils import get_date_col_property, get_delta_partition - class DataReader(metaclass=abc.ABCMeta): """ @@ -36,16 +34,15 @@ class DataReader(metaclass=abc.ABCMeta): def get_latest( self, table: str, - until: Optional[datetime.date] = None, - date_column: str = None, + date_column: str, + date_until: Optional[datetime.date] = None, uppercase_columns: bool = False, ) -> DataFrame: """ Get latest available date partition of the table until specified date :param table: input table path - :param until: Optional until date (inclusive) - :param date_column: column to filter dates on, takes highest priority + :param date_until: Optional until date (inclusive) :param uppercase_columns: Option to refactor all column names to uppercase :return: Dataframe """ @@ -55,18 +52,17 @@ def get_latest( def get_table( self, table: str, - info_date_from: Optional[datetime.date] = None, - info_date_to: Optional[datetime.date] = None, - date_column: str = None, + date_column: str, + date_from: Optional[datetime.date] = None, + date_to: Optional[datetime.date] = None, uppercase_columns: bool = False, ) -> DataFrame: """ Get a whole table or a slice by selected dates :param table: input table path - :param info_date_from: Optional date from (inclusive) - :param info_date_to: Optional date to (inclusive) - :param date_column: column to filter dates on, takes highest priority + :param date_from: Optional date from (inclusive) + :param date_to: Optional date to (inclusive) :param uppercase_columns: Option to refactor all column names to uppercase :return: Dataframe """ @@ -76,17 +72,13 @@ def get_table( class TableReader(DataReader): """An implementation of data reader for databricks tables""" - def __init__(self, spark: SparkSession, date_property: str = "rialto_date_column", infer_partition: bool = False): + def __init__(self, spark: SparkSession): """ Init :param spark: - :param date_property: Databricks table property specifying date column, take priority over inference - :param infer_partition: infer date column as tables partition from delta metadata """ self.spark = spark - self.date_property = date_property - self.infer_partition = infer_partition super().__init__() def _uppercase_column_names(self, df: DataFrame) -> DataFrame: @@ -106,41 +98,26 @@ def _get_latest_available_date(self, df: DataFrame, date_col: str, until: Option df = df.select(F.max(date_col)).alias("latest") return df.head()[0] - def _get_date_col(self, table: str, date_column: str): - """ - Get tables date column - - column specified at get_table/get_latest takes priority, if inference is enabled it - takes 2nd place, last resort is table property - """ - if date_column: - return date_column - elif self.infer_partition: - return get_delta_partition(self.spark, table) - else: - return get_date_col_property(self.spark, table, self.date_property) - def get_latest( self, table: str, - until: Optional[datetime.date] = None, - date_column: str = None, + date_column: str, + date_until: Optional[datetime.date] = None, uppercase_columns: bool = False, ) -> DataFrame: """ Get latest available date partition of the table until specified date :param table: input table path - :param until: Optional until date (inclusive) + :param date_until: Optional until date (inclusive) :param date_column: column to filter dates on, takes highest priority :param uppercase_columns: Option to refactor all column names to uppercase :return: Dataframe """ - date_col = self._get_date_col(table, date_column) df = self.spark.read.table(table) - selected_date = self._get_latest_available_date(df, date_col, until) - df = df.filter(F.col(date_col) == selected_date) + selected_date = self._get_latest_available_date(df, date_column, date_until) + df = df.filter(F.col(date_column) == selected_date) if uppercase_columns: df = self._uppercase_column_names(df) @@ -149,28 +126,27 @@ def get_latest( def get_table( self, table: str, - info_date_from: Optional[datetime.date] = None, - info_date_to: Optional[datetime.date] = None, - date_column: str = None, + date_column: str, + date_from: Optional[datetime.date] = None, + date_to: Optional[datetime.date] = None, uppercase_columns: bool = False, ) -> DataFrame: """ Get a whole table or a slice by selected dates :param table: input table path - :param info_date_from: Optional date from (inclusive) - :param info_date_to: Optional date to (inclusive) + :param date_from: Optional date from (inclusive) + :param date_to: Optional date to (inclusive) :param date_column: column to filter dates on, takes highest priority :param uppercase_columns: Option to refactor all column names to uppercase :return: Dataframe """ - date_col = self._get_date_col(table, date_column) df = self.spark.read.table(table) - if info_date_from: - df = df.filter(F.col(date_col) >= info_date_from) - if info_date_to: - df = df.filter(F.col(date_col) <= info_date_to) + if date_from: + df = df.filter(F.col(date_column) >= date_from) + if date_to: + df = df.filter(F.col(date_column) <= date_to) if uppercase_columns: df = self._uppercase_column_names(df) return df diff --git a/rialto/common/utils.py b/rialto/common/utils.py index c5527a8..b2e19b4 100644 --- a/rialto/common/utils.py +++ b/rialto/common/utils.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = ["load_yaml", "get_date_col_property", "get_delta_partition"] +__all__ = ["load_yaml"] import os from typing import Any import pyspark.sql.functions as F import yaml +from env_yaml import EnvLoader from pyspark.sql import DataFrame from pyspark.sql.types import FloatType @@ -34,46 +35,14 @@ def load_yaml(path: str) -> Any: raise FileNotFoundError(f"Can't find {path}.") with open(path, "r") as stream: - return yaml.safe_load(stream) - - -def get_date_col_property(spark, table: str, property: str) -> str: - """ - Retrieve a data column name from a given table property - - :param spark: spark session - :param table: path to table - :param property: name of the property - :return: data column name - """ - props = spark.sql(f"show tblproperties {table}") - date_col = props.filter(F.col("key") == property).select("value").collect() - if len(date_col): - return date_col[0].value - else: - raise RuntimeError(f"Table {table} has no property {property}.") - - -def get_delta_partition(spark, table: str) -> str: - """ - Select first partition column of the delta table - - :param table: full table name - :return: partition column name - """ - columns = spark.catalog.listColumns(table) - partition_columns = list(filter(lambda c: c.isPartition, columns)) - if len(partition_columns): - return partition_columns[0].name - else: - raise RuntimeError(f"Delta table has no partitions: {table}.") + return yaml.load(stream, EnvLoader) def cast_decimals_to_floats(df: DataFrame) -> DataFrame: """ Find all decimal types in the table and cast them to floats. Fixes errors in .toPandas() conversions. - :param df: pyspark DataFrame + :param df: input df :return: pyspark DataFrame with fixed types """ decimal_cols = [col_name for col_name, data_type in df.dtypes if "decimal" in data_type] diff --git a/rialto/jobs/__init__.py b/rialto/jobs/__init__.py index 79c3773..a6ee6cb 100644 --- a/rialto/jobs/__init__.py +++ b/rialto/jobs/__init__.py @@ -11,3 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from rialto.jobs.decorators import config, datasource, job diff --git a/rialto/jobs/configuration/config_holder.py b/rialto/jobs/configuration/config_holder.py deleted file mode 100644 index 161c61a..0000000 --- a/rialto/jobs/configuration/config_holder.py +++ /dev/null @@ -1,130 +0,0 @@ -# Copyright 2022 ABSA Group Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -__all__ = ["ConfigException", "FeatureStoreConfig", "ConfigHolder"] - -import datetime -import typing - -from pydantic import BaseModel - - -class ConfigException(Exception): - """Wrong Configuration Exception""" - - pass - - -class FeatureStoreConfig(BaseModel): - """Configuration of Feature Store Paths""" - - feature_store_schema: str = None - feature_metadata_schema: str = None - - -class ConfigHolder: - """ - Main Rialto Jobs config holder. - - Configured via job_runner and then called from job_base / job decorators. - """ - - _config = {} - _dependencies = {} - _run_date = None - _feature_store_config: FeatureStoreConfig = None - - @classmethod - def set_run_date(cls, run_date: datetime.date) -> None: - """ - Inicialize run Date - - :param run_date: datetime.date, run date - :return: None - """ - cls._run_date = run_date - - @classmethod - def get_run_date(cls) -> datetime.date: - """ - Run date - - :return: datetime.date, Run date - """ - if cls._run_date is None: - raise ConfigException("Run Date not Set !") - return cls._run_date - - @classmethod - def set_feature_store_config(cls, feature_store_schema: str, feature_metadata_schema: str) -> None: - """ - Inicialize feature store config - - :param feature_store_schema: str, schema name - :param feature_metadata_schema: str, metadata schema name - :return: None - """ - cls._feature_store_config = FeatureStoreConfig( - feature_store_schema=feature_store_schema, feature_metadata_schema=feature_metadata_schema - ) - - @classmethod - def get_feature_store_config(cls) -> FeatureStoreConfig: - """ - Feature Store Config - - :return: FeatureStoreConfig - """ - if cls._feature_store_config is None: - raise ConfigException("Feature Store Config not Set !") - - return cls._feature_store_config - - @classmethod - def get_config(cls) -> typing.Dict: - """ - Get config dictionary - - :return: dictionary of key-value pairs - """ - return cls._config.copy() - - @classmethod - def set_custom_config(cls, **kwargs) -> None: - """ - Set custom key-value pairs for custom config - - :param kwargs: key-value pairs to setup - :return: None - """ - cls._config.update(kwargs) - - @classmethod - def get_dependency_config(cls) -> typing.Dict: - """ - Get rialto job dependency config - - :return: dictionary with dependency config - """ - return cls._dependencies - - @classmethod - def set_dependency_config(cls, dependencies: typing.Dict) -> None: - """ - Get rialto job dependency config - - :param dependencies: dictionary with the config - :return: None - """ - cls._dependencies = dependencies diff --git a/rialto/jobs/decorators/__init__.py b/rialto/jobs/decorators/__init__.py index ba62141..6f2713a 100644 --- a/rialto/jobs/decorators/__init__.py +++ b/rialto/jobs/decorators/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .decorators import datasource, job +from .decorators import config, datasource, job diff --git a/rialto/jobs/decorators/decorators.py b/rialto/jobs/decorators/decorators.py index f900726..d288b7b 100644 --- a/rialto/jobs/decorators/decorators.py +++ b/rialto/jobs/decorators/decorators.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = ["datasource", "job"] +__all__ = ["datasource", "job", "config"] import inspect import typing @@ -24,6 +24,20 @@ from rialto.jobs.decorators.resolver import Resolver +def config(ds_getter: typing.Callable) -> typing.Callable: + """ + Config parser functions decorator. + + Registers a config parsing function into a rialto job prerequisite. + You can then request the job via job function arguments. + + :param ds_getter: dataset reader function + :return: raw reader function, unchanged + """ + Resolver.register_callable(ds_getter) + return ds_getter + + def datasource(ds_getter: typing.Callable) -> typing.Callable: """ Dataset reader functions decorator. @@ -77,14 +91,14 @@ def job(*args, custom_name=None, disable_version=False): """ Rialto jobs decorator. - Transforms a python function into a rialto transormation, which can be imported and ran by Rialto Runner. + Transforms a python function into a rialto transformation, which can be imported and ran by Rialto Runner. Is mainly used as @job and the function's name is used, and the outputs get automatic. To override this behavior, use @job(custom_name=XXX, disable_version=True). :param *args: list of positional arguments. Empty in case custom_name or disable_version is specified. :param custom_name: str for custom job name. - :param disable_version: bool for disabling autofilling the VERSION column in the job's outputs. + :param disable_version: bool for disabling automatically filling the VERSION column in the job's outputs. :return: One more job wrapper for run function (if custom name or version override specified). Otherwise, generates Rialto Transformation Type and returns it for in-module registration. """ @@ -93,7 +107,7 @@ def job(*args, custom_name=None, disable_version=False): module = _get_module(stack) version = _get_version(module) - # Use case where it's just raw @f. Otherwise we get [] here. + # Use case where it's just raw @f. Otherwise, we get [] here. if len(args) == 1 and callable(args[0]): f = args[0] return _generate_rialto_job(callable=f, module=module, class_name=f.__name__, version=version) diff --git a/rialto/jobs/decorators/job_base.py b/rialto/jobs/decorators/job_base.py index 9e3ecc8..d91537f 100644 --- a/rialto/jobs/decorators/job_base.py +++ b/rialto/jobs/decorators/job_base.py @@ -24,11 +24,11 @@ from pyspark.sql import DataFrame, SparkSession from rialto.common import TableReader -from rialto.jobs.configuration.config_holder import ConfigHolder from rialto.jobs.decorators.resolver import Resolver -from rialto.loader import DatabricksLoader, PysparkFeatureLoader +from rialto.loader import PysparkFeatureLoader from rialto.metadata import MetadataManager from rialto.runner import Transformation +from rialto.runner.config_loader import PipelineConfig class JobBase(Transformation): @@ -53,12 +53,14 @@ def get_job_name(self) -> str: def _setup_resolver(self, run_date: datetime.date) -> None: Resolver.register_callable(lambda: run_date, "run_date") - Resolver.register_callable(ConfigHolder.get_config, "config") - Resolver.register_callable(ConfigHolder.get_dependency_config, "dependencies") - Resolver.register_callable(self._get_spark, "spark") Resolver.register_callable(self._get_table_reader, "table_reader") - Resolver.register_callable(self._get_feature_loader, "feature_loader") + Resolver.register_callable(self._get_config, "config") + + if self._get_feature_loader() is not None: + Resolver.register_callable(self._get_feature_loader, "feature_loader") + if self._get_metadata_manager() is not None: + Resolver.register_callable(self._get_metadata_manager, "metadata_manager") try: yield @@ -66,13 +68,18 @@ def _setup_resolver(self, run_date: datetime.date) -> None: Resolver.cache_clear() def _setup( - self, spark: SparkSession, run_date: datetime.date, table_reader: TableReader, dependencies: typing.Dict = None + self, + spark: SparkSession, + table_reader: TableReader, + config: PipelineConfig = None, + metadata_manager: MetadataManager = None, + feature_loader: PysparkFeatureLoader = None, ) -> None: self._spark = spark self._table_rader = table_reader - - ConfigHolder.set_dependency_config(dependencies) - ConfigHolder.set_run_date(run_date) + self._config = config + self._metadata = metadata_manager + self._feature_loader = feature_loader def _get_spark(self) -> SparkSession: return self._spark @@ -80,13 +87,14 @@ def _get_spark(self) -> SparkSession: def _get_table_reader(self) -> TableReader: return self._table_rader - def _get_feature_loader(self) -> PysparkFeatureLoader: - config = ConfigHolder.get_feature_store_config() + def _get_config(self) -> PipelineConfig: + return self._config - databricks_loader = DatabricksLoader(self._spark, config.feature_store_schema) - feature_loader = PysparkFeatureLoader(self._spark, databricks_loader, config.feature_metadata_schema) + def _get_feature_loader(self) -> PysparkFeatureLoader: + return self._feature_loader - return feature_loader + def _get_metadata_manager(self) -> MetadataManager: + return self._metadata def _get_timestamp_holder_result(self) -> DataFrame: spark = self._get_spark() @@ -118,8 +126,9 @@ def run( reader: TableReader, run_date: datetime.date, spark: SparkSession = None, + config: PipelineConfig = None, metadata_manager: MetadataManager = None, - dependencies: typing.Dict = None, + feature_loader: PysparkFeatureLoader = None, ) -> DataFrame: """ Rialto transformation run @@ -127,12 +136,11 @@ def run( :param reader: data store api object :param info_date: date :param spark: spark session - :param metadata_manager: metadata api object - :param dependencies: rialto job dependencies + :param config: pipeline config :return: dataframe """ try: - self._setup(spark, run_date, reader, dependencies) + self._setup(spark, reader, config, metadata_manager, feature_loader) return self._run_main_callable(run_date) except Exception as e: logger.exception(e) diff --git a/rialto/jobs/decorators/resolver.py b/rialto/jobs/decorators/resolver.py index 9f90e5a..26856d1 100644 --- a/rialto/jobs/decorators/resolver.py +++ b/rialto/jobs/decorators/resolver.py @@ -30,7 +30,7 @@ class Resolver: Resolver handles dependency management between datasets and jobs. We register different callables, which can depend on other callables. - Calling resolve() we attempts to resolve these dependencies. + Calling resolve() we attempt to resolve these dependencies. """ _storage = {} @@ -101,7 +101,7 @@ def cache_clear(cls) -> None: """ Clear resolver cache. - The resolve mehtod caches its results to avoid duplication of resolutions. + The resolve method caches its results to avoid duplication of resolutions. However, in case we re-register some callables, we need to clear cache in order to ensure re-execution of all resolutions. diff --git a/rialto/jobs/decorators/test_utils.py b/rialto/jobs/decorators/test_utils.py index 5465d6e..39d76ce 100644 --- a/rialto/jobs/decorators/test_utils.py +++ b/rialto/jobs/decorators/test_utils.py @@ -17,9 +17,10 @@ import importlib import typing from contextlib import contextmanager -from unittest.mock import patch, create_autospec, MagicMock -from rialto.jobs.decorators.resolver import Resolver, ResolverException +from unittest.mock import MagicMock, create_autospec, patch + from rialto.jobs.decorators.job_base import JobBase +from rialto.jobs.decorators.resolver import Resolver, ResolverException def _passthrough_decorator(*args, **kwargs) -> typing.Callable: @@ -34,6 +35,8 @@ def _disable_job_decorators() -> None: patches = [ patch("rialto.jobs.decorators.datasource", _passthrough_decorator), patch("rialto.jobs.decorators.decorators.datasource", _passthrough_decorator), + patch("rialto.jobs.decorators.config", _passthrough_decorator), + patch("rialto.jobs.decorators.decorators.config", _passthrough_decorator), patch("rialto.jobs.decorators.job", _passthrough_decorator), patch("rialto.jobs.decorators.decorators.job", _passthrough_decorator), ] diff --git a/rialto/loader/__init__.py b/rialto/loader/__init__.py index 7adc52d..7e1e936 100644 --- a/rialto/loader/__init__.py +++ b/rialto/loader/__init__.py @@ -12,5 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from rialto.loader.data_loader import DatabricksLoader from rialto.loader.pyspark_feature_loader import PysparkFeatureLoader diff --git a/rialto/loader/data_loader.py b/rialto/loader/data_loader.py deleted file mode 100644 index 930c2b0..0000000 --- a/rialto/loader/data_loader.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2022 ABSA Group Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -__all__ = ["DatabricksLoader"] - -from datetime import date - -from pyspark.sql import DataFrame, SparkSession - -from rialto.common.table_reader import TableReader -from rialto.loader.interfaces import DataLoader - - -class DatabricksLoader(DataLoader): - """Implementation of DataLoader using TableReader to access feature tables""" - - def __init__(self, spark: SparkSession, schema: str, date_column: str = "INFORMATION_DATE"): - super().__init__() - - self.reader = TableReader(spark) - self.schema = schema - self.date_col = date_column - - def read_group(self, group: str, information_date: date) -> DataFrame: - """ - Read a feature group by getting the latest partition by date - - :param group: group name - :param information_date: partition date - :return: dataframe - """ - return self.reader.get_latest( - f"{self.schema}.{group}", until=information_date, date_column=self.date_col, uppercase_columns=True - ) diff --git a/rialto/loader/interfaces.py b/rialto/loader/interfaces.py index dad08e6..9089f40 100644 --- a/rialto/loader/interfaces.py +++ b/rialto/loader/interfaces.py @@ -12,31 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = ["DataLoader", "FeatureLoaderInterface"] +__all__ = ["FeatureLoaderInterface"] import abc from datetime import date from typing import Dict -class DataLoader(metaclass=abc.ABCMeta): - """ - An interface to read feature groups from storage - - Requires read_group function. - """ - - @abc.abstractmethod - def read_group(self, group: str, information_date: date): - """ - Read one feature group - - :param group: Group name - :param information_date: date - """ - raise NotImplementedError - - class FeatureLoaderInterface(metaclass=abc.ABCMeta): """ A definition of feature loading interface diff --git a/rialto/loader/pyspark_feature_loader.py b/rialto/loader/pyspark_feature_loader.py index d0eef20..7ee78fc 100644 --- a/rialto/loader/pyspark_feature_loader.py +++ b/rialto/loader/pyspark_feature_loader.py @@ -20,9 +20,9 @@ from pyspark.sql import DataFrame, SparkSession +from rialto.common import TableReader from rialto.common.utils import cast_decimals_to_floats from rialto.loader.config_loader import FeatureConfig, GroupConfig, get_feature_config -from rialto.loader.data_loader import DataLoader from rialto.loader.interfaces import FeatureLoaderInterface from rialto.metadata.metadata_manager import ( FeatureMetadata, @@ -34,7 +34,13 @@ class PysparkFeatureLoader(FeatureLoaderInterface): """Implementation of feature loader for pyspark environment""" - def __init__(self, spark: SparkSession, data_loader: DataLoader, metadata_schema: str): + def __init__( + self, + spark: SparkSession, + feature_schema: str, + metadata_schema: str, + date_column: str = "INFORMATION_DATE", + ): """ Init @@ -44,11 +50,28 @@ def __init__(self, spark: SparkSession, data_loader: DataLoader, metadata_schema """ super().__init__() self.spark = spark - self.data_loader = data_loader + self.reader = TableReader(spark) + self.feature_schema = feature_schema + self.date_col = date_column self.metadata = MetadataManager(spark, metadata_schema) KeyMap = namedtuple("KeyMap", ["df", "key"]) + def read_group(self, group: str, information_date: date) -> DataFrame: + """ + Read a feature group by getting the latest partition by date + + :param group: group name + :param information_date: partition date + :return: dataframe + """ + return self.reader.get_latest( + f"{self.feature_schema}.{group}", + date_until=information_date, + date_column=self.date_col, + uppercase_columns=True, + ) + def get_feature(self, group_name: str, feature_name: str, information_date: date) -> DataFrame: """ Get single feature @@ -60,9 +83,7 @@ def get_feature(self, group_name: str, feature_name: str, information_date: date """ print("This function is untested, use with caution!") key = self.get_group_metadata(group_name).key - return self.data_loader.read_group(self.get_group_fs_name(group_name), information_date).select( - *key, feature_name - ) + return self.read_group(self.get_group_fs_name(group_name), information_date).select(*key, feature_name) def get_feature_metadata(self, group_name: str, feature_name: str) -> FeatureMetadata: """ @@ -83,7 +104,7 @@ def get_group(self, group_name: str, information_date: date) -> DataFrame: :return: A dataframe containing feature group key """ print("This function is untested, use with caution!") - return self.data_loader.read_group(self.get_group_fs_name(group_name), information_date) + return self.read_group(self.get_group_fs_name(group_name), information_date) def get_group_metadata(self, group_name: str) -> GroupMetadata: """ @@ -144,7 +165,7 @@ def _get_keymaps(self, config: FeatureConfig, information_date: date) -> List[Ke """ key_maps = [] for mapping in config.maps: - df = self.data_loader.read_group(self.get_group_fs_name(mapping), information_date).drop("INFORMATION_DATE") + df = self.read_group(self.get_group_fs_name(mapping), information_date).drop("INFORMATION_DATE") key = self.metadata.get_group(mapping).key key_maps.append(PysparkFeatureLoader.KeyMap(df, key)) return key_maps @@ -174,9 +195,7 @@ def get_features_from_cfg(self, path: str, information_date: date) -> DataFrame: """ config = get_feature_config(path) # 1 select keys from base - base = self.data_loader.read_group(self.get_group_fs_name(config.base.group), information_date).select( - config.base.keys - ) + base = self.read_group(self.get_group_fs_name(config.base.group), information_date).select(config.base.keys) # 2 join maps onto base (resolve keys) if config.maps: key_maps = self._get_keymaps(config, information_date) @@ -184,7 +203,7 @@ def get_features_from_cfg(self, path: str, information_date: date) -> DataFrame: # 3 read, select and join other tables for group_cfg in config.selection: - df = self.data_loader.read_group(self.get_group_fs_name(group_cfg.group), information_date) + df = self.read_group(self.get_group_fs_name(group_cfg.group), information_date) base = self._add_feature_group(base, df, group_cfg) # 4 fix dtypes for pandas conversion diff --git a/rialto/runner/config_loader.py b/rialto/runner/config_loader.py index af6640b..86c142d 100644 --- a/rialto/runner/config_loader.py +++ b/rialto/runner/config_loader.py @@ -12,13 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = ["get_pipelines_config", "transform_dependencies"] +__all__ = [ + "get_pipelines_config", +] -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional from pydantic import BaseModel from rialto.common.utils import load_yaml +from rialto.runner.config_overrides import override_config class IntervalConfig(BaseModel): @@ -29,13 +32,13 @@ class IntervalConfig(BaseModel): class ScheduleConfig(BaseModel): frequency: str day: Optional[int] = 0 - info_date_shift: Union[Optional[IntervalConfig], List[IntervalConfig]] = IntervalConfig(units="days", value=0) + info_date_shift: Optional[List[IntervalConfig]] = IntervalConfig(units="days", value=0) class DependencyConfig(BaseModel): table: str name: Optional[str] = None - date_col: Optional[str] = None + date_col: str interval: IntervalConfig @@ -52,37 +55,47 @@ class MailConfig(BaseModel): sent_empty: Optional[bool] = False -class GeneralConfig(BaseModel): - target_schema: str - target_partition_column: str - source_date_column_property: Optional[str] = None +class RunnerConfig(BaseModel): watched_period_units: str watched_period_value: int - job: str mail: MailConfig +class TargetConfig(BaseModel): + target_schema: str + target_partition_column: str + + +class MetadataManagerConfig(BaseModel): + metadata_schema: str + + +class FeatureLoaderConfig(BaseModel): + feature_schema: str + metadata_schema: str + + class PipelineConfig(BaseModel): name: str - module: Optional[ModuleConfig] = None + module: ModuleConfig schedule: ScheduleConfig - dependencies: List[DependencyConfig] = [] + dependencies: Optional[List[DependencyConfig]] = [] + target: TargetConfig = None + metadata_manager: Optional[MetadataManagerConfig] = None + feature_loader: Optional[FeatureLoaderConfig] = None + extras: Optional[Dict] = {} class PipelinesConfig(BaseModel): - general: GeneralConfig + runner: RunnerConfig pipelines: list[PipelineConfig] -def get_pipelines_config(path) -> PipelinesConfig: +def get_pipelines_config(path: str, overrides: Dict) -> PipelinesConfig: """Load and parse yaml config""" - return PipelinesConfig(**load_yaml(path)) - - -def transform_dependencies(dependencies: List[DependencyConfig]) -> Dict: - """Transform dependency config list into a dictionary""" - res = {} - for dep in dependencies: - if dep.name: - res[dep.name] = dep - return res + raw_config = load_yaml(path) + if overrides: + cfg = override_config(raw_config, overrides) + return PipelinesConfig(**cfg) + else: + return PipelinesConfig(**raw_config) diff --git a/rialto/runner/config_overrides.py b/rialto/runner/config_overrides.py new file mode 100644 index 0000000..a525525 --- /dev/null +++ b/rialto/runner/config_overrides.py @@ -0,0 +1,76 @@ +# Copyright 2022 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ["override_config"] + +from typing import Dict, List, Tuple + +from loguru import logger + + +def _split_index_key(key: str) -> Tuple[str, str]: + name = key.split("[")[0] + index = key.split("[")[1].replace("]", "") + return name, index + + +def _find_first_match(config: List, index: str) -> int: + index_key, index_value = index.split("=") + return next(i for i, x in enumerate(config) if x.get(index_key) == index_value) + + +def _override(config, path, value) -> Dict: + key = path[0] + if "[" in key: + name, index = _split_index_key(key) + if name not in config: + raise ValueError(f"Invalid key {name}") + if "=" in index: + index = _find_first_match(config[name], index) + else: + index = int(index) + if index >= 0 and index < len(config[name]): + if len(path) == 1: + config[name][index] = value + else: + config[name][index] = _override(config[name][index], path[1:], value) + elif index == -1: + if len(path) == 1: + config[name].append(value) + else: + raise ValueError(f"Invalid index {index} for key {name} in path {path}") + else: + raise IndexError(f"Index {index} out of bounds for key {key}") + else: + if len(path) == 1: + config[key] = value + else: + if key not in config: + raise ValueError(f"Invalid key {key}") + config[key] = _override(config[key], path[1:], value) + return config + + +def override_config(config: Dict, overrides: Dict) -> Dict: + """Override config with user input + + :param config: config dictionary + :param overrides: dictionary of overrides + :return: Overridden config + """ + for path, value in overrides.items(): + logger.info("Applying override: ", path, value) + config = _override(config, path.split("."), value) + + return config diff --git a/rialto/runner/runner.py b/rialto/runner/runner.py index 343d2fe..ac9d6bc 100644 --- a/rialto/runner/runner.py +++ b/rialto/runner/runner.py @@ -16,25 +16,15 @@ import datetime from datetime import date -from importlib import import_module -from typing import List, Tuple +from typing import Dict, List, Tuple import pyspark.sql.functions as F from loguru import logger from pyspark.sql import DataFrame, SparkSession +import rialto.runner.utils as utils from rialto.common import TableReader -from rialto.common.utils import get_date_col_property, get_delta_partition -from rialto.jobs.configuration.config_holder import ConfigHolder -from rialto.metadata import MetadataManager -from rialto.runner.config_loader import ( - DependencyConfig, - ModuleConfig, - PipelineConfig, - ScheduleConfig, - get_pipelines_config, - transform_dependencies, -) +from rialto.runner.config_loader import PipelineConfig, get_pipelines_config from rialto.runner.date_manager import DateManager from rialto.runner.table import Table from rialto.runner.tracker import Record, Tracker @@ -48,100 +38,60 @@ def __init__( self, spark: SparkSession, config_path: str, - feature_metadata_schema: str = None, run_date: str = None, - date_from: str = None, - date_until: str = None, - feature_store_schema: str = None, - custom_job_config: dict = None, rerun: bool = False, op: str = None, + skip_dependencies: bool = False, + overrides: Dict = None, ): self.spark = spark - self.config = get_pipelines_config(config_path) - self.reader = TableReader( - spark, date_property=self.config.general.source_date_column_property, infer_partition=False - ) - if feature_metadata_schema: - self.metadata = MetadataManager(spark, feature_metadata_schema) - else: - self.metadata = None - self.date_from = date_from - self.date_until = date_until + self.config = get_pipelines_config(config_path, overrides) + self.reader = TableReader(spark) self.rerun = rerun + self.skip_dependencies = skip_dependencies self.op = op - self.tracker = Tracker(self.config.general.target_schema) - - if (feature_store_schema is not None) and (feature_metadata_schema is not None): - ConfigHolder.set_feature_store_config(feature_store_schema, feature_metadata_schema) - - if custom_job_config is not None: - ConfigHolder.set_custom_config(**custom_job_config) + self.tracker = Tracker() if run_date: run_date = DateManager.str_to_date(run_date) else: run_date = date.today() - if self.date_from: - self.date_from = DateManager.str_to_date(date_from) - if self.date_until: - self.date_until = DateManager.str_to_date(date_until) - - if not self.date_from: - self.date_from = DateManager.date_subtract( - run_date=run_date, - units=self.config.general.watched_period_units, - value=self.config.general.watched_period_value, - ) - if not self.date_until: - self.date_until = run_date + + self.date_from = DateManager.date_subtract( + run_date=run_date, + units=self.config.runner.watched_period_units, + value=self.config.runner.watched_period_value, + ) + + self.date_until = run_date + if self.date_from > self.date_until: raise ValueError(f"Invalid date range from {self.date_from} until {self.date_until}") logger.info(f"Running period from {self.date_from} until {self.date_until}") - def _load_module(self, cfg: ModuleConfig) -> Transformation: - """ - Load feature group - - :param cfg: Feature configuration - :return: Transformation object - """ - module = import_module(cfg.python_module) - class_obj = getattr(module, cfg.python_class) - return class_obj() - - def _generate( - self, instance: Transformation, run_date: date, dependencies: List[DependencyConfig] = None - ) -> DataFrame: + def _execute(self, instance: Transformation, run_date: date, pipeline: PipelineConfig) -> DataFrame: """ - Run feature group + Run the job :param instance: Instance of Transformation :param run_date: date to run for + :param pipeline: pipeline configuration :return: Dataframe """ - if dependencies is not None: - dependencies = transform_dependencies(dependencies) + metadata_manager, feature_loader = utils.init_tools(self.spark, pipeline) + df = instance.run( - reader=self.reader, - run_date=run_date, spark=self.spark, - metadata_manager=self.metadata, - dependencies=dependencies, + run_date=run_date, + config=pipeline, + reader=self.reader, + metadata_manager=metadata_manager, + feature_loader=feature_loader, ) logger.info(f"Generated {df.count()} records") return df - def _table_exists(self, table: str) -> bool: - """ - Check table exists in spark catalog - - :param table: full table path - :return: bool - """ - return self.spark.catalog.tableExists(table) - def _write(self, df: DataFrame, info_date: date, table: Table) -> None: """ Write dataframe to storage @@ -155,44 +105,6 @@ def _write(self, df: DataFrame, info_date: date, table: Table) -> None: df.write.partitionBy(table.partition).mode("overwrite").saveAsTable(table.get_table_path()) logger.info(f"Results writen to {table.get_table_path()}") - try: - get_date_col_property(self.spark, table.get_table_path(), "rialto_date_column") - except RuntimeError: - sql_query = ( - f"ALTER TABLE {table.get_table_path()} SET TBLPROPERTIES ('rialto_date_column' = '{table.partition}')" - ) - self.spark.sql(sql_query) - logger.info(f"Set table property rialto_date_column to {table.partition}") - - def _delta_partition(self, table: str) -> str: - """ - Select first partition column, should be only one - - :param table: full table name - :return: partition column name - """ - columns = self.spark.catalog.listColumns(table) - partition_columns = list(filter(lambda c: c.isPartition, columns)) - if len(partition_columns): - return partition_columns[0].name - else: - raise RuntimeError(f"Delta table has no partitions: {table}.") - - def _get_partitions(self, table: Table) -> List[date]: - """ - Get partition values - - :param table: Table object - :return: List of partition values - """ - rows = ( - self.reader.get_table(table.get_table_path(), date_column=table.partition) - .select(table.partition) - .distinct() - .collect() - ) - return [r[table.partition] for r in rows] - def check_dates_have_partition(self, table: Table, dates: List[date]) -> List[bool]: """ For given list of dates, check if there is a matching partition for each @@ -201,8 +113,8 @@ def check_dates_have_partition(self, table: Table, dates: List[date]) -> List[bo :param dates: list of dates to check :return: list of bool """ - if self._table_exists(table.get_table_path()): - partitions = self._get_partitions(table) + if utils.table_exists(self.spark, table.get_table_path()): + partitions = utils.get_partitions(self.reader, table) return [(date in partitions) for date in dates] else: logger.info(f"Table {table.get_table_path()} doesn't exist!") @@ -226,18 +138,9 @@ def check_dependencies(self, pipeline: PipelineConfig, run_date: date) -> bool: possible_dep_dates = DateManager.all_dates(dep_from, run_date) - # date column options prioritization (manual column, table property, inferred from delta) - if dependency.date_col: - date_col = dependency.date_col - elif self.config.general.source_date_column_property: - date_col = get_date_col_property( - self.spark, dependency.table, self.config.general.source_date_column_property - ) - else: - date_col = get_delta_partition(self.spark, dependency.table) - logger.debug(f"Date column for {dependency.table} is {date_col}") + logger.debug(f"Date column for {dependency.table} is {dependency.date_col}") - source = Table(table_path=dependency.table, partition=date_col) + source = Table(table_path=dependency.table, partition=dependency.date_col) if True in self.check_dates_have_partition(source, possible_dep_dates): logger.info(f"Dependency for {dependency.table} from {dep_from} until {run_date} is fulfilled") else: @@ -251,25 +154,6 @@ def check_dependencies(self, pipeline: PipelineConfig, run_date: date) -> bool: return True - def get_possible_run_dates(self, schedule: ScheduleConfig) -> List[date]: - """ - List possible run dates according to parameters and config - - :param schedule: schedule config - :return: List of dates - """ - return DateManager.run_dates(self.date_from, self.date_until, schedule) - - def get_info_dates(self, schedule: ScheduleConfig, run_dates: List[date]) -> List[date]: - """ - Transform given dates into info dates according to the config - - :param schedule: schedule config - :param run_dates: date list - :return: list of modified dates - """ - return [DateManager.to_info_date(x, schedule) for x in run_dates] - def _get_completion(self, target: Table, info_dates: List[date]) -> List[bool]: """ Check if model has run for given dates @@ -291,8 +175,8 @@ def _select_run_dates(self, pipeline: PipelineConfig, table: Table) -> Tuple[Lis :param table: table path :return: list of run dates and list of info dates """ - possible_run_dates = self.get_possible_run_dates(pipeline.schedule) - possible_info_dates = self.get_info_dates(pipeline.schedule, possible_run_dates) + possible_run_dates = DateManager.run_dates(self.date_from, self.date_until, pipeline.schedule) + possible_info_dates = [DateManager.to_info_date(x, pipeline.schedule) for x in possible_run_dates] current_state = self._get_completion(table, possible_info_dates) selection = [ @@ -318,18 +202,17 @@ def _run_one_date(self, pipeline: PipelineConfig, run_date: date, info_date: dat :param target: target Table :return: success bool """ - if self.check_dependencies(pipeline, run_date): + if self.skip_dependencies or self.check_dependencies(pipeline, run_date): logger.info(f"Running {pipeline.name} for {run_date}") - if self.config.general.job == "run": - feature_group = self._load_module(pipeline.module) - df = self._generate(feature_group, run_date, pipeline.dependencies) - records = df.count() - if records > 0: - self._write(df, info_date, target) - return records - else: - raise RuntimeError("No records generated") + feature_group = utils.load_module(pipeline.module) + df = self._execute(feature_group, run_date, pipeline) + records = df.count() + if records > 0: + self._write(df, info_date, target) + return records + else: + raise RuntimeError("No records generated") return 0 def _run_pipeline(self, pipeline: PipelineConfig): @@ -340,9 +223,9 @@ def _run_pipeline(self, pipeline: PipelineConfig): :return: success bool """ target = Table( - schema_path=self.config.general.target_schema, + schema_path=pipeline.target.target_schema, class_name=pipeline.module.python_class, - partition=self.config.general.target_partition_column, + partition=pipeline.target.target_partition_column, ) logger.info(f"Loaded pipeline {pipeline.name}") @@ -371,8 +254,8 @@ def _run_pipeline(self, pipeline: PipelineConfig): ) ) except Exception as error: - print(f"An exception occurred in pipeline {pipeline.name}") - print(error) + logger.error(f"An exception occurred in pipeline {pipeline.name}") + logger.error(error) self.tracker.add( Record( job=pipeline.name, @@ -386,7 +269,7 @@ def _run_pipeline(self, pipeline: PipelineConfig): ) ) except KeyboardInterrupt: - print(f"Pipeline {pipeline.name} interrupted") + logger.error(f"Pipeline {pipeline.name} interrupted") self.tracker.add( Record( job=pipeline.name, @@ -413,4 +296,4 @@ def __call__(self): self._run_pipeline(pipeline) finally: print(self.tracker.records) - self.tracker.report(self.config.general.mail) + self.tracker.report(self.config.runner.mail) diff --git a/rialto/runner/tracker.py b/rialto/runner/tracker.py index de97fb0..57a24e6 100644 --- a/rialto/runner/tracker.py +++ b/rialto/runner/tracker.py @@ -41,8 +41,7 @@ class Record: class Tracker: """Collect information about runs and sent them out via email""" - def __init__(self, target_schema: str): - self.target_schema = target_schema + def __init__(self): self.records = [] self.last_error = None self.pipeline_start = datetime.now() @@ -55,7 +54,7 @@ def add(self, record: Record) -> None: def report(self, mail_cfg: MailConfig): """Create and send html report""" if len(self.records) or mail_cfg.sent_empty: - report = HTMLMessage.make_report(self.target_schema, self.pipeline_start, self.records) + report = HTMLMessage.make_report(self.pipeline_start, self.records) for receiver in mail_cfg.to: message = Mailer.create_message( subject=mail_cfg.subject, sender=mail_cfg.sender, receiver=receiver, body=report @@ -118,7 +117,7 @@ def _make_overview_header(): """ @staticmethod - def _make_header(target: str, start: datetime): + def _make_header(start: datetime): return f"""
@@ -127,7 +126,7 @@ def _make_header(target: str, start: datetime):
- Jobs started {str(start).split('.')[0]}, targeting {target} + Jobs started {str(start).split('.')[0]}
@@ -228,14 +227,14 @@ def _make_insights(records: List[Record]): """ @staticmethod - def make_report(target: str, start: datetime, records: List[Record]) -> str: + def make_report(start: datetime, records: List[Record]) -> str: """Create html email report""" html = [ """ """, HTMLMessage._head(), HTMLMessage._body_open(), - HTMLMessage._make_header(target, start), + HTMLMessage._make_header(start), HTMLMessage._make_overview(records), HTMLMessage._make_insights(records), HTMLMessage._body_close(), diff --git a/rialto/runner/transformation.py b/rialto/runner/transformation.py index 4399ce0..5b6f2eb 100644 --- a/rialto/runner/transformation.py +++ b/rialto/runner/transformation.py @@ -16,12 +16,13 @@ import abc import datetime -from typing import Dict from pyspark.sql import DataFrame, SparkSession -from rialto.common import TableReader +from rialto.common import DataReader +from rialto.loader import PysparkFeatureLoader from rialto.metadata import MetadataManager +from rialto.runner.config_loader import PipelineConfig class Transformation(metaclass=abc.ABCMeta): @@ -30,11 +31,12 @@ class Transformation(metaclass=abc.ABCMeta): @abc.abstractmethod def run( self, - reader: TableReader, + reader: DataReader, run_date: datetime.date, spark: SparkSession = None, + config: PipelineConfig = None, metadata_manager: MetadataManager = None, - dependencies: Dict = None, + feature_loader: PysparkFeatureLoader = None, ) -> DataFrame: """ Run the transformation @@ -42,8 +44,9 @@ def run( :param reader: data store api object :param run_date: date :param spark: spark session - :param metadata_manager: metadata api object - :param dependencies: dictionary of dependencies + :param config: pipeline config + :param metadata_manager: metadata manager + :param feature_loader: feature loader :return: dataframe """ raise NotImplementedError diff --git a/rialto/runner/utils.py b/rialto/runner/utils.py new file mode 100644 index 0000000..b74ec1b --- /dev/null +++ b/rialto/runner/utils.py @@ -0,0 +1,74 @@ +from datetime import date +from importlib import import_module +from typing import List, Tuple + +from pyspark.sql import SparkSession + +from rialto.common import DataReader +from rialto.loader import PysparkFeatureLoader +from rialto.metadata import MetadataManager +from rialto.runner.config_loader import ModuleConfig, PipelineConfig +from rialto.runner.table import Table +from rialto.runner.transformation import Transformation + + +def load_module(cfg: ModuleConfig) -> Transformation: + """ + Load feature group + + :param cfg: Feature configuration + :return: Transformation object + """ + module = import_module(cfg.python_module) + class_obj = getattr(module, cfg.python_class) + return class_obj() + + +def table_exists(spark: SparkSession, table: str) -> bool: + """ + Check table exists in spark catalog + + :param table: full table path + :return: bool + """ + return spark.catalog.tableExists(table) + + +def get_partitions(reader: DataReader, table: Table) -> List[date]: + """ + Get partition values + + :param table: Table object + :return: List of partition values + """ + rows = ( + reader.get_table(table.get_table_path(), date_column=table.partition) + .select(table.partition) + .distinct() + .collect() + ) + return [r[table.partition] for r in rows] + + +def init_tools(spark: SparkSession, pipeline: PipelineConfig) -> Tuple[MetadataManager, PysparkFeatureLoader]: + """ + Initialize metadata manager and feature loader + + :param spark: Spark session + :param pipeline: Pipeline configuration + :return: MetadataManager and PysparkFeatureLoader + """ + if pipeline.metadata_manager is not None: + metadata_manager = MetadataManager(spark, pipeline.metadata_manager.metadata_schema) + else: + metadata_manager = None + + if pipeline.feature_loader is not None: + feature_loader = PysparkFeatureLoader( + spark, + feature_schema=pipeline.feature_loader.feature_schema, + metadata_schema=pipeline.feature_loader.metadata_schema, + ) + else: + feature_loader = None + return metadata_manager, feature_loader diff --git a/tests/jobs/test_config_holder.py b/tests/jobs/test_config_holder.py deleted file mode 100644 index 38fadb1..0000000 --- a/tests/jobs/test_config_holder.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright 2022 ABSA Group Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from datetime import date - -import pytest - -from rialto.jobs.configuration.config_holder import ( - ConfigException, - ConfigHolder, - FeatureStoreConfig, -) - - -def test_run_date_unset(): - with pytest.raises(ConfigException): - ConfigHolder.get_run_date() - - -def test_run_date(): - dt = date(2023, 1, 1) - - ConfigHolder.set_run_date(dt) - - assert ConfigHolder.get_run_date() == dt - - -def test_feature_store_config_unset(): - with pytest.raises(ConfigException): - ConfigHolder.get_feature_store_config() - - -def test_feature_store_config(): - ConfigHolder.set_feature_store_config("store_schema", "metadata_schema") - - fsc = ConfigHolder.get_feature_store_config() - - assert type(fsc) is FeatureStoreConfig - assert fsc.feature_store_schema == "store_schema" - assert fsc.feature_metadata_schema == "metadata_schema" - - -def test_config_unset(): - config = ConfigHolder.get_config() - - assert type(config) is type({}) - assert len(config.items()) == 0 - - -def test_config_dict_copied_not_ref(): - """Test that config holder config can't be set from outside""" - config = ConfigHolder.get_config() - - config["test"] = 123 - - assert "test" not in ConfigHolder.get_config() - - -def test_config(): - ConfigHolder.set_custom_config(hello=123) - ConfigHolder.set_custom_config(world="test") - - config = ConfigHolder.get_config() - - assert config["hello"] == 123 - assert config["world"] == "test" - - -def test_config_from_dict(): - ConfigHolder.set_custom_config(**{"dict_item_1": 123, "dict_item_2": 456}) - - config = ConfigHolder.get_config() - - assert config["dict_item_1"] == 123 - assert config["dict_item_2"] == 456 - - -def test_dependencies_unset(): - deps = ConfigHolder.get_dependency_config() - assert len(deps.keys()) == 0 - - -def test_dependencies(): - ConfigHolder.set_dependency_config({"hello": 123}) - - deps = ConfigHolder.get_dependency_config() - - assert deps["hello"] == 123 diff --git a/tests/jobs/test_decorators.py b/tests/jobs/test_decorators.py index c6d05e6..54cb4a4 100644 --- a/tests/jobs/test_decorators.py +++ b/tests/jobs/test_decorators.py @@ -14,7 +14,6 @@ from importlib import import_module -from rialto.jobs.configuration.config_holder import ConfigHolder from rialto.jobs.decorators.job_base import JobBase from rialto.jobs.decorators.resolver import Resolver @@ -26,6 +25,13 @@ def test_dataset_decorator(): assert test_dataset == "dataset_return" +def test_config_decorator(): + _ = import_module("tests.jobs.test_job.test_job") + test_dataset = Resolver.resolve("custom_config") + + assert test_dataset == "config_return" + + def _rialto_import_stub(module_name, class_name): module = import_module(module_name) class_obj = getattr(module, class_name) @@ -70,7 +76,6 @@ def test_job_disabling_version(): def test_job_dependencies_registered(spark): - ConfigHolder.set_custom_config(value=123) job_class = _rialto_import_stub("tests.jobs.test_job.test_job", "job_asking_for_all_deps") # asserts part of the run - job_class.run(spark=spark, run_date=456, reader=789, metadata_manager=None, dependencies=1011) + job_class.run(spark=spark, run_date=456, reader=789, config=123, metadata_manager=654, feature_loader=321) diff --git a/tests/jobs/test_job/dependency_tests_job.py b/tests/jobs/test_job/dependency_tests_job.py index 3029b33..38e10ba 100644 --- a/tests/jobs/test_job/dependency_tests_job.py +++ b/tests/jobs/test_job/dependency_tests_job.py @@ -1,4 +1,4 @@ -from rialto.jobs.decorators import job, datasource +from rialto.jobs.decorators import datasource, job @datasource @@ -47,5 +47,5 @@ def missing_dependency_job(a, x): @job -def default_dependency_job(run_date, spark, config, dependencies, table_reader, feature_loader): +def default_dependency_job(run_date, spark, config, table_reader, feature_loader): return 1 diff --git a/tests/jobs/test_job/test_job.py b/tests/jobs/test_job/test_job.py index 460490a..3d648b5 100644 --- a/tests/jobs/test_job/test_job.py +++ b/tests/jobs/test_job/test_job.py @@ -11,9 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from rialto.jobs.decorators import config, datasource, job -from rialto.jobs.decorators import datasource, job +@config +def custom_config(): + return "config_return" @datasource @@ -37,9 +40,10 @@ def disable_version_job_function(): @job -def job_asking_for_all_deps(spark, run_date, config, dependencies, table_reader): +def job_asking_for_all_deps(spark, run_date, config, table_reader, metadata_manager, feature_loader): assert spark is not None assert run_date == 456 - assert config["value"] == 123 + assert config == 123 assert table_reader == 789 - assert dependencies == 1011 + assert metadata_manager == 654 + assert feature_loader == 321 diff --git a/tests/jobs/test_job_base.py b/tests/jobs/test_job_base.py index ab8284a..55fced1 100644 --- a/tests/jobs/test_job_base.py +++ b/tests/jobs/test_job_base.py @@ -14,42 +14,36 @@ import datetime -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pyspark.sql.types import tests.jobs.resources as resources -from rialto.jobs.configuration.config_holder import ConfigHolder, FeatureStoreConfig from rialto.jobs.decorators.resolver import Resolver from rialto.loader import PysparkFeatureLoader def test_setup_except_feature_loader(spark): table_reader = MagicMock() + config = MagicMock() date = datetime.date(2023, 1, 1) - ConfigHolder.set_custom_config(hello=1, world=2) - - resources.CustomJobNoReturnVal().run( - reader=table_reader, run_date=date, spark=spark, metadata_manager=None, dependencies={1: 1} - ) + resources.CustomJobNoReturnVal().run(reader=table_reader, run_date=date, spark=spark, config=config) assert Resolver.resolve("run_date") == date - assert Resolver.resolve("config") == ConfigHolder.get_config() - assert Resolver.resolve("dependencies") == ConfigHolder.get_dependency_config() + assert Resolver.resolve("config") == config assert Resolver.resolve("spark") == spark assert Resolver.resolve("table_reader") == table_reader -@patch( - "rialto.jobs.configuration.config_holder.ConfigHolder.get_feature_store_config", - return_value=FeatureStoreConfig(feature_store_schema="schema", feature_metadata_schema="metadata_schema"), -) def test_setup_feature_loader(spark): table_reader = MagicMock() date = datetime.date(2023, 1, 1) + feature_loader = PysparkFeatureLoader(spark, "", "", "") - resources.CustomJobNoReturnVal().run(reader=table_reader, run_date=date, spark=spark, metadata_manager=None) + resources.CustomJobNoReturnVal().run( + reader=table_reader, run_date=date, spark=spark, config=None, feature_loader=feature_loader + ) assert type(Resolver.resolve("feature_loader")) == PysparkFeatureLoader @@ -60,7 +54,7 @@ def test_custom_callable_called(spark, mocker): table_reader = MagicMock() date = datetime.date(2023, 1, 1) - resources.CustomJobNoReturnVal().run(reader=table_reader, run_date=date, spark=spark, metadata_manager=None) + resources.CustomJobNoReturnVal().run(reader=table_reader, run_date=date, spark=spark, config=None) spy_cc.assert_called_once() @@ -69,9 +63,7 @@ def test_no_return_vaue_adds_version_timestamp_dataframe(spark): table_reader = MagicMock() date = datetime.date(2023, 1, 1) - result = resources.CustomJobNoReturnVal().run( - reader=table_reader, run_date=date, spark=spark, metadata_manager=None - ) + result = resources.CustomJobNoReturnVal().run(reader=table_reader, run_date=date, spark=spark, config=None) assert type(result) is pyspark.sql.DataFrame assert result.columns == ["JOB_NAME", "CREATION_TIME", "VERSION"] @@ -83,9 +75,7 @@ def test_return_dataframe_forwarded_with_version(spark): table_reader = MagicMock() date = datetime.date(2023, 1, 1) - result = resources.CustomJobReturnsDataFrame().run( - reader=table_reader, run_date=date, spark=spark, metadata_manager=None - ) + result = resources.CustomJobReturnsDataFrame().run(reader=table_reader, run_date=date, spark=spark, config=None) assert type(result) is pyspark.sql.DataFrame assert result.columns == ["FIRST", "SECOND", "VERSION"] @@ -97,7 +87,7 @@ def test_none_job_version_wont_fill_job_colun(spark): table_reader = MagicMock() date = datetime.date(2023, 1, 1) - result = resources.CustomJobNoVersion().run(reader=table_reader, run_date=date, spark=spark, metadata_manager=None) + result = resources.CustomJobNoVersion().run(reader=table_reader, run_date=date, spark=spark, config=None) assert type(result) is pyspark.sql.DataFrame assert "VERSION" not in result.columns diff --git a/tests/loader/pyspark/dummy_loaders.py b/tests/loader/pyspark/dummy_loaders.py deleted file mode 100644 index a2b0cb8..0000000 --- a/tests/loader/pyspark/dummy_loaders.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright 2022 ABSA Group Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from datetime import date - -from rialto.loader.data_loader import DataLoader - - -class DummyDataLoader(DataLoader): - def __init__(self): - super().__init__() - - def read_group(self, group: str, information_date: date): - return None diff --git a/tests/loader/pyspark/test_from_cfg.py b/tests/loader/pyspark/test_from_cfg.py index 3ad653e..dd2049f 100644 --- a/tests/loader/pyspark/test_from_cfg.py +++ b/tests/loader/pyspark/test_from_cfg.py @@ -21,7 +21,6 @@ from rialto.loader.config_loader import get_feature_config from rialto.loader.pyspark_feature_loader import PysparkFeatureLoader from tests.loader.pyspark.dataframe_builder import dataframe_builder as dfb -from tests.loader.pyspark.dummy_loaders import DummyDataLoader @pytest.fixture(scope="session") @@ -45,7 +44,7 @@ def spark(request): @pytest.fixture(scope="session") def loader(spark): - return PysparkFeatureLoader(spark, DummyDataLoader(), MagicMock()) + return PysparkFeatureLoader(spark, MagicMock(), MagicMock()) VALID_LIST = [(["a"], ["a"]), (["a"], ["a", "b", "c"]), (["c", "a"], ["a", "b", "c"])] @@ -90,7 +89,7 @@ def __call__(self, *args, **kwargs): metadata = MagicMock() monkeypatch.setattr(metadata, "get_group", GroupMd()) - loader = PysparkFeatureLoader(spark, DummyDataLoader(), "") + loader = PysparkFeatureLoader(spark, "", "") loader.metadata = metadata base = dfb(spark, data=r.base_frame_data, columns=r.base_frame_columns) @@ -105,7 +104,7 @@ def __call__(self, *args, **kwargs): def test_get_group_metadata(spark, mocker): mocker.patch("rialto.loader.pyspark_feature_loader.MetadataManager.get_group", return_value=7) - loader = PysparkFeatureLoader(spark, DummyDataLoader(), "") + loader = PysparkFeatureLoader(spark, "", "") ret_val = loader.get_group_metadata("group_name") assert ret_val == 7 @@ -115,7 +114,7 @@ def test_get_group_metadata(spark, mocker): def test_get_feature_metadata(spark, mocker): mocker.patch("rialto.loader.pyspark_feature_loader.MetadataManager.get_feature", return_value=8) - loader = PysparkFeatureLoader(spark, DummyDataLoader(), "") + loader = PysparkFeatureLoader(spark, "", "") ret_val = loader.get_feature_metadata("group_name", "feature") assert ret_val == 8 @@ -129,7 +128,7 @@ def test_get_metadata_from_cfg(spark, mocker): ) mocker.patch("rialto.loader.pyspark_feature_loader.MetadataManager.get_group", side_effect=lambda g: {"B": 10}[g]) - loader = PysparkFeatureLoader(spark, DummyDataLoader(), "") + loader = PysparkFeatureLoader(spark, "", "") metadata = loader.get_metadata_from_cfg("tests/loader/pyspark/example_cfg.yaml") assert metadata["B_F1"] == 1 diff --git a/tests/runner/conftest.py b/tests/runner/conftest.py index 44f0c09..4e527be 100644 --- a/tests/runner/conftest.py +++ b/tests/runner/conftest.py @@ -39,6 +39,4 @@ def spark(request): @pytest.fixture(scope="function") def basic_runner(spark): - return Runner( - spark, config_path="tests/runner/transformations/config.yaml", feature_metadata_schema="", run_date="2023-03-31" - ) + return Runner(spark, config_path="tests/runner/transformations/config.yaml", run_date="2023-03-31") diff --git a/tests/runner/overrider.yaml b/tests/runner/overrider.yaml new file mode 100644 index 0000000..3029730 --- /dev/null +++ b/tests/runner/overrider.yaml @@ -0,0 +1,86 @@ +# Copyright 2022 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +runner: + watched_period_units: "months" + watched_period_value: 2 + mail: + sender: test@testing.org + smtp: server.test + to: + - developer@testing.org + - developer2@testing.org + subject: test report +pipelines: + - name: SimpleGroup + module: + python_module: tests.runner.transformations + python_class: SimpleGroup + schedule: + frequency: weekly + day: 7 + info_date_shift: + - value: 3 + units: days + - value: 2 + units: weeks + dependencies: + - table: source.schema.dep1 + interval: + units: "days" + value: 1 + date_col: "DATE" + - table: source.schema.dep2 + interval: + units: "months" + value: 3 + date_col: "DATE" + target: + target_schema: catalog.schema + target_partition_column: "INFORMATION_DATE" + feature_loader: + config_path: path/to/config.yaml + feature_schema: catalog.feature_tables + metadata_schema: catalog.metadata + metadata_manager: + metadata_schema: catalog.metadata + - name: OtherGroup + module: + python_module: tests.runner.transformations + python_class: SimpleGroup + schedule: + frequency: weekly + day: 7 + info_date_shift: + - value: 3 + units: days + dependencies: + - table: source.schema.dep1 + name: source1 + interval: + units: "days" + value: 1 + date_col: "DATE" + - table: source.schema.dep2 + name: source2 + interval: + units: "months" + value: 3 + date_col: "batch" + target: + target_schema: catalog.schema + target_partition_column: "INFORMATION_DATE" + extras: + some_value: 3 + some_other_value: cat diff --git a/tests/runner/test_date_manager.py b/tests/runner/test_date_manager.py index 9088e0c..73b61b8 100644 --- a/tests/runner/test_date_manager.py +++ b/tests/runner/test_date_manager.py @@ -144,7 +144,7 @@ def test_run_dates_invalid(): [(7, "2023-02-26"), (3, "2023-03-02"), (-5, "2023-03-10"), (0, "2023-03-05")], ) def test_to_info_date(shift, res): - cfg = ScheduleConfig(frequency="daily", info_date_shift=IntervalConfig(units="days", value=shift)) + cfg = ScheduleConfig(frequency="daily", info_date_shift=[IntervalConfig(units="days", value=shift)]) base = DateManager.str_to_date("2023-03-05") info = DateManager.to_info_date(base, cfg) assert DateManager.str_to_date(res) == info @@ -155,7 +155,7 @@ def test_to_info_date(shift, res): [("days", "2023-03-02"), ("weeks", "2023-02-12"), ("months", "2022-12-05"), ("years", "2020-03-05")], ) def test_info_date_shift_units(unit, result): - cfg = ScheduleConfig(frequency="daily", info_date_shift=IntervalConfig(units=unit, value=3)) + cfg = ScheduleConfig(frequency="daily", info_date_shift=[IntervalConfig(units=unit, value=3)]) base = DateManager.str_to_date("2023-03-05") info = DateManager.to_info_date(base, cfg) assert DateManager.str_to_date(result) == info diff --git a/tests/runner/test_overrides.py b/tests/runner/test_overrides.py new file mode 100644 index 0000000..17fcdbe --- /dev/null +++ b/tests/runner/test_overrides.py @@ -0,0 +1,137 @@ +# Copyright 2022 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from rialto.runner import Runner + + +def test_overrides_simple(spark): + runner = Runner( + spark, + config_path="tests/runner/overrider.yaml", + run_date="2023-03-31", + overrides={"runner.mail.to": ["x@b.c", "y@b.c", "z@b.c"]}, + ) + assert runner.config.runner.mail.to == ["x@b.c", "y@b.c", "z@b.c"] + + +def test_overrides_array_index(spark): + runner = Runner( + spark, + config_path="tests/runner/overrider.yaml", + run_date="2023-03-31", + overrides={"runner.mail.to[1]": "a@b.c"}, + ) + assert runner.config.runner.mail.to == ["developer@testing.org", "a@b.c"] + + +def test_overrides_array_append(spark): + runner = Runner( + spark, + config_path="tests/runner/overrider.yaml", + run_date="2023-03-31", + overrides={"runner.mail.to[-1]": "test"}, + ) + assert runner.config.runner.mail.to == ["developer@testing.org", "developer2@testing.org", "test"] + + +def test_overrides_array_lookup(spark): + runner = Runner( + spark, + config_path="tests/runner/overrider.yaml", + run_date="2023-03-31", + overrides={"pipelines[name=SimpleGroup].target.target_schema": "new_schema"}, + ) + assert runner.config.pipelines[0].target.target_schema == "new_schema" + + +def test_overrides_combined(spark): + runner = Runner( + spark, + config_path="tests/runner/overrider.yaml", + run_date="2023-03-31", + overrides={ + "runner.mail.to": ["x@b.c", "y@b.c", "z@b.c"], + "pipelines[name=SimpleGroup].target.target_schema": "new_schema", + "pipelines[name=SimpleGroup].schedule.info_date_shift[0].value": 1, + }, + ) + assert runner.config.runner.mail.to == ["x@b.c", "y@b.c", "z@b.c"] + assert runner.config.pipelines[0].target.target_schema == "new_schema" + assert runner.config.pipelines[0].schedule.info_date_shift[0].value == 1 + + +def test_index_out_of_range(spark): + with pytest.raises(IndexError) as error: + Runner( + spark, + config_path="tests/runner/overrider.yaml", + run_date="2023-03-31", + overrides={"runner.mail.to[8]": "test"}, + ) + assert error.value.args[0] == "Index 8 out of bounds for key to[8]" + + +def test_invalid_index_key(spark): + with pytest.raises(ValueError) as error: + Runner( + spark, + config_path="tests/runner/overrider.yaml", + run_date="2023-03-31", + overrides={"runner.mail.test[8]": "test"}, + ) + assert error.value.args[0] == "Invalid key test" + + +def test_invalid_key(spark): + with pytest.raises(ValueError) as error: + Runner( + spark, + config_path="tests/runner/overrider.yaml", + run_date="2023-03-31", + overrides={"runner.mail.test.param": "test"}, + ) + assert error.value.args[0] == "Invalid key test" + + +def test_replace_section(spark): + runner = Runner( + spark, + config_path="tests/runner/overrider.yaml", + run_date="2023-03-31", + overrides={ + "pipelines[name=SimpleGroup].feature_loader": { + "config_path": "features_cfg.yaml", + "feature_schema": "catalog.features", + "metadata_schema": "catalog.metadata", + } + }, + ) + assert runner.config.pipelines[0].feature_loader.feature_schema == "catalog.features" + + +def test_add_section(spark): + runner = Runner( + spark, + config_path="tests/runner/overrider.yaml", + run_date="2023-03-31", + overrides={ + "pipelines[name=OtherGroup].feature_loader": { + "config_path": "features_cfg.yaml", + "feature_schema": "catalog.features", + "metadata_schema": "catalog.metadata", + } + }, + ) + assert runner.config.pipelines[1].feature_loader.feature_schema == "catalog.features" diff --git a/tests/runner/test_runner.py b/tests/runner/test_runner.py index 0459411..e23eee8 100644 --- a/tests/runner/test_runner.py +++ b/tests/runner/test_runner.py @@ -11,15 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from datetime import datetime from typing import Optional import pytest from pyspark.sql import DataFrame +import rialto.runner.utils as utils from rialto.common.table_reader import DataReader -from rialto.jobs.configuration.config_holder import ConfigHolder from rialto.runner.runner import DateManager, Runner from rialto.runner.table import Table from tests.runner.runner_resources import ( @@ -38,8 +37,8 @@ def __init__(self, spark): def get_table( self, table: str, - info_date_from: Optional[datetime.date] = None, - info_date_to: Optional[datetime.date] = None, + date_from: Optional[datetime.date] = None, + date_to: Optional[datetime.date] = None, date_column: str = None, uppercase_columns: bool = False, ) -> DataFrame: @@ -53,114 +52,79 @@ def get_table( def get_latest( self, table: str, - until: Optional[datetime.date] = None, + date_until: Optional[datetime.date] = None, date_column: str = None, uppercase_columns: bool = False, ) -> DataFrame: pass -def test_table_exists(spark, mocker, basic_runner): +def test_table_exists(spark, mocker): mock = mocker.patch("pyspark.sql.Catalog.tableExists", return_value=True) - basic_runner._table_exists("abc") + utils.table_exists(spark, "abc") mock.assert_called_once_with("abc") -def test_infer_column(spark, mocker, basic_runner): - column = namedtuple("catalog", ["name", "isPartition"]) - catalog = [column("a", True), column("b", False), column("c", False)] - - mock = mocker.patch("pyspark.sql.Catalog.listColumns", return_value=catalog) - partition = basic_runner._delta_partition("aaa") - assert partition == "a" - mock.assert_called_once_with("aaa") - - def test_load_module(spark, basic_runner): - module = basic_runner._load_module(basic_runner.config.pipelines[0].module) + module = utils.load_module(basic_runner.config.pipelines[0].module) assert isinstance(module, SimpleGroup) def test_generate(spark, mocker, basic_runner): run = mocker.patch("tests.runner.transformations.simple_group.SimpleGroup.run") group = SimpleGroup() - basic_runner._generate(group, DateManager.str_to_date("2023-01-31")) + config = basic_runner.config.pipelines[0] + basic_runner._execute(group, DateManager.str_to_date("2023-01-31"), config) + run.assert_called_once_with( reader=basic_runner.reader, run_date=DateManager.str_to_date("2023-01-31"), spark=spark, - metadata_manager=basic_runner.metadata, - dependencies=None, + config=config, + metadata_manager=None, + feature_loader=None, ) def test_generate_w_dep(spark, mocker, basic_runner): run = mocker.patch("tests.runner.transformations.simple_group.SimpleGroup.run") group = SimpleGroup() - basic_runner._generate(group, DateManager.str_to_date("2023-01-31"), basic_runner.config.pipelines[2].dependencies) + basic_runner._execute(group, DateManager.str_to_date("2023-01-31"), basic_runner.config.pipelines[2]) run.assert_called_once_with( reader=basic_runner.reader, run_date=DateManager.str_to_date("2023-01-31"), spark=spark, - metadata_manager=basic_runner.metadata, - dependencies={ - "source1": basic_runner.config.pipelines[2].dependencies[0], - "source2": basic_runner.config.pipelines[2].dependencies[1], - }, + config=basic_runner.config.pipelines[2], + metadata_manager=None, + feature_loader=None, ) def test_init_dates(spark): - runner = Runner( - spark, config_path="tests/runner/transformations/config.yaml", feature_metadata_schema="", run_date="2023-03-31" - ) + runner = Runner(spark, config_path="tests/runner/transformations/config.yaml", run_date="2023-03-31") assert runner.date_from == DateManager.str_to_date("2023-01-31") assert runner.date_until == DateManager.str_to_date("2023-03-31") runner = Runner( spark, config_path="tests/runner/transformations/config.yaml", - feature_metadata_schema="", - date_from="2023-03-01", - date_until="2023-03-31", + run_date="2023-03-31", + overrides={"runner.watched_period_units": "weeks", "runner.watched_period_value": 2}, ) - assert runner.date_from == DateManager.str_to_date("2023-03-01") + assert runner.date_from == DateManager.str_to_date("2023-03-17") assert runner.date_until == DateManager.str_to_date("2023-03-31") runner = Runner( spark, config_path="tests/runner/transformations/config2.yaml", - feature_metadata_schema="", run_date="2023-03-31", ) assert runner.date_from == DateManager.str_to_date("2023-02-24") assert runner.date_until == DateManager.str_to_date("2023-03-31") -def test_possible_run_dates(spark): - runner = Runner( - spark, - config_path="tests/runner/transformations/config.yaml", - feature_metadata_schema="", - date_from="2023-03-01", - date_until="2023-03-31", - ) - - dates = runner.get_possible_run_dates(runner.config.pipelines[0].schedule) - expected = ["2023-03-05", "2023-03-12", "2023-03-19", "2023-03-26"] - assert dates == [DateManager.str_to_date(d) for d in expected] - - -def test_info_dates(spark, basic_runner): - run = ["2023-02-05", "2023-02-12", "2023-02-19", "2023-02-26", "2023-03-05"] - run = [DateManager.str_to_date(d) for d in run] - info = basic_runner.get_info_dates(basic_runner.config.pipelines[0].schedule, run) - expected = ["2023-02-02", "2023-02-09", "2023-02-16", "2023-02-23", "2023-03-02"] - assert info == [DateManager.str_to_date(d) for d in expected] - - def test_completion(spark, mocker, basic_runner): - mocker.patch("rialto.runner.runner.Runner._table_exists", return_value=True) + mocker.patch("rialto.runner.utils.table_exists", return_value=True) basic_runner.reader = MockReader(spark) @@ -173,11 +137,9 @@ def test_completion(spark, mocker, basic_runner): def test_completion_rerun(spark, mocker, basic_runner): - mocker.patch("rialto.runner.runner.Runner._table_exists", return_value=True) + mocker.patch("rialto.runner.runner.utils.table_exists", return_value=True) - runner = Runner( - spark, config_path="tests/runner/transformations/config.yaml", feature_metadata_schema="", run_date="2023-03-31" - ) + runner = Runner(spark, config_path="tests/runner/transformations/config.yaml", run_date="2023-03-31") runner.reader = MockReader(spark) dates = ["2023-02-26", "2023-03-05", "2023-03-12", "2023-03-19", "2023-03-26"] @@ -189,14 +151,12 @@ def test_completion_rerun(spark, mocker, basic_runner): def test_check_dates_have_partition(spark, mocker): - mocker.patch("rialto.runner.runner.Runner._table_exists", return_value=True) + mocker.patch("rialto.runner.runner.utils.table_exists", return_value=True) runner = Runner( spark, config_path="tests/runner/transformations/config.yaml", - feature_metadata_schema="", - date_from="2023-03-01", - date_until="2023-03-31", + run_date="2023-03-31", ) runner.reader = MockReader(spark) dates = ["2023-03-04", "2023-03-05", "2023-03-06"] @@ -207,14 +167,12 @@ def test_check_dates_have_partition(spark, mocker): def test_check_dates_have_partition_no_table(spark, mocker): - mocker.patch("rialto.runner.runner.Runner._table_exists", return_value=False) + mocker.patch("rialto.runner.runner.utils.table_exists", return_value=False) runner = Runner( spark, config_path="tests/runner/transformations/config.yaml", - feature_metadata_schema="", - date_from="2023-03-01", - date_until="2023-03-31", + run_date="2023-03-31", ) dates = ["2023-03-04", "2023-03-05", "2023-03-06"] dates = [DateManager.str_to_date(d) for d in dates] @@ -228,14 +186,12 @@ def test_check_dates_have_partition_no_table(spark, mocker): [("2023-02-26", False), ("2023-03-05", True)], ) def test_check_dependencies(spark, mocker, r_date, expected): - mocker.patch("rialto.runner.runner.Runner._table_exists", return_value=True) + mocker.patch("rialto.runner.runner.utils.table_exists", return_value=True) runner = Runner( spark, config_path="tests/runner/transformations/config.yaml", - feature_metadata_schema="", - date_from="2023-03-01", - date_until="2023-03-31", + run_date="2023-03-31", ) runner.reader = MockReader(spark) res = runner.check_dependencies(runner.config.pipelines[0], DateManager.str_to_date(r_date)) @@ -243,14 +199,12 @@ def test_check_dependencies(spark, mocker, r_date, expected): def test_check_no_dependencies(spark, mocker): - mocker.patch("rialto.runner.runner.Runner._table_exists", return_value=True) + mocker.patch("rialto.runner.runner.utils.table_exists", return_value=True) runner = Runner( spark, config_path="tests/runner/transformations/config.yaml", - feature_metadata_schema="", - date_from="2023-03-01", - date_until="2023-03-31", + run_date="2023-03-31", ) runner.reader = MockReader(spark) res = runner.check_dependencies(runner.config.pipelines[1], DateManager.str_to_date("2023-03-05")) @@ -258,14 +212,13 @@ def test_check_no_dependencies(spark, mocker): def test_select_dates(spark, mocker): - mocker.patch("rialto.runner.runner.Runner._table_exists", return_value=True) + mocker.patch("rialto.runner.runner.utils.table_exists", return_value=True) runner = Runner( spark, config_path="tests/runner/transformations/config.yaml", - feature_metadata_schema="", - date_from="2023-03-01", - date_until="2023-03-31", + run_date="2023-03-31", + overrides={"runner.watched_period_units": "months", "runner.watched_period_value": 1}, ) runner.reader = MockReader(spark) @@ -281,14 +234,13 @@ def test_select_dates(spark, mocker): def test_select_dates_all_done(spark, mocker): - mocker.patch("rialto.runner.runner.Runner._table_exists", return_value=True) + mocker.patch("rialto.runner.runner.utils.table_exists", return_value=True) runner = Runner( spark, config_path="tests/runner/transformations/config.yaml", - feature_metadata_schema="", - date_from="2023-03-02", - date_until="2023-03-02", + run_date="2023-03-02", + overrides={"runner.watched_period_units": "months", "runner.watched_period_value": 0}, ) runner.reader = MockReader(spark) @@ -307,9 +259,7 @@ def test_op_selected(spark, mocker): mocker.patch("rialto.runner.tracker.Tracker.report") run = mocker.patch("rialto.runner.runner.Runner._run_pipeline") - runner = Runner( - spark, config_path="tests/runner/transformations/config.yaml", feature_metadata_schema="", op="SimpleGroup" - ) + runner = Runner(spark, config_path="tests/runner/transformations/config.yaml", op="SimpleGroup") runner() run.called_once() @@ -319,42 +269,8 @@ def test_op_bad(spark, mocker): mocker.patch("rialto.runner.tracker.Tracker.report") mocker.patch("rialto.runner.runner.Runner._run_pipeline") - runner = Runner( - spark, config_path="tests/runner/transformations/config.yaml", feature_metadata_schema="", op="BadOp" - ) + runner = Runner(spark, config_path="tests/runner/transformations/config.yaml", op="BadOp") with pytest.raises(ValueError) as exception: runner() assert str(exception.value) == "Unknown operation selected: BadOp" - - -def test_custom_config(spark, mocker): - cc_spy = mocker.spy(ConfigHolder, "set_custom_config") - custom_config = {"cc": 42} - - _ = Runner(spark, config_path="tests/runner/transformations/config.yaml", custom_job_config=custom_config) - - cc_spy.assert_called_once_with(cc=42) - - -def test_feature_store_config(spark, mocker): - fs_spy = mocker.spy(ConfigHolder, "set_feature_store_config") - - _ = Runner( - spark, - config_path="tests/runner/transformations/config.yaml", - feature_store_schema="schema", - feature_metadata_schema="metadata", - ) - - fs_spy.assert_called_once_with("schema", "metadata") - - -def test_no_configs(spark, mocker): - cc_spy = mocker.spy(ConfigHolder, "set_custom_config") - fs_spy = mocker.spy(ConfigHolder, "set_feature_store_config") - - _ = Runner(spark, config_path="tests/runner/transformations/config.yaml") - - cc_spy.assert_not_called() - fs_spy.assert_not_called() diff --git a/tests/runner/transformations/config.yaml b/tests/runner/transformations/config.yaml index 2bfeaf1..3b72107 100644 --- a/tests/runner/transformations/config.yaml +++ b/tests/runner/transformations/config.yaml @@ -12,12 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -general: - target_schema: catalog.schema - target_partition_column: "INFORMATION_DATE" +runner: watched_period_units: "months" watched_period_value: 2 - job: "run" # run/check mail: sender: test@testing.org smtp: server.test @@ -34,8 +31,8 @@ pipelines: frequency: weekly day: 7 info_date_shift: - value: 3 - units: days + - value: 3 + units: days dependencies: - table: source.schema.dep1 interval: @@ -47,6 +44,9 @@ pipelines: units: "months" value: 3 date_col: "DATE" + target: + target_schema: catalog.schema + target_partition_column: "INFORMATION_DATE" - name: GroupNoDeps module: python_module: tests.runner.transformations @@ -55,8 +55,8 @@ pipelines: frequency: weekly day: 7 info_date_shift: - value: 3 - units: days + - value: 3 + units: days - name: NamedDeps module: python_module: tests.runner.transformations @@ -65,8 +65,8 @@ pipelines: frequency: weekly day: 7 info_date_shift: - value: 3 - units: days + - value: 3 + units: days dependencies: - table: source.schema.dep1 name: source1 @@ -80,3 +80,6 @@ pipelines: units: "months" value: 3 date_col: "batch" + target: + target_schema: catalog.schema + target_partition_column: "INFORMATION_DATE" diff --git a/tests/runner/transformations/config2.yaml b/tests/runner/transformations/config2.yaml index a91894b..f7b9604 100644 --- a/tests/runner/transformations/config2.yaml +++ b/tests/runner/transformations/config2.yaml @@ -12,12 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -general: - target_schema: catalog.schema - target_partition_column: "INFORMATION_DATE" +runner: watched_period_units: "weeks" watched_period_value: 5 - job: "run" # run/check mail: sender: test@testing.org smtp: server.test @@ -43,3 +40,6 @@ pipelines: units: "months" value: 1 date_col: "DATE" + target: + target_schema: catalog.schema + target_partition_column: "INFORMATION_DATE" diff --git a/tests/runner/transformations/simple_group.py b/tests/runner/transformations/simple_group.py index fcda5c7..ec2311c 100644 --- a/tests/runner/transformations/simple_group.py +++ b/tests/runner/transformations/simple_group.py @@ -18,6 +18,7 @@ from pyspark.sql.types import StructType from rialto.common import TableReader +from rialto.loader import PysparkFeatureLoader from rialto.metadata import MetadataManager from rialto.runner import Transformation @@ -28,7 +29,8 @@ def run( reader: TableReader, run_date: datetime.date, spark: SparkSession = None, - metadata_manager: MetadataManager = None, - dependencies: Dict = None, + config: Dict = None, + metadata: MetadataManager = None, + feature_loader: PysparkFeatureLoader = None, ) -> DataFrame: return spark.createDataFrame([], StructType([]))