Skip to content

Commit

Permalink
[DOP-9787] Improve read strategies in DBReader
Browse files Browse the repository at this point in the history
  • Loading branch information
dolfinus committed Dec 6, 2023
1 parent 4092134 commit d4080ad
Show file tree
Hide file tree
Showing 48 changed files with 873 additions and 1,018 deletions.
2 changes: 1 addition & 1 deletion onetl/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from onetl.base.base_file_limit import BaseFileLimit
from onetl.base.contains_exception import ContainsException
from onetl.base.contains_get_df_schema import ContainsGetDFSchemaMethod
from onetl.base.contains_get_min_max_bounds import ContainsGetMinMaxBounds
from onetl.base.contains_get_min_max_values import ContainsGetMinMaxValues
from onetl.base.path_protocol import PathProtocol, PathWithStatsProtocol
from onetl.base.path_stat_protocol import PathStatProtocol
from onetl.base.pure_path_protocol import PurePathProtocol
Expand Down
13 changes: 5 additions & 8 deletions onetl/base/base_db_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,13 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any

from etl_entities.source import Table

from onetl.base.base_connection import BaseConnection
from onetl.hwm import Statement
from onetl.hwm import Window

if TYPE_CHECKING:
from etl_entities.hwm import HWM, ColumnHWM
from pyspark.sql import DataFrame
from pyspark.sql.types import StructType
from pyspark.sql.types import DataType, StructType


class BaseDBDialect(ABC):
Expand All @@ -37,7 +35,7 @@ def __init__(self, connection: BaseDBConnection) -> None:
self.connection = connection

@abstractmethod
def validate_name(self, value: Table) -> str:
def validate_name(self, value: str) -> str:
"""Check if ``source`` or ``target`` value is valid.
Raises
Expand Down Expand Up @@ -109,7 +107,7 @@ def validate_hint(self, hint: Any) -> Any | None:
"""

@abstractmethod
def detect_hwm_class(self, hwm_column_type: str) -> type[ColumnHWM]:
def detect_hwm_class(self, data_type: DataType) -> type[ColumnHWM] | None:
"""
Detects hwm column type based on specific data types in connections data stores
"""
Expand Down Expand Up @@ -141,8 +139,7 @@ def read_source_as_df(
hint: Any | None = None,
where: Any | None = None,
df_schema: StructType | None = None,
start_from: Statement | None = None,
end_at: Statement | None = None,
window: Window | None = None,
) -> DataFrame:
"""
Reads the source to dataframe. |support_hooks|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,19 @@

from typing_extensions import Protocol, runtime_checkable

from onetl.hwm.window import Window


@runtime_checkable
class ContainsGetMinMaxBounds(Protocol):
class ContainsGetMinMaxValues(Protocol):
"""
Protocol for objects containing ``get_min_max_bounds`` method
Protocol for objects containing ``get_min_max_values`` method
"""

def get_min_max_bounds(
def get_min_max_values(
self,
source: str,
column: str,
expression: str | None = None,
window: Window,
hint: Any | None = None,
where: Any | None = None,
) -> tuple[Any, Any]:
Expand Down
95 changes: 46 additions & 49 deletions onetl/connection/db_connection/db_connection/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,39 @@

from __future__ import annotations

import operator
from datetime import date, datetime
from typing import Any, Callable, ClassVar, Dict

from etl_entities.hwm import ColumnHWM
from typing import TYPE_CHECKING, Any

from onetl.base import BaseDBDialect
from onetl.hwm import Statement
from onetl.hwm import Window
from onetl.hwm.store import SparkTypeToHWM
from onetl.hwm.window import Edge

if TYPE_CHECKING:
from etl_entities.hwm import ColumnHWM
from pyspark.sql.types import DataType


class DBDialect(BaseDBDialect):
_compare_statements: ClassVar[Dict[Callable, str]] = {
operator.ge: "{} >= {}",
operator.gt: "{} > {}",
operator.le: "{} <= {}",
operator.lt: "{} < {}",
operator.eq: "{} == {}",
operator.ne: "{} != {}",
}

def detect_hwm_class(self, hwm_column_type: str) -> type[ColumnHWM]:
return SparkTypeToHWM.get(hwm_column_type) # type: ignore
def detect_hwm_class(self, value: DataType) -> type[ColumnHWM] | None:
return SparkTypeToHWM.get(value.typeName()) # type: ignore

def apply_window(
self,
condition: Any,
window: Window | None = None,
) -> Any:
window = window or Window()
conditions = [
condition,
self._edge_to_where(window.expression, window.start_from, position="start"),
self._edge_to_where(window.expression, window.stop_at, position="end"),
]
conditions = list(filter(None, conditions))
if not conditions:
return None

return self._merge_conditions(conditions)

def escape_column(self, value: str) -> str:
return f'"{value}"'
Expand All @@ -58,45 +68,32 @@ def get_min_value(self, value: Any) -> str:
result = self._serialize_value(value)
return f"MIN({result})"

def condition_assembler(
self,
condition: Any,
start_from: Statement | None,
end_at: Statement | None,
) -> Any:
conditions = [condition]

if start_from:
condition1 = self._get_compare_statement(
comparator=start_from.operator,
arg1=start_from.expression,
arg2=start_from.value,
)
conditions.append(condition1)

if end_at:
condition2 = self._get_compare_statement(
comparator=end_at.operator,
arg1=end_at.expression,
arg2=end_at.value,
)
conditions.append(condition2)

result: list[Any] = list(filter(None, conditions))
if not result:
return None
return self._merge_conditions(result)

def _get_compare_statement(self, comparator: Callable, arg1: Any, arg2: Any) -> Any:
template = self._compare_statements[comparator]
return template.format(arg1, self._serialize_value(arg2))

def _merge_conditions(self, conditions: list[Any]) -> Any:
if len(conditions) == 1:
return conditions[0]

return " AND ".join(f"({item})" for item in conditions)

def _edge_to_where(
self,
expression: str,
edge: Edge,
position: str,
) -> Any:
if not expression or not edge.is_set():
return None

operators: dict[tuple[str, bool], str] = {
("start", True): ">=",
("start", False): ">",
("end", True): "<=",
("end", False): "<",
}

operator = operators[(position, edge.including)]
value = self._serialize_value(edge.value)
return f"{expression} {operator} {value}"

def _serialize_value(self, value: Any) -> str | int | dict:
"""
Transform the value into an SQL Dialect-supported form.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from __future__ import annotations

