Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#2249 sqlalchemy indexes off by default #2253

Merged
merged 2 commits into from
Feb 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 0 additions & 15 deletions dlt/destinations/impl/snowflake/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,21 +141,6 @@ class SnowflakeClientConfiguration(DestinationClientDwhWithStagingConfiguration)
create_indexes: bool = False
"""Whether UNIQUE or PRIMARY KEY constrains should be created"""

def __init__(
self,
*,
credentials: SnowflakeCredentials = None,
create_indexes: bool = False,
destination_name: str = None,
environment: str = None,
) -> None:
super().__init__(
credentials=credentials,
destination_name=destination_name,
environment=environment,
)
self.create_indexes = create_indexes

def fingerprint(self) -> str:
"""Returns a fingerprint of host part of a connection string"""
if self.credentials and self.credentials.host:
Expand Down
8 changes: 8 additions & 0 deletions dlt/destinations/impl/snowflake/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ def __init__(
stage_name: t.Optional[str] = None,
keep_staged_files: bool = True,
csv_format: t.Optional[CsvFormatConfiguration] = None,
query_tag: t.Optional[str] = None,
create_indexes: bool = False,
destination_name: t.Optional[str] = None,
environment: t.Optional[str] = None,
**kwargs: t.Any,
Expand All @@ -153,12 +155,18 @@ def __init__(
a connection string in the format `snowflake://user:password@host:port/database`
stage_name: Name of an existing stage to use for loading data. Default uses implicit stage per table
keep_staged_files: Whether to delete or keep staged files after loading
csv_format: Optional csv format configuration
query_tag: A tag with placeholders to tag sessions executing jobs
create_indexes: Whether UNIQUE or PRIMARY KEY constrains should be created

"""
super().__init__(
credentials=credentials,
stage_name=stage_name,
keep_staged_files=keep_staged_files,
csv_format=csv_format,
query_tag=query_tag,
create_indexes=create_indexes,
destination_name=destination_name,
environment=environment,
**kwargs,
Expand Down
4 changes: 4 additions & 0 deletions dlt/destinations/impl/sqlalchemy/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ class SqlalchemyClientConfiguration(DestinationClientDwhConfiguration):
destination_type: Final[str] = dataclasses.field(default="sqlalchemy", init=False, repr=False, compare=False) # type: ignore
credentials: SqlalchemyCredentials = None
"""SQLAlchemy connection string"""
create_unique_indexes: bool = False
"""Whether UNIQUE constrains should be created"""
create_primary_keys: bool = False
"""Whether PRIMARY KEY constrains should be created"""

engine_args: Dict[str, Any] = dataclasses.field(default_factory=dict)
"""Additional arguments passed to `sqlalchemy.create_engine`"""
Expand Down
5 changes: 4 additions & 1 deletion dlt/destinations/impl/sqlalchemy/db_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ def reflect_table(
table_name,
metadata,
autoload_with=self._current_connection,
resolve_fks=False,
schema=self.dataset_name,
include_columns=include_columns,
extend_existing=True,
Expand Down Expand Up @@ -442,7 +443,7 @@ def _make_database_exception(e: Exception) -> Exception:
# SQLite
r"no such table", # Missing table
r"no such database", # Missing table
# PostgreSQL / Trino / Vertica
# PostgreSQL / Trino / Vertica / Exasol (database)
r"does not exist", # Missing schema, relation
# r"does not exist", # Missing table
# MSSQL
Expand All @@ -457,6 +458,8 @@ def _make_database_exception(e: Exception) -> Exception:
# Apache Hive
r"table not found", # Missing table
r"database does not exist",
# Exasol
r" not found",
]
# entity not found
for pat_ in patterns:
Expand Down
20 changes: 15 additions & 5 deletions dlt/destinations/impl/sqlalchemy/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ def client_class(self) -> t.Type["SqlalchemyJobClient"]:
def __init__(
self,
credentials: t.Union[SqlalchemyCredentials, t.Dict[str, t.Any], str, "Engine"] = None,
create_unique_indexes: bool = False,
create_primary_keys: bool = False,
destination_name: t.Optional[str] = None,
environment: t.Optional[str] = None,
engine_args: t.Optional[t.Dict[str, t.Any]] = None,
Expand All @@ -107,16 +109,24 @@ def __init__(
All arguments provided here supersede other configuration sources such as environment variables and dlt config files.

Args:
credentials: Credentials to connect to the sqlalchemy database. Can be an instance of `SqlalchemyCredentials` or
a connection string in the format `mysql://user:password@host:port/database`
destination_name: The name of the destination
environment: The environment to use
**kwargs: Additional arguments passed to the destination
credentials (Union[SqlalchemyCredentials, Dict[str, Any], str, Engine], optional): Credentials to connect to the sqlalchemy database. Can be an instance of
`SqlalchemyCredentials` or a connection string in the format `mysql://user:password@host:port/database`. Defaults to None.
create_unique_indexes (bool, optional): Whether UNIQUE constraints should be created. Defaults to False.
create_primary_keys (bool, optional): Whether PRIMARY KEY constraints should be created. Defaults to False.
destination_name (Optional[str], optional): The name of the destination. Defaults to None.
environment (Optional[str], optional): The environment to use. Defaults to None.
engine_args (Optional[Dict[str, Any]], optional): Additional arguments to pass to the SQLAlchemy engine. Defaults to None.
**kwargs (Any): Additional arguments passed to the destination.
Returns:
None
"""
super().__init__(
credentials=credentials,
create_unique_indexes=create_unique_indexes,
create_primary_keys=create_primary_keys,
destination_name=destination_name,
environment=environment,
engine_args=engine_args,
**kwargs,
)

Expand Down
29 changes: 21 additions & 8 deletions dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
pipeline_state_table,
normalize_table_identifiers,
is_complete_column,
get_columns_names_with_prop,
)
from dlt.destinations.exceptions import DatabaseUndefinedRelation
from dlt.destinations.impl.sqlalchemy.db_api_client import SqlalchemyClient
Expand Down Expand Up @@ -53,7 +54,7 @@ def __init__(

self.schema = schema
self.capabilities = capabilities
self.config = config
self.config: SqlalchemyClientConfiguration = config
self.type_mapper = self.capabilities.get_type_mapper(self.sql_client.dialect)

def _to_table_object(self, schema_table: PreparedTableSchema) -> sa.Table:
Expand All @@ -64,27 +65,39 @@ def _to_table_object(self, schema_table: PreparedTableSchema) -> sa.Table:
# Re-generate the table if columns have changed
if existing_col_names == new_col_names:
return existing

# build the list of Column objects from the schema
table_columns = [
self._to_column_object(col, schema_table)
for col in schema_table["columns"].values()
if is_complete_column(col)
]

if self.config.create_primary_keys:
# if a primary key list is provided in the schema, add a PrimaryKeyConstraint.
pk_columns = get_columns_names_with_prop(schema_table, "primary_key")
if pk_columns:
table_columns.append(sa.PrimaryKeyConstraint(*pk_columns)) # type: ignore[arg-type]

return sa.Table(
schema_table["name"],
self.sql_client.metadata,
*[
self._to_column_object(col, schema_table)
for col in schema_table["columns"].values()
if is_complete_column(col)
],
*table_columns,
extend_existing=True,
schema=self.sql_client.dataset_name,
)

def _to_column_object(
self, schema_column: TColumnSchema, table: PreparedTableSchema
) -> sa.Column:
return sa.Column(
col_ = sa.Column(
schema_column["name"],
self.type_mapper.to_destination_type(schema_column, table),
nullable=schema_column.get("nullable", True),
unique=schema_column.get("unique", False),
)
if self.config.create_unique_indexes:
col_.unique = schema_column.get("unique", False)
return col_

def _create_replace_followup_jobs(
self, table_chain: Sequence[PreparedTableSchema]
Expand Down
21 changes: 18 additions & 3 deletions docs/website/docs/dlt-ecosystem/destinations/sqlalchemy.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,16 @@ with engine.connect() as conn:
print(result.fetchall())
```

## Notes on other dialects
We tested this destination on **mysql** and **sqlite** dialects. Below are a few notes that may help enabling other dialects:
1. `dlt` must be able to recognize if a database exception relates to non existing entity (like table or schema). We put
some work to recognize those for most of the popular dialects (look for `db_api_client.py`)
2. Primary keys and unique constraints are not created by default to avoid problems with particular dialects.
3. `merge` write disposition uses only `DELETE` and `INSERT` operations to enable as many dialects as possible.

Please report issues with particular dialects. We'll try to make them work.


## Write dispositions

The following write dispositions are supported:
Expand Down Expand Up @@ -157,6 +167,11 @@ For example, SQLite does not have `DATETIME` or `TIMESTAMP` types, so `timestamp
* [Parquet](../file-formats/parquet.md) is supported.

## Supported column hints

* `unique` hints are translated to `UNIQUE` constraints via SQLAlchemy (granted the database supports it).

No indexes or constraints are created on the table. You can enable the following via destination configuration
```toml
[destination.sqlalchemy]
create_unique_indexes=true
create_primary_keys=true
```
* `unique` hints are translated to `UNIQUE` constraints via SQLAlchemy.
* `primary_key` hints are translated to `PRIMARY KEY` constraints via SQLAlchemy.
56 changes: 56 additions & 0 deletions tests/load/pipeline/test_sqlalchemy_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import pytest

from tests.load.utils import (
destinations_configs,
DestinationTestConfiguration,
)

# mark all tests as essential, do not remove
pytestmark = pytest.mark.essential


@pytest.mark.parametrize(
"destination_config",
destinations_configs(default_sql_configs=True, subset=["sqlalchemy"]),
ids=lambda x: x.name,
)
@pytest.mark.parametrize("create_unique_indexes", (True, False))
@pytest.mark.parametrize("create_primary_keys", (True, False))
def test_sqlalchemy_create_indexes(
destination_config: DestinationTestConfiguration,
create_unique_indexes: bool,
create_primary_keys: bool,
) -> None:
from dlt.destinations import sqlalchemy
from dlt.common.libs.sql_alchemy import Table, MetaData

alchemy_ = sqlalchemy(
create_unique_indexes=create_unique_indexes, create_primary_keys=create_primary_keys
)

pipeline = destination_config.setup_pipeline(
"test_snowflake_case_sensitive_identifiers", dev_mode=True, destination=alchemy_
)
# load table with indexes
pipeline.run([{"id": 1, "value": "X"}], table_name="with_pk", primary_key="id")
# load without indexes
pipeline.run([{"id": 1, "value": "X"}], table_name="without_pk")

dataset_ = pipeline.dataset()
assert len(dataset_.with_pk.fetchall()) == 1
assert len(dataset_.without_pk.fetchall()) == 1

from sqlalchemy import inspect

with pipeline.sql_client() as client:
with_pk: Table = client.reflect_table("with_pk", metadata=MetaData())
assert (with_pk.c.id.primary_key or False) is create_primary_keys
if client.dialect.name != "sqlite":
# reflection does not show unique constraints
# assert (with_pk.c._dlt_id.unique or False) is create_unique_indexes
inspector = inspect(client.engine)
indexes = inspector.get_indexes("with_pk", schema=pipeline.dataset_name)
if create_unique_indexes:
assert indexes[0]["column_names"][0] == "_dlt_id"
else:
assert len(indexes) == 0
Loading