Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Module Register rework of dependency registration #15

Merged
merged 5 commits into from
Sep 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ All notable changes to this project will be documented in this file.
- config holder removed from jobs
- metadata_manager and feature_loader are now available arguments, depending on configuration
- added @config decorator, similar use case to @datasource, for parsing configuration
- reworked Resolver + Added ModuleRegister
- datasources no longer just by importing, thus are no longer available for all jobs
- register_dependency_callable and register_dependency_module added to register datasources
- together, it's now possilbe to have 2 datasources with the same name, but different implementations for 2 jobs.
#### TableReader
- function signatures changed
- until -> date_until
Expand Down
29 changes: 27 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ With that sorted out, we can now provide a quick example of the *rialto.jobs* mo
```python
from pyspark.sql import DataFrame
from rialto.common import TableReader
from rialto.jobs.decorators import config_parser, job, datasource
from rialto.jobs import config_parser, job, datasource
from rialto.runner.config_loader import PipelineConfig
from pydantic import BaseModel

Expand Down Expand Up @@ -419,7 +419,6 @@ If you want to disable versioning of your job (adding package VERSION column to
def my_job(...):
...
```

These parameters can be used separately, or combined.

### Notes & Rules
Expand All @@ -435,6 +434,32 @@ This can be useful in **model training**.
Finally, remember, that your jobs are still just *Rialto Transformations* internally.
Meaning that at the end of the day, you should always read some data, do some operations on it and either return a pyspark DataFrame, or not return anything and let the framework return the placeholder one.


### Importing / Registering Datasources
Datasources required for a job (or another datasource) can be defined in a different module.
To register your module as a datasource, you can use the following functions:

```python3
from rialto.jobs import register_dependency_callable, register_dependency_module
import my_package.my_datasources as md
import my_package.my_datasources_big as big_md

# Register an entire dependency module
register_dependency_module(md)

# Register a single datasource from a bigger module
register_dependency_callable(big_md.sample_datasource)

@job
def my_job(my_datasource, sample_datasource: DataFrame, ...):
...
```

Each job/datasource can only resolve datasources it has defined as dependencies.

**NOTE**: While ```register_dependency_module``` only registers a module as available dependencies, the ```register_dependency_callable``` actually brings the datasource into the targed module - and thus becomes available for export in the dependency chains.


### Testing
One of the main advantages of the jobs module is simplification of unit tests for your transformations. Rialto provides following tools:

Expand Down
719 changes: 375 additions & 344 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[tool.poetry]
name = "rialto-dev"
name = "rialto"

version = "2.0.0"

Expand Down Expand Up @@ -31,6 +31,7 @@ pandas = "^2.1.0"
flake8-broken-line = "^1.0.0"
loguru = "^0.7.2"
importlib-metadata = "^7.2.1"
numpy = "<2.0.0"

[tool.poetry.dev-dependencies]
pyspark = "^3.4.1"
Expand Down
24 changes: 22 additions & 2 deletions rialto/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__all__ = ["load_yaml"]
__all__ = ["load_yaml", "cast_decimals_to_floats", "get_caller_module"]

import inspect
import os
from typing import Any
from typing import Any, List

import pyspark.sql.functions as F
import yaml
Expand Down Expand Up @@ -51,3 +52,22 @@ def cast_decimals_to_floats(df: DataFrame) -> DataFrame:
df = df.withColumn(c, F.col(c).cast(FloatType()))

return df


def get_caller_module() -> Any:
"""
Ged module containing the function which is calling your function.

Inspects the call stack, where:
0th entry is this function
1st entry is the function which needs to know who called it
2nd entry is the calling function

Therefore, we'll return a module which contains the function at the 2nd place on the stack.

:return: Python Module containing the calling function.
"""

stack = inspect.stack()
last_stack = stack[2]
return inspect.getmodule(last_stack[0])
4 changes: 4 additions & 0 deletions rialto/jobs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,7 @@
# limitations under the License.

from rialto.jobs.decorators import config_parser, datasource, job
from rialto.jobs.module_register import (
register_dependency_callable,
register_dependency_module,
)
18 changes: 5 additions & 13 deletions rialto/jobs/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@

__all__ = ["datasource", "job", "config_parser"]

import inspect
import typing

import importlib_metadata
from loguru import logger

from rialto.common.utils import get_caller_module
from rialto.jobs.job_base import JobBase
from rialto.jobs.resolver import Resolver
from rialto.jobs.module_register import ModuleRegister


def config_parser(cf_getter: typing.Callable) -> typing.Callable:
Expand All @@ -34,7 +34,7 @@ def config_parser(cf_getter: typing.Callable) -> typing.Callable:
:param cf_getter: dataset reader function
:return: raw function, unchanged
"""
Resolver.register_callable(cf_getter)
ModuleRegister.register_callable(cf_getter)
return cf_getter


Expand All @@ -48,16 +48,10 @@ def datasource(ds_getter: typing.Callable) -> typing.Callable:
:param ds_getter: dataset reader function
:return: raw reader function, unchanged
"""
Resolver.register_callable(ds_getter)
ModuleRegister.register_callable(ds_getter)
return ds_getter


def _get_module(stack: typing.List) -> typing.Any:
last_stack = stack[1]
mod = inspect.getmodule(last_stack[0])
return mod


def _get_version(module: typing.Any) -> str:
try:
package_name, _, _ = module.__name__.partition(".")
Expand Down Expand Up @@ -102,9 +96,7 @@ def job(*args, custom_name=None, disable_version=False):
:return: One more job wrapper for run function (if custom name or version override specified).
Otherwise, generates Rialto Transformation Type and returns it for in-module registration.
"""
stack = inspect.stack()

module = _get_module(stack)
module = get_caller_module()
version = _get_version(module)

# Use case where it's just raw @f. Otherwise, we get [] here.
Expand Down
79 changes: 27 additions & 52 deletions rialto/jobs/job_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import abc
import datetime
import typing
from contextlib import contextmanager

import pyspark.sql.functions as F
from loguru import logger
Expand Down Expand Up @@ -49,55 +48,33 @@ def get_job_name(self) -> str:
"""Job name getter"""
pass

@contextmanager
def _setup_resolver(self, run_date: datetime.date) -> None:
Resolver.register_callable(lambda: run_date, "run_date")

Resolver.register_callable(self._get_spark, "spark")
Resolver.register_callable(self._get_table_reader, "table_reader")
Resolver.register_callable(self._get_config, "config")

if self._get_feature_loader() is not None:
Resolver.register_callable(self._get_feature_loader, "feature_loader")
if self._get_metadata_manager() is not None:
Resolver.register_callable(self._get_metadata_manager, "metadata_manager")

try:
yield
finally:
Resolver.cache_clear()

def _setup(
def _get_resolver(
self,
spark: SparkSession,
run_date: datetime.date,
table_reader: TableReader,
config: PipelineConfig = None,
metadata_manager: MetadataManager = None,
feature_loader: PysparkFeatureLoader = None,
) -> None:
self._spark = spark
self._table_rader = table_reader
self._config = config
self._metadata = metadata_manager
self._feature_loader = feature_loader
) -> Resolver:
resolver = Resolver()

def _get_spark(self) -> SparkSession:
return self._spark
# Static Always - Available dependencies
resolver.register_object(spark, "spark")
resolver.register_object(run_date, "run_date")
resolver.register_object(config, "config")
resolver.register_object(table_reader, "table_reader")

def _get_table_reader(self) -> TableReader:
return self._table_rader
# Optionals
if feature_loader is not None:
resolver.register_object(feature_loader, "feature_loader")

def _get_config(self) -> PipelineConfig:
return self._config
if metadata_manager is not None:
resolver.register_object(metadata_manager, "metadata_manager")

def _get_feature_loader(self) -> PysparkFeatureLoader:
return self._feature_loader
return resolver

def _get_metadata_manager(self) -> MetadataManager:
return self._metadata

def _get_timestamp_holder_result(self) -> DataFrame:
spark = self._get_spark()
def _get_timestamp_holder_result(self, spark) -> DataFrame:
return spark.createDataFrame(
[(self.get_job_name(), datetime.datetime.now())], schema="JOB_NAME string, CREATION_TIME timestamp"
)
Expand All @@ -110,17 +87,6 @@ def _add_job_version(self, df: DataFrame) -> DataFrame:

return df

def _run_main_callable(self, run_date: datetime.date) -> DataFrame:
with self._setup_resolver(run_date):
custom_callable = self.get_custom_callable()
raw_result = Resolver.register_resolve(custom_callable)

if raw_result is None:
raw_result = self._get_timestamp_holder_result()

result_with_version = self._add_job_version(raw_result)
return result_with_version

def run(
self,
reader: TableReader,
Expand All @@ -140,8 +106,17 @@ def run(
:return: dataframe
"""
try:
self._setup(spark, reader, config, metadata_manager, feature_loader)
return self._run_main_callable(run_date)
resolver = self._get_resolver(spark, run_date, reader, config, metadata_manager, feature_loader)

custom_callable = self.get_custom_callable()
raw_result = resolver.resolve(custom_callable)

if raw_result is None:
raw_result = self._get_timestamp_holder_result(spark)

result_with_version = self._add_job_version(raw_result)
return result_with_version

except Exception as e:
logger.exception(e)
raise e
Loading
Loading