Skip to content

Commit

Permalink
fixes empty df for polars
Browse files Browse the repository at this point in the history
  • Loading branch information
aersam committed Mar 26, 2024
1 parent 6799d68 commit 58719ea
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 6 deletions.
18 changes: 13 additions & 5 deletions deltalake2db/polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import polars as pl
from deltalake import DeltaTable, Field, DataType
from deltalake.exceptions import DeltaProtocolError
from deltalake.schema import StructType, ArrayType, PrimitiveType
from deltalake.schema import StructType, ArrayType, PrimitiveType, MapType
import os


Expand Down Expand Up @@ -59,11 +59,15 @@ def _get_expr(
return base_expr.alias(meta.name) if meta else base_expr


def _try_get_type(dtype: "DataType") -> "pl.PolarsDataType | None":
def _get_type(dtype: "DataType") -> "pl.PolarsDataType | None":
import polars as pl

if not isinstance(dtype, PrimitiveType):
return None
if isinstance(dtype, StructType):
return pl.Struct([pl.Field(f.name, _get_type(f.type)) for f in dtype.fields])
if isinstance(dtype, ArrayType):
return pl.List(_get_type(dtype.element_type))
if isinstance(dtype, MapType):
raise NotImplementedError("MapType not supported in polars")

dtype_str = str(dtype.type)
if dtype_str == "string":
Expand Down Expand Up @@ -112,7 +116,7 @@ def scan_delta_union(delta_table: DeltaTable | Path) -> "pl.LazyFrame":
parquet_schema = base_ds.limit(0).schema
selects = []
for field in all_fields:
pl_dtype = _try_get_type(field.type)
pl_dtype = _get_type(field.type)
pn = field.metadata.get("delta.columnMapping.physicalName", field.name)
if "partition_values" in ac and pn in ac["partition_values"]:
part_vl = ac["partition_values"][pn]
Expand All @@ -136,4 +140,8 @@ def scan_delta_union(delta_table: DeltaTable | Path) -> "pl.LazyFrame":

ds = base_ds.select(*selects)
all_ds.append(ds)
if len(all_ds) == 0:
return pl.DataFrame(
data=[], schema={f.name: _get_type(f.type) for f in all_fields}
).lazy()
return pl.concat(all_ds, how="diagonal_relaxed")
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "deltalake2db"
version = "0.2.1"
version = "0.2.2"
description = ""
authors = ["Adrian Ehrsam <[email protected]>"]
license = "MIT"
Expand Down
10 changes: 10 additions & 0 deletions tests/test_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,16 @@ def test_user_add():
assert diff == [1555]


def test_user_empty():
dt = DeltaTable("tests/data/user_empty")

from deltalake2db import polars_scan_delta

df = polars_scan_delta(dt).collect()
assert df.shape[0] == 0
assert "time stämp" in df.columns


def test_strange_cols():
dt = DeltaTable("tests/data/user")

Expand Down

0 comments on commit 58719ea

Please sign in to comment.