diff --git a/deltalake2db/azure_helper.py b/deltalake2db/azure_helper.py index 8049df0..f49dfd4 100644 --- a/deltalake2db/azure_helper.py +++ b/deltalake2db/azure_helper.py @@ -71,6 +71,15 @@ def get_storage_options_fsspec(storage_options: dict): return storage_options +def get_account_name_from_path(path: str): + if ".blob.core.windows.net" in path or ".dfs.core.windows.net" in path: + from urllib.parse import urlparse + + up = urlparse(path) + return up.netloc.split(".")[0] + return None + + def get_storage_options_object_store( path: Union[Path, str], storage_options: Optional[dict], @@ -109,7 +118,6 @@ def _get_cred(chain: str): if chain is not None: new_opts = storage_options.copy() new_opts.pop("chain", None) - new_opts.pop("anon", None) cred = _get_credential_from_chain(chain, get_credential) new_opts["token"] = cred.get_token(STORAGE_SCOPE).token if account_name_from_url: @@ -117,6 +125,12 @@ def _get_cred(chain: str): "account_name", account_name_from_url ) return new_path, new_opts + elif "anon" in storage_options and str(storage_options["anon"]).lower() in [ + "1", + "true", + ]: + storage_options = storage_options.copy() + storage_options.pop("anon") if account_name_from_url is not None and "account_name" not in storage_options: new_opts = storage_options.copy() new_opts["account_name"] = account_name_from_url diff --git a/deltalake2db/duckdb.py b/deltalake2db/duckdb.py index 478a527..78fc307 100644 --- a/deltalake2db/duckdb.py +++ b/deltalake2db/duckdb.py @@ -7,6 +7,7 @@ from deltalake2db.azure_helper import ( get_storage_options_fsspec, get_storage_options_object_store, + get_account_name_from_path, AZURE_EMULATOR_CONNECTION_STRING, ) from deltalake2db.filter_by_meta import _can_filter @@ -126,18 +127,21 @@ def _get_expr( return base_expr.as_(meta.name) if meta is not None and alias else base_expr -def load_install_azure(con: "duckdb.DuckDBPyConnection"): +def load_install_extension(con: "duckdb.DuckDBPyConnection", ext_name: str): with con.cursor() as cur: cur.execute( - "select loaded, installed from duckdb_extensions() where extension_name='azure' " + "select loaded, installed from duckdb_extensions() where extension_name=$ext_name ", + { + "ext_name": ext_name, + }, ) res = cur.fetchone() loaded = res[0] if res else False installed = res[0] if res else False if not installed: - con.install_extension("azure") + con.install_extension(ext_name) if not loaded: - con.load_extension("azure") + con.load_extension(ext_name) def apply_storage_options( @@ -210,7 +214,7 @@ def apply_storage_options_azure_ext( type = "azure" if type != "azure": raise ValueError("Only azure is supported for now") - load_install_azure(con) + load_install_extension(con, "azure") with con.cursor() as cur: cur.execute("FROM duckdb_secrets() where type='azure';") secrets = cur.fetchall() @@ -320,6 +324,7 @@ def create_view_for_delta( storage_options: Optional[dict] = None, get_credential: "Optional[Callable[[str], Optional[TokenCredential]]]" = None, use_fsspec: bool = False, + use_delta_ext=False, ): sql = get_sql_for_delta( delta_table, @@ -328,6 +333,7 @@ def create_view_for_delta( storage_options=storage_options, get_credential=get_credential, use_fsspec=use_fsspec, + use_delta_ext=use_delta_ext, ) assert '"' not in view_name if overwrite: @@ -337,7 +343,7 @@ def create_view_for_delta( def get_sql_for_delta_expr( - dt: "Union[DeltaTable,Path , str]", + table_or_path: "Union[DeltaTable,Path , str]", conditions: Union[Optional[dict], Sequence[ex.Expression], ex.Expression] = None, select: Union[Sequence[Union[str, ex.Expression]], None] = None, distinct=False, @@ -350,159 +356,184 @@ def get_sql_for_delta_expr( *, get_credential: "Optional[Callable[[str], Optional[TokenCredential]]]" = None, use_fsspec=False, + use_delta_ext=False, ) -> ex.Select: from deltalake import DeltaTable from .sql_utils import read_parquet, union, filter_via_dict - account_name_path = None base_path = ( - dt.table_uri - if isinstance(dt, DeltaTable) - else (dt if isinstance(dt, str) else str(dt.absolute())) + table_or_path.table_uri + if isinstance(table_or_path, DeltaTable) + else ( + table_or_path + if isinstance(table_or_path, str) + else str(table_or_path.absolute()) + ) ) base_path = base_path.removesuffix("/") - if isinstance(dt, Path) or isinstance(dt, str): - path_for_delta, storage_options_for_delta = get_storage_options_object_store( - dt, storage_options, get_credential - ) - account_name_path = ( - storage_options_for_delta.get("account_name", None) - if storage_options_for_delta - else None - ) - dt = DeltaTable(path_for_delta, storage_options=storage_options_for_delta) - - from .protocol_check import check_is_supported - - check_is_supported(dt) - - delta_table_cte_name = delta_table_cte_name or sql_prefix + "_delta_table" - from deltalake.schema import PrimitiveType + is_azure = base_path.startswith("az://") or base_path.startswith("abfss://") + account_name_path = get_account_name_from_path(base_path) if is_azure else None + if storage_options is None and isinstance(table_or_path, DeltaTable): + storage_options = table_or_path._storage_options - file_selects: list[ex.Select] = [] - - delta_fields = dt.schema().fields + conds = filter_via_dict(conditions) if isinstance(conditions, dict) else conditions owns_con = False if duck_con is None: import duckdb duck_con = duckdb.connect() owns_con = True + if use_delta_ext: + load_install_extension(duck_con, "delta") if use_fsspec: fake_protocol = apply_storage_options_fsspec( duck_con, base_path, - storage_options or dt._storage_options or {}, + storage_options or {}, account_name_path=account_name_path, ) base_path = fake_protocol + "://" + base_path.split("://")[1] - elif dt.table_uri.startswith("az://") or dt.table_uri.startswith("abfss://"): + elif is_azure: apply_storage_options_azure_ext( duck_con, - storage_options or dt._storage_options or {}, + storage_options or {}, type="azure", account_name_path=account_name_path, ) # type: ignore try: - for ac in dt.get_add_actions(flatten=True).to_pylist(): - if ( - conditions is not None - and isinstance(conditions, dict) - and _can_filter(ac, conditions) - ): - continue - if action_filter and not action_filter(ac): - continue - fullpath = base_path + "/" + ac["path"] - with duck_con.cursor() as cur: - cur.execute(f"select name from parquet_schema('{fullpath}')") - cols: list[str] = [c[0] for c in cur.fetchall()] - cols_sql: list[ex.Expression] = [] - for field in delta_fields: - field_name = field.name - phys_name = field.metadata.get( - "delta.columnMapping.physicalName", field_name + if not use_delta_ext: + if isinstance(table_or_path, Path) or isinstance(table_or_path, str): + path_for_delta, storage_options_for_delta = ( + get_storage_options_object_store( + table_or_path, storage_options, get_credential + ) ) - cast_as = None - - if isinstance(field.type, PrimitiveType): - if str(field.type).startswith("decimal("): - cast_as = ex.DataType.build(str(field.type)) - else: - cast_as = type_map.get(field.type.type) - if "partition_values" in ac and phys_name in ac["partition_values"]: - cols_sql.append( - _cast( - ex.convert(ac["partition_values"][phys_name]), cast_as - ).as_(field_name) + dt = DeltaTable( + path_for_delta, storage_options=storage_options_for_delta + ) + else: + dt = table_or_path + from .protocol_check import check_is_supported + + check_is_supported(dt) + delta_table_cte_name = delta_table_cte_name or sql_prefix + "_delta_table" + from deltalake.schema import PrimitiveType + + file_selects: list[ex.Select] = [] + + delta_fields = dt.schema().fields + for ac in dt.get_add_actions(flatten=True).to_pylist(): + if ( + conditions is not None + and isinstance(conditions, dict) + and _can_filter(ac, conditions) + ): + continue + if action_filter and not action_filter(ac): + continue + fullpath = base_path + "/" + ac["path"] + with duck_con.cursor() as cur: + cur.execute(f"select name from parquet_schema('{fullpath}')") + cols: list[str] = [c[0] for c in cur.fetchall()] + cols_sql: list[ex.Expression] = [] + for field in delta_fields: + field_name = field.name + phys_name = field.metadata.get( + "delta.columnMapping.physicalName", field_name ) - elif "partition." + phys_name in ac: - cols_sql.append( - _cast(ex.convert(ac["partition." + phys_name]), cast_as).as_( - field_name + + cast_as = None + + if isinstance(field.type, PrimitiveType): + if str(field.type).startswith("decimal("): + cast_as = ex.DataType.build(str(field.type)) + else: + cast_as = type_map.get(field.type.type) + if "partition_values" in ac and phys_name in ac["partition_values"]: + cols_sql.append( + _cast( + ex.convert(ac["partition_values"][phys_name]), cast_as + ).as_(field_name) ) - ) - elif "partition_values" in ac and field.name in ac["partition_values"]: - cols_sql.append( - _cast( - ex.convert(ac["partition_values"][field.name]), cast_as - ).as_(field_name) - ) - elif "partition." + field.name in ac: - cols_sql.append( - _cast(ex.convert(ac["partition." + field.name]), cast_as).as_( - field_name + elif "partition." + phys_name in ac: + cols_sql.append( + _cast( + ex.convert(ac["partition." + phys_name]), cast_as + ).as_(field_name) ) - ) - elif phys_name in cols: - cols_sql.append( - _get_expr(ex.column(phys_name, quoted=True), field.type, field) - ) - else: - cols_sql.append(ex.Null().as_(field_name)) - - select_pq = ex.select( - *cols_sql - ).from_( - read_parquet(ex.convert(fullpath)) - ) # "SELECT " + ", ".join(cols_sql) + " FROM read_parquet('" + fullpath + "')" - file_selects.append(select_pq) - if len(file_selects) == 0: - file_selects = [] - fields = [_dummy_expr(field.type).as_(field.name) for field in delta_fields] - file_selects.append(ex.select(*fields).where("1=0")) - file_sql = ex.CTE( - this=union(file_selects, distinct=False), alias=delta_table_cte_name - ) - if select: - select_exprs = [ - ex.column(s, quoted=True) if isinstance(s, str) else s for s in select - ] - else: - select_exprs = [ex.Star()] - - conds = ( - filter_via_dict(conditions) if isinstance(conditions, dict) else conditions - ) - - se = ex.select(*select_exprs) - if distinct: - se = se.distinct() - if conds is not None: - se = se.where(*conds) - se = se.from_(delta_table_cte_name) - - if cte_wrap_name: - s = ex.select(ex.Star()).from_(cte_wrap_name) - # s = s.with_() - s.with_(file_sql.alias, file_sql.args["this"], copy=False) - s.with_(cte_wrap_name, se, copy=False) - return s + elif ( + "partition_values" in ac + and field.name in ac["partition_values"] + ): + cols_sql.append( + _cast( + ex.convert(ac["partition_values"][field.name]), cast_as + ).as_(field_name) + ) + elif "partition." + field.name in ac: + cols_sql.append( + _cast( + ex.convert(ac["partition." + field.name]), cast_as + ).as_(field_name) + ) + elif phys_name in cols: + cols_sql.append( + _get_expr( + ex.column(phys_name, quoted=True), field.type, field + ) + ) + else: + cols_sql.append(ex.Null().as_(field_name)) + + select_pq = ex.select(*cols_sql).from_( + read_parquet(ex.convert(fullpath)) + ) # "SELECT " + ", ".join(cols_sql) + " FROM read_parquet('" + fullpath + "')" + file_selects.append(select_pq) + if len(file_selects) == 0: + file_selects = [] + fields = [ + _dummy_expr(field.type).as_(field.name) for field in delta_fields + ] + file_selects.append(ex.select(*fields).where("1=0")) + file_sql = ex.CTE( + this=union(file_selects, distinct=False), alias=delta_table_cte_name + ) + if select: + select_exprs = [ + ex.column(s, quoted=True) if isinstance(s, str) else s + for s in select + ] + else: + select_exprs = [ex.Star()] + + se = ex.select(*select_exprs) + if distinct: + se = se.distinct() + if conds is not None: + se = se.where(*conds) + se = se.from_(delta_table_cte_name) + + if cte_wrap_name: + s = ex.select(ex.Star()).from_(cte_wrap_name) + # s = s.with_() + s.with_(file_sql.alias, file_sql.args["this"], copy=False) + s.with_(cte_wrap_name, se, copy=False) + return s + else: + se.with_(file_sql.alias, file_sql.args["this"], copy=False) + return se else: - se.with_(file_sql.alias, file_sql.args["this"], copy=False) + se = ex.select(ex.Star()).from_( + ex.func("delta_scan", ex.convert(base_path)) + ) + if distinct: + se = se.distinct() + if conds is not None: + se = se.where(*conds) return se + finally: if owns_con: duck_con.close() @@ -521,9 +552,10 @@ def get_sql_for_delta( *, get_credential: "Optional[Callable[[str], Optional[TokenCredential]]]" = None, use_fsspec: bool = False, + use_delta_ext=False, ) -> str: expr = get_sql_for_delta_expr( - dt=dt, + table_or_path=dt, conditions=conditions, select=select, distinct=distinct, @@ -534,6 +566,7 @@ def get_sql_for_delta( storage_options=storage_options, get_credential=get_credential, use_fsspec=use_fsspec, + use_delta_ext=use_delta_ext, ) if cte_wrap_name: suffix_sql = ex.select(ex.Star()).from_(cte_wrap_name).sql(dialect="duckdb") diff --git a/pyproject.toml b/pyproject.toml index df37d1c..f08490f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "deltalake2db" -version = "0.5.2" +version = "0.6.0-beta1" description = "" authors = ["Adrian Ehrsam "] license = "MIT" diff --git a/tests/test_duckdb.py b/tests/test_duckdb.py index 5d79cc6..cbde035 100644 --- a/tests/test_duckdb.py +++ b/tests/test_duckdb.py @@ -2,6 +2,7 @@ from deltalake import DeltaTable import duckdb import polars as pl +import pytest def test_col_mapping(): @@ -39,13 +40,16 @@ def test_col_mapping(): print(as_py_rows) -def test_strange_cols(): +@pytest.mark.parametrize("use_delta_ext", [False, True]) +def test_strange_cols(use_delta_ext): dt = DeltaTable("tests/data/user") from deltalake2db import duckdb_create_view_for_delta with duckdb.connect() as con: - duckdb_create_view_for_delta(con, dt, "delta_table") + duckdb_create_view_for_delta( + con, dt, "delta_table", use_delta_ext=use_delta_ext + ) con.execute("select * from delta_table") col_names = [c[0] for c in con.description] assert "time stämp" in col_names diff --git a/tests/test_duckdb_az.py b/tests/test_duckdb_az.py index a19a571..0491c95 100644 --- a/tests/test_duckdb_az.py +++ b/tests/test_duckdb_az.py @@ -42,8 +42,11 @@ def test_chain(): print(as_py_rows) -@pytest.mark.parametrize("use_fsspec", [True, False]) -def test_col_mapping(storage_options, use_fsspec: bool): +@pytest.mark.parametrize( + "use_fsspec,use_delta_ext", + [(True, False), (False, False), (False, True)], +) +def test_col_mapping(storage_options, use_fsspec: bool, use_delta_ext: bool): from deltalake2db import duckdb_create_view_for_delta with duckdb.connect() as con: @@ -53,6 +56,7 @@ def test_col_mapping(storage_options, use_fsspec: bool): "delta_table", storage_options=storage_options, use_fsspec=use_fsspec, + use_delta_ext=use_delta_ext, ) duckdb_create_view_for_delta( con, @@ -60,6 +64,7 @@ def test_col_mapping(storage_options, use_fsspec: bool): "delta_table", storage_options=storage_options, use_fsspec=use_fsspec, + use_delta_ext=use_delta_ext, ) # do it twice to test duplicate secrets df = pl.from_arrow(con.execute("select * from delta_table").fetch_arrow_table())