From 0733113a622e3e497919f8f0e342fabadb3d50de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=9C=D0=B0=D1=80=D1=82=D1=8B=D0=BD=D0=BE=D0=B2=20=D0=9C?= =?UTF-8?q?=D0=B0=D0=BA=D1=81=D0=B8=D0=BC=20=D0=A1=D0=B5=D1=80=D0=B3=D0=B5?= =?UTF-8?q?=D0=B5=D0=B2=D0=B8=D1=87?= Date: Tue, 5 Dec 2023 13:11:41 +0000 Subject: [PATCH] [DOP-9787] Convert all classmethods of Dialect class to regular methods --- onetl/base/base_db_connection.py | 51 +++------- .../db_connection/clickhouse/dialect.py | 20 ++-- .../db_connection/db_connection/dialect.py | 95 +++++++------------ .../support_hwm_expression_none.py | 9 +- .../support_hwm_expression_str.py | 11 ++- .../db_connection/greenplum/connection.py | 14 +-- .../db_connection/greenplum/dialect.py | 6 +- .../db_connection/hive/connection.py | 14 +-- .../connection/db_connection/hive/dialect.py | 3 +- .../jdbc_connection/connection.py | 22 ++--- .../db_connection/jdbc_connection/dialect.py | 6 +- .../connection/db_connection/kafka/dialect.py | 24 +++-- .../db_connection/mongodb/connection.py | 18 ++-- .../db_connection/mongodb/dialect.py | 38 +++----- .../connection/db_connection/mssql/dialect.py | 22 ++--- .../connection/db_connection/mysql/dialect.py | 23 ++--- .../db_connection/oracle/dialect.py | 20 ++-- .../db_connection/postgres/dialect.py | 22 ++--- .../db_connection/teradata/dialect.py | 22 ++--- onetl/db/db_reader/db_reader.py | 25 ++--- onetl/db/db_writer/db_writer.py | 3 +- .../test_mongodb_unit.py | 39 -------- 22 files changed, 185 insertions(+), 322 deletions(-) diff --git a/onetl/base/base_db_connection.py b/onetl/base/base_db_connection.py index fc0e1f11c..dea051e0a 100644 --- a/onetl/base/base_db_connection.py +++ b/onetl/base/base_db_connection.py @@ -15,7 +15,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any from etl_entities.source import Table @@ -33,9 +33,11 @@ class BaseDBDialect(ABC): Collection of methods used for validating input values before passing them to read_source_as_df/write_df_to_target """ - @classmethod + def __init__(self, connection: BaseDBConnection) -> None: + self.connection = connection + @abstractmethod - def validate_name(cls, connection: BaseDBConnection, value: Table) -> Table: + def validate_name(self, value: Table) -> str: """Check if ``source`` or ``target`` value is valid. Raises @@ -46,9 +48,8 @@ def validate_name(cls, connection: BaseDBConnection, value: Table) -> Table: If value is invalid """ - @classmethod @abstractmethod - def validate_columns(cls, connection: BaseDBConnection, columns: list[str] | None) -> list[str] | None: + def validate_columns(self, columns: list[str] | None) -> list[str] | None: """Check if ``columns`` value is valid. Raises @@ -59,9 +60,8 @@ def validate_columns(cls, connection: BaseDBConnection, columns: list[str] | Non If value is invalid """ - @classmethod @abstractmethod - def validate_hwm(cls, connection: BaseDBConnection, hwm: HWM) -> HWM: + def validate_hwm(self, hwm: HWM | None) -> HWM | None: """Check if ``HWM`` class is valid. Raises @@ -72,9 +72,8 @@ def validate_hwm(cls, connection: BaseDBConnection, hwm: HWM) -> HWM: If hwm is invalid """ - @classmethod @abstractmethod - def validate_df_schema(cls, connection: BaseDBConnection, df_schema: StructType | None) -> StructType | None: + def validate_df_schema(self, df_schema: StructType | None) -> StructType | None: """Check if ``df_schema`` value is valid. Raises @@ -85,9 +84,8 @@ def validate_df_schema(cls, connection: BaseDBConnection, df_schema: StructType If value is invalid """ - @classmethod @abstractmethod - def validate_where(cls, connection: BaseDBConnection, where: Any) -> Any | None: + def validate_where(self, where: Any) -> Any | None: """Check if ``where`` value is valid. Raises @@ -98,9 +96,8 @@ def validate_where(cls, connection: BaseDBConnection, where: Any) -> Any | None: If value is invalid """ - @classmethod @abstractmethod - def validate_hint(cls, connection: BaseDBConnection, hint: Any) -> Any | None: + def validate_hint(self, hint: Any) -> Any | None: """Check if ``hint`` value is valid. Raises @@ -111,34 +108,12 @@ def validate_hint(cls, connection: BaseDBConnection, hint: Any) -> Any | None: If value is invalid """ - @classmethod @abstractmethod - def detect_hwm_class(cls, hwm_column_type: str) -> ColumnHWM: + def detect_hwm_class(self, hwm_column_type: str) -> type[ColumnHWM]: """ Detects hwm column type based on specific data types in connections data stores """ - @classmethod - @abstractmethod - def _merge_conditions(cls, conditions: list[Any]) -> Any: - """ - Convert multiple WHERE conditions to one - """ - - @classmethod - @abstractmethod - def _expression_with_alias(cls, expression: Any, alias: str) -> Any: - """ - Return "expression AS alias" statement - """ - - @classmethod - @abstractmethod - def _get_compare_statement(cls, comparator: Callable, arg1: Any, arg2: Any) -> Any: - """ - Return "arg1 COMPARATOR arg2" statement - """ - class BaseDBConnection(BaseConnection): """ @@ -147,6 +122,10 @@ class BaseDBConnection(BaseConnection): Dialect = BaseDBDialect + @property + def dialect(self): + return self.Dialect(self) + @property @abstractmethod def instance_url(self) -> str: diff --git a/onetl/connection/db_connection/clickhouse/dialect.py b/onetl/connection/db_connection/clickhouse/dialect.py index 56fe44b33..3fa6eb9ad 100644 --- a/onetl/connection/db_connection/clickhouse/dialect.py +++ b/onetl/connection/db_connection/clickhouse/dialect.py @@ -20,20 +20,16 @@ class ClickhouseDialect(JDBCDialect): - @classmethod - def _get_datetime_value_sql(cls, value: datetime) -> str: + def get_partition_column_hash(self, partition_column: str, num_partitions: int) -> str: + return f"modulo(halfMD5({partition_column}), {num_partitions})" + + def get_partition_column_mod(self, partition_column: str, num_partitions: int) -> str: + return f"{partition_column} % {num_partitions}" + + def _serialize_datetime(self, value: datetime) -> str: result = value.strftime("%Y-%m-%d %H:%M:%S") return f"CAST('{result}' AS DateTime)" - @classmethod - def _get_date_value_sql(cls, value: date) -> str: + def _serialize_date(self, value: date) -> str: result = value.strftime("%Y-%m-%d") return f"CAST('{result}' AS Date)" - - @classmethod - def _get_partition_column_hash(cls, partition_column: str, num_partitions: int) -> str: - return f"modulo(halfMD5({partition_column}), {num_partitions})" - - @classmethod - def _get_partition_column_mod(cls, partition_column: str, num_partitions: int) -> str: - return f"{partition_column} % {num_partitions}" diff --git a/onetl/connection/db_connection/db_connection/dialect.py b/onetl/connection/db_connection/db_connection/dialect.py index 929d3c002..2f380321f 100644 --- a/onetl/connection/db_connection/db_connection/dialect.py +++ b/onetl/connection/db_connection/db_connection/dialect.py @@ -16,16 +16,10 @@ import operator from datetime import date, datetime -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict +from typing import Any, Callable, ClassVar, Dict from onetl.base import BaseDBDialect from onetl.hwm import Statement -from onetl.hwm.store import SparkTypeToHWM - -if TYPE_CHECKING: - from etl_entities.hwm import HWM, ColumnHWM - - from onetl.connection.db_connection.db_connection.connection import BaseDBConnection class DBDialect(BaseDBDialect): @@ -38,39 +32,28 @@ class DBDialect(BaseDBDialect): operator.ne: "{} != {}", } - @classmethod - def validate_hwm(cls, connection: BaseDBConnection, hwm: HWM) -> HWM: - if hasattr(cls, "validate_hwm_expression"): - cls.validate_hwm_expression(connection, hwm) - return hwm - - @classmethod - def detect_hwm_class(cls, hwm_column_type: str) -> ColumnHWM: - return SparkTypeToHWM.get(hwm_column_type) # type: ignore - - @classmethod - def _escape_column(cls, value: str) -> str: + def escape_column(self, value: str) -> str: return f'"{value}"' - @classmethod - def _expression_with_alias(cls, expression: str, alias: str) -> str: + def aliased(self, expression: str, alias: str) -> str: return f"{expression} AS {alias}" - @classmethod - def _get_compare_statement(cls, comparator: Callable, arg1: Any, arg2: Any) -> Any: - template = cls._compare_statements[comparator] - return template.format(arg1, cls._serialize_datetime_value(arg2)) - - @classmethod - def _merge_conditions(cls, conditions: list[Any]) -> Any: - if len(conditions) == 1: - return conditions[0] + def get_max_value(self, value: Any) -> str: + """ + Generate `MAX(value)` clause for given value + """ + result = self._serialize_value(value) + return f"MAX({result})" - return " AND ".join(f"({item})" for item in conditions) + def get_min_value(self, value: Any) -> str: + """ + Generate `MIN(value)` clause for given value + """ + result = self._serialize_value(value) + return f"MIN({result})" - @classmethod - def _condition_assembler( - cls, + def condition_assembler( + self, condition: Any, start_from: Statement | None, end_at: Statement | None, @@ -78,7 +61,7 @@ def _condition_assembler( conditions = [condition] if start_from: - condition1 = cls._get_compare_statement( + condition1 = self._get_compare_statement( comparator=start_from.operator, arg1=start_from.expression, arg2=start_from.value, @@ -86,7 +69,7 @@ def _condition_assembler( conditions.append(condition1) if end_at: - condition2 = cls._get_compare_statement( + condition2 = self._get_compare_statement( comparator=end_at.operator, arg1=end_at.expression, arg2=end_at.value, @@ -96,51 +79,41 @@ def _condition_assembler( result: list[Any] = list(filter(None, conditions)) if not result: return None + return result - return cls._merge_conditions(result) + def _get_compare_statement(self, comparator: Callable, arg1: Any, arg2: Any) -> Any: + template = self._compare_statements[comparator] + return template.format(arg1, self._serialize_value(arg2)) - @classmethod - def _serialize_datetime_value(cls, value: Any) -> str | int | dict: + def _merge_conditions(self, conditions: list[Any]) -> Any: + if len(conditions) == 1: + return conditions[0] + + return " AND ".join(f"({item})" for item in conditions) + + def _serialize_value(self, value: Any) -> str | int | dict: """ Transform the value into an SQL Dialect-supported form. """ if isinstance(value, datetime): - return cls._get_datetime_value_sql(value) + return self._serialize_datetime(value) if isinstance(value, date): - return cls._get_date_value_sql(value) + return self._serialize_date(value) return str(value) - @classmethod - def _get_datetime_value_sql(cls, value: datetime) -> str: + def _serialize_datetime(self, value: datetime) -> str: """ Transform the datetime value into supported by SQL Dialect """ result = value.isoformat() return repr(result) - @classmethod - def _get_date_value_sql(cls, value: date) -> str: + def _serialize_date(self, value: date) -> str: """ Transform the date value into supported by SQL Dialect """ result = value.isoformat() return repr(result) - - @classmethod - def _get_max_value_sql(cls, value: Any) -> str: - """ - Generate `MAX(value)` clause for given value - """ - result = cls._serialize_datetime_value(value) - return f"MAX({result})" - - @classmethod - def _get_min_value_sql(cls, value: Any) -> str: - """ - Generate `MIN(value)` clause for given value - """ - result = cls._serialize_datetime_value(value) - return f"MIN({result})" diff --git a/onetl/connection/db_connection/dialect_mixins/support_hwm_expression_none.py b/onetl/connection/db_connection/dialect_mixins/support_hwm_expression_none.py index 27b5cc89b..3118dc72e 100644 --- a/onetl/connection/db_connection/dialect_mixins/support_hwm_expression_none.py +++ b/onetl/connection/db_connection/dialect_mixins/support_hwm_expression_none.py @@ -6,10 +6,11 @@ class SupportHWMExpressionNone: - @classmethod - def validate_hwm_expression(cls, connection: BaseDBConnection, hwm: HWM) -> HWM | None: - if hwm.expression is not None: + connection: BaseDBConnection + + def validate_hwm_expression(self, hwm: HWM | None) -> HWM | None: + if hwm and hwm.expression is not None: raise ValueError( - f"'hwm.expression' parameter is not supported by {connection.__class__.__name__}", + f"'hwm.expression' parameter is not supported by {self.connection.__class__.__name__}", ) return hwm diff --git a/onetl/connection/db_connection/dialect_mixins/support_hwm_expression_str.py b/onetl/connection/db_connection/dialect_mixins/support_hwm_expression_str.py index 4280860a3..1bcb49ac6 100644 --- a/onetl/connection/db_connection/dialect_mixins/support_hwm_expression_str.py +++ b/onetl/connection/db_connection/dialect_mixins/support_hwm_expression_str.py @@ -6,14 +6,15 @@ class SupportHWMExpressionStr: - @classmethod - def validate_hwm_expression(cls, connection: BaseDBConnection, hwm: HWM) -> HWM | None: - if hwm.expression is None: - return None + connection: BaseDBConnection + + def validate_hwm_expression(self, hwm: HWM | None) -> HWM | None: + if not hwm or hwm.expression is None: + return hwm if not isinstance(hwm.expression, str): raise TypeError( - f"{connection.__class__.__name__} requires 'hwm.expression' parameter type to be 'str', " + f"{self.connection.__class__.__name__} requires 'hwm.expression' parameter type to be 'str', " f"got {hwm.expression.__class__.__name__!r}", ) diff --git a/onetl/connection/db_connection/greenplum/connection.py b/onetl/connection/db_connection/greenplum/connection.py index d1eedff7f..fa0530cde 100644 --- a/onetl/connection/db_connection/greenplum/connection.py +++ b/onetl/connection/db_connection/greenplum/connection.py @@ -273,7 +273,7 @@ def read_source_as_df( ) -> DataFrame: read_options = self.ReadOptions.parse(options).dict(by_alias=True, exclude_none=True) log.info("|%s| Executing SQL query (on executor):", self.__class__.__name__) - where = self.Dialect._condition_assembler(condition=where, start_from=start_from, end_at=end_at) + where = self.dialect.condition_assembler(condition=where, start_from=start_from, end_at=end_at) query = get_sql_query(table=source, columns=columns, where=where) log_lines(log, query) @@ -351,13 +351,13 @@ def get_min_max_bounds( query = get_sql_query( table=source, columns=[ - self.Dialect._expression_with_alias( - self.Dialect._get_min_value_sql(expression or column), - self.Dialect._escape_column("min"), + self.dialect.aliased( + self.dialect.get_min_value(expression or column), + self.dialect.escape_column("min"), ), - self.Dialect._expression_with_alias( - self.Dialect._get_max_value_sql(expression or column), - self.Dialect._escape_column("max"), + self.dialect.aliased( + self.dialect.get_max_value(expression or column), + self.dialect.escape_column("max"), ), ], where=where, diff --git a/onetl/connection/db_connection/greenplum/dialect.py b/onetl/connection/db_connection/greenplum/dialect.py index 76a572750..820c6ef9d 100644 --- a/onetl/connection/db_connection/greenplum/dialect.py +++ b/onetl/connection/db_connection/greenplum/dialect.py @@ -36,12 +36,10 @@ class GreenplumDialect( # noqa: WPS215 SupportHWMExpressionStr, DBDialect, ): - @classmethod - def _get_datetime_value_sql(cls, value: datetime) -> str: + def _serialize_datetime(self, value: datetime) -> str: result = value.isoformat() return f"cast('{result}' as timestamp)" - @classmethod - def _get_date_value_sql(cls, value: date) -> str: + def _serialize_date(self, value: date) -> str: result = value.isoformat() return f"cast('{result}' as date)" diff --git a/onetl/connection/db_connection/hive/connection.py b/onetl/connection/db_connection/hive/connection.py index 97cc034f5..2e07c2623 100644 --- a/onetl/connection/db_connection/hive/connection.py +++ b/onetl/connection/db_connection/hive/connection.py @@ -367,7 +367,7 @@ def read_source_as_df( start_from: Statement | None = None, end_at: Statement | None = None, ) -> DataFrame: - where = self.Dialect._condition_assembler(condition=where, start_from=start_from, end_at=end_at) + where = self.dialect.condition_assembler(condition=where, start_from=start_from, end_at=end_at) sql_text = get_sql_query( table=source, columns=columns, @@ -407,13 +407,13 @@ def get_min_max_bounds( sql_text = get_sql_query( table=source, columns=[ - self.Dialect._expression_with_alias( - self.Dialect._get_min_value_sql(expression or column), - self.Dialect._escape_column("min"), + self.dialect.aliased( + self.dialect.get_min_value(expression or column), + self.dialect.escape_column("min"), ), - self.Dialect._expression_with_alias( - self.Dialect._get_max_value_sql(expression or column), - self.Dialect._escape_column("max"), + self.dialect.aliased( + self.dialect.get_max_value(expression or column), + self.dialect.escape_column("max"), ), ], where=where, diff --git a/onetl/connection/db_connection/hive/dialect.py b/onetl/connection/db_connection/hive/dialect.py index bafd7d97a..5ed611008 100644 --- a/onetl/connection/db_connection/hive/dialect.py +++ b/onetl/connection/db_connection/hive/dialect.py @@ -34,6 +34,5 @@ class HiveDialect( # noqa: WPS215 SupportHWMExpressionStr, DBDialect, ): - @classmethod - def _escape_column(cls, value: str) -> str: + def escape_column(self, value: str) -> str: return f"`{value}`" diff --git a/onetl/connection/db_connection/jdbc_connection/connection.py b/onetl/connection/db_connection/jdbc_connection/connection.py index f5b611910..685d4c267 100644 --- a/onetl/connection/db_connection/jdbc_connection/connection.py +++ b/onetl/connection/db_connection/jdbc_connection/connection.py @@ -172,12 +172,12 @@ def read_source_as_df( if read_options.partition_column: if read_options.partitioning_mode == JDBCPartitioningMode.MOD: - partition_column = self.Dialect._get_partition_column_mod( + partition_column = self.dialect.get_partition_column_mod( read_options.partition_column, read_options.num_partitions, ) elif read_options.partitioning_mode == JDBCPartitioningMode.HASH: - partition_column = self.Dialect._get_partition_column_hash( + partition_column = self.dialect.get_partition_column_hash( read_options.partition_column, read_options.num_partitions, ) @@ -189,12 +189,12 @@ def read_source_as_df( # have the same name as the field in the table ( 2.4 version ) # https://github.com/apache/spark/pull/21379 alias = "generated_" + secrets.token_hex(5) - alias_escaped = self.Dialect._escape_column(alias) - aliased_column = self.Dialect._expression_with_alias(partition_column, alias_escaped) + alias_escaped = self.dialect.escape_column(alias) + aliased_column = self.dialect.aliased(partition_column, alias_escaped) read_options = read_options.copy(update={"partition_column": alias_escaped}) new_columns.append(aliased_column) - where = self.Dialect._condition_assembler(condition=where, start_from=start_from, end_at=end_at) + where = self.dialect.condition_assembler(condition=where, start_from=start_from, end_at=end_at) query = get_sql_query( table=source, columns=new_columns, @@ -296,13 +296,13 @@ def get_min_max_bounds( query = get_sql_query( table=source, columns=[ - self.Dialect._expression_with_alias( - self.Dialect._get_min_value_sql(expression or column), - self.Dialect._escape_column("min"), + self.dialect.aliased( + self.dialect.get_min_value(expression or column), + self.dialect.escape_column("min"), ), - self.Dialect._expression_with_alias( - self.Dialect._get_max_value_sql(expression or column), - self.Dialect._escape_column("max"), + self.dialect.aliased( + self.dialect.get_max_value(expression or column), + self.dialect.escape_column("max"), ), ], where=where, diff --git a/onetl/connection/db_connection/jdbc_connection/dialect.py b/onetl/connection/db_connection/jdbc_connection/dialect.py index c7fed5472..cd3371cb5 100644 --- a/onetl/connection/db_connection/jdbc_connection/dialect.py +++ b/onetl/connection/db_connection/jdbc_connection/dialect.py @@ -36,12 +36,10 @@ class JDBCDialect( # noqa: WPS215 SupportHWMExpressionStr, DBDialect, ): - @classmethod @abstractmethod - def _get_partition_column_hash(cls, partition_column: str, num_partitions: int) -> str: + def get_partition_column_hash(self, partition_column: str, num_partitions: int) -> str: ... - @classmethod @abstractmethod - def _get_partition_column_mod(cls, partition_column: str, num_partitions: int) -> str: + def get_partition_column_mod(self, partition_column: str, num_partitions: int) -> str: ... diff --git a/onetl/connection/db_connection/kafka/dialect.py b/onetl/connection/db_connection/kafka/dialect.py index 53714bedd..d27816e5f 100644 --- a/onetl/connection/db_connection/kafka/dialect.py +++ b/onetl/connection/db_connection/kafka/dialect.py @@ -20,7 +20,6 @@ from etl_entities.hwm import HWM from onetl._util.spark import get_spark_version -from onetl.base import BaseDBConnection from onetl.connection.db_connection.db_connection.dialect import DBDialect from onetl.connection.db_connection.dialect_mixins import ( SupportColumnsNone, @@ -43,26 +42,25 @@ class KafkaDialect( # noqa: WPS215 SupportHWMExpressionNone, DBDialect, ): - valid_hwm_columns = {"offset", "timestamp"} + VALID_HWM_COLUMNS = {"offset", "timestamp"} - @classmethod def validate_hwm( - cls, - connection: BaseDBConnection, - hwm: HWM, - ) -> HWM: - cls.validate_column(connection, hwm.entity) + self, + hwm: HWM | None, + ) -> HWM | None: + if hwm is None: + return None + self.validate_column(hwm.entity) return hwm - @classmethod - def validate_column(cls, connection: BaseDBConnection, column: str) -> None: - if column not in cls.valid_hwm_columns: - raise ValueError(f"{column} is not a valid hwm column. Valid options are: {cls.valid_hwm_columns}") + def validate_column(self, column: str) -> None: + if column not in self.VALID_HWM_COLUMNS: + raise ValueError(f"{column} is not a valid hwm column. Valid options are: {self.VALID_HWM_COLUMNS}") if column == "timestamp": # Spark version less 3.x does not support reading from Kafka with the timestamp parameter - spark_version = get_spark_version(connection.spark) # type: ignore[attr-defined] + spark_version = get_spark_version(self.connection.spark) # type: ignore[attr-defined] if spark_version.major < 3: raise ValueError( f"Spark version must be 3.x for the timestamp column. Current version is: {spark_version}", diff --git a/onetl/connection/db_connection/mongodb/connection.py b/onetl/connection/db_connection/mongodb/connection.py index 280596d5d..51dd6f9cc 100644 --- a/onetl/connection/db_connection/mongodb/connection.py +++ b/onetl/connection/db_connection/mongodb/connection.py @@ -359,7 +359,7 @@ def pipeline( log.info("|%s| Executing aggregation pipeline:", self.__class__.__name__) read_options = self.PipelineOptions.parse(options).dict(by_alias=True, exclude_none=True) - pipeline = self.Dialect.prepare_pipeline(pipeline) + pipeline = self.dialect.prepare_pipeline(pipeline) log_with_indent(log, "collection = %r", collection) log_json(log, pipeline, name="pipeline") @@ -370,7 +370,7 @@ def pipeline( log_options(log, read_options) read_options["collection"] = collection - read_options["aggregation.pipeline"] = self.Dialect.convert_to_str(pipeline) + read_options["aggregation.pipeline"] = self.dialect.convert_to_str(pipeline) read_options["connection.uri"] = self.connection_url spark_reader = self.spark.read.format("mongodb").options(**read_options) @@ -418,14 +418,14 @@ def get_min_max_bounds( if where: pipeline.insert(0, {"$match": where}) - pipeline = self.Dialect.prepare_pipeline(pipeline) + pipeline = self.dialect.prepare_pipeline(pipeline) read_options["connection.uri"] = self.connection_url read_options["collection"] = source - read_options["aggregation.pipeline"] = self.Dialect.convert_to_str(pipeline) + read_options["aggregation.pipeline"] = self.dialect.convert_to_str(pipeline) if hint: - read_options["hint"] = self.Dialect.convert_to_str(hint) + read_options["hint"] = self.dialect.convert_to_str(hint) log.info("|%s| Executing aggregation pipeline:", self.__class__.__name__) log_with_indent(log, "collection = %r", source) @@ -456,18 +456,18 @@ def read_source_as_df( options: MongoDBReadOptions | dict | None = None, ) -> DataFrame: read_options = self.ReadOptions.parse(options).dict(by_alias=True, exclude_none=True) - final_where = self.Dialect._condition_assembler( + final_where = self.dialect.condition_assembler( condition=where, start_from=start_from, end_at=end_at, ) - pipeline = self.Dialect.prepare_pipeline({"$match": final_where}) if final_where else None + pipeline = self.dialect.prepare_pipeline({"$match": final_where}) if final_where else None if pipeline: - read_options["aggregation.pipeline"] = self.Dialect.convert_to_str(pipeline) + read_options["aggregation.pipeline"] = self.dialect.convert_to_str(pipeline) if hint: - read_options["hint"] = self.Dialect.convert_to_str(hint) + read_options["hint"] = self.dialect.convert_to_str(hint) read_options["connection.uri"] = self.connection_url read_options["collection"] = source diff --git a/onetl/connection/db_connection/mongodb/dialect.py b/onetl/connection/db_connection/mongodb/dialect.py index f408615aa..d1818a8a0 100644 --- a/onetl/connection/db_connection/mongodb/dialect.py +++ b/onetl/connection/db_connection/mongodb/dialect.py @@ -19,7 +19,6 @@ from datetime import datetime from typing import Any, Callable, ClassVar, Dict, Iterable, Mapping -from onetl.base.base_db_connection import BaseDBConnection from onetl.connection.db_connection.db_connection.dialect import DBDialect from onetl.connection.db_connection.dialect_mixins import ( SupportColumnsNone, @@ -88,10 +87,8 @@ class MongoDBDialect( # noqa: WPS215 operator.ne: "$ne", } - @classmethod def validate_where( - cls, - connection: BaseDBConnection, + self, where: Any, ) -> dict | None: if where is None: @@ -99,18 +96,16 @@ def validate_where( if not isinstance(where, dict): raise ValueError( - f"{connection.__class__.__name__} requires 'where' parameter type to be 'dict', " + f"{self.connection.__class__.__name__} requires 'where' parameter type to be 'dict', " f"got {where.__class__.__name__!r}", ) for key in where: - cls._validate_top_level_keys_in_where_parameter(key) + self._validate_top_level_keys_in_where_parameter(key) return where - @classmethod def validate_hint( - cls, - connection: BaseDBConnection, + self, hint: Any, ) -> dict | None: if hint is None: @@ -118,14 +113,13 @@ def validate_hint( if not isinstance(hint, dict): raise ValueError( - f"{connection.__class__.__name__} requires 'hint' parameter type to be 'dict', " + f"{self.connection.__class__.__name__} requires 'hint' parameter type to be 'dict', " f"got {hint.__class__.__name__!r}", ) return hint - @classmethod def prepare_pipeline( - cls, + self, pipeline: Any, ) -> Any: """ @@ -136,33 +130,30 @@ def prepare_pipeline( return {"$date": pipeline.astimezone().isoformat()} if isinstance(pipeline, Mapping): - return {cls.prepare_pipeline(key): cls.prepare_pipeline(value) for key, value in pipeline.items()} + return {self.prepare_pipeline(key): self.prepare_pipeline(value) for key, value in pipeline.items()} if isinstance(pipeline, Iterable) and not isinstance(pipeline, str): - return [cls.prepare_pipeline(item) for item in pipeline] + return [self.prepare_pipeline(item) for item in pipeline] return pipeline - @classmethod def convert_to_str( - cls, + self, value: Any, ) -> str: """ Converts the given dictionary, list or primitive to a string. """ - return json.dumps(cls.prepare_pipeline(value)) + return json.dumps(self.prepare_pipeline(value)) - @classmethod - def _merge_conditions(cls, conditions: list[Any]) -> Any: + def _merge_conditions(self, conditions: list[Any]) -> Any: if len(conditions) == 1: return conditions[0] return {"$and": conditions} - @classmethod - def _get_compare_statement(cls, comparator: Callable, arg1: Any, arg2: Any) -> dict: + def _get_compare_statement(self, comparator: Callable, arg1: Any, arg2: Any) -> dict: """ Returns the comparison statement in MongoDB syntax: @@ -176,12 +167,11 @@ def _get_compare_statement(cls, comparator: Callable, arg1: Any, arg2: Any) -> d """ return { arg1: { - cls._compare_statements[comparator]: arg2, + self._compare_statements[comparator]: arg2, }, } - @classmethod - def _validate_top_level_keys_in_where_parameter(cls, key: str): + def _validate_top_level_keys_in_where_parameter(self, key: str): """ Checks the 'where' parameter for illegal operators, such as ``$match``, ``$merge`` or ``$changeStream``. diff --git a/onetl/connection/db_connection/mssql/dialect.py b/onetl/connection/db_connection/mssql/dialect.py index 95e4ff022..f39568423 100644 --- a/onetl/connection/db_connection/mssql/dialect.py +++ b/onetl/connection/db_connection/mssql/dialect.py @@ -20,21 +20,17 @@ class MSSQLDialect(JDBCDialect): - @classmethod - def _get_datetime_value_sql(cls, value: datetime) -> str: + # https://docs.microsoft.com/ru-ru/sql/t-sql/functions/hashbytes-transact-sql?view=sql-server-ver16 + def get_partition_column_hash(self, partition_column: str, num_partitions: int) -> str: + return f"CONVERT(BIGINT, HASHBYTES ( 'SHA' , {partition_column} )) % {num_partitions}" + + def get_partition_column_mod(self, partition_column: str, num_partitions: int) -> str: + return f"{partition_column} % {num_partitions}" + + def _serialize_datetime(self, value: datetime) -> str: result = value.isoformat() return f"CAST('{result}' AS datetime2)" - @classmethod - def _get_date_value_sql(cls, value: date) -> str: + def _serialize_date(self, value: date) -> str: result = value.isoformat() return f"CAST('{result}' AS date)" - - # https://docs.microsoft.com/ru-ru/sql/t-sql/functions/hashbytes-transact-sql?view=sql-server-ver16 - @classmethod - def _get_partition_column_hash(cls, partition_column: str, num_partitions: int) -> str: - return f"CONVERT(BIGINT, HASHBYTES ( 'SHA' , {partition_column} )) % {num_partitions}" - - @classmethod - def _get_partition_column_mod(cls, partition_column: str, num_partitions: int) -> str: - return f"{partition_column} % {num_partitions}" diff --git a/onetl/connection/db_connection/mysql/dialect.py b/onetl/connection/db_connection/mysql/dialect.py index b3cd70a55..59f663aed 100644 --- a/onetl/connection/db_connection/mysql/dialect.py +++ b/onetl/connection/db_connection/mysql/dialect.py @@ -20,24 +20,19 @@ class MySQLDialect(JDBCDialect): - @classmethod - def _escape_column(cls, value: str) -> str: + def get_partition_column_hash(self, partition_column: str, num_partitions: int) -> str: + return f"MOD(CONV(CONV(RIGHT(MD5({partition_column}), 16), 16, 2), 2, 10), {num_partitions})" + + def get_partition_column_mod(self, partition_column: str, num_partitions: int) -> str: + return f"MOD({partition_column}, {num_partitions})" + + def escape_column(self, value: str) -> str: return f"`{value}`" - @classmethod - def _get_datetime_value_sql(cls, value: datetime) -> str: + def _serialize_datetime(self, value: datetime) -> str: result = value.strftime("%Y-%m-%d %H:%M:%S.%f") return f"STR_TO_DATE('{result}', '%Y-%m-%d %H:%i:%s.%f')" - @classmethod - def _get_date_value_sql(cls, value: date) -> str: + def _serialize_date(self, value: date) -> str: result = value.strftime("%Y-%m-%d") return f"STR_TO_DATE('{result}', '%Y-%m-%d')" - - @classmethod - def _get_partition_column_hash(cls, partition_column: str, num_partitions: int) -> str: - return f"MOD(CONV(CONV(RIGHT(MD5({partition_column}), 16),16, 2), 2, 10), {num_partitions})" - - @classmethod - def _get_partition_column_mod(cls, partition_column: str, num_partitions: int) -> str: - return f"MOD({partition_column}, {num_partitions})" diff --git a/onetl/connection/db_connection/oracle/dialect.py b/onetl/connection/db_connection/oracle/dialect.py index fb3fa715d..e413142c4 100644 --- a/onetl/connection/db_connection/oracle/dialect.py +++ b/onetl/connection/db_connection/oracle/dialect.py @@ -20,20 +20,16 @@ class OracleDialect(JDBCDialect): - @classmethod - def _get_datetime_value_sql(cls, value: datetime) -> str: + def get_partition_column_hash(self, partition_column: str, num_partitions: int) -> str: + return f"ora_hash({partition_column}, {num_partitions})" + + def get_partition_column_mod(self, partition_column: str, num_partitions: int) -> str: + return f"MOD({partition_column}, {num_partitions})" + + def _serialize_datetime(self, value: datetime) -> str: result = value.strftime("%Y-%m-%d %H:%M:%S") return f"TO_DATE('{result}', 'YYYY-MM-DD HH24:MI:SS')" - @classmethod - def _get_date_value_sql(cls, value: date) -> str: + def _serialize_date(self, value: date) -> str: result = value.strftime("%Y-%m-%d") return f"TO_DATE('{result}', 'YYYY-MM-DD')" - - @classmethod - def _get_partition_column_hash(cls, partition_column: str, num_partitions: int) -> str: - return f"ora_hash({partition_column}, {num_partitions})" - - @classmethod - def _get_partition_column_mod(cls, partition_column: str, num_partitions: int) -> str: - return f"MOD({partition_column}, {num_partitions})" diff --git a/onetl/connection/db_connection/postgres/dialect.py b/onetl/connection/db_connection/postgres/dialect.py index 05a44471e..babfd63aa 100644 --- a/onetl/connection/db_connection/postgres/dialect.py +++ b/onetl/connection/db_connection/postgres/dialect.py @@ -21,21 +21,17 @@ class PostgresDialect(SupportHintNone, JDBCDialect): - @classmethod - def _get_datetime_value_sql(cls, value: datetime) -> str: + # https://stackoverflow.com/a/9812029 + def get_partition_column_hash(self, partition_column: str, num_partitions: int) -> str: + return f"('x'||right(md5('{partition_column}'), 16))::bit(32)::bigint % {num_partitions}" + + def get_partition_column_mod(self, partition_column: str, num_partitions: int) -> str: + return f"{partition_column} % {num_partitions}" + + def _serialize_datetime(self, value: datetime) -> str: result = value.isoformat() return f"'{result}'::timestamp" - @classmethod - def _get_date_value_sql(cls, value: date) -> str: + def _serialize_date(self, value: date) -> str: result = value.isoformat() return f"'{result}'::date" - - # https://stackoverflow.com/a/9812029 - @classmethod - def _get_partition_column_hash(cls, partition_column: str, num_partitions: int) -> str: - return f"('x'||right(md5('{partition_column}'), 16))::bit(32)::bigint % {num_partitions}" - - @classmethod - def _get_partition_column_mod(cls, partition_column: str, num_partitions: int) -> str: - return f"{partition_column} % {num_partitions}" diff --git a/onetl/connection/db_connection/teradata/dialect.py b/onetl/connection/db_connection/teradata/dialect.py index c449debc6..7845dc360 100644 --- a/onetl/connection/db_connection/teradata/dialect.py +++ b/onetl/connection/db_connection/teradata/dialect.py @@ -20,21 +20,17 @@ class TeradataDialect(JDBCDialect): - @classmethod - def _get_datetime_value_sql(cls, value: datetime) -> str: + # https://docs.teradata.com/r/w4DJnG9u9GdDlXzsTXyItA/lkaegQT4wAakj~K_ZmW1Dg + def get_partition_column_hash(self, partition_column: str, num_partitions: int) -> str: + return f"HASHAMP(HASHBUCKET(HASHROW({partition_column}))) mod {num_partitions}" + + def get_partition_column_mod(self, partition_column: str, num_partitions: int) -> str: + return f"{partition_column} mod {num_partitions}" + + def _serialize_datetime(self, value: datetime) -> str: result = value.isoformat() return f"CAST('{result}' AS TIMESTAMP)" - @classmethod - def _get_date_value_sql(cls, value: date) -> str: + def _serialize_date(self, value: date) -> str: result = value.isoformat() return f"CAST('{result}' AS DATE)" - - # https://docs.teradata.com/r/w4DJnG9u9GdDlXzsTXyItA/lkaegQT4wAakj~K_ZmW1Dg - @classmethod - def _get_partition_column_hash(cls, partition_column: str, num_partitions: int) -> str: - return f"HASHAMP(HASHBUCKET(HASHROW({partition_column}))) mod {num_partitions}" - - @classmethod - def _get_partition_column_mod(cls, partition_column: str, num_partitions: int) -> str: - return f"{partition_column} mod {num_partitions}" diff --git a/onetl/db/db_reader/db_reader.py b/onetl/db/db_reader/db_reader.py index 16186503e..096977a13 100644 --- a/onetl/db/db_reader/db_reader.py +++ b/onetl/db/db_reader/db_reader.py @@ -382,18 +382,16 @@ class DBReader(FrozenModel): @validator("source", pre=True, always=True) def validate_source(cls, source, values): connection: BaseDBConnection = values["connection"] - dialect = connection.Dialect if isinstance(source, str): # source="dbschema.table" or source="table", If source="dbschema.some.table" in class Table will raise error. source = Table(name=source, instance=connection.instance_url) # Here Table(name='source', db='dbschema', instance='some_instance') - return dialect.validate_name(connection, source) + return connection.dialect.validate_name(source) @validator("where", pre=True, always=True) def validate_where(cls, where: Any, values: dict) -> Any: connection: BaseDBConnection = values["connection"] - dialect = connection.Dialect - result = dialect.validate_where(connection, where) + result = connection.dialect.validate_where(where) if isinstance(result, dict): return frozendict.frozendict(result) # type: ignore[attr-defined, operator] return result @@ -401,8 +399,7 @@ def validate_where(cls, where: Any, values: dict) -> Any: @validator("hint", pre=True, always=True) def validate_hint(cls, hint: Any, values: dict) -> Any: connection: BaseDBConnection = values["connection"] - dialect = connection.Dialect - result = dialect.validate_hint(connection, hint) + result = connection.dialect.validate_hint(hint) if isinstance(result, dict): return frozendict.frozendict(result) # type: ignore[attr-defined, operator] return result @@ -410,8 +407,7 @@ def validate_hint(cls, hint: Any, values: dict) -> Any: @validator("df_schema", pre=True, always=True) def validate_df_schema(cls, df_schema: StructType | None, values: dict) -> StructType | None: connection: BaseDBConnection = values["connection"] - dialect = connection.Dialect - return dialect.validate_df_schema(connection, df_schema) + return connection.dialect.validate_df_schema(df_schema) @root_validator(pre=True) def validate_hwm(cls, values: dict) -> dict: # noqa: WPS231 @@ -463,17 +459,12 @@ def validate_hwm(cls, values: dict) -> dict: # noqa: WPS231 "Otherwise DBReader cannot determine HWM type for this column", ) - dialect = connection.Dialect - dialect.validate_hwm(connection, hwm) - - values["hwm"] = hwm - + values["hwm"] = connection.dialect.validate_hwm(hwm) return values @root_validator(pre=True) # noqa: WPS231 def validate_columns(cls, values: dict) -> dict: connection: BaseDBConnection = values["connection"] - dialect = connection.Dialect columns: list[str] | str | None = values.get("columns") columns_list: list[str] | None @@ -482,7 +473,7 @@ def validate_columns(cls, values: dict) -> dict: else: columns_list = columns - columns_list = dialect.validate_columns(connection, columns_list) + columns_list = connection.dialect.validate_columns(columns_list) if columns_list is None: return values @@ -531,7 +522,7 @@ def detect_hwm(self, hwm: HWM) -> HWM: schema = {field.name.casefold(): field for field in self.get_df_schema()} column = hwm.entity.casefold() target_column_data_type = schema[column].dataType.typeName() - hwm_class_for_target = self.connection.Dialect.detect_hwm_class(target_column_data_type) + hwm_class_for_target = self.connection.dialect.detect_hwm_class(target_column_data_type) try: return hwm_class_for_target.deserialize(hwm.dict()) except ValueError as e: @@ -702,7 +693,7 @@ def _resolve_all_columns(self) -> list[str] | None: hwm_statement = ( self.hwm.entity if not self.hwm.expression - else self.connection.Dialect._expression_with_alias( # noqa: WPS437 + else self.connection.dialect.aliased( # noqa: WPS437 self.hwm.expression, self.hwm.entity, ) diff --git a/onetl/db/db_writer/db_writer.py b/onetl/db/db_writer/db_writer.py index ea36cd4b3..c072ffc3e 100644 --- a/onetl/db/db_writer/db_writer.py +++ b/onetl/db/db_writer/db_writer.py @@ -157,12 +157,11 @@ class DBWriter(FrozenModel): @validator("target", pre=True, always=True) def validate_target(cls, target, values): connection: BaseDBConnection = values["connection"] - dialect = connection.Dialect if isinstance(target, str): # target="dbschema.table" or target="table", If target="dbschema.some.table" in class Table will raise error. target = Table(name=target, instance=connection.instance_url) # Here Table(name='target', db='dbschema', instance='some_instance') - return dialect.validate_name(connection, target) + return connection.dialect.validate_name(target) @validator("options", pre=True, always=True) def validate_options(cls, options, values): diff --git a/tests/tests_unit/tests_db_connection_unit/test_mongodb_unit.py b/tests/tests_unit/tests_db_connection_unit/test_mongodb_unit.py index d53b4d614..209a05f6d 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_mongodb_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_mongodb_unit.py @@ -1,5 +1,4 @@ import re -from datetime import datetime import pytest @@ -202,44 +201,6 @@ def test_mongodb_with_extra(spark_mock): assert mongo.connection_url == "mongodb://user:password@host:27017/database?opt1=value1&tls=true" -def test_mongodb_convert_list_to_str(): - where = [ - {"$or": [{"col_1": {"$gt": 1, "$eq": True}}, {"col_2": {"$eq": None}}]}, - { - "$and": [ - {"col_3": {"$eq": "Hello"}}, - {"col_4": {"$eq": datetime.fromisoformat("2022-12-23T11:22:33.456+03:00")}}, - ], - }, - ] - - assert MongoDB.Dialect.convert_to_str(where) == ( - '[{"$or": [{"col_1": {"$gt": 1, "$eq": true}}, {"col_2": {"$eq": null}}]}, ' - '{"$and": [{"col_3": {"$eq": "Hello"}}, {"col_4": {"$eq": {"$date": "2022-12-23T08:22:33.456000+00:00"}}}]}]' - ) - - -def test_mongodb_convert_dict_to_str(): - where = { - "$and": [ - {"$or": [{"col_1": {"$gt": 1, "$eq": True}}, {"col_2": {"$eq": None}}]}, - { - "$and": [ - {"col_3": {"$eq": "Hello"}}, - {"col_4": {"$eq": datetime.fromisoformat("2022-12-23T11:22:33.456+03:00")}}, - ], - }, - ], - } - - assert MongoDB.Dialect.convert_to_str(where) == ( - '{"$and": ' - '[{"$or": [{"col_1": {"$gt": 1, "$eq": true}}, {"col_2": {"$eq": null}}]}, ' - '{"$and": [{"col_3": {"$eq": "Hello"}}, {"col_4": {"$eq": {"$date": "2022-12-23T08:22:33.456000+00:00"}}}]}]' - "}" - ) - - @pytest.mark.parametrize( "options, value", [