Skip to content

Commit

Permalink
[DOP-9787] Convert all classmethods of Dialect class to regular methods
Browse files Browse the repository at this point in the history
  • Loading branch information
dolfinus committed Dec 5, 2023
1 parent e1a27e0 commit 0733113
Show file tree
Hide file tree
Showing 22 changed files with 185 additions and 322 deletions.
51 changes: 15 additions & 36 deletions onetl/base/base_db_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any

from etl_entities.source import Table

Expand All @@ -33,9 +33,11 @@ class BaseDBDialect(ABC):
Collection of methods used for validating input values before passing them to read_source_as_df/write_df_to_target
"""

@classmethod
def __init__(self, connection: BaseDBConnection) -> None:
self.connection = connection

@abstractmethod
def validate_name(cls, connection: BaseDBConnection, value: Table) -> Table:
def validate_name(self, value: Table) -> str:
"""Check if ``source`` or ``target`` value is valid.
Raises
Expand All @@ -46,9 +48,8 @@ def validate_name(cls, connection: BaseDBConnection, value: Table) -> Table:
If value is invalid
"""

@classmethod
@abstractmethod
def validate_columns(cls, connection: BaseDBConnection, columns: list[str] | None) -> list[str] | None:
def validate_columns(self, columns: list[str] | None) -> list[str] | None:
"""Check if ``columns`` value is valid.
Raises
Expand All @@ -59,9 +60,8 @@ def validate_columns(cls, connection: BaseDBConnection, columns: list[str] | Non
If value is invalid
"""

@classmethod
@abstractmethod
def validate_hwm(cls, connection: BaseDBConnection, hwm: HWM) -> HWM:
def validate_hwm(self, hwm: HWM | None) -> HWM | None:
"""Check if ``HWM`` class is valid.
Raises
Expand All @@ -72,9 +72,8 @@ def validate_hwm(cls, connection: BaseDBConnection, hwm: HWM) -> HWM:
If hwm is invalid
"""

@classmethod
@abstractmethod
def validate_df_schema(cls, connection: BaseDBConnection, df_schema: StructType | None) -> StructType | None:
def validate_df_schema(self, df_schema: StructType | None) -> StructType | None:
"""Check if ``df_schema`` value is valid.
Raises
Expand All @@ -85,9 +84,8 @@ def validate_df_schema(cls, connection: BaseDBConnection, df_schema: StructType
If value is invalid
"""

@classmethod
@abstractmethod
def validate_where(cls, connection: BaseDBConnection, where: Any) -> Any | None:
def validate_where(self, where: Any) -> Any | None:
"""Check if ``where`` value is valid.
Raises
Expand All @@ -98,9 +96,8 @@ def validate_where(cls, connection: BaseDBConnection, where: Any) -> Any | None:
If value is invalid
"""

@classmethod
@abstractmethod
def validate_hint(cls, connection: BaseDBConnection, hint: Any) -> Any | None:
def validate_hint(self, hint: Any) -> Any | None:
"""Check if ``hint`` value is valid.
Raises
Expand All @@ -111,34 +108,12 @@ def validate_hint(cls, connection: BaseDBConnection, hint: Any) -> Any | None:
If value is invalid
"""

@classmethod
@abstractmethod
def detect_hwm_class(cls, hwm_column_type: str) -> ColumnHWM:
def detect_hwm_class(self, hwm_column_type: str) -> type[ColumnHWM]:
"""
Detects hwm column type based on specific data types in connections data stores
"""

@classmethod
@abstractmethod
def _merge_conditions(cls, conditions: list[Any]) -> Any:
"""
Convert multiple WHERE conditions to one
"""

@classmethod
@abstractmethod
def _expression_with_alias(cls, expression: Any, alias: str) -> Any:
"""
Return "expression AS alias" statement
"""

@classmethod
@abstractmethod
def _get_compare_statement(cls, comparator: Callable, arg1: Any, arg2: Any) -> Any:
"""
Return "arg1 COMPARATOR arg2" statement
"""


class BaseDBConnection(BaseConnection):
"""
Expand All @@ -147,6 +122,10 @@ class BaseDBConnection(BaseConnection):

Dialect = BaseDBDialect

@property
def dialect(self):
return self.Dialect(self)

@property
@abstractmethod
def instance_url(self) -> str:
Expand Down
20 changes: 8 additions & 12 deletions onetl/connection/db_connection/clickhouse/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,16 @@


class ClickhouseDialect(JDBCDialect):
@classmethod
def _get_datetime_value_sql(cls, value: datetime) -> str:
def get_partition_column_hash(self, partition_column: str, num_partitions: int) -> str:
return f"modulo(halfMD5({partition_column}), {num_partitions})"

def get_partition_column_mod(self, partition_column: str, num_partitions: int) -> str:
return f"{partition_column} % {num_partitions}"

def _serialize_datetime(self, value: datetime) -> str:
result = value.strftime("%Y-%m-%d %H:%M:%S")
return f"CAST('{result}' AS DateTime)"

@classmethod
def _get_date_value_sql(cls, value: date) -> str:
def _serialize_date(self, value: date) -> str:
result = value.strftime("%Y-%m-%d")
return f"CAST('{result}' AS Date)"

@classmethod
def _get_partition_column_hash(cls, partition_column: str, num_partitions: int) -> str:
return f"modulo(halfMD5({partition_column}), {num_partitions})"

@classmethod
def _get_partition_column_mod(cls, partition_column: str, num_partitions: int) -> str:
return f"{partition_column} % {num_partitions}"
95 changes: 34 additions & 61 deletions onetl/connection/db_connection/db_connection/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,10 @@

import operator
from datetime import date, datetime
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict
from typing import Any, Callable, ClassVar, Dict

from onetl.base import BaseDBDialect
from onetl.hwm import Statement
from onetl.hwm.store import SparkTypeToHWM

if TYPE_CHECKING:
from etl_entities.hwm import HWM, ColumnHWM

from onetl.connection.db_connection.db_connection.connection import BaseDBConnection


class DBDialect(BaseDBDialect):
Expand All @@ -38,55 +32,44 @@ class DBDialect(BaseDBDialect):
operator.ne: "{} != {}",
}

@classmethod
def validate_hwm(cls, connection: BaseDBConnection, hwm: HWM) -> HWM:
if hasattr(cls, "validate_hwm_expression"):
cls.validate_hwm_expression(connection, hwm)
return hwm

@classmethod
def detect_hwm_class(cls, hwm_column_type: str) -> ColumnHWM:
return SparkTypeToHWM.get(hwm_column_type) # type: ignore

@classmethod
def _escape_column(cls, value: str) -> str:
def escape_column(self, value: str) -> str:
return f'"{value}"'

@classmethod
def _expression_with_alias(cls, expression: str, alias: str) -> str:
def aliased(self, expression: str, alias: str) -> str:
return f"{expression} AS {alias}"

@classmethod
def _get_compare_statement(cls, comparator: Callable, arg1: Any, arg2: Any) -> Any:
template = cls._compare_statements[comparator]
return template.format(arg1, cls._serialize_datetime_value(arg2))

@classmethod
def _merge_conditions(cls, conditions: list[Any]) -> Any:
if len(conditions) == 1:
return conditions[0]
def get_max_value(self, value: Any) -> str:
"""
Generate `MAX(value)` clause for given value
"""
result = self._serialize_value(value)
return f"MAX({result})"

return " AND ".join(f"({item})" for item in conditions)
def get_min_value(self, value: Any) -> str:
"""
Generate `MIN(value)` clause for given value
"""
result = self._serialize_value(value)
return f"MIN({result})"

@classmethod
def _condition_assembler(
cls,
def condition_assembler(
self,
condition: Any,
start_from: Statement | None,
end_at: Statement | None,
) -> Any:
conditions = [condition]

if start_from:
condition1 = cls._get_compare_statement(
condition1 = self._get_compare_statement(
comparator=start_from.operator,
arg1=start_from.expression,
arg2=start_from.value,
)
conditions.append(condition1)

if end_at:
condition2 = cls._get_compare_statement(
condition2 = self._get_compare_statement(
comparator=end_at.operator,
arg1=end_at.expression,
arg2=end_at.value,
Expand All @@ -96,51 +79,41 @@ def _condition_assembler(
result: list[Any] = list(filter(None, conditions))
if not result:
return None
return result

return cls._merge_conditions(result)
def _get_compare_statement(self, comparator: Callable, arg1: Any, arg2: Any) -> Any:
template = self._compare_statements[comparator]
return template.format(arg1, self._serialize_value(arg2))

@classmethod
def _serialize_datetime_value(cls, value: Any) -> str | int | dict:
def _merge_conditions(self, conditions: list[Any]) -> Any:
if len(conditions) == 1:
return conditions[0]

return " AND ".join(f"({item})" for item in conditions)

def _serialize_value(self, value: Any) -> str | int | dict:
"""
Transform the value into an SQL Dialect-supported form.
"""

if isinstance(value, datetime):
return cls._get_datetime_value_sql(value)
return self._serialize_datetime(value)

if isinstance(value, date):
return cls._get_date_value_sql(value)
return self._serialize_date(value)

return str(value)

@classmethod
def _get_datetime_value_sql(cls, value: datetime) -> str:
def _serialize_datetime(self, value: datetime) -> str:
"""
Transform the datetime value into supported by SQL Dialect
"""
result = value.isoformat()
return repr(result)

@classmethod
def _get_date_value_sql(cls, value: date) -> str:
def _serialize_date(self, value: date) -> str:
"""
Transform the date value into supported by SQL Dialect
"""
result = value.isoformat()
return repr(result)

@classmethod
def _get_max_value_sql(cls, value: Any) -> str:
"""
Generate `MAX(value)` clause for given value
"""
result = cls._serialize_datetime_value(value)
return f"MAX({result})"

@classmethod
def _get_min_value_sql(cls, value: Any) -> str:
"""
Generate `MIN(value)` clause for given value
"""
result = cls._serialize_datetime_value(value)
return f"MIN({result})"
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@


class SupportHWMExpressionNone:
@classmethod
def validate_hwm_expression(cls, connection: BaseDBConnection, hwm: HWM) -> HWM | None:
if hwm.expression is not None:
connection: BaseDBConnection

def validate_hwm_expression(self, hwm: HWM | None) -> HWM | None:
if hwm and hwm.expression is not None:
raise ValueError(
f"'hwm.expression' parameter is not supported by {connection.__class__.__name__}",
f"'hwm.expression' parameter is not supported by {self.connection.__class__.__name__}",
)
return hwm
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@


class SupportHWMExpressionStr:
@classmethod
def validate_hwm_expression(cls, connection: BaseDBConnection, hwm: HWM) -> HWM | None:
if hwm.expression is None:
return None
connection: BaseDBConnection

def validate_hwm_expression(self, hwm: HWM | None) -> HWM | None:
if not hwm or hwm.expression is None:
return hwm

if not isinstance(hwm.expression, str):
raise TypeError(
f"{connection.__class__.__name__} requires 'hwm.expression' parameter type to be 'str', "
f"{self.connection.__class__.__name__} requires 'hwm.expression' parameter type to be 'str', "
f"got {hwm.expression.__class__.__name__!r}",
)

Expand Down
14 changes: 7 additions & 7 deletions onetl/connection/db_connection/greenplum/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def read_source_as_df(
) -> DataFrame:
read_options = self.ReadOptions.parse(options).dict(by_alias=True, exclude_none=True)
log.info("|%s| Executing SQL query (on executor):", self.__class__.__name__)
where = self.Dialect._condition_assembler(condition=where, start_from=start_from, end_at=end_at)
where = self.dialect.condition_assembler(condition=where, start_from=start_from, end_at=end_at)
query = get_sql_query(table=source, columns=columns, where=where)
log_lines(log, query)

Expand Down Expand Up @@ -351,13 +351,13 @@ def get_min_max_bounds(
query = get_sql_query(
table=source,
columns=[
self.Dialect._expression_with_alias(
self.Dialect._get_min_value_sql(expression or column),
self.Dialect._escape_column("min"),
self.dialect.aliased(
self.dialect.get_min_value(expression or column),
self.dialect.escape_column("min"),
),
self.Dialect._expression_with_alias(
self.Dialect._get_max_value_sql(expression or column),
self.Dialect._escape_column("max"),
self.dialect.aliased(
self.dialect.get_max_value(expression or column),
self.dialect.escape_column("max"),
),
],
where=where,
Expand Down
6 changes: 2 additions & 4 deletions onetl/connection/db_connection/greenplum/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,10 @@ class GreenplumDialect( # noqa: WPS215
SupportHWMExpressionStr,
DBDialect,
):
@classmethod
def _get_datetime_value_sql(cls, value: datetime) -> str:
def _serialize_datetime(self, value: datetime) -> str:
result = value.isoformat()
return f"cast('{result}' as timestamp)"

@classmethod
def _get_date_value_sql(cls, value: date) -> str:
def _serialize_date(self, value: date) -> str:
result = value.isoformat()
return f"cast('{result}' as date)"
Loading

0 comments on commit 0733113

Please sign in to comment.