Skip to content

Commit

Permalink
v2 changes
Browse files Browse the repository at this point in the history
  • Loading branch information
MDobransky committed Aug 23, 2024
1 parent f2be7cb commit a49b06c
Show file tree
Hide file tree
Showing 25 changed files with 229 additions and 534 deletions.
25 changes: 14 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -312,17 +312,18 @@ from pyspark.sql import DataFrame
from rialto.common import TableReader
from rialto.jobs.decorators import job, datasource


@datasource
def my_datasource(run_date: datetime.date, table_reader: TableReader) -> DataFrame:
return table_reader.get_latest("my_catalog.my_schema.my_table", until=run_date)
return table_reader.get_latest("my_catalog.my_schema.my_table", date_until=run_date)


@job
def my_job(my_datasource: DataFrame) -> DataFrame:
return my_datasource.withColumn("HelloWorld", F.lit(1))
return my_datasource.withColumn("HelloWorld", F.lit(1))
```
This piece of code
1. creates a rialto transformation called *my_job*, which is then callable by the rialto runner.
This piece of code
1. creates a rialto transformation called *my_job*, which is then callable by the rialto runner.
2. It sources the *my_datasource* and then runs *my_job* on top of that datasource.
3. Rialto adds VERSION (of your package) and INFORMATION_DATE (as per config) columns automatically.
4. The rialto runner stores the final to a catalog, to a table according to the job's name.
Expand Down Expand Up @@ -383,20 +384,20 @@ import my_package.test_job_module as tjm
# Datasource Testing
def test_datasource_a():
... mocks here ...

with disable_job_decorators(tjm):
datasource_a_output = tjm.datasource_a(... mocks ...)

... asserts ...

# Job Testing
def test_my_job():
datasource_a_mock = ...
... other mocks...

with disable_job_decorators(tjm):
job_output = tjm.my_job(datasource_a_mock, ... mocks ...)

... asserts ...
```

Expand Down Expand Up @@ -563,6 +564,7 @@ reader = TableReader(spark=spark_instance)
```

usage of _get_table_:

```python
# get whole table
df = reader.get_table(table="catalog.schema.table", date_column="information_date")
Expand All @@ -573,18 +575,19 @@ from datetime import datetime
start = datetime.strptime("2020-01-01", "%Y-%m-%d").date()
end = datetime.strptime("2024-01-01", "%Y-%m-%d").date()

df = reader.get_table(table="catalog.schema.table", info_date_from=start, info_date_to=end)
df = reader.get_table(table="catalog.schema.table", date_from=start, date_to=end)
```

usage of _get_latest_:

```python
# most recent partition
df = reader.get_latest(table="catalog.schema.table", date_column="information_date")

# most recent partition until
until = datetime.strptime("2020-01-01", "%Y-%m-%d").date()

df = reader.get_latest(table="catalog.schema.table", until=until, date_column="information_date")
df = reader.get_latest(table="catalog.schema.table", date_until=until, date_column="information_date")

```
For full information on parameters and their optionality see technical documentation.
Expand Down
31 changes: 17 additions & 14 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[tool.poetry]
name = "rialto"
name = "rialto-dev"

version = "1.3.2"
version = "2.0.0"

packages = [
{ include = "rialto" },
Expand Down Expand Up @@ -31,6 +31,7 @@ pandas = "^2.1.0"
flake8-broken-line = "^1.0.0"
loguru = "^0.7.2"
importlib-metadata = "^7.2.1"
env_yaml = "^0.0.3"

[tool.poetry.dev-dependencies]
pyspark = "^3.4.1"
Expand Down
70 changes: 23 additions & 47 deletions rialto/common/table_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
import pyspark.sql.functions as F
from pyspark.sql import DataFrame, SparkSession

from rialto.common.utils import get_date_col_property, get_delta_partition


class DataReader(metaclass=abc.ABCMeta):
"""
Expand All @@ -36,16 +34,15 @@ class DataReader(metaclass=abc.ABCMeta):
def get_latest(
self,
table: str,
until: Optional[datetime.date] = None,
date_column: str = None,
date_column: str,
date_until: Optional[datetime.date] = None,
uppercase_columns: bool = False,
) -> DataFrame:
"""
Get latest available date partition of the table until specified date
:param table: input table path
:param until: Optional until date (inclusive)
:param date_column: column to filter dates on, takes highest priority
:param date_until: Optional until date (inclusive)
:param uppercase_columns: Option to refactor all column names to uppercase
:return: Dataframe
"""
Expand All @@ -55,18 +52,17 @@ def get_latest(
def get_table(
self,
table: str,
info_date_from: Optional[datetime.date] = None,
info_date_to: Optional[datetime.date] = None,
date_column: str = None,
date_column: str,
date_from: Optional[datetime.date] = None,
date_to: Optional[datetime.date] = None,
uppercase_columns: bool = False,
) -> DataFrame:
"""
Get a whole table or a slice by selected dates
:param table: input table path
:param info_date_from: Optional date from (inclusive)
:param info_date_to: Optional date to (inclusive)
:param date_column: column to filter dates on, takes highest priority
:param date_from: Optional date from (inclusive)
:param date_to: Optional date to (inclusive)
:param uppercase_columns: Option to refactor all column names to uppercase
:return: Dataframe
"""
Expand All @@ -76,17 +72,13 @@ def get_table(
class TableReader(DataReader):
"""An implementation of data reader for databricks tables"""

def __init__(self, spark: SparkSession, date_property: str = "rialto_date_column", infer_partition: bool = False):
def __init__(self, spark: SparkSession):
"""
Init
:param spark:
:param date_property: Databricks table property specifying date column, take priority over inference
:param infer_partition: infer date column as tables partition from delta metadata
"""
self.spark = spark
self.date_property = date_property
self.infer_partition = infer_partition
super().__init__()

def _uppercase_column_names(self, df: DataFrame) -> DataFrame:
Expand All @@ -106,41 +98,26 @@ def _get_latest_available_date(self, df: DataFrame, date_col: str, until: Option
df = df.select(F.max(date_col)).alias("latest")
return df.head()[0]

def _get_date_col(self, table: str, date_column: str):
"""
Get tables date column
column specified at get_table/get_latest takes priority, if inference is enabled it
takes 2nd place, last resort is table property
"""
if date_column:
return date_column
elif self.infer_partition:
return get_delta_partition(self.spark, table)
else:
return get_date_col_property(self.spark, table, self.date_property)

def get_latest(
self,
table: str,
until: Optional[datetime.date] = None,
date_column: str = None,
date_column: str,
date_until: Optional[datetime.date] = None,
uppercase_columns: bool = False,
) -> DataFrame:
"""
Get latest available date partition of the table until specified date
:param table: input table path
:param until: Optional until date (inclusive)
:param date_until: Optional until date (inclusive)
:param date_column: column to filter dates on, takes highest priority
:param uppercase_columns: Option to refactor all column names to uppercase
:return: Dataframe
"""
date_col = self._get_date_col(table, date_column)
df = self.spark.read.table(table)

selected_date = self._get_latest_available_date(df, date_col, until)
df = df.filter(F.col(date_col) == selected_date)
selected_date = self._get_latest_available_date(df, date_column, date_until)
df = df.filter(F.col(date_column) == selected_date)

if uppercase_columns:
df = self._uppercase_column_names(df)
Expand All @@ -149,28 +126,27 @@ def get_latest(
def get_table(
self,
table: str,
info_date_from: Optional[datetime.date] = None,
info_date_to: Optional[datetime.date] = None,
date_column: str = None,
date_column: str,
date_from: Optional[datetime.date] = None,
date_to: Optional[datetime.date] = None,
uppercase_columns: bool = False,
) -> DataFrame:
"""
Get a whole table or a slice by selected dates
:param table: input table path
:param info_date_from: Optional date from (inclusive)
:param info_date_to: Optional date to (inclusive)
:param date_from: Optional date from (inclusive)
:param date_to: Optional date to (inclusive)
:param date_column: column to filter dates on, takes highest priority
:param uppercase_columns: Option to refactor all column names to uppercase
:return: Dataframe
"""
date_col = self._get_date_col(table, date_column)
df = self.spark.read.table(table)

if info_date_from:
df = df.filter(F.col(date_col) >= info_date_from)
if info_date_to:
df = df.filter(F.col(date_col) <= info_date_to)
if date_from:
df = df.filter(F.col(date_column) >= date_from)
if date_to:
df = df.filter(F.col(date_column) <= date_to)
if uppercase_columns:
df = self._uppercase_column_names(df)
return df
5 changes: 3 additions & 2 deletions rialto/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__all__ = ["load_yaml", "get_date_col_property", "get_delta_partition"]
__all__ = ["load_yaml"]

import os
from typing import Any

import pyspark.sql.functions as F
import yaml
from env_yaml import EnvLoader
from pyspark.sql import DataFrame
from pyspark.sql.types import FloatType

Expand All @@ -34,7 +35,7 @@ def load_yaml(path: str) -> Any:
raise FileNotFoundError(f"Can't find {path}.")

with open(path, "r") as stream:
return yaml.safe_load(stream)
return yaml.load(stream, EnvLoader)


def get_date_col_property(spark, table: str, property: str) -> str:
Expand Down
2 changes: 2 additions & 0 deletions rialto/jobs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from rialto.jobs.decorators import datasource, job
Loading

0 comments on commit a49b06c

Please sign in to comment.