From b5b15aed15f4de976fe9fba2603fa6d0e427cb95 Mon Sep 17 00:00:00 2001 From: Marek Dobransky Date: Fri, 27 Sep 2024 15:12:07 +0200 Subject: [PATCH 1/3] list key for sequentail features --- docs/source/conf.py | 2 +- rialto/maker/feature_maker.py | 7 +++++-- tests/maker/test_FeatureMaker.py | 7 +++++++ 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 1ce2224..07257e5 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -28,7 +28,7 @@ project = "rialto" copyright = "2022, Marek Dobransky" author = "Marek Dobransky" -release = "1.3.0" +release = "2.0.1" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/rialto/maker/feature_maker.py b/rialto/maker/feature_maker.py index 6cbc1ee..7aa6c85 100644 --- a/rialto/maker/feature_maker.py +++ b/rialto/maker/feature_maker.py @@ -49,7 +49,10 @@ def _set_values(self, df: DataFrame, key: typing.Union[str, typing.List[str]], m :return: None """ self.data_frame = df - self.key = key + if isinstance(key, str): + self.key = [key] + else: + self.key = key self.make_date = make_date def _order_by_dependencies(self, feature_holders: typing.List[FeatureHolder]) -> typing.List[FeatureHolder]: @@ -136,7 +139,7 @@ def _make_sequential(self, keep_preexisting: bool) -> DataFrame: ) if not keep_preexisting: logger.info("Dropping non-selected columns") - self.data_frame = self.data_frame.select(self.key, *feature_names) + self.data_frame = self.data_frame.select(*self.key, *feature_names) return self._filter_null_keys(self.data_frame) def _make_aggregated(self) -> DataFrame: diff --git a/tests/maker/test_FeatureMaker.py b/tests/maker/test_FeatureMaker.py index b8f9c1a..f3da26d 100644 --- a/tests/maker/test_FeatureMaker.py +++ b/tests/maker/test_FeatureMaker.py @@ -61,6 +61,13 @@ def test_sequential_multi_key(input_df): assert "TRANSACTIONS_OUTBOUND_VALUE" in df.columns +def test_sequential_multi_key_drop(input_df): + df, _ = FeatureMaker.make( + input_df, ["CUSTOMER_KEY", "TYPE"], date.today(), sequential_outbound, keep_preexisting=False + ) + assert "TRANSACTIONS_OUTBOUND_VALUE" in df.columns + + def test_sequential_keeps(input_df): df, _ = FeatureMaker.make(input_df, "CUSTOMER_KEY", date.today(), sequential_outbound, keep_preexisting=True) assert "AMT" in df.columns From 1020f55d074621f396b79b57876c975c8602f4e7 Mon Sep 17 00:00:00 2001 From: Marek Dobransky Date: Fri, 27 Sep 2024 15:16:04 +0200 Subject: [PATCH 2/3] logging clarity --- rialto/runner/runner.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/rialto/runner/runner.py b/rialto/runner/runner.py index ac9d6bc..49280b6 100644 --- a/rialto/runner/runner.py +++ b/rialto/runner/runner.py @@ -67,7 +67,7 @@ def __init__( if self.date_from > self.date_until: raise ValueError(f"Invalid date range from {self.date_from} until {self.date_until}") - logger.info(f"Running period from {self.date_from} until {self.date_until}") + logger.info(f"Running period set to: {self.date_from} - {self.date_until}") def _execute(self, instance: Transformation, run_date: date, pipeline: PipelineConfig) -> DataFrame: """ @@ -285,6 +285,7 @@ def _run_pipeline(self, pipeline: PipelineConfig): def __call__(self): """Execute pipelines""" + logger.info("Executing pipelines") try: if self.op: selected = [p for p in self.config.pipelines if p.name == self.op] @@ -297,3 +298,4 @@ def __call__(self): finally: print(self.tracker.records) self.tracker.report(self.config.runner.mail) + logger.info("Execution finished") From 1acc6bd015da2dd36085cacb6d9ec889f582f62f Mon Sep 17 00:00:00 2001 From: Marek Dobransky Date: Fri, 27 Sep 2024 15:59:01 +0200 Subject: [PATCH 3/3] fix for register on reload --- rialto/jobs/module_register.py | 9 +++++++++ rialto/jobs/test_utils.py | 9 ++++++--- tests/jobs/test_job/test_job.py | 5 +++++ tests/jobs/test_register.py | 8 ++++++++ 4 files changed, 28 insertions(+), 3 deletions(-) create mode 100644 tests/jobs/test_register.py diff --git a/rialto/jobs/module_register.py b/rialto/jobs/module_register.py index 8283454..27a55ef 100644 --- a/rialto/jobs/module_register.py +++ b/rialto/jobs/module_register.py @@ -56,6 +56,15 @@ def register_callable(cls, callable): callable_module = callable.__module__ cls.add_callable_to_module(callable, callable_module) + @classmethod + def remove_module(cls, module): + """ + Remove a module from the storage. + + :param module: The module to be removed. + """ + cls._storage.pop(module.__name__, None) + @classmethod def register_dependency(cls, module, parent_name): """ diff --git a/rialto/jobs/test_utils.py b/rialto/jobs/test_utils.py index d8f2945..cced2fe 100644 --- a/rialto/jobs/test_utils.py +++ b/rialto/jobs/test_utils.py @@ -20,6 +20,7 @@ from unittest.mock import MagicMock, create_autospec, patch from rialto.jobs.job_base import JobBase +from rialto.jobs.module_register import ModuleRegister from rialto.jobs.resolver import Resolver, ResolverException @@ -59,15 +60,17 @@ def disable_job_decorators(module) -> None: :return: None """ with _disable_job_decorators(): + ModuleRegister.remove_module(module) importlib.reload(module) yield + ModuleRegister.remove_module(module) importlib.reload(module) def resolver_resolves(spark, job: JobBase) -> bool: """ - Checker method for your dependency resoultion. + Checker method for your dependency resolution. If your job's dependencies are all defined and resolvable, returns true. Otherwise, throws an exception. @@ -100,8 +103,8 @@ def stack_watching_resolver_resolve(self, callable): return result - with patch(f"rialto.jobs.job_base.Resolver.resolve", stack_watching_resolver_resolve): - with patch(f"rialto.jobs.job_base.JobBase._add_job_version", lambda _, x: x): + with patch("rialto.jobs.job_base.Resolver.resolve", stack_watching_resolver_resolve): + with patch("rialto.jobs.job_base.JobBase._add_job_version", lambda _, x: x): job().run( reader=MagicMock(), run_date=MagicMock(), diff --git a/tests/jobs/test_job/test_job.py b/tests/jobs/test_job/test_job.py index 4e47364..01c8df9 100644 --- a/tests/jobs/test_job/test_job.py +++ b/tests/jobs/test_job/test_job.py @@ -29,6 +29,11 @@ def job_function(): return "job_function_return" +@job +def job_with_datasource(dataset): + return dataset + + @job(custom_name="custom_job_name") def custom_name_job_function(): return "custom_job_name_return" diff --git a/tests/jobs/test_register.py b/tests/jobs/test_register.py new file mode 100644 index 0000000..1904537 --- /dev/null +++ b/tests/jobs/test_register.py @@ -0,0 +1,8 @@ +from rialto.jobs.test_utils import disable_job_decorators, resolver_resolves +from tests.jobs.test_job import test_job + + +def test_resolve_after_disable(spark): + with disable_job_decorators(test_job): + assert test_job.job_with_datasource("test") == "test" + assert resolver_resolves(spark, test_job.job_with_datasource)