Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Switchable Version Column
Browse files Browse the repository at this point in the history
vvancak committed Jul 23, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent f0ae0c7 commit fd228a1
Showing 7 changed files with 64 additions and 19 deletions.
37 changes: 24 additions & 13 deletions rialto/jobs/decorators/decorators.py
Original file line number Diff line number Diff line change
@@ -14,10 +14,10 @@

__all__ = ["datasource", "job"]

import importlib_metadata
import inspect
import typing

import importlib_metadata
from loguru import logger

from rialto.jobs.decorators.job_base import JobBase
@@ -47,7 +47,7 @@ def _get_module(stack: typing.List) -> typing.Any:
def _get_version(module: typing.Any) -> str:
try:
package_name, _, _ = module.__name__.partition(".")
dist_name = importlib_metadata.packages_distributions()[package_name][0]
dist_name = importlib_metadata.packages_distributions()[package_name][0]
return importlib_metadata.version(dist_name)

except Exception:
@@ -73,29 +73,40 @@ def _generate_rialto_job(callable: typing.Callable, module: object, class_name:
return generated_class


def job(name_or_callable: typing.Union[str, typing.Callable]) -> typing.Union[typing.Callable, typing.Type]:
def job(*args, custom_name=None, disable_version=False):
"""
Rialto jobs decorator.
Transforms a python function into a rialto transormation, which can be imported and ran by Rialto Runner.
Allows a custom name, via @job("custom_name_here") or can be just used as @job and the function's name is used.
Is mainly used as @job and the function's name is used, and the outputs get automatic.
To override this behavious, use @job(custom_name=XXX, disable_version=True).
:param name_or_callable: str for custom job name. Otherwise, run function.
:return: One more job wrapper for run function (if custom name specified).
:param *args: list of positional arguments. Empty in case custom_name or disable_version is specified.
:param custom_name: str for custom job name.
:param disable_version: bool for disable autofilling the VERSION column in the job's outputs.
:return: One more job wrapper for run function (if custom name or version override specified).
Otherwise, generates Rialto Transformation Type and returns it for in-module registration.
"""
stack = inspect.stack()

module = _get_module(stack)
version = _get_version(module)

if type(name_or_callable) is str:
# Use case where it's just raw @f. Otherwise we get [] here.
if len(args) == 1 and callable(args[0]):
return _generate_rialto_job(callable=args[0], module=module, class_name=module.__name__, version=version)

# Setting default custom name, in case user only disables version
if custom_name is None:
custom_name = module.__name__

def inner_wrapper(callable):
return _generate_rialto_job(callable, module, name_or_callable, version)
# Setting version to None causes JobBase to not fill it
if disable_version:
version = None

return inner_wrapper
# We need to return one more wrapper
def inner_wrapper(f):
return _generate_rialto_job(callable=f, module=module, class_name=custom_name, version=version)

else:
name = name_or_callable.__name__
return _generate_rialto_job(name_or_callable, module, name, version)
return inner_wrapper
6 changes: 5 additions & 1 deletion rialto/jobs/decorators/job_base.py
Original file line number Diff line number Diff line change
@@ -96,7 +96,11 @@ def _get_timestamp_holder_result(self) -> DataFrame:

def _add_job_version(self, df: DataFrame) -> DataFrame:
version = self.get_job_version()
return df.withColumn("VERSION", F.lit(version))

if version is not None:
return df.withColumn("VERSION", F.lit(version))

return df

def _run_main_callable(self, run_date: datetime.date) -> DataFrame:
with self._setup_resolver(run_date):
7 changes: 3 additions & 4 deletions rialto/jobs/decorators/test_utils.py
Original file line number Diff line number Diff line change
@@ -20,12 +20,11 @@
from unittest.mock import patch


def _passthrough_decorator(x: typing.Callable) -> typing.Callable:
if type(x) is str:
def _passthrough_decorator(*args, **kwargs) -> typing.Callable:
if len(args) == 0:
return _passthrough_decorator

else:
return x
return args[0]


@contextmanager
5 changes: 5 additions & 0 deletions tests/jobs/resources.py
Original file line number Diff line number Diff line change
@@ -41,3 +41,8 @@ def f(spark):
return spark.createDataFrame(df)

return f


class CustomJobNoVersion(CustomJobNoReturnVal):
def get_job_version(self) -> str:
return None
11 changes: 11 additions & 0 deletions tests/jobs/test_decorators.py
Original file line number Diff line number Diff line change
@@ -57,6 +57,17 @@ def test_custom_name_function():
custom_callable = result_class.get_custom_callable()
assert custom_callable() == "custom_job_name_return"

job_name = result_class.get_job_name()
assert job_name == "custom_job_name"


def test_job_disabling_version():
result_class = _rialto_import_stub("tests.jobs.test_job.test_job", "disable_version_job_function")
assert issubclass(type(result_class), JobBase)

job_version = result_class.get_job_version()
assert job_version is None


def test_job_dependencies_registered(spark):
ConfigHolder.set_custom_config(value=123)
7 changes: 6 additions & 1 deletion tests/jobs/test_job/test_job.py
Original file line number Diff line number Diff line change
@@ -26,11 +26,16 @@ def job_function():
return "job_function_return"


@job("custom_job_name")
@job(custom_name="custom_job_name")
def custom_name_job_function():
return "custom_job_name_return"


@job(disable_version=True)
def disable_version_job_function():
return "disabled_version_job_return"


@job
def job_asking_for_all_deps(spark, run_date, config, dependencies, table_reader):
assert spark is not None
10 changes: 10 additions & 0 deletions tests/jobs/test_job_base.py
Original file line number Diff line number Diff line change
@@ -91,3 +91,13 @@ def test_return_dataframe_forwarded_with_version(spark):
assert result.columns == ["FIRST", "SECOND", "VERSION"]
assert result.first()["VERSION"] == "job_version"
assert result.count() == 2


def test_none_job_version_wont_fill_job_colun(spark):
table_reader = MagicMock()
date = datetime.date(2023, 1, 1)

result = resources.CustomJobNoVersion().run(reader=table_reader, run_date=date, spark=spark, metadata_manager=None)

assert type(result) is pyspark.sql.DataFrame
assert "VERSION" not in result.columns

0 comments on commit fd228a1

Please sign in to comment.