Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolver Resolution Test Utils #9

Merged
merged 1 commit into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion rialto/jobs/decorators/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
import importlib
import typing
from contextlib import contextmanager
from unittest.mock import patch
from unittest.mock import patch, create_autospec, MagicMock
from rialto.jobs.decorators.resolver import Resolver, ResolverException
from rialto.jobs.decorators.job_base import JobBase


def _passthrough_decorator(*args, **kwargs) -> typing.Callable:
Expand Down Expand Up @@ -58,3 +60,45 @@ def disable_job_decorators(module) -> None:
yield

importlib.reload(module)


def resolver_resolves(spark, job: JobBase) -> bool:
"""
Checker method for your dependency resoultion.

If your job's dependencies are all defined and resolvable, returns true.
Otherwise, throws an exception.

:param spark: SparkSession object.
:param job: Job to try and resolve.

:return: bool, True if job can be resolved
"""

class SmartStorage:
def __init__(self):
self._storage = Resolver._storage.copy()
self._call_stack = []

def __setitem__(self, key, value):
self._storage[key] = value

def keys(self):
return self._storage.keys()

def __getitem__(self, func_name):
if func_name in self._call_stack:
raise ResolverException(f"Circular Dependence on {func_name}!")

self._call_stack.append(func_name)

real_method = self._storage[func_name]
fake_method = create_autospec(real_method)
fake_method.side_effect = lambda *args, **kwargs: self._call_stack.remove(func_name)

return fake_method

with patch("rialto.jobs.decorators.resolver.Resolver._storage", SmartStorage()):
job().run(reader=MagicMock(), run_date=MagicMock(), spark=spark)

return True
51 changes: 51 additions & 0 deletions tests/jobs/test_job/dependency_tests_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from rialto.jobs.decorators import job, datasource


@datasource
def a():
return 1


@datasource
def b(a):
return a + 1


@datasource
def c(a, b):
return a + b


@job
def ok_dependency_job(c):
return c + 1


@datasource
def d(a, circle_1):
return circle_1 + a


@datasource
def circle_1(circle_2):
return circle_2 + 1


@datasource
def circle_2(circle_1):
return circle_1 + 1


@job
def circular_dependency_job(d):
return d + 1


@job
def missing_dependency_job(a, x):
return x + a


@job
def default_dependency_job(run_date, spark, config, dependencies, table_reader, feature_loader):
return 1
29 changes: 27 additions & 2 deletions tests/jobs/test_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

import rialto.jobs.decorators as decorators
import tests.jobs.test_job.test_job as test_job
import tests.jobs.test_job.dependency_tests_job as dependency_tests_job
from rialto.jobs.decorators.resolver import Resolver
from rialto.jobs.decorators.test_utils import disable_job_decorators
from rialto.jobs.decorators.test_utils import disable_job_decorators, resolver_resolves


def test_raw_dataset_patch(mocker):
Expand Down Expand Up @@ -46,3 +47,27 @@ def test_custom_name_job_function_patch(mocker):
assert test_job.custom_name_job_function() == "custom_job_name_return"

spy_dec.assert_not_called()


def test_resolver_resolves_ok_job(spark):
assert resolver_resolves(spark, dependency_tests_job.ok_dependency_job)


def test_resolver_resolves_default_dependency(spark):
assert resolver_resolves(spark, dependency_tests_job.default_dependency_job)


def test_resolver_resolves_fails_circular_dependency(spark):
with pytest.raises(Exception) as exc_info:
assert resolver_resolves(spark, dependency_tests_job.circular_dependency_job)

assert exc_info is not None
assert str(exc_info.value) == "Circular Dependence on circle_1!"


def test_resolver_resolves_fails_missing_dependency(spark):
with pytest.raises(Exception) as exc_info:
assert resolver_resolves(spark, dependency_tests_job.missing_dependency_job)

assert exc_info is not None
assert str(exc_info.value) == "x declaration not found!"
Loading