Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfixes v200 #19

Merged
merged 3 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
project = "rialto"
copyright = "2022, Marek Dobransky"
author = "Marek Dobransky"
release = "1.3.0"
release = "2.0.1"

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
Expand Down
9 changes: 9 additions & 0 deletions rialto/jobs/module_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@ def register_callable(cls, callable):
callable_module = callable.__module__
cls.add_callable_to_module(callable, callable_module)

@classmethod
def remove_module(cls, module):
"""
Remove a module from the storage.
:param module: The module to be removed.
"""
cls._storage.pop(module.__name__, None)

@classmethod
def register_dependency(cls, module, parent_name):
"""
Expand Down
9 changes: 6 additions & 3 deletions rialto/jobs/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from unittest.mock import MagicMock, create_autospec, patch

from rialto.jobs.job_base import JobBase
from rialto.jobs.module_register import ModuleRegister
from rialto.jobs.resolver import Resolver, ResolverException


Expand Down Expand Up @@ -59,15 +60,17 @@ def disable_job_decorators(module) -> None:
:return: None
"""
with _disable_job_decorators():
ModuleRegister.remove_module(module)
importlib.reload(module)
yield

ModuleRegister.remove_module(module)
importlib.reload(module)


def resolver_resolves(spark, job: JobBase) -> bool:
"""
Checker method for your dependency resoultion.
Checker method for your dependency resolution.

If your job's dependencies are all defined and resolvable, returns true.
Otherwise, throws an exception.
Expand Down Expand Up @@ -100,8 +103,8 @@ def stack_watching_resolver_resolve(self, callable):

return result

with patch(f"rialto.jobs.job_base.Resolver.resolve", stack_watching_resolver_resolve):
with patch(f"rialto.jobs.job_base.JobBase._add_job_version", lambda _, x: x):
with patch("rialto.jobs.job_base.Resolver.resolve", stack_watching_resolver_resolve):
with patch("rialto.jobs.job_base.JobBase._add_job_version", lambda _, x: x):
job().run(
reader=MagicMock(),
run_date=MagicMock(),
Expand Down
7 changes: 5 additions & 2 deletions rialto/maker/feature_maker.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ def _set_values(self, df: DataFrame, key: typing.Union[str, typing.List[str]], m
:return: None
"""
self.data_frame = df
self.key = key
if isinstance(key, str):
self.key = [key]
else:
self.key = key
self.make_date = make_date

def _order_by_dependencies(self, feature_holders: typing.List[FeatureHolder]) -> typing.List[FeatureHolder]:
Expand Down Expand Up @@ -136,7 +139,7 @@ def _make_sequential(self, keep_preexisting: bool) -> DataFrame:
)
if not keep_preexisting:
logger.info("Dropping non-selected columns")
self.data_frame = self.data_frame.select(self.key, *feature_names)
self.data_frame = self.data_frame.select(*self.key, *feature_names)
return self._filter_null_keys(self.data_frame)

def _make_aggregated(self) -> DataFrame:
Expand Down
4 changes: 3 additions & 1 deletion rialto/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(

if self.date_from > self.date_until:
raise ValueError(f"Invalid date range from {self.date_from} until {self.date_until}")
logger.info(f"Running period from {self.date_from} until {self.date_until}")
logger.info(f"Running period set to: {self.date_from} - {self.date_until}")

def _execute(self, instance: Transformation, run_date: date, pipeline: PipelineConfig) -> DataFrame:
"""
Expand Down Expand Up @@ -285,6 +285,7 @@ def _run_pipeline(self, pipeline: PipelineConfig):

def __call__(self):
"""Execute pipelines"""
logger.info("Executing pipelines")
try:
if self.op:
selected = [p for p in self.config.pipelines if p.name == self.op]
Expand All @@ -297,3 +298,4 @@ def __call__(self):
finally:
print(self.tracker.records)
self.tracker.report(self.config.runner.mail)
logger.info("Execution finished")
5 changes: 5 additions & 0 deletions tests/jobs/test_job/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ def job_function():
return "job_function_return"


@job
def job_with_datasource(dataset):
return dataset


@job(custom_name="custom_job_name")
def custom_name_job_function():
return "custom_job_name_return"
Expand Down
8 changes: 8 additions & 0 deletions tests/jobs/test_register.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from rialto.jobs.test_utils import disable_job_decorators, resolver_resolves
from tests.jobs.test_job import test_job


def test_resolve_after_disable(spark):
with disable_job_decorators(test_job):
assert test_job.job_with_datasource("test") == "test"
assert resolver_resolves(spark, test_job.job_with_datasource)
7 changes: 7 additions & 0 deletions tests/maker/test_FeatureMaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ def test_sequential_multi_key(input_df):
assert "TRANSACTIONS_OUTBOUND_VALUE" in df.columns


def test_sequential_multi_key_drop(input_df):
df, _ = FeatureMaker.make(
input_df, ["CUSTOMER_KEY", "TYPE"], date.today(), sequential_outbound, keep_preexisting=False
)
assert "TRANSACTIONS_OUTBOUND_VALUE" in df.columns


def test_sequential_keeps(input_df):
df, _ = FeatureMaker.make(input_df, "CUSTOMER_KEY", date.today(), sequential_outbound, keep_preexisting=True)
assert "AMT" in df.columns
Expand Down
Loading