from etl_entities.source import Table


class SupportNameAny:
def validate_name(self, value: Table) -> Table:
def validate_name(self, value: str) -> str:
return value
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from __future__ import annotations

from etl_entities.source import Table


class SupportNameWithSchemaOnly:
def validate_name(cls, value: Table) -> Table:
if value.name.count(".") != 1:
def validate_name(self, value: str) -> str:
if value.count(".") != 1:
raise ValueError(
f"Name should be passed in `schema.name` format, got '{value}'",
)
Expand Down
29 changes: 13 additions & 16 deletions onetl/connection/db_connection/greenplum/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from onetl.connection.db_connection.jdbc_mixin.options import JDBCOptions
from onetl.exception import MISSING_JVM_CLASS_MSG, TooManyParallelJobsError
from onetl.hooks import slot, support_hooks
from onetl.hwm import Statement
from onetl.hwm import Window
from onetl.impl import GenericOptions
from onetl.log import log_lines, log_with_indent

Expand Down Expand Up @@ -267,13 +267,12 @@ def read_source_as_df(
hint: str | None = None,
where: str | None = None,
df_schema: StructType | None = None,
start_from: Statement | None = None,
end_at: Statement | None = None,
window: Window | None = None,
options: GreenplumReadOptions | None = None,
) -> DataFrame:
read_options = self.ReadOptions.parse(options).dict(by_alias=True, exclude_none=True)
log.info("|%s| Executing SQL query (on executor):", self.__class__.__name__)
where = self.dialect.condition_assembler(condition=where, start_from=start_from, end_at=end_at)
where = self.dialect.apply_window(where, window)
query = get_sql_query(table=source, columns=columns, where=where)
log_lines(log, query)

Expand Down Expand Up @@ -335,32 +334,30 @@ def get_df_schema(
return df.schema

@slot
def get_min_max_bounds(
def get_min_max_values(
self,
source: str,
column: str,
expression: str | None = None,
hint: str | None = None,
where: str | None = None,
window: Window,
hint: Any | None = None,
where: Any | None = None,
options: JDBCOptions | None = None,
) -> tuple[Any, Any]:
log.info("|%s| Getting min and max values for column %r ...", self.__class__.__name__, column)

log.info("|%s| Getting min and max values for %r ...", self.__class__.__name__, window.expression)
jdbc_options = self.JDBCOptions.parse(options).copy(update={"fetchsize": 1})

query = get_sql_query(
table=source,
columns=[
self.dialect.aliased(
self.dialect.get_min_value(expression or column),
self.dialect.get_min_value(window.expression),
self.dialect.escape_column("min"),
),
self.dialect.aliased(
self.dialect.get_max_value(expression or column),
self.dialect.get_max_value(window.expression),
self.dialect.escape_column("max"),
),
],
where=where,
where=self.dialect.apply_window(where, window),
)

log.info("|%s| Executing SQL query (on driver):", self.__class__.__name__)
Expand All @@ -372,8 +369,8 @@ def get_min_max_bounds(
max_value = row["max"]

log.info("|%s| Received values:", self.__class__.__name__)
log_with_indent(log, "MIN(%r) = %r", column, min_value)
log_with_indent(log, "MAX(%r) = %r", column, max_value)
log_with_indent(log, "MIN(%s) = %r", window.expression, min_value)
log_with_indent(log, "MAX(%s) = %r", window.expression, max_value)

return min_value, max_value

Expand Down
Loading

0 comments on commit d4080ad

Please sign in to comment.