Skip to content

Commit

Permalink
Merge pull request #22 from bmsuisse/dev
Browse files Browse the repository at this point in the history
allow selection of fields for polars
  • Loading branch information
aersam authored Oct 31, 2024
2 parents 9b96398 + a5e91f8 commit 1adc11f
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 4 deletions.
7 changes: 6 additions & 1 deletion deltalake2db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
from .duckdb import get_sql_for_delta_expr
from .duckdb import create_view_for_delta as duckdb_create_view_for_delta
from .duckdb import apply_storage_options as duckdb_apply_storage_options
from .polars import scan_delta_union as polars_scan_delta, get_polars_schema
from .polars import (
scan_delta_union as polars_scan_delta,
get_polars_schema,
PolarsSettings,
)
from .protocol_check import is_protocol_supported
from .delta_lake import get_delta_table

Expand All @@ -13,6 +17,7 @@
"duckdb_apply_storage_options",
"polars_scan_delta",
"get_polars_schema",
"PolarsSettings",
"is_protocol_supported",
"get_delta_table",
]
22 changes: 20 additions & 2 deletions deltalake2db/polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@


class PolarsSettings:
exclude_fields: Optional[list[str]] = None
fields: Optional[list[str]] = None

timestamp_ntz_type: "pl.Datetime"
timestamp_type: "pl.Datetime"

Expand All @@ -23,6 +26,8 @@ def __init__(
*,
timestamp_ntz_type: "Optional[pl.Datetime]" = None,
timestamp_type: "Optional[pl.Datetime]" = None,
exclude_fields: Optional[list[str]] = None,
fields: Optional[list[str]] = None,
):
import polars as pl

Expand All @@ -32,6 +37,8 @@ def __init__(
self.timestamp_type = timestamp_type or pl.Datetime(
time_unit="us", time_zone="utc"
)
self.exclude_fields = exclude_fields
self.fields = fields


def _get_expr(
Expand Down Expand Up @@ -187,9 +194,12 @@ def get_polars_schema(
delta_table = DeltaTable(delta_table)
check_is_supported(delta_table)
res_dict = OrderedDict()
meta = delta_table.metadata()
for f in delta_table.schema().fields:
pn = f.name
if settings.exclude_fields and f.name in settings.exclude_fields:
continue
if settings.fields and f.name not in settings.fields:
continue
if physical_name:
pn = f.metadata.get("delta.columnMapping.physicalName", f.name)
res_dict[pn] = _get_type(f.type, physical_name, settings)
Expand Down Expand Up @@ -217,12 +227,16 @@ def scan_delta_union(
check_is_supported(delta_table)
all_ds = []
all_fields = delta_table.schema().fields
physical_schema = get_polars_schema(delta_table, physical_name=True)
physical_schema = get_polars_schema(
delta_table, physical_name=True, settings=settings
)
physical_schema_no_parts = physical_schema.copy()

logical_to_physical = {
f.name: f.metadata.get("delta.columnMapping.physicalName", f.name)
for f in all_fields
if (settings.exclude_fields is None or f.name not in settings.exclude_fields)
and (settings.fields is None or f.name in settings.fields)
}
for pc in delta_table.metadata().partition_columns:
physical_schema_no_parts.pop(logical_to_physical.get(pc, pc))
Expand Down Expand Up @@ -268,6 +282,10 @@ def scan_delta_union(
)
selects = []
for field in all_fields:
if settings.exclude_fields and field.name in settings.exclude_fields:
continue
if settings.fields and field.name not in settings.fields:
continue
pn = field.metadata.get("delta.columnMapping.physicalName", field.name)
if "partition_values" in ac and pn in ac["partition_values"]:
part_vl = ac["partition_values"][pn]
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "deltalake2db"
version = "0.7.1"
version = "0.7.2"
description = ""
authors = ["Adrian Ehrsam <[email protected]>"]
license = "MIT"
Expand Down
16 changes: 16 additions & 0 deletions tests/test_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,22 @@ def test_user_empty():
assert "time stämp" in df.columns


def test_select():
dt = DeltaTable("tests/data/user")

from deltalake2db import polars_scan_delta, PolarsSettings

df = polars_scan_delta(dt, settings=PolarsSettings(fields=["User - iD"])).collect()
assert len(df.columns) == 1
assert "User - iD" in df.columns

df = polars_scan_delta(
dt, settings=PolarsSettings(exclude_fields=["User - iD"])
).collect()
assert len(df.columns) > 1
assert "User - iD" not in df.columns


def test_strange_cols():
dt = DeltaTable("tests/data/user")

Expand Down

0 comments on commit 1adc11f

Please sign in to comment.