Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improving array casting #1865

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open

Conversation

FBruzzesi
Copy link
Member

What type of PR is this? (check all applicable)

  • πŸ’Ύ Refactor
  • ✨ Feature
  • πŸ› Bug Fix
  • πŸ”§ Optimization
  • πŸ“ Documentation
  • βœ… Test
  • 🐳 Other

Checklist

  • Code follows style guide (ruff)
  • Tests added
  • Documented the changes

If you have comments or can explain your changes, please do so below

Comment on lines -735 to +741
self: Self, inner: DType | type[DType], width: int | None = None
self: Self,
inner: DType | type[DType],
shape: int | tuple[int, ...] | None = None,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be a breaking change :(

@@ -220,7 +224,7 @@ def broadcast_and_extract_dataframe_comparand(

if isinstance(other, ArrowSeries):
len_other = len(other)
if len_other == 1:
if len_other == 1 and length != 1:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise for list/array types we end up getting the first element

@@ -111,10 +111,13 @@ def native_to_narwhals_dtype(duckdb_dtype: str, version: Version) -> DType:
)
if match_ := re.match(r"(.*)\[\]$", duckdb_dtype):
return dtypes.List(native_to_narwhals_dtype(match_.group(1), version))
if match_ := re.match(r"(\w+)\[(\d+)\]", duckdb_dtype):
if match_ := re.match(r"(\w+)((?:\[\d+\])+)", duckdb_dtype):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Array type in duckdb can also have multiple dimensions, rendering as: INNER[d1][d2][...]

Comment on lines +177 to +181
duckdb_shape_fmt = "".join(f"[{item}]" for item in dtype.shape) # type: ignore[union-attr]
while isinstance(dtype.inner, dtypes.Array): # type: ignore[union-attr]
dtype = dtype.inner # type: ignore[union-attr]
inner = narwhals_to_native_dtype(dtype.inner, version) # type: ignore[union-attr]
return f"{inner}{duckdb_shape_fmt}"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First creates the shape [d1][d2]... then find the inner type recursively (first being non array)

@@ -160,7 +160,7 @@ def broadcast_and_extract_dataframe_comparand(index: Any, other: Any) -> Any:
if isinstance(other, PandasLikeSeries):
len_other = other.len()

if len_other == 1:
if len_other == 1 and len(index) != 1:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly as for pyarrow, otherwise for list/array types we end up getting the first element

Comment on lines +59 to +62
if isinstance(dtype, pyspark_types.ArrayType): # pragma: no cover
return dtypes.List(
inner=native_to_narwhals_dtype(dtype.elementType, version=version)
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spark arrays do not have a fixed dimension, so I am converting to a list

@FBruzzesi FBruzzesi marked this pull request as ready for review January 27, 2025 14:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant