Skip to content

Commit

Permalink
[DOP-9787] Improve read strategies in DBReader
Browse files Browse the repository at this point in the history
  • Loading branch information
dolfinus committed Dec 8, 2023
1 parent 4092134 commit dd7d06d
Show file tree
Hide file tree
Showing 59 changed files with 1,764 additions and 1,417 deletions.
41 changes: 41 additions & 0 deletions docs/changelog/next_release/182.breaking.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
Implementation of read strategies has been drastically improved.

Before 0.10:

- Get table schema by making query ``SELECT * FROM table WHERE 1=0`` (if ``DBReader.columns`` contains ``*``)
- Append HWM column to list of table columns and remove duplicated columns.
- Create dataframe from query like ``SELECT hwm.expression AS hwm.column, ...other table columns... FROM table WHERE prev_hwm.expression > prev_hwm.value``.
- Determine HWM class by ``df.schema[hwm.column].dataType``.
- Calculate ``df.select(min(hwm.column), max(hwm.column)).collect()`` on Spark side.
- Use ``max(hwm.column)`` as next HWM value.
- Return dataframe to user.

This was far from ideal:
- Dataframe content (all rows or just changed ones) was loaded from the source to Spark only to get min/max values of specific column.
- Step of fetching table schema and then substituting column names in the following query may cause errors.

For example, source contains columns with mixed name case, like ``"MyColumn"`` and ``"My column"``.
Column names were not escaped during query generation, leading to queries that cannot be executed by database.
So users have to explicitly set proper columns list with wrapping them with ``"``.

- Dataframe was created from query with clause like ``WHERE hwm.expression > prev_hwm.value``,
not ``WHERE hwm.expression > prev_hwm.value AND hwm.expression <= current_hwm.value``.

So if new rows appeared in the source after HWM value is determined, these rows may be read by DBReader on the first run,
and then again on the next run, because they are returned by ``WHERE hwm.expression > prev_hwm.value`` query.

Since 0.10:
- Get type of HWM expression: ``SELECT hwm.expression FROM table WHERE 1=0``
- Determine HWM class by ``df.schema[0]``.
- Get min/max values by querying ``SELECT MIN(hwm.expression), MAX(hwm.expression) FROM table WHERE hwm.expression >= prev_hwm.value``.
- Use ``max(hwm.column)`` as next HWM value.
- Create dataframe from query ``SELECT * FROM table WHERE hwm.expression > prev_hwm.value AND hwm.expression <= current_hwm.value``, and return it to user.

Improvements:
- Allow source to calculate min/max instead of loading everything to Spark. This should be *really* fast, and also source can use indexes to speed this up even more.
- Restrict dataframe content to always match HWM values.
- Don't mess up with columns list, just pass them to source as-is. So ``DBReader`` does not fail on tables with mixed column naming.

**Breaking change** - HWM column is not being implicitly added to dataframe.
If it was not just some column value but some expression which then used in your code by accessing dataframe column,
you should explicitly add same expression to ``DBReader.columns``.
38 changes: 0 additions & 38 deletions onetl/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,41 +180,3 @@ def generate_temp_path(root: PurePath) -> PurePath:
current_process = ProcessStackManager.get_current()
current_dt = datetime.now().strftime(DATETIME_FORMAT)
return root / "onetl" / current_process.host / current_process.full_name / current_dt


def get_sql_query(
table: str,
columns: list[str] | None = None,
where: str | None = None,
hint: str | None = None,
compact: bool = False,
) -> str:
"""
Generates a SQL query using input arguments
"""

if compact:
indent = " "
else:
indent = os.linesep + " " * 7

hint = f" /*+ {hint} */" if hint else ""

columns_str = "*"
if columns:
columns_str = indent + f",{indent}".join(column for column in columns)

if columns_str.strip() == "*":
columns_str = indent + "*"

where_str = ""
if where:
where_str = "WHERE" + indent + where

return os.linesep.join(
[
f"SELECT{hint}{columns_str}",
f"FROM{indent}{table}",
where_str,
],
).strip()
2 changes: 1 addition & 1 deletion onetl/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from onetl.base.base_file_limit import BaseFileLimit
from onetl.base.contains_exception import ContainsException
from onetl.base.contains_get_df_schema import ContainsGetDFSchemaMethod
from onetl.base.contains_get_min_max_bounds import ContainsGetMinMaxBounds
from onetl.base.contains_get_min_max_values import ContainsGetMinMaxValues
from onetl.base.path_protocol import PathProtocol, PathWithStatsProtocol
from onetl.base.path_stat_protocol import PathStatProtocol
from onetl.base.pure_path_protocol import PurePathProtocol
Expand Down
13 changes: 5 additions & 8 deletions onetl/base/base_db_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,13 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any

from etl_entities.source import Table

from onetl.base.base_connection import BaseConnection
from onetl.hwm import Statement
from onetl.hwm import Window

if TYPE_CHECKING:
from etl_entities.hwm import HWM, ColumnHWM
from pyspark.sql import DataFrame
from pyspark.sql.types import StructType
from pyspark.sql.types import StructField, StructType


class BaseDBDialect(ABC):
Expand All @@ -37,7 +35,7 @@ def __init__(self, connection: BaseDBConnection) -> None:
self.connection = connection

@abstractmethod
def validate_name(self, value: Table) -> str:
def validate_name(self, value: str) -> str:
"""Check if ``source`` or ``target`` value is valid.
Raises
Expand Down Expand Up @@ -109,7 +107,7 @@ def validate_hint(self, hint: Any) -> Any | None:
"""

@abstractmethod
def detect_hwm_class(self, hwm_column_type: str) -> type[ColumnHWM]:
def detect_hwm_class(self, field: StructField) -> type[ColumnHWM] | None:
"""
Detects hwm column type based on specific data types in connections data stores
"""
Expand Down Expand Up @@ -141,8 +139,7 @@ def read_source_as_df(
hint: Any | None = None,
where: Any | None = None,
df_schema: StructType | None = None,
start_from: Statement | None = None,
end_at: Statement | None = None,
window: Window | None = None,
) -> DataFrame:
"""
Reads the source to dataframe. |support_hooks|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,19 @@

from typing_extensions import Protocol, runtime_checkable

from onetl.hwm.window import Window


@runtime_checkable
class ContainsGetMinMaxBounds(Protocol):
class ContainsGetMinMaxValues(Protocol):
"""
Protocol for objects containing ``get_min_max_bounds`` method
Protocol for objects containing ``get_min_max_values`` method
"""

def get_min_max_bounds(
def get_min_max_values(
self,
source: str,
column: str,
expression: str | None = None,
window: Window,
hint: Any | None = None,
where: Any | None = None,
) -> tuple[Any, Any]:
Expand Down
127 changes: 78 additions & 49 deletions onetl/connection/db_connection/db_connection/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,77 @@

from __future__ import annotations

import operator
import os
from datetime import date, datetime
from typing import Any, Callable, ClassVar, Dict

from etl_entities.hwm import ColumnHWM
from typing import TYPE_CHECKING, Any

from onetl.base import BaseDBDialect
from onetl.hwm import Statement
from onetl.hwm import Edge, Window
from onetl.hwm.store import SparkTypeToHWM

if TYPE_CHECKING:
from etl_entities.hwm import ColumnHWM
from pyspark.sql.types import StructField


class DBDialect(BaseDBDialect):
_compare_statements: ClassVar[Dict[Callable, str]] = {
operator.ge: "{} >= {}",
operator.gt: "{} > {}",
operator.le: "{} <= {}",
operator.lt: "{} < {}",
operator.eq: "{} == {}",
operator.ne: "{} != {}",
}

def detect_hwm_class(self, hwm_column_type: str) -> type[ColumnHWM]:
return SparkTypeToHWM.get(hwm_column_type) # type: ignore
def detect_hwm_class(self, field: StructField) -> type[ColumnHWM] | None:
return SparkTypeToHWM.get(field.dataType.typeName()) # type: ignore

def get_sql_query(
self,
table: str,
columns: list[str] | None = None,
where: str | list[str] | None = None,
hint: str | None = None,
compact: bool = False,
) -> str:
"""
Generates a SQL query using input arguments
"""

if compact:
indent = " "
else:
indent = os.linesep + " " * 7

hint = f" /*+ {hint} */" if hint else ""

columns_str = indent + "*"
if columns:
columns_str = indent + f",{indent}".join(column for column in columns)

where = where or []
if isinstance(where, str):
where = [where]

where_clauses = []
if len(where) == 1:
where_clauses.append("WHERE" + indent + where[0])
else:
for i, item in enumerate(where):
directive = "WHERE" if i == 0 else " AND"
where_clauses.append(directive + indent + f"({item})")

query_parts = [
f"SELECT{hint}{columns_str}",
f"FROM{indent}{table}",
*where_clauses,
]

return os.linesep.join(filter(None, query_parts)).strip()

def apply_window(
self,
condition: Any,
window: Window | None = None,
) -> Any:
conditions = [
condition,
self._edge_to_where(window.expression, window.start_from, position="start") if window else None,
self._edge_to_where(window.expression, window.stop_at, position="end") if window else None,
]
return list(filter(None, conditions))

def escape_column(self, value: str) -> str:
return f'"{value}"'
Expand All @@ -58,44 +106,25 @@ def get_min_value(self, value: Any) -> str:
result = self._serialize_value(value)
return f"MIN({result})"

def condition_assembler(
def _edge_to_where(
self,
condition: Any,
start_from: Statement | None,
end_at: Statement | None,
expression: str,
edge: Edge,
position: str,
) -> Any:
conditions = [condition]

if start_from:
condition1 = self._get_compare_statement(
comparator=start_from.operator,
arg1=start_from.expression,
arg2=start_from.value,
)
conditions.append(condition1)

if end_at:
condition2 = self._get_compare_statement(
comparator=end_at.operator,
arg1=end_at.expression,
arg2=end_at.value,
)
conditions.append(condition2)

result: list[Any] = list(filter(None, conditions))
if not result:
if not edge.is_set():
return None
return self._merge_conditions(result)

def _get_compare_statement(self, comparator: Callable, arg1: Any, arg2: Any) -> Any:
template = self._compare_statements[comparator]
return template.format(arg1, self._serialize_value(arg2))

def _merge_conditions(self, conditions: list[Any]) -> Any:
if len(conditions) == 1:
return conditions[0]
operators: dict[tuple[str, bool], str] = {
("start", True): ">=",
("start", False): "> ",
("end", True): "<=",
("end", False): "< ",
}

return " AND ".join(f"({item})" for item in conditions)
operator = operators[(position, edge.including)]
value = self._serialize_value(edge.value)
return f"{expression} {operator} {value}"

def _serialize_value(self, value: Any) -> str | int | dict:
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from __future__ import annotations

from etl_entities.source import Table


class SupportNameAny:
def validate_name(self, value: Table) -> Table:
def validate_name(self, value: str) -> str:
return value
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from __future__ import annotations

from etl_entities.source import Table


class SupportNameWithSchemaOnly:
def validate_name(cls, value: Table) -> Table:
if value.name.count(".") != 1:
def validate_name(self, value: str) -> str:
if value.count(".") != 1:
raise ValueError(
f"Name should be passed in `schema.name` format, got '{value}'",
)
Expand Down
Loading

0 comments on commit dd7d06d

Please sign in to comment